Add support for tasks.clone

This commit is contained in:
allegroai 2019-12-24 18:01:48 +02:00
parent f9776e4319
commit 5ae64fd791
7 changed files with 197 additions and 10 deletions

View File

@ -24,10 +24,13 @@ from database.model.task.task import (
TaskStatus,
TaskStatusMessage,
TaskSystemTags,
ArtifactModes,
Artifact,
)
from database.utils import get_company_or_none_constraint, id as create_id
from service_repo import APICall
from timing_context import TimingContext
from utilities.dicts import deep_merge
from utilities.threads_manager import ThreadsManager
from .utils import ChangeStatusRequest, validate_status_change
@ -151,6 +154,51 @@ class TaskBLL(object):
return model
@classmethod
def clone_task(
cls,
company_id,
user_id,
task_id,
name: Optional[str] = None,
comment: Optional[str] = None,
parent: Optional[str] = None,
project: Optional[str] = None,
tags: Optional[Sequence[str]] = None,
system_tags: Optional[Sequence[str]] = None,
execution_overrides: Optional[dict] = None,
) -> Task:
task = cls.get_by_id(company_id=company_id, task_id=task_id)
execution_dict = task.execution.to_proper_dict() if task.execution else {}
if execution_overrides:
execution_dict = deep_merge(execution_dict, execution_overrides)
artifacts = execution_dict.get("artifacts")
if artifacts:
execution_dict["artifacts"] = [
a for a in artifacts if a.get("mode") != ArtifactModes.output
]
now = datetime.utcnow()
new_task = Task(
id=create_id(),
user=user_id,
company=company_id,
created=now,
last_update=now,
name=name or task.name,
comment=comment or task.comment,
parent=parent or task.parent,
project=project or task.project,
tags=tags or task.tags,
system_tags=system_tags or [],
type=task.type,
script=task.script,
output=Output(destination=task.output.destination) if task.output else None,
execution=execution_dict,
)
cls.validate(new_task)
new_task.save()
return new_task
@classmethod
def validate(cls, task: Task):
assert isinstance(task, Task)
@ -160,8 +208,10 @@ class TaskBLL(object):
):
raise errors.bad_request.InvalidTaskId("invalid parent", parent=task.parent)
if task.project:
Project.get_for_writing(company=task.company, id=task.project)
if task.project and not Project.get_for_writing(
company=task.company, id=task.project
):
raise errors.bad_request.InvalidProjectId(id=task.project)
cls.validate_execution_model(task)

View File

@ -67,10 +67,15 @@ class ArtifactTypeData(EmbeddedDocument):
data_hash = StringField()
class ArtifactModes:
input = "input"
output = "output"
class Artifact(EmbeddedDocument):
key = StringField(required=True)
type = StringField(required=True)
mode = StringField(choices=("input", "output"), default="output")
mode = StringField(choices=get_options(ArtifactModes), default=ArtifactModes.output)
uri = StringField()
hash = StringField()
content_size = LongField()

View File

