Add queues.clear_queue

Add new parameter 'update_task_status' to queues.remove_task
This commit is contained in:
clearml 2024-12-05 22:15:43 +02:00
parent 2752c4df54
commit 4b93f1f508
9 changed files with 253 additions and 87 deletions

View File

@ -56,6 +56,10 @@ class TaskRequest(QueueRequest):
task = StringField(required=True) task = StringField(required=True)
class RemoveTaskRequest(TaskRequest):
update_task_status = BoolField(default=False)
class AddTaskRequest(TaskRequest): class AddTaskRequest(TaskRequest):
update_execution_queue = BoolField(default=True) update_execution_queue = BoolField(default=True)

View File

@ -1,6 +1,6 @@
from collections import defaultdict from collections import defaultdict
from datetime import datetime from datetime import datetime
from typing import Sequence, Optional, Tuple, Union from typing import Sequence, Optional, Tuple, Union, Iterable
from elasticsearch import Elasticsearch from elasticsearch import Elasticsearch
from mongoengine import Q from mongoengine import Q
@ -135,51 +135,74 @@ class QueueBLL(object):
self.get_by_id(company_id=company_id, queue_id=queue_id, only=("id",)) self.get_by_id(company_id=company_id, queue_id=queue_id, only=("id",))
return Queue.safe_update(company_id, queue_id, update_fields) return Queue.safe_update(company_id, queue_id, update_fields)
def delete(self, company_id: str, user_id: str, queue_id: str, force: bool) -> None: def _update_task_status_on_removal_from_queue(
self,
company_id: str,
user_id: str,
task_ids: Iterable[str],
queue_id: str,
reason: str
) -> Sequence[str]:
from apiserver.bll.task import ChangeStatusRequest
tasks = []
for task_id in task_ids:
try:
task = Task.get(
company=company_id,
id=task_id,
_only=[
"id",
"company",
"status",
"enqueue_status",
"project",
],
)
if not task:
continue
tasks.append(task.id)
ChangeStatusRequest(
task=task,
new_status=task.enqueue_status or TaskStatus.created,
status_reason=reason,
status_message="",
user_id=user_id,
force=True,
).execute(enqueue_status=None)
except Exception as ex:
log.error(
f"Failed updating task {task_id} status on removal from queue: {queue_id}, {str(ex)}"
)
return tasks
def delete(self, company_id: str, user_id: str, queue_id: str, force: bool) -> Sequence[str]:
""" """
Delete the queue Delete the queue
:raise errors.bad_request.InvalidQueueId: if the queue is not found :raise errors.bad_request.InvalidQueueId: if the queue is not found
:raise errors.bad_request.QueueNotEmpty: if the queue is not empty and 'force' not set :raise errors.bad_request.QueueNotEmpty: if the queue is not empty and 'force' not set
""" """
with translate_errors_context(): queue = self.get_by_id(company_id=company_id, queue_id=queue_id)
queue = self.get_by_id(company_id=company_id, queue_id=queue_id) if not queue.entries:
if queue.entries:
if not force:
raise errors.bad_request.QueueNotEmpty(
"use force=true to delete", id=queue_id
)
from apiserver.bll.task import ChangeStatusRequest
for item in queue.entries:
try:
task = Task.get(
company=company_id,
id=item.task,
_only=[
"id",
"company",
"status",
"enqueue_status",
"project",
],
)
if not task:
continue
ChangeStatusRequest(
task=task,
new_status=task.enqueue_status or TaskStatus.created,
status_reason="Queue was deleted",
status_message="",
user_id=user_id,
force=True,
).execute(enqueue_status=None)
except Exception as ex:
log.exception(
f"Failed dequeuing task {item.task} from queue: {queue_id}"
)
queue.delete() queue.delete()
return []
if not force:
raise errors.bad_request.QueueNotEmpty(
"use force=true to delete", id=queue_id
)
tasks = self._update_task_status_on_removal_from_queue(
company_id=company_id,
user_id=user_id,
task_ids={item.task for item in queue.entries},
queue_id=queue_id,
reason=f"Queue {queue_id} was deleted",
)
queue.delete()
return tasks
def get_all( def get_all(
self, self,
@ -307,7 +330,36 @@ class QueueBLL(object):
return queue.entries[0] return queue.entries[0]
def remove_task(self, company_id: str, queue_id: str, task_id: str) -> int: def clear_queue(
self,
company_id: str,
user_id: str,
queue_id: str,
):
queue = Queue.objects(company=company_id, id=queue_id).first()
if not queue:
raise errors.bad_request.InvalidQueueId(
queue=queue_id
)
if not queue.entries:
return []
tasks = self._update_task_status_on_removal_from_queue(
company_id=company_id,
user_id=user_id,
task_ids={item.task for item in queue.entries},
queue_id=queue_id,
reason=f"Queue {queue_id} was cleared",
)
queue.update(entries=[])
queue.reload()
self.metrics.log_queue_metrics_to_es(company_id=company_id, queues=[queue])
return tasks
def remove_task(self, company_id: str, user_id: str, queue_id: str, task_id: str, update_task_status: bool = False) -> int:
""" """
Removes the task from the queue and returns the number of removed items Removes the task from the queue and returns the number of removed items
:raise errors.bad_request.InvalidQueueOrTaskNotQueued: if the task is not found in the queue :raise errors.bad_request.InvalidQueueOrTaskNotQueued: if the task is not found in the queue
@ -322,6 +374,14 @@ class QueueBLL(object):
res = Queue.objects(entries__task=task_id, **query).update_one( res = Queue.objects(entries__task=task_id, **query).update_one(
pull_all__entries=entries_to_remove, last_update=datetime.utcnow() pull_all__entries=entries_to_remove, last_update=datetime.utcnow()
) )
if res and update_task_status:
self._update_task_status_on_removal_from_queue(
company_id=company_id,
user_id=user_id,
task_ids=[task_id],
queue_id=queue_id,
reason=f"Task was removed from the queue {queue_id}",
)
queue.reload() queue.reload()
self.metrics.log_queue_metrics_to_es(company_id=company_id, queues=[queue]) self.metrics.log_queue_metrics_to_es(company_id=company_id, queues=[queue])

View File

@ -168,7 +168,9 @@ class TaskBLL:
configuration_overrides: Optional[dict] = None, configuration_overrides: Optional[dict] = None,
) -> Tuple[Task, dict]: ) -> Tuple[Task, dict]:
validate_tags(tags, system_tags) validate_tags(tags, system_tags)
task: Task = cls.get_by_id(company_id=company_id, task_id=task_id, allow_public=True) task: Task = cls.get_by_id(
company_id=company_id, task_id=task_id, allow_public=True
)
params_dict = {} params_dict = {}
if hyperparams: if hyperparams:
@ -187,8 +189,7 @@ class TaskBLL:
params_dict["configuration"] = configuration params_dict["configuration"] = configuration
elif configuration_overrides: elif configuration_overrides:
updated_configuration = { updated_configuration = {
k: value k: value for k, value in (task.configuration or {}).items()
for k, value in (task.configuration or {}).items()
} }
for key, value in configuration_overrides.items(): for key, value in configuration_overrides.items():
updated_configuration[key] = value updated_configuration[key] = value
@ -457,7 +458,9 @@ class TaskBLL:
return ret return ret
@staticmethod @staticmethod
def remove_task_from_all_queues(company_id: str, task_id: str, exclude: str = None) -> int: def remove_task_from_all_queues(
company_id: str, task_id: str, exclude: str = None
) -> int:
more = {} more = {}
if exclude: if exclude:
more["id__ne"] = exclude more["id__ne"] = exclude
@ -478,7 +481,7 @@ class TaskBLL:
new_status_for_aborted_task=None, new_status_for_aborted_task=None,
): ):
try: try:
cls.dequeue(task, company_id, silent_fail=True) cls.dequeue(task, company_id=company_id, user_id=user_id, silent_fail=True)
except APIError: except APIError:
# dequeue may fail if the queue was deleted # dequeue may fail if the queue was deleted
pass pass
@ -502,7 +505,7 @@ class TaskBLL:
).execute(enqueue_status=None) ).execute(enqueue_status=None)
@classmethod @classmethod
def dequeue(cls, task: Task, company_id: str, silent_fail=False): def dequeue(cls, task: Task, company_id: str, user_id: str, silent_fail=False):
""" """
Dequeue the task from the queue Dequeue the task from the queue
:param task: task to dequeue :param task: task to dequeue
@ -529,6 +532,9 @@ class TaskBLL:
return { return {
"removed": queue_bll.remove_task( "removed": queue_bll.remove_task(
company_id=company_id, queue_id=task.execution.queue, task_id=task.id company_id=company_id,
user_id=user_id,
queue_id=task.execution.queue,
task_id=task.id,
) )
} }

