mirror of
https://github.com/clearml/clearml-server
synced 2025-05-17 18:04:59 +00:00
Support max_task_entries option in queues.get_by_id/get_all
Add queues.peek_task and queues.get_num_entries
This commit is contained in:
parent
6683d2d7a9
commit
d117a4f022
@ -26,6 +26,10 @@ class QueueRequest(Base):
|
|||||||
queue = StringField(required=True)
|
queue = StringField(required=True)
|
||||||
|
|
||||||
|
|
||||||
|
class GetByIdRequest(QueueRequest):
|
||||||
|
max_task_entries = IntField()
|
||||||
|
|
||||||
|
|
||||||
class GetNextTaskRequest(QueueRequest):
|
class GetNextTaskRequest(QueueRequest):
|
||||||
queue = StringField(required=True)
|
queue = StringField(required=True)
|
||||||
get_task_info = BoolField(default=False)
|
get_task_info = BoolField(default=False)
|
||||||
|
@ -59,8 +59,16 @@ class QueueBLL(object):
|
|||||||
|
|
||||||
return qs.first()
|
return qs.first()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_task_entries_projection(max_task_entries: int) -> dict:
|
||||||
|
return dict(slice__entries=max_task_entries)
|
||||||
|
|
||||||
def get_by_id(
|
def get_by_id(
|
||||||
self, company_id: str, queue_id: str, only: Optional[Sequence[str]] = None
|
self,
|
||||||
|
company_id: str,
|
||||||
|
queue_id: str,
|
||||||
|
only: Optional[Sequence[str]] = None,
|
||||||
|
max_task_entries: int = None,
|
||||||
) -> Queue:
|
) -> Queue:
|
||||||
"""
|
"""
|
||||||
Get queue by id
|
Get queue by id
|
||||||
@ -71,6 +79,8 @@ class QueueBLL(object):
|
|||||||
qs = Queue.objects(**query)
|
qs = Queue.objects(**query)
|
||||||
if only:
|
if only:
|
||||||
qs = qs.only(*only)
|
qs = qs.only(*only)
|
||||||
|
if max_task_entries:
|
||||||
|
qs = qs.fields(**self._get_task_entries_projection(max_task_entries))
|
||||||
queue = qs.first()
|
queue = qs.first()
|
||||||
if not queue:
|
if not queue:
|
||||||
raise errors.bad_request.InvalidQueueId(**query)
|
raise errors.bad_request.InvalidQueueId(**query)
|
||||||
@ -136,7 +146,11 @@ class QueueBLL(object):
|
|||||||
queue.delete()
|
queue.delete()
|
||||||
|
|
||||||
def get_all(
|
def get_all(
|
||||||
self, company_id: str, query_dict: dict, ret_params: dict = None,
|
self,
|
||||||
|
company_id: str,
|
||||||
|
query_dict: dict,
|
||||||
|
max_task_entries: int = None,
|
||||||
|
ret_params: dict = None,
|
||||||
) -> Sequence[dict]:
|
) -> Sequence[dict]:
|
||||||
"""Get all the queues according to the query"""
|
"""Get all the queues according to the query"""
|
||||||
with translate_errors_context():
|
with translate_errors_context():
|
||||||
@ -144,11 +158,18 @@ class QueueBLL(object):
|
|||||||
company=company_id,
|
company=company_id,
|
||||||
parameters=query_dict,
|
parameters=query_dict,
|
||||||
query_dict=query_dict,
|
query_dict=query_dict,
|
||||||
|
projection_fields=self._get_task_entries_projection(max_task_entries)
|
||||||
|
if max_task_entries
|
||||||
|
else None,
|
||||||
ret_params=ret_params,
|
ret_params=ret_params,
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_queue_infos(
|
def get_queue_infos(
|
||||||
self, company_id: str, query_dict: dict, ret_params: dict = None,
|
self,
|
||||||
|
company_id: str,
|
||||||
|
query_dict: dict,
|
||||||
|
max_task_entries: int = None,
|
||||||
|
ret_params: dict = None,
|
||||||
) -> Sequence[dict]:
|
) -> Sequence[dict]:
|
||||||
"""
|
"""
|
||||||
Get infos on all the company queues, including queue tasks and workers
|
Get infos on all the company queues, including queue tasks and workers
|
||||||
@ -159,6 +180,9 @@ class QueueBLL(object):
|
|||||||
company=company_id,
|
company=company_id,
|
||||||
query_dict=query_dict,
|
query_dict=query_dict,
|
||||||
override_projection=projection,
|
override_projection=projection,
|
||||||
|
projection_fields=self._get_task_entries_projection(max_task_entries)
|
||||||
|
if max_task_entries
|
||||||
|
else None,
|
||||||
ret_params=ret_params,
|
ret_params=ret_params,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -292,5 +316,24 @@ class QueueBLL(object):
|
|||||||
|
|
||||||
return new_position
|
return new_position
|
||||||
|
|
||||||
|
def count_entries(self, company: str, queue_id: str) -> Optional[int]:
|
||||||
|
res = next(
|
||||||
|
Queue.aggregate(
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"$match": {
|
||||||
|
"company": {"$in": [None, "", company]},
|
||||||
|
"_id": queue_id,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{"$project": {"count": {"$size": "$entries"}}},
|
||||||
|
]
|
||||||
|
),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
if res is None:
|
||||||
|
raise errors.bad_request.InvalidQueueId(queue_id=queue_id)
|
||||||
|
return int(res.get("count"))
|
||||||
|
|
||||||
|
|
||||||
MetricsRefresher.start(queue_metrics=QueueBLL().metrics)
|
MetricsRefresher.start(queue_metrics=QueueBLL().metrics)
|
||||||
|
@ -648,6 +648,7 @@ class GetMixin(PropsMixin):
|
|||||||
allow_public=False,
|
allow_public=False,
|
||||||
override_projection=None,
|
override_projection=None,
|
||||||
expand_reference_ids=True,
|
expand_reference_ids=True,
|
||||||
|
projection_fields: dict = None,
|
||||||
ret_params: dict = None,
|
ret_params: dict = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@ -684,6 +685,7 @@ class GetMixin(PropsMixin):
|
|||||||
query=query,
|
query=query,
|
||||||
query_options=query_options,
|
query_options=query_options,
|
||||||
allow_public=allow_public,
|
allow_public=allow_public,
|
||||||
|
projection_fields=projection_fields,
|
||||||
ret_params=ret_params,
|
ret_params=ret_params,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -754,6 +756,7 @@ class GetMixin(PropsMixin):
|
|||||||
allow_public=False,
|
allow_public=False,
|
||||||
override_projection: Collection[str] = None,
|
override_projection: Collection[str] = None,
|
||||||
return_dicts=True,
|
return_dicts=True,
|
||||||
|
projection_fields: dict = None,
|
||||||
ret_params: dict = None,
|
ret_params: dict = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@ -803,6 +806,7 @@ class GetMixin(PropsMixin):
|
|||||||
parameters=parameters,
|
parameters=parameters,
|
||||||
override_projection=override_projection,
|
override_projection=override_projection,
|
||||||
override_collation=override_collation,
|
override_collation=override_collation,
|
||||||
|
projection_fields=projection_fields,
|
||||||
)
|
)
|
||||||
return cls.get_data_with_scroll_support(
|
return cls.get_data_with_scroll_support(
|
||||||
query_dict=query_dict, data_getter=data_getter, ret_params=ret_params,
|
query_dict=query_dict, data_getter=data_getter, ret_params=ret_params,
|
||||||
@ -813,6 +817,7 @@ class GetMixin(PropsMixin):
|
|||||||
parameters=parameters,
|
parameters=parameters,
|
||||||
override_projection=override_projection,
|
override_projection=override_projection,
|
||||||
override_collation=override_collation,
|
override_collation=override_collation,
|
||||||
|
projection_fields=projection_fields,
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -837,6 +842,7 @@ class GetMixin(PropsMixin):
|
|||||||
parameters=None,
|
parameters=None,
|
||||||
override_projection=None,
|
override_projection=None,
|
||||||
override_collation=None,
|
override_collation=None,
|
||||||
|
projection_fields: dict = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Fetch all documents matching a provided query.
|
Fetch all documents matching a provided query.
|
||||||
@ -879,6 +885,9 @@ class GetMixin(PropsMixin):
|
|||||||
if exclude:
|
if exclude:
|
||||||
qs = qs.exclude(*exclude)
|
qs = qs.exclude(*exclude)
|
||||||
|
|
||||||
|
if projection_fields:
|
||||||
|
qs = qs.fields(**projection_fields)
|
||||||
|
|
||||||
if start is not None and size:
|
if start is not None and size:
|
||||||
# add paging
|
# add paging
|
||||||
qs = qs.skip(start).limit(size)
|
qs = qs.skip(start).limit(size)
|
||||||
@ -920,6 +929,7 @@ class GetMixin(PropsMixin):
|
|||||||
parameters: dict = None,
|
parameters: dict = None,
|
||||||
override_projection: Collection[str] = None,
|
override_projection: Collection[str] = None,
|
||||||
override_collation: dict = None,
|
override_collation: dict = None,
|
||||||
|
projection_fields: dict = None,
|
||||||
) -> Sequence[dict]:
|
) -> Sequence[dict]:
|
||||||
"""
|
"""
|
||||||
Fetch all documents matching a provided query. For the first order by field
|
Fetch all documents matching a provided query. For the first order by field
|
||||||
@ -977,6 +987,9 @@ class GetMixin(PropsMixin):
|
|||||||
if exclude:
|
if exclude:
|
||||||
query_sets = [qs.exclude(*exclude) for qs in query_sets]
|
query_sets = [qs.exclude(*exclude) for qs in query_sets]
|
||||||
|
|
||||||
|
if projection_fields:
|
||||||
|
query_sets = [qs.fields(**projection_fields) for qs in query_sets]
|
||||||
|
|
||||||
if start is None or not size:
|
if start is None or not size:
|
||||||
return [obj.to_proper_dict(only=include) for qs in query_sets for obj in qs]
|
return [obj.to_proper_dict(only=include) for qs in query_sets for obj in qs]
|
||||||
|
|
||||||
|
@ -112,6 +112,12 @@ get_by_id {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
"999.0": ${get_by_id."2.4"} {
|
||||||
|
request.properties.max_task_entries {
|
||||||
|
description: Max number of queue task entries to return
|
||||||
|
type: integer
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
// typescript generation hack
|
// typescript generation hack
|
||||||
get_all_ex {
|
get_all_ex {
|
||||||
@ -140,6 +146,12 @@ get_all_ex {
|
|||||||
description: "Scroll ID that can be used with the next calls to get_all_ex to retrieve more data"
|
description: "Scroll ID that can be used with the next calls to get_all_ex to retrieve more data"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
"999.0": ${get_all_ex."2.15"} {
|
||||||
|
request.properties.max_task_entries {
|
||||||
|
description: Max number of queue task entries to return
|
||||||
|
type: integer
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
get_all {
|
get_all {
|
||||||
"2.4" {
|
"2.4" {
|
||||||
@ -226,6 +238,12 @@ get_all {
|
|||||||
description: "Scroll ID that can be used with the next calls to get_all to retrieve more data"
|
description: "Scroll ID that can be used with the next calls to get_all to retrieve more data"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
"999.0": ${get_all."2.15"} {
|
||||||
|
request.properties.max_task_entries {
|
||||||
|
description: Max number of queue task entries to return
|
||||||
|
type: integer
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
get_default {
|
get_default {
|
||||||
"2.4" {
|
"2.4" {
|
||||||
@ -707,3 +725,51 @@ delete_metadata {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
peek_task {
|
||||||
|
"2.15" {
|
||||||
|
description: "Peek the next task from a given queue"
|
||||||
|
request {
|
||||||
|
type: object
|
||||||
|
required: [ queue ]
|
||||||
|
properties {
|
||||||
|
queue {
|
||||||
|
description: "ID of the queue"
|
||||||
|
type: string
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
response {
|
||||||
|
type: object
|
||||||
|
properties {
|
||||||
|
task {
|
||||||
|
description: "Task ID"
|
||||||
|
type: string
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
get_num_entries {
|
||||||
|
"2.15" {
|
||||||
|
description: "Get the number of task entries in the given queue"
|
||||||
|
request {
|
||||||
|
type: object
|
||||||
|
required: [ queue ]
|
||||||
|
properties {
|
||||||
|
queue {
|
||||||
|
description: "ID of the queue"
|
||||||
|
type: string
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
response {
|
||||||
|
type: object
|
||||||
|
properties {
|
||||||
|
num {
|
||||||
|
description: "Number of entries"
|
||||||
|
type: integer
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -14,6 +14,7 @@ from apiserver.apimodels.queues import (
|
|||||||
AddOrUpdateMetadataRequest,
|
AddOrUpdateMetadataRequest,
|
||||||
DeleteMetadataRequest,
|
DeleteMetadataRequest,
|
||||||
GetNextTaskRequest,
|
GetNextTaskRequest,
|
||||||
|
GetByIdRequest,
|
||||||
)
|
)
|
||||||
from apiserver.bll.model import Metadata
|
from apiserver.bll.model import Metadata
|
||||||
from apiserver.bll.queue import QueueBLL
|
from apiserver.bll.queue import QueueBLL
|
||||||
@ -33,9 +34,11 @@ worker_bll = WorkerBLL()
|
|||||||
queue_bll = QueueBLL(worker_bll)
|
queue_bll = QueueBLL(worker_bll)
|
||||||
|
|
||||||
|
|
||||||
@endpoint("queues.get_by_id", min_version="2.4", request_data_model=QueueRequest)
|
@endpoint("queues.get_by_id", min_version="2.4", request_data_model=GetByIdRequest)
|
||||||
def get_by_id(call: APICall, company_id, req_model: QueueRequest):
|
def get_by_id(call: APICall, company_id, request: GetByIdRequest):
|
||||||
queue = queue_bll.get_by_id(company_id, req_model.queue)
|
queue = queue_bll.get_by_id(
|
||||||
|
company_id, request.queue, max_task_entries=request.max_task_entries
|
||||||
|
)
|
||||||
queue_dict = queue.to_proper_dict()
|
queue_dict = queue.to_proper_dict()
|
||||||
conform_output_tags(call, queue_dict)
|
conform_output_tags(call, queue_dict)
|
||||||
unescape_metadata(call, queue_dict)
|
unescape_metadata(call, queue_dict)
|
||||||
@ -55,7 +58,10 @@ def get_all_ex(call: APICall):
|
|||||||
|
|
||||||
Metadata.escape_query_parameters(call)
|
Metadata.escape_query_parameters(call)
|
||||||
queues = queue_bll.get_queue_infos(
|
queues = queue_bll.get_queue_infos(
|
||||||
company_id=call.identity.company, query_dict=call.data, ret_params=ret_params,
|
company_id=call.identity.company,
|
||||||
|
query_dict=call.data,
|
||||||
|
max_task_entries=call.data.pop("max_task_entries", None),
|
||||||
|
ret_params=ret_params,
|
||||||
)
|
)
|
||||||
conform_output_tags(call, queues)
|
conform_output_tags(call, queues)
|
||||||
unescape_metadata(call, queues)
|
unescape_metadata(call, queues)
|
||||||
@ -68,7 +74,10 @@ def get_all(call: APICall):
|
|||||||
ret_params = {}
|
ret_params = {}
|
||||||
Metadata.escape_query_parameters(call)
|
Metadata.escape_query_parameters(call)
|
||||||
queues = queue_bll.get_all(
|
queues = queue_bll.get_all(
|
||||||
company_id=call.identity.company, query_dict=call.data, ret_params=ret_params,
|
company_id=call.identity.company,
|
||||||
|
query_dict=call.data,
|
||||||
|
max_task_entries=call.data.pop("max_task_entries", None),
|
||||||
|
ret_params=ret_params,
|
||||||
)
|
)
|
||||||
conform_output_tags(call, queues)
|
conform_output_tags(call, queues)
|
||||||
unescape_metadata(call, queues)
|
unescape_metadata(call, queues)
|
||||||
@ -272,3 +281,17 @@ def delete_metadata(call: APICall, company_id: str, request: DeleteMetadataReque
|
|||||||
queue_id = request.queue
|
queue_id = request.queue
|
||||||
queue = queue_bll.get_by_id(company_id=company_id, queue_id=queue_id, only=("id",))
|
queue = queue_bll.get_by_id(company_id=company_id, queue_id=queue_id, only=("id",))
|
||||||
return {"updated": Metadata.delete_metadata(queue, keys=request.keys)}
|
return {"updated": Metadata.delete_metadata(queue, keys=request.keys)}
|
||||||
|
|
||||||
|
|
||||||
|
@endpoint("queues.peek_task", min_version="2.15")
|
||||||
|
def peek_task(call: APICall, company_id: str, request: QueueRequest):
|
||||||
|
queue_id = request.queue
|
||||||
|
queue = queue_bll.get_by_id(
|
||||||
|
company_id=company_id, queue_id=queue_id, max_task_entries=1
|
||||||
|
)
|
||||||
|
return {"task": queue.entries[0].task if queue.entries else None}
|
||||||
|
|
||||||
|
|
||||||
|
@endpoint("queues.get_num_entries", min_version="2.15")
|
||||||
|
def get_num_entries(call: APICall, company_id: str, request: QueueRequest):
|
||||||
|
return {"num": queue_bll.count_entries(company=company_id, queue_id=request.queue)}
|
||||||
|
@ -9,9 +9,6 @@ from apiserver.tests.automated import TestService, utc_now_tz_aware
|
|||||||
|
|
||||||
|
|
||||||
class TestQueues(TestService):
|
class TestQueues(TestService):
|
||||||
def setUp(self, version="2.4"):
|
|
||||||
super().setUp(version=version)
|
|
||||||
|
|
||||||
def test_default_queue(self):
|
def test_default_queue(self):
|
||||||
res = self.api.queues.get_default()
|
res = self.api.queues.get_default()
|
||||||
self.assertIsNotNone(res.id)
|
self.assertIsNotNone(res.id)
|
||||||
@ -63,6 +60,34 @@ class TestQueues(TestService):
|
|||||||
self.assertQueueTasks(res.queue, [task])
|
self.assertQueueTasks(res.queue, [task])
|
||||||
self.assertTaskTags(task, system_tags=[])
|
self.assertTaskTags(task, system_tags=[])
|
||||||
|
|
||||||
|
def test_max_queue_entries(self):
|
||||||
|
queue = self._temp_queue("TestTempQueue")
|
||||||
|
tasks = [
|
||||||
|
self._create_temp_queued_task(t, queue)["id"]
|
||||||
|
for t in ("temp task1", "temp task2", "temp task3")
|
||||||
|
]
|
||||||
|
|
||||||
|
num = self.api.queues.get_num_entries(queue=queue).num
|
||||||
|
self.assertEqual(num, 3)
|
||||||
|
|
||||||
|
task_id = self.api.queues.peek_task(queue=queue).task
|
||||||
|
self.assertEqual(task_id, tasks[0])
|
||||||
|
|
||||||
|
res = self.api.queues.get_by_id(queue=queue)
|
||||||
|
self.assertQueueTasks(res.queue, tasks)
|
||||||
|
|
||||||
|
res = self.api.queues.get_all(id=[queue]).queues[0]
|
||||||
|
self.assertQueueTasks(res, tasks)
|
||||||
|
|
||||||
|
res = self.api.queues.get_all(id=[queue], max_task_entries=2).queues[0]
|
||||||
|
self.assertQueueTasks(res, tasks[:2])
|
||||||
|
|
||||||
|
res = self.api.queues.get_all_ex(id=[queue]).queues[0]
|
||||||
|
self.assertEqual([e.task.id for e in res.entries], tasks)
|
||||||
|
|
||||||
|
res = self.api.queues.get_all_ex(id=[queue], max_task_entries=2).queues[0]
|
||||||
|
self.assertEqual([e.task.id for e in res.entries], tasks[:2])
|
||||||
|
|
||||||
def test_move_task(self):
|
def test_move_task(self):
|
||||||
queue = self._temp_queue("TestTempQueue")
|
queue = self._temp_queue("TestTempQueue")
|
||||||
tasks = [
|
tasks = [
|
||||||
|
Loading…
Reference in New Issue
Block a user