Fix task can't be cloned if input model was deleted

This commit is contained in:
allegroai
2020-06-01 12:23:29 +03:00
parent f8d8fc40a6
commit dcdf2a3d58
10 changed files with 175 additions and 84 deletions

View File

@@ -164,9 +164,11 @@ class TaskBLL(object):
tags: Optional[Sequence[str]] = None,
system_tags: Optional[Sequence[str]] = None,
execution_overrides: Optional[dict] = None,
validate_references: bool = False,
) -> Task:
task = cls.get_by_id(company_id=company_id, task_id=task_id)
execution_dict = task.execution.to_proper_dict() if task.execution else {}
execution_model_overriden = False
if execution_overrides:
parameters = execution_overrides.get("parameters")
if parameters is not None:
@@ -174,6 +176,8 @@ class TaskBLL(object):
ParameterKeyEscaper.escape(k): v for k, v in parameters.items()
}
execution_dict = deep_merge(execution_dict, execution_overrides)
execution_model_overriden = execution_overrides.get("model") is not None
artifacts = execution_dict.get("artifacts")
if artifacts:
execution_dict["artifacts"] = [
@@ -201,26 +205,42 @@ class TaskBLL(object):
else None,
execution=execution_dict,
)
cls.validate(new_task)
cls.validate(
new_task,
validate_model=validate_references or execution_model_overriden,
validate_parent=validate_references or parent,
validate_project=validate_references or project,
)
new_task.save()
return new_task
@classmethod
def validate(cls, task: Task):
assert isinstance(task, Task)
if task.parent and not Task.get(
company=task.company, id=task.parent, _only=("id",), include_public=True
def validate(
cls,
task: Task,
validate_model=True,
validate_parent=True,
validate_project=True,
):
if (
validate_parent
and task.parent
and not Task.get(
company=task.company, id=task.parent, _only=("id",), include_public=True
)
):
raise errors.bad_request.InvalidTaskId("invalid parent", parent=task.parent)
if task.project and not Project.get_for_writing(
company=task.company, id=task.project
if (
validate_project
and task.project
and not Project.get_for_writing(company=task.company, id=task.project)
):
raise errors.bad_request.InvalidProjectId(id=task.project)
cls.validate_execution_model(task)
if validate_model:
cls.validate_execution_model(task)
@staticmethod
def get_unique_metric_variants(company_id, project_ids=None):