From 4b93f1f5080fae182cd61f5fa361686b298c8470 Mon Sep 17 00:00:00 2001 From: clearml <> Date: Thu, 5 Dec 2024 22:15:43 +0200 Subject: [PATCH] Add queues.clear_queue Add new parameter 'update_task_status' to queues.remove_task --- apiserver/apimodels/queues.py | 4 + apiserver/bll/queue/queue_bll.py | 142 +++++++++++++----- apiserver/bll/task/task_bll.py | 20 ++- apiserver/bll/task/task_operations.py | 8 +- apiserver/schema/services/queues.conf | 35 ++++- apiserver/services/queues.py | 74 +++++---- apiserver/tests/automated/test_queues.py | 47 ++++++ apiserver/tests/automated/test_reports.py | 8 +- apiserver/tests/automated/test_subprojects.py | 2 +- 9 files changed, 253 insertions(+), 87 deletions(-) diff --git a/apiserver/apimodels/queues.py b/apiserver/apimodels/queues.py index d784137..de1b69a 100644 --- a/apiserver/apimodels/queues.py +++ b/apiserver/apimodels/queues.py @@ -56,6 +56,10 @@ class TaskRequest(QueueRequest): task = StringField(required=True) +class RemoveTaskRequest(TaskRequest): + update_task_status = BoolField(default=False) + + class AddTaskRequest(TaskRequest): update_execution_queue = BoolField(default=True) diff --git a/apiserver/bll/queue/queue_bll.py b/apiserver/bll/queue/queue_bll.py index 65d4957..ab8a357 100644 --- a/apiserver/bll/queue/queue_bll.py +++ b/apiserver/bll/queue/queue_bll.py @@ -1,6 +1,6 @@ from collections import defaultdict from datetime import datetime -from typing import Sequence, Optional, Tuple, Union +from typing import Sequence, Optional, Tuple, Union, Iterable from elasticsearch import Elasticsearch from mongoengine import Q @@ -135,51 +135,74 @@ class QueueBLL(object): self.get_by_id(company_id=company_id, queue_id=queue_id, only=("id",)) return Queue.safe_update(company_id, queue_id, update_fields) - def delete(self, company_id: str, user_id: str, queue_id: str, force: bool) -> None: + def _update_task_status_on_removal_from_queue( + self, + company_id: str, + user_id: str, + task_ids: Iterable[str], + queue_id: str, + reason: str + ) -> Sequence[str]: + from apiserver.bll.task import ChangeStatusRequest + tasks = [] + for task_id in task_ids: + try: + task = Task.get( + company=company_id, + id=task_id, + _only=[ + "id", + "company", + "status", + "enqueue_status", + "project", + ], + ) + if not task: + continue + + tasks.append(task.id) + ChangeStatusRequest( + task=task, + new_status=task.enqueue_status or TaskStatus.created, + status_reason=reason, + status_message="", + user_id=user_id, + force=True, + ).execute(enqueue_status=None) + except Exception as ex: + log.error( + f"Failed updating task {task_id} status on removal from queue: {queue_id}, {str(ex)}" + ) + + return tasks + + def delete(self, company_id: str, user_id: str, queue_id: str, force: bool) -> Sequence[str]: """ Delete the queue :raise errors.bad_request.InvalidQueueId: if the queue is not found :raise errors.bad_request.QueueNotEmpty: if the queue is not empty and 'force' not set """ - with translate_errors_context(): - queue = self.get_by_id(company_id=company_id, queue_id=queue_id) - if queue.entries: - if not force: - raise errors.bad_request.QueueNotEmpty( - "use force=true to delete", id=queue_id - ) - from apiserver.bll.task import ChangeStatusRequest - - for item in queue.entries: - try: - task = Task.get( - company=company_id, - id=item.task, - _only=[ - "id", - "company", - "status", - "enqueue_status", - "project", - ], - ) - if not task: - continue - - ChangeStatusRequest( - task=task, - new_status=task.enqueue_status or TaskStatus.created, - status_reason="Queue was deleted", - status_message="", - user_id=user_id, - force=True, - ).execute(enqueue_status=None) - except Exception as ex: - log.exception( - f"Failed dequeuing task {item.task} from queue: {queue_id}" - ) - + queue = self.get_by_id(company_id=company_id, queue_id=queue_id) + if not queue.entries: queue.delete() + return [] + + if not force: + raise errors.bad_request.QueueNotEmpty( + "use force=true to delete", id=queue_id + ) + + tasks = self._update_task_status_on_removal_from_queue( + company_id=company_id, + user_id=user_id, + task_ids={item.task for item in queue.entries}, + queue_id=queue_id, + reason=f"Queue {queue_id} was deleted", + ) + + queue.delete() + return tasks def get_all( self, @@ -307,7 +330,36 @@ class QueueBLL(object): return queue.entries[0] - def remove_task(self, company_id: str, queue_id: str, task_id: str) -> int: + def clear_queue( + self, + company_id: str, + user_id: str, + queue_id: str, + ): + queue = Queue.objects(company=company_id, id=queue_id).first() + if not queue: + raise errors.bad_request.InvalidQueueId( + queue=queue_id + ) + + if not queue.entries: + return [] + + tasks = self._update_task_status_on_removal_from_queue( + company_id=company_id, + user_id=user_id, + task_ids={item.task for item in queue.entries}, + queue_id=queue_id, + reason=f"Queue {queue_id} was cleared", + ) + + queue.update(entries=[]) + queue.reload() + self.metrics.log_queue_metrics_to_es(company_id=company_id, queues=[queue]) + + return tasks + + def remove_task(self, company_id: str, user_id: str, queue_id: str, task_id: str, update_task_status: bool = False) -> int: """ Removes the task from the queue and returns the number of removed items :raise errors.bad_request.InvalidQueueOrTaskNotQueued: if the task is not found in the queue @@ -322,6 +374,14 @@ class QueueBLL(object): res = Queue.objects(entries__task=task_id, **query).update_one( pull_all__entries=entries_to_remove, last_update=datetime.utcnow() ) + if res and update_task_status: + self._update_task_status_on_removal_from_queue( + company_id=company_id, + user_id=user_id, + task_ids=[task_id], + queue_id=queue_id, + reason=f"Task was removed from the queue {queue_id}", + ) queue.reload() self.metrics.log_queue_metrics_to_es(company_id=company_id, queues=[queue]) diff --git a/apiserver/bll/task/task_bll.py b/apiserver/bll/task/task_bll.py index ebfae5c..3f1e893 100644 --- a/apiserver/bll/task/task_bll.py +++ b/apiserver/bll/task/task_bll.py @@ -168,7 +168,9 @@ class TaskBLL: configuration_overrides: Optional[dict] = None, ) -> Tuple[Task, dict]: validate_tags(tags, system_tags) - task: Task = cls.get_by_id(company_id=company_id, task_id=task_id, allow_public=True) + task: Task = cls.get_by_id( + company_id=company_id, task_id=task_id, allow_public=True + ) params_dict = {} if hyperparams: @@ -187,8 +189,7 @@ class TaskBLL: params_dict["configuration"] = configuration elif configuration_overrides: updated_configuration = { - k: value - for k, value in (task.configuration or {}).items() + k: value for k, value in (task.configuration or {}).items() } for key, value in configuration_overrides.items(): updated_configuration[key] = value @@ -457,7 +458,9 @@ class TaskBLL: return ret @staticmethod - def remove_task_from_all_queues(company_id: str, task_id: str, exclude: str = None) -> int: + def remove_task_from_all_queues( + company_id: str, task_id: str, exclude: str = None + ) -> int: more = {} if exclude: more["id__ne"] = exclude @@ -478,7 +481,7 @@ class TaskBLL: new_status_for_aborted_task=None, ): try: - cls.dequeue(task, company_id, silent_fail=True) + cls.dequeue(task, company_id=company_id, user_id=user_id, silent_fail=True) except APIError: # dequeue may fail if the queue was deleted pass @@ -502,7 +505,7 @@ class TaskBLL: ).execute(enqueue_status=None) @classmethod - def dequeue(cls, task: Task, company_id: str, silent_fail=False): + def dequeue(cls, task: Task, company_id: str, user_id: str, silent_fail=False): """ Dequeue the task from the queue :param task: task to dequeue @@ -529,6 +532,9 @@ class TaskBLL: return { "removed": queue_bll.remove_task( - company_id=company_id, queue_id=task.execution.queue, task_id=task.id + company_id=company_id, + user_id=user_id, + queue_id=task.execution.queue, + task_id=task.id, ) } diff --git a/apiserver/bll/task/task_operations.py b/apiserver/bll/task/task_operations.py index 15f5e04..7df9a79 100644 --- a/apiserver/bll/task/task_operations.py +++ b/apiserver/bll/task/task_operations.py @@ -405,7 +405,9 @@ def reset_task( updates = {} try: - dequeued = TaskBLL.dequeue(task, company_id, silent_fail=True) + dequeued = TaskBLL.dequeue( + task, company_id=company_id, user_id=user_id, silent_fail=True + ) except APIError: # dequeue may fail if the task was not enqueued pass @@ -577,7 +579,9 @@ def stop_task( if set_stopped: if is_queued: try: - TaskBLL.dequeue(task_, company_id=company_id, silent_fail=True) + TaskBLL.dequeue( + task_, company_id=company_id, user_id=user_id, silent_fail=True + ) except APIError: # dequeue may fail if the task was not enqueued pass diff --git a/apiserver/schema/services/queues.conf b/apiserver/schema/services/queues.conf index 1e7c8cd..4ac3db9 100644 --- a/apiserver/schema/services/queues.conf +++ b/apiserver/schema/services/queues.conf @@ -537,8 +537,41 @@ remove_task { } } } + "999.0": ${remove_task."2.4"} { + request.properties { + update_task_status { + type: boolean + default: false + description: If set to 'true' then change the removed task status to the one it had prior to enqueuing or 'created' + } + } + } +} +clear_queue { + "999.0" { + description: Remove all tasks from the queue and change their statuses to what they were prior to enqueuing or 'created' + request { + type: object + required: [queue] + properties { + queue { + description: "Queue id" + type: string + } + } + } + response { + type: object + properties { + removed_tasks { + description: IDs of the removed tasks + type: array + items {type: string} + } + } + } + } } - move_task_forward: { "2.4" { description: "Moves a task entry one step forward towards the top of the queue." diff --git a/apiserver/services/queues.py b/apiserver/services/queues.py index 9598844..225a141 100644 --- a/apiserver/services/queues.py +++ b/apiserver/services/queues.py @@ -21,6 +21,7 @@ from apiserver.apimodels.queues import ( GetByIdRequest, GetAllRequest, AddTaskRequest, + RemoveTaskRequest, ) from apiserver.bll.model import Metadata from apiserver.bll.queue import QueueBLL @@ -47,7 +48,7 @@ def conform_queue_data(call: APICall, queue_data: Union[Sequence[dict], dict]): unescape_metadata(call, queue_data) -@endpoint("queues.get_by_id", min_version="2.4", request_data_model=GetByIdRequest) +@endpoint("queues.get_by_id", min_version="2.4") def get_by_id(call: APICall, company_id, request: GetByIdRequest): queue = queue_bll.get_by_id( company_id, request.queue, max_task_entries=request.max_task_entries @@ -112,7 +113,7 @@ def get_all(call: APICall, company: str, request: GetAllRequest): call.result.data = {"queues": queues, **ret_params} -@endpoint("queues.create", min_version="2.4", request_data_model=CreateRequest) +@endpoint("queues.create", min_version="2.4") def create(call: APICall, company_id, request: CreateRequest): tags, system_tags = conform_tags( call, request.tags, request.system_tags, validate=True @@ -130,27 +131,26 @@ def create(call: APICall, company_id, request: CreateRequest): @endpoint( "queues.update", min_version="2.4", - request_data_model=UpdateRequest, response_data_model=UpdateResponse, ) -def update(call: APICall, company_id, req_model: UpdateRequest): +def update(call: APICall, company_id, request: UpdateRequest): data = call.data_model_for_partial_update conform_tag_fields(call, data, validate=True) escape_metadata(data) updated, fields = queue_bll.update( - company_id=company_id, queue_id=req_model.queue, **data + company_id=company_id, queue_id=request.queue, **data ) conform_queue_data(call, fields) call.result.data_model = UpdateResponse(updated=updated, fields=fields) -@endpoint("queues.delete", min_version="2.4", request_data_model=DeleteRequest) -def delete(call: APICall, company_id, req_model: DeleteRequest): +@endpoint("queues.delete", min_version="2.4") +def delete(call: APICall, company_id, request: DeleteRequest): queue_bll.delete( company_id=company_id, user_id=call.identity.user, - queue_id=req_model.queue, - force=req_model.force, + queue_id=request.queue, + force=request.force, ) call.result.data = {"deleted": 1} @@ -167,7 +167,7 @@ def add_task(call: APICall, company_id, request: AddTaskRequest): call.result.data = {"added": added} -@endpoint("queues.get_next_task", request_data_model=GetNextTaskRequest) +@endpoint("queues.get_next_task") def get_next_task(call: APICall, company_id, request: GetNextTaskRequest): entry = queue_bll.get_next_task( company_id=company_id, queue_id=request.queue, task_id=request.task @@ -187,11 +187,26 @@ def get_next_task(call: APICall, company_id, request: GetNextTaskRequest): call.result.data = data -@endpoint("queues.remove_task", min_version="2.4", request_data_model=TaskRequest) -def remove_task(call: APICall, company_id, req_model: TaskRequest): +@endpoint("queues.remove_task", min_version="2.4") +def remove_task(call: APICall, company_id, request: RemoveTaskRequest): call.result.data = { "removed": queue_bll.remove_task( - company_id=company_id, queue_id=req_model.queue, task_id=req_model.task + company_id=company_id, + user_id=call.identity.user, + queue_id=request.queue, + task_id=request.task, + update_task_status=request.update_task_status, + ) + } + + +@endpoint("queues.clear_queue") +def clear_queue(call: APICall, company_id, request: QueueRequest): + call.result.data = { + "removed_tasks": queue_bll.clear_queue( + company_id=company_id, + user_id=call.identity.user, + queue_id=request.queue, ) } @@ -199,16 +214,15 @@ def remove_task(call: APICall, company_id, req_model: TaskRequest): @endpoint( "queues.move_task_forward", min_version="2.4", - request_data_model=MoveTaskRequest, response_data_model=MoveTaskResponse, ) -def move_task_forward(call: APICall, company_id, req_model: MoveTaskRequest): +def move_task_forward(call: APICall, company_id, request: MoveTaskRequest): call.result.data_model = MoveTaskResponse( position=queue_bll.reposition_task( company_id=company_id, - queue_id=req_model.queue, - task_id=req_model.task, - move_count=-req_model.count, + queue_id=request.queue, + task_id=request.task, + move_count=-request.count, ) ) @@ -216,16 +230,15 @@ def move_task_forward(call: APICall, company_id, req_model: MoveTaskRequest): @endpoint( "queues.move_task_backward", min_version="2.4", - request_data_model=MoveTaskRequest, response_data_model=MoveTaskResponse, ) -def move_task_backward(call: APICall, company_id, req_model: MoveTaskRequest): +def move_task_backward(call: APICall, company_id, request: MoveTaskRequest): call.result.data_model = MoveTaskResponse( position=queue_bll.reposition_task( company_id=company_id, - queue_id=req_model.queue, - task_id=req_model.task, - move_count=req_model.count, + queue_id=request.queue, + task_id=request.task, + move_count=request.count, ) ) @@ -233,15 +246,14 @@ def move_task_backward(call: APICall, company_id, req_model: MoveTaskRequest): @endpoint( "queues.move_task_to_front", min_version="2.4", - request_data_model=TaskRequest, response_data_model=MoveTaskResponse, ) -def move_task_to_front(call: APICall, company_id, req_model: TaskRequest): +def move_task_to_front(call: APICall, company_id, request: TaskRequest): call.result.data_model = MoveTaskResponse( position=queue_bll.reposition_task( company_id=company_id, - queue_id=req_model.queue, - task_id=req_model.task, + queue_id=request.queue, + task_id=request.task, move_count=MOVE_FIRST, ) ) @@ -250,15 +262,14 @@ def move_task_to_front(call: APICall, company_id, req_model: TaskRequest): @endpoint( "queues.move_task_to_back", min_version="2.4", - request_data_model=TaskRequest, response_data_model=MoveTaskResponse, ) -def move_task_to_back(call: APICall, company_id, req_model: TaskRequest): +def move_task_to_back(call: APICall, company_id, request: TaskRequest): call.result.data_model = MoveTaskResponse( position=queue_bll.reposition_task( company_id=company_id, - queue_id=req_model.queue, - task_id=req_model.task, + queue_id=request.queue, + task_id=request.task, move_count=MOVE_LAST, ) ) @@ -267,7 +278,6 @@ def move_task_to_back(call: APICall, company_id, req_model: TaskRequest): @endpoint( "queues.get_queue_metrics", min_version="2.4", - request_data_model=GetMetricsRequest, response_data_model=GetMetricsResponse, ) def get_queue_metrics( diff --git a/apiserver/tests/automated/test_queues.py b/apiserver/tests/automated/test_queues.py index 1959d91..fb402e4 100644 --- a/apiserver/tests/automated/test_queues.py +++ b/apiserver/tests/automated/test_queues.py @@ -40,6 +40,53 @@ class TestQueues(TestService): ) self.assertMetricQueues(res["queues"], queue_id) + def test_add_remove_clear(self): + queue1 = self._temp_queue("TestTempQueue1") + queue2 = self._temp_queue("TestTempQueue2") + + task_names = ["TempDevTask1", "TempDevTask2"] + tasks = [self._temp_task(name) for name in task_names] + + for task in tasks: + self.api.tasks.enqueue(task=task, queue=queue1) + + # remove task with and without status update + res = self.api.queues.remove_task(task=tasks[0], queue=queue1) + self.assertEqual(res.removed, 1) + res = self.api.tasks.get_by_id(task=tasks[0]) + self.assertEqual(res.task.status, "queued") + self.assertEqual(res.task.execution.queue, queue1) + + res = self.api.queues.remove_task(task=tasks[1], queue=queue1, update_task_status=True) + self.assertEqual(res.removed, 1) + res = self.api.tasks.get_by_id(task=tasks[1]) + self.assertEqual(res.task.status, "created") + + res = self.api.queues.get_by_id(queue=queue1) + self.assertQueueTasks(res.queue, []) + + # add task + res = self.api.queues.add_task(queue=queue2, task=tasks[0]) + self.assertEqual(res.added, 1) + res = self.api.tasks.get_by_id(task=tasks[0]) + self.assertEqual(res.task.status, "queued") + self.assertEqual(res.task.execution.queue, queue2) + + res = self.api.queues.get_by_id(queue=queue2) + self.assertQueueTasks(res.queue, [tasks[0]]) + + # clear queue + res = self.api.queues.clear_queue(queue=queue1) + self.assertEqual(res.removed_tasks, []) + res = self.api.queues.clear_queue(queue=queue2) + self.assertEqual(res.removed_tasks, [tasks[0]]) + + res = self.api.tasks.get_by_id(task=tasks[0]) + self.assertEqual(res.task.status, "created") + + res = self.api.queues.get_by_id(queue=queue2) + self.assertQueueTasks(res.queue, []) + def test_hidden_queues(self): hidden_name = "TestHiddenQueue" hidden_queue = self._temp_queue(hidden_name, system_tags=["k8s-glue"]) diff --git a/apiserver/tests/automated/test_reports.py b/apiserver/tests/automated/test_reports.py index 654391f..1b99dcb 100644 --- a/apiserver/tests/automated/test_reports.py +++ b/apiserver/tests/automated/test_reports.py @@ -12,7 +12,7 @@ class TestReports(TestService): def _delete_project(self, name): existing_project = first( self.api.projects.get_all_ex( - name=f"^{re.escape(name)}$", search_hidden=True + name=f"^{re.escape(name)}$", search_hidden=True, allow_public=False ).projects ) if existing_project: @@ -34,10 +34,10 @@ class TestReports(TestService): self.assertEqual(set(task.tags), set(tags)) self.assertEqual(task.type, "report") self.assertEqual(set(task.system_tags), {"hidden", "reports"}) - projects = self.api.projects.get_all_ex(name=r"^\.reports$").projects + projects = self.api.projects.get_all_ex(name=r"^\.reports$", allow_public=False).projects self.assertEqual(len(projects), 0) project = self.api.projects.get_all_ex( - name=r"^\.reports$", search_hidden=True + name=r"^\.reports$", search_hidden=True, allow_public=False ).projects[0] self.assertEqual(project.id, task.project.id) self.assertEqual(set(project.system_tags), {"hidden", "reports"}) @@ -108,6 +108,7 @@ class TestReports(TestService): include_stats=True, check_own_contents=True, search_hidden=True, + allow_public=False, ).projects self.assertEqual(len(projects), 1) p = projects[0] @@ -120,6 +121,7 @@ class TestReports(TestService): include_stats=True, check_own_contents=True, search_hidden=True, + allow_public=False, ).projects self.assertEqual(len(projects), 1) p = projects[0] diff --git a/apiserver/tests/automated/test_subprojects.py b/apiserver/tests/automated/test_subprojects.py index 0158391..cb10afa 100644 --- a/apiserver/tests/automated/test_subprojects.py +++ b/apiserver/tests/automated/test_subprojects.py @@ -15,7 +15,7 @@ class TestSubProjects(TestService): def test_dataset_stats(self): project = self._temp_project(name="Dataset test", system_tags=["dataset"]) res = self.api.organization.get_entities_count( - datasets={"system_tags": ["dataset"]} + datasets={"system_tags": ["dataset"]}, allow_public=False, ) self.assertEqual(res.datasets, 1)