mirror of
https://github.com/clearml/clearml-server
synced 2025-03-09 21:51:54 +00:00
Support max_task_entries option in queues.get_by_id/get_all
Add queues.peek_task and queues.get_num_entries
This commit is contained in:
parent
6683d2d7a9
commit
d117a4f022
@ -26,6 +26,10 @@ class QueueRequest(Base):
|
||||
queue = StringField(required=True)
|
||||
|
||||
|
||||
class GetByIdRequest(QueueRequest):
|
||||
max_task_entries = IntField()
|
||||
|
||||
|
||||
class GetNextTaskRequest(QueueRequest):
|
||||
queue = StringField(required=True)
|
||||
get_task_info = BoolField(default=False)
|
||||
|
@ -59,8 +59,16 @@ class QueueBLL(object):
|
||||
|
||||
return qs.first()
|
||||
|
||||
@staticmethod
|
||||
def _get_task_entries_projection(max_task_entries: int) -> dict:
|
||||
return dict(slice__entries=max_task_entries)
|
||||
|
||||
def get_by_id(
|
||||
self, company_id: str, queue_id: str, only: Optional[Sequence[str]] = None
|
||||
self,
|
||||
company_id: str,
|
||||
queue_id: str,
|
||||
only: Optional[Sequence[str]] = None,
|
||||
max_task_entries: int = None,
|
||||
) -> Queue:
|
||||
"""
|
||||
Get queue by id
|
||||
@ -71,6 +79,8 @@ class QueueBLL(object):
|
||||
qs = Queue.objects(**query)
|
||||
if only:
|
||||
qs = qs.only(*only)
|
||||
if max_task_entries:
|
||||
qs = qs.fields(**self._get_task_entries_projection(max_task_entries))
|
||||
queue = qs.first()
|
||||
if not queue:
|
||||
raise errors.bad_request.InvalidQueueId(**query)
|
||||
@ -136,7 +146,11 @@ class QueueBLL(object):
|
||||
queue.delete()
|
||||
|
||||
def get_all(
|
||||
self, company_id: str, query_dict: dict, ret_params: dict = None,
|
||||
self,
|
||||
company_id: str,
|
||||
query_dict: dict,
|
||||
max_task_entries: int = None,
|
||||
ret_params: dict = None,
|
||||
) -> Sequence[dict]:
|
||||
"""Get all the queues according to the query"""
|
||||
with translate_errors_context():
|
||||
@ -144,11 +158,18 @@ class QueueBLL(object):
|
||||
company=company_id,
|
||||
parameters=query_dict,
|
||||
query_dict=query_dict,
|
||||
projection_fields=self._get_task_entries_projection(max_task_entries)
|
||||
if max_task_entries
|
||||
else None,
|
||||
ret_params=ret_params,
|
||||
)
|
||||
|
||||
def get_queue_infos(
|
||||
self, company_id: str, query_dict: dict, ret_params: dict = None,
|
||||
self,
|
||||
company_id: str,
|
||||
query_dict: dict,
|
||||
max_task_entries: int = None,
|
||||
ret_params: dict = None,
|
||||
) -> Sequence[dict]:
|
||||
"""
|
||||
Get infos on all the company queues, including queue tasks and workers
|
||||
@ -159,6 +180,9 @@ class QueueBLL(object):
|
||||
company=company_id,
|
||||
query_dict=query_dict,
|
||||
override_projection=projection,
|
||||
projection_fields=self._get_task_entries_projection(max_task_entries)
|
||||
if max_task_entries
|
||||
else None,
|
||||
ret_params=ret_params,
|
||||
)
|
||||
|
||||
@ -292,5 +316,24 @@ class QueueBLL(object):
|
||||
|
||||
return new_position
|
||||
|
||||
def count_entries(self, company: str, queue_id: str) -> Optional[int]:
|
||||
res = next(
|
||||
Queue.aggregate(
|
||||
[
|
||||
{
|
||||
"$match": {
|
||||
"company": {"$in": [None, "", company]},
|
||||
"_id": queue_id,
|
||||
}
|
||||
},
|
||||
{"$project": {"count": {"$size": "$entries"}}},
|
||||
]
|
||||
),
|
||||
None,
|
||||
)
|
||||
if res is None:
|
||||
raise errors.bad_request.InvalidQueueId(queue_id=queue_id)
|
||||
return int(res.get("count"))
|
||||
|
||||
|
||||
MetricsRefresher.start(queue_metrics=QueueBLL().metrics)
|
||||
|
@ -648,6 +648,7 @@ class GetMixin(PropsMixin):
|
||||
allow_public=False,
|
||||
override_projection=None,
|
||||
expand_reference_ids=True,
|
||||
projection_fields: dict = None,
|
||||
ret_params: dict = None,
|
||||
):
|
||||
"""
|
||||
@ -684,6 +685,7 @@ class GetMixin(PropsMixin):
|
||||
query=query,
|
||||
query_options=query_options,
|
||||
allow_public=allow_public,
|
||||
projection_fields=projection_fields,
|
||||
ret_params=ret_params,
|
||||
)
|
||||
|
||||
@ -754,6 +756,7 @@ class GetMixin(PropsMixin):
|
||||
allow_public=False,
|
||||
override_projection: Collection[str] = None,
|
||||
return_dicts=True,
|
||||
projection_fields: dict = None,
|
||||
ret_params: dict = None,
|
||||
):
|
||||
"""
|
||||
@ -803,6 +806,7 @@ class GetMixin(PropsMixin):
|
||||
parameters=parameters,
|
||||
override_projection=override_projection,
|
||||
override_collation=override_collation,
|
||||
projection_fields=projection_fields,
|
||||
)
|
||||
return cls.get_data_with_scroll_support(
|
||||
query_dict=query_dict, data_getter=data_getter, ret_params=ret_params,
|
||||
@ -813,6 +817,7 @@ class GetMixin(PropsMixin):
|
||||
parameters=parameters,
|
||||
override_projection=override_projection,
|
||||
override_collation=override_collation,
|
||||
projection_fields=projection_fields,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@ -837,6 +842,7 @@ class GetMixin(PropsMixin):
|
||||
parameters=None,
|
||||
override_projection=None,
|
||||
override_collation=None,
|
||||
projection_fields: dict = None,
|
||||
):
|
||||
"""
|
||||
Fetch all documents matching a provided query.
|
||||
@ -879,6 +885,9 @@ class GetMixin(PropsMixin):
|
||||
if exclude:
|
||||
qs = qs.exclude(*exclude)
|
||||
|
||||
if projection_fields:
|
||||
qs = qs.fields(**projection_fields)
|
||||
|
||||
if start is not None and size:
|
||||
# add paging
|
||||
qs = qs.skip(start).limit(size)
|
||||
@ -920,6 +929,7 @@ class GetMixin(PropsMixin):
|
||||
parameters: dict = None,
|
||||
override_projection: Collection[str] = None,
|
||||
override_collation: dict = None,
|
||||
projection_fields: dict = None,
|
||||
) -> Sequence[dict]:
|
||||
"""
|
||||
Fetch all documents matching a provided query. For the first order by field
|
||||
@ -977,6 +987,9 @@ class GetMixin(PropsMixin):
|
||||
if exclude:
|
||||
query_sets = [qs.exclude(*exclude) for qs in query_sets]
|
||||
|
||||
if projection_fields:
|
||||
query_sets = [qs.fields(**projection_fields) for qs in query_sets]
|
||||
|
||||
if start is None or not size:
|
||||
return [obj.to_proper_dict(only=include) for qs in query_sets for obj in qs]
|
||||
|
||||
|
@ -112,6 +112,12 @@ get_by_id {
|
||||
}
|
||||
}
|
||||
}
|
||||
"999.0": ${get_by_id."2.4"} {
|
||||
request.properties.max_task_entries {
|
||||
description: Max number of queue task entries to return
|
||||
type: integer
|
||||
}
|
||||
}
|
||||
}
|
||||
// typescript generation hack
|
||||
get_all_ex {
|
||||
@ -140,6 +146,12 @@ get_all_ex {
|
||||
description: "Scroll ID that can be used with the next calls to get_all_ex to retrieve more data"
|
||||
}
|
||||
}
|
||||
"999.0": ${get_all_ex."2.15"} {
|
||||
request.properties.max_task_entries {
|
||||
description: Max number of queue task entries to return
|
||||
type: integer
|
||||
}
|
||||
}
|
||||
}
|
||||
get_all {
|
||||
"2.4" {
|
||||
@ -226,6 +238,12 @@ get_all {
|
||||
description: "Scroll ID that can be used with the next calls to get_all to retrieve more data"
|
||||
}
|
||||
}
|
||||
"999.0": ${get_all."2.15"} {
|
||||
request.properties.max_task_entries {
|
||||
description: Max number of queue task entries to return
|
||||
type: integer
|
||||
}
|
||||
}
|
||||
}
|
||||
get_default {
|
||||
"2.4" {
|
||||
@ -707,3 +725,51 @@ delete_metadata {
|
||||
}
|
||||
}
|
||||
}
|
||||
peek_task {
|
||||
"2.15" {
|
||||
description: "Peek the next task from a given queue"
|
||||
request {
|
||||
type: object
|
||||
required: [ queue ]
|
||||
properties {
|
||||
queue {
|
||||
description: "ID of the queue"
|
||||
type: string
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
task {
|
||||
description: "Task ID"
|
||||
type: string
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
get_num_entries {
|
||||
"2.15" {
|
||||
description: "Get the number of task entries in the given queue"
|
||||
request {
|
||||
type: object
|
||||
required: [ queue ]
|
||||
properties {
|
||||
queue {
|
||||
description: "ID of the queue"
|
||||
type: string
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
num {
|
||||
description: "Number of entries"
|
||||
type: integer
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -14,6 +14,7 @@ from apiserver.apimodels.queues import (
|
||||
AddOrUpdateMetadataRequest,
|
||||
DeleteMetadataRequest,
|
||||
GetNextTaskRequest,
|
||||
GetByIdRequest,
|
||||
)
|
||||
from apiserver.bll.model import Metadata
|
||||
from apiserver.bll.queue import QueueBLL
|
||||
@ -33,9 +34,11 @@ worker_bll = WorkerBLL()
|
||||
queue_bll = QueueBLL(worker_bll)
|
||||
|
||||
|
||||
@endpoint("queues.get_by_id", min_version="2.4", request_data_model=QueueRequest)
|
||||
def get_by_id(call: APICall, company_id, req_model: QueueRequest):
|
||||
queue = queue_bll.get_by_id(company_id, req_model.queue)
|
||||
@endpoint("queues.get_by_id", min_version="2.4", request_data_model=GetByIdRequest)
|
||||
def get_by_id(call: APICall, company_id, request: GetByIdRequest):
|
||||
queue = queue_bll.get_by_id(
|
||||
company_id, request.queue, max_task_entries=request.max_task_entries
|
||||
)
|
||||
queue_dict = queue.to_proper_dict()
|
||||
conform_output_tags(call, queue_dict)
|
||||
unescape_metadata(call, queue_dict)
|
||||
@ -55,7 +58,10 @@ def get_all_ex(call: APICall):
|
||||
|
||||
Metadata.escape_query_parameters(call)
|
||||
queues = queue_bll.get_queue_infos(
|
||||
company_id=call.identity.company, query_dict=call.data, ret_params=ret_params,
|
||||
company_id=call.identity.company,
|
||||
query_dict=call.data,
|
||||
max_task_entries=call.data.pop("max_task_entries", None),
|
||||
ret_params=ret_params,
|
||||
)
|
||||
conform_output_tags(call, queues)
|
||||
unescape_metadata(call, queues)
|
||||
@ -68,7 +74,10 @@ def get_all(call: APICall):
|
||||
ret_params = {}
|
||||
Metadata.escape_query_parameters(call)
|
||||
queues = queue_bll.get_all(
|
||||
company_id=call.identity.company, query_dict=call.data, ret_params=ret_params,
|
||||
company_id=call.identity.company,
|
||||
query_dict=call.data,
|
||||
max_task_entries=call.data.pop("max_task_entries", None),
|
||||
ret_params=ret_params,
|
||||
)
|
||||
conform_output_tags(call, queues)
|
||||
unescape_metadata(call, queues)
|
||||
@ -272,3 +281,17 @@ def delete_metadata(call: APICall, company_id: str, request: DeleteMetadataReque
|
||||
queue_id = request.queue
|
||||
queue = queue_bll.get_by_id(company_id=company_id, queue_id=queue_id, only=("id",))
|
||||
return {"updated": Metadata.delete_metadata(queue, keys=request.keys)}
|
||||
|
||||
|
||||
@endpoint("queues.peek_task", min_version="2.15")
|
||||
def peek_task(call: APICall, company_id: str, request: QueueRequest):
|
||||
queue_id = request.queue
|
||||
queue = queue_bll.get_by_id(
|
||||
company_id=company_id, queue_id=queue_id, max_task_entries=1
|
||||
)
|
||||
return {"task": queue.entries[0].task if queue.entries else None}
|
||||
|
||||
|
||||
@endpoint("queues.get_num_entries", min_version="2.15")
|
||||
def get_num_entries(call: APICall, company_id: str, request: QueueRequest):
|
||||
return {"num": queue_bll.count_entries(company=company_id, queue_id=request.queue)}
|
||||
|
@ -9,9 +9,6 @@ from apiserver.tests.automated import TestService, utc_now_tz_aware
|
||||
|
||||
|
||||
class TestQueues(TestService):
|
||||
def setUp(self, version="2.4"):
|
||||
super().setUp(version=version)
|
||||
|
||||
def test_default_queue(self):
|
||||
res = self.api.queues.get_default()
|
||||
self.assertIsNotNone(res.id)
|
||||
@ -63,6 +60,34 @@ class TestQueues(TestService):
|
||||
self.assertQueueTasks(res.queue, [task])
|
||||
self.assertTaskTags(task, system_tags=[])
|
||||
|
||||
def test_max_queue_entries(self):
|
||||
queue = self._temp_queue("TestTempQueue")
|
||||
tasks = [
|
||||
self._create_temp_queued_task(t, queue)["id"]
|
||||
for t in ("temp task1", "temp task2", "temp task3")
|
||||
]
|
||||
|
||||
num = self.api.queues.get_num_entries(queue=queue).num
|
||||
self.assertEqual(num, 3)
|
||||
|
||||
task_id = self.api.queues.peek_task(queue=queue).task
|
||||
self.assertEqual(task_id, tasks[0])
|
||||
|
||||
res = self.api.queues.get_by_id(queue=queue)
|
||||
self.assertQueueTasks(res.queue, tasks)
|
||||
|
||||
res = self.api.queues.get_all(id=[queue]).queues[0]
|
||||
self.assertQueueTasks(res, tasks)
|
||||
|
||||
res = self.api.queues.get_all(id=[queue], max_task_entries=2).queues[0]
|
||||
self.assertQueueTasks(res, tasks[:2])
|
||||
|
||||
res = self.api.queues.get_all_ex(id=[queue]).queues[0]
|
||||
self.assertEqual([e.task.id for e in res.entries], tasks)
|
||||
|
||||
res = self.api.queues.get_all_ex(id=[queue], max_task_entries=2).queues[0]
|
||||
self.assertEqual([e.task.id for e in res.entries], tasks[:2])
|
||||
|
||||
def test_move_task(self):
|
||||
queue = self._temp_queue("TestTempQueue")
|
||||
tasks = [
|
||||
|
Loading…
Reference in New Issue
Block a user