Fix tasks.clone

This commit is contained in:
allegroai 2021-01-05 18:15:01 +02:00
parent df334d083e
commit 29c792d459
5 changed files with 54 additions and 31 deletions

View File

@ -111,8 +111,8 @@ class CloneRequest(TaskRequest):
new_task_system_tags = ListField([str])
new_task_parent = StringField()
new_task_project = StringField()
new_hyperparams = DictField()
new_configuration = DictField()
new_task_hyperparams = DictField()
new_task_configuration = DictField()
execution_overrides = DictField()
validate_references = BoolField(default=False)
new_project_name = StringField()

View File

@ -29,6 +29,7 @@ from apiserver.database.model.task.task import (
from apiserver.database.utils import get_company_or_none_constraint, id as create_id
from apiserver.es_factory import es_factory
from apiserver.service_repo import APICall
from apiserver.services.utils import validate_tags
from apiserver.timing_context import TimingContext
from apiserver.utilities.parameter_key_escaper import ParameterKeyEscaper
from .artifacts import artifacts_prepare_for_save
@ -179,7 +180,8 @@ class TaskBLL:
execution_overrides: Optional[dict] = None,
validate_references: bool = False,
new_project_name: str = None,
) -> Task:
) -> Tuple[Task, dict]:
validate_tags(tags, system_tags)
params_dict = {
field: value
for field, value in (
@ -215,18 +217,20 @@ class TaskBLL:
if a.get("mode") != ArtifactModes.output
}
new_project_data = None
if not project and new_project_name:
# Use a project with the provided name, or create a new project
project = project_bll.find_or_create(
project = ProjectBLL.find_or_create(
project_name=new_project_name,
user=user_id,
company=company_id,
description="Auto-generated while cloning",
)
new_project_data = {"id": project, "name": new_project_name}
now = datetime.utcnow()
with translate_errors_context():
with TimingContext("mongo", "clone task"):
new_task = Task(
id=create_id(),
user=user_id,
@ -270,7 +274,7 @@ class TaskBLL:
system_tags=updated_system_tags,
)
return new_task
return new_task, new_project_data
@classmethod
def validate(
@ -280,6 +284,11 @@ class TaskBLL:
validate_parent=True,
validate_project=True,
):
"""
Validate task properties according to the flag
Task project is always checked for being writable
in order to disable the modification of public projects
"""
if (
validate_parent
and task.parent
@ -289,12 +298,10 @@ class TaskBLL:
):
raise errors.bad_request.InvalidTaskId("invalid parent", parent=task.parent)
if (
validate_project
and task.project
and not Project.get_for_writing(company=task.company, id=task.project)
):
raise errors.bad_request.InvalidProjectId(id=task.project)
if task.project:
project = Project.get_for_writing(company=task.company, id=task.project)
if validate_project and not project:
raise errors.bad_request.InvalidProjectId(id=task.project)
if validate_model:
cls.validate_execution_model(task)

View File

@ -781,6 +781,18 @@ clone {
}
}
}
response {
properties {
new_project {
description: "In case the new_project_name was specified returns the target project details"
type: object
properties {
id: "The ID of the target project"
name: "The name of the target project"
}
}
}
}
}
}
create {

View File

@ -87,7 +87,6 @@ from apiserver.service_repo import APICall, endpoint
from apiserver.services.utils import (
conform_tag_fields,
conform_output_tags,
validate_tags,
)
from apiserver.timing_context import TimingContext
from apiserver.utilities.partial_version import PartialVersion
@ -415,12 +414,9 @@ def create(call: APICall, company_id, req_model: CreateRequest):
call.result.data_model = IdResponse(id=task.id)
@endpoint(
"tasks.clone", request_data_model=CloneRequest, response_data_model=IdResponse
)
@endpoint("tasks.clone", request_data_model=CloneRequest)
def clone_task(call: APICall, company_id, request: CloneRequest):
validate_tags(request.new_task_tags, request.new_task_system_tags)
task = task_bll.clone_task(
task, new_project = task_bll.clone_task(
company_id=company_id,
user_id=call.identity.user,
task_id=request.task,
@ -430,13 +426,16 @@ def clone_task(call: APICall, company_id, request: CloneRequest):
project=request.new_task_project,
tags=request.new_task_tags,
system_tags=request.new_task_system_tags,
hyperparams=request.new_hyperparams,
configuration=request.new_configuration,
hyperparams=request.new_task_hyperparams,
configuration=request.new_task_configuration,
execution_overrides=request.execution_overrides,
validate_references=request.validate_references,
new_project_name=request.new_project_name,
)
call.result.data_model = IdResponse(id=task.id)
call.result.data = {
"id": task.id,
**({"new_project": new_project} if new_project else {}),
}
def prepare_update_fields(call: APICall, task, call_data):
@ -468,9 +467,7 @@ def update(call: APICall, company_id, req_model: UpdateRequest):
return UpdateResponse(updated=0)
updated_count, updated_fields = Task.safe_update(
company_id=company_id,
id=task_id,
partial_update_dict=partial_update_dict,
company_id=company_id, id=task_id, partial_update_dict=partial_update_dict,
)
if updated_count:
new_project = updated_fields.get("project", task.project)
@ -879,7 +876,13 @@ def reset(call: APICall, company_id, request: ResetRequest):
force=force,
status_reason="reset",
status_message="reset",
).execute(started=None, completed=None, published=None, active_duration=None, **updates)
).execute(
started=None,
completed=None,
published=None,
active_duration=None,
**updates,
)
)
# do not return artifacts since they are not serializable
@ -906,10 +909,7 @@ def archive(call: APICall, company_id, request: ArchiveRequest):
for task in tasks:
try:
TaskBLL.dequeue_and_change_status(
task,
company_id,
request.status_message,
request.status_reason,
task, company_id, request.status_message, request.status_reason,
)
except APIError:
# dequeue may fail if the task was not enqueued

View File

@ -18,10 +18,14 @@ class TestMoveUnderProject(TestService):
# task clone
p2_name = "project_for_clone"
task2 = self.api.tasks.clone(task=task, new_project_name=p2_name).id
res = self.api.tasks.clone(task=task, new_project_name=p2_name)
task2 = res.id
project_data = res.new_project
self.assertTrue(project_data.id)
self.assertEqual(p2_name, project_data.name)
tasks = self.api.tasks.get_all_ex(id=[task2]).tasks
project2 = tasks[0].project.id
self.assertTrue(project2)
self.assertEqual(project_data.id, project2)
projects = self.api.projects.get_all_ex(id=[project2]).projects
self.assertEqual(p2_name, projects[0].name)
self.api.projects.delete(project=project2, force=True)