Fix task can't be cloned if input model was deleted

This commit is contained in:
allegroai 2020-06-01 12:23:29 +03:00
parent f8d8fc40a6
commit dcdf2a3d58
10 changed files with 175 additions and 84 deletions

View File

@ -100,6 +100,7 @@ class CloneRequest(TaskRequest):
new_task_parent = StringField() new_task_parent = StringField()
new_task_project = StringField() new_task_project = StringField()
execution_overrides = DictField() execution_overrides = DictField()
validate_references = BoolField(default=False)
class AddOrUpdateArtifactsRequest(TaskRequest): class AddOrUpdateArtifactsRequest(TaskRequest):

View File

@ -0,0 +1 @@
from .project_bll import ProjectBLL

View File

@ -0,0 +1,33 @@
from typing import Sequence, Optional
from mongoengine import Q
from config import config
from database.model.model import Model
from database.model.task.task import Task
from timing_context import TimingContext
log = config.logger(__file__)
class ProjectBLL:
@classmethod
def get_active_users(
cls, company, project_ids: Sequence, user_ids: Optional[Sequence] = None
) -> set:
"""
Get the set of user ids that created tasks/models in the given projects
If project_ids is empty then all projects are examined
If user_ids are passed then only subset of these users is returned
"""
with TimingContext("mongo", "active_users_in_projects"):
res = set()
query = Q(company=company)
if project_ids:
query &= Q(project__in=project_ids)
if user_ids:
query &= Q(user__in=user_ids)
for cls_ in (Task, Model):
res |= set(cls_.objects(query).distinct(field="user"))
return res

View File

@ -164,9 +164,11 @@ class TaskBLL(object):
tags: Optional[Sequence[str]] = None, tags: Optional[Sequence[str]] = None,
system_tags: Optional[Sequence[str]] = None, system_tags: Optional[Sequence[str]] = None,
execution_overrides: Optional[dict] = None, execution_overrides: Optional[dict] = None,
validate_references: bool = False,
) -> Task: ) -> Task:
task = cls.get_by_id(company_id=company_id, task_id=task_id) task = cls.get_by_id(company_id=company_id, task_id=task_id)
execution_dict = task.execution.to_proper_dict() if task.execution else {} execution_dict = task.execution.to_proper_dict() if task.execution else {}
execution_model_overriden = False
if execution_overrides: if execution_overrides:
parameters = execution_overrides.get("parameters") parameters = execution_overrides.get("parameters")
if parameters is not None: if parameters is not None:
@ -174,6 +176,8 @@ class TaskBLL(object):
ParameterKeyEscaper.escape(k): v for k, v in parameters.items() ParameterKeyEscaper.escape(k): v for k, v in parameters.items()
} }
execution_dict = deep_merge(execution_dict, execution_overrides) execution_dict = deep_merge(execution_dict, execution_overrides)
execution_model_overriden = execution_overrides.get("model") is not None
artifacts = execution_dict.get("artifacts") artifacts = execution_dict.get("artifacts")
if artifacts: if artifacts:
execution_dict["artifacts"] = [ execution_dict["artifacts"] = [
@ -201,25 +205,41 @@ class TaskBLL(object):
else None, else None,
execution=execution_dict, execution=execution_dict,
) )
cls.validate(new_task) cls.validate(
new_task,
validate_model=validate_references or execution_model_overriden,
validate_parent=validate_references or parent,
validate_project=validate_references or project,
)
new_task.save() new_task.save()
return new_task return new_task
@classmethod @classmethod
def validate(cls, task: Task): def validate(
assert isinstance(task, Task) cls,
task: Task,
if task.parent and not Task.get( validate_model=True,
validate_parent=True,
validate_project=True,
):
if (
validate_parent
and task.parent
and not Task.get(
company=task.company, id=task.parent, _only=("id",), include_public=True company=task.company, id=task.parent, _only=("id",), include_public=True
)
): ):
raise errors.bad_request.InvalidTaskId("invalid parent", parent=task.parent) raise errors.bad_request.InvalidTaskId("invalid parent", parent=task.parent)
if task.project and not Project.get_for_writing( if (
company=task.company, id=task.project validate_project
and task.project
and not Project.get_for_writing(company=task.company, id=task.project)
): ):
raise errors.bad_request.InvalidProjectId(id=task.project) raise errors.bad_request.InvalidProjectId(id=task.project)
if validate_model:
cls.validate_execution_model(task) cls.validate_execution_model(task)
@staticmethod @staticmethod

