mirror of
https://github.com/clearml/clearml-server
synced 2025-04-28 17:51:24 +00:00
Add support for tasks.clone
This commit is contained in:
parent
f9776e4319
commit
5ae64fd791
@ -24,10 +24,13 @@ from database.model.task.task import (
|
||||
TaskStatus,
|
||||
TaskStatusMessage,
|
||||
TaskSystemTags,
|
||||
ArtifactModes,
|
||||
Artifact,
|
||||
)
|
||||
from database.utils import get_company_or_none_constraint, id as create_id
|
||||
from service_repo import APICall
|
||||
from timing_context import TimingContext
|
||||
from utilities.dicts import deep_merge
|
||||
from utilities.threads_manager import ThreadsManager
|
||||
from .utils import ChangeStatusRequest, validate_status_change
|
||||
|
||||
@ -151,6 +154,51 @@ class TaskBLL(object):
|
||||
|
||||
return model
|
||||
|
||||
@classmethod
|
||||
def clone_task(
|
||||
cls,
|
||||
company_id,
|
||||
user_id,
|
||||
task_id,
|
||||
name: Optional[str] = None,
|
||||
comment: Optional[str] = None,
|
||||
parent: Optional[str] = None,
|
||||
project: Optional[str] = None,
|
||||
tags: Optional[Sequence[str]] = None,
|
||||
system_tags: Optional[Sequence[str]] = None,
|
||||
execution_overrides: Optional[dict] = None,
|
||||
) -> Task:
|
||||
task = cls.get_by_id(company_id=company_id, task_id=task_id)
|
||||
execution_dict = task.execution.to_proper_dict() if task.execution else {}
|
||||
if execution_overrides:
|
||||
execution_dict = deep_merge(execution_dict, execution_overrides)
|
||||
artifacts = execution_dict.get("artifacts")
|
||||
if artifacts:
|
||||
execution_dict["artifacts"] = [
|
||||
a for a in artifacts if a.get("mode") != ArtifactModes.output
|
||||
]
|
||||
now = datetime.utcnow()
|
||||
new_task = Task(
|
||||
id=create_id(),
|
||||
user=user_id,
|
||||
company=company_id,
|
||||
created=now,
|
||||
last_update=now,
|
||||
name=name or task.name,
|
||||
comment=comment or task.comment,
|
||||
parent=parent or task.parent,
|
||||
project=project or task.project,
|
||||
tags=tags or task.tags,
|
||||
system_tags=system_tags or [],
|
||||
type=task.type,
|
||||
script=task.script,
|
||||
output=Output(destination=task.output.destination) if task.output else None,
|
||||
execution=execution_dict,
|
||||
)
|
||||
cls.validate(new_task)
|
||||
new_task.save()
|
||||
return new_task
|
||||
|
||||
@classmethod
|
||||
def validate(cls, task: Task):
|
||||
assert isinstance(task, Task)
|
||||
@ -160,8 +208,10 @@ class TaskBLL(object):
|
||||
):
|
||||
raise errors.bad_request.InvalidTaskId("invalid parent", parent=task.parent)
|
||||
|
||||
if task.project:
|
||||
Project.get_for_writing(company=task.company, id=task.project)
|
||||
if task.project and not Project.get_for_writing(
|
||||
company=task.company, id=task.project
|
||||
):
|
||||
raise errors.bad_request.InvalidProjectId(id=task.project)
|
||||
|
||||
cls.validate_execution_model(task)
|
||||
|
||||
|
@ -67,10 +67,15 @@ class ArtifactTypeData(EmbeddedDocument):
|
||||
data_hash = StringField()
|
||||
|
||||
|
||||
class ArtifactModes:
|
||||
input = "input"
|
||||
output = "output"
|
||||
|
||||
|
||||
class Artifact(EmbeddedDocument):
|
||||
key = StringField(required=True)
|
||||
type = StringField(required=True)
|
||||
mode = StringField(choices=("input", "output"), default="output")
|
||||
mode = StringField(choices=get_options(ArtifactModes), default=ArtifactModes.output)
|
||||
uri = StringField()
|
||||
hash = StringField()
|
||||
content_size = LongField()
|
||||
|
@ -550,6 +550,60 @@ get_all {
|
||||
}
|
||||
}
|
||||
}
|
||||
clone {
|
||||
"2.5" {
|
||||
description: "Clone an existing task"
|
||||
request {
|
||||
type: object
|
||||
required: [ task ]
|
||||
properties {
|
||||
task {
|
||||
description: "ID of the task"
|
||||
type: string
|
||||
}
|
||||
new_task_name {
|
||||
description: "The name of the cloned task. If not provided then taken from the original task"
|
||||
type: string
|
||||
}
|
||||
new_task_comment {
|
||||
description: "The comment of the cloned task. If not provided then taken from the original task"
|
||||
type: string
|
||||
}
|
||||
new_task_tags {
|
||||
description: "The user-defined tags of the cloned task. If not provided then taken from the original task"
|
||||
type: array
|
||||
items { type: string }
|
||||
}
|
||||
new_task_system_tags {
|
||||
description: "The system tags of the cloned task. If not provided then empty"
|
||||
type: array
|
||||
items { type: string }
|
||||
}
|
||||
new_task_parent {
|
||||
description: "The parent of the cloned task. If not provided then taken from the original task"
|
||||
type: string
|
||||
}
|
||||
new_task_project {
|
||||
description: "The project of the cloned task. If not provided then taken from the original task"
|
||||
type: string
|
||||
}
|
||||
execution_overrides {
|
||||
description: "The execution params for the cloned task. The params not specified are taken from the original task"
|
||||
"$ref": "#/definitions/execution"
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
id {
|
||||
description: "ID of the new task"
|
||||
type: string
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
create {
|
||||
"2.1" {
|
||||
description: "Create a new task"
|
||||
|
@ -12,7 +12,7 @@ from mongoengine.queryset.transform import COMPARISON_OPERATORS
|
||||
from pymongo import UpdateOne
|
||||
|
||||
from apierrors import errors, APIError
|
||||
from apimodels.base import UpdateResponse
|
||||
from apimodels.base import UpdateResponse, IdResponse
|
||||
from apimodels.tasks import (
|
||||
StartedResponse,
|
||||
ResetResponse,
|
||||
@ -281,7 +281,9 @@ def validate(call: APICall, company_id, req_model: CreateRequest):
|
||||
_validate_and_get_task_from_call(call)
|
||||
|
||||
|
||||
@endpoint("tasks.create", request_data_model=CreateRequest)
|
||||
@endpoint(
|
||||
"tasks.create", request_data_model=CreateRequest, response_data_model=IdResponse
|
||||
)
|
||||
def create(call: APICall, company_id, req_model: CreateRequest):
|
||||
task = _validate_and_get_task_from_call(call)
|
||||
|
||||
@ -289,7 +291,26 @@ def create(call: APICall, company_id, req_model: CreateRequest):
|
||||
task.save()
|
||||
update_project_time(task.project)
|
||||
|
||||
call.result.data = {"id": task.id}
|
||||
call.result.data_model = IdResponse(id=task.id)
|
||||
|
||||
|
||||
@endpoint(
|
||||
"tasks.clone", request_data_model=CloneRequest, response_data_model=IdResponse
|
||||
)
|
||||
def clone_task(call: APICall, company_id, request: CloneRequest):
|
||||
task = task_bll.clone_task(
|
||||
company_id=company_id,
|
||||
user_id=call.identity.user,
|
||||
task_id=request.task,
|
||||
name=request.new_task_name,
|
||||
comment=request.new_task_comment,
|
||||
parent=request.new_task_parent,
|
||||
project=request.new_task_project,
|
||||
tags=request.new_task_tags,
|
||||
system_tags=request.new_task_system_tags,
|
||||
execution_overrides=request.execution_overrides,
|
||||
)
|
||||
call.result.data_model = IdResponse(id=task.id)
|
||||
|
||||
|
||||
def prepare_update_fields(call: APICall, task, call_data):
|
||||
|
@ -6,6 +6,9 @@ log = config.logger(__file__)
|
||||
|
||||
|
||||
class TestTasksEdit(TestService):
|
||||
def setUp(self, **kwargs):
|
||||
super().setUp(version=2.5)
|
||||
|
||||
def new_task(self, **kwargs):
|
||||
return self.create_temp(
|
||||
"tasks", type="testing", name="test", input=dict(view=dict()), **kwargs
|
||||
@ -34,3 +37,39 @@ class TestTasksEdit(TestService):
|
||||
self.api.models.edit(model=not_ready_model, ready=False)
|
||||
self.assertFalse(self.api.models.get_by_id(model=not_ready_model).model.ready)
|
||||
self.api.tasks.edit(task=task, execution=dict(model=not_ready_model))
|
||||
|
||||
def test_clone_task(self):
|
||||
script = dict(
|
||||
binary="python",
|
||||
requirements=dict(pip=["six"]),
|
||||
repository="https://example.come/foo/bar",
|
||||
entry_point="test.py",
|
||||
diff="foo",
|
||||
)
|
||||
execution = dict(parameters=dict(test="Test"))
|
||||
tags = ["hello"]
|
||||
system_tags = ["development", "test"]
|
||||
task = self.new_task(
|
||||
script=script, execution=execution, tags=tags, system_tags=system_tags
|
||||
)
|
||||
|
||||
new_name = "new test"
|
||||
new_tags = ["by"]
|
||||
execution_overrides = dict(framework="Caffe")
|
||||
new_task_id = self.api.tasks.clone(
|
||||
task=task,
|
||||
new_task_name=new_name,
|
||||
new_task_tags=new_tags,
|
||||
execution_overrides=execution_overrides,
|
||||
new_task_parent=task,
|
||||
).id
|
||||
new_task = self.api.tasks.get_by_id(task=new_task_id).task
|
||||
self.assertEqual(new_task.name, new_name)
|
||||
self.assertEqual(new_task.type, "testing")
|
||||
self.assertEqual(new_task.tags, new_tags)
|
||||
self.assertEqual(new_task.status, "created")
|
||||
self.assertEqual(new_task.script, script)
|
||||
self.assertEqual(new_task.parent, task)
|
||||
self.assertEqual(new_task.execution.parameters, execution["parameters"])
|
||||
self.assertEqual(new_task.execution.framework, execution_overrides["framework"])
|
||||
self.assertEqual(new_task.system_tags, [])
|
||||
|
@ -108,7 +108,7 @@ class TestWorkersService(TestService):
|
||||
from_date = to_date - timedelta(days=1)
|
||||
|
||||
# no variants
|
||||
res = self.api.workers.get_statistics(
|
||||
res = self.api.workers.get_stats(
|
||||
items=[
|
||||
dict(key="cpu_usage", aggregation="avg"),
|
||||
dict(key="cpu_usage", aggregation="max"),
|
||||
@ -142,7 +142,7 @@ class TestWorkersService(TestService):
|
||||
)
|
||||
|
||||
# split by variants
|
||||
res = self.api.workers.get_statistics(
|
||||
res = self.api.workers.get_stats(
|
||||
items=[dict(key="cpu_usage", aggregation="avg")],
|
||||
from_date=from_date.timestamp(),
|
||||
to_date=to_date.timestamp(),
|
||||
@ -165,7 +165,7 @@ class TestWorkersService(TestService):
|
||||
|
||||
assert all(_check_metric_and_variants(worker) for worker in res["workers"])
|
||||
|
||||
res = self.api.workers.get_statistics(
|
||||
res = self.api.workers.get_stats(
|
||||
items=[dict(key="cpu_usage", aggregation="avg")],
|
||||
from_date=from_date.timestamp(),
|
||||
to_date=to_date.timestamp(),
|
||||
|
@ -12,6 +12,24 @@ def flatten_nested_items(
|
||||
for key, value in dictionary.items():
|
||||
path = prefix + (key,)
|
||||
if isinstance(value, dict) and nesting != 0:
|
||||
yield from flatten_nested_items(value, next_nesting, include_leaves, prefix=path)
|
||||
yield from flatten_nested_items(
|
||||
value, next_nesting, include_leaves, prefix=path
|
||||
)
|
||||
elif include_leaves is None or key in include_leaves:
|
||||
yield path, value
|
||||
|
||||
|
||||
def deep_merge(source: dict, override: dict) -> dict:
|
||||
"""
|
||||
Merge the override dict into the source in-place
|
||||
Contrary to the dpath.merge the sequences are not expanded
|
||||
If override contains the sequence with the same name as source
|
||||
then the whole sequence in the source is overridden
|
||||
"""
|
||||
for key, value in override.items():
|
||||
if key in source and isinstance(source[key], dict) and isinstance(value, dict):
|
||||
deep_merge(source[key], value)
|
||||
else:
|
||||
source[key] = value
|
||||
|
||||
return source
|
||||
|
Loading…
Reference in New Issue
Block a user