add update_execution_queue parameter to tasks.enqueue

This commit is contained in:
clearml 2024-12-05 19:12:26 +02:00
parent 58b748ddf3
commit f9577f9faa
4 changed files with 53 additions and 23 deletions

View File

@ -109,6 +109,7 @@ class EnqueueRequest(UpdateRequest):
queue = StringField()
queue_name = StringField()
verify_watched_queue = BoolField(default=False)
update_execution_queue = BoolField(default=True)
class DeleteRequest(UpdateRequest):

View File

@ -22,7 +22,8 @@ from apiserver.database.model.task.task import (
TaskStatusMessage,
ArtifactModes,
Execution,
DEFAULT_LAST_ITERATION, TaskType,
DEFAULT_LAST_ITERATION,
TaskType,
)
from apiserver.database.utils import get_options
from apiserver.service_repo.auth import Identity
@ -100,7 +101,9 @@ def archive_task(
)
if include_pipeline_steps and (
step_tasks := _get_pipeline_steps_for_controller_task(task, company_id, only=fields)
step_tasks := _get_pipeline_steps_for_controller_task(
task, company_id, only=fields
)
):
for step in step_tasks:
archive_task_core(step)
@ -137,7 +140,9 @@ def unarchive_task(
)
if include_pipeline_steps and (
step_tasks := _get_pipeline_steps_for_controller_task(task, company_id, only=fields)
step_tasks := _get_pipeline_steps_for_controller_task(
task, company_id, only=fields
)
):
for step in step_tasks:
unarchive_task_core(step)
@ -205,12 +210,25 @@ def enqueue_task(
queue_name: str = None,
validate: bool = False,
force: bool = False,
update_execution_queue: bool = True,
) -> Tuple[int, dict]:
if queue_id and queue_name:
raise errors.bad_request.ValidationError(
"Either queue id or queue name should be provided"
)
task = get_task_with_write_access(
task_id=task_id, company_id=company_id, identity=identity
)
if not update_execution_queue:
if not (
task.status == TaskStatus.queued and task.execution and task.execution.queue
):
raise errors.bad_request.ValidationError(
"Cannot skip setting execution queue for a task "
"that is not enqueued or does not have execution queue set"
)
if queue_name:
queue = queue_bll.get_by_name(
company_id=company_id, queue_name=queue_name, only=("id",)
@ -223,10 +241,6 @@ def enqueue_task(
# try to get default queue
queue_id = queue_bll.get_default(company_id).id
task = get_task_with_write_access(
task_id=task_id, company_id=company_id, identity=identity
)
user_id = identity.user
if validate:
TaskBLL.validate(task)
@ -258,13 +272,19 @@ def enqueue_task(
raise
# set the current queue ID in the task
if task.execution:
Task.objects(id=task_id).update(execution__queue=queue_id, multi=False)
else:
Task.objects(id=task_id).update(execution=Execution(queue=queue_id), multi=False)
if update_execution_queue:
if task.execution:
Task.objects(id=task_id).update(execution__queue=queue_id, multi=False)
else:
Task.objects(id=task_id).update(
execution=Execution(queue=queue_id), multi=False
)
nested_set(res, ("fields", "execution.queue"), queue_id)
# make sure that the task is not queued in any other queue
TaskBLL.remove_task_from_all_queues(company_id=company_id, task_id=task_id, exclude=queue_id)
nested_set(res, ("fields", "execution.queue"), queue_id)
TaskBLL.remove_task_from_all_queues(
company_id=company_id, task_id=task_id, exclude=queue_id
)
return 1, res
@ -304,9 +324,7 @@ def delete_task(
include_pipeline_steps: bool,
) -> Tuple[int, Task, CleanupResult]:
user_id = identity.user
task = get_task_with_write_access(
task_id, company_id=company_id, identity=identity
)
task = get_task_with_write_access(task_id, company_id=company_id, identity=identity)
if (
task.status != TaskStatus.created
@ -378,9 +396,7 @@ def reset_task(
clear_all: bool,
) -> Tuple[dict, CleanupResult, dict]:
user_id = identity.user
task = get_task_with_write_access(
task_id, company_id=company_id, identity=identity
)
task = get_task_with_write_access(task_id, company_id=company_id, identity=identity)
if not force and task.status == TaskStatus.published:
raise errors.bad_request.InvalidTaskStatus(task_id=task.id, status=task.status)
@ -463,9 +479,7 @@ def publish_task(
status_reason: str = "",
) -> dict:
user_id = identity.user
task = get_task_with_write_access(
task_id, company_id=company_id, identity=identity
)
task = get_task_with_write_access(task_id, company_id=company_id, identity=identity)
if not force:
validate_status_change(task.status, TaskStatus.published)
@ -584,7 +598,9 @@ def stop_task(
).execute()
if include_pipeline_steps and (
step_tasks := _get_pipeline_steps_for_controller_task(task, company_id, only=fields)
step_tasks := _get_pipeline_steps_for_controller_task(
task, company_id, only=fields
)
):
for step in step_tasks:
stop_task_core(step, True)

View File

@ -1507,6 +1507,13 @@ Fails if the following parameters in the task were not filled:
type: boolean
}
}
"999.0": ${enqueue."2.22"} {
request.properties.update_execution_queue {
description: If set to false then the task 'execution.queue' is not updated. This can be done only for the task that is already enqueued
type: boolean
default: true
}
}
}
enqueue_many {
"2.13": ${_definitions.change_many_request} {

View File

@ -923,6 +923,11 @@ def delete_configuration(
response_data_model=EnqueueResponse,
)
def enqueue(call: APICall, company_id, request: EnqueueRequest):
if request.verify_watched_queue and not request.update_execution_queue:
raise errors.bad_request.ValidationError(
"verify_watched_queue cannot be used with update_execution_queue=False"
)
queued, res = enqueue_task(
task_id=request.task,
company_id=company_id,
@ -932,6 +937,7 @@ def enqueue(call: APICall, company_id, request: EnqueueRequest):
status_reason=request.status_reason,
queue_name=request.queue_name,
force=request.force,
update_execution_queue=request.update_execution_queue,
)
if request.verify_watched_queue:
res_queue = nested_get(res, ("fields", "execution.queue"))