Add workers.get_count endpoint

This commit is contained in:
allegroai 2023-07-26 18:21:52 +03:00
parent 1b650b1689
commit db021f2863
5 changed files with 159 additions and 58 deletions

View File

@ -12,7 +12,7 @@ from jsonmodels.fields import (
) )
from jsonmodels.models import Base from jsonmodels.models import Base
from apiserver.apimodels import make_default, ListField, EnumField, JsonSerializableMixin from apiserver.apimodels import ListField, EnumField, JsonSerializableMixin
DEFAULT_TIMEOUT = 10 * 60 DEFAULT_TIMEOUT = 10 * 60
@ -104,6 +104,10 @@ class GetAllResponse(Base):
workers = ListField(WorkerResponseEntry) workers = ListField(WorkerResponseEntry)
class GetCountRequest(GetAllRequest):
last_seen = IntField(default=0)
class StatsBase(Base): class StatsBase(Base):
worker_ids = ListField(str) worker_ids = ListField(str)

View File

@ -200,6 +200,24 @@ class WorkerBLL:
finally: finally:
self._save_worker(entry) self._save_worker(entry)
def get_count(
self,
company_id: str,
last_seen: Optional[int] = None,
tags: Sequence[str] = None,
system_tags: Sequence[str] = None,
):
if not last_seen:
return len(
self._get_keys(company_id, user_tags=tags, system_tags=system_tags)
)
return len(
self.get_all(
company_id, last_seen=last_seen, tags=tags, system_tags=system_tags
)
)
def get_all( def get_all(
self, self,
company_id: str, company_id: str,
@ -396,15 +414,16 @@ class WorkerBLL:
msg = "Failed saving worker entry" msg = "Failed saving worker entry"
log.exception(msg) log.exception(msg)
def _get( def _get_keys(
self, self,
company: str, company: str,
user: str = "*", user: str = "*",
worker_id: str = "*",
user_tags: Sequence[str] = None, user_tags: Sequence[str] = None,
system_tags: Sequence[str] = None, system_tags: Sequence[str] = None,
) -> Sequence[WorkerEntry]: ) -> Sequence[bytes]:
"""Get worker entries matching the company and user, worker patterns""" if not (user_tags or system_tags):
match = self._get_worker_key(company, user, "*")
return list(self.redis.scan_iter(match))
def filter_by_user(in_keys: Set[bytes]) -> Set[bytes]: def filter_by_user(in_keys: Set[bytes]) -> Set[bytes]:
if user == "*": if user == "*":
@ -412,64 +431,73 @@ class WorkerBLL:
user_bytes = user.encode() user_bytes = user.encode()
return {k for k in in_keys if user_bytes in k} return {k for k in in_keys if user_bytes in k}
if user_tags or system_tags: worker_keys = set()
worker_keys = set() for tags, tags_field in (
for tags, tags_field in ( (user_tags, "tags"),
(user_tags, "tags"), (system_tags, "systemtags"),
(system_tags, "systemtags"), ):
): if not tags:
if not tags: continue
continue
timestamp = int(time()) timestamp = int(time())
include, exclude = partition(tags, key=lambda x: x[0] != "-") include, exclude = partition(tags, key=lambda x: x[0] != "-")
if include: if include:
tagged_workers = set() tagged_workers = set()
for tag in include: for tag in include:
tagged_workers_key = self._get_tagged_workers_key( tagged_workers_key = self._get_tagged_workers_key(
company, tags_field, tag company, tags_field, tag
) )
self.redis.zremrangebyscore( self.redis.zremrangebyscore(
tagged_workers_key, min=0, max=timestamp tagged_workers_key, min=0, max=timestamp
) )
tagged_workers.update( tagged_workers.update(self.redis.zrange(tagged_workers_key, 0, -1))
self.redis.zrange(tagged_workers_key, 0, -1)
) tagged_workers = filter_by_user(tagged_workers)
tagged_workers = filter_by_user(tagged_workers) worker_keys = (
worker_keys = ( worker_keys.intersection(tagged_workers)
worker_keys.intersection(tagged_workers) if worker_keys
if worker_keys else tagged_workers
else tagged_workers )
if not worker_keys:
return []
if exclude:
if not worker_keys:
all_workers_key = self._get_all_workers_key(company)
self.redis.zremrangebyscore(all_workers_key, min=0, max=timestamp)
worker_keys.update(self.redis.zrange(all_workers_key, 0, -1))
worker_keys = filter_by_user(worker_keys)
if not worker_keys:
return []
for tag in exclude:
tagged_workers_key = self._get_tagged_workers_key(
company, tags_field, tag[1:]
)
self.redis.zremrangebyscore(
tagged_workers_key, min=0, max=timestamp
)
worker_keys.difference_update(
self.redis.zrange(tagged_workers_key, 0, -1)
) )
if not worker_keys: if not worker_keys:
return [] return []
if exclude:
if not worker_keys: return list(worker_keys)
all_workers_key = self._get_all_workers_key(company)
self.redis.zremrangebyscore( def _get(
all_workers_key, min=0, max=timestamp self,
) company: str,
worker_keys.update(self.redis.zrange(all_workers_key, 0, -1)) user: str = "*",
worker_keys = filter_by_user(worker_keys) user_tags: Sequence[str] = None,
if not worker_keys: system_tags: Sequence[str] = None,
return [] ) -> Sequence[WorkerEntry]:
for tag in exclude: """Get worker entries matching the company and user, worker patterns"""
tagged_workers_key = self._get_tagged_workers_key(
company, tags_field, tag[1:]
)
self.redis.zremrangebyscore(
tagged_workers_key, min=0, max=timestamp
)
worker_keys.difference_update(
self.redis.zrange(tagged_workers_key, 0, -1)
)
if not worker_keys:
return []
else:
match = self._get_worker_key(company, user, "*")
worker_keys = self.redis.scan_iter(match)
entries = [] entries = []
for key in worker_keys: for key in self._get_keys(
company, user=user, user_tags=user_tags, system_tags=system_tags
):
data = self.redis.get(key) data = self.redis.get(key)
if data: if data:
entries.append(WorkerEntry.from_json(data)) entries.append(WorkerEntry.from_json(data))

View File

@ -311,6 +311,41 @@ get_all {
} }
} }
} }
get_count {
"999.0": {
description: "Returns the number of registered workers."
request {
type: object
properties {
last_seen {
description: """Filter out workers not active for more than last_seen seconds.
A value or 0 or 'none' will disable the filter."""
type: integer
default: 0
}
tags {
description: The list of allowed worker tags. Prepend tag value with '-' in order to exclude
type: array
items { type: string }
}
system_tags {
description: The list of allowed worker system tags. Prepend tag value with '-' in order to exclude
type: array
items { type: string }
}
}
}
response {
type: object
properties {
count {
description: Workers count
type: integer
}
}
}
}
}
register { register {
"2.4" { "2.4" {
description: "Register a worker in the system. Called by the Worker Daemon." description: "Register a worker in the system. Called by the Worker Daemon."

View File

@ -22,6 +22,7 @@ from apiserver.apimodels.workers import (
GetActivityReportRequest, GetActivityReportRequest,
GetActivityReportResponse, GetActivityReportResponse,
ActivityReportSeries, ActivityReportSeries,
GetCountRequest,
) )
from apiserver.bll.workers import WorkerBLL from apiserver.bll.workers import WorkerBLL
from apiserver.config_repo import config from apiserver.config_repo import config
@ -50,6 +51,20 @@ def get_all(call: APICall, company_id: str, request: GetAllRequest):
) )
@endpoint(
"workers.get_count", request_data_model=GetCountRequest,
)
def get_all(call: APICall, company_id: str, request: GetCountRequest):
call.result.data = {
"count": worker_bll.get_count(
company_id,
request.last_seen,
tags=request.tags,
system_tags=request.system_tags,
)
}
@endpoint("workers.register", min_version="2.4", request_data_model=RegisterRequest) @endpoint("workers.register", min_version="2.4", request_data_model=RegisterRequest)
def register(call: APICall, company_id, request: RegisterRequest): def register(call: APICall, company_id, request: RegisterRequest):
worker = request.worker worker = request.worker

View File

@ -27,6 +27,25 @@ class TestWorkersService(TestService):
self.api.workers.unregister(worker=test_worker) self.api.workers.unregister(worker=test_worker)
self._check_exists(test_worker, False) self._check_exists(test_worker, False)
def test_get_count(self):
test_workers = [f"test_{uuid4().hex}" for _ in range(2)]
system_tag = f"tag_{uuid4().hex}"
for w in test_workers:
self.api.workers.register(worker=w, system_tags=[system_tag])
# total workers count include the new ones
count = self.api.workers.get_count().count
self.assertGreater(count, len(test_workers))
# filter by system tag and last seen
count = self.api.workers.get_count(system_tags=[system_tag], last_seen=4).count
self.assertEqual(count, len(test_workers))
time.sleep(5)
# workers not seen recently
count = self.api.workers.get_count(system_tags=[system_tag], last_seen=4).count
self.assertEqual(count, 0)
# but still visible without the last seen filter
count = self.api.workers.get_count(system_tags=[system_tag]).count
self.assertEqual(count, len(test_workers))
def test_workers_timeout(self): def test_workers_timeout(self):
test_worker = f"test_{uuid4().hex}" test_worker = f"test_{uuid4().hex}"
self._check_exists(test_worker, False) self._check_exists(test_worker, False)