View File

@ -530,59 +530,59 @@
} }
} }
} }
"2.7" { // "2.7" {
description: "Get 'log' events for this task" // description: "Get 'log' events for this task"
request { // request {
type: object // type: object
required: [ // required: [
task // task
] // ]
properties { // properties {
task { // task {
type: string // type: string
description: "Task ID" // description: "Task ID"
} // }
batch_size { // batch_size {
type: integer // type: integer
description: "The amount of log events to return" // description: "The amount of log events to return"
} // }
navigate_earlier { // navigate_earlier {
type: boolean // type: boolean
description: "If set then log events are retreived from the latest to the earliest ones (in timestamp descending order). Otherwise from the earliest to the latest ones (in timestamp ascending order). The default is True" // description: "If set then log events are retreived from the latest to the earliest ones (in timestamp descending order). Otherwise from the earliest to the latest ones (in timestamp ascending order). The default is True"
} // }
refresh { // refresh {
type: boolean // type: boolean
description: "If set then scroll will be moved to the latest logs (if 'navigate_earlier' is set to True) or to the earliest (otherwise)" // description: "If set then scroll will be moved to the latest logs (if 'navigate_earlier' is set to True) or to the earliest (otherwise)"
} // }
scroll_id { // scroll_id {
type: string // type: string
description: "Scroll ID of previous call (used for getting more results)" // description: "Scroll ID of previous call (used for getting more results)"
} // }
} // }
} // }
response { // response {
type: object // type: object
properties { // properties {
events { // events {
type: array // type: array
items { type: object } // items { type: object }
description: "Log items list" // description: "Log items list"
} // }
returned { // returned {
type: integer // type: integer
description: "Number of log events returned" // description: "Number of log events returned"
} // }
total { // total {
type: number // type: number
description: "Total number of log events available for this query" // description: "Total number of log events available for this query"
} // }
scroll_id { // scroll_id {
type: string // type: string
description: "Scroll ID for getting more results" // description: "Scroll ID for getting more results"
} // }
} // }
} // }
} // }
} }
get_task_events { get_task_events {
"2.1" { "2.1" {

View File

@ -591,6 +591,10 @@ clone {
description: "The execution params for the cloned task. The params not specified are taken from the original task" description: "The execution params for the cloned task. The params not specified are taken from the original task"
"$ref": "#/definitions/execution" "$ref": "#/definitions/execution"
} }
validate_references {
description: "If set to 'false' then the task fields that are copied from the original task are not validated. The default is false."
type: boolean
}
} }
} }
response { response {

View File

@ -94,26 +94,27 @@ def get_task_log_v1_7(call, company_id, req_model):
) )
@endpoint("events.get_task_log", min_version="2.7", request_data_model=LogEventsRequest) # uncomment this once the front end is ready
def get_task_log(call, company_id, req_model: LogEventsRequest): # @endpoint("events.get_task_log", min_version="2.7", request_data_model=LogEventsRequest)
task_id = req_model.task # def get_task_log(call, company_id, req_model: LogEventsRequest):
task_bll.assert_exists(company_id, task_id, allow_public=True) # task_id = req_model.task
# task_bll.assert_exists(company_id, task_id, allow_public=True)
res = event_bll.log_events_iterator.get_task_events( #
company_id=company_id, # res = event_bll.log_events_iterator.get_task_events(
task_id=task_id, # company_id=company_id,
batch_size=req_model.batch_size, # task_id=task_id,
navigate_earlier=req_model.navigate_earlier, # batch_size=req_model.batch_size,
refresh=req_model.refresh, # navigate_earlier=req_model.navigate_earlier,
state_id=req_model.scroll_id, # refresh=req_model.refresh,
) # state_id=req_model.scroll_id,
# )
call.result.data = dict( #
events=res.events, # call.result.data = dict(
returned=len(res.events), # events=res.events,
total=res.total_events, # returned=len(res.events),
scroll_id=res.next_scroll_id, # total=res.total_events,
) # scroll_id=res.next_scroll_id,
# )
@endpoint("events.download_task_log", required_fields=["task"]) @endpoint("events.download_task_log", required_fields=["task"])

