From 5ae64fd791a1fb3c7ecb72309bc7d14f1537746a Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Tue, 24 Dec 2019 18:01:48 +0200 Subject: [PATCH] Add support for tasks.clone --- server/bll/task/task_bll.py | 54 ++++++++++++++++++++++- server/database/model/task/task.py | 7 ++- server/schema/services/tasks.conf | 54 +++++++++++++++++++++++ server/services/tasks.py | 27 ++++++++++-- server/tests/automated/test_tasks_edit.py | 39 ++++++++++++++++ server/tests/automated/test_workers.py | 6 +-- server/utilities/dicts.py | 20 ++++++++- 7 files changed, 197 insertions(+), 10 deletions(-) diff --git a/server/bll/task/task_bll.py b/server/bll/task/task_bll.py index 921946b..b0eb18a 100644 --- a/server/bll/task/task_bll.py +++ b/server/bll/task/task_bll.py @@ -24,10 +24,13 @@ from database.model.task.task import ( TaskStatus, TaskStatusMessage, TaskSystemTags, + ArtifactModes, + Artifact, ) from database.utils import get_company_or_none_constraint, id as create_id from service_repo import APICall from timing_context import TimingContext +from utilities.dicts import deep_merge from utilities.threads_manager import ThreadsManager from .utils import ChangeStatusRequest, validate_status_change @@ -151,6 +154,51 @@ class TaskBLL(object): return model + @classmethod + def clone_task( + cls, + company_id, + user_id, + task_id, + name: Optional[str] = None, + comment: Optional[str] = None, + parent: Optional[str] = None, + project: Optional[str] = None, + tags: Optional[Sequence[str]] = None, + system_tags: Optional[Sequence[str]] = None, + execution_overrides: Optional[dict] = None, + ) -> Task: + task = cls.get_by_id(company_id=company_id, task_id=task_id) + execution_dict = task.execution.to_proper_dict() if task.execution else {} + if execution_overrides: + execution_dict = deep_merge(execution_dict, execution_overrides) + artifacts = execution_dict.get("artifacts") + if artifacts: + execution_dict["artifacts"] = [ + a for a in artifacts if a.get("mode") != ArtifactModes.output + ] + now = datetime.utcnow() + new_task = Task( + id=create_id(), + user=user_id, + company=company_id, + created=now, + last_update=now, + name=name or task.name, + comment=comment or task.comment, + parent=parent or task.parent, + project=project or task.project, + tags=tags or task.tags, + system_tags=system_tags or [], + type=task.type, + script=task.script, + output=Output(destination=task.output.destination) if task.output else None, + execution=execution_dict, + ) + cls.validate(new_task) + new_task.save() + return new_task + @classmethod def validate(cls, task: Task): assert isinstance(task, Task) @@ -160,8 +208,10 @@ class TaskBLL(object): ): raise errors.bad_request.InvalidTaskId("invalid parent", parent=task.parent) - if task.project: - Project.get_for_writing(company=task.company, id=task.project) + if task.project and not Project.get_for_writing( + company=task.company, id=task.project + ): + raise errors.bad_request.InvalidProjectId(id=task.project) cls.validate_execution_model(task) diff --git a/server/database/model/task/task.py b/server/database/model/task/task.py index 5af3440..e7b2f49 100644 --- a/server/database/model/task/task.py +++ b/server/database/model/task/task.py @@ -67,10 +67,15 @@ class ArtifactTypeData(EmbeddedDocument): data_hash = StringField() +class ArtifactModes: + input = "input" + output = "output" + + class Artifact(EmbeddedDocument): key = StringField(required=True) type = StringField(required=True) - mode = StringField(choices=("input", "output"), default="output") + mode = StringField(choices=get_options(ArtifactModes), default=ArtifactModes.output) uri = StringField() hash = StringField() content_size = LongField() diff --git a/server/schema/services/tasks.conf b/server/schema/services/tasks.conf index 2b1682a..b4cffca 100644 --- a/server/schema/services/tasks.conf +++ b/server/schema/services/tasks.conf @@ -550,6 +550,60 @@ get_all { } } } +clone { + "2.5" { + description: "Clone an existing task" + request { + type: object + required: [ task ] + properties { + task { + description: "ID of the task" + type: string + } + new_task_name { + description: "The name of the cloned task. If not provided then taken from the original task" + type: string + } + new_task_comment { + description: "The comment of the cloned task. If not provided then taken from the original task" + type: string + } + new_task_tags { + description: "The user-defined tags of the cloned task. If not provided then taken from the original task" + type: array + items { type: string } + } + new_task_system_tags { + description: "The system tags of the cloned task. If not provided then empty" + type: array + items { type: string } + } + new_task_parent { + description: "The parent of the cloned task. If not provided then taken from the original task" + type: string + } + new_task_project { + description: "The project of the cloned task. If not provided then taken from the original task" + type: string + } + execution_overrides { + description: "The execution params for the cloned task. The params not specified are taken from the original task" + "$ref": "#/definitions/execution" + } + } + } + response { + type: object + properties { + id { + description: "ID of the new task" + type: string + } + } + } + } +} create { "2.1" { description: "Create a new task" diff --git a/server/services/tasks.py b/server/services/tasks.py index f0418be..53b15b3 100644 --- a/server/services/tasks.py +++ b/server/services/tasks.py @@ -12,7 +12,7 @@ from mongoengine.queryset.transform import COMPARISON_OPERATORS from pymongo import UpdateOne from apierrors import errors, APIError -from apimodels.base import UpdateResponse +from apimodels.base import UpdateResponse, IdResponse from apimodels.tasks import ( StartedResponse, ResetResponse, @@ -281,7 +281,9 @@ def validate(call: APICall, company_id, req_model: CreateRequest): _validate_and_get_task_from_call(call) -@endpoint("tasks.create", request_data_model=CreateRequest) +@endpoint( + "tasks.create", request_data_model=CreateRequest, response_data_model=IdResponse +) def create(call: APICall, company_id, req_model: CreateRequest): task = _validate_and_get_task_from_call(call) @@ -289,7 +291,26 @@ def create(call: APICall, company_id, req_model: CreateRequest): task.save() update_project_time(task.project) - call.result.data = {"id": task.id} + call.result.data_model = IdResponse(id=task.id) + + +@endpoint( + "tasks.clone", request_data_model=CloneRequest, response_data_model=IdResponse +) +def clone_task(call: APICall, company_id, request: CloneRequest): + task = task_bll.clone_task( + company_id=company_id, + user_id=call.identity.user, + task_id=request.task, + name=request.new_task_name, + comment=request.new_task_comment, + parent=request.new_task_parent, + project=request.new_task_project, + tags=request.new_task_tags, + system_tags=request.new_task_system_tags, + execution_overrides=request.execution_overrides, + ) + call.result.data_model = IdResponse(id=task.id) def prepare_update_fields(call: APICall, task, call_data): diff --git a/server/tests/automated/test_tasks_edit.py b/server/tests/automated/test_tasks_edit.py index 0827bf4..8f5e6f1 100644 --- a/server/tests/automated/test_tasks_edit.py +++ b/server/tests/automated/test_tasks_edit.py @@ -6,6 +6,9 @@ log = config.logger(__file__) class TestTasksEdit(TestService): + def setUp(self, **kwargs): + super().setUp(version=2.5) + def new_task(self, **kwargs): return self.create_temp( "tasks", type="testing", name="test", input=dict(view=dict()), **kwargs @@ -34,3 +37,39 @@ class TestTasksEdit(TestService): self.api.models.edit(model=not_ready_model, ready=False) self.assertFalse(self.api.models.get_by_id(model=not_ready_model).model.ready) self.api.tasks.edit(task=task, execution=dict(model=not_ready_model)) + + def test_clone_task(self): + script = dict( + binary="python", + requirements=dict(pip=["six"]), + repository="https://example.come/foo/bar", + entry_point="test.py", + diff="foo", + ) + execution = dict(parameters=dict(test="Test")) + tags = ["hello"] + system_tags = ["development", "test"] + task = self.new_task( + script=script, execution=execution, tags=tags, system_tags=system_tags + ) + + new_name = "new test" + new_tags = ["by"] + execution_overrides = dict(framework="Caffe") + new_task_id = self.api.tasks.clone( + task=task, + new_task_name=new_name, + new_task_tags=new_tags, + execution_overrides=execution_overrides, + new_task_parent=task, + ).id + new_task = self.api.tasks.get_by_id(task=new_task_id).task + self.assertEqual(new_task.name, new_name) + self.assertEqual(new_task.type, "testing") + self.assertEqual(new_task.tags, new_tags) + self.assertEqual(new_task.status, "created") + self.assertEqual(new_task.script, script) + self.assertEqual(new_task.parent, task) + self.assertEqual(new_task.execution.parameters, execution["parameters"]) + self.assertEqual(new_task.execution.framework, execution_overrides["framework"]) + self.assertEqual(new_task.system_tags, []) diff --git a/server/tests/automated/test_workers.py b/server/tests/automated/test_workers.py index 02c9fbf..1489d92 100644 --- a/server/tests/automated/test_workers.py +++ b/server/tests/automated/test_workers.py @@ -108,7 +108,7 @@ class TestWorkersService(TestService): from_date = to_date - timedelta(days=1) # no variants - res = self.api.workers.get_statistics( + res = self.api.workers.get_stats( items=[ dict(key="cpu_usage", aggregation="avg"), dict(key="cpu_usage", aggregation="max"), @@ -142,7 +142,7 @@ class TestWorkersService(TestService): ) # split by variants - res = self.api.workers.get_statistics( + res = self.api.workers.get_stats( items=[dict(key="cpu_usage", aggregation="avg")], from_date=from_date.timestamp(), to_date=to_date.timestamp(), @@ -165,7 +165,7 @@ class TestWorkersService(TestService): assert all(_check_metric_and_variants(worker) for worker in res["workers"]) - res = self.api.workers.get_statistics( + res = self.api.workers.get_stats( items=[dict(key="cpu_usage", aggregation="avg")], from_date=from_date.timestamp(), to_date=to_date.timestamp(), diff --git a/server/utilities/dicts.py b/server/utilities/dicts.py index ba27882..3790063 100644 --- a/server/utilities/dicts.py +++ b/server/utilities/dicts.py @@ -12,6 +12,24 @@ def flatten_nested_items( for key, value in dictionary.items(): path = prefix + (key,) if isinstance(value, dict) and nesting != 0: - yield from flatten_nested_items(value, next_nesting, include_leaves, prefix=path) + yield from flatten_nested_items( + value, next_nesting, include_leaves, prefix=path + ) elif include_leaves is None or key in include_leaves: yield path, value + + +def deep_merge(source: dict, override: dict) -> dict: + """ + Merge the override dict into the source in-place + Contrary to the dpath.merge the sequences are not expanded + If override contains the sequence with the same name as source + then the whole sequence in the source is overridden + """ + for key, value in override.items(): + if key in source and isinstance(source[key], dict) and isinstance(value, dict): + deep_merge(source[key], value) + else: + source[key] = value + + return source