Fix task can't be cloned if input model was deleted

This commit is contained in:
allegroai 2020-06-01 12:23:29 +03:00
parent f8d8fc40a6
commit dcdf2a3d58
10 changed files with 175 additions and 84 deletions

View File

@ -100,6 +100,7 @@ class CloneRequest(TaskRequest):
new_task_parent = StringField()
new_task_project = StringField()
execution_overrides = DictField()
validate_references = BoolField(default=False)
class AddOrUpdateArtifactsRequest(TaskRequest):

View File

@ -0,0 +1 @@
from .project_bll import ProjectBLL

View File

@ -0,0 +1,33 @@
from typing import Sequence, Optional
from mongoengine import Q
from config import config
from database.model.model import Model
from database.model.task.task import Task
from timing_context import TimingContext
log = config.logger(__file__)
class ProjectBLL:
@classmethod
def get_active_users(
cls, company, project_ids: Sequence, user_ids: Optional[Sequence] = None
) -> set:
"""
Get the set of user ids that created tasks/models in the given projects
If project_ids is empty then all projects are examined
If user_ids are passed then only subset of these users is returned
"""
with TimingContext("mongo", "active_users_in_projects"):
res = set()
query = Q(company=company)
if project_ids:
query &= Q(project__in=project_ids)
if user_ids:
query &= Q(user__in=user_ids)
for cls_ in (Task, Model):
res |= set(cls_.objects(query).distinct(field="user"))
return res

View File

@ -164,9 +164,11 @@ class TaskBLL(object):
tags: Optional[Sequence[str]] = None,
system_tags: Optional[Sequence[str]] = None,
execution_overrides: Optional[dict] = None,
validate_references: bool = False,
) -> Task:
task = cls.get_by_id(company_id=company_id, task_id=task_id)
execution_dict = task.execution.to_proper_dict() if task.execution else {}
execution_model_overriden = False
if execution_overrides:
parameters = execution_overrides.get("parameters")
if parameters is not None:
@ -174,6 +176,8 @@ class TaskBLL(object):
ParameterKeyEscaper.escape(k): v for k, v in parameters.items()
}
execution_dict = deep_merge(execution_dict, execution_overrides)
execution_model_overriden = execution_overrides.get("model") is not None
artifacts = execution_dict.get("artifacts")
if artifacts:
execution_dict["artifacts"] = [
@ -201,26 +205,42 @@ class TaskBLL(object):
else None,
execution=execution_dict,
)
cls.validate(new_task)
cls.validate(
new_task,
validate_model=validate_references or execution_model_overriden,
validate_parent=validate_references or parent,
validate_project=validate_references or project,
)
new_task.save()
return new_task
@classmethod
def validate(cls, task: Task):
assert isinstance(task, Task)
if task.parent and not Task.get(
company=task.company, id=task.parent, _only=("id",), include_public=True
def validate(
cls,
task: Task,
validate_model=True,
validate_parent=True,
validate_project=True,
):
if (
validate_parent
and task.parent
and not Task.get(
company=task.company, id=task.parent, _only=("id",), include_public=True
)
):
raise errors.bad_request.InvalidTaskId("invalid parent", parent=task.parent)
if task.project and not Project.get_for_writing(
company=task.company, id=task.project
if (
validate_project
and task.project
and not Project.get_for_writing(company=task.company, id=task.project)
):
raise errors.bad_request.InvalidProjectId(id=task.project)
cls.validate_execution_model(task)
if validate_model:
cls.validate_execution_model(task)
@staticmethod
def get_unique_metric_variants(company_id, project_ids=None):

View File

@ -530,59 +530,59 @@
}
}
}
"2.7" {
description: "Get 'log' events for this task"
request {
type: object
required: [
task
]
properties {
task {
type: string
description: "Task ID"
}
batch_size {
type: integer
description: "The amount of log events to return"
}
navigate_earlier {
type: boolean
description: "If set then log events are retreived from the latest to the earliest ones (in timestamp descending order). Otherwise from the earliest to the latest ones (in timestamp ascending order). The default is True"
}
refresh {
type: boolean
description: "If set then scroll will be moved to the latest logs (if 'navigate_earlier' is set to True) or to the earliest (otherwise)"
}
scroll_id {
type: string
description: "Scroll ID of previous call (used for getting more results)"
}
}
}
response {
type: object
properties {
events {
type: array
items { type: object }
description: "Log items list"
}
returned {
type: integer
description: "Number of log events returned"
}
total {
type: number
description: "Total number of log events available for this query"
}
scroll_id {
type: string
description: "Scroll ID for getting more results"
}
}
}
}
// "2.7" {
// description: "Get 'log' events for this task"
// request {
// type: object
// required: [
// task
// ]
// properties {
// task {
// type: string
// description: "Task ID"
// }
// batch_size {
// type: integer
// description: "The amount of log events to return"
// }
// navigate_earlier {
// type: boolean
// description: "If set then log events are retreived from the latest to the earliest ones (in timestamp descending order). Otherwise from the earliest to the latest ones (in timestamp ascending order). The default is True"
// }
// refresh {
// type: boolean
// description: "If set then scroll will be moved to the latest logs (if 'navigate_earlier' is set to True) or to the earliest (otherwise)"
// }
// scroll_id {
// type: string
// description: "Scroll ID of previous call (used for getting more results)"
// }
// }
// }
// response {
// type: object
// properties {
// events {
// type: array
// items { type: object }
// description: "Log items list"
// }
// returned {
// type: integer
// description: "Number of log events returned"
// }
// total {
// type: number
// description: "Total number of log events available for this query"
// }
// scroll_id {
// type: string
// description: "Scroll ID for getting more results"
// }
// }
// }
// }
}
get_task_events {
"2.1" {

View File

@ -591,6 +591,10 @@ clone {
description: "The execution params for the cloned task. The params not specified are taken from the original task"
"$ref": "#/definitions/execution"
}
validate_references {
description: "If set to 'false' then the task fields that are copied from the original task are not validated. The default is false."
type: boolean
}
}
}
response {

View File

@ -94,26 +94,27 @@ def get_task_log_v1_7(call, company_id, req_model):
)
@endpoint("events.get_task_log", min_version="2.7", request_data_model=LogEventsRequest)
def get_task_log(call, company_id, req_model: LogEventsRequest):
task_id = req_model.task
task_bll.assert_exists(company_id, task_id, allow_public=True)
res = event_bll.log_events_iterator.get_task_events(
company_id=company_id,
task_id=task_id,
batch_size=req_model.batch_size,
navigate_earlier=req_model.navigate_earlier,
refresh=req_model.refresh,
state_id=req_model.scroll_id,
)
call.result.data = dict(
events=res.events,
returned=len(res.events),
total=res.total_events,
scroll_id=res.next_scroll_id,
)
# uncomment this once the front end is ready
# @endpoint("events.get_task_log", min_version="2.7", request_data_model=LogEventsRequest)
# def get_task_log(call, company_id, req_model: LogEventsRequest):
# task_id = req_model.task
# task_bll.assert_exists(company_id, task_id, allow_public=True)
#
# res = event_bll.log_events_iterator.get_task_events(
# company_id=company_id,
# task_id=task_id,
# batch_size=req_model.batch_size,
# navigate_earlier=req_model.navigate_earlier,
# refresh=req_model.refresh,
# state_id=req_model.scroll_id,
# )
#
# call.result.data = dict(
# events=res.events,
# returned=len(res.events),
# total=res.total_events,
# scroll_id=res.next_scroll_id,
# )
@endpoint("events.download_task_log", required_fields=["task"])

