Fix model Id handling when deleting models for tasks

This commit is contained in:
allegroai 2023-05-25 19:35:18 +03:00
parent b22f26129e
commit 0c37ced2a1
4 changed files with 56 additions and 12 deletions

View File

@ -108,25 +108,27 @@ class ModelBLL:
if model.task:
task = Task.objects(id=model.task).first()
if task and task.status == TaskStatus.published:
if not force:
raise errors.bad_request.ModelCreatingTaskExists(
"and published, use force=True to delete", task=model.task
)
if task.models.output and model_id in task.models.output:
now = datetime.utcnow()
if task:
now = datetime.utcnow()
if task.status == TaskStatus.published:
if not force:
raise errors.bad_request.ModelCreatingTaskExists(
"and published, use force=True to delete", task=model.task
)
Task._get_collection().update_one(
filter={"_id": model.task, "models.output.model": model_id},
update={
"$set": {
"models.output.$[elem].model": deleted_model_id,
"output.error": f"model deleted on {now.isoformat()}",
"last_change": now,
},
"last_change": now,
},
array_filters=[{"elem.model": model_id}],
upsert=False,
)
else:
task.update(pull__models__output__model=model_id, set__last_change=now)
del_count = Model.objects(id=model_id, company=company_id).delete()
return del_count, model

View File

@ -1,4 +1,5 @@
from collections import defaultdict
from datetime import datetime
from typing import Tuple, Set, Sequence
import attr
@ -15,7 +16,7 @@ from apiserver.config_repo import config
from apiserver.database.model import EntityVisibility
from apiserver.database.model.model import Model
from apiserver.database.model.project import Project
from apiserver.database.model.task.task import Task, ArtifactModes, TaskType
from apiserver.database.model.task.task import Task, ArtifactModes, TaskType, TaskStatus
from .project_bll import ProjectBLL
from .sub_projects import _ids_with_children
@ -185,29 +186,43 @@ def _delete_models(
return 0, set(), set()
model_ids = list({m.id for m in models})
deleted = "__DELETED__"
Task._get_collection().update_many(
filter={
"project": {"$nin": projects},
"models.input.model": {"$in": model_ids},
},
update={"$set": {"models.input.$[elem].model": None}},
update={"$set": {"models.input.$[elem].model": deleted}},
array_filters=[{"elem.model": {"$in": model_ids}}],
upsert=False,
)
model_tasks = list({m.task for m in models if m.task})
if model_tasks:
now = datetime.utcnow()
# update published tasks
Task._get_collection().update_many(
filter={
"_id": {"$in": model_tasks},
"project": {"$nin": projects},
"models.output.model": {"$in": model_ids},
"status": TaskStatus.published,
},
update={
"$set": {
"models.output.$[elem].model": deleted,
"last_change": now,
}
},
update={"$set": {"models.output.$[elem].model": None}},
array_filters=[{"elem.model": {"$in": model_ids}}],
upsert=False,
)
# update unpublished tasks
Task.objects(
id__in=model_tasks,
project__nin=projects,
status__ne=TaskStatus.published,
).update(pull__models__output__model__in=model_ids, set__last_change=now)
event_urls, model_urls = set(), set()
for m in models:

View File

@ -266,6 +266,7 @@ def delete_task(
if move_to_trash:
# make sure that whatever changes were done to the task are saved
# the task itself will be deleted later in the move_tasks_to_trash operation
task.last_update = datetime.utcnow()
task.save()
else:
task.delete()

View File

@ -1,3 +1,4 @@
from apiserver.apierrors import errors
from apiserver.apierrors.errors.bad_request import InvalidModelId
from apiserver.tests.automated import TestService
@ -11,6 +12,31 @@ class TestModelsService(TestService):
def setUp(self, version="2.9"):
super().setUp(version=version)
def test_delete_model_for_task(self):
# non published task
task_id, model_id = self._create_task_and_model()
task = self.api.tasks.get_by_id(task=task_id).task
self.assertEqual(task.models.output[0].model, model_id)
res = self.api.models.delete(model=model_id)
self.assertTrue(res.deleted)
with self.api.raises(errors.bad_request.InvalidModelId):
self.api.models.get_by_id(model=model_id)
task = self.api.tasks.get_by_id(task=task_id).task
self.assertEqual(task.models.output, [])
# published task
task_id, model_id = self._create_task_and_model()
self.api.tasks.stopped(task=task_id)
self.api.tasks.publish(task=task_id, publish_model=False)
with self.api.raises(errors.bad_request.ModelCreatingTaskExists):
self.api.models.delete(model=model_id)
res = self.api.models.delete(model=model_id, force=True)
self.assertTrue(res.deleted)
with self.api.raises(errors.bad_request.InvalidModelId):
self.api.models.get_by_id(model=model_id)
task = self.api.tasks.get_by_id(task=task_id).task
self.assertEqual(task.models.output[0].model, f"__DELETED__{model_id}")
def test_publish_output_model_running_task(self):
task_id, model_id = self._create_task_and_model()
self._assert_model_ready(model_id, False)