From f9577f9faa8b21e1d06362cb16be4076bbc48ba3 Mon Sep 17 00:00:00 2001 From: clearml <> Date: Thu, 5 Dec 2024 19:12:26 +0200 Subject: [PATCH] add update_execution_queue parameter to tasks.enqueue --- apiserver/apimodels/tasks.py | 1 + apiserver/bll/task/task_operations.py | 62 +++++++++++++++++---------- apiserver/schema/services/tasks.conf | 7 +++ apiserver/services/tasks.py | 6 +++ 4 files changed, 53 insertions(+), 23 deletions(-) diff --git a/apiserver/apimodels/tasks.py b/apiserver/apimodels/tasks.py index 50a07d8..a96f017 100644 --- a/apiserver/apimodels/tasks.py +++ b/apiserver/apimodels/tasks.py @@ -109,6 +109,7 @@ class EnqueueRequest(UpdateRequest): queue = StringField() queue_name = StringField() verify_watched_queue = BoolField(default=False) + update_execution_queue = BoolField(default=True) class DeleteRequest(UpdateRequest): diff --git a/apiserver/bll/task/task_operations.py b/apiserver/bll/task/task_operations.py index a80922b..15f5e04 100644 --- a/apiserver/bll/task/task_operations.py +++ b/apiserver/bll/task/task_operations.py @@ -22,7 +22,8 @@ from apiserver.database.model.task.task import ( TaskStatusMessage, ArtifactModes, Execution, - DEFAULT_LAST_ITERATION, TaskType, + DEFAULT_LAST_ITERATION, + TaskType, ) from apiserver.database.utils import get_options from apiserver.service_repo.auth import Identity @@ -100,7 +101,9 @@ def archive_task( ) if include_pipeline_steps and ( - step_tasks := _get_pipeline_steps_for_controller_task(task, company_id, only=fields) + step_tasks := _get_pipeline_steps_for_controller_task( + task, company_id, only=fields + ) ): for step in step_tasks: archive_task_core(step) @@ -137,7 +140,9 @@ def unarchive_task( ) if include_pipeline_steps and ( - step_tasks := _get_pipeline_steps_for_controller_task(task, company_id, only=fields) + step_tasks := _get_pipeline_steps_for_controller_task( + task, company_id, only=fields + ) ): for step in step_tasks: unarchive_task_core(step) @@ -205,12 +210,25 @@ def enqueue_task( queue_name: str = None, validate: bool = False, force: bool = False, + update_execution_queue: bool = True, ) -> Tuple[int, dict]: if queue_id and queue_name: raise errors.bad_request.ValidationError( "Either queue id or queue name should be provided" ) + task = get_task_with_write_access( + task_id=task_id, company_id=company_id, identity=identity + ) + if not update_execution_queue: + if not ( + task.status == TaskStatus.queued and task.execution and task.execution.queue + ): + raise errors.bad_request.ValidationError( + "Cannot skip setting execution queue for a task " + "that is not enqueued or does not have execution queue set" + ) + if queue_name: queue = queue_bll.get_by_name( company_id=company_id, queue_name=queue_name, only=("id",) @@ -223,10 +241,6 @@ def enqueue_task( # try to get default queue queue_id = queue_bll.get_default(company_id).id - task = get_task_with_write_access( - task_id=task_id, company_id=company_id, identity=identity - ) - user_id = identity.user if validate: TaskBLL.validate(task) @@ -258,13 +272,19 @@ def enqueue_task( raise # set the current queue ID in the task - if task.execution: - Task.objects(id=task_id).update(execution__queue=queue_id, multi=False) - else: - Task.objects(id=task_id).update(execution=Execution(queue=queue_id), multi=False) + if update_execution_queue: + if task.execution: + Task.objects(id=task_id).update(execution__queue=queue_id, multi=False) + else: + Task.objects(id=task_id).update( + execution=Execution(queue=queue_id), multi=False + ) + nested_set(res, ("fields", "execution.queue"), queue_id) + # make sure that the task is not queued in any other queue - TaskBLL.remove_task_from_all_queues(company_id=company_id, task_id=task_id, exclude=queue_id) - nested_set(res, ("fields", "execution.queue"), queue_id) + TaskBLL.remove_task_from_all_queues( + company_id=company_id, task_id=task_id, exclude=queue_id + ) return 1, res @@ -304,9 +324,7 @@ def delete_task( include_pipeline_steps: bool, ) -> Tuple[int, Task, CleanupResult]: user_id = identity.user - task = get_task_with_write_access( - task_id, company_id=company_id, identity=identity - ) + task = get_task_with_write_access(task_id, company_id=company_id, identity=identity) if ( task.status != TaskStatus.created @@ -378,9 +396,7 @@ def reset_task( clear_all: bool, ) -> Tuple[dict, CleanupResult, dict]: user_id = identity.user - task = get_task_with_write_access( - task_id, company_id=company_id, identity=identity - ) + task = get_task_with_write_access(task_id, company_id=company_id, identity=identity) if not force and task.status == TaskStatus.published: raise errors.bad_request.InvalidTaskStatus(task_id=task.id, status=task.status) @@ -463,9 +479,7 @@ def publish_task( status_reason: str = "", ) -> dict: user_id = identity.user - task = get_task_with_write_access( - task_id, company_id=company_id, identity=identity - ) + task = get_task_with_write_access(task_id, company_id=company_id, identity=identity) if not force: validate_status_change(task.status, TaskStatus.published) @@ -584,7 +598,9 @@ def stop_task( ).execute() if include_pipeline_steps and ( - step_tasks := _get_pipeline_steps_for_controller_task(task, company_id, only=fields) + step_tasks := _get_pipeline_steps_for_controller_task( + task, company_id, only=fields + ) ): for step in step_tasks: stop_task_core(step, True) diff --git a/apiserver/schema/services/tasks.conf b/apiserver/schema/services/tasks.conf index 4d54f66..299c90a 100644 --- a/apiserver/schema/services/tasks.conf +++ b/apiserver/schema/services/tasks.conf @@ -1507,6 +1507,13 @@ Fails if the following parameters in the task were not filled: type: boolean } } + "999.0": ${enqueue."2.22"} { + request.properties.update_execution_queue { + description: If set to false then the task 'execution.queue' is not updated. This can be done only for the task that is already enqueued + type: boolean + default: true + } + } } enqueue_many { "2.13": ${_definitions.change_many_request} { diff --git a/apiserver/services/tasks.py b/apiserver/services/tasks.py index 4e3f1c1..4b36ebb 100644 --- a/apiserver/services/tasks.py +++ b/apiserver/services/tasks.py @@ -923,6 +923,11 @@ def delete_configuration( response_data_model=EnqueueResponse, ) def enqueue(call: APICall, company_id, request: EnqueueRequest): + if request.verify_watched_queue and not request.update_execution_queue: + raise errors.bad_request.ValidationError( + "verify_watched_queue cannot be used with update_execution_queue=False" + ) + queued, res = enqueue_task( task_id=request.task, company_id=company_id, @@ -932,6 +937,7 @@ def enqueue(call: APICall, company_id, request: EnqueueRequest): status_reason=request.status_reason, queue_name=request.queue_name, force=request.force, + update_execution_queue=request.update_execution_queue, ) if request.verify_watched_queue: res_queue = nested_get(res, ("fields", "execution.queue"))