Return file urls for tasks.delete/reset and models.delete

This commit is contained in:
allegroai 2021-05-03 17:38:09 +03:00
parent 78e4a58c91
commit 8b464e7ae6
15 changed files with 575 additions and 265 deletions

View File

@ -218,7 +218,7 @@ class ActualEnumField(fields.StringField):
) )
def parse_value(self, value): def parse_value(self, value):
if value in (NotSet, None) and not self.required: if value is NotSet and not self.required:
return self.get_default_value() return self.get_default_value()
try: try:
# noinspection PyArgumentList # noinspection PyArgumentList

View File

@ -32,8 +32,16 @@ class CreateModelResponse(models.Base):
created = fields.BoolField(required=True) created = fields.BoolField(required=True)
class PublishModelRequest(models.Base): class ModelRequest(models.Base):
model = fields.StringField(required=True) model = fields.StringField(required=True)
class DeleteModelRequest(ModelRequest):
force = fields.BoolField(default=False)
return_file_url = fields.BoolField(default=False)
class PublishModelRequest(ModelRequest):
force_publish_task = fields.BoolField(default=False) force_publish_task = fields.BoolField(default=False)
publish_task = fields.BoolField(default=True) publish_task = fields.BoolField(default=True)

View File

@ -53,6 +53,7 @@ class ResetResponse(UpdateResponse):
frames = DictField() frames = DictField()
events = DictField() events = DictField()
model_deleted = IntField() model_deleted = IntField()
urls = DictField()
class TaskRequest(models.Base): class TaskRequest(models.Base):
@ -71,6 +72,7 @@ class EnqueueRequest(UpdateRequest):
class DeleteRequest(UpdateRequest): class DeleteRequest(UpdateRequest):
move_to_trash = BoolField(default=True) move_to_trash = BoolField(default=True)
return_file_urls = BoolField(default=False)
class SetRequirementsRequest(TaskRequest): class SetRequirementsRequest(TaskRequest):
@ -137,6 +139,7 @@ class DeleteArtifactsRequest(TaskRequest):
class ResetRequest(UpdateRequest): class ResetRequest(UpdateRequest):
clear_all = BoolField(default=False) clear_all = BoolField(default=False)
return_file_urls = BoolField(default=False)
class MultiTaskRequest(models.Base): class MultiTaskRequest(models.Base):

View File

