Support max_task_entries option in queues.get_by_id/get_all

Add queues.peek_task and queues.get_num_entries
This commit is contained in:
allegroai 2022-07-08 17:42:20 +03:00
parent 6683d2d7a9
commit d117a4f022
6 changed files with 185 additions and 11 deletions

View File

@ -26,6 +26,10 @@ class QueueRequest(Base):
queue = StringField(required=True)
class GetByIdRequest(QueueRequest):
max_task_entries = IntField()
class GetNextTaskRequest(QueueRequest):
queue = StringField(required=True)
get_task_info = BoolField(default=False)

View File

@ -59,8 +59,16 @@ class QueueBLL(object):
return qs.first()
@staticmethod
def _get_task_entries_projection(max_task_entries: int) -> dict:
return dict(slice__entries=max_task_entries)
def get_by_id(
self, company_id: str, queue_id: str, only: Optional[Sequence[str]] = None
self,
company_id: str,
queue_id: str,
only: Optional[Sequence[str]] = None,
max_task_entries: int = None,
) -> Queue:
"""
Get queue by id
@ -71,6 +79,8 @@ class QueueBLL(object):
qs = Queue.objects(**query)
if only:
qs = qs.only(*only)
if max_task_entries:
qs = qs.fields(**self._get_task_entries_projection(max_task_entries))
queue = qs.first()
if not queue:
raise errors.bad_request.InvalidQueueId(**query)
@ -136,7 +146,11 @@ class QueueBLL(object):
queue.delete()
def get_all(
self, company_id: str, query_dict: dict, ret_params: dict = None,
self,
company_id: str,
query_dict: dict,
max_task_entries: int = None,
ret_params: dict = None,
) -> Sequence[dict]:
"""Get all the queues according to the query"""
with translate_errors_context():
@ -144,11 +158,18 @@ class QueueBLL(object):
company=company_id,
parameters=query_dict,
query_dict=query_dict,
projection_fields=self._get_task_entries_projection(max_task_entries)
if max_task_entries
else None,
ret_params=ret_params,
)
def get_queue_infos(
self, company_id: str, query_dict: dict, ret_params: dict = None,
self,
company_id: str,
query_dict: dict,
max_task_entries: int = None,
ret_params: dict = None,
) -> Sequence[dict]:
"""
Get infos on all the company queues, including queue tasks and workers
@ -159,6 +180,9 @@ class QueueBLL(object):
company=company_id,
query_dict=query_dict,
override_projection=projection,
projection_fields=self._get_task_entries_projection(max_task_entries)
if max_task_entries
else None,
ret_params=ret_params,
)
@ -292,5 +316,24 @@ class QueueBLL(object):
return new_position
def count_entries(self, company: str, queue_id: str) -> Optional[int]:
res = next(
Queue.aggregate(
[
{
"$match": {
"company": {"$in": [None, "", company]},
"_id": queue_id,
}
},
{"$project": {"count": {"$size": "$entries"}}},
]
),
None,
)
if res is None:
raise errors.bad_request.InvalidQueueId(queue_id=queue_id)
return int(res.get("count"))
MetricsRefresher.start(queue_metrics=QueueBLL().metrics)

View File

@ -648,6 +648,7 @@ class GetMixin(PropsMixin):
allow_public=False,
override_projection=None,
expand_reference_ids=True,
projection_fields: dict = None,
ret_params: dict = None,
):
"""
@ -684,6 +685,7 @@ class GetMixin(PropsMixin):
query=query,
query_options=query_options,
allow_public=allow_public,
projection_fields=projection_fields,
ret_params=ret_params,
)
@ -754,6 +756,7 @@ class GetMixin(PropsMixin):
allow_public=False,
override_projection: Collection[str] = None,
return_dicts=True,
projection_fields: dict = None,
ret_params: dict = None,
):
"""
@ -803,6 +806,7 @@ class GetMixin(PropsMixin):
parameters=parameters,
override_projection=override_projection,
override_collation=override_collation,
projection_fields=projection_fields,
)
return cls.get_data_with_scroll_support(
query_dict=query_dict, data_getter=data_getter, ret_params=ret_params,
@ -813,6 +817,7 @@ class GetMixin(PropsMixin):
parameters=parameters,
override_projection=override_projection,
override_collation=override_collation,
projection_fields=projection_fields,
)
@classmethod
@ -837,6 +842,7 @@ class GetMixin(PropsMixin):
parameters=None,
override_projection=None,
override_collation=None,
projection_fields: dict = None,
):
"""
Fetch all documents matching a provided query.
@ -879,6 +885,9 @@ class GetMixin(PropsMixin):
if exclude:
qs = qs.exclude(*exclude)
if projection_fields:
qs = qs.fields(**projection_fields)
if start is not None and size:
# add paging
qs = qs.skip(start).limit(size)
@ -920,6 +929,7 @@ class GetMixin(PropsMixin):
parameters: dict = None,
override_projection: Collection[str] = None,
override_collation: dict = None,
projection_fields: dict = None,
) -> Sequence[dict]:
"""
Fetch all documents matching a provided query. For the first order by field
@ -977,6 +987,9 @@ class GetMixin(PropsMixin):
if exclude:
query_sets = [qs.exclude(*exclude) for qs in query_sets]
if projection_fields:
query_sets = [qs.fields(**projection_fields) for qs in query_sets]
if start is None or not size:
return [obj.to_proper_dict(only=include) for qs in query_sets for obj in qs]

View File

