mirror of
https://github.com/clearml/clearml-server
synced 2025-04-06 22:14:10 +00:00
Return file urls for tasks.delete/reset and models.delete
This commit is contained in:
parent
78e4a58c91
commit
8b464e7ae6
@ -218,7 +218,7 @@ class ActualEnumField(fields.StringField):
|
||||
)
|
||||
|
||||
def parse_value(self, value):
|
||||
if value in (NotSet, None) and not self.required:
|
||||
if value is NotSet and not self.required:
|
||||
return self.get_default_value()
|
||||
try:
|
||||
# noinspection PyArgumentList
|
||||
|
@ -32,8 +32,16 @@ class CreateModelResponse(models.Base):
|
||||
created = fields.BoolField(required=True)
|
||||
|
||||
|
||||
class PublishModelRequest(models.Base):
|
||||
class ModelRequest(models.Base):
|
||||
model = fields.StringField(required=True)
|
||||
|
||||
|
||||
class DeleteModelRequest(ModelRequest):
|
||||
force = fields.BoolField(default=False)
|
||||
return_file_url = fields.BoolField(default=False)
|
||||
|
||||
|
||||
class PublishModelRequest(ModelRequest):
|
||||
force_publish_task = fields.BoolField(default=False)
|
||||
publish_task = fields.BoolField(default=True)
|
||||
|
||||
|
@ -53,6 +53,7 @@ class ResetResponse(UpdateResponse):
|
||||
frames = DictField()
|
||||
events = DictField()
|
||||
model_deleted = IntField()
|
||||
urls = DictField()
|
||||
|
||||
|
||||
class TaskRequest(models.Base):
|
||||
@ -71,6 +72,7 @@ class EnqueueRequest(UpdateRequest):
|
||||
|
||||
class DeleteRequest(UpdateRequest):
|
||||
move_to_trash = BoolField(default=True)
|
||||
return_file_urls = BoolField(default=False)
|
||||
|
||||
|
||||
class SetRequirementsRequest(TaskRequest):
|
||||
@ -137,6 +139,7 @@ class DeleteArtifactsRequest(TaskRequest):
|
||||
|
||||
class ResetRequest(UpdateRequest):
|
||||
clear_all = BoolField(default=False)
|
||||
return_file_urls = BoolField(default=False)
|
||||
|
||||
|
||||
class MultiTaskRequest(models.Base):
|
||||
|
@ -1,5 +1,6 @@
|
||||
import base64
|
||||
import hashlib
|
||||
import re
|
||||
import zlib
|
||||
from collections import defaultdict
|
||||
from contextlib import closing
|
||||
@ -36,11 +37,10 @@ from apiserver.redis_manager import redman
|
||||
from apiserver.timing_context import TimingContext
|
||||
from apiserver.tools import safe_get
|
||||
from apiserver.utilities.dicts import flatten_nested_items
|
||||
|
||||
# noinspection PyTypeChecker
|
||||
from apiserver.utilities.json import loads
|
||||
|
||||
EVENT_TYPES = set(map(attrgetter("value"), EventType))
|
||||
# noinspection PyTypeChecker
|
||||
EVENT_TYPES: Set[str] = set(map(attrgetter("value"), EventType))
|
||||
LOCKED_TASK_STATUSES = (TaskStatus.publishing, TaskStatus.published)
|
||||
|
||||
|
||||
@ -49,11 +49,16 @@ class PlotFields:
|
||||
plot_len = "plot_len"
|
||||
plot_str = "plot_str"
|
||||
plot_data = "plot_data"
|
||||
source_urls = "source_urls"
|
||||
|
||||
|
||||
class EventBLL(object):
|
||||
id_fields = ("task", "iter", "metric", "variant", "key")
|
||||
empty_scroll = "FFFF"
|
||||
img_source_regex = re.compile(
|
||||
r"['\"]source['\"]:\s?['\"]([a-z][a-z0-9+\-.]*://.*?)['\"]",
|
||||
flags=re.IGNORECASE,
|
||||
)
|
||||
|
||||
def __init__(self, events_es=None, redis=None):
|
||||
self.es = events_es or es_factory.connect("events")
|
||||
@ -269,6 +274,11 @@ class EventBLL(object):
|
||||
event[PlotFields.plot_len] = plot_len
|
||||
if validate:
|
||||
event[PlotFields.valid_plot] = self._is_valid_json(plot_str)
|
||||
|
||||
urls = {match for match in self.img_source_regex.findall(plot_str)}
|
||||
if urls:
|
||||
event[PlotFields.source_urls] = list(urls)
|
||||
|
||||
if compression_threshold and plot_len >= compression_threshold:
|
||||
event[PlotFields.plot_data] = base64.encodebytes(
|
||||
zlib.compress(plot_str.encode(), level=1)
|
||||
@ -504,7 +514,7 @@ class EventBLL(object):
|
||||
scroll_id: str = None,
|
||||
):
|
||||
if scroll_id == self.empty_scroll:
|
||||
return [], scroll_id, 0
|
||||
return TaskEventsResult()
|
||||
|
||||
if scroll_id:
|
||||
with translate_errors_context(), TimingContext("es", "get_task_events"):
|
||||
@ -598,6 +608,41 @@ class EventBLL(object):
|
||||
|
||||
return events, total_events, next_scroll_id
|
||||
|
||||
def get_plot_image_urls(
|
||||
self, company_id: str, task_id: str, scroll_id: Optional[str]
|
||||
) -> Tuple[Sequence[dict], Optional[str]]:
|
||||
if scroll_id == self.empty_scroll:
|
||||
return [], None
|
||||
|
||||
if scroll_id:
|
||||
es_res = self.es.scroll(scroll_id=scroll_id, scroll="10m")
|
||||
else:
|
||||
if check_empty_data(self.es, company_id, EventType.metrics_plot):
|
||||
return [], None
|
||||
|
||||
es_req = {
|
||||
"size": 1000,
|
||||
"_source": [PlotFields.source_urls],
|
||||
"query": {
|
||||
"bool": {
|
||||
"must": [
|
||||
{"term": {"task": task_id}},
|
||||
{"exists": {"field": PlotFields.source_urls}},
|
||||
]
|
||||
}
|
||||
},
|
||||
}
|
||||
es_res = search_company_events(
|
||||
self.es,
|
||||
company_id=company_id,
|
||||
event_type=EventType.metrics_plot,
|
||||
body=es_req,
|
||||
scroll="10m",
|
||||
)
|
||||
|
||||
events, _, next_scroll_id = self._get_events_from_es_res(es_res)
|
||||
return events, next_scroll_id
|
||||
|
||||
def get_task_events(
|
||||
self,
|
||||
company_id: str,
|
||||
|
@ -3,5 +3,4 @@ from .utils import (
|
||||
ChangeStatusRequest,
|
||||
update_project_time,
|
||||
validate_status_change,
|
||||
split_by,
|
||||
)
|
||||
|
@ -38,7 +38,7 @@ from apiserver.timing_context import TimingContext
|
||||
from apiserver.utilities.parameter_key_escaper import ParameterKeyEscaper
|
||||
from .artifacts import artifacts_prepare_for_save
|
||||
from .param_utils import params_prepare_for_save
|
||||
from .utils import ChangeStatusRequest, validate_status_change, update_project_time
|
||||
from .utils import ChangeStatusRequest, validate_status_change, update_project_time, task_deleted_prefix
|
||||
|
||||
log = config.logger(__file__)
|
||||
org_bll = OrgBLL()
|
||||
@ -247,6 +247,11 @@ class TaskBLL:
|
||||
]
|
||||
|
||||
with TimingContext("mongo", "clone task"):
|
||||
parent_task = (
|
||||
task.parent
|
||||
if task.parent and not task.parent.startswith(task_deleted_prefix)
|
||||
else None
|
||||
)
|
||||
new_task = Task(
|
||||
id=create_id(),
|
||||
user=user_id,
|
||||
@ -256,7 +261,7 @@ class TaskBLL:
|
||||
last_change=now,
|
||||
name=name or task.name,
|
||||
comment=comment or task.comment,
|
||||
parent=parent or task.parent,
|
||||
parent=parent or parent_task,
|
||||
project=project or task.project,
|
||||
tags=tags or task.tags,
|
||||
system_tags=system_tags or clean_system_tags(task.system_tags),
|
||||
|
247
apiserver/bll/task/task_cleanup.py
Normal file
247
apiserver/bll/task/task_cleanup.py
Normal file
@ -0,0 +1,247 @@
|
||||
import itertools
|
||||
from collections import defaultdict
|
||||
from operator import attrgetter
|
||||
from typing import Sequence, Generic, Callable, Type, Iterable, TypeVar, List, Set
|
||||
|
||||
import attr
|
||||
from boltons.iterutils import partition
|
||||
from mongoengine import QuerySet, Document
|
||||
|
||||
from apiserver.apierrors import errors
|
||||
from apiserver.bll.event import EventBLL
|
||||
from apiserver.bll.event.event_bll import PlotFields
|
||||
from apiserver.bll.event.event_common import EventType
|
||||
from apiserver.bll.task.utils import task_deleted_prefix
|
||||
from apiserver.database.model.model import Model
|
||||
from apiserver.database.model.task.task import Task, TaskStatus, ArtifactModes
|
||||
from apiserver.timing_context import TimingContext
|
||||
|
||||
event_bll = EventBLL()
|
||||
T = TypeVar("T", bound=Document)
|
||||
|
||||
|
||||
class DocumentGroup(List[T]):
|
||||
"""
|
||||
Operate on a list of documents as if they were a query result
|
||||
"""
|
||||
|
||||
def __init__(self, document_type: Type[T], documents: Iterable[T]):
|
||||
super(DocumentGroup, self).__init__(documents)
|
||||
self.type = document_type
|
||||
|
||||
@property
|
||||
def ids(self) -> Set[str]:
|
||||
return {obj.id for obj in self}
|
||||
|
||||
def objects(self, *args, **kwargs) -> QuerySet:
|
||||
return self.type.objects(id__in=self.ids, *args, **kwargs)
|
||||
|
||||
|
||||
class TaskOutputs(Generic[T]):
|
||||
"""
|
||||
Split task outputs of the same type by the ready state
|
||||
"""
|
||||
|
||||
published: DocumentGroup[T]
|
||||
draft: DocumentGroup[T]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
is_published: Callable[[T], bool],
|
||||
document_type: Type[T],
|
||||
children: Iterable[T],
|
||||
):
|
||||
"""
|
||||
:param is_published: predicate returning whether items is considered published
|
||||
:param document_type: type of output
|
||||
:param children: output documents
|
||||
"""
|
||||
self.published, self.draft = map(
|
||||
lambda x: DocumentGroup(document_type, x),
|
||||
partition(children, key=is_published),
|
||||
)
|
||||
|
||||
|
||||
@attr.s(auto_attribs=True)
|
||||
class TaskUrls:
|
||||
model_urls: Sequence[str]
|
||||
event_urls: Sequence[str]
|
||||
artifact_urls: Sequence[str]
|
||||
|
||||
|
||||
@attr.s(auto_attribs=True)
|
||||
class CleanupResult:
|
||||
"""
|
||||
Counts of objects modified in task cleanup operation
|
||||
"""
|
||||
|
||||
updated_children: int
|
||||
updated_models: int
|
||||
deleted_models: int
|
||||
urls: TaskUrls = None
|
||||
|
||||
|
||||
def _collect_plot_image_urls(company: str, task: str) -> Set[str]:
|
||||
urls = set()
|
||||
next_scroll_id = None
|
||||
with TimingContext("es", "collect_plot_image_urls"):
|
||||
while True:
|
||||
events, next_scroll_id = event_bll.get_plot_image_urls(
|
||||
company_id=company, task_id=task, scroll_id=next_scroll_id
|
||||
)
|
||||
if not events:
|
||||
break
|
||||
for event in events:
|
||||
event_urls = event.get(PlotFields.source_urls)
|
||||
if event_urls:
|
||||
urls.update(set(event_urls))
|
||||
|
||||
return urls
|
||||
|
||||
|
||||
def _collect_debug_image_urls(company: str, task: str) -> Set[str]:
|
||||
"""
|
||||
Return the set of unique image urls
|
||||
Uses DebugImagesIterator to make sure that we do not retrieve recycled urls
|
||||
"""
|
||||
metrics = event_bll.get_metrics_and_variants(
|
||||
company_id=company, task_id=task, event_type=EventType.metrics_image
|
||||
)
|
||||
if not metrics:
|
||||
return set()
|
||||
|
||||
task_metrics = [(task, metric) for metric in metrics]
|
||||
scroll_id = None
|
||||
urls = defaultdict(set)
|
||||
while True:
|
||||
res = event_bll.debug_images_iterator.get_task_events(
|
||||
company_id=company, metrics=task_metrics, iter_count=100, state_id=scroll_id
|
||||
)
|
||||
if not res.metric_events or not any(
|
||||
events for _, _, events in res.metric_events
|
||||
):
|
||||
break
|
||||
|
||||
scroll_id = res.next_scroll_id
|
||||
for _, metric, iterations in res.metric_events:
|
||||
metric_urls = set(ev.get("url") for it in iterations for ev in it["events"])
|
||||
metric_urls.discard(None)
|
||||
urls[metric].update(metric_urls)
|
||||
|
||||
return set(itertools.chain.from_iterable(urls.values()))
|
||||
|
||||
|
||||
def cleanup_task(
|
||||
task: Task, force: bool = False, update_children=True, return_file_urls=False
|
||||
) -> CleanupResult:
|
||||
"""
|
||||
Validate task deletion and delete/modify all its output.
|
||||
:param task: task object
|
||||
:param force: whether to delete task with published outputs
|
||||
:return: count of delete and modified items
|
||||
"""
|
||||
models = verify_task_children_and_ouptuts(task, force)
|
||||
|
||||
event_urls, artifact_urls, model_urls = set(), set(), set()
|
||||
if return_file_urls:
|
||||
event_urls = _collect_debug_image_urls(task.company, task.id)
|
||||
event_urls.update(_collect_plot_image_urls(task.company, task.id))
|
||||
if task.execution and task.execution.artifacts:
|
||||
artifact_urls = {
|
||||
a.uri
|
||||
for a in task.execution.artifacts.values()
|
||||
if a.mode == ArtifactModes.output and a.uri
|
||||
}
|
||||
model_urls = {m.uri for m in models.draft.objects().only("uri") if m.uri}
|
||||
|
||||
deleted_task_id = f"{task_deleted_prefix}{task.id}"
|
||||
if update_children:
|
||||
with TimingContext("mongo", "update_task_children"):
|
||||
updated_children = Task.objects(parent=task.id).update(
|
||||
parent=deleted_task_id
|
||||
)
|
||||
else:
|
||||
updated_children = 0
|
||||
|
||||
if models.draft:
|
||||
with TimingContext("mongo", "delete_models"):
|
||||
deleted_models = models.draft.objects().delete()
|
||||
else:
|
||||
deleted_models = 0
|
||||
|
||||
if models.published and update_children:
|
||||
with TimingContext("mongo", "update_task_models"):
|
||||
updated_models = models.published.objects().update(task=deleted_task_id)
|
||||
else:
|
||||
updated_models = 0
|
||||
|
||||
event_bll.delete_task_events(task.company, task.id, allow_locked=force)
|
||||
|
||||
return CleanupResult(
|
||||
deleted_models=deleted_models,
|
||||
updated_children=updated_children,
|
||||
updated_models=updated_models,
|
||||
urls=TaskUrls(
|
||||
event_urls=list(event_urls),
|
||||
artifact_urls=list(artifact_urls),
|
||||
model_urls=list(model_urls),
|
||||
)
|
||||
if return_file_urls
|
||||
else None,
|
||||
)
|
||||
|
||||
|
||||
def verify_task_children_and_ouptuts(task, force: bool) -> TaskOutputs[Model]:
|
||||
if not force:
|
||||
with TimingContext("mongo", "count_published_children"):
|
||||
published_children_count = Task.objects(
|
||||
parent=task.id, status=TaskStatus.published
|
||||
).count()
|
||||
if published_children_count:
|
||||
raise errors.bad_request.TaskCannotBeDeleted(
|
||||
"has children, use force=True",
|
||||
task=task.id,
|
||||
children=published_children_count,
|
||||
)
|
||||
|
||||
with TimingContext("mongo", "get_task_models"):
|
||||
models = TaskOutputs(
|
||||
attrgetter("ready"),
|
||||
Model,
|
||||
Model.objects(task=task.id).only("id", "task", "ready"),
|
||||
)
|
||||
if not force and models.published:
|
||||
raise errors.bad_request.TaskCannotBeDeleted(
|
||||
"has output models, use force=True",
|
||||
task=task.id,
|
||||
models=len(models.published),
|
||||
)
|
||||
|
||||
if task.output.model:
|
||||
with TimingContext("mongo", "get_task_output_model"):
|
||||
output_model = Model.objects(id=task.output.model).first()
|
||||
if output_model:
|
||||
if output_model.ready:
|
||||
if not force:
|
||||
raise errors.bad_request.TaskCannotBeDeleted(
|
||||
"has output model, use force=True",
|
||||
task=task.id,
|
||||
model=task.output.model,
|
||||
)
|
||||
models.published.append(output_model)
|
||||
else:
|
||||
models.draft.append(output_model)
|
||||
|
||||
if models.draft:
|
||||
with TimingContext("mongo", "get_execution_models"):
|
||||
model_ids = models.draft.ids
|
||||
dependent_tasks = Task.objects(execution__model__in=model_ids).only(
|
||||
"id", "execution.model"
|
||||
)
|
||||
input_models = [t.execution.model for t in dependent_tasks]
|
||||
if input_models:
|
||||
models.draft = DocumentGroup(
|
||||
Model, (m for m in models.draft if m.id not in input_models)
|
||||
)
|
||||
|
||||
return models
|
@ -1,5 +1,5 @@
|
||||
from datetime import datetime
|
||||
from typing import TypeVar, Callable, Tuple, Sequence, Union
|
||||
from typing import Sequence, Union
|
||||
|
||||
import attr
|
||||
import six
|
||||
@ -13,6 +13,7 @@ from apiserver.timing_context import TimingContext
|
||||
from apiserver.utilities.attrs import typed_attrs
|
||||
|
||||
valid_statuses = get_options(TaskStatus)
|
||||
task_deleted_prefix = "__DELETED__"
|
||||
|
||||
|
||||
@typed_attrs
|
||||
@ -164,22 +165,6 @@ def update_project_time(project_ids: Union[str, Sequence[str]]):
|
||||
return Project.objects(id__in=project_ids).update(last_update=datetime.utcnow())
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def split_by(
|
||||
condition: Callable[[T], bool], items: Sequence[T]
|
||||
) -> Tuple[Sequence[T], Sequence[T]]:
|
||||
"""
|
||||
split "items" to two lists by "condition"
|
||||
"""
|
||||
applied = zip(map(condition, items), items)
|
||||
return (
|
||||
[item for cond, item in applied if cond],
|
||||
[item for cond, item in applied if not cond],
|
||||
)
|
||||
|
||||
|
||||
def get_task_for_update(
|
||||
company_id: str, task_id: str, allow_all_statuses: bool = False, force: bool = False
|
||||
) -> Task:
|
||||
|
@ -697,6 +697,24 @@ delete {
|
||||
}
|
||||
}
|
||||
}
|
||||
"2.13": ${delete."2.1"} {
|
||||
request {
|
||||
properties {
|
||||
return_file_url {
|
||||
description: "If set to 'true' then return the url of the model file. Default value is 'false'"
|
||||
type: boolean
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
properties {
|
||||
url {
|
||||
descrition: "The url of the model file"
|
||||
type: string
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
make_public {
|
||||
|
@ -1191,6 +1191,25 @@ reset {
|
||||
}
|
||||
}
|
||||
}
|
||||
"2.13": ${reset."2.1"} {
|
||||
request {
|
||||
properties {
|
||||
return_file_urls {
|
||||
description: "If set to 'true' then return the urls of the files that were uploaded by this task. Default value is 'false'"
|
||||
type: boolean
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
properties {
|
||||
urls {
|
||||
description: "The urls of the files that were uploaded by this task. Returned if the 'return_file_urls' properties was set to True"
|
||||
type: array
|
||||
items {type: string}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
delete {
|
||||
"2.1" {
|
||||
@ -1246,6 +1265,25 @@ delete {
|
||||
}
|
||||
}
|
||||
}
|
||||
"2.13": ${delete."2.1"} {
|
||||
request {
|
||||
properties {
|
||||
return_file_urls {
|
||||
description: "If set to 'true' then return the urls of the files that were uploaded by this task. Default value is 'false'"
|
||||
type: boolean
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
properties {
|
||||
urls {
|
||||
description: "The urls of the files that were uploaded by this task. Returned if the 'return_file_urls' properties was set to True"
|
||||
type: array
|
||||
items {type: string}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
archive {
|
||||
"2.12" {
|
||||
|
@ -14,6 +14,7 @@ from apiserver.apimodels.models import (
|
||||
PublishModelResponse,
|
||||
ModelTaskPublishResponse,
|
||||
GetFrameworksRequest,
|
||||
DeleteModelRequest,
|
||||
)
|
||||
from apiserver.bll.model import ModelBLL
|
||||
from apiserver.bll.organization import OrgBLL, Tags
|
||||
@ -440,14 +441,14 @@ def set_ready(call: APICall, company_id, req_model: PublishModelRequest):
|
||||
)
|
||||
|
||||
|
||||
@endpoint("models.delete", required_fields=["model"])
|
||||
def update(call: APICall, company_id, _):
|
||||
model_id = call.data["model"]
|
||||
force = call.data.get("force", False)
|
||||
@endpoint("models.delete", request_data_model=DeleteModelRequest)
|
||||
def update(call: APICall, company_id, request: DeleteModelRequest):
|
||||
model_id = request.model
|
||||
force = request.force
|
||||
|
||||
with translate_errors_context():
|
||||
query = dict(id=model_id, company=company_id)
|
||||
model = Model.objects(**query).only("id", "task", "project").first()
|
||||
model = Model.objects(**query).only("id", "task", "project", "uri").first()
|
||||
if not model:
|
||||
raise errors.bad_request.InvalidModelId(**query)
|
||||
|
||||
@ -483,7 +484,10 @@ def update(call: APICall, company_id, _):
|
||||
del_count = Model.objects(**query).delete()
|
||||
if del_count:
|
||||
_reset_cached_tags(company_id, projects=[model.project])
|
||||
call.result.data = dict(deleted=del_count > 0)
|
||||
call.result.data = dict(
|
||||
deleted=del_count > 0,
|
||||
url=model.uri if request.return_file_url else None
|
||||
)
|
||||
|
||||
|
||||
@endpoint("models.make_public", min_version="2.9", request_data_model=MakePublicRequest)
|
||||
|
@ -1,11 +1,9 @@
|
||||
from copy import deepcopy
|
||||
from datetime import datetime
|
||||
from operator import attrgetter
|
||||
from typing import Sequence, Callable, Type, TypeVar, Union, Tuple
|
||||
from typing import Sequence, Union, Tuple
|
||||
|
||||
import attr
|
||||
import dpath
|
||||
import mongoengine
|
||||
from mongoengine import EmbeddedDocument, Q
|
||||
from mongoengine.queryset.transform import COMPARISON_OPERATORS
|
||||
from pymongo import UpdateOne
|
||||
@ -55,7 +53,6 @@ from apiserver.bll.task import (
|
||||
TaskBLL,
|
||||
ChangeStatusRequest,
|
||||
update_project_time,
|
||||
split_by,
|
||||
)
|
||||
from apiserver.bll.task.artifacts import (
|
||||
artifacts_prepare_for_save,
|
||||
@ -69,11 +66,11 @@ from apiserver.bll.task.param_utils import (
|
||||
params_unprepare_from_saved,
|
||||
escape_paths,
|
||||
)
|
||||
from apiserver.bll.task.utils import update_task
|
||||
from apiserver.bll.task.task_cleanup import cleanup_task
|
||||
from apiserver.bll.task.utils import update_task, task_deleted_prefix
|
||||
from apiserver.bll.util import SetFieldsResolver
|
||||
from apiserver.database.errors import translate_errors_context
|
||||
from apiserver.database.model import EntityVisibility
|
||||
from apiserver.database.model.model import Model
|
||||
from apiserver.database.model.task.output import Output
|
||||
from apiserver.database.model.task.task import (
|
||||
Task,
|
||||
@ -386,6 +383,9 @@ def _validate_and_get_task_from_call(call: APICall, **kwargs) -> Tuple[Task, dic
|
||||
|
||||
@endpoint("tasks.validate", request_data_model=CreateRequest)
|
||||
def validate(call: APICall, company_id, req_model: CreateRequest):
|
||||
parent = call.data.get("parent")
|
||||
if parent and parent.startswith(task_deleted_prefix):
|
||||
call.data.pop("parent")
|
||||
_validate_and_get_task_from_call(call)
|
||||
|
||||
|
||||
@ -849,7 +849,12 @@ def reset(call: APICall, company_id, request: ResetRequest):
|
||||
if dequeued:
|
||||
api_results.update(dequeued=dequeued)
|
||||
|
||||
cleaned_up = cleanup_task(task, force=force, update_children=False)
|
||||
cleaned_up = cleanup_task(
|
||||
task,
|
||||
force=force,
|
||||
update_children=False,
|
||||
return_file_urls=request.return_file_urls,
|
||||
)
|
||||
api_results.update(attr.asdict(cleaned_up))
|
||||
|
||||
updates.update(
|
||||
@ -937,146 +942,6 @@ def archive(call: APICall, company_id, request: ArchiveRequest):
|
||||
call.result.data_model = ArchiveResponse(archived=archived)
|
||||
|
||||
|
||||
class DocumentGroup(list):
|
||||
"""
|
||||
Operate on a list of documents as if they were a query result
|
||||
"""
|
||||
|
||||
def __init__(self, document_type, documents):
|
||||
super(DocumentGroup, self).__init__(documents)
|
||||
self.type = document_type
|
||||
|
||||
def objects(self, *args, **kwargs):
|
||||
return self.type.objects(id__in=[obj.id for obj in self], *args, **kwargs)
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class TaskOutputs(object):
|
||||
"""
|
||||
Split task outputs of the same type by the ready state
|
||||
"""
|
||||
|
||||
published = None # type: DocumentGroup
|
||||
draft = None # type: DocumentGroup
|
||||
|
||||
def __init__(self, is_published, document_type, children):
|
||||
# type: (Callable[[T], bool], Type[mongoengine.Document], Sequence[T]) -> ()
|
||||
"""
|
||||
:param is_published: predicate returning whether items is considered published
|
||||
:param document_type: type of output
|
||||
:param children: output documents
|
||||
"""
|
||||
self.published, self.draft = map(
|
||||
lambda x: DocumentGroup(document_type, x), split_by(is_published, children)
|
||||
)
|
||||
|
||||
|
||||
@attr.s
|
||||
class CleanupResult(object):
|
||||
"""
|
||||
Counts of objects modified in task cleanup operation
|
||||
"""
|
||||
|
||||
updated_children = attr.ib(type=int)
|
||||
updated_models = attr.ib(type=int)
|
||||
deleted_models = attr.ib(type=int)
|
||||
|
||||
|
||||
def cleanup_task(task: Task, force: bool = False, update_children=True):
|
||||
"""
|
||||
Validate task deletion and delete/modify all its output.
|
||||
:param task: task object
|
||||
:param force: whether to delete task with published outputs
|
||||
:return: count of delete and modified items
|
||||
"""
|
||||
models, child_tasks = get_outputs_for_deletion(task, force)
|
||||
deleted_task_id = trash_task_id(task.id)
|
||||
if child_tasks and update_children:
|
||||
with TimingContext("mongo", "update_task_children"):
|
||||
updated_children = child_tasks.update(parent=deleted_task_id)
|
||||
else:
|
||||
updated_children = 0
|
||||
|
||||
if models.draft:
|
||||
with TimingContext("mongo", "delete_models"):
|
||||
deleted_models = models.draft.objects().delete()
|
||||
else:
|
||||
deleted_models = 0
|
||||
|
||||
if models.published and update_children:
|
||||
with TimingContext("mongo", "update_task_models"):
|
||||
updated_models = models.published.objects().update(task=deleted_task_id)
|
||||
else:
|
||||
updated_models = 0
|
||||
|
||||
event_bll.delete_task_events(task.company, task.id, allow_locked=force)
|
||||
|
||||
return CleanupResult(
|
||||
deleted_models=deleted_models,
|
||||
updated_children=updated_children,
|
||||
updated_models=updated_models,
|
||||
)
|
||||
|
||||
|
||||
def get_outputs_for_deletion(task, force=False):
|
||||
with TimingContext("mongo", "get_task_models"):
|
||||
models = TaskOutputs(
|
||||
attrgetter("ready"),
|
||||
Model,
|
||||
Model.objects(task=task.id).only("id", "task", "ready"),
|
||||
)
|
||||
if not force and models.published:
|
||||
raise errors.bad_request.TaskCannotBeDeleted(
|
||||
"has output models, use force=True",
|
||||
task=task.id,
|
||||
models=len(models.published),
|
||||
)
|
||||
|
||||
if task.output.model:
|
||||
output_model = get_output_model(task, force)
|
||||
if output_model:
|
||||
if output_model.ready:
|
||||
models.published.append(output_model)
|
||||
else:
|
||||
models.draft.append(output_model)
|
||||
|
||||
if models.draft:
|
||||
with TimingContext("mongo", "get_execution_models"):
|
||||
model_ids = [m.id for m in models.draft]
|
||||
dependent_tasks = Task.objects(execution__model__in=model_ids).only(
|
||||
"id", "execution.model"
|
||||
)
|
||||
busy_models = [t.execution.model for t in dependent_tasks]
|
||||
models.draft[:] = [m for m in models.draft if m.id not in busy_models]
|
||||
|
||||
with TimingContext("mongo", "get_task_children"):
|
||||
tasks = Task.objects(parent=task.id).only("id", "parent", "status")
|
||||
published_tasks = [
|
||||
task for task in tasks if task.status == TaskStatus.published
|
||||
]
|
||||
if not force and published_tasks:
|
||||
raise errors.bad_request.TaskCannotBeDeleted(
|
||||
"has children, use force=True", task=task.id, children=published_tasks
|
||||
)
|
||||
return models, tasks
|
||||
|
||||
|
||||
def get_output_model(task, force=False):
|
||||
with TimingContext("mongo", "get_task_output_model"):
|
||||
output_model = Model.objects(id=task.output.model).first()
|
||||
if output_model and output_model.ready and not force:
|
||||
raise errors.bad_request.TaskCannotBeDeleted(
|
||||
"has output model, use force=True", task=task.id, model=task.output.model
|
||||
)
|
||||
return output_model
|
||||
|
||||
|
||||
def trash_task_id(task_id):
|
||||
return "__DELETED__{}".format(task_id)
|
||||
|
||||
|
||||
@endpoint("tasks.delete", request_data_model=DeleteRequest)
|
||||
def delete(call: APICall, company_id, req_model: DeleteRequest):
|
||||
task = TaskBLL.get_task_with_access(
|
||||
@ -1095,7 +960,9 @@ def delete(call: APICall, company_id, req_model: DeleteRequest):
|
||||
)
|
||||
|
||||
with translate_errors_context():
|
||||
result = cleanup_task(task, force=force)
|
||||
result = cleanup_task(
|
||||
task, force=force, return_file_urls=req_model.return_file_urls
|
||||
)
|
||||
|
||||
if move_to_trash:
|
||||
collection_name = task._get_collection_name()
|
||||
|
@ -204,7 +204,9 @@ class TestTaskEvents(TestService):
|
||||
self.send_batch(events)
|
||||
for key in None, "iter", "timestamp", "iso_time":
|
||||
with self.subTest(key=key):
|
||||
data = self.api.events.scalar_metrics_iter_histogram(task=task, key=key)
|
||||
data = self.api.events.scalar_metrics_iter_histogram(
|
||||
task=task, **(dict(key=key) if key is not None else {})
|
||||
)
|
||||
self.assertIn(metric, data)
|
||||
self.assertIn(variant, data[metric])
|
||||
self.assertIn("x", data[metric][variant])
|
||||
|
@ -1,95 +1,185 @@
|
||||
from parameterized import parameterized
|
||||
from typing import Set
|
||||
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.apierrors import errors
|
||||
from apiserver.es_factory import es_factory
|
||||
from apiserver.tests.automated import TestService
|
||||
|
||||
log = config.logger(__file__)
|
||||
|
||||
|
||||
continuations = (
|
||||
(lambda self, task: self.tasks.reset(task=task),),
|
||||
(lambda self, task: self.tasks.delete(task=task),),
|
||||
)
|
||||
|
||||
|
||||
def reset_and_delete():
|
||||
"""
|
||||
Parametrize a test for both delete and reset operations,
|
||||
which should yield the same results.
|
||||
NOTE: "parameterized" engages in call stack manipulation,
|
||||
so be careful when changing the application of the decorator.
|
||||
For example, receiving "func" as a parameter and passing it to
|
||||
"expand" doesn't work.
|
||||
"""
|
||||
return parameterized.expand(
|
||||
[
|
||||
(lambda self, task: self.tasks.delete(task=task),),
|
||||
(lambda self, task: self.tasks.reset(task=task),),
|
||||
],
|
||||
name_func=lambda func, num, _: "{}_{}".format(
|
||||
func.__name__, ["delete", "reset"][int(num)]
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class TestTasksResetDelete(TestService):
|
||||
def setUp(self, **kwargs):
|
||||
super().setUp(version="2.11")
|
||||
|
||||
TASK_CANNOT_BE_DELETED_CODES = (400, 123)
|
||||
def test_delete(self):
|
||||
# draft task can be deleted
|
||||
task = self.new_task()
|
||||
res = self.assert_delete_task(task)
|
||||
self.assertIsNone(res.get("urls"))
|
||||
# published task can be deleted only with force flag
|
||||
task = self.new_task()
|
||||
self.publish_task(task)
|
||||
with self.api.raises(errors.bad_request.TaskCannotBeDeleted):
|
||||
self.assert_delete_task(task)
|
||||
self.assert_delete_task(task, force=True)
|
||||
|
||||
def setUp(self, version="1.7"):
|
||||
super(TestTasksResetDelete, self).setUp(version=version)
|
||||
self.tasks = self.api.tasks
|
||||
self.models = self.api.models
|
||||
# task with published children can only be deleted with force flag
|
||||
task = self.new_task()
|
||||
child = self.new_task(parent=task)
|
||||
self.publish_task(child)
|
||||
with self.api.raises(errors.bad_request.TaskCannotBeDeleted):
|
||||
self.assert_delete_task(task)
|
||||
res = self.assert_delete_task(task, force=True)
|
||||
self.assertEqual(res.updated_children, 1)
|
||||
# make sure that the child model is valid after the parent deletion
|
||||
self.api.tasks.validate(**self.api.tasks.get_by_id(task=child).task)
|
||||
|
||||
# task with published model can only be deleted with force flag
|
||||
task = self.new_task()
|
||||
model = self.new_model()
|
||||
self.api.models.edit(model=model, task=task, ready=True)
|
||||
with self.api.raises(errors.bad_request.TaskCannotBeDeleted):
|
||||
self.assert_delete_task(task)
|
||||
res = self.assert_delete_task(task, force=True)
|
||||
self.assertEqual(res.updated_models, 1)
|
||||
|
||||
def test_return_file_urls(self):
|
||||
# empty task
|
||||
task = self.new_task()
|
||||
res = self.assert_delete_task(task, return_file_urls=True)
|
||||
self.assertEqual(res.urls.model_urls, [])
|
||||
self.assertEqual(res.urls.event_urls, [])
|
||||
self.assertEqual(res.urls.artifact_urls, [])
|
||||
|
||||
task = self.new_task()
|
||||
model_urls = self.create_task_models(task)
|
||||
artifact_urls = self.send_artifacts(task)
|
||||
event_urls = self.send_debug_image_events(task)
|
||||
event_urls.update(self.send_plot_events(task))
|
||||
res = self.assert_delete_task(task, force=True, return_file_urls=True)
|
||||
self.assertEqual(set(res.urls.model_urls), model_urls)
|
||||
self.assertEqual(set(res.urls.event_urls), event_urls)
|
||||
self.assertEqual(set(res.urls.artifact_urls), artifact_urls)
|
||||
|
||||
def test_reset(self):
|
||||
# draft task can be deleted
|
||||
task = self.new_task()
|
||||
res = self.api.tasks.reset(task=task)
|
||||
self.assertFalse(res.get("urls"))
|
||||
|
||||
# published task can be reset only with force flag
|
||||
task = self.new_task()
|
||||
self.publish_task(task)
|
||||
with self.api.raises(errors.bad_request.InvalidTaskStatus):
|
||||
self.api.tasks.reset(task=task)
|
||||
self.api.tasks.reset(task=task, force=True)
|
||||
|
||||
# test urls
|
||||
task = self.new_task()
|
||||
model_urls = self.create_task_models(task)
|
||||
artifact_urls = self.send_artifacts(task)
|
||||
event_urls = self.send_debug_image_events(task)
|
||||
event_urls.update(self.send_plot_events(task))
|
||||
res = self.api.tasks.reset(task=task, force=True, return_file_urls=True)
|
||||
self.assertEqual(set(res.urls.model_urls), model_urls)
|
||||
self.assertEqual(set(res.urls.event_urls), event_urls)
|
||||
self.assertEqual(set(res.urls.artifact_urls), artifact_urls)
|
||||
|
||||
def test_model_delete(self):
|
||||
model = self.new_model(uri="test")
|
||||
res = self.api.models.delete(model=model, return_file_url=True)
|
||||
self.assertEqual(res.url, "test")
|
||||
|
||||
def assert_delete_task(self, task_id, force=False, return_file_urls=False):
|
||||
tasks = self.api.tasks.get_all_ex(id=[task_id]).tasks
|
||||
self.assertEqual(tasks[0].id, task_id)
|
||||
res = self.api.tasks.delete(
|
||||
task=task_id, force=force, return_file_urls=return_file_urls
|
||||
)
|
||||
self.assertTrue(res.deleted)
|
||||
tasks = self.api.tasks.get_all_ex(id=[task_id]).tasks
|
||||
self.assertEqual(tasks, [])
|
||||
return res
|
||||
|
||||
def create_task_models(self, task) -> Set[str]:
|
||||
"""
|
||||
Update models from task and return only non public models
|
||||
"""
|
||||
model_ready = self.new_model(uri="ready")
|
||||
model_not_ready = self.new_model(uri="not_ready", ready=False)
|
||||
self.api.models.edit(model=model_not_ready, task=task)
|
||||
self.api.models.edit(model=model_ready, task=task)
|
||||
return {"not_ready"}
|
||||
|
||||
def send_artifacts(self, task) -> Set[str]:
|
||||
"""
|
||||
Add input and output artifacts and return output artifact names
|
||||
"""
|
||||
artifacts = [
|
||||
dict(key="a", type="str", uri="test1", mode="input"),
|
||||
dict(key="b", type="int", uri="test2"),
|
||||
]
|
||||
# test create/get and get_all
|
||||
self.api.tasks.add_or_update_artifacts(task=task, artifacts=artifacts)
|
||||
return {"test2"}
|
||||
|
||||
def send_debug_image_events(self, task) -> Set[str]:
|
||||
events = [
|
||||
self.create_event(task, "training_debug_image", iteration, url=f"url_{iteration}")
|
||||
for iteration in range(5)
|
||||
]
|
||||
self.send_batch(events)
|
||||
return set(ev["url"] for ev in events)
|
||||
|
||||
def send_plot_events(self, task) -> Set[str]:
|
||||
plots = [
|
||||
'{"data": [], "layout": {"xaxis": {"visible": false, "range": [0, 640]}, "yaxis": {"visible": false, "range": [0, 514], "scaleanchor": "x"}, "margin": {"l": 0, "r": 0, "t": 64, "b": 0}, "images": [{"sizex": 640, "sizey": 514, "xref": "x", "yref": "y", "opacity": 1.0, "x": 0, "y": 514, "sizing": "contain", "layer": "below", "source": "https://files.community-master.hosted.allegro.ai/examples/XGBoost%20simple%20example.35abd481a6ea4a6a976c217e80191dcd/metrics/Feature%20importance/plot%20image/Feature%20importance_plot%20image_00000000.png"}], "showlegend": false, "title": "Feature importance/plot image", "name": null}}',
|
||||
'{"data": [], "layout": {"xaxis": {"visible": false, "range": [0, 640]}, "yaxis": {"visible": false, "range": [0, 200], "scaleanchor": "x"}, "margin": {"l": 0, "r": 0, "t": 64, "b": 0}, "images": [{"sizex": 640, "sizey": 200, "xref": "x", "yref": "y", "opacity": 1.0, "x": 0, "y": 200, "sizing": "contain", "layer": "below", "source": "https://files.community-master.hosted.allegro.ai/examples/XGBoost%20simple%20example.35abd481a6ea4a6a976c217e80191dcd/metrics/untitled%2000/plot%20image/untitled%2000_plot%20image_00000000.jpeg"}], "showlegend": false, "title": "untitled 00/plot image", "name": null}}',
|
||||
'{"data": [{"y": ["lying", "sitting", "standing", "people", "backgroun"], "x": ["lying", "sitting", "standing", "people", "backgroun"], "z": [[758, 163, 0, 0, 23], [63, 858, 3, 0, 0], [0, 50, 188, 21, 35], [0, 22, 8, 40, 4], [12, 91, 26, 29, 368]], "type": "heatmap"}], "layout": {"title": "Confusion Matrix for iter 100", "xaxis": {"title": "Predicted value"}, "yaxis": {"title": "Real value"}}}',
|
||||
]
|
||||
events = [
|
||||
self.create_event(task, "plot", iteration, plot_str=plot_str)
|
||||
for iteration, plot_str in enumerate(plots)
|
||||
]
|
||||
self.send_batch(events)
|
||||
return {
|
||||
"https://files.community-master.hosted.allegro.ai/examples/XGBoost%20simple%20example.35abd481a6ea4a6a976c217e80191dcd/metrics/Feature%20importance/plot%20image/Feature%20importance_plot%20image_00000000.png",
|
||||
"https://files.community-master.hosted.allegro.ai/examples/XGBoost%20simple%20example.35abd481a6ea4a6a976c217e80191dcd/metrics/untitled%2000/plot%20image/untitled%2000_plot%20image_00000000.jpeg",
|
||||
}
|
||||
|
||||
def create_event(self, task, type_, iteration, **kwargs) -> dict:
|
||||
return {
|
||||
"worker": "test",
|
||||
"type": type_,
|
||||
"task": task,
|
||||
"iter": iteration,
|
||||
"timestamp": es_factory.get_timestamp_millis(),
|
||||
"metric": "Metric1",
|
||||
"variant": "Variant1",
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
def send_batch(self, events):
|
||||
_, data = self.api.send_batch("events.add_batch", events)
|
||||
return data
|
||||
|
||||
def new_task(self, **kwargs):
|
||||
task_id = self.tasks.create(
|
||||
type='testing', name='server-test', input=dict(view=dict()), **kwargs
|
||||
)['id']
|
||||
self.defer(self.tasks.delete, can_fail=True, task=task_id, force=True)
|
||||
return task_id
|
||||
return self.create_temp(
|
||||
"tasks",
|
||||
delete_params=dict(can_fail=True),
|
||||
type="testing",
|
||||
name="test task delete",
|
||||
input=dict(view=dict()),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def new_model(self, **kwargs):
|
||||
model_id = self.models.create(name='test', uri='file:///a', labels={}, **kwargs)['id']
|
||||
self.defer(self.models.delete, can_fail=True, model=model_id, force=True)
|
||||
return model_id
|
||||
self.update_missing(kwargs, name="test", uri="file:///a/b", labels={})
|
||||
return self.create_temp(
|
||||
"models",
|
||||
delete_params=dict(can_fail=True),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def delete_failure(self):
|
||||
return self.api.raises(self.TASK_CANNOT_BE_DELETED_CODES)
|
||||
|
||||
def publish_created_task(self, task_id):
|
||||
self.tasks.started(task=task_id)
|
||||
self.tasks.stopped(task=task_id)
|
||||
self.tasks.publish(task=task_id)
|
||||
|
||||
@reset_and_delete()
|
||||
def test_plain(self, cont):
|
||||
cont(self, self.new_task())
|
||||
|
||||
@reset_and_delete()
|
||||
def test_draft_child(self, cont):
|
||||
parent = self.new_task()
|
||||
self.new_task(parent=parent)
|
||||
cont(self, parent)
|
||||
|
||||
@reset_and_delete()
|
||||
def test_published_child(self, cont):
|
||||
parent = self.new_task()
|
||||
child = self.new_task(parent=parent)
|
||||
self.publish_created_task(child)
|
||||
with self.delete_failure():
|
||||
cont(self, parent)
|
||||
|
||||
@reset_and_delete()
|
||||
def test_draft_model(self, cont):
|
||||
task = self.new_task()
|
||||
model = self.new_model()
|
||||
self.models.edit(model=model, task=task, ready=False)
|
||||
cont(self, task)
|
||||
|
||||
@reset_and_delete()
|
||||
def test_published_model(self, cont):
|
||||
task = self.new_task()
|
||||
model = self.new_model()
|
||||
self.models.edit(model=model, task=task, ready=True)
|
||||
with self.delete_failure():
|
||||
cont(self, task)
|
||||
def publish_task(self, task_id):
|
||||
self.api.tasks.started(task=task_id)
|
||||
self.api.tasks.stopped(task=task_id)
|
||||
self.api.tasks.publish(task=task_id)
|
||||
|
@ -1,2 +1 @@
|
||||
nose==1.3.7
|
||||
parameterized>=0.7.1
|
||||
|
Loading…
Reference in New Issue
Block a user