@ -550,6 +550,60 @@ get_all {
}
}
}
clone {
"2.5" {
description: "Clone an existing task"
request {
type: object
required: [ task ]
properties {
task {
description: "ID of the task"
type: string
}
new_task_name {
description: "The name of the cloned task. If not provided then taken from the original task"
type: string
}
new_task_comment {
description: "The comment of the cloned task. If not provided then taken from the original task"
type: string
}
new_task_tags {
description: "The user-defined tags of the cloned task. If not provided then taken from the original task"
type: array
items { type: string }
}
new_task_system_tags {
description: "The system tags of the cloned task. If not provided then empty"
type: array
items { type: string }
}
new_task_parent {
description: "The parent of the cloned task. If not provided then taken from the original task"
type: string
}
new_task_project {
description: "The project of the cloned task. If not provided then taken from the original task"
type: string
}
execution_overrides {
description: "The execution params for the cloned task. The params not specified are taken from the original task"
"$ref": "#/definitions/execution"
}
}
}
response {
type: object
properties {
id {
description: "ID of the new task"
type: string
}
}
}
}
}
create {
"2.1" {
description: "Create a new task"

View File

@ -12,7 +12,7 @@ from mongoengine.queryset.transform import COMPARISON_OPERATORS
from pymongo import UpdateOne
from apierrors import errors, APIError
from apimodels.base import UpdateResponse
from apimodels.base import UpdateResponse, IdResponse
from apimodels.tasks import (
StartedResponse,
ResetResponse,
@ -281,7 +281,9 @@ def validate(call: APICall, company_id, req_model: CreateRequest):
_validate_and_get_task_from_call(call)
@endpoint("tasks.create", request_data_model=CreateRequest)
@endpoint(
"tasks.create", request_data_model=CreateRequest, response_data_model=IdResponse
)
def create(call: APICall, company_id, req_model: CreateRequest):
task = _validate_and_get_task_from_call(call)
@ -289,7 +291,26 @@ def create(call: APICall, company_id, req_model: CreateRequest):
task.save()
update_project_time(task.project)
call.result.data = {"id": task.id}
call.result.data_model = IdResponse(id=task.id)
@endpoint(
"tasks.clone", request_data_model=CloneRequest, response_data_model=IdResponse
)
def clone_task(call: APICall, company_id, request: CloneRequest):
task = task_bll.clone_task(
company_id=company_id,
user_id=call.identity.user,
task_id=request.task,
name=request.new_task_name,
comment=request.new_task_comment,
parent=request.new_task_parent,
project=request.new_task_project,
tags=request.new_task_tags,
system_tags=request.new_task_system_tags,
execution_overrides=request.execution_overrides,
)
call.result.data_model = IdResponse(id=task.id)
def prepare_update_fields(call: APICall, task, call_data):

View File

@ -6,6 +6,9 @@ log = config.logger(__file__)
class TestTasksEdit(TestService):
def setUp(self, **kwargs):
super().setUp(version=2.5)
def new_task(self, **kwargs):
return self.create_temp(
"tasks", type="testing", name="test", input=dict(view=dict()), **kwargs
@ -34,3 +37,39 @@ class TestTasksEdit(TestService):
self.api.models.edit(model=not_ready_model, ready=False)
self.assertFalse(self.api.models.get_by_id(model=not_ready_model).model.ready)
self.api.tasks.edit(task=task, execution=dict(model=not_ready_model))
def test_clone_task(self):
script = dict(
binary="python",
requirements=dict(pip=["six"]),
repository="https://example.come/foo/bar",
entry_point="test.py",
diff="foo",
)
execution = dict(parameters=dict(test="Test"))
tags = ["hello"]
system_tags = ["development", "test"]
task = self.new_task(
script=script, execution=execution, tags=tags, system_tags=system_tags
)
new_name = "new test"
new_tags = ["by"]
execution_overrides = dict(framework="Caffe")
new_task_id = self.api.tasks.clone(
task=task,
new_task_name=new_name,
new_task_tags=new_tags,
execution_overrides=execution_overrides,
new_task_parent=task,
).id
new_task = self.api.tasks.get_by_id(task=new_task_id).task
self.assertEqual(new_task.name, new_name)
self.assertEqual(new_task.type, "testing")
self.assertEqual(new_task.tags, new_tags)
self.assertEqual(new_task.status, "created")
self.assertEqual(new_task.script, script)
self.assertEqual(new_task.parent, task)
self.assertEqual(new_task.execution.parameters, execution["parameters"])
self.assertEqual(new_task.execution.framework, execution_overrides["framework"])
self.assertEqual(new_task.system_tags, [])

View File

@ -108,7 +108,7 @@ class TestWorkersService(TestService):
from_date = to_date - timedelta(days=1)
# no variants
res = self.api.workers.get_statistics(
res = self.api.workers.get_stats(
items=[
dict(key="cpu_usage", aggregation="avg"),
dict(key="cpu_usage", aggregation="max"),
@ -142,7 +142,7 @@ class TestWorkersService(TestService):
)
# split by variants
res = self.api.workers.get_statistics(
res = self.api.workers.get_stats(
items=[dict(key="cpu_usage", aggregation="avg")],
from_date=from_date.timestamp(),
to_date=to_date.timestamp(),
@ -165,7 +165,7 @@ class TestWorkersService(TestService):
assert all(_check_metric_and_variants(worker) for worker in res["workers"])
res = self.api.workers.get_statistics(
res = self.api.workers.get_stats(
items=[dict(key="cpu_usage", aggregation="avg")],
from_date=from_date.timestamp(),
to_date=to_date.timestamp(),

View File

@ -12,6 +12,24 @@ def flatten_nested_items(
for key, value in dictionary.items():
path = prefix + (key,)
if isinstance(value, dict) and nesting != 0:
yield from flatten_nested_items(value, next_nesting, include_leaves, prefix=path)
yield from flatten_nested_items(
value, next_nesting, include_leaves, prefix=path
)
elif include_leaves is None or key in include_leaves:
yield path, value
def deep_merge(source: dict, override: dict) -> dict:
"""
Merge the override dict into the source in-place
Contrary to the dpath.merge the sequences are not expanded
If override contains the sequence with the same name as source
then the whole sequence in the source is overridden
"""
for key, value in override.items():
if key in source and isinstance(source[key], dict) and isinstance(value, dict):
deep_merge(source[key], value)
else:
source[key] = value
return source