mirror of
https://github.com/clearml/clearml-server
synced 2025-03-09 21:51:54 +00:00
Fix task can't be cloned if input model was deleted
This commit is contained in:
parent
f8d8fc40a6
commit
dcdf2a3d58
@ -100,6 +100,7 @@ class CloneRequest(TaskRequest):
|
||||
new_task_parent = StringField()
|
||||
new_task_project = StringField()
|
||||
execution_overrides = DictField()
|
||||
validate_references = BoolField(default=False)
|
||||
|
||||
|
||||
class AddOrUpdateArtifactsRequest(TaskRequest):
|
||||
|
1
server/bll/project/__init__.py
Normal file
1
server/bll/project/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
from .project_bll import ProjectBLL
|
33
server/bll/project/project_bll.py
Normal file
33
server/bll/project/project_bll.py
Normal file
@ -0,0 +1,33 @@
|
||||
from typing import Sequence, Optional
|
||||
|
||||
from mongoengine import Q
|
||||
|
||||
from config import config
|
||||
from database.model.model import Model
|
||||
from database.model.task.task import Task
|
||||
from timing_context import TimingContext
|
||||
|
||||
log = config.logger(__file__)
|
||||
|
||||
|
||||
class ProjectBLL:
|
||||
@classmethod
|
||||
def get_active_users(
|
||||
cls, company, project_ids: Sequence, user_ids: Optional[Sequence] = None
|
||||
) -> set:
|
||||
"""
|
||||
Get the set of user ids that created tasks/models in the given projects
|
||||
If project_ids is empty then all projects are examined
|
||||
If user_ids are passed then only subset of these users is returned
|
||||
"""
|
||||
with TimingContext("mongo", "active_users_in_projects"):
|
||||
res = set()
|
||||
query = Q(company=company)
|
||||
if project_ids:
|
||||
query &= Q(project__in=project_ids)
|
||||
if user_ids:
|
||||
query &= Q(user__in=user_ids)
|
||||
for cls_ in (Task, Model):
|
||||
res |= set(cls_.objects(query).distinct(field="user"))
|
||||
|
||||
return res
|
@ -164,9 +164,11 @@ class TaskBLL(object):
|
||||
tags: Optional[Sequence[str]] = None,
|
||||
system_tags: Optional[Sequence[str]] = None,
|
||||
execution_overrides: Optional[dict] = None,
|
||||
validate_references: bool = False,
|
||||
) -> Task:
|
||||
task = cls.get_by_id(company_id=company_id, task_id=task_id)
|
||||
execution_dict = task.execution.to_proper_dict() if task.execution else {}
|
||||
execution_model_overriden = False
|
||||
if execution_overrides:
|
||||
parameters = execution_overrides.get("parameters")
|
||||
if parameters is not None:
|
||||
@ -174,6 +176,8 @@ class TaskBLL(object):
|
||||
ParameterKeyEscaper.escape(k): v for k, v in parameters.items()
|
||||
}
|
||||
execution_dict = deep_merge(execution_dict, execution_overrides)
|
||||
execution_model_overriden = execution_overrides.get("model") is not None
|
||||
|
||||
artifacts = execution_dict.get("artifacts")
|
||||
if artifacts:
|
||||
execution_dict["artifacts"] = [
|
||||
@ -201,26 +205,42 @@ class TaskBLL(object):
|
||||
else None,
|
||||
execution=execution_dict,
|
||||
)
|
||||
cls.validate(new_task)
|
||||
cls.validate(
|
||||
new_task,
|
||||
validate_model=validate_references or execution_model_overriden,
|
||||
validate_parent=validate_references or parent,
|
||||
validate_project=validate_references or project,
|
||||
)
|
||||
new_task.save()
|
||||
|
||||
return new_task
|
||||
|
||||
@classmethod
|
||||
def validate(cls, task: Task):
|
||||
assert isinstance(task, Task)
|
||||
|
||||
if task.parent and not Task.get(
|
||||
company=task.company, id=task.parent, _only=("id",), include_public=True
|
||||
def validate(
|
||||
cls,
|
||||
task: Task,
|
||||
validate_model=True,
|
||||
validate_parent=True,
|
||||
validate_project=True,
|
||||
):
|
||||
if (
|
||||
validate_parent
|
||||
and task.parent
|
||||
and not Task.get(
|
||||
company=task.company, id=task.parent, _only=("id",), include_public=True
|
||||
)
|
||||
):
|
||||
raise errors.bad_request.InvalidTaskId("invalid parent", parent=task.parent)
|
||||
|
||||
if task.project and not Project.get_for_writing(
|
||||
company=task.company, id=task.project
|
||||
if (
|
||||
validate_project
|
||||
and task.project
|
||||
and not Project.get_for_writing(company=task.company, id=task.project)
|
||||
):
|
||||
raise errors.bad_request.InvalidProjectId(id=task.project)
|
||||
|
||||
cls.validate_execution_model(task)
|
||||
if validate_model:
|
||||
cls.validate_execution_model(task)
|
||||
|
||||
@staticmethod
|
||||
def get_unique_metric_variants(company_id, project_ids=None):
|
||||
|
@ -530,59 +530,59 @@
|
||||
}
|
||||
}
|
||||
}
|
||||
"2.7" {
|
||||
description: "Get 'log' events for this task"
|
||||
request {
|
||||
type: object
|
||||
required: [
|
||||
task
|
||||
]
|
||||
properties {
|
||||
task {
|
||||
type: string
|
||||
description: "Task ID"
|
||||
}
|
||||
batch_size {
|
||||
type: integer
|
||||
description: "The amount of log events to return"
|
||||
}
|
||||
navigate_earlier {
|
||||
type: boolean
|
||||
description: "If set then log events are retreived from the latest to the earliest ones (in timestamp descending order). Otherwise from the earliest to the latest ones (in timestamp ascending order). The default is True"
|
||||
}
|
||||
refresh {
|
||||
type: boolean
|
||||
description: "If set then scroll will be moved to the latest logs (if 'navigate_earlier' is set to True) or to the earliest (otherwise)"
|
||||
}
|
||||
scroll_id {
|
||||
type: string
|
||||
description: "Scroll ID of previous call (used for getting more results)"
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
events {
|
||||
type: array
|
||||
items { type: object }
|
||||
description: "Log items list"
|
||||
}
|
||||
returned {
|
||||
type: integer
|
||||
description: "Number of log events returned"
|
||||
}
|
||||
total {
|
||||
type: number
|
||||
description: "Total number of log events available for this query"
|
||||
}
|
||||
scroll_id {
|
||||
type: string
|
||||
description: "Scroll ID for getting more results"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// "2.7" {
|
||||
// description: "Get 'log' events for this task"
|
||||
// request {
|
||||
// type: object
|
||||
// required: [
|
||||
// task
|
||||
// ]
|
||||
// properties {
|
||||
// task {
|
||||
// type: string
|
||||
// description: "Task ID"
|
||||
// }
|
||||
// batch_size {
|
||||
// type: integer
|
||||
// description: "The amount of log events to return"
|
||||
// }
|
||||
// navigate_earlier {
|
||||
// type: boolean
|
||||
// description: "If set then log events are retreived from the latest to the earliest ones (in timestamp descending order). Otherwise from the earliest to the latest ones (in timestamp ascending order). The default is True"
|
||||
// }
|
||||
// refresh {
|
||||
// type: boolean
|
||||
// description: "If set then scroll will be moved to the latest logs (if 'navigate_earlier' is set to True) or to the earliest (otherwise)"
|
||||
// }
|
||||
// scroll_id {
|
||||
// type: string
|
||||
// description: "Scroll ID of previous call (used for getting more results)"
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// response {
|
||||
// type: object
|
||||
// properties {
|
||||
// events {
|
||||
// type: array
|
||||
// items { type: object }
|
||||
// description: "Log items list"
|
||||
// }
|
||||
// returned {
|
||||
// type: integer
|
||||
// description: "Number of log events returned"
|
||||
// }
|
||||
// total {
|
||||
// type: number
|
||||
// description: "Total number of log events available for this query"
|
||||
// }
|
||||
// scroll_id {
|
||||
// type: string
|
||||
// description: "Scroll ID for getting more results"
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
}
|
||||
get_task_events {
|
||||
"2.1" {
|
||||
|
@ -591,6 +591,10 @@ clone {
|
||||
description: "The execution params for the cloned task. The params not specified are taken from the original task"
|
||||
"$ref": "#/definitions/execution"
|
||||
}
|
||||
validate_references {
|
||||
description: "If set to 'false' then the task fields that are copied from the original task are not validated. The default is false."
|
||||
type: boolean
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
|
@ -94,26 +94,27 @@ def get_task_log_v1_7(call, company_id, req_model):
|
||||
)
|
||||
|
||||
|
||||
@endpoint("events.get_task_log", min_version="2.7", request_data_model=LogEventsRequest)
|
||||
def get_task_log(call, company_id, req_model: LogEventsRequest):
|
||||
task_id = req_model.task
|
||||
task_bll.assert_exists(company_id, task_id, allow_public=True)
|
||||
|
||||
res = event_bll.log_events_iterator.get_task_events(
|
||||
company_id=company_id,
|
||||
task_id=task_id,
|
||||
batch_size=req_model.batch_size,
|
||||
navigate_earlier=req_model.navigate_earlier,
|
||||
refresh=req_model.refresh,
|
||||
state_id=req_model.scroll_id,
|
||||
)
|
||||
|
||||
call.result.data = dict(
|
||||
events=res.events,
|
||||
returned=len(res.events),
|
||||
total=res.total_events,
|
||||
scroll_id=res.next_scroll_id,
|
||||
)
|
||||
# uncomment this once the front end is ready
|
||||
# @endpoint("events.get_task_log", min_version="2.7", request_data_model=LogEventsRequest)
|
||||
# def get_task_log(call, company_id, req_model: LogEventsRequest):
|
||||
# task_id = req_model.task
|
||||
# task_bll.assert_exists(company_id, task_id, allow_public=True)
|
||||
#
|
||||
# res = event_bll.log_events_iterator.get_task_events(
|
||||
# company_id=company_id,
|
||||
# task_id=task_id,
|
||||
# batch_size=req_model.batch_size,
|
||||
# navigate_earlier=req_model.navigate_earlier,
|
||||
# refresh=req_model.refresh,
|
||||
# state_id=req_model.scroll_id,
|
||||
# )
|
||||
#
|
||||
# call.result.data = dict(
|
||||
# events=res.events,
|
||||
# returned=len(res.events),
|
||||
# total=res.total_events,
|
||||
# scroll_id=res.next_scroll_id,
|
||||
# )
|
||||
|
||||
|
||||
@endpoint("events.download_task_log", required_fields=["task"])
|
||||
|
@ -361,6 +361,7 @@ def clone_task(call: APICall, company_id, request: CloneRequest):
|
||||
tags=request.new_task_tags,
|
||||
system_tags=request.new_task_system_tags,
|
||||
execution_overrides=request.execution_overrides,
|
||||
validate_references=request.validate_references,
|
||||
)
|
||||
call.result.data_model = IdResponse(id=task.id)
|
||||
|
||||
|
@ -186,6 +186,7 @@ class TestTaskEvents(TestService):
|
||||
self.assertEqual(len(res.events), 1)
|
||||
|
||||
def test_task_logs(self):
|
||||
# this test will fail until the new api is uncommented
|
||||
task = self._temp_task()
|
||||
timestamp = es_factory.get_timestamp_millis()
|
||||
events = [
|
||||
|
@ -74,13 +74,13 @@ class TestTasksEdit(TestService):
|
||||
new_name = "new test"
|
||||
new_tags = ["by"]
|
||||
execution_overrides = dict(framework="Caffe")
|
||||
new_task_id = self.api.tasks.clone(
|
||||
new_task_id = self._clone_task(
|
||||
task=task,
|
||||
new_task_name=new_name,
|
||||
new_task_tags=new_tags,
|
||||
execution_overrides=execution_overrides,
|
||||
new_task_parent=task,
|
||||
).id
|
||||
)
|
||||
new_task = self.api.tasks.get_by_id(task=new_task_id).task
|
||||
self.assertEqual(new_task.name, new_name)
|
||||
self.assertEqual(new_task.type, "testing")
|
||||
@ -91,3 +91,32 @@ class TestTasksEdit(TestService):
|
||||
self.assertEqual(new_task.execution.parameters, execution["parameters"])
|
||||
self.assertEqual(new_task.execution.framework, execution_overrides["framework"])
|
||||
self.assertEqual(new_task.system_tags, [])
|
||||
|
||||
def test_model_check_in_clone(self):
|
||||
model = self.new_model()
|
||||
task = self.new_task(execution=dict(model=model))
|
||||
|
||||
# task with deleted model still can be copied
|
||||
self.api.models.delete(model=model, force=True)
|
||||
self._clone_task(task=task, new_task_name="clone test")
|
||||
|
||||
# unless check for refs is done
|
||||
with self.api.raises(InvalidModelId):
|
||||
self._clone_task(
|
||||
task=task, new_task_name="clone test2", validate_references=True
|
||||
)
|
||||
|
||||
# if the model is overriden then it is always checked
|
||||
with self.api.raises(InvalidModelId):
|
||||
self._clone_task(
|
||||
task=task,
|
||||
new_task_name="clone test3",
|
||||
execution_overrides=dict(model="not existing"),
|
||||
)
|
||||
|
||||
def _clone_task(self, task, **kwargs):
|
||||
new_task = self.api.tasks.clone(task=task, **kwargs).id
|
||||
self.defer(
|
||||
self.api.tasks.delete, task=new_task, move_to_trash=False, force=True
|
||||
)
|
||||
return new_task
|
||||
|
Loading…
Reference in New Issue
Block a user