From 8b464e7ae67c78932b7c537b0c73bc6ba46b3362 Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Mon, 3 May 2021 17:38:09 +0300 Subject: [PATCH] Return file urls for tasks.delete/reset and models.delete --- apiserver/apimodels/__init__.py | 2 +- apiserver/apimodels/models.py | 10 +- apiserver/apimodels/tasks.py | 3 + apiserver/bll/event/event_bll.py | 53 +++- apiserver/bll/task/__init__.py | 1 - apiserver/bll/task/task_bll.py | 9 +- apiserver/bll/task/task_cleanup.py | 247 +++++++++++++++++ apiserver/bll/task/utils.py | 19 +- apiserver/schema/services/models.conf | 18 ++ apiserver/schema/services/tasks.conf | 38 +++ apiserver/services/models.py | 16 +- apiserver/services/tasks.py | 163 +---------- apiserver/tests/automated/test_task_events.py | 4 +- .../tests/automated/test_tasks_delete.py | 256 ++++++++++++------ apiserver/tests/requirements.txt | 1 - 15 files changed, 575 insertions(+), 265 deletions(-) create mode 100644 apiserver/bll/task/task_cleanup.py diff --git a/apiserver/apimodels/__init__.py b/apiserver/apimodels/__init__.py index c59d9cd..50dcb48 100644 --- a/apiserver/apimodels/__init__.py +++ b/apiserver/apimodels/__init__.py @@ -218,7 +218,7 @@ class ActualEnumField(fields.StringField): ) def parse_value(self, value): - if value in (NotSet, None) and not self.required: + if value is NotSet and not self.required: return self.get_default_value() try: # noinspection PyArgumentList diff --git a/apiserver/apimodels/models.py b/apiserver/apimodels/models.py index db74ab2..2ec4451 100644 --- a/apiserver/apimodels/models.py +++ b/apiserver/apimodels/models.py @@ -32,8 +32,16 @@ class CreateModelResponse(models.Base): created = fields.BoolField(required=True) -class PublishModelRequest(models.Base): +class ModelRequest(models.Base): model = fields.StringField(required=True) + + +class DeleteModelRequest(ModelRequest): + force = fields.BoolField(default=False) + return_file_url = fields.BoolField(default=False) + + +class PublishModelRequest(ModelRequest): force_publish_task = fields.BoolField(default=False) publish_task = fields.BoolField(default=True) diff --git a/apiserver/apimodels/tasks.py b/apiserver/apimodels/tasks.py index 96fef42..cc7a027 100644 --- a/apiserver/apimodels/tasks.py +++ b/apiserver/apimodels/tasks.py @@ -53,6 +53,7 @@ class ResetResponse(UpdateResponse): frames = DictField() events = DictField() model_deleted = IntField() + urls = DictField() class TaskRequest(models.Base): @@ -71,6 +72,7 @@ class EnqueueRequest(UpdateRequest): class DeleteRequest(UpdateRequest): move_to_trash = BoolField(default=True) + return_file_urls = BoolField(default=False) class SetRequirementsRequest(TaskRequest): @@ -137,6 +139,7 @@ class DeleteArtifactsRequest(TaskRequest): class ResetRequest(UpdateRequest): clear_all = BoolField(default=False) + return_file_urls = BoolField(default=False) class MultiTaskRequest(models.Base): diff --git a/apiserver/bll/event/event_bll.py b/apiserver/bll/event/event_bll.py index 5e3a2f3..98a647b 100644 --- a/apiserver/bll/event/event_bll.py +++ b/apiserver/bll/event/event_bll.py @@ -1,5 +1,6 @@ import base64 import hashlib +import re import zlib from collections import defaultdict from contextlib import closing @@ -36,11 +37,10 @@ from apiserver.redis_manager import redman from apiserver.timing_context import TimingContext from apiserver.tools import safe_get from apiserver.utilities.dicts import flatten_nested_items - -# noinspection PyTypeChecker from apiserver.utilities.json import loads -EVENT_TYPES = set(map(attrgetter("value"), EventType)) +# noinspection PyTypeChecker +EVENT_TYPES: Set[str] = set(map(attrgetter("value"), EventType)) LOCKED_TASK_STATUSES = (TaskStatus.publishing, TaskStatus.published) @@ -49,11 +49,16 @@ class PlotFields: plot_len = "plot_len" plot_str = "plot_str" plot_data = "plot_data" + source_urls = "source_urls" class EventBLL(object): id_fields = ("task", "iter", "metric", "variant", "key") empty_scroll = "FFFF" + img_source_regex = re.compile( + r"['\"]source['\"]:\s?['\"]([a-z][a-z0-9+\-.]*://.*?)['\"]", + flags=re.IGNORECASE, + ) def __init__(self, events_es=None, redis=None): self.es = events_es or es_factory.connect("events") @@ -269,6 +274,11 @@ class EventBLL(object): event[PlotFields.plot_len] = plot_len if validate: event[PlotFields.valid_plot] = self._is_valid_json(plot_str) + + urls = {match for match in self.img_source_regex.findall(plot_str)} + if urls: + event[PlotFields.source_urls] = list(urls) + if compression_threshold and plot_len >= compression_threshold: event[PlotFields.plot_data] = base64.encodebytes( zlib.compress(plot_str.encode(), level=1) @@ -504,7 +514,7 @@ class EventBLL(object): scroll_id: str = None, ): if scroll_id == self.empty_scroll: - return [], scroll_id, 0 + return TaskEventsResult() if scroll_id: with translate_errors_context(), TimingContext("es", "get_task_events"): @@ -598,6 +608,41 @@ class EventBLL(object): return events, total_events, next_scroll_id + def get_plot_image_urls( + self, company_id: str, task_id: str, scroll_id: Optional[str] + ) -> Tuple[Sequence[dict], Optional[str]]: + if scroll_id == self.empty_scroll: + return [], None + + if scroll_id: + es_res = self.es.scroll(scroll_id=scroll_id, scroll="10m") + else: + if check_empty_data(self.es, company_id, EventType.metrics_plot): + return [], None + + es_req = { + "size": 1000, + "_source": [PlotFields.source_urls], + "query": { + "bool": { + "must": [ + {"term": {"task": task_id}}, + {"exists": {"field": PlotFields.source_urls}}, + ] + } + }, + } + es_res = search_company_events( + self.es, + company_id=company_id, + event_type=EventType.metrics_plot, + body=es_req, + scroll="10m", + ) + + events, _, next_scroll_id = self._get_events_from_es_res(es_res) + return events, next_scroll_id + def get_task_events( self, company_id: str, diff --git a/apiserver/bll/task/__init__.py b/apiserver/bll/task/__init__.py index 544b289..8ce5c16 100644 --- a/apiserver/bll/task/__init__.py +++ b/apiserver/bll/task/__init__.py @@ -3,5 +3,4 @@ from .utils import ( ChangeStatusRequest, update_project_time, validate_status_change, - split_by, ) diff --git a/apiserver/bll/task/task_bll.py b/apiserver/bll/task/task_bll.py index a220523..be64004 100644 --- a/apiserver/bll/task/task_bll.py +++ b/apiserver/bll/task/task_bll.py @@ -38,7 +38,7 @@ from apiserver.timing_context import TimingContext from apiserver.utilities.parameter_key_escaper import ParameterKeyEscaper from .artifacts import artifacts_prepare_for_save from .param_utils import params_prepare_for_save -from .utils import ChangeStatusRequest, validate_status_change, update_project_time +from .utils import ChangeStatusRequest, validate_status_change, update_project_time, task_deleted_prefix log = config.logger(__file__) org_bll = OrgBLL() @@ -247,6 +247,11 @@ class TaskBLL: ] with TimingContext("mongo", "clone task"): + parent_task = ( + task.parent + if task.parent and not task.parent.startswith(task_deleted_prefix) + else None + ) new_task = Task( id=create_id(), user=user_id, @@ -256,7 +261,7 @@ class TaskBLL: last_change=now, name=name or task.name, comment=comment or task.comment, - parent=parent or task.parent, + parent=parent or parent_task, project=project or task.project, tags=tags or task.tags, system_tags=system_tags or clean_system_tags(task.system_tags), diff --git a/apiserver/bll/task/task_cleanup.py b/apiserver/bll/task/task_cleanup.py new file mode 100644 index 0000000..96aa90e --- /dev/null +++ b/apiserver/bll/task/task_cleanup.py @@ -0,0 +1,247 @@ +import itertools +from collections import defaultdict +from operator import attrgetter +from typing import Sequence, Generic, Callable, Type, Iterable, TypeVar, List, Set + +import attr +from boltons.iterutils import partition +from mongoengine import QuerySet, Document + +from apiserver.apierrors import errors +from apiserver.bll.event import EventBLL +from apiserver.bll.event.event_bll import PlotFields +from apiserver.bll.event.event_common import EventType +from apiserver.bll.task.utils import task_deleted_prefix +from apiserver.database.model.model import Model +from apiserver.database.model.task.task import Task, TaskStatus, ArtifactModes +from apiserver.timing_context import TimingContext + +event_bll = EventBLL() +T = TypeVar("T", bound=Document) + + +class DocumentGroup(List[T]): + """ + Operate on a list of documents as if they were a query result + """ + + def __init__(self, document_type: Type[T], documents: Iterable[T]): + super(DocumentGroup, self).__init__(documents) + self.type = document_type + + @property + def ids(self) -> Set[str]: + return {obj.id for obj in self} + + def objects(self, *args, **kwargs) -> QuerySet: + return self.type.objects(id__in=self.ids, *args, **kwargs) + + +class TaskOutputs(Generic[T]): + """ + Split task outputs of the same type by the ready state + """ + + published: DocumentGroup[T] + draft: DocumentGroup[T] + + def __init__( + self, + is_published: Callable[[T], bool], + document_type: Type[T], + children: Iterable[T], + ): + """ + :param is_published: predicate returning whether items is considered published + :param document_type: type of output + :param children: output documents + """ + self.published, self.draft = map( + lambda x: DocumentGroup(document_type, x), + partition(children, key=is_published), + ) + + +@attr.s(auto_attribs=True) +class TaskUrls: + model_urls: Sequence[str] + event_urls: Sequence[str] + artifact_urls: Sequence[str] + + +@attr.s(auto_attribs=True) +class CleanupResult: + """ + Counts of objects modified in task cleanup operation + """ + + updated_children: int + updated_models: int + deleted_models: int + urls: TaskUrls = None + + +def _collect_plot_image_urls(company: str, task: str) -> Set[str]: + urls = set() + next_scroll_id = None + with TimingContext("es", "collect_plot_image_urls"): + while True: + events, next_scroll_id = event_bll.get_plot_image_urls( + company_id=company, task_id=task, scroll_id=next_scroll_id + ) + if not events: + break + for event in events: + event_urls = event.get(PlotFields.source_urls) + if event_urls: + urls.update(set(event_urls)) + + return urls + + +def _collect_debug_image_urls(company: str, task: str) -> Set[str]: + """ + Return the set of unique image urls + Uses DebugImagesIterator to make sure that we do not retrieve recycled urls + """ + metrics = event_bll.get_metrics_and_variants( + company_id=company, task_id=task, event_type=EventType.metrics_image + ) + if not metrics: + return set() + + task_metrics = [(task, metric) for metric in metrics] + scroll_id = None + urls = defaultdict(set) + while True: + res = event_bll.debug_images_iterator.get_task_events( + company_id=company, metrics=task_metrics, iter_count=100, state_id=scroll_id + ) + if not res.metric_events or not any( + events for _, _, events in res.metric_events + ): + break + + scroll_id = res.next_scroll_id + for _, metric, iterations in res.metric_events: + metric_urls = set(ev.get("url") for it in iterations for ev in it["events"]) + metric_urls.discard(None) + urls[metric].update(metric_urls) + + return set(itertools.chain.from_iterable(urls.values())) + + +def cleanup_task( + task: Task, force: bool = False, update_children=True, return_file_urls=False +) -> CleanupResult: + """ + Validate task deletion and delete/modify all its output. + :param task: task object + :param force: whether to delete task with published outputs + :return: count of delete and modified items + """ + models = verify_task_children_and_ouptuts(task, force) + + event_urls, artifact_urls, model_urls = set(), set(), set() + if return_file_urls: + event_urls = _collect_debug_image_urls(task.company, task.id) + event_urls.update(_collect_plot_image_urls(task.company, task.id)) + if task.execution and task.execution.artifacts: + artifact_urls = { + a.uri + for a in task.execution.artifacts.values() + if a.mode == ArtifactModes.output and a.uri + } + model_urls = {m.uri for m in models.draft.objects().only("uri") if m.uri} + + deleted_task_id = f"{task_deleted_prefix}{task.id}" + if update_children: + with TimingContext("mongo", "update_task_children"): + updated_children = Task.objects(parent=task.id).update( + parent=deleted_task_id + ) + else: + updated_children = 0 + + if models.draft: + with TimingContext("mongo", "delete_models"): + deleted_models = models.draft.objects().delete() + else: + deleted_models = 0 + + if models.published and update_children: + with TimingContext("mongo", "update_task_models"): + updated_models = models.published.objects().update(task=deleted_task_id) + else: + updated_models = 0 + + event_bll.delete_task_events(task.company, task.id, allow_locked=force) + + return CleanupResult( + deleted_models=deleted_models, + updated_children=updated_children, + updated_models=updated_models, + urls=TaskUrls( + event_urls=list(event_urls), + artifact_urls=list(artifact_urls), + model_urls=list(model_urls), + ) + if return_file_urls + else None, + ) + + +def verify_task_children_and_ouptuts(task, force: bool) -> TaskOutputs[Model]: + if not force: + with TimingContext("mongo", "count_published_children"): + published_children_count = Task.objects( + parent=task.id, status=TaskStatus.published + ).count() + if published_children_count: + raise errors.bad_request.TaskCannotBeDeleted( + "has children, use force=True", + task=task.id, + children=published_children_count, + ) + + with TimingContext("mongo", "get_task_models"): + models = TaskOutputs( + attrgetter("ready"), + Model, + Model.objects(task=task.id).only("id", "task", "ready"), + ) + if not force and models.published: + raise errors.bad_request.TaskCannotBeDeleted( + "has output models, use force=True", + task=task.id, + models=len(models.published), + ) + + if task.output.model: + with TimingContext("mongo", "get_task_output_model"): + output_model = Model.objects(id=task.output.model).first() + if output_model: + if output_model.ready: + if not force: + raise errors.bad_request.TaskCannotBeDeleted( + "has output model, use force=True", + task=task.id, + model=task.output.model, + ) + models.published.append(output_model) + else: + models.draft.append(output_model) + + if models.draft: + with TimingContext("mongo", "get_execution_models"): + model_ids = models.draft.ids + dependent_tasks = Task.objects(execution__model__in=model_ids).only( + "id", "execution.model" + ) + input_models = [t.execution.model for t in dependent_tasks] + if input_models: + models.draft = DocumentGroup( + Model, (m for m in models.draft if m.id not in input_models) + ) + + return models diff --git a/apiserver/bll/task/utils.py b/apiserver/bll/task/utils.py index 51c3a2c..419a0ce 100644 --- a/apiserver/bll/task/utils.py +++ b/apiserver/bll/task/utils.py @@ -1,5 +1,5 @@ from datetime import datetime -from typing import TypeVar, Callable, Tuple, Sequence, Union +from typing import Sequence, Union import attr import six @@ -13,6 +13,7 @@ from apiserver.timing_context import TimingContext from apiserver.utilities.attrs import typed_attrs valid_statuses = get_options(TaskStatus) +task_deleted_prefix = "__DELETED__" @typed_attrs @@ -164,22 +165,6 @@ def update_project_time(project_ids: Union[str, Sequence[str]]): return Project.objects(id__in=project_ids).update(last_update=datetime.utcnow()) -T = TypeVar("T") - - -def split_by( - condition: Callable[[T], bool], items: Sequence[T] -) -> Tuple[Sequence[T], Sequence[T]]: - """ - split "items" to two lists by "condition" - """ - applied = zip(map(condition, items), items) - return ( - [item for cond, item in applied if cond], - [item for cond, item in applied if not cond], - ) - - def get_task_for_update( company_id: str, task_id: str, allow_all_statuses: bool = False, force: bool = False ) -> Task: diff --git a/apiserver/schema/services/models.conf b/apiserver/schema/services/models.conf index 6a6c1fd..fde6840 100644 --- a/apiserver/schema/services/models.conf +++ b/apiserver/schema/services/models.conf @@ -697,6 +697,24 @@ delete { } } } + "2.13": ${delete."2.1"} { + request { + properties { + return_file_url { + description: "If set to 'true' then return the url of the model file. Default value is 'false'" + type: boolean + } + } + } + response { + properties { + url { + descrition: "The url of the model file" + type: string + } + } + } + } } make_public { diff --git a/apiserver/schema/services/tasks.conf b/apiserver/schema/services/tasks.conf index 73ca1bb..e0a5b89 100644 --- a/apiserver/schema/services/tasks.conf +++ b/apiserver/schema/services/tasks.conf @@ -1191,6 +1191,25 @@ reset { } } } + "2.13": ${reset."2.1"} { + request { + properties { + return_file_urls { + description: "If set to 'true' then return the urls of the files that were uploaded by this task. Default value is 'false'" + type: boolean + } + } + } + response { + properties { + urls { + description: "The urls of the files that were uploaded by this task. Returned if the 'return_file_urls' properties was set to True" + type: array + items {type: string} + } + } + } + } } delete { "2.1" { @@ -1246,6 +1265,25 @@ delete { } } } + "2.13": ${delete."2.1"} { + request { + properties { + return_file_urls { + description: "If set to 'true' then return the urls of the files that were uploaded by this task. Default value is 'false'" + type: boolean + } + } + } + response { + properties { + urls { + description: "The urls of the files that were uploaded by this task. Returned if the 'return_file_urls' properties was set to True" + type: array + items {type: string} + } + } + } + } } archive { "2.12" { diff --git a/apiserver/services/models.py b/apiserver/services/models.py index 40f6ef2..2e8dbba 100644 --- a/apiserver/services/models.py +++ b/apiserver/services/models.py @@ -14,6 +14,7 @@ from apiserver.apimodels.models import ( PublishModelResponse, ModelTaskPublishResponse, GetFrameworksRequest, + DeleteModelRequest, ) from apiserver.bll.model import ModelBLL from apiserver.bll.organization import OrgBLL, Tags @@ -440,14 +441,14 @@ def set_ready(call: APICall, company_id, req_model: PublishModelRequest): ) -@endpoint("models.delete", required_fields=["model"]) -def update(call: APICall, company_id, _): - model_id = call.data["model"] - force = call.data.get("force", False) +@endpoint("models.delete", request_data_model=DeleteModelRequest) +def update(call: APICall, company_id, request: DeleteModelRequest): + model_id = request.model + force = request.force with translate_errors_context(): query = dict(id=model_id, company=company_id) - model = Model.objects(**query).only("id", "task", "project").first() + model = Model.objects(**query).only("id", "task", "project", "uri").first() if not model: raise errors.bad_request.InvalidModelId(**query) @@ -483,7 +484,10 @@ def update(call: APICall, company_id, _): del_count = Model.objects(**query).delete() if del_count: _reset_cached_tags(company_id, projects=[model.project]) - call.result.data = dict(deleted=del_count > 0) + call.result.data = dict( + deleted=del_count > 0, + url=model.uri if request.return_file_url else None + ) @endpoint("models.make_public", min_version="2.9", request_data_model=MakePublicRequest) diff --git a/apiserver/services/tasks.py b/apiserver/services/tasks.py index f3b7c71..f41e0e4 100644 --- a/apiserver/services/tasks.py +++ b/apiserver/services/tasks.py @@ -1,11 +1,9 @@ from copy import deepcopy from datetime import datetime -from operator import attrgetter -from typing import Sequence, Callable, Type, TypeVar, Union, Tuple +from typing import Sequence, Union, Tuple import attr import dpath -import mongoengine from mongoengine import EmbeddedDocument, Q from mongoengine.queryset.transform import COMPARISON_OPERATORS from pymongo import UpdateOne @@ -55,7 +53,6 @@ from apiserver.bll.task import ( TaskBLL, ChangeStatusRequest, update_project_time, - split_by, ) from apiserver.bll.task.artifacts import ( artifacts_prepare_for_save, @@ -69,11 +66,11 @@ from apiserver.bll.task.param_utils import ( params_unprepare_from_saved, escape_paths, ) -from apiserver.bll.task.utils import update_task +from apiserver.bll.task.task_cleanup import cleanup_task +from apiserver.bll.task.utils import update_task, task_deleted_prefix from apiserver.bll.util import SetFieldsResolver from apiserver.database.errors import translate_errors_context from apiserver.database.model import EntityVisibility -from apiserver.database.model.model import Model from apiserver.database.model.task.output import Output from apiserver.database.model.task.task import ( Task, @@ -386,6 +383,9 @@ def _validate_and_get_task_from_call(call: APICall, **kwargs) -> Tuple[Task, dic @endpoint("tasks.validate", request_data_model=CreateRequest) def validate(call: APICall, company_id, req_model: CreateRequest): + parent = call.data.get("parent") + if parent and parent.startswith(task_deleted_prefix): + call.data.pop("parent") _validate_and_get_task_from_call(call) @@ -849,7 +849,12 @@ def reset(call: APICall, company_id, request: ResetRequest): if dequeued: api_results.update(dequeued=dequeued) - cleaned_up = cleanup_task(task, force=force, update_children=False) + cleaned_up = cleanup_task( + task, + force=force, + update_children=False, + return_file_urls=request.return_file_urls, + ) api_results.update(attr.asdict(cleaned_up)) updates.update( @@ -937,146 +942,6 @@ def archive(call: APICall, company_id, request: ArchiveRequest): call.result.data_model = ArchiveResponse(archived=archived) -class DocumentGroup(list): - """ - Operate on a list of documents as if they were a query result - """ - - def __init__(self, document_type, documents): - super(DocumentGroup, self).__init__(documents) - self.type = document_type - - def objects(self, *args, **kwargs): - return self.type.objects(id__in=[obj.id for obj in self], *args, **kwargs) - - -T = TypeVar("T") - - -class TaskOutputs(object): - """ - Split task outputs of the same type by the ready state - """ - - published = None # type: DocumentGroup - draft = None # type: DocumentGroup - - def __init__(self, is_published, document_type, children): - # type: (Callable[[T], bool], Type[mongoengine.Document], Sequence[T]) -> () - """ - :param is_published: predicate returning whether items is considered published - :param document_type: type of output - :param children: output documents - """ - self.published, self.draft = map( - lambda x: DocumentGroup(document_type, x), split_by(is_published, children) - ) - - -@attr.s -class CleanupResult(object): - """ - Counts of objects modified in task cleanup operation - """ - - updated_children = attr.ib(type=int) - updated_models = attr.ib(type=int) - deleted_models = attr.ib(type=int) - - -def cleanup_task(task: Task, force: bool = False, update_children=True): - """ - Validate task deletion and delete/modify all its output. - :param task: task object - :param force: whether to delete task with published outputs - :return: count of delete and modified items - """ - models, child_tasks = get_outputs_for_deletion(task, force) - deleted_task_id = trash_task_id(task.id) - if child_tasks and update_children: - with TimingContext("mongo", "update_task_children"): - updated_children = child_tasks.update(parent=deleted_task_id) - else: - updated_children = 0 - - if models.draft: - with TimingContext("mongo", "delete_models"): - deleted_models = models.draft.objects().delete() - else: - deleted_models = 0 - - if models.published and update_children: - with TimingContext("mongo", "update_task_models"): - updated_models = models.published.objects().update(task=deleted_task_id) - else: - updated_models = 0 - - event_bll.delete_task_events(task.company, task.id, allow_locked=force) - - return CleanupResult( - deleted_models=deleted_models, - updated_children=updated_children, - updated_models=updated_models, - ) - - -def get_outputs_for_deletion(task, force=False): - with TimingContext("mongo", "get_task_models"): - models = TaskOutputs( - attrgetter("ready"), - Model, - Model.objects(task=task.id).only("id", "task", "ready"), - ) - if not force and models.published: - raise errors.bad_request.TaskCannotBeDeleted( - "has output models, use force=True", - task=task.id, - models=len(models.published), - ) - - if task.output.model: - output_model = get_output_model(task, force) - if output_model: - if output_model.ready: - models.published.append(output_model) - else: - models.draft.append(output_model) - - if models.draft: - with TimingContext("mongo", "get_execution_models"): - model_ids = [m.id for m in models.draft] - dependent_tasks = Task.objects(execution__model__in=model_ids).only( - "id", "execution.model" - ) - busy_models = [t.execution.model for t in dependent_tasks] - models.draft[:] = [m for m in models.draft if m.id not in busy_models] - - with TimingContext("mongo", "get_task_children"): - tasks = Task.objects(parent=task.id).only("id", "parent", "status") - published_tasks = [ - task for task in tasks if task.status == TaskStatus.published - ] - if not force and published_tasks: - raise errors.bad_request.TaskCannotBeDeleted( - "has children, use force=True", task=task.id, children=published_tasks - ) - return models, tasks - - -def get_output_model(task, force=False): - with TimingContext("mongo", "get_task_output_model"): - output_model = Model.objects(id=task.output.model).first() - if output_model and output_model.ready and not force: - raise errors.bad_request.TaskCannotBeDeleted( - "has output model, use force=True", task=task.id, model=task.output.model - ) - return output_model - - -def trash_task_id(task_id): - return "__DELETED__{}".format(task_id) - - @endpoint("tasks.delete", request_data_model=DeleteRequest) def delete(call: APICall, company_id, req_model: DeleteRequest): task = TaskBLL.get_task_with_access( @@ -1095,7 +960,9 @@ def delete(call: APICall, company_id, req_model: DeleteRequest): ) with translate_errors_context(): - result = cleanup_task(task, force=force) + result = cleanup_task( + task, force=force, return_file_urls=req_model.return_file_urls + ) if move_to_trash: collection_name = task._get_collection_name() diff --git a/apiserver/tests/automated/test_task_events.py b/apiserver/tests/automated/test_task_events.py index f4cce05..984ebcc 100644 --- a/apiserver/tests/automated/test_task_events.py +++ b/apiserver/tests/automated/test_task_events.py @@ -204,7 +204,9 @@ class TestTaskEvents(TestService): self.send_batch(events) for key in None, "iter", "timestamp", "iso_time": with self.subTest(key=key): - data = self.api.events.scalar_metrics_iter_histogram(task=task, key=key) + data = self.api.events.scalar_metrics_iter_histogram( + task=task, **(dict(key=key) if key is not None else {}) + ) self.assertIn(metric, data) self.assertIn(variant, data[metric]) self.assertIn("x", data[metric][variant]) diff --git a/apiserver/tests/automated/test_tasks_delete.py b/apiserver/tests/automated/test_tasks_delete.py index f4bccff..305a17f 100644 --- a/apiserver/tests/automated/test_tasks_delete.py +++ b/apiserver/tests/automated/test_tasks_delete.py @@ -1,95 +1,185 @@ -from parameterized import parameterized +from typing import Set -from apiserver.config_repo import config +from apiserver.apierrors import errors +from apiserver.es_factory import es_factory from apiserver.tests.automated import TestService -log = config.logger(__file__) - - -continuations = ( - (lambda self, task: self.tasks.reset(task=task),), - (lambda self, task: self.tasks.delete(task=task),), -) - - -def reset_and_delete(): - """ - Parametrize a test for both delete and reset operations, - which should yield the same results. - NOTE: "parameterized" engages in call stack manipulation, - so be careful when changing the application of the decorator. - For example, receiving "func" as a parameter and passing it to - "expand" doesn't work. - """ - return parameterized.expand( - [ - (lambda self, task: self.tasks.delete(task=task),), - (lambda self, task: self.tasks.reset(task=task),), - ], - name_func=lambda func, num, _: "{}_{}".format( - func.__name__, ["delete", "reset"][int(num)] - ), - ) - class TestTasksResetDelete(TestService): + def setUp(self, **kwargs): + super().setUp(version="2.11") - TASK_CANNOT_BE_DELETED_CODES = (400, 123) + def test_delete(self): + # draft task can be deleted + task = self.new_task() + res = self.assert_delete_task(task) + self.assertIsNone(res.get("urls")) + # published task can be deleted only with force flag + task = self.new_task() + self.publish_task(task) + with self.api.raises(errors.bad_request.TaskCannotBeDeleted): + self.assert_delete_task(task) + self.assert_delete_task(task, force=True) - def setUp(self, version="1.7"): - super(TestTasksResetDelete, self).setUp(version=version) - self.tasks = self.api.tasks - self.models = self.api.models + # task with published children can only be deleted with force flag + task = self.new_task() + child = self.new_task(parent=task) + self.publish_task(child) + with self.api.raises(errors.bad_request.TaskCannotBeDeleted): + self.assert_delete_task(task) + res = self.assert_delete_task(task, force=True) + self.assertEqual(res.updated_children, 1) + # make sure that the child model is valid after the parent deletion + self.api.tasks.validate(**self.api.tasks.get_by_id(task=child).task) + + # task with published model can only be deleted with force flag + task = self.new_task() + model = self.new_model() + self.api.models.edit(model=model, task=task, ready=True) + with self.api.raises(errors.bad_request.TaskCannotBeDeleted): + self.assert_delete_task(task) + res = self.assert_delete_task(task, force=True) + self.assertEqual(res.updated_models, 1) + + def test_return_file_urls(self): + # empty task + task = self.new_task() + res = self.assert_delete_task(task, return_file_urls=True) + self.assertEqual(res.urls.model_urls, []) + self.assertEqual(res.urls.event_urls, []) + self.assertEqual(res.urls.artifact_urls, []) + + task = self.new_task() + model_urls = self.create_task_models(task) + artifact_urls = self.send_artifacts(task) + event_urls = self.send_debug_image_events(task) + event_urls.update(self.send_plot_events(task)) + res = self.assert_delete_task(task, force=True, return_file_urls=True) + self.assertEqual(set(res.urls.model_urls), model_urls) + self.assertEqual(set(res.urls.event_urls), event_urls) + self.assertEqual(set(res.urls.artifact_urls), artifact_urls) + + def test_reset(self): + # draft task can be deleted + task = self.new_task() + res = self.api.tasks.reset(task=task) + self.assertFalse(res.get("urls")) + + # published task can be reset only with force flag + task = self.new_task() + self.publish_task(task) + with self.api.raises(errors.bad_request.InvalidTaskStatus): + self.api.tasks.reset(task=task) + self.api.tasks.reset(task=task, force=True) + + # test urls + task = self.new_task() + model_urls = self.create_task_models(task) + artifact_urls = self.send_artifacts(task) + event_urls = self.send_debug_image_events(task) + event_urls.update(self.send_plot_events(task)) + res = self.api.tasks.reset(task=task, force=True, return_file_urls=True) + self.assertEqual(set(res.urls.model_urls), model_urls) + self.assertEqual(set(res.urls.event_urls), event_urls) + self.assertEqual(set(res.urls.artifact_urls), artifact_urls) + + def test_model_delete(self): + model = self.new_model(uri="test") + res = self.api.models.delete(model=model, return_file_url=True) + self.assertEqual(res.url, "test") + + def assert_delete_task(self, task_id, force=False, return_file_urls=False): + tasks = self.api.tasks.get_all_ex(id=[task_id]).tasks + self.assertEqual(tasks[0].id, task_id) + res = self.api.tasks.delete( + task=task_id, force=force, return_file_urls=return_file_urls + ) + self.assertTrue(res.deleted) + tasks = self.api.tasks.get_all_ex(id=[task_id]).tasks + self.assertEqual(tasks, []) + return res + + def create_task_models(self, task) -> Set[str]: + """ + Update models from task and return only non public models + """ + model_ready = self.new_model(uri="ready") + model_not_ready = self.new_model(uri="not_ready", ready=False) + self.api.models.edit(model=model_not_ready, task=task) + self.api.models.edit(model=model_ready, task=task) + return {"not_ready"} + + def send_artifacts(self, task) -> Set[str]: + """ + Add input and output artifacts and return output artifact names + """ + artifacts = [ + dict(key="a", type="str", uri="test1", mode="input"), + dict(key="b", type="int", uri="test2"), + ] + # test create/get and get_all + self.api.tasks.add_or_update_artifacts(task=task, artifacts=artifacts) + return {"test2"} + + def send_debug_image_events(self, task) -> Set[str]: + events = [ + self.create_event(task, "training_debug_image", iteration, url=f"url_{iteration}") + for iteration in range(5) + ] + self.send_batch(events) + return set(ev["url"] for ev in events) + + def send_plot_events(self, task) -> Set[str]: + plots = [ + '{"data": [], "layout": {"xaxis": {"visible": false, "range": [0, 640]}, "yaxis": {"visible": false, "range": [0, 514], "scaleanchor": "x"}, "margin": {"l": 0, "r": 0, "t": 64, "b": 0}, "images": [{"sizex": 640, "sizey": 514, "xref": "x", "yref": "y", "opacity": 1.0, "x": 0, "y": 514, "sizing": "contain", "layer": "below", "source": "https://files.community-master.hosted.allegro.ai/examples/XGBoost%20simple%20example.35abd481a6ea4a6a976c217e80191dcd/metrics/Feature%20importance/plot%20image/Feature%20importance_plot%20image_00000000.png"}], "showlegend": false, "title": "Feature importance/plot image", "name": null}}', + '{"data": [], "layout": {"xaxis": {"visible": false, "range": [0, 640]}, "yaxis": {"visible": false, "range": [0, 200], "scaleanchor": "x"}, "margin": {"l": 0, "r": 0, "t": 64, "b": 0}, "images": [{"sizex": 640, "sizey": 200, "xref": "x", "yref": "y", "opacity": 1.0, "x": 0, "y": 200, "sizing": "contain", "layer": "below", "source": "https://files.community-master.hosted.allegro.ai/examples/XGBoost%20simple%20example.35abd481a6ea4a6a976c217e80191dcd/metrics/untitled%2000/plot%20image/untitled%2000_plot%20image_00000000.jpeg"}], "showlegend": false, "title": "untitled 00/plot image", "name": null}}', + '{"data": [{"y": ["lying", "sitting", "standing", "people", "backgroun"], "x": ["lying", "sitting", "standing", "people", "backgroun"], "z": [[758, 163, 0, 0, 23], [63, 858, 3, 0, 0], [0, 50, 188, 21, 35], [0, 22, 8, 40, 4], [12, 91, 26, 29, 368]], "type": "heatmap"}], "layout": {"title": "Confusion Matrix for iter 100", "xaxis": {"title": "Predicted value"}, "yaxis": {"title": "Real value"}}}', + ] + events = [ + self.create_event(task, "plot", iteration, plot_str=plot_str) + for iteration, plot_str in enumerate(plots) + ] + self.send_batch(events) + return { + "https://files.community-master.hosted.allegro.ai/examples/XGBoost%20simple%20example.35abd481a6ea4a6a976c217e80191dcd/metrics/Feature%20importance/plot%20image/Feature%20importance_plot%20image_00000000.png", + "https://files.community-master.hosted.allegro.ai/examples/XGBoost%20simple%20example.35abd481a6ea4a6a976c217e80191dcd/metrics/untitled%2000/plot%20image/untitled%2000_plot%20image_00000000.jpeg", + } + + def create_event(self, task, type_, iteration, **kwargs) -> dict: + return { + "worker": "test", + "type": type_, + "task": task, + "iter": iteration, + "timestamp": es_factory.get_timestamp_millis(), + "metric": "Metric1", + "variant": "Variant1", + **kwargs, + } + + def send_batch(self, events): + _, data = self.api.send_batch("events.add_batch", events) + return data def new_task(self, **kwargs): - task_id = self.tasks.create( - type='testing', name='server-test', input=dict(view=dict()), **kwargs - )['id'] - self.defer(self.tasks.delete, can_fail=True, task=task_id, force=True) - return task_id + return self.create_temp( + "tasks", + delete_params=dict(can_fail=True), + type="testing", + name="test task delete", + input=dict(view=dict()), + **kwargs, + ) def new_model(self, **kwargs): - model_id = self.models.create(name='test', uri='file:///a', labels={}, **kwargs)['id'] - self.defer(self.models.delete, can_fail=True, model=model_id, force=True) - return model_id + self.update_missing(kwargs, name="test", uri="file:///a/b", labels={}) + return self.create_temp( + "models", + delete_params=dict(can_fail=True), + **kwargs, + ) - def delete_failure(self): - return self.api.raises(self.TASK_CANNOT_BE_DELETED_CODES) - - def publish_created_task(self, task_id): - self.tasks.started(task=task_id) - self.tasks.stopped(task=task_id) - self.tasks.publish(task=task_id) - - @reset_and_delete() - def test_plain(self, cont): - cont(self, self.new_task()) - - @reset_and_delete() - def test_draft_child(self, cont): - parent = self.new_task() - self.new_task(parent=parent) - cont(self, parent) - - @reset_and_delete() - def test_published_child(self, cont): - parent = self.new_task() - child = self.new_task(parent=parent) - self.publish_created_task(child) - with self.delete_failure(): - cont(self, parent) - - @reset_and_delete() - def test_draft_model(self, cont): - task = self.new_task() - model = self.new_model() - self.models.edit(model=model, task=task, ready=False) - cont(self, task) - - @reset_and_delete() - def test_published_model(self, cont): - task = self.new_task() - model = self.new_model() - self.models.edit(model=model, task=task, ready=True) - with self.delete_failure(): - cont(self, task) + def publish_task(self, task_id): + self.api.tasks.started(task=task_id) + self.api.tasks.stopped(task=task_id) + self.api.tasks.publish(task=task_id) diff --git a/apiserver/tests/requirements.txt b/apiserver/tests/requirements.txt index 41e3da6..9eccdff 100644 --- a/apiserver/tests/requirements.txt +++ b/apiserver/tests/requirements.txt @@ -1,2 +1 @@ nose==1.3.7 -parameterized>=0.7.1