Add queues.clear_queue

Add new parameter 'update_task_status' to queues.remove_task
This commit is contained in:
clearml 2024-12-05 22:15:43 +02:00
parent 2752c4df54
commit 4b93f1f508
9 changed files with 253 additions and 87 deletions

View File

@ -56,6 +56,10 @@ class TaskRequest(QueueRequest):
task = StringField(required=True)
class RemoveTaskRequest(TaskRequest):
update_task_status = BoolField(default=False)
class AddTaskRequest(TaskRequest):
update_execution_queue = BoolField(default=True)

View File

@ -1,6 +1,6 @@
from collections import defaultdict
from datetime import datetime
from typing import Sequence, Optional, Tuple, Union
from typing import Sequence, Optional, Tuple, Union, Iterable
from elasticsearch import Elasticsearch
from mongoengine import Q
@ -135,51 +135,74 @@ class QueueBLL(object):
self.get_by_id(company_id=company_id, queue_id=queue_id, only=("id",))
return Queue.safe_update(company_id, queue_id, update_fields)
def delete(self, company_id: str, user_id: str, queue_id: str, force: bool) -> None:
def _update_task_status_on_removal_from_queue(
self,
company_id: str,
user_id: str,
task_ids: Iterable[str],
queue_id: str,
reason: str
) -> Sequence[str]:
from apiserver.bll.task import ChangeStatusRequest
tasks = []
for task_id in task_ids:
try:
task = Task.get(
company=company_id,
id=task_id,
_only=[
"id",
"company",
"status",
"enqueue_status",
"project",
],
)
if not task:
continue
tasks.append(task.id)
ChangeStatusRequest(
task=task,
new_status=task.enqueue_status or TaskStatus.created,
status_reason=reason,
status_message="",
user_id=user_id,
force=True,
).execute(enqueue_status=None)
except Exception as ex:
log.error(
f"Failed updating task {task_id} status on removal from queue: {queue_id}, {str(ex)}"
)
return tasks
def delete(self, company_id: str, user_id: str, queue_id: str, force: bool) -> Sequence[str]:
"""
Delete the queue
:raise errors.bad_request.InvalidQueueId: if the queue is not found
:raise errors.bad_request.QueueNotEmpty: if the queue is not empty and 'force' not set
"""
with translate_errors_context():
queue = self.get_by_id(company_id=company_id, queue_id=queue_id)
if queue.entries:
if not force:
raise errors.bad_request.QueueNotEmpty(
"use force=true to delete", id=queue_id
)
from apiserver.bll.task import ChangeStatusRequest
for item in queue.entries:
try:
task = Task.get(
company=company_id,
id=item.task,
_only=[
"id",
"company",
"status",
"enqueue_status",
"project",
],
)
if not task:
continue
ChangeStatusRequest(
task=task,
new_status=task.enqueue_status or TaskStatus.created,
status_reason="Queue was deleted",
status_message="",
user_id=user_id,
force=True,
).execute(enqueue_status=None)
except Exception as ex:
log.exception(
f"Failed dequeuing task {item.task} from queue: {queue_id}"
)
queue = self.get_by_id(company_id=company_id, queue_id=queue_id)
if not queue.entries:
queue.delete()
return []
if not force:
raise errors.bad_request.QueueNotEmpty(
"use force=true to delete", id=queue_id
)
tasks = self._update_task_status_on_removal_from_queue(
company_id=company_id,
user_id=user_id,
task_ids={item.task for item in queue.entries},
queue_id=queue_id,
reason=f"Queue {queue_id} was deleted",
)
queue.delete()
return tasks
def get_all(
self,
@ -307,7 +330,36 @@ class QueueBLL(object):
return queue.entries[0]
def remove_task(self, company_id: str, queue_id: str, task_id: str) -> int:
def clear_queue(
self,
company_id: str,
user_id: str,
queue_id: str,
):
queue = Queue.objects(company=company_id, id=queue_id).first()
if not queue:
raise errors.bad_request.InvalidQueueId(
queue=queue_id
)
if not queue.entries:
return []
tasks = self._update_task_status_on_removal_from_queue(
company_id=company_id,
user_id=user_id,
task_ids={item.task for item in queue.entries},
queue_id=queue_id,
reason=f"Queue {queue_id} was cleared",
)
queue.update(entries=[])
queue.reload()
self.metrics.log_queue_metrics_to_es(company_id=company_id, queues=[queue])
return tasks
def remove_task(self, company_id: str, user_id: str, queue_id: str, task_id: str, update_task_status: bool = False) -> int:
"""
Removes the task from the queue and returns the number of removed items
:raise errors.bad_request.InvalidQueueOrTaskNotQueued: if the task is not found in the queue
@ -322,6 +374,14 @@ class QueueBLL(object):
res = Queue.objects(entries__task=task_id, **query).update_one(
pull_all__entries=entries_to_remove, last_update=datetime.utcnow()
)
if res and update_task_status:
self._update_task_status_on_removal_from_queue(
company_id=company_id,
user_id=user_id,
task_ids=[task_id],
queue_id=queue_id,
reason=f"Task was removed from the queue {queue_id}",
)
queue.reload()
self.metrics.log_queue_metrics_to_es(company_id=company_id, queues=[queue])

View File

@ -168,7 +168,9 @@ class TaskBLL:
configuration_overrides: Optional[dict] = None,
) -> Tuple[Task, dict]:
validate_tags(tags, system_tags)
task: Task = cls.get_by_id(company_id=company_id, task_id=task_id, allow_public=True)
task: Task = cls.get_by_id(
company_id=company_id, task_id=task_id, allow_public=True
)
params_dict = {}
if hyperparams:
@ -187,8 +189,7 @@ class TaskBLL:
params_dict["configuration"] = configuration
elif configuration_overrides:
updated_configuration = {
k: value
for k, value in (task.configuration or {}).items()
k: value for k, value in (task.configuration or {}).items()
}
for key, value in configuration_overrides.items():
updated_configuration[key] = value
@ -457,7 +458,9 @@ class TaskBLL:
return ret
@staticmethod
def remove_task_from_all_queues(company_id: str, task_id: str, exclude: str = None) -> int:
def remove_task_from_all_queues(
company_id: str, task_id: str, exclude: str = None
) -> int:
more = {}
if exclude:
more["id__ne"] = exclude
@ -478,7 +481,7 @@ class TaskBLL:
new_status_for_aborted_task=None,
):
try:
cls.dequeue(task, company_id, silent_fail=True)
cls.dequeue(task, company_id=company_id, user_id=user_id, silent_fail=True)
except APIError:
# dequeue may fail if the queue was deleted
pass
@ -502,7 +505,7 @@ class TaskBLL:
).execute(enqueue_status=None)
@classmethod
def dequeue(cls, task: Task, company_id: str, silent_fail=False):
def dequeue(cls, task: Task, company_id: str, user_id: str, silent_fail=False):
"""
Dequeue the task from the queue
:param task: task to dequeue
@ -529,6 +532,9 @@ class TaskBLL:
return {
"removed": queue_bll.remove_task(
company_id=company_id, queue_id=task.execution.queue, task_id=task.id
company_id=company_id,
user_id=user_id,
queue_id=task.execution.queue,
task_id=task.id,
)
}

View File

@ -405,7 +405,9 @@ def reset_task(
updates = {}
try:
dequeued = TaskBLL.dequeue(task, company_id, silent_fail=True)
dequeued = TaskBLL.dequeue(
task, company_id=company_id, user_id=user_id, silent_fail=True
)
except APIError:
# dequeue may fail if the task was not enqueued
pass
@ -577,7 +579,9 @@ def stop_task(
if set_stopped:
if is_queued:
try:
TaskBLL.dequeue(task_, company_id=company_id, silent_fail=True)
TaskBLL.dequeue(
task_, company_id=company_id, user_id=user_id, silent_fail=True
)
except APIError:
# dequeue may fail if the task was not enqueued
pass

View File

@ -537,8 +537,41 @@ remove_task {
}
}
}
"999.0": ${remove_task."2.4"} {
request.properties {
update_task_status {
type: boolean
default: false
description: If set to 'true' then change the removed task status to the one it had prior to enqueuing or 'created'
}
}
}
}
clear_queue {
"999.0" {
description: Remove all tasks from the queue and change their statuses to what they were prior to enqueuing or 'created'
request {
type: object
required: [queue]
properties {
queue {
description: "Queue id"
type: string
}
}
}
response {
type: object
properties {
removed_tasks {
description: IDs of the removed tasks
type: array
items {type: string}
}
}
}
}
}
move_task_forward: {
"2.4" {
description: "Moves a task entry one step forward towards the top of the queue."

View File

@ -21,6 +21,7 @@ from apiserver.apimodels.queues import (
GetByIdRequest,
GetAllRequest,
AddTaskRequest,
RemoveTaskRequest,
)
from apiserver.bll.model import Metadata
from apiserver.bll.queue import QueueBLL
@ -47,7 +48,7 @@ def conform_queue_data(call: APICall, queue_data: Union[Sequence[dict], dict]):
unescape_metadata(call, queue_data)
@endpoint("queues.get_by_id", min_version="2.4", request_data_model=GetByIdRequest)
@endpoint("queues.get_by_id", min_version="2.4")
def get_by_id(call: APICall, company_id, request: GetByIdRequest):
queue = queue_bll.get_by_id(
company_id, request.queue, max_task_entries=request.max_task_entries
@ -112,7 +113,7 @@ def get_all(call: APICall, company: str, request: GetAllRequest):
call.result.data = {"queues": queues, **ret_params}
@endpoint("queues.create", min_version="2.4", request_data_model=CreateRequest)
@endpoint("queues.create", min_version="2.4")
def create(call: APICall, company_id, request: CreateRequest):
tags, system_tags = conform_tags(
call, request.tags, request.system_tags, validate=True
@ -130,27 +131,26 @@ def create(call: APICall, company_id, request: CreateRequest):
@endpoint(
"queues.update",
min_version="2.4",
request_data_model=UpdateRequest,
response_data_model=UpdateResponse,
)
def update(call: APICall, company_id, req_model: UpdateRequest):
def update(call: APICall, company_id, request: UpdateRequest):
data = call.data_model_for_partial_update
conform_tag_fields(call, data, validate=True)
escape_metadata(data)
updated, fields = queue_bll.update(
company_id=company_id, queue_id=req_model.queue, **data
company_id=company_id, queue_id=request.queue, **data
)
conform_queue_data(call, fields)
call.result.data_model = UpdateResponse(updated=updated, fields=fields)
@endpoint("queues.delete", min_version="2.4", request_data_model=DeleteRequest)
def delete(call: APICall, company_id, req_model: DeleteRequest):
@endpoint("queues.delete", min_version="2.4")
def delete(call: APICall, company_id, request: DeleteRequest):
queue_bll.delete(
company_id=company_id,
user_id=call.identity.user,
queue_id=req_model.queue,
force=req_model.force,
queue_id=request.queue,
force=request.force,
)
call.result.data = {"deleted": 1}
@ -167,7 +167,7 @@ def add_task(call: APICall, company_id, request: AddTaskRequest):
call.result.data = {"added": added}
@endpoint("queues.get_next_task", request_data_model=GetNextTaskRequest)
@endpoint("queues.get_next_task")
def get_next_task(call: APICall, company_id, request: GetNextTaskRequest):
entry = queue_bll.get_next_task(
company_id=company_id, queue_id=request.queue, task_id=request.task
@ -187,11 +187,26 @@ def get_next_task(call: APICall, company_id, request: GetNextTaskRequest):
call.result.data = data
@endpoint("queues.remove_task", min_version="2.4", request_data_model=TaskRequest)
def remove_task(call: APICall, company_id, req_model: TaskRequest):
@endpoint("queues.remove_task", min_version="2.4")
def remove_task(call: APICall, company_id, request: RemoveTaskRequest):
call.result.data = {
"removed": queue_bll.remove_task(
company_id=company_id, queue_id=req_model.queue, task_id=req_model.task
company_id=company_id,
user_id=call.identity.user,
queue_id=request.queue,
task_id=request.task,
update_task_status=request.update_task_status,
)
}
@endpoint("queues.clear_queue")
def clear_queue(call: APICall, company_id, request: QueueRequest):
call.result.data = {
"removed_tasks": queue_bll.clear_queue(
company_id=company_id,
user_id=call.identity.user,
queue_id=request.queue,
)
}
@ -199,16 +214,15 @@ def remove_task(call: APICall, company_id, req_model: TaskRequest):
@endpoint(
"queues.move_task_forward",
min_version="2.4",
request_data_model=MoveTaskRequest,
response_data_model=MoveTaskResponse,
)
def move_task_forward(call: APICall, company_id, req_model: MoveTaskRequest):
def move_task_forward(call: APICall, company_id, request: MoveTaskRequest):
call.result.data_model = MoveTaskResponse(
position=queue_bll.reposition_task(
company_id=company_id,
queue_id=req_model.queue,
task_id=req_model.task,
move_count=-req_model.count,
queue_id=request.queue,
task_id=request.task,
move_count=-request.count,
)
)
@ -216,16 +230,15 @@ def move_task_forward(call: APICall, company_id, req_model: MoveTaskRequest):
@endpoint(
"queues.move_task_backward",
min_version="2.4",
request_data_model=MoveTaskRequest,
response_data_model=MoveTaskResponse,
)
def move_task_backward(call: APICall, company_id, req_model: MoveTaskRequest):
def move_task_backward(call: APICall, company_id, request: MoveTaskRequest):
call.result.data_model = MoveTaskResponse(
position=queue_bll.reposition_task(
company_id=company_id,
queue_id=req_model.queue,
task_id=req_model.task,
move_count=req_model.count,
queue_id=request.queue,
task_id=request.task,
move_count=request.count,
)
)
@ -233,15 +246,14 @@ def move_task_backward(call: APICall, company_id, req_model: MoveTaskRequest):
@endpoint(
"queues.move_task_to_front",
min_version="2.4",
request_data_model=TaskRequest,
response_data_model=MoveTaskResponse,
)
def move_task_to_front(call: APICall, company_id, req_model: TaskRequest):
def move_task_to_front(call: APICall, company_id, request: TaskRequest):
call.result.data_model = MoveTaskResponse(
position=queue_bll.reposition_task(
company_id=company_id,
queue_id=req_model.queue,
task_id=req_model.task,
queue_id=request.queue,
task_id=request.task,
move_count=MOVE_FIRST,
)
)
@ -250,15 +262,14 @@ def move_task_to_front(call: APICall, company_id, req_model: TaskRequest):
@endpoint(
"queues.move_task_to_back",
min_version="2.4",
request_data_model=TaskRequest,
response_data_model=MoveTaskResponse,
)
def move_task_to_back(call: APICall, company_id, req_model: TaskRequest):
def move_task_to_back(call: APICall, company_id, request: TaskRequest):
call.result.data_model = MoveTaskResponse(
position=queue_bll.reposition_task(
company_id=company_id,
queue_id=req_model.queue,
task_id=req_model.task,
queue_id=request.queue,
task_id=request.task,
move_count=MOVE_LAST,
)
)
@ -267,7 +278,6 @@ def move_task_to_back(call: APICall, company_id, req_model: TaskRequest):
@endpoint(
"queues.get_queue_metrics",
min_version="2.4",
request_data_model=GetMetricsRequest,
response_data_model=GetMetricsResponse,
)
def get_queue_metrics(

View File

@ -40,6 +40,53 @@ class TestQueues(TestService):
)
self.assertMetricQueues(res["queues"], queue_id)
def test_add_remove_clear(self):
queue1 = self._temp_queue("TestTempQueue1")
queue2 = self._temp_queue("TestTempQueue2")
task_names = ["TempDevTask1", "TempDevTask2"]
tasks = [self._temp_task(name) for name in task_names]
for task in tasks:
self.api.tasks.enqueue(task=task, queue=queue1)
# remove task with and without status update
res = self.api.queues.remove_task(task=tasks[0], queue=queue1)
self.assertEqual(res.removed, 1)
res = self.api.tasks.get_by_id(task=tasks[0])
self.assertEqual(res.task.status, "queued")
self.assertEqual(res.task.execution.queue, queue1)
res = self.api.queues.remove_task(task=tasks[1], queue=queue1, update_task_status=True)
self.assertEqual(res.removed, 1)
res = self.api.tasks.get_by_id(task=tasks[1])
self.assertEqual(res.task.status, "created")
res = self.api.queues.get_by_id(queue=queue1)
self.assertQueueTasks(res.queue, [])
# add task
res = self.api.queues.add_task(queue=queue2, task=tasks[0])
self.assertEqual(res.added, 1)
res = self.api.tasks.get_by_id(task=tasks[0])
self.assertEqual(res.task.status, "queued")
self.assertEqual(res.task.execution.queue, queue2)
res = self.api.queues.get_by_id(queue=queue2)
self.assertQueueTasks(res.queue, [tasks[0]])
# clear queue
res = self.api.queues.clear_queue(queue=queue1)
self.assertEqual(res.removed_tasks, [])
res = self.api.queues.clear_queue(queue=queue2)
self.assertEqual(res.removed_tasks, [tasks[0]])
res = self.api.tasks.get_by_id(task=tasks[0])
self.assertEqual(res.task.status, "created")
res = self.api.queues.get_by_id(queue=queue2)
self.assertQueueTasks(res.queue, [])
def test_hidden_queues(self):
hidden_name = "TestHiddenQueue"
hidden_queue = self._temp_queue(hidden_name, system_tags=["k8s-glue"])

View File

@ -12,7 +12,7 @@ class TestReports(TestService):
def _delete_project(self, name):
existing_project = first(
self.api.projects.get_all_ex(
name=f"^{re.escape(name)}$", search_hidden=True
name=f"^{re.escape(name)}$", search_hidden=True, allow_public=False
).projects
)
if existing_project:
@ -34,10 +34,10 @@ class TestReports(TestService):
self.assertEqual(set(task.tags), set(tags))
self.assertEqual(task.type, "report")
self.assertEqual(set(task.system_tags), {"hidden", "reports"})
projects = self.api.projects.get_all_ex(name=r"^\.reports$").projects
projects = self.api.projects.get_all_ex(name=r"^\.reports$", allow_public=False).projects
self.assertEqual(len(projects), 0)
project = self.api.projects.get_all_ex(
name=r"^\.reports$", search_hidden=True
name=r"^\.reports$", search_hidden=True, allow_public=False
).projects[0]
self.assertEqual(project.id, task.project.id)
self.assertEqual(set(project.system_tags), {"hidden", "reports"})
@ -108,6 +108,7 @@ class TestReports(TestService):
include_stats=True,
check_own_contents=True,
search_hidden=True,
allow_public=False,
).projects
self.assertEqual(len(projects), 1)
p = projects[0]
@ -120,6 +121,7 @@ class TestReports(TestService):
include_stats=True,
check_own_contents=True,
search_hidden=True,
allow_public=False,
).projects
self.assertEqual(len(projects), 1)
p = projects[0]

View File

@ -15,7 +15,7 @@ class TestSubProjects(TestService):
def test_dataset_stats(self):
project = self._temp_project(name="Dataset test", system_tags=["dataset"])
res = self.api.organization.get_entities_count(
datasets={"system_tags": ["dataset"]}
datasets={"system_tags": ["dataset"]}, allow_public=False,
)
self.assertEqual(res.datasets, 1)