Fix model Id handling when deleting models for tasks

This commit is contained in:
allegroai 2023-05-25 19:35:18 +03:00
parent b22f26129e
commit 0c37ced2a1
4 changed files with 56 additions and 12 deletions

View File

@ -108,25 +108,27 @@ class ModelBLL:
if model.task: if model.task:
task = Task.objects(id=model.task).first() task = Task.objects(id=model.task).first()
if task and task.status == TaskStatus.published: if task:
if not force: now = datetime.utcnow()
raise errors.bad_request.ModelCreatingTaskExists( if task.status == TaskStatus.published:
"and published, use force=True to delete", task=model.task if not force:
) raise errors.bad_request.ModelCreatingTaskExists(
if task.models.output and model_id in task.models.output: "and published, use force=True to delete", task=model.task
now = datetime.utcnow() )
Task._get_collection().update_one( Task._get_collection().update_one(
filter={"_id": model.task, "models.output.model": model_id}, filter={"_id": model.task, "models.output.model": model_id},
update={ update={
"$set": { "$set": {
"models.output.$[elem].model": deleted_model_id, "models.output.$[elem].model": deleted_model_id,
"output.error": f"model deleted on {now.isoformat()}", "output.error": f"model deleted on {now.isoformat()}",
"last_change": now,
}, },
"last_change": now,
}, },
array_filters=[{"elem.model": model_id}], array_filters=[{"elem.model": model_id}],
upsert=False, upsert=False,
) )
else:
task.update(pull__models__output__model=model_id, set__last_change=now)
del_count = Model.objects(id=model_id, company=company_id).delete() del_count = Model.objects(id=model_id, company=company_id).delete()
return del_count, model return del_count, model

View File

@ -1,4 +1,5 @@
from collections import defaultdict from collections import defaultdict
from datetime import datetime
from typing import Tuple, Set, Sequence from typing import Tuple, Set, Sequence
import attr import attr
@ -15,7 +16,7 @@ from apiserver.config_repo import config
from apiserver.database.model import EntityVisibility from apiserver.database.model import EntityVisibility
from apiserver.database.model.model import Model from apiserver.database.model.model import Model
from apiserver.database.model.project import Project from apiserver.database.model.project import Project
from apiserver.database.model.task.task import Task, ArtifactModes, TaskType from apiserver.database.model.task.task import Task, ArtifactModes, TaskType, TaskStatus
from .project_bll import ProjectBLL from .project_bll import ProjectBLL
from .sub_projects import _ids_with_children from .sub_projects import _ids_with_children
@ -185,29 +186,43 @@ def _delete_models(
return 0, set(), set() return 0, set(), set()
model_ids = list({m.id for m in models}) model_ids = list({m.id for m in models})
deleted = "__DELETED__"
Task._get_collection().update_many( Task._get_collection().update_many(
filter={ filter={
"project": {"$nin": projects}, "project": {"$nin": projects},
"models.input.model": {"$in": model_ids}, "models.input.model": {"$in": model_ids},
}, },
update={"$set": {"models.input.$[elem].model": None}}, update={"$set": {"models.input.$[elem].model": deleted}},
array_filters=[{"elem.model": {"$in": model_ids}}], array_filters=[{"elem.model": {"$in": model_ids}}],
upsert=False, upsert=False,
) )
model_tasks = list({m.task for m in models if m.task}) model_tasks = list({m.task for m in models if m.task})
if model_tasks: if model_tasks:
now = datetime.utcnow()
# update published tasks
Task._get_collection().update_many( Task._get_collection().update_many(
filter={ filter={
"_id": {"$in": model_tasks}, "_id": {"$in": model_tasks},
"project": {"$nin": projects}, "project": {"$nin": projects},
"models.output.model": {"$in": model_ids}, "models.output.model": {"$in": model_ids},
"status": TaskStatus.published,
},
update={
"$set": {
"models.output.$[elem].model": deleted,
"last_change": now,
}
}, },
update={"$set": {"models.output.$[elem].model": None}},
array_filters=[{"elem.model": {"$in": model_ids}}], array_filters=[{"elem.model": {"$in": model_ids}}],
upsert=False, upsert=False,
) )
# update unpublished tasks
Task.objects(
id__in=model_tasks,
project__nin=projects,
status__ne=TaskStatus.published,
).update(pull__models__output__model__in=model_ids, set__last_change=now)
event_urls, model_urls = set(), set() event_urls, model_urls = set(), set()
for m in models: for m in models:

View File

@ -266,6 +266,7 @@ def delete_task(
if move_to_trash: if move_to_trash:
# make sure that whatever changes were done to the task are saved # make sure that whatever changes were done to the task are saved
# the task itself will be deleted later in the move_tasks_to_trash operation # the task itself will be deleted later in the move_tasks_to_trash operation
task.last_update = datetime.utcnow()
task.save() task.save()
else: else:
task.delete() task.delete()

View File

@ -1,3 +1,4 @@
from apiserver.apierrors import errors
from apiserver.apierrors.errors.bad_request import InvalidModelId from apiserver.apierrors.errors.bad_request import InvalidModelId
from apiserver.tests.automated import TestService from apiserver.tests.automated import TestService
@ -11,6 +12,31 @@ class TestModelsService(TestService):
def setUp(self, version="2.9"): def setUp(self, version="2.9"):
super().setUp(version=version) super().setUp(version=version)
def test_delete_model_for_task(self):
# non published task
task_id, model_id = self._create_task_and_model()
task = self.api.tasks.get_by_id(task=task_id).task
self.assertEqual(task.models.output[0].model, model_id)
res = self.api.models.delete(model=model_id)
self.assertTrue(res.deleted)
with self.api.raises(errors.bad_request.InvalidModelId):
self.api.models.get_by_id(model=model_id)
task = self.api.tasks.get_by_id(task=task_id).task
self.assertEqual(task.models.output, [])
# published task
task_id, model_id = self._create_task_and_model()
self.api.tasks.stopped(task=task_id)
self.api.tasks.publish(task=task_id, publish_model=False)
with self.api.raises(errors.bad_request.ModelCreatingTaskExists):
self.api.models.delete(model=model_id)
res = self.api.models.delete(model=model_id, force=True)
self.assertTrue(res.deleted)
with self.api.raises(errors.bad_request.InvalidModelId):
self.api.models.get_by_id(model=model_id)
task = self.api.tasks.get_by_id(task=task_id).task
self.assertEqual(task.models.output[0].model, f"__DELETED__{model_id}")
def test_publish_output_model_running_task(self): def test_publish_output_model_running_task(self):
task_id, model_id = self._create_task_and_model() task_id, model_id = self._create_task_and_model()
self._assert_model_ready(model_id, False) self._assert_model_ready(model_id, False)