From f5008d80ad3739eb6fe71802b0607b45162baddf Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Mon, 3 May 2021 17:39:13 +0300 Subject: [PATCH] Optimize and improve tasks/models/projects.delete --- apiserver/apimodels/models.py | 1 - apiserver/apimodels/projects.py | 7 +- apiserver/apimodels/tasks.py | 2 + apiserver/bll/event/event_bll.py | 17 +++ apiserver/bll/project/project_cleanup.py | 137 ++++++++++++++++++ apiserver/bll/task/task_bll.py | 4 +- apiserver/bll/task/task_cleanup.py | 20 ++- apiserver/bll/task/utils.py | 2 +- apiserver/database/model/task/task.py | 1 + apiserver/schema/services/models.conf | 8 - apiserver/schema/services/projects.conf | 43 ++++++ apiserver/schema/services/tasks.conf | 27 +++- apiserver/services/models.py | 25 ++-- apiserver/services/projects.py | 48 +++--- apiserver/services/tasks.py | 18 ++- .../tests/automated/test_tasks_delete.py | 97 +++++++++---- 16 files changed, 350 insertions(+), 107 deletions(-) create mode 100644 apiserver/bll/project/project_cleanup.py diff --git a/apiserver/apimodels/models.py b/apiserver/apimodels/models.py index 2ec4451..dcd6367 100644 --- a/apiserver/apimodels/models.py +++ b/apiserver/apimodels/models.py @@ -38,7 +38,6 @@ class ModelRequest(models.Base): class DeleteModelRequest(ModelRequest): force = fields.BoolField(default=False) - return_file_url = fields.BoolField(default=False) class PublishModelRequest(ModelRequest): diff --git a/apiserver/apimodels/projects.py b/apiserver/apimodels/projects.py index ccbd11e..2070576 100644 --- a/apiserver/apimodels/projects.py +++ b/apiserver/apimodels/projects.py @@ -6,7 +6,12 @@ from apiserver.database.model import EntityVisibility class ProjectReq(models.Base): - project = fields.StringField() + project = fields.StringField(required=True) + + +class DeleteRequest(ProjectReq): + force = fields.BoolField(default=False) + delete_contents = fields.BoolField(default=False) class GetHyperParamReq(ProjectReq): diff --git a/apiserver/apimodels/tasks.py b/apiserver/apimodels/tasks.py index cc7a027..91518ca 100644 --- a/apiserver/apimodels/tasks.py +++ b/apiserver/apimodels/tasks.py @@ -73,6 +73,7 @@ class EnqueueRequest(UpdateRequest): class DeleteRequest(UpdateRequest): move_to_trash = BoolField(default=True) return_file_urls = BoolField(default=False) + delete_output_models = BoolField(default=True) class SetRequirementsRequest(TaskRequest): @@ -140,6 +141,7 @@ class DeleteArtifactsRequest(TaskRequest): class ResetRequest(UpdateRequest): clear_all = BoolField(default=False) return_file_urls = BoolField(default=False) + delete_output_models = BoolField(default=True) class MultiTaskRequest(models.Base): diff --git a/apiserver/bll/event/event_bll.py b/apiserver/bll/event/event_bll.py index 98a647b..b8c5444 100644 --- a/apiserver/bll/event/event_bll.py +++ b/apiserver/bll/event/event_bll.py @@ -943,3 +943,20 @@ class EventBLL(object): ) return es_res.get("deleted", 0) + + def delete_multi_task_events(self, company_id: str, task_ids: Sequence[str]): + """ + Delete mutliple task events. No check is done for tasks write access + so it should be checked by the calling code + """ + es_req = {"query": {"terms": {"task": task_ids}}} + with translate_errors_context(), TimingContext("es", "delete_multi_tasks_events"): + es_res = delete_company_events( + es=self.es, + company_id=company_id, + event_type=EventType.all, + body=es_req, + refresh=True, + ) + + return es_res.get("deleted", 0) diff --git a/apiserver/bll/project/project_cleanup.py b/apiserver/bll/project/project_cleanup.py new file mode 100644 index 0000000..1c5aa38 --- /dev/null +++ b/apiserver/bll/project/project_cleanup.py @@ -0,0 +1,137 @@ +from datetime import datetime +from typing import Tuple, Set + +import attr + +from apiserver.apierrors import errors +from apiserver.bll.event import EventBLL +from apiserver.bll.task.task_cleanup import ( + collect_debug_image_urls, + collect_plot_image_urls, + TaskUrls, +) +from apiserver.config_repo import config +from apiserver.database.model import EntityVisibility +from apiserver.database.model.model import Model +from apiserver.database.model.project import Project +from apiserver.database.model.task.task import Task, ArtifactModes +from apiserver.timing_context import TimingContext + +log = config.logger(__file__) +event_bll = EventBLL() + + +@attr.s(auto_attribs=True) +class DeleteProjectResult: + deleted: int = 0 + disassociated_tasks: int = 0 + deleted_models: int = 0 + deleted_tasks: int = 0 + urls: TaskUrls = None + + +def delete_project( + company: str, project_id: str, force: bool, delete_contents: bool +) -> DeleteProjectResult: + project = Project.get_for_writing(company=company, id=project_id) + if not project: + raise errors.bad_request.InvalidProjectId(id=project_id) + + if not force: + for cls, error in ( + (Task, errors.bad_request.ProjectHasTasks), + (Model, errors.bad_request.ProjectHasModels), + ): + non_archived = cls.objects( + project=project_id, system_tags__nin=[EntityVisibility.archived.value], + ).only("id") + if non_archived: + raise error("use force=true to delete", id=project_id) + + if not delete_contents: + with TimingContext("mongo", "update_children"): + for cls in (Model, Task): + updated_count = cls.objects(project=project_id).update(project=None) + res = DeleteProjectResult(disassociated_tasks=updated_count) + else: + deleted_models, model_urls = _delete_models(project=project_id) + deleted_tasks, event_urls, artifact_urls = _delete_tasks( + company=company, project=project_id + ) + res = DeleteProjectResult( + deleted_tasks=deleted_tasks, + deleted_models=deleted_models, + urls=TaskUrls( + model_urls=list(model_urls), + event_urls=list(event_urls), + artifact_urls=list(artifact_urls), + ), + ) + + res.deleted = Project.objects(id=project_id).delete() + return res + + +def _delete_tasks(company: str, project: str) -> Tuple[int, Set, Set]: + """ + Delete only the task themselves and their non published version. + Child models under the same project are deleted separately. + Children tasks should be deleted in the same api call. + If any child entities are left in another projects then updated their parent task to None + """ + tasks = Task.objects(project=project).only("id", "execution__artifacts") + if not tasks: + return 0, set(), set() + + task_ids = {t.id for t in tasks} + with TimingContext("mongo", "delete_tasks_update_children"): + Task.objects(parent__in=task_ids, project__ne=project).update(parent=None) + Model.objects(task__in=task_ids, project__ne=project).update(task=None) + + event_urls, artifact_urls = set(), set() + for task in tasks: + event_urls.update(collect_debug_image_urls(company, task.id)) + event_urls.update(collect_plot_image_urls(company, task.id)) + if task.execution and task.execution.artifacts: + artifact_urls.update( + { + a.uri + for a in task.execution.artifacts.values() + if a.mode == ArtifactModes.output and a.uri + } + ) + + event_bll.delete_multi_task_events(company, list(task_ids)) + deleted = tasks.delete() + return deleted, event_urls, artifact_urls + + +def _delete_models(project: str) -> Tuple[int, Set[str]]: + """ + Delete project models and update the tasks from other projects + that reference them to reference None. + """ + with TimingContext("mongo", "delete_models"): + models = Model.objects(project=project).only("task", "id", "uri") + if not models: + return 0, set() + + model_ids = {m.id for m in models} + Task.objects(execution__model__in=model_ids, project__ne=project).update( + execution__model=None + ) + + model_tasks = {m.task for m in models if m.task} + if model_tasks: + now = datetime.utcnow() + Task.objects( + id__in=model_tasks, project__ne=project, output__model__in=model_ids + ).update( + output__model=None, + output__error=f"model deleted on {now.isoformat()}", + last_change=now, + ) + + urls = {m.uri for m in models if m.uri} + deleted = models.delete() + return deleted, urls diff --git a/apiserver/bll/task/task_bll.py b/apiserver/bll/task/task_bll.py index be64004..d0e9d27 100644 --- a/apiserver/bll/task/task_bll.py +++ b/apiserver/bll/task/task_bll.py @@ -38,7 +38,7 @@ from apiserver.timing_context import TimingContext from apiserver.utilities.parameter_key_escaper import ParameterKeyEscaper from .artifacts import artifacts_prepare_for_save from .param_utils import params_prepare_for_save -from .utils import ChangeStatusRequest, validate_status_change, update_project_time, task_deleted_prefix +from .utils import ChangeStatusRequest, validate_status_change, update_project_time, deleted_prefix log = config.logger(__file__) org_bll = OrgBLL() @@ -249,7 +249,7 @@ class TaskBLL: with TimingContext("mongo", "clone task"): parent_task = ( task.parent - if task.parent and not task.parent.startswith(task_deleted_prefix) + if task.parent and not task.parent.startswith(deleted_prefix) else None ) new_task = Task( diff --git a/apiserver/bll/task/task_cleanup.py b/apiserver/bll/task/task_cleanup.py index 96aa90e..c813d04 100644 --- a/apiserver/bll/task/task_cleanup.py +++ b/apiserver/bll/task/task_cleanup.py @@ -11,7 +11,7 @@ from apiserver.apierrors import errors from apiserver.bll.event import EventBLL from apiserver.bll.event.event_bll import PlotFields from apiserver.bll.event.event_common import EventType -from apiserver.bll.task.utils import task_deleted_prefix +from apiserver.bll.task.utils import deleted_prefix from apiserver.database.model.model import Model from apiserver.database.model.task.task import Task, TaskStatus, ArtifactModes from apiserver.timing_context import TimingContext @@ -81,7 +81,7 @@ class CleanupResult: urls: TaskUrls = None -def _collect_plot_image_urls(company: str, task: str) -> Set[str]: +def collect_plot_image_urls(company: str, task: str) -> Set[str]: urls = set() next_scroll_id = None with TimingContext("es", "collect_plot_image_urls"): @@ -99,7 +99,7 @@ def _collect_plot_image_urls(company: str, task: str) -> Set[str]: return urls -def _collect_debug_image_urls(company: str, task: str) -> Set[str]: +def collect_debug_image_urls(company: str, task: str) -> Set[str]: """ Return the set of unique image urls Uses DebugImagesIterator to make sure that we do not retrieve recycled urls @@ -132,7 +132,11 @@ def _collect_debug_image_urls(company: str, task: str) -> Set[str]: def cleanup_task( - task: Task, force: bool = False, update_children=True, return_file_urls=False + task: Task, + force: bool = False, + update_children=True, + return_file_urls=False, + delete_output_models=True, ) -> CleanupResult: """ Validate task deletion and delete/modify all its output. @@ -144,8 +148,8 @@ def cleanup_task( event_urls, artifact_urls, model_urls = set(), set(), set() if return_file_urls: - event_urls = _collect_debug_image_urls(task.company, task.id) - event_urls.update(_collect_plot_image_urls(task.company, task.id)) + event_urls = collect_debug_image_urls(task.company, task.id) + event_urls.update(collect_plot_image_urls(task.company, task.id)) if task.execution and task.execution.artifacts: artifact_urls = { a.uri @@ -154,7 +158,7 @@ def cleanup_task( } model_urls = {m.uri for m in models.draft.objects().only("uri") if m.uri} - deleted_task_id = f"{task_deleted_prefix}{task.id}" + deleted_task_id = f"{deleted_prefix}{task.id}" if update_children: with TimingContext("mongo", "update_task_children"): updated_children = Task.objects(parent=task.id).update( @@ -163,7 +167,7 @@ def cleanup_task( else: updated_children = 0 - if models.draft: + if models.draft and delete_output_models: with TimingContext("mongo", "delete_models"): deleted_models = models.draft.objects().delete() else: diff --git a/apiserver/bll/task/utils.py b/apiserver/bll/task/utils.py index 419a0ce..a3d7625 100644 --- a/apiserver/bll/task/utils.py +++ b/apiserver/bll/task/utils.py @@ -13,7 +13,7 @@ from apiserver.timing_context import TimingContext from apiserver.utilities.attrs import typed_attrs valid_statuses = get_options(TaskStatus) -task_deleted_prefix = "__DELETED__" +deleted_prefix = "__DELETED__" @typed_attrs diff --git a/apiserver/database/model/task/task.py b/apiserver/database/model/task/task.py index c2c30e7..0b68af7 100644 --- a/apiserver/database/model/task/task.py +++ b/apiserver/database/model/task/task.py @@ -155,6 +155,7 @@ class Task(AttributedDocument): "active_duration", "parent", "project", + "execution.model", ("company", "name"), ("company", "user"), ("company", "status", "type"), diff --git a/apiserver/schema/services/models.conf b/apiserver/schema/services/models.conf index fde6840..d7fd4d8 100644 --- a/apiserver/schema/services/models.conf +++ b/apiserver/schema/services/models.conf @@ -698,14 +698,6 @@ delete { } } "2.13": ${delete."2.1"} { - request { - properties { - return_file_url { - description: "If set to 'true' then return the url of the model file. Default value is 'false'" - type: boolean - } - } - } response { properties { url { diff --git a/apiserver/schema/services/projects.conf b/apiserver/schema/services/projects.conf index 608a11c..28c1650 100644 --- a/apiserver/schema/services/projects.conf +++ b/apiserver/schema/services/projects.conf @@ -242,6 +242,23 @@ _definitions { } } } + urls { + type: object + properties { + model_urls { + type: array + items {type: string} + } + event_urls { + type: array + items {type: string} + } + artifact_urls { + type: array + items {type: string} + } + } + } } create { @@ -515,6 +532,32 @@ delete { } } } + "2.13": ${delete."2.1"} { + request { + properties { + delete_contents { + description: "If set to 'true' then the project tasks and models will be deleted. Otherwise their project property will be unassigned. Default value is 'false'" + type: boolean + } + } + } + response { + properties { + urls { + description: "The urls of the files that were uploaded by the project tasks and models. Returned if the 'delete_contents' was set to 'true'" + "$ref": "#/definitions/urls" + } + deleted_models { + description: "Number of models deleted" + type: integer + } + deleted_tasks { + description: "Number of tasks deleted" + type: integer + } + } + } + } } get_unique_metric_variants { "2.1" { diff --git a/apiserver/schema/services/tasks.conf b/apiserver/schema/services/tasks.conf index e0a5b89..c305735 100644 --- a/apiserver/schema/services/tasks.conf +++ b/apiserver/schema/services/tasks.conf @@ -533,6 +533,23 @@ _definitions { } } } + task_urls { + type: object + properties { + model_urls { + type: array + items {type: string} + } + event_urls { + type: array + items {type: string} + } + artifact_urls { + type: array + items {type: string} + } + } + } } get_by_id { @@ -1203,9 +1220,8 @@ reset { response { properties { urls { - description: "The urls of the files that were uploaded by this task. Returned if the 'return_file_urls' properties was set to True" - type: array - items {type: string} + description: "The urls of the files that were uploaded by this task. Returned if the 'return_file_urls' was set to 'true'" + "$ref": "#/definitions/task_urls" } } } @@ -1277,9 +1293,8 @@ delete { response { properties { urls { - description: "The urls of the files that were uploaded by this task. Returned if the 'return_file_urls' properties was set to True" - type: array - items {type: string} + description: "The urls of the files that were uploaded by this task. Returned if the 'return_file_urls' was set to 'true'" + "$ref": "#/definitions/task_urls" } } } diff --git a/apiserver/services/models.py b/apiserver/services/models.py index 2e8dbba..0878cce 100644 --- a/apiserver/services/models.py +++ b/apiserver/services/models.py @@ -20,6 +20,7 @@ from apiserver.bll.model import ModelBLL from apiserver.bll.organization import OrgBLL, Tags from apiserver.bll.project import ProjectBLL from apiserver.bll.task import TaskBLL +from apiserver.bll.task.utils import deleted_prefix from apiserver.config_repo import config from apiserver.database.errors import translate_errors_context from apiserver.database.model import validate_id @@ -442,7 +443,7 @@ def set_ready(call: APICall, company_id, req_model: PublishModelRequest): @endpoint("models.delete", request_data_model=DeleteModelRequest) -def update(call: APICall, company_id, request: DeleteModelRequest): +def delete(call: APICall, company_id, request: DeleteModelRequest): model_id = request.model force = request.force @@ -452,7 +453,7 @@ def update(call: APICall, company_id, request: DeleteModelRequest): if not model: raise errors.bad_request.InvalidModelId(**query) - deleted_model_id = f"__DELETED__{model_id}" + deleted_model_id = f"{deleted_prefix}{model_id}" using_tasks = Task.objects(execution__model=model_id).only("id") if using_tasks: @@ -473,21 +474,19 @@ def update(call: APICall, company_id, request: DeleteModelRequest): raise errors.bad_request.ModelCreatingTaskExists( "and published, use force=True to delete", task=model.task ) - now = datetime.utcnow() - task.update( - output__model=deleted_model_id, - output__error=f"model deleted on {now.isoformat()}", - last_change=now, - upsert=False, - ) + if task.output and task.output.model == model_id: + now = datetime.utcnow() + task.update( + output__model=deleted_model_id, + output__error=f"model deleted on {now.isoformat()}", + last_change=now, + upsert=False, + ) del_count = Model.objects(**query).delete() if del_count: _reset_cached_tags(company_id, projects=[model.project]) - call.result.data = dict( - deleted=del_count > 0, - url=model.uri if request.return_file_url else None - ) + call.result.data = dict(deleted=del_count > 0, url=model.uri,) @endpoint("models.make_public", min_version="2.9", request_data_model=MakePublicRequest) diff --git a/apiserver/services/projects.py b/apiserver/services/projects.py index 2628233..4cb0414 100644 --- a/apiserver/services/projects.py +++ b/apiserver/services/projects.py @@ -1,5 +1,7 @@ from datetime import datetime +from typing import Sequence +import attr from mongoengine import Q from apiserver.apierrors import errors @@ -12,15 +14,14 @@ from apiserver.apimodels.projects import ( ProjectTaskParentsRequest, ProjectHyperparamValuesRequest, ProjectsGetRequest, + DeleteRequest, ) from apiserver.bll.organization import OrgBLL, Tags from apiserver.bll.project import ProjectBLL +from apiserver.bll.project.project_cleanup import delete_project from apiserver.bll.task import TaskBLL from apiserver.database.errors import translate_errors_context -from apiserver.database.model import EntityVisibility -from apiserver.database.model.model import Model from apiserver.database.model.project import Project -from apiserver.database.model.task.task import Task from apiserver.database.utils import ( parse_from_call, get_company_or_none_constraint, @@ -178,36 +179,21 @@ def update(call: APICall): call.result.data_model = UpdateResponse(updated=updated, fields=fields) -@endpoint("projects.delete", required_fields=["project"]) -def delete(call): - assert isinstance(call, APICall) - project_id = call.data["project"] - force = call.data.get("force", False) +def _reset_cached_tags(company: str, projects: Sequence[str]): + org_bll.reset_tags(company, Tags.Task, projects=projects) + org_bll.reset_tags(company, Tags.Model, projects=projects) - with translate_errors_context(): - project = Project.get_for_writing(company=call.identity.company, id=project_id) - if not project: - raise errors.bad_request.InvalidProjectId(id=project_id) - # NOTE: from this point on we'll use the project ID and won't check for company, since we assume we already - # have the correct project ID. - - # Find the tasks which belong to the project - for cls, error in ( - (Task, errors.bad_request.ProjectHasTasks), - (Model, errors.bad_request.ProjectHasModels), - ): - res = cls.objects( - project=project_id, system_tags__nin=[EntityVisibility.archived.value] - ).only("id") - if res and not force: - raise error("use force=true to delete", id=project_id) - - updated_count = res.update(project=None) - - project.delete() - - call.result.data = {"deleted": 1, "disassociated_tasks": updated_count} +@endpoint("projects.delete", request_data_model=DeleteRequest) +def delete(call: APICall, company_id: str, request: DeleteRequest): + res = delete_project( + company=company_id, + project_id=request.project, + force=request.force, + delete_contents=request.delete_contents, + ) + _reset_cached_tags(company_id, projects=[request.project]) + call.result.data = {**attr.asdict(res)} @endpoint("projects.get_unique_metric_variants", request_data_model=ProjectReq) diff --git a/apiserver/services/tasks.py b/apiserver/services/tasks.py index f41e0e4..5d49ebb 100644 --- a/apiserver/services/tasks.py +++ b/apiserver/services/tasks.py @@ -67,7 +67,7 @@ from apiserver.bll.task.param_utils import ( escape_paths, ) from apiserver.bll.task.task_cleanup import cleanup_task -from apiserver.bll.task.utils import update_task, task_deleted_prefix +from apiserver.bll.task.utils import update_task, deleted_prefix from apiserver.bll.util import SetFieldsResolver from apiserver.database.errors import translate_errors_context from apiserver.database.model import EntityVisibility @@ -384,7 +384,7 @@ def _validate_and_get_task_from_call(call: APICall, **kwargs) -> Tuple[Task, dic @endpoint("tasks.validate", request_data_model=CreateRequest) def validate(call: APICall, company_id, req_model: CreateRequest): parent = call.data.get("parent") - if parent and parent.startswith(task_deleted_prefix): + if parent and parent.startswith(deleted_prefix): call.data.pop("parent") _validate_and_get_task_from_call(call) @@ -854,6 +854,7 @@ def reset(call: APICall, company_id, request: ResetRequest): force=force, update_children=False, return_file_urls=request.return_file_urls, + delete_output_models=request.delete_output_models, ) api_results.update(attr.asdict(cleaned_up)) @@ -943,13 +944,13 @@ def archive(call: APICall, company_id, request: ArchiveRequest): @endpoint("tasks.delete", request_data_model=DeleteRequest) -def delete(call: APICall, company_id, req_model: DeleteRequest): +def delete(call: APICall, company_id, request: DeleteRequest): task = TaskBLL.get_task_with_access( - req_model.task, company_id=company_id, requires_write_access=True + request.task, company_id=company_id, requires_write_access=True ) - move_to_trash = req_model.move_to_trash - force = req_model.force + move_to_trash = request.move_to_trash + force = request.force if task.status != TaskStatus.created and not force: raise errors.bad_request.TaskCannotBeDeleted( @@ -961,7 +962,10 @@ def delete(call: APICall, company_id, req_model: DeleteRequest): with translate_errors_context(): result = cleanup_task( - task, force=force, return_file_urls=req_model.return_file_urls + task, + force=force, + return_file_urls=request.return_file_urls, + delete_output_models=request.delete_output_models, ) if move_to_trash: diff --git a/apiserver/tests/automated/test_tasks_delete.py b/apiserver/tests/automated/test_tasks_delete.py index 305a17f..019c123 100644 --- a/apiserver/tests/automated/test_tasks_delete.py +++ b/apiserver/tests/automated/test_tasks_delete.py @@ -1,4 +1,4 @@ -from typing import Set +from typing import Set, Tuple from apiserver.apierrors import errors from apiserver.es_factory import es_factory @@ -7,7 +7,7 @@ from apiserver.tests.automated import TestService class TestTasksResetDelete(TestService): def setUp(self, **kwargs): - super().setUp(version="2.11") + super().setUp(version="2.13") def test_delete(self): # draft task can be deleted @@ -50,12 +50,12 @@ class TestTasksResetDelete(TestService): self.assertEqual(res.urls.artifact_urls, []) task = self.new_task() - model_urls = self.create_task_models(task) + published_model_urls, draft_model_urls = self.create_task_models(task) artifact_urls = self.send_artifacts(task) event_urls = self.send_debug_image_events(task) event_urls.update(self.send_plot_events(task)) res = self.assert_delete_task(task, force=True, return_file_urls=True) - self.assertEqual(set(res.urls.model_urls), model_urls) + self.assertEqual(set(res.urls.model_urls), draft_model_urls) self.assertEqual(set(res.urls.event_urls), event_urls) self.assertEqual(set(res.urls.artifact_urls), artifact_urls) @@ -73,21 +73,59 @@ class TestTasksResetDelete(TestService): self.api.tasks.reset(task=task, force=True) # test urls - task = self.new_task() - model_urls = self.create_task_models(task) - artifact_urls = self.send_artifacts(task) - event_urls = self.send_debug_image_events(task) - event_urls.update(self.send_plot_events(task)) + task, (published_model_urls, draft_model_urls), artifact_urls, event_urls = self.create_task_with_data() res = self.api.tasks.reset(task=task, force=True, return_file_urls=True) - self.assertEqual(set(res.urls.model_urls), model_urls) + self.assertEqual(set(res.urls.model_urls), draft_model_urls) self.assertEqual(set(res.urls.event_urls), event_urls) self.assertEqual(set(res.urls.artifact_urls), artifact_urls) def test_model_delete(self): model = self.new_model(uri="test") - res = self.api.models.delete(model=model, return_file_url=True) + res = self.api.models.delete(model=model) self.assertEqual(res.url, "test") + def test_project_delete(self): + # without delete_contents flag + project = self.new_project() + task = self.new_task(project=project) + res = self.api.tasks.get_by_id(task=task) + self.assertEqual(res.task.get("project"), project) + + res = self.api.projects.delete(project=project, force=True) + self.assertEqual(res.deleted, 1) + self.assertEqual(res.disassociated_tasks, 1) + self.assertEqual(res.deleted_tasks, 0) + res = self.api.tasks.get_by_id(task=task) + self.assertEqual(res.task.get("project"), None) + + # with delete_contents flag + project = self.new_project() + task, (published_model_urls, draft_model_urls), artifact_urls, event_urls = self.create_task_with_data( + project=project + ) + res = self.api.projects.delete( + project=project, force=True, delete_contents=True + ) + self.assertEqual(set(res.urls.model_urls), published_model_urls | draft_model_urls) + self.assertEqual(res.deleted, 1) + self.assertEqual(res.disassociated_tasks, 0) + self.assertEqual(res.deleted_tasks, 1) + self.assertEqual(res.deleted_models, 2) + self.assertEqual(set(res.urls.event_urls), event_urls) + self.assertEqual(set(res.urls.artifact_urls), artifact_urls) + with self.api.raises(errors.bad_request.InvalidTaskId): + self.api.tasks.get_by_id(task=task) + + def create_task_with_data( + self, **kwargs + ) -> Tuple[str, Tuple[Set[str], Set[str]], Set[str], Set[str]]: + task = self.new_task(**kwargs) + published_model_urls, draft_model_urls = self.create_task_models(task, **kwargs) + artifact_urls = self.send_artifacts(task) + event_urls = self.send_debug_image_events(task) + event_urls.update(self.send_plot_events(task)) + return task, (published_model_urls, draft_model_urls), artifact_urls, event_urls + def assert_delete_task(self, task_id, force=False, return_file_urls=False): tasks = self.api.tasks.get_all_ex(id=[task_id]).tasks self.assertEqual(tasks[0].id, task_id) @@ -99,15 +137,15 @@ class TestTasksResetDelete(TestService): self.assertEqual(tasks, []) return res - def create_task_models(self, task) -> Set[str]: + def create_task_models(self, task, **kwargs) -> Tuple[Set[str], Set[str]]: """ Update models from task and return only non public models """ - model_ready = self.new_model(uri="ready") - model_not_ready = self.new_model(uri="not_ready", ready=False) + model_ready = self.new_model(uri="ready", **kwargs) + model_not_ready = self.new_model(uri="not_ready", ready=False, **kwargs) self.api.models.edit(model=model_not_ready, task=task) self.api.models.edit(model=model_ready, task=task) - return {"not_ready"} + return {"ready"}, {"not_ready"} def send_artifacts(self, task) -> Set[str]: """ @@ -123,7 +161,9 @@ class TestTasksResetDelete(TestService): def send_debug_image_events(self, task) -> Set[str]: events = [ - self.create_event(task, "training_debug_image", iteration, url=f"url_{iteration}") + self.create_event( + task, "training_debug_image", iteration, url=f"url_{iteration}" + ) for iteration in range(5) ] self.send_batch(events) @@ -161,23 +201,22 @@ class TestTasksResetDelete(TestService): _, data = self.api.send_batch("events.add_batch", events) return data + name = "test task delete" + delete_params = dict(can_fail=True, force=True) + def new_task(self, **kwargs): - return self.create_temp( - "tasks", - delete_params=dict(can_fail=True), - type="testing", - name="test task delete", - input=dict(view=dict()), - **kwargs, + self.update_missing( + kwargs, name=self.name, type="testing", input=dict(view=dict()) ) + return self.create_temp("tasks", delete_params=self.delete_params, **kwargs,) def new_model(self, **kwargs): - self.update_missing(kwargs, name="test", uri="file:///a/b", labels={}) - return self.create_temp( - "models", - delete_params=dict(can_fail=True), - **kwargs, - ) + self.update_missing(kwargs, name=self.name, uri="file:///a/b", labels={}) + return self.create_temp("models", delete_params=self.delete_params, **kwargs,) + + def new_project(self, **kwargs): + self.update_missing(kwargs, name=self.name, description="") + return self.create_temp("projects", delete_params=self.delete_params, **kwargs) def publish_task(self, task_id): self.api.tasks.started(task=task_id)