@ -1,5 +1,6 @@
import base64 import base64
import hashlib import hashlib
import re
import zlib import zlib
from collections import defaultdict from collections import defaultdict
from contextlib import closing from contextlib import closing
@ -36,11 +37,10 @@ from apiserver.redis_manager import redman
from apiserver.timing_context import TimingContext from apiserver.timing_context import TimingContext
from apiserver.tools import safe_get from apiserver.tools import safe_get
from apiserver.utilities.dicts import flatten_nested_items from apiserver.utilities.dicts import flatten_nested_items
# noinspection PyTypeChecker
from apiserver.utilities.json import loads from apiserver.utilities.json import loads
EVENT_TYPES = set(map(attrgetter("value"), EventType)) # noinspection PyTypeChecker
EVENT_TYPES: Set[str] = set(map(attrgetter("value"), EventType))
LOCKED_TASK_STATUSES = (TaskStatus.publishing, TaskStatus.published) LOCKED_TASK_STATUSES = (TaskStatus.publishing, TaskStatus.published)
@ -49,11 +49,16 @@ class PlotFields:
plot_len = "plot_len" plot_len = "plot_len"
plot_str = "plot_str" plot_str = "plot_str"
plot_data = "plot_data" plot_data = "plot_data"
source_urls = "source_urls"
class EventBLL(object): class EventBLL(object):
id_fields = ("task", "iter", "metric", "variant", "key") id_fields = ("task", "iter", "metric", "variant", "key")
empty_scroll = "FFFF" empty_scroll = "FFFF"
img_source_regex = re.compile(
r"['\"]source['\"]:\s?['\"]([a-z][a-z0-9+\-.]*://.*?)['\"]",
flags=re.IGNORECASE,
)
def __init__(self, events_es=None, redis=None): def __init__(self, events_es=None, redis=None):
self.es = events_es or es_factory.connect("events") self.es = events_es or es_factory.connect("events")
@ -269,6 +274,11 @@ class EventBLL(object):
event[PlotFields.plot_len] = plot_len event[PlotFields.plot_len] = plot_len
if validate: if validate:
event[PlotFields.valid_plot] = self._is_valid_json(plot_str) event[PlotFields.valid_plot] = self._is_valid_json(plot_str)
urls = {match for match in self.img_source_regex.findall(plot_str)}
if urls:
event[PlotFields.source_urls] = list(urls)
if compression_threshold and plot_len >= compression_threshold: if compression_threshold and plot_len >= compression_threshold:
event[PlotFields.plot_data] = base64.encodebytes( event[PlotFields.plot_data] = base64.encodebytes(
zlib.compress(plot_str.encode(), level=1) zlib.compress(plot_str.encode(), level=1)
@ -504,7 +514,7 @@ class EventBLL(object):
scroll_id: str = None, scroll_id: str = None,
): ):
if scroll_id == self.empty_scroll: if scroll_id == self.empty_scroll:
return [], scroll_id, 0 return TaskEventsResult()
if scroll_id: if scroll_id:
with translate_errors_context(), TimingContext("es", "get_task_events"): with translate_errors_context(), TimingContext("es", "get_task_events"):
@ -598,6 +608,41 @@ class EventBLL(object):
return events, total_events, next_scroll_id return events, total_events, next_scroll_id
def get_plot_image_urls(
self, company_id: str, task_id: str, scroll_id: Optional[str]
) -> Tuple[Sequence[dict], Optional[str]]:
if scroll_id == self.empty_scroll:
return [], None
if scroll_id:
es_res = self.es.scroll(scroll_id=scroll_id, scroll="10m")
else:
if check_empty_data(self.es, company_id, EventType.metrics_plot):
return [], None
es_req = {
"size": 1000,
"_source": [PlotFields.source_urls],
"query": {
"bool": {
"must": [
{"term": {"task": task_id}},
{"exists": {"field": PlotFields.source_urls}},
]
}
},
}
es_res = search_company_events(
self.es,
company_id=company_id,
event_type=EventType.metrics_plot,
body=es_req,
scroll="10m",
)
events, _, next_scroll_id = self._get_events_from_es_res(es_res)
return events, next_scroll_id
def get_task_events( def get_task_events(
self, self,
company_id: str, company_id: str,

View File

@ -3,5 +3,4 @@ from .utils import (
ChangeStatusRequest, ChangeStatusRequest,
update_project_time, update_project_time,
validate_status_change, validate_status_change,
split_by,
) )

View File

@ -38,7 +38,7 @@ from apiserver.timing_context import TimingContext
from apiserver.utilities.parameter_key_escaper import ParameterKeyEscaper from apiserver.utilities.parameter_key_escaper import ParameterKeyEscaper
from .artifacts import artifacts_prepare_for_save from .artifacts import artifacts_prepare_for_save
from .param_utils import params_prepare_for_save from .param_utils import params_prepare_for_save
from .utils import ChangeStatusRequest, validate_status_change, update_project_time from .utils import ChangeStatusRequest, validate_status_change, update_project_time, task_deleted_prefix
log = config.logger(__file__) log = config.logger(__file__)
org_bll = OrgBLL() org_bll = OrgBLL()
@ -247,6 +247,11 @@ class TaskBLL:
] ]
with TimingContext("mongo", "clone task"): with TimingContext("mongo", "clone task"):
parent_task = (
task.parent
if task.parent and not task.parent.startswith(task_deleted_prefix)
else None
)
new_task = Task( new_task = Task(
id=create_id(), id=create_id(),
user=user_id, user=user_id,
@ -256,7 +261,7 @@ class TaskBLL:
last_change=now, last_change=now,
name=name or task.name, name=name or task.name,
comment=comment or task.comment, comment=comment or task.comment,
parent=parent or task.parent, parent=parent or parent_task,
project=project or task.project, project=project or task.project,
tags=tags or task.tags, tags=tags or task.tags,
system_tags=system_tags or clean_system_tags(task.system_tags), system_tags=system_tags or clean_system_tags(task.system_tags),

View File

@ -0,0 +1,247 @@
import itertools
from collections import defaultdict
from operator import attrgetter
from typing import Sequence, Generic, Callable, Type, Iterable, TypeVar, List, Set
import attr
from boltons.iterutils import partition
from mongoengine import QuerySet, Document
from apiserver.apierrors import errors
from apiserver.bll.event import EventBLL
from apiserver.bll.event.event_bll import PlotFields
from apiserver.bll.event.event_common import EventType
from apiserver.bll.task.utils import task_deleted_prefix
from apiserver.database.model.model import Model
from apiserver.database.model.task.task import Task, TaskStatus, ArtifactModes
from apiserver.timing_context import TimingContext
event_bll = EventBLL()
T = TypeVar("T", bound=Document)
class DocumentGroup(List[T]):
"""
Operate on a list of documents as if they were a query result
"""
def __init__(self, document_type: Type[T], documents: Iterable[T]):
super(DocumentGroup, self).__init__(documents)
self.type = document_type
@property
def ids(self) -> Set[str]:
return {obj.id for obj in self}
def objects(self, *args, **kwargs) -> QuerySet:
return self.type.objects(id__in=self.ids, *args, **kwargs)
class TaskOutputs(Generic[T]):
"""
Split task outputs of the same type by the ready state
"""
published: DocumentGroup[T]
draft: DocumentGroup[T]
def __init__(
self,
is_published: Callable[[T], bool],
document_type: Type[T],
children: Iterable[T],
):
"""
:param is_published: predicate returning whether items is considered published
:param document_type: type of output
:param children: output documents
"""
self.published, self.draft = map(
lambda x: DocumentGroup(document_type, x),
partition(children, key=is_published),
)
@attr.s(auto_attribs=True)
class TaskUrls:
model_urls: Sequence[str]
event_urls: Sequence[str]
artifact_urls: Sequence[str]
@attr.s(auto_attribs=True)
class CleanupResult:
"""
Counts of objects modified in task cleanup operation
"""
updated_children: int
updated_models: int
deleted_models: int
urls: TaskUrls = None
def _collect_plot_image_urls(company: str, task: str) -> Set[str]:
urls = set()
next_scroll_id = None
with TimingContext("es", "collect_plot_image_urls"):
while True:
events, next_scroll_id = event_bll.get_plot_image_urls(
company_id=company, task_id=task, scroll_id=next_scroll_id
)
if not events:
break
for event in events:
event_urls = event.get(PlotFields.source_urls)
if event_urls:
urls.update(set(event_urls))
return urls
def _collect_debug_image_urls(company: str, task: str) -> Set[str]:
"""
Return the set of unique image urls
Uses DebugImagesIterator to make sure that we do not retrieve recycled urls
"""
metrics = event_bll.get_metrics_and_variants(
company_id=company, task_id=task, event_type=EventType.metrics_image
)
if not metrics:
return set()
task_metrics = [(task, metric) for metric in metrics]
scroll_id = None
urls = defaultdict(set)
while True:
res = event_bll.debug_images_iterator.get_task_events(
company_id=company, metrics=task_metrics, iter_count=100, state_id=scroll_id
)
if not res.metric_events or not any(
events for _, _, events in res.metric_events
):
break
scroll_id = res.next_scroll_id
for _, metric, iterations in res.metric_events:
metric_urls = set(ev.get("url") for it in iterations for ev in it["events"])
metric_urls.discard(None)
urls[metric].update(metric_urls)
return set(itertools.chain.from_iterable(urls.values()))
def cleanup_task(
task: Task, force: bool = False, update_children=True, return_file_urls=False
) -> CleanupResult:
"""
Validate task deletion and delete/modify all its output.
:param task: task object
:param force: whether to delete task with published outputs
:return: count of delete and modified items
"""
models = verify_task_children_and_ouptuts(task, force)
event_urls, artifact_urls, model_urls = set(), set(), set()
if return_file_urls:
event_urls = _collect_debug_image_urls(task.company, task.id)
event_urls.update(_collect_plot_image_urls(task.company, task.id))
if task.execution and task.execution.artifacts:
artifact_urls = {
a.uri
for a in task.execution.artifacts.values()
if a.mode == ArtifactModes.output and a.uri
}
model_urls = {m.uri for m in models.draft.objects().only("uri") if m.uri}
deleted_task_id = f"{task_deleted_prefix}{task.id}"
if update_children:
with TimingContext("mongo", "update_task_children"):
updated_children = Task.objects(parent=task.id).update(
parent=deleted_task_id
)
else:
updated_children = 0
if models.draft:
with TimingContext("mongo", "delete_models"):
deleted_models = models.draft.objects().delete()
else:
deleted_models = 0
if models.published and update_children:
with TimingContext("mongo", "update_task_models"):
updated_models = models.published.objects().update(task=deleted_task_id)
else:
updated_models = 0
event_bll.delete_task_events(task.company, task.id, allow_locked=force)
return CleanupResult(
deleted_models=deleted_models,
updated_children=updated_children,
updated_models=updated_models,
urls=TaskUrls(
event_urls=list(event_urls),
artifact_urls=list(artifact_urls),
model_urls=list(model_urls),
)
if return_file_urls
else None,
)
def verify_task_children_and_ouptuts(task, force: bool) -> TaskOutputs[Model]:
if not force:
with TimingContext("mongo", "count_published_children"):
published_children_count = Task.objects(
parent=task.id, status=TaskStatus.published
).count()
if published_children_count:
raise errors.bad_request.TaskCannotBeDeleted(
"has children, use force=True",
task=task.id,
children=published_children_count,
)
with TimingContext("mongo", "get_task_models"):
models = TaskOutputs(
attrgetter("ready"),
Model,
Model.objects(task=task.id).only("id", "task", "ready"),
)
if not force and models.published:
raise errors.bad_request.TaskCannotBeDeleted(
"has output models, use force=True",
task=task.id,
models=len(models.published),
)
if task.output.model:
with TimingContext("mongo", "get_task_output_model"):
output_model = Model.objects(id=task.output.model).first()
if output_model:
if output_model.ready:
if not force:
raise errors.bad_request.TaskCannotBeDeleted(
"has output model, use force=True",
task=task.id,
model=task.output.model,
)
models.published.append(output_model)
else:
models.draft.append(output_model)
if models.draft:
with TimingContext("mongo", "get_execution_models"):
model_ids = models.draft.ids
dependent_tasks = Task.objects(execution__model__in=model_ids).only(
"id", "execution.model"
)
input_models = [t.execution.model for t in dependent_tasks]
if input_models:
models.draft = DocumentGroup(
Model, (m for m in models.draft if m.id not in input_models)
)
return models

View File

@ -1,5 +1,5 @@
from datetime import datetime from datetime import datetime
from typing import TypeVar, Callable, Tuple, Sequence, Union from typing import Sequence, Union
import attr import attr
import six import six
@ -13,6 +13,7 @@ from apiserver.timing_context import TimingContext
from apiserver.utilities.attrs import typed_attrs from apiserver.utilities.attrs import typed_attrs
valid_statuses = get_options(TaskStatus) valid_statuses = get_options(TaskStatus)
task_deleted_prefix = "__DELETED__"
@typed_attrs @typed_attrs
@ -164,22 +165,6 @@ def update_project_time(project_ids: Union[str, Sequence[str]]):
return Project.objects(id__in=project_ids).update(last_update=datetime.utcnow()) return Project.objects(id__in=project_ids).update(last_update=datetime.utcnow())
T = TypeVar("T")
def split_by(
condition: Callable[[T], bool], items: Sequence[T]
) -> Tuple[Sequence[T], Sequence[T]]:
"""
split "items" to two lists by "condition"
"""
applied = zip(map(condition, items), items)
return (
[item for cond, item in applied if cond],
[item for cond, item in applied if not cond],
)
def get_task_for_update( def get_task_for_update(
company_id: str, task_id: str, allow_all_statuses: bool = False, force: bool = False company_id: str, task_id: str, allow_all_statuses: bool = False, force: bool = False
) -> Task: ) -> Task:

View File

@ -697,6 +697,24 @@ delete {
} }
} }
} }
"2.13": ${delete."2.1"} {
request {
properties {
return_file_url {
description: "If set to 'true' then return the url of the model file. Default value is 'false'"
type: boolean
}
}
}
response {
properties {
url {
descrition: "The url of the model file"
type: string
}
}
}
}
} }
make_public { make_public {

View File

@ -1191,6 +1191,25 @@ reset {
} }
} }
} }
"2.13": ${reset."2.1"} {
request {
properties {
return_file_urls {
description: "If set to 'true' then return the urls of the files that were uploaded by this task. Default value is 'false'"
type: boolean
}
}
}
response {
properties {
urls {
description: "The urls of the files that were uploaded by this task. Returned if the 'return_file_urls' properties was set to True"
type: array
items {type: string}
}
}
}
}
} }
delete { delete {
"2.1" { "2.1" {
@ -1246,6 +1265,25 @@ delete {
} }
} }
} }
"2.13": ${delete."2.1"} {
request {
properties {
return_file_urls {
description: "If set to 'true' then return the urls of the files that were uploaded by this task. Default value is 'false'"
type: boolean
}
}
}
response {
properties {
urls {
description: "The urls of the files that were uploaded by this task. Returned if the 'return_file_urls' properties was set to True"
type: array
items {type: string}
}
}
}
}
} }
archive { archive {
"2.12" { "2.12" {

View File

@ -14,6 +14,7 @@ from apiserver.apimodels.models import (
PublishModelResponse, PublishModelResponse,
ModelTaskPublishResponse, ModelTaskPublishResponse,
GetFrameworksRequest, GetFrameworksRequest,
DeleteModelRequest,
) )
from apiserver.bll.model import ModelBLL from apiserver.bll.model import ModelBLL
from apiserver.bll.organization import OrgBLL, Tags from apiserver.bll.organization import OrgBLL, Tags
@ -440,14 +441,14 @@ def set_ready(call: APICall, company_id, req_model: PublishModelRequest):
) )
@endpoint("models.delete", required_fields=["model"]) @endpoint("models.delete", request_data_model=DeleteModelRequest)
def update(call: APICall, company_id, _): def update(call: APICall, company_id, request: DeleteModelRequest):
model_id = call.data["model"] model_id = request.model
force = call.data.get("force", False) force = request.force
with translate_errors_context(): with translate_errors_context():
query = dict(id=model_id, company=company_id) query = dict(id=model_id, company=company_id)
model = Model.objects(**query).only("id", "task", "project").first() model = Model.objects(**query).only("id", "task", "project", "uri").first()
if not model: if not model:
raise errors.bad_request.InvalidModelId(**query) raise errors.bad_request.InvalidModelId(**query)
@ -483,7 +484,10 @@ def update(call: APICall, company_id, _):
del_count = Model.objects(**query).delete() del_count = Model.objects(**query).delete()
if del_count: if del_count:
_reset_cached_tags(company_id, projects=[model.project]) _reset_cached_tags(company_id, projects=[model.project])
call.result.data = dict(deleted=del_count > 0) call.result.data = dict(
deleted=del_count > 0,
url=model.uri if request.return_file_url else None
)
@endpoint("models.make_public", min_version="2.9", request_data_model=MakePublicRequest) @endpoint("models.make_public", min_version="2.9", request_data_model=MakePublicRequest)

View File

@ -1,11 +1,9 @@
from copy import deepcopy from copy import deepcopy
from datetime import datetime from datetime import datetime
from operator import attrgetter from typing import Sequence, Union, Tuple
from typing import Sequence, Callable, Type, TypeVar, Union, Tuple
import attr import attr
import dpath import dpath
import mongoengine
from mongoengine import EmbeddedDocument, Q from mongoengine import EmbeddedDocument, Q
from mongoengine.queryset.transform import COMPARISON_OPERATORS from mongoengine.queryset.transform import COMPARISON_OPERATORS
from pymongo import UpdateOne from pymongo import UpdateOne
@ -55,7 +53,6 @@ from apiserver.bll.task import (
TaskBLL, TaskBLL,
ChangeStatusRequest, ChangeStatusRequest,
update_project_time, update_project_time,
split_by,
) )
from apiserver.bll.task.artifacts import ( from apiserver.bll.task.artifacts import (
artifacts_prepare_for_save, artifacts_prepare_for_save,
@ -69,11 +66,11 @@ from apiserver.bll.task.param_utils import (
params_unprepare_from_saved, params_unprepare_from_saved,
escape_paths, escape_paths,
) )
from apiserver.bll.task.utils import update_task from apiserver.bll.task.task_cleanup import cleanup_task
from apiserver.bll.task.utils import update_task, task_deleted_prefix
from apiserver.bll.util import SetFieldsResolver from apiserver.bll.util import SetFieldsResolver
from apiserver.database.errors import translate_errors_context from apiserver.database.errors import translate_errors_context
from apiserver.database.model import EntityVisibility from apiserver.database.model import EntityVisibility
from apiserver.database.model.model import Model
from apiserver.database.model.task.output import Output from apiserver.database.model.task.output import Output
from apiserver.database.model.task.task import ( from apiserver.database.model.task.task import (
Task, Task,
@ -386,6 +383,9 @@ def _validate_and_get_task_from_call(call: APICall, **kwargs) -> Tuple[Task, dic
@endpoint("tasks.validate", request_data_model=CreateRequest) @endpoint("tasks.validate", request_data_model=CreateRequest)
def validate(call: APICall, company_id, req_model: CreateRequest): def validate(call: APICall, company_id, req_model: CreateRequest):
parent = call.data.get("parent")
if parent and parent.startswith(task_deleted_prefix):
call.data.pop("parent")
_validate_and_get_task_from_call(call) _validate_and_get_task_from_call(call)
@ -849,7 +849,12 @@ def reset(call: APICall, company_id, request: ResetRequest):
if dequeued: if dequeued:
api_results.update(dequeued=dequeued) api_results.update(dequeued=dequeued)
cleaned_up = cleanup_task(task, force=force, update_children=False) cleaned_up = cleanup_task(
task,
force=force,
update_children=False,
return_file_urls=request.return_file_urls,
)
api_results.update(attr.asdict(cleaned_up)) api_results.update(attr.asdict(cleaned_up))
updates.update( updates.update(
@ -937,146 +942,6 @@ def archive(call: APICall, company_id, request: ArchiveRequest):
call.result.data_model = ArchiveResponse(archived=archived) call.result.data_model = ArchiveResponse(archived=archived)
class DocumentGroup(list):
"""
Operate on a list of documents as if they were a query result
"""
def __init__(self, document_type, documents):
super(DocumentGroup, self).__init__(documents)
self.type = document_type
def objects(self, *args, **kwargs):
return self.type.objects(id__in=[obj.id for obj in self], *args, **kwargs)
T = TypeVar("T")
class TaskOutputs(object):
"""
Split task outputs of the same type by the ready state
"""
published = None # type: DocumentGroup
draft = None # type: DocumentGroup
def __init__(self, is_published, document_type, children):
# type: (Callable[[T], bool], Type[mongoengine.Document], Sequence[T]) -> ()
"""
:param is_published: predicate returning whether items is considered published
:param document_type: type of output
:param children: output documents
"""
self.published, self.draft = map(
lambda x: DocumentGroup(document_type, x), split_by(is_published, children)
)
@attr.s
class CleanupResult(object):
"""
Counts of objects modified in task cleanup operation
"""
updated_children = attr.ib(type=int)
updated_models = attr.ib(type=int)
deleted_models = attr.ib(type=int)
def cleanup_task(task: Task, force: bool = False, update_children=True):
"""
Validate task deletion and delete/modify all its output.
:param task: task object
:param force: whether to delete task with published outputs
:return: count of delete and modified items
"""
models, child_tasks = get_outputs_for_deletion(task, force)
deleted_task_id = trash_task_id(task.id)
if child_tasks and update_children:
with TimingContext("mongo", "update_task_children"):
updated_children = child_tasks.update(parent=deleted_task_id)
else:
updated_children = 0
if models.draft:
with TimingContext("mongo", "delete_models"):
deleted_models = models.draft.objects().delete()
else:
deleted_models = 0
if models.published and update_children:
with TimingContext("mongo", "update_task_models"):
updated_models = models.published.objects().update(task=deleted_task_id)
else:
updated_models = 0
event_bll.delete_task_events(task.company, task.id, allow_locked=force)
return CleanupResult(
deleted_models=deleted_models,
updated_children=updated_children,
updated_models=updated_models,
)
def get_outputs_for_deletion(task, force=False):
with TimingContext("mongo", "get_task_models"):
models = TaskOutputs(
attrgetter("ready"),
Model,
Model.objects(task=task.id).only("id", "task", "ready"),
)
if not force and models.published:
raise errors.bad_request.TaskCannotBeDeleted(
"has output models, use force=True",
task=task.id,
models=len(models.published),
)
if task.output.model:
output_model = get_output_model(task, force)
if output_model:
if output_model.ready:
models.published.append(output_model)
else:
models.draft.append(output_model)
if models.draft:
with TimingContext("mongo", "get_execution_models"):
model_ids = [m.id for m in models.draft]
dependent_tasks = Task.objects(execution__model__in=model_ids).only(
"id", "execution.model"
)
busy_models = [t.execution.model for t in dependent_tasks]
models.draft[:] = [m for m in models.draft if m.id not in busy_models]
with TimingContext("mongo", "get_task_children"):
tasks = Task.objects(parent=task.id).only("id", "parent", "status")
published_tasks = [
task for task in tasks if task.status == TaskStatus.published
]
if not force and published_tasks:
raise errors.bad_request.TaskCannotBeDeleted(
"has children, use force=True", task=task.id, children=published_tasks
)
return models, tasks
def get_output_model(task, force=False):
with TimingContext("mongo", "get_task_output_model"):
output_model = Model.objects(id=task.output.model).first()
if output_model and output_model.ready and not force:
raise errors.bad_request.TaskCannotBeDeleted(
"has output model, use force=True", task=task.id, model=task.output.model
)
return output_model
def trash_task_id(task_id):
return "__DELETED__{}".format(task_id)
@endpoint("tasks.delete", request_data_model=DeleteRequest) @endpoint("tasks.delete", request_data_model=DeleteRequest)
def delete(call: APICall, company_id, req_model: DeleteRequest): def delete(call: APICall, company_id, req_model: DeleteRequest):
task = TaskBLL.get_task_with_access( task = TaskBLL.get_task_with_access(
@ -1095,7 +960,9 @@ def delete(call: APICall, company_id, req_model: DeleteRequest):
) )
with translate_errors_context(): with translate_errors_context():
result = cleanup_task(task, force=force) result = cleanup_task(
task, force=force, return_file_urls=req_model.return_file_urls
)
if move_to_trash: if move_to_trash:
collection_name = task._get_collection_name() collection_name = task._get_collection_name()

View File

@ -204,7 +204,9 @@ class TestTaskEvents(TestService):
self.send_batch(events) self.send_batch(events)
for key in None, "iter", "timestamp", "iso_time": for key in None, "iter", "timestamp", "iso_time":
with self.subTest(key=key): with self.subTest(key=key):
data = self.api.events.scalar_metrics_iter_histogram(task=task, key=key) data = self.api.events.scalar_metrics_iter_histogram(
task=task, **(dict(key=key) if key is not None else {})
)
self.assertIn(metric, data) self.assertIn(metric, data)
self.assertIn(variant, data[metric]) self.assertIn(variant, data[metric])
self.assertIn("x", data[metric][variant]) self.assertIn("x", data[metric][variant])

View File

@ -1,95 +1,185 @@
from parameterized import parameterized from typing import Set
from apiserver.config_repo import config from apiserver.apierrors import errors
from apiserver.es_factory import es_factory
from apiserver.tests.automated import TestService from apiserver.tests.automated import TestService
log = config.logger(__file__)
continuations = (
(lambda self, task: self.tasks.reset(task=task),),
(lambda self, task: self.tasks.delete(task=task),),
)
def reset_and_delete():
"""
Parametrize a test for both delete and reset operations,
which should yield the same results.
NOTE: "parameterized" engages in call stack manipulation,
so be careful when changing the application of the decorator.
For example, receiving "func" as a parameter and passing it to
"expand" doesn't work.
"""
return parameterized.expand(
[
(lambda self, task: self.tasks.delete(task=task),),
(lambda self, task: self.tasks.reset(task=task),),
],
name_func=lambda func, num, _: "{}_{}".format(
func.__name__, ["delete", "reset"][int(num)]
),
)
class TestTasksResetDelete(TestService): class TestTasksResetDelete(TestService):
def setUp(self, **kwargs):
super().setUp(version="2.11")
TASK_CANNOT_BE_DELETED_CODES = (400, 123) def test_delete(self):
# draft task can be deleted
task = self.new_task()
res = self.assert_delete_task(task)
self.assertIsNone(res.get("urls"))
# published task can be deleted only with force flag
task = self.new_task()
self.publish_task(task)
with self.api.raises(errors.bad_request.TaskCannotBeDeleted):
self.assert_delete_task(task)
self.assert_delete_task(task, force=True)
def setUp(self, version="1.7"): # task with published children can only be deleted with force flag
super(TestTasksResetDelete, self).setUp(version=version) task = self.new_task()
self.tasks = self.api.tasks child = self.new_task(parent=task)
self.models = self.api.models self.publish_task(child)
with self.api.raises(errors.bad_request.TaskCannotBeDeleted):
self.assert_delete_task(task)
res = self.assert_delete_task(task, force=True)
self.assertEqual(res.updated_children, 1)
# make sure that the child model is valid after the parent deletion
self.api.tasks.validate(**self.api.tasks.get_by_id(task=child).task)
# task with published model can only be deleted with force flag
task = self.new_task()
model = self.new_model()
self.api.models.edit(model=model, task=task, ready=True)
with self.api.raises(errors.bad_request.TaskCannotBeDeleted):
self.assert_delete_task(task)
res = self.assert_delete_task(task, force=True)
self.assertEqual(res.updated_models, 1)
def test_return_file_urls(self):
# empty task
task = self.new_task()
res = self.assert_delete_task(task, return_file_urls=True)
self.assertEqual(res.urls.model_urls, [])
self.assertEqual(res.urls.event_urls, [])
self.assertEqual(res.urls.artifact_urls, [])
task = self.new_task()
model_urls = self.create_task_models(task)
artifact_urls = self.send_artifacts(task)
event_urls = self.send_debug_image_events(task)
event_urls.update(self.send_plot_events(task))
res = self.assert_delete_task(task, force=True, return_file_urls=True)
self.assertEqual(set(res.urls.model_urls), model_urls)
self.assertEqual(set(res.urls.event_urls), event_urls)
self.assertEqual(set(res.urls.artifact_urls), artifact_urls)
def test_reset(self):
# draft task can be deleted
task = self.new_task()
res = self.api.tasks.reset(task=task)
self.assertFalse(res.get("urls"))
# published task can be reset only with force flag
task = self.new_task()
self.publish_task(task)
with self.api.raises(errors.bad_request.InvalidTaskStatus):
self.api.tasks.reset(task=task)
self.api.tasks.reset(task=task, force=True)
# test urls
task = self.new_task()
model_urls = self.create_task_models(task)
artifact_urls = self.send_artifacts(task)
event_urls = self.send_debug_image_events(task)
event_urls.update(self.send_plot_events(task))
res = self.api.tasks.reset(task=task, force=True, return_file_urls=True)
self.assertEqual(set(res.urls.model_urls), model_urls)
self.assertEqual(set(res.urls.event_urls), event_urls)
self.assertEqual(set(res.urls.artifact_urls), artifact_urls)
def test_model_delete(self):
model = self.new_model(uri="test")
res = self.api.models.delete(model=model, return_file_url=True)
self.assertEqual(res.url, "test")
def assert_delete_task(self, task_id, force=False, return_file_urls=False):
tasks = self.api.tasks.get_all_ex(id=[task_id]).tasks
self.assertEqual(tasks[0].id, task_id)
res = self.api.tasks.delete(
task=task_id, force=force, return_file_urls=return_file_urls
)
self.assertTrue(res.deleted)
tasks = self.api.tasks.get_all_ex(id=[task_id]).tasks
self.assertEqual(tasks, [])
return res
def create_task_models(self, task) -> Set[str]:
"""
Update models from task and return only non public models
"""
model_ready = self.new_model(uri="ready")
model_not_ready = self.new_model(uri="not_ready", ready=False)
self.api.models.edit(model=model_not_ready, task=task)
self.api.models.edit(model=model_ready, task=task)
return {"not_ready"}
def send_artifacts(self, task) -> Set[str]:
"""
Add input and output artifacts and return output artifact names
"""
artifacts = [
dict(key="a", type="str", uri="test1", mode="input"),
dict(key="b", type="int", uri="test2"),
]
# test create/get and get_all
self.api.tasks.add_or_update_artifacts(task=task, artifacts=artifacts)
return {"test2"}
def send_debug_image_events(self, task) -> Set[str]:
events = [
self.create_event(task, "training_debug_image", iteration, url=f"url_{iteration}")
for iteration in range(5)
]
self.send_batch(events)
return set(ev["url"] for ev in events)
def send_plot_events(self, task) -> Set[str]:
plots = [
'{"data": [], "layout": {"xaxis": {"visible": false, "range": [0, 640]}, "yaxis": {"visible": false, "range": [0, 514], "scaleanchor": "x"}, "margin": {"l": 0, "r": 0, "t": 64, "b": 0}, "images": [{"sizex": 640, "sizey": 514, "xref": "x", "yref": "y", "opacity": 1.0, "x": 0, "y": 514, "sizing": "contain", "layer": "below", "source": "https://files.community-master.hosted.allegro.ai/examples/XGBoost%20simple%20example.35abd481a6ea4a6a976c217e80191dcd/metrics/Feature%20importance/plot%20image/Feature%20importance_plot%20image_00000000.png"}], "showlegend": false, "title": "Feature importance/plot image", "name": null}}',
'{"data": [], "layout": {"xaxis": {"visible": false, "range": [0, 640]}, "yaxis": {"visible": false, "range": [0, 200], "scaleanchor": "x"}, "margin": {"l": 0, "r": 0, "t": 64, "b": 0}, "images": [{"sizex": 640, "sizey": 200, "xref": "x", "yref": "y", "opacity": 1.0, "x": 0, "y": 200, "sizing": "contain", "layer": "below", "source": "https://files.community-master.hosted.allegro.ai/examples/XGBoost%20simple%20example.35abd481a6ea4a6a976c217e80191dcd/metrics/untitled%2000/plot%20image/untitled%2000_plot%20image_00000000.jpeg"}], "showlegend": false, "title": "untitled 00/plot image", "name": null}}',
'{"data": [{"y": ["lying", "sitting", "standing", "people", "backgroun"], "x": ["lying", "sitting", "standing", "people", "backgroun"], "z": [[758, 163, 0, 0, 23], [63, 858, 3, 0, 0], [0, 50, 188, 21, 35], [0, 22, 8, 40, 4], [12, 91, 26, 29, 368]], "type": "heatmap"}], "layout": {"title": "Confusion Matrix for iter 100", "xaxis": {"title": "Predicted value"}, "yaxis": {"title": "Real value"}}}',
]
events = [
self.create_event(task, "plot", iteration, plot_str=plot_str)
for iteration, plot_str in enumerate(plots)
]
self.send_batch(events)
return {
"https://files.community-master.hosted.allegro.ai/examples/XGBoost%20simple%20example.35abd481a6ea4a6a976c217e80191dcd/metrics/Feature%20importance/plot%20image/Feature%20importance_plot%20image_00000000.png",
"https://files.community-master.hosted.allegro.ai/examples/XGBoost%20simple%20example.35abd481a6ea4a6a976c217e80191dcd/metrics/untitled%2000/plot%20image/untitled%2000_plot%20image_00000000.jpeg",
}
def create_event(self, task, type_, iteration, **kwargs) -> dict:
return {
"worker": "test",
"type": type_,
"task": task,
"iter": iteration,
"timestamp": es_factory.get_timestamp_millis(),
"metric": "Metric1",
"variant": "Variant1",
**kwargs,
}
def send_batch(self, events):
_, data = self.api.send_batch("events.add_batch", events)
return data
def new_task(self, **kwargs): def new_task(self, **kwargs):
task_id = self.tasks.create( return self.create_temp(
type='testing', name='server-test', input=dict(view=dict()), **kwargs "tasks",
)['id'] delete_params=dict(can_fail=True),
self.defer(self.tasks.delete, can_fail=True, task=task_id, force=True) type="testing",
return task_id name="test task delete",
input=dict(view=dict()),
**kwargs,
)
def new_model(self, **kwargs): def new_model(self, **kwargs):
model_id = self.models.create(name='test', uri='file:///a', labels={}, **kwargs)['id'] self.update_missing(kwargs, name="test", uri="file:///a/b", labels={})
self.defer(self.models.delete, can_fail=True, model=model_id, force=True) return self.create_temp(
return model_id "models",
delete_params=dict(can_fail=True),
**kwargs,
)
def delete_failure(self): def publish_task(self, task_id):
return self.api.raises(self.TASK_CANNOT_BE_DELETED_CODES) self.api.tasks.started(task=task_id)
self.api.tasks.stopped(task=task_id)
def publish_created_task(self, task_id): self.api.tasks.publish(task=task_id)
self.tasks.started(task=task_id)
self.tasks.stopped(task=task_id)
self.tasks.publish(task=task_id)
@reset_and_delete()
def test_plain(self, cont):
cont(self, self.new_task())
@reset_and_delete()
def test_draft_child(self, cont):
parent = self.new_task()
self.new_task(parent=parent)
cont(self, parent)
@reset_and_delete()
def test_published_child(self, cont):
parent = self.new_task()
child = self.new_task(parent=parent)
self.publish_created_task(child)
with self.delete_failure():
cont(self, parent)
@reset_and_delete()
def test_draft_model(self, cont):
task = self.new_task()
model = self.new_model()
self.models.edit(model=model, task=task, ready=False)
cont(self, task)
@reset_and_delete()
def test_published_model(self, cont):
task = self.new_task()
model = self.new_model()
self.models.edit(model=model, task=task, ready=True)
with self.delete_failure():
cont(self, task)

View File

@ -1,2 +1 @@
nose==1.3.7 nose==1.3.7
parameterized>=0.7.1