mirror of
https://github.com/clearml/clearml-server
synced 2025-06-26 23:15:47 +00:00
Fix task can't be cloned if input model was deleted
This commit is contained in:
parent
f8d8fc40a6
commit
dcdf2a3d58
@ -100,6 +100,7 @@ class CloneRequest(TaskRequest):
|
|||||||
new_task_parent = StringField()
|
new_task_parent = StringField()
|
||||||
new_task_project = StringField()
|
new_task_project = StringField()
|
||||||
execution_overrides = DictField()
|
execution_overrides = DictField()
|
||||||
|
validate_references = BoolField(default=False)
|
||||||
|
|
||||||
|
|
||||||
class AddOrUpdateArtifactsRequest(TaskRequest):
|
class AddOrUpdateArtifactsRequest(TaskRequest):
|
||||||
|
1
server/bll/project/__init__.py
Normal file
1
server/bll/project/__init__.py
Normal file
@ -0,0 +1 @@
|
|||||||
|
from .project_bll import ProjectBLL
|
33
server/bll/project/project_bll.py
Normal file
33
server/bll/project/project_bll.py
Normal file
@ -0,0 +1,33 @@
|
|||||||
|
from typing import Sequence, Optional
|
||||||
|
|
||||||
|
from mongoengine import Q
|
||||||
|
|
||||||
|
from config import config
|
||||||
|
from database.model.model import Model
|
||||||
|
from database.model.task.task import Task
|
||||||
|
from timing_context import TimingContext
|
||||||
|
|
||||||
|
log = config.logger(__file__)
|
||||||
|
|
||||||
|
|
||||||
|
class ProjectBLL:
|
||||||
|
@classmethod
|
||||||
|
def get_active_users(
|
||||||
|
cls, company, project_ids: Sequence, user_ids: Optional[Sequence] = None
|
||||||
|
) -> set:
|
||||||
|
"""
|
||||||
|
Get the set of user ids that created tasks/models in the given projects
|
||||||
|
If project_ids is empty then all projects are examined
|
||||||
|
If user_ids are passed then only subset of these users is returned
|
||||||
|
"""
|
||||||
|
with TimingContext("mongo", "active_users_in_projects"):
|
||||||
|
res = set()
|
||||||
|
query = Q(company=company)
|
||||||
|
if project_ids:
|
||||||
|
query &= Q(project__in=project_ids)
|
||||||
|
if user_ids:
|
||||||
|
query &= Q(user__in=user_ids)
|
||||||
|
for cls_ in (Task, Model):
|
||||||
|
res |= set(cls_.objects(query).distinct(field="user"))
|
||||||
|
|
||||||
|
return res
|
@ -164,9 +164,11 @@ class TaskBLL(object):
|
|||||||
tags: Optional[Sequence[str]] = None,
|
tags: Optional[Sequence[str]] = None,
|
||||||
system_tags: Optional[Sequence[str]] = None,
|
system_tags: Optional[Sequence[str]] = None,
|
||||||
execution_overrides: Optional[dict] = None,
|
execution_overrides: Optional[dict] = None,
|
||||||
|
validate_references: bool = False,
|
||||||
) -> Task:
|
) -> Task:
|
||||||
task = cls.get_by_id(company_id=company_id, task_id=task_id)
|
task = cls.get_by_id(company_id=company_id, task_id=task_id)
|
||||||
execution_dict = task.execution.to_proper_dict() if task.execution else {}
|
execution_dict = task.execution.to_proper_dict() if task.execution else {}
|
||||||
|
execution_model_overriden = False
|
||||||
if execution_overrides:
|
if execution_overrides:
|
||||||
parameters = execution_overrides.get("parameters")
|
parameters = execution_overrides.get("parameters")
|
||||||
if parameters is not None:
|
if parameters is not None:
|
||||||
@ -174,6 +176,8 @@ class TaskBLL(object):
|
|||||||
ParameterKeyEscaper.escape(k): v for k, v in parameters.items()
|
ParameterKeyEscaper.escape(k): v for k, v in parameters.items()
|
||||||
}
|
}
|
||||||
execution_dict = deep_merge(execution_dict, execution_overrides)
|
execution_dict = deep_merge(execution_dict, execution_overrides)
|
||||||
|
execution_model_overriden = execution_overrides.get("model") is not None
|
||||||
|
|
||||||
artifacts = execution_dict.get("artifacts")
|
artifacts = execution_dict.get("artifacts")
|
||||||
if artifacts:
|
if artifacts:
|
||||||
execution_dict["artifacts"] = [
|
execution_dict["artifacts"] = [
|
||||||
@ -201,25 +205,41 @@ class TaskBLL(object):
|
|||||||
else None,
|
else None,
|
||||||
execution=execution_dict,
|
execution=execution_dict,
|
||||||
)
|
)
|
||||||
cls.validate(new_task)
|
cls.validate(
|
||||||
|
new_task,
|
||||||
|
validate_model=validate_references or execution_model_overriden,
|
||||||
|
validate_parent=validate_references or parent,
|
||||||
|
validate_project=validate_references or project,
|
||||||
|
)
|
||||||
new_task.save()
|
new_task.save()
|
||||||
|
|
||||||
return new_task
|
return new_task
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def validate(cls, task: Task):
|
def validate(
|
||||||
assert isinstance(task, Task)
|
cls,
|
||||||
|
task: Task,
|
||||||
if task.parent and not Task.get(
|
validate_model=True,
|
||||||
|
validate_parent=True,
|
||||||
|
validate_project=True,
|
||||||
|
):
|
||||||
|
if (
|
||||||
|
validate_parent
|
||||||
|
and task.parent
|
||||||
|
and not Task.get(
|
||||||
company=task.company, id=task.parent, _only=("id",), include_public=True
|
company=task.company, id=task.parent, _only=("id",), include_public=True
|
||||||
|
)
|
||||||
):
|
):
|
||||||
raise errors.bad_request.InvalidTaskId("invalid parent", parent=task.parent)
|
raise errors.bad_request.InvalidTaskId("invalid parent", parent=task.parent)
|
||||||
|
|
||||||
if task.project and not Project.get_for_writing(
|
if (
|
||||||
company=task.company, id=task.project
|
validate_project
|
||||||
|
and task.project
|
||||||
|
and not Project.get_for_writing(company=task.company, id=task.project)
|
||||||
):
|
):
|
||||||
raise errors.bad_request.InvalidProjectId(id=task.project)
|
raise errors.bad_request.InvalidProjectId(id=task.project)
|
||||||
|
|
||||||
|
if validate_model:
|
||||||
cls.validate_execution_model(task)
|
cls.validate_execution_model(task)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@ -530,59 +530,59 @@
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
"2.7" {
|
// "2.7" {
|
||||||
description: "Get 'log' events for this task"
|
// description: "Get 'log' events for this task"
|
||||||
request {
|
// request {
|
||||||
type: object
|
// type: object
|
||||||
required: [
|
// required: [
|
||||||
task
|
// task
|
||||||
]
|
// ]
|
||||||
properties {
|
// properties {
|
||||||
task {
|
// task {
|
||||||
type: string
|
// type: string
|
||||||
description: "Task ID"
|
// description: "Task ID"
|
||||||
}
|
// }
|
||||||
batch_size {
|
// batch_size {
|
||||||
type: integer
|
// type: integer
|
||||||
description: "The amount of log events to return"
|
// description: "The amount of log events to return"
|
||||||
}
|
// }
|
||||||
navigate_earlier {
|
// navigate_earlier {
|
||||||
type: boolean
|
// type: boolean
|
||||||
description: "If set then log events are retreived from the latest to the earliest ones (in timestamp descending order). Otherwise from the earliest to the latest ones (in timestamp ascending order). The default is True"
|
// description: "If set then log events are retreived from the latest to the earliest ones (in timestamp descending order). Otherwise from the earliest to the latest ones (in timestamp ascending order). The default is True"
|
||||||
}
|
// }
|
||||||
refresh {
|
// refresh {
|
||||||
type: boolean
|
// type: boolean
|
||||||
description: "If set then scroll will be moved to the latest logs (if 'navigate_earlier' is set to True) or to the earliest (otherwise)"
|
// description: "If set then scroll will be moved to the latest logs (if 'navigate_earlier' is set to True) or to the earliest (otherwise)"
|
||||||
}
|
// }
|
||||||
scroll_id {
|
// scroll_id {
|
||||||
type: string
|
// type: string
|
||||||
description: "Scroll ID of previous call (used for getting more results)"
|
// description: "Scroll ID of previous call (used for getting more results)"
|
||||||
}
|
// }
|
||||||
}
|
// }
|
||||||
}
|
// }
|
||||||
response {
|
// response {
|
||||||
type: object
|
// type: object
|
||||||
properties {
|
// properties {
|
||||||
events {
|
// events {
|
||||||
type: array
|
// type: array
|
||||||
items { type: object }
|
// items { type: object }
|
||||||
description: "Log items list"
|
// description: "Log items list"
|
||||||
}
|
// }
|
||||||
returned {
|
// returned {
|
||||||
type: integer
|
// type: integer
|
||||||
description: "Number of log events returned"
|
// description: "Number of log events returned"
|
||||||
}
|
// }
|
||||||
total {
|
// total {
|
||||||
type: number
|
// type: number
|
||||||
description: "Total number of log events available for this query"
|
// description: "Total number of log events available for this query"
|
||||||
}
|
// }
|
||||||
scroll_id {
|
// scroll_id {
|
||||||
type: string
|
// type: string
|
||||||
description: "Scroll ID for getting more results"
|
// description: "Scroll ID for getting more results"
|
||||||
}
|
// }
|
||||||
}
|
// }
|
||||||
}
|
// }
|
||||||
}
|
// }
|
||||||
}
|
}
|
||||||
get_task_events {
|
get_task_events {
|
||||||
"2.1" {
|
"2.1" {
|
||||||
|
@ -591,6 +591,10 @@ clone {
|
|||||||
description: "The execution params for the cloned task. The params not specified are taken from the original task"
|
description: "The execution params for the cloned task. The params not specified are taken from the original task"
|
||||||
"$ref": "#/definitions/execution"
|
"$ref": "#/definitions/execution"
|
||||||
}
|
}
|
||||||
|
validate_references {
|
||||||
|
description: "If set to 'false' then the task fields that are copied from the original task are not validated. The default is false."
|
||||||
|
type: boolean
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
response {
|
response {
|
||||||
|
@ -94,26 +94,27 @@ def get_task_log_v1_7(call, company_id, req_model):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@endpoint("events.get_task_log", min_version="2.7", request_data_model=LogEventsRequest)
|
# uncomment this once the front end is ready
|
||||||
def get_task_log(call, company_id, req_model: LogEventsRequest):
|
# @endpoint("events.get_task_log", min_version="2.7", request_data_model=LogEventsRequest)
|
||||||
task_id = req_model.task
|
# def get_task_log(call, company_id, req_model: LogEventsRequest):
|
||||||
task_bll.assert_exists(company_id, task_id, allow_public=True)
|
# task_id = req_model.task
|
||||||
|
# task_bll.assert_exists(company_id, task_id, allow_public=True)
|
||||||
res = event_bll.log_events_iterator.get_task_events(
|
#
|
||||||
company_id=company_id,
|
# res = event_bll.log_events_iterator.get_task_events(
|
||||||
task_id=task_id,
|
# company_id=company_id,
|
||||||
batch_size=req_model.batch_size,
|
# task_id=task_id,
|
||||||
navigate_earlier=req_model.navigate_earlier,
|
# batch_size=req_model.batch_size,
|
||||||
refresh=req_model.refresh,
|
# navigate_earlier=req_model.navigate_earlier,
|
||||||
state_id=req_model.scroll_id,
|
# refresh=req_model.refresh,
|
||||||
)
|
# state_id=req_model.scroll_id,
|
||||||
|
# )
|
||||||
call.result.data = dict(
|
#
|
||||||
events=res.events,
|
# call.result.data = dict(
|
||||||
returned=len(res.events),
|
# events=res.events,
|
||||||
total=res.total_events,
|
# returned=len(res.events),
|
||||||
scroll_id=res.next_scroll_id,
|
# total=res.total_events,
|
||||||
)
|
# scroll_id=res.next_scroll_id,
|
||||||
|
# )
|
||||||
|
|
||||||
|
|
||||||
@endpoint("events.download_task_log", required_fields=["task"])
|
@endpoint("events.download_task_log", required_fields=["task"])
|
||||||
|
@ -361,6 +361,7 @@ def clone_task(call: APICall, company_id, request: CloneRequest):
|
|||||||
tags=request.new_task_tags,
|
tags=request.new_task_tags,
|
||||||
system_tags=request.new_task_system_tags,
|
system_tags=request.new_task_system_tags,
|
||||||
execution_overrides=request.execution_overrides,
|
execution_overrides=request.execution_overrides,
|
||||||
|
validate_references=request.validate_references,
|
||||||
)
|
)
|
||||||
call.result.data_model = IdResponse(id=task.id)
|
call.result.data_model = IdResponse(id=task.id)
|
||||||
|
|
||||||
|
@ -186,6 +186,7 @@ class TestTaskEvents(TestService):
|
|||||||
self.assertEqual(len(res.events), 1)
|
self.assertEqual(len(res.events), 1)
|
||||||
|
|
||||||
def test_task_logs(self):
|
def test_task_logs(self):
|
||||||
|
# this test will fail until the new api is uncommented
|
||||||
task = self._temp_task()
|
task = self._temp_task()
|
||||||
timestamp = es_factory.get_timestamp_millis()
|
timestamp = es_factory.get_timestamp_millis()
|
||||||
events = [
|
events = [
|
||||||
|
@ -74,13 +74,13 @@ class TestTasksEdit(TestService):
|
|||||||
new_name = "new test"
|
new_name = "new test"
|
||||||
new_tags = ["by"]
|
new_tags = ["by"]
|
||||||
execution_overrides = dict(framework="Caffe")
|
execution_overrides = dict(framework="Caffe")
|
||||||
new_task_id = self.api.tasks.clone(
|
new_task_id = self._clone_task(
|
||||||
task=task,
|
task=task,
|
||||||
new_task_name=new_name,
|
new_task_name=new_name,
|
||||||
new_task_tags=new_tags,
|
new_task_tags=new_tags,
|
||||||
execution_overrides=execution_overrides,
|
execution_overrides=execution_overrides,
|
||||||
new_task_parent=task,
|
new_task_parent=task,
|
||||||
).id
|
)
|
||||||
new_task = self.api.tasks.get_by_id(task=new_task_id).task
|
new_task = self.api.tasks.get_by_id(task=new_task_id).task
|
||||||
self.assertEqual(new_task.name, new_name)
|
self.assertEqual(new_task.name, new_name)
|
||||||
self.assertEqual(new_task.type, "testing")
|
self.assertEqual(new_task.type, "testing")
|
||||||
@ -91,3 +91,32 @@ class TestTasksEdit(TestService):
|
|||||||
self.assertEqual(new_task.execution.parameters, execution["parameters"])
|
self.assertEqual(new_task.execution.parameters, execution["parameters"])
|
||||||
self.assertEqual(new_task.execution.framework, execution_overrides["framework"])
|
self.assertEqual(new_task.execution.framework, execution_overrides["framework"])
|
||||||
self.assertEqual(new_task.system_tags, [])
|
self.assertEqual(new_task.system_tags, [])
|
||||||
|
|
||||||
|
def test_model_check_in_clone(self):
|
||||||
|
model = self.new_model()
|
||||||
|
task = self.new_task(execution=dict(model=model))
|
||||||
|
|
||||||
|
# task with deleted model still can be copied
|
||||||
|
self.api.models.delete(model=model, force=True)
|
||||||
|
self._clone_task(task=task, new_task_name="clone test")
|
||||||
|
|
||||||
|
# unless check for refs is done
|
||||||
|
with self.api.raises(InvalidModelId):
|
||||||
|
self._clone_task(
|
||||||
|
task=task, new_task_name="clone test2", validate_references=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# if the model is overriden then it is always checked
|
||||||
|
with self.api.raises(InvalidModelId):
|
||||||
|
self._clone_task(
|
||||||
|
task=task,
|
||||||
|
new_task_name="clone test3",
|
||||||
|
execution_overrides=dict(model="not existing"),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _clone_task(self, task, **kwargs):
|
||||||
|
new_task = self.api.tasks.clone(task=task, **kwargs).id
|
||||||
|
self.defer(
|
||||||
|
self.api.tasks.delete, task=new_task, move_to_trash=False, force=True
|
||||||
|
)
|
||||||
|
return new_task
|
||||||
|
Loading…
Reference in New Issue
Block a user