From 29c792d45957e4722ea7727d8001f397898325bd Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Tue, 5 Jan 2021 18:15:01 +0200 Subject: [PATCH] Fix tasks.clone --- apiserver/apimodels/tasks.py | 4 +-- apiserver/bll/task/task_bll.py | 27 +++++++++------ apiserver/schema/services/tasks.conf | 12 +++++++ apiserver/services/tasks.py | 34 +++++++++---------- .../automated/test_move_under_project.py | 8 +++-- 5 files changed, 54 insertions(+), 31 deletions(-) diff --git a/apiserver/apimodels/tasks.py b/apiserver/apimodels/tasks.py index 27ca2a7..96fef42 100644 --- a/apiserver/apimodels/tasks.py +++ b/apiserver/apimodels/tasks.py @@ -111,8 +111,8 @@ class CloneRequest(TaskRequest): new_task_system_tags = ListField([str]) new_task_parent = StringField() new_task_project = StringField() - new_hyperparams = DictField() - new_configuration = DictField() + new_task_hyperparams = DictField() + new_task_configuration = DictField() execution_overrides = DictField() validate_references = BoolField(default=False) new_project_name = StringField() diff --git a/apiserver/bll/task/task_bll.py b/apiserver/bll/task/task_bll.py index fbf4205..aa3873e 100644 --- a/apiserver/bll/task/task_bll.py +++ b/apiserver/bll/task/task_bll.py @@ -29,6 +29,7 @@ from apiserver.database.model.task.task import ( from apiserver.database.utils import get_company_or_none_constraint, id as create_id from apiserver.es_factory import es_factory from apiserver.service_repo import APICall +from apiserver.services.utils import validate_tags from apiserver.timing_context import TimingContext from apiserver.utilities.parameter_key_escaper import ParameterKeyEscaper from .artifacts import artifacts_prepare_for_save @@ -179,7 +180,8 @@ class TaskBLL: execution_overrides: Optional[dict] = None, validate_references: bool = False, new_project_name: str = None, - ) -> Task: + ) -> Tuple[Task, dict]: + validate_tags(tags, system_tags) params_dict = { field: value for field, value in ( @@ -215,18 +217,20 @@ class TaskBLL: if a.get("mode") != ArtifactModes.output } + new_project_data = None if not project and new_project_name: # Use a project with the provided name, or create a new project - project = project_bll.find_or_create( + project = ProjectBLL.find_or_create( project_name=new_project_name, user=user_id, company=company_id, description="Auto-generated while cloning", ) + new_project_data = {"id": project, "name": new_project_name} now = datetime.utcnow() - with translate_errors_context(): + with TimingContext("mongo", "clone task"): new_task = Task( id=create_id(), user=user_id, @@ -270,7 +274,7 @@ class TaskBLL: system_tags=updated_system_tags, ) - return new_task + return new_task, new_project_data @classmethod def validate( @@ -280,6 +284,11 @@ class TaskBLL: validate_parent=True, validate_project=True, ): + """ + Validate task properties according to the flag + Task project is always checked for being writable + in order to disable the modification of public projects + """ if ( validate_parent and task.parent @@ -289,12 +298,10 @@ class TaskBLL: ): raise errors.bad_request.InvalidTaskId("invalid parent", parent=task.parent) - if ( - validate_project - and task.project - and not Project.get_for_writing(company=task.company, id=task.project) - ): - raise errors.bad_request.InvalidProjectId(id=task.project) + if task.project: + project = Project.get_for_writing(company=task.company, id=task.project) + if validate_project and not project: + raise errors.bad_request.InvalidProjectId(id=task.project) if validate_model: cls.validate_execution_model(task) diff --git a/apiserver/schema/services/tasks.conf b/apiserver/schema/services/tasks.conf index 208ae0f..e0e1143 100644 --- a/apiserver/schema/services/tasks.conf +++ b/apiserver/schema/services/tasks.conf @@ -781,6 +781,18 @@ clone { } } } + response { + properties { + new_project { + description: "In case the new_project_name was specified returns the target project details" + type: object + properties { + id: "The ID of the target project" + name: "The name of the target project" + } + } + } + } } } create { diff --git a/apiserver/services/tasks.py b/apiserver/services/tasks.py index 1ac8692..1daa8f2 100644 --- a/apiserver/services/tasks.py +++ b/apiserver/services/tasks.py @@ -87,7 +87,6 @@ from apiserver.service_repo import APICall, endpoint from apiserver.services.utils import ( conform_tag_fields, conform_output_tags, - validate_tags, ) from apiserver.timing_context import TimingContext from apiserver.utilities.partial_version import PartialVersion @@ -415,12 +414,9 @@ def create(call: APICall, company_id, req_model: CreateRequest): call.result.data_model = IdResponse(id=task.id) -@endpoint( - "tasks.clone", request_data_model=CloneRequest, response_data_model=IdResponse -) +@endpoint("tasks.clone", request_data_model=CloneRequest) def clone_task(call: APICall, company_id, request: CloneRequest): - validate_tags(request.new_task_tags, request.new_task_system_tags) - task = task_bll.clone_task( + task, new_project = task_bll.clone_task( company_id=company_id, user_id=call.identity.user, task_id=request.task, @@ -430,13 +426,16 @@ def clone_task(call: APICall, company_id, request: CloneRequest): project=request.new_task_project, tags=request.new_task_tags, system_tags=request.new_task_system_tags, - hyperparams=request.new_hyperparams, - configuration=request.new_configuration, + hyperparams=request.new_task_hyperparams, + configuration=request.new_task_configuration, execution_overrides=request.execution_overrides, validate_references=request.validate_references, new_project_name=request.new_project_name, ) - call.result.data_model = IdResponse(id=task.id) + call.result.data = { + "id": task.id, + **({"new_project": new_project} if new_project else {}), + } def prepare_update_fields(call: APICall, task, call_data): @@ -468,9 +467,7 @@ def update(call: APICall, company_id, req_model: UpdateRequest): return UpdateResponse(updated=0) updated_count, updated_fields = Task.safe_update( - company_id=company_id, - id=task_id, - partial_update_dict=partial_update_dict, + company_id=company_id, id=task_id, partial_update_dict=partial_update_dict, ) if updated_count: new_project = updated_fields.get("project", task.project) @@ -879,7 +876,13 @@ def reset(call: APICall, company_id, request: ResetRequest): force=force, status_reason="reset", status_message="reset", - ).execute(started=None, completed=None, published=None, active_duration=None, **updates) + ).execute( + started=None, + completed=None, + published=None, + active_duration=None, + **updates, + ) ) # do not return artifacts since they are not serializable @@ -906,10 +909,7 @@ def archive(call: APICall, company_id, request: ArchiveRequest): for task in tasks: try: TaskBLL.dequeue_and_change_status( - task, - company_id, - request.status_message, - request.status_reason, + task, company_id, request.status_message, request.status_reason, ) except APIError: # dequeue may fail if the task was not enqueued diff --git a/apiserver/tests/automated/test_move_under_project.py b/apiserver/tests/automated/test_move_under_project.py index 456412b..7c2d83c 100644 --- a/apiserver/tests/automated/test_move_under_project.py +++ b/apiserver/tests/automated/test_move_under_project.py @@ -18,10 +18,14 @@ class TestMoveUnderProject(TestService): # task clone p2_name = "project_for_clone" - task2 = self.api.tasks.clone(task=task, new_project_name=p2_name).id + res = self.api.tasks.clone(task=task, new_project_name=p2_name) + task2 = res.id + project_data = res.new_project + self.assertTrue(project_data.id) + self.assertEqual(p2_name, project_data.name) tasks = self.api.tasks.get_all_ex(id=[task2]).tasks project2 = tasks[0].project.id - self.assertTrue(project2) + self.assertEqual(project_data.id, project2) projects = self.api.projects.get_all_ex(id=[project2]).projects self.assertEqual(p2_name, projects[0].name) self.api.projects.delete(project=project2, force=True)