Support workers filtering with tags

This commit is contained in:
allegroai 2022-07-08 17:37:33 +03:00
parent b41ab8c550
commit b2feafac09
4 changed files with 473 additions and 450 deletions

View File

@ -96,6 +96,7 @@ class WorkerResponseEntry(WorkerEntry):
class GetAllRequest(Base):
last_seen = IntField(default=3600)
tags = ListField(str)
class GetAllResponse(Base):

View File

@ -76,7 +76,7 @@ class WorkerBLL:
raise bad_request.InvalidUserId(**query)
company = Company.objects(id=company_id).only("id", "name").first()
if not company:
raise server_error.InternalError("invalid company", company=company_id)
raise bad_request.InvalidId("invalid company", company=company_id)
queue_objs = Queue.objects(company=company_id, id__in=queues).only("id")
if len(queue_objs) < len(queues):
@ -189,7 +189,10 @@ class WorkerBLL:
self._save_worker(entry)
def get_all(
self, company_id: str, last_seen: Optional[int] = None
self,
company_id: str,
last_seen: Optional[int] = None,
tags: Sequence[str] = None,
) -> Sequence[WorkerEntry]:
"""
Get all the company workers that were active during the last_seen period
@ -210,16 +213,26 @@ class WorkerBLL:
if w.last_activity_time.replace(tzinfo=None) >= ref_time
]
if tags:
include = {t for t in tags if not t.startswith("-")}
exclude = {t[1:] for t in tags if t.startswith("-")}
workers = [
w
for w in workers
if (not include or any(t in include for t in w.tags))
and (not exclude or all(t not in exclude for t in w.tags))
]
return workers
def get_all_with_projection(
self, company_id: str, last_seen: int
self, company_id: str, last_seen: int, tags: Sequence[str] = None
) -> Sequence[WorkerResponseEntry]:
helpers = list(
map(
WorkerConversionHelper.from_worker_entry,
self.get_all(company_id=company_id, last_seen=last_seen),
self.get_all(company_id=company_id, last_seen=last_seen, tags=tags),
)
)

View File

@ -1,4 +1,3 @@
{
_description: "Provides an API for worker machines, allowing workers to report status and get tasks for execution"
_definitions {
metrics_category {
@ -288,6 +287,13 @@
}
}
}
"999.0": ${get_all."2.4"} {
request.properties.tags {
description: The list of allowed worker tags. Prepend tag value with '-' in order to exclude
type: array
items { type: string }
}
}
}
register {
"2.4" {
@ -499,4 +505,3 @@
}
}
}
}

View File

@ -41,7 +41,9 @@ worker_bll = WorkerBLL()
)
def get_all(call: APICall, company_id: str, request: GetAllRequest):
call.result.data_model = GetAllResponse(
workers=worker_bll.get_all_with_projection(company_id, request.last_seen)
workers=worker_bll.get_all_with_projection(
company_id, request.last_seen, tags=request.tags
)
)
@ -72,7 +74,9 @@ def unregister(call: APICall, company_id, req_model: WorkerRequest):
worker_bll.unregister_worker(company_id, call.identity.user, req_model.worker)
@endpoint("workers.status_report", min_version="2.4", request_data_model=StatusReportRequest)
@endpoint(
"workers.status_report", min_version="2.4", request_data_model=StatusReportRequest
)
def status_report(call: APICall, company_id, request: StatusReportRequest):
worker_bll.status_report(
company_id=company_id,