View File

@ -361,6 +361,7 @@ def clone_task(call: APICall, company_id, request: CloneRequest):
tags=request.new_task_tags, tags=request.new_task_tags,
system_tags=request.new_task_system_tags, system_tags=request.new_task_system_tags,
execution_overrides=request.execution_overrides, execution_overrides=request.execution_overrides,
validate_references=request.validate_references,
) )
call.result.data_model = IdResponse(id=task.id) call.result.data_model = IdResponse(id=task.id)

View File

@ -186,6 +186,7 @@ class TestTaskEvents(TestService):
self.assertEqual(len(res.events), 1) self.assertEqual(len(res.events), 1)
def test_task_logs(self): def test_task_logs(self):
# this test will fail until the new api is uncommented
task = self._temp_task() task = self._temp_task()
timestamp = es_factory.get_timestamp_millis() timestamp = es_factory.get_timestamp_millis()
events = [ events = [

View File

@ -74,13 +74,13 @@ class TestTasksEdit(TestService):
new_name = "new test" new_name = "new test"
new_tags = ["by"] new_tags = ["by"]
execution_overrides = dict(framework="Caffe") execution_overrides = dict(framework="Caffe")
new_task_id = self.api.tasks.clone( new_task_id = self._clone_task(
task=task, task=task,
new_task_name=new_name, new_task_name=new_name,
new_task_tags=new_tags, new_task_tags=new_tags,
execution_overrides=execution_overrides, execution_overrides=execution_overrides,
new_task_parent=task, new_task_parent=task,
).id )
new_task = self.api.tasks.get_by_id(task=new_task_id).task new_task = self.api.tasks.get_by_id(task=new_task_id).task
self.assertEqual(new_task.name, new_name) self.assertEqual(new_task.name, new_name)
self.assertEqual(new_task.type, "testing") self.assertEqual(new_task.type, "testing")
@ -91,3 +91,32 @@ class TestTasksEdit(TestService):
self.assertEqual(new_task.execution.parameters, execution["parameters"]) self.assertEqual(new_task.execution.parameters, execution["parameters"])
self.assertEqual(new_task.execution.framework, execution_overrides["framework"]) self.assertEqual(new_task.execution.framework, execution_overrides["framework"])
self.assertEqual(new_task.system_tags, []) self.assertEqual(new_task.system_tags, [])
def test_model_check_in_clone(self):
model = self.new_model()
task = self.new_task(execution=dict(model=model))
# task with deleted model still can be copied
self.api.models.delete(model=model, force=True)
self._clone_task(task=task, new_task_name="clone test")
# unless check for refs is done
with self.api.raises(InvalidModelId):
self._clone_task(
task=task, new_task_name="clone test2", validate_references=True
)
# if the model is overriden then it is always checked
with self.api.raises(InvalidModelId):
self._clone_task(
task=task,
new_task_name="clone test3",
execution_overrides=dict(model="not existing"),
)
def _clone_task(self, task, **kwargs):
new_task = self.api.tasks.clone(task=task, **kwargs).id
self.defer(
self.api.tasks.delete, task=new_task, move_to_trash=False, force=True
)
return new_task