@ -112,6 +112,12 @@ get_by_id {
}
}
}
"999.0": ${get_by_id."2.4"} {
request.properties.max_task_entries {
description: Max number of queue task entries to return
type: integer
}
}
}
// typescript generation hack
get_all_ex {
@ -140,6 +146,12 @@ get_all_ex {
description: "Scroll ID that can be used with the next calls to get_all_ex to retrieve more data"
}
}
"999.0": ${get_all_ex."2.15"} {
request.properties.max_task_entries {
description: Max number of queue task entries to return
type: integer
}
}
}
get_all {
"2.4" {
@ -226,6 +238,12 @@ get_all {
description: "Scroll ID that can be used with the next calls to get_all to retrieve more data"
}
}
"999.0": ${get_all."2.15"} {
request.properties.max_task_entries {
description: Max number of queue task entries to return
type: integer
}
}
}
get_default {
"2.4" {
@ -707,3 +725,51 @@ delete_metadata {
}
}
}
peek_task {
"2.15" {
description: "Peek the next task from a given queue"
request {
type: object
required: [ queue ]
properties {
queue {
description: "ID of the queue"
type: string
}
}
}
response {
type: object
properties {
task {
description: "Task ID"
type: string
}
}
}
}
}
get_num_entries {
"2.15" {
description: "Get the number of task entries in the given queue"
request {
type: object
required: [ queue ]
properties {
queue {
description: "ID of the queue"
type: string
}
}
}
response {
type: object
properties {
num {
description: "Number of entries"
type: integer
}
}
}
}
}

View File

@ -14,6 +14,7 @@ from apiserver.apimodels.queues import (
AddOrUpdateMetadataRequest,
DeleteMetadataRequest,
GetNextTaskRequest,
GetByIdRequest,
)
from apiserver.bll.model import Metadata
from apiserver.bll.queue import QueueBLL
@ -33,9 +34,11 @@ worker_bll = WorkerBLL()
queue_bll = QueueBLL(worker_bll)
@endpoint("queues.get_by_id", min_version="2.4", request_data_model=QueueRequest)
def get_by_id(call: APICall, company_id, req_model: QueueRequest):
queue = queue_bll.get_by_id(company_id, req_model.queue)
@endpoint("queues.get_by_id", min_version="2.4", request_data_model=GetByIdRequest)
def get_by_id(call: APICall, company_id, request: GetByIdRequest):
queue = queue_bll.get_by_id(
company_id, request.queue, max_task_entries=request.max_task_entries
)
queue_dict = queue.to_proper_dict()
conform_output_tags(call, queue_dict)
unescape_metadata(call, queue_dict)
@ -55,7 +58,10 @@ def get_all_ex(call: APICall):
Metadata.escape_query_parameters(call)
queues = queue_bll.get_queue_infos(
company_id=call.identity.company, query_dict=call.data, ret_params=ret_params,
company_id=call.identity.company,
query_dict=call.data,
max_task_entries=call.data.pop("max_task_entries", None),
ret_params=ret_params,
)
conform_output_tags(call, queues)
unescape_metadata(call, queues)
@ -68,7 +74,10 @@ def get_all(call: APICall):
ret_params = {}
Metadata.escape_query_parameters(call)
queues = queue_bll.get_all(
company_id=call.identity.company, query_dict=call.data, ret_params=ret_params,
company_id=call.identity.company,
query_dict=call.data,
max_task_entries=call.data.pop("max_task_entries", None),
ret_params=ret_params,
)
conform_output_tags(call, queues)
unescape_metadata(call, queues)
@ -272,3 +281,17 @@ def delete_metadata(call: APICall, company_id: str, request: DeleteMetadataReque
queue_id = request.queue
queue = queue_bll.get_by_id(company_id=company_id, queue_id=queue_id, only=("id",))
return {"updated": Metadata.delete_metadata(queue, keys=request.keys)}
@endpoint("queues.peek_task", min_version="2.15")
def peek_task(call: APICall, company_id: str, request: QueueRequest):
queue_id = request.queue
queue = queue_bll.get_by_id(
company_id=company_id, queue_id=queue_id, max_task_entries=1
)
return {"task": queue.entries[0].task if queue.entries else None}
@endpoint("queues.get_num_entries", min_version="2.15")
def get_num_entries(call: APICall, company_id: str, request: QueueRequest):
return {"num": queue_bll.count_entries(company=company_id, queue_id=request.queue)}

View File

@ -9,9 +9,6 @@ from apiserver.tests.automated import TestService, utc_now_tz_aware
class TestQueues(TestService):
def setUp(self, version="2.4"):
super().setUp(version=version)
def test_default_queue(self):
res = self.api.queues.get_default()
self.assertIsNotNone(res.id)
@ -63,6 +60,34 @@ class TestQueues(TestService):
self.assertQueueTasks(res.queue, [task])
self.assertTaskTags(task, system_tags=[])
def test_max_queue_entries(self):
queue = self._temp_queue("TestTempQueue")
tasks = [
self._create_temp_queued_task(t, queue)["id"]
for t in ("temp task1", "temp task2", "temp task3")
]
num = self.api.queues.get_num_entries(queue=queue).num
self.assertEqual(num, 3)
task_id = self.api.queues.peek_task(queue=queue).task
self.assertEqual(task_id, tasks[0])
res = self.api.queues.get_by_id(queue=queue)
self.assertQueueTasks(res.queue, tasks)
res = self.api.queues.get_all(id=[queue]).queues[0]
self.assertQueueTasks(res, tasks)
res = self.api.queues.get_all(id=[queue], max_task_entries=2).queues[0]
self.assertQueueTasks(res, tasks[:2])
res = self.api.queues.get_all_ex(id=[queue]).queues[0]
self.assertEqual([e.task.id for e in res.entries], tasks)
res = self.api.queues.get_all_ex(id=[queue], max_task_entries=2).queues[0]
self.assertEqual([e.task.id for e in res.entries], tasks[:2])
def test_move_task(self):
queue = self._temp_queue("TestTempQueue")
tasks = [