View File

@ -405,7 +405,9 @@ def reset_task(
updates = {} updates = {}
try: try:
dequeued = TaskBLL.dequeue(task, company_id, silent_fail=True) dequeued = TaskBLL.dequeue(
task, company_id=company_id, user_id=user_id, silent_fail=True
)
except APIError: except APIError:
# dequeue may fail if the task was not enqueued # dequeue may fail if the task was not enqueued
pass pass
@ -577,7 +579,9 @@ def stop_task(
if set_stopped: if set_stopped:
if is_queued: if is_queued:
try: try:
TaskBLL.dequeue(task_, company_id=company_id, silent_fail=True) TaskBLL.dequeue(
task_, company_id=company_id, user_id=user_id, silent_fail=True
)
except APIError: except APIError:
# dequeue may fail if the task was not enqueued # dequeue may fail if the task was not enqueued
pass pass

View File

@ -537,8 +537,41 @@ remove_task {
} }
} }
} }
"999.0": ${remove_task."2.4"} {
request.properties {
update_task_status {
type: boolean
default: false
description: If set to 'true' then change the removed task status to the one it had prior to enqueuing or 'created'
}
}
}
}
clear_queue {
"999.0" {
description: Remove all tasks from the queue and change their statuses to what they were prior to enqueuing or 'created'
request {
type: object
required: [queue]
properties {
queue {
description: "Queue id"
type: string
}
}
}
response {
type: object
properties {
removed_tasks {
description: IDs of the removed tasks
type: array
items {type: string}
}
}
}
}
} }
move_task_forward: { move_task_forward: {
"2.4" { "2.4" {
description: "Moves a task entry one step forward towards the top of the queue." description: "Moves a task entry one step forward towards the top of the queue."

View File

@ -21,6 +21,7 @@ from apiserver.apimodels.queues import (
GetByIdRequest, GetByIdRequest,
GetAllRequest, GetAllRequest,
AddTaskRequest, AddTaskRequest,
RemoveTaskRequest,
) )
from apiserver.bll.model import Metadata from apiserver.bll.model import Metadata
from apiserver.bll.queue import QueueBLL from apiserver.bll.queue import QueueBLL
@ -47,7 +48,7 @@ def conform_queue_data(call: APICall, queue_data: Union[Sequence[dict], dict]):
unescape_metadata(call, queue_data) unescape_metadata(call, queue_data)
@endpoint("queues.get_by_id", min_version="2.4", request_data_model=GetByIdRequest) @endpoint("queues.get_by_id", min_version="2.4")
def get_by_id(call: APICall, company_id, request: GetByIdRequest): def get_by_id(call: APICall, company_id, request: GetByIdRequest):
queue = queue_bll.get_by_id( queue = queue_bll.get_by_id(
company_id, request.queue, max_task_entries=request.max_task_entries company_id, request.queue, max_task_entries=request.max_task_entries
@ -112,7 +113,7 @@ def get_all(call: APICall, company: str, request: GetAllRequest):
call.result.data = {"queues": queues, **ret_params} call.result.data = {"queues": queues, **ret_params}
@endpoint("queues.create", min_version="2.4", request_data_model=CreateRequest) @endpoint("queues.create", min_version="2.4")
def create(call: APICall, company_id, request: CreateRequest): def create(call: APICall, company_id, request: CreateRequest):
tags, system_tags = conform_tags( tags, system_tags = conform_tags(
call, request.tags, request.system_tags, validate=True call, request.tags, request.system_tags, validate=True
@ -130,27 +131,26 @@ def create(call: APICall, company_id, request: CreateRequest):
@endpoint( @endpoint(
"queues.update", "queues.update",
min_version="2.4", min_version="2.4",
request_data_model=UpdateRequest,
response_data_model=UpdateResponse, response_data_model=UpdateResponse,
) )
def update(call: APICall, company_id, req_model: UpdateRequest): def update(call: APICall, company_id, request: UpdateRequest):
data = call.data_model_for_partial_update data = call.data_model_for_partial_update
conform_tag_fields(call, data, validate=True) conform_tag_fields(call, data, validate=True)
escape_metadata(data) escape_metadata(data)
updated, fields = queue_bll.update( updated, fields = queue_bll.update(
company_id=company_id, queue_id=req_model.queue, **data company_id=company_id, queue_id=request.queue, **data
) )
conform_queue_data(call, fields) conform_queue_data(call, fields)
call.result.data_model = UpdateResponse(updated=updated, fields=fields) call.result.data_model = UpdateResponse(updated=updated, fields=fields)
@endpoint("queues.delete", min_version="2.4", request_data_model=DeleteRequest) @endpoint("queues.delete", min_version="2.4")
def delete(call: APICall, company_id, req_model: DeleteRequest): def delete(call: APICall, company_id, request: DeleteRequest):
queue_bll.delete( queue_bll.delete(
company_id=company_id, company_id=company_id,
user_id=call.identity.user, user_id=call.identity.user,
queue_id=req_model.queue, queue_id=request.queue,
force=req_model.force, force=request.force,
) )
call.result.data = {"deleted": 1} call.result.data = {"deleted": 1}
@ -167,7 +167,7 @@ def add_task(call: APICall, company_id, request: AddTaskRequest):
call.result.data = {"added": added} call.result.data = {"added": added}
@endpoint("queues.get_next_task", request_data_model=GetNextTaskRequest) @endpoint("queues.get_next_task")
def get_next_task(call: APICall, company_id, request: GetNextTaskRequest): def get_next_task(call: APICall, company_id, request: GetNextTaskRequest):
entry = queue_bll.get_next_task( entry = queue_bll.get_next_task(
company_id=company_id, queue_id=request.queue, task_id=request.task company_id=company_id, queue_id=request.queue, task_id=request.task
@ -187,11 +187,26 @@ def get_next_task(call: APICall, company_id, request: GetNextTaskRequest):
call.result.data = data call.result.data = data
@endpoint("queues.remove_task", min_version="2.4", request_data_model=TaskRequest) @endpoint("queues.remove_task", min_version="2.4")
def remove_task(call: APICall, company_id, req_model: TaskRequest): def remove_task(call: APICall, company_id, request: RemoveTaskRequest):
call.result.data = { call.result.data = {
"removed": queue_bll.remove_task( "removed": queue_bll.remove_task(
company_id=company_id, queue_id=req_model.queue, task_id=req_model.task company_id=company_id,
user_id=call.identity.user,
queue_id=request.queue,
task_id=request.task,
update_task_status=request.update_task_status,
)
}
@endpoint("queues.clear_queue")
def clear_queue(call: APICall, company_id, request: QueueRequest):
call.result.data = {
"removed_tasks": queue_bll.clear_queue(
company_id=company_id,
user_id=call.identity.user,
queue_id=request.queue,
) )
} }
@ -199,16 +214,15 @@ def remove_task(call: APICall, company_id, req_model: TaskRequest):
@endpoint( @endpoint(
"queues.move_task_forward", "queues.move_task_forward",
min_version="2.4", min_version="2.4",
request_data_model=MoveTaskRequest,
response_data_model=MoveTaskResponse, response_data_model=MoveTaskResponse,
) )
def move_task_forward(call: APICall, company_id, req_model: MoveTaskRequest): def move_task_forward(call: APICall, company_id, request: MoveTaskRequest):
call.result.data_model = MoveTaskResponse( call.result.data_model = MoveTaskResponse(
position=queue_bll.reposition_task( position=queue_bll.reposition_task(
company_id=company_id, company_id=company_id,
queue_id=req_model.queue, queue_id=request.queue,
task_id=req_model.task, task_id=request.task,
move_count=-req_model.count, move_count=-request.count,
) )
) )
@ -216,16 +230,15 @@ def move_task_forward(call: APICall, company_id, req_model: MoveTaskRequest):
@endpoint( @endpoint(
"queues.move_task_backward", "queues.move_task_backward",
min_version="2.4", min_version="2.4",
request_data_model=MoveTaskRequest,
response_data_model=MoveTaskResponse, response_data_model=MoveTaskResponse,
) )
def move_task_backward(call: APICall, company_id, req_model: MoveTaskRequest): def move_task_backward(call: APICall, company_id, request: MoveTaskRequest):
call.result.data_model = MoveTaskResponse( call.result.data_model = MoveTaskResponse(
position=queue_bll.reposition_task( position=queue_bll.reposition_task(
company_id=company_id, company_id=company_id,
queue_id=req_model.queue, queue_id=request.queue,
task_id=req_model.task, task_id=request.task,
move_count=req_model.count, move_count=request.count,
) )
) )
@ -233,15 +246,14 @@ def move_task_backward(call: APICall, company_id, req_model: MoveTaskRequest):
@endpoint( @endpoint(
"queues.move_task_to_front", "queues.move_task_to_front",
min_version="2.4", min_version="2.4",
request_data_model=TaskRequest,
response_data_model=MoveTaskResponse, response_data_model=MoveTaskResponse,
) )
def move_task_to_front(call: APICall, company_id, req_model: TaskRequest): def move_task_to_front(call: APICall, company_id, request: TaskRequest):
call.result.data_model = MoveTaskResponse( call.result.data_model = MoveTaskResponse(
position=queue_bll.reposition_task( position=queue_bll.reposition_task(
company_id=company_id, company_id=company_id,
queue_id=req_model.queue, queue_id=request.queue,
task_id=req_model.task, task_id=request.task,
move_count=MOVE_FIRST, move_count=MOVE_FIRST,
) )
) )
@ -250,15 +262,14 @@ def move_task_to_front(call: APICall, company_id, req_model: TaskRequest):
@endpoint( @endpoint(
"queues.move_task_to_back", "queues.move_task_to_back",
min_version="2.4", min_version="2.4",
request_data_model=TaskRequest,
response_data_model=MoveTaskResponse, response_data_model=MoveTaskResponse,
) )
def move_task_to_back(call: APICall, company_id, req_model: TaskRequest): def move_task_to_back(call: APICall, company_id, request: TaskRequest):
call.result.data_model = MoveTaskResponse( call.result.data_model = MoveTaskResponse(
position=queue_bll.reposition_task( position=queue_bll.reposition_task(
company_id=company_id, company_id=company_id,
queue_id=req_model.queue, queue_id=request.queue,
task_id=req_model.task, task_id=request.task,
move_count=MOVE_LAST, move_count=MOVE_LAST,
) )
) )
@ -267,7 +278,6 @@ def move_task_to_back(call: APICall, company_id, req_model: TaskRequest):
@endpoint( @endpoint(
"queues.get_queue_metrics", "queues.get_queue_metrics",
min_version="2.4", min_version="2.4",
request_data_model=GetMetricsRequest,
response_data_model=GetMetricsResponse, response_data_model=GetMetricsResponse,
) )
def get_queue_metrics( def get_queue_metrics(

View File

@ -40,6 +40,53 @@ class TestQueues(TestService):
) )
self.assertMetricQueues(res["queues"], queue_id) self.assertMetricQueues(res["queues"], queue_id)
def test_add_remove_clear(self):
queue1 = self._temp_queue("TestTempQueue1")
queue2 = self._temp_queue("TestTempQueue2")
task_names = ["TempDevTask1", "TempDevTask2"]
tasks = [self._temp_task(name) for name in task_names]
for task in tasks:
self.api.tasks.enqueue(task=task, queue=queue1)
# remove task with and without status update
res = self.api.queues.remove_task(task=tasks[0], queue=queue1)
self.assertEqual(res.removed, 1)
res = self.api.tasks.get_by_id(task=tasks[0])
self.assertEqual(res.task.status, "queued")
self.assertEqual(res.task.execution.queue, queue1)
res = self.api.queues.remove_task(task=tasks[1], queue=queue1, update_task_status=True)
self.assertEqual(res.removed, 1)
res = self.api.tasks.get_by_id(task=tasks[1])
self.assertEqual(res.task.status, "created")
res = self.api.queues.get_by_id(queue=queue1)
self.assertQueueTasks(res.queue, [])
# add task
res = self.api.queues.add_task(queue=queue2, task=tasks[0])
self.assertEqual(res.added, 1)
res = self.api.tasks.get_by_id(task=tasks[0])
self.assertEqual(res.task.status, "queued")
self.assertEqual(res.task.execution.queue, queue2)
res = self.api.queues.get_by_id(queue=queue2)
self.assertQueueTasks(res.queue, [tasks[0]])
# clear queue
res = self.api.queues.clear_queue(queue=queue1)
self.assertEqual(res.removed_tasks, [])
res = self.api.queues.clear_queue(queue=queue2)
self.assertEqual(res.removed_tasks, [tasks[0]])
res = self.api.tasks.get_by_id(task=tasks[0])
self.assertEqual(res.task.status, "created")
res = self.api.queues.get_by_id(queue=queue2)
self.assertQueueTasks(res.queue, [])
def test_hidden_queues(self): def test_hidden_queues(self):
hidden_name = "TestHiddenQueue" hidden_name = "TestHiddenQueue"
hidden_queue = self._temp_queue(hidden_name, system_tags=["k8s-glue"]) hidden_queue = self._temp_queue(hidden_name, system_tags=["k8s-glue"])

View File

@ -12,7 +12,7 @@ class TestReports(TestService):
def _delete_project(self, name): def _delete_project(self, name):
existing_project = first( existing_project = first(
self.api.projects.get_all_ex( self.api.projects.get_all_ex(
name=f"^{re.escape(name)}$", search_hidden=True name=f"^{re.escape(name)}$", search_hidden=True, allow_public=False
).projects ).projects
) )
if existing_project: if existing_project:
@ -34,10 +34,10 @@ class TestReports(TestService):
self.assertEqual(set(task.tags), set(tags)) self.assertEqual(set(task.tags), set(tags))
self.assertEqual(task.type, "report") self.assertEqual(task.type, "report")
self.assertEqual(set(task.system_tags), {"hidden", "reports"}) self.assertEqual(set(task.system_tags), {"hidden", "reports"})
projects = self.api.projects.get_all_ex(name=r"^\.reports$").projects projects = self.api.projects.get_all_ex(name=r"^\.reports$", allow_public=False).projects
self.assertEqual(len(projects), 0) self.assertEqual(len(projects), 0)
project = self.api.projects.get_all_ex( project = self.api.projects.get_all_ex(
name=r"^\.reports$", search_hidden=True name=r"^\.reports$", search_hidden=True, allow_public=False
).projects[0] ).projects[0]
self.assertEqual(project.id, task.project.id) self.assertEqual(project.id, task.project.id)
self.assertEqual(set(project.system_tags), {"hidden", "reports"}) self.assertEqual(set(project.system_tags), {"hidden", "reports"})
@ -108,6 +108,7 @@ class TestReports(TestService):
include_stats=True, include_stats=True,
check_own_contents=True, check_own_contents=True,
search_hidden=True, search_hidden=True,
allow_public=False,
).projects ).projects
self.assertEqual(len(projects), 1) self.assertEqual(len(projects), 1)
p = projects[0] p = projects[0]
@ -120,6 +121,7 @@ class TestReports(TestService):
include_stats=True, include_stats=True,
check_own_contents=True, check_own_contents=True,
search_hidden=True, search_hidden=True,
allow_public=False,
).projects ).projects
self.assertEqual(len(projects), 1) self.assertEqual(len(projects), 1)
p = projects[0] p = projects[0]

View File

@ -15,7 +15,7 @@ class TestSubProjects(TestService):
def test_dataset_stats(self): def test_dataset_stats(self):
project = self._temp_project(name="Dataset test", system_tags=["dataset"]) project = self._temp_project(name="Dataset test", system_tags=["dataset"])
res = self.api.organization.get_entities_count( res = self.api.organization.get_entities_count(
datasets={"system_tags": ["dataset"]} datasets={"system_tags": ["dataset"]}, allow_public=False,
) )
self.assertEqual(res.datasets, 1) self.assertEqual(res.datasets, 1)