View File

@ -361,6 +361,7 @@ def clone_task(call: APICall, company_id, request: CloneRequest):
tags=request.new_task_tags,
system_tags=request.new_task_system_tags,
execution_overrides=request.execution_overrides,
validate_references=request.validate_references,
)
call.result.data_model = IdResponse(id=task.id)

View File

@ -186,6 +186,7 @@ class TestTaskEvents(TestService):
self.assertEqual(len(res.events), 1)
def test_task_logs(self):
# this test will fail until the new api is uncommented
task = self._temp_task()
timestamp = es_factory.get_timestamp_millis()
events = [

View File

@ -74,13 +74,13 @@ class TestTasksEdit(TestService):
new_name = "new test"
new_tags = ["by"]
execution_overrides = dict(framework="Caffe")
new_task_id = self.api.tasks.clone(
new_task_id = self._clone_task(
task=task,
new_task_name=new_name,
new_task_tags=new_tags,
execution_overrides=execution_overrides,
new_task_parent=task,
).id
)
new_task = self.api.tasks.get_by_id(task=new_task_id).task
self.assertEqual(new_task.name, new_name)
self.assertEqual(new_task.type, "testing")
@ -91,3 +91,32 @@ class TestTasksEdit(TestService):
self.assertEqual(new_task.execution.parameters, execution["parameters"])
self.assertEqual(new_task.execution.framework, execution_overrides["framework"])
self.assertEqual(new_task.system_tags, [])
def test_model_check_in_clone(self):
model = self.new_model()
task = self.new_task(execution=dict(model=model))
# task with deleted model still can be copied
self.api.models.delete(model=model, force=True)
self._clone_task(task=task, new_task_name="clone test")
# unless check for refs is done
with self.api.raises(InvalidModelId):
self._clone_task(
task=task, new_task_name="clone test2", validate_references=True
)
# if the model is overriden then it is always checked
with self.api.raises(InvalidModelId):
self._clone_task(
task=task,
new_task_name="clone test3",
execution_overrides=dict(model="not existing"),
)
def _clone_task(self, task, **kwargs):
new_task = self.api.tasks.clone(task=task, **kwargs).id
self.defer(
self.api.tasks.delete, task=new_task, move_to_trash=False, force=True
)
return new_task