diff --git a/apiserver/apimodels/workers.py b/apiserver/apimodels/workers.py index 52f49f8..ba98503 100644 --- a/apiserver/apimodels/workers.py +++ b/apiserver/apimodels/workers.py @@ -100,6 +100,7 @@ class GetAllRequest(Base): last_seen = IntField(default=3600) tags = ListField(str) system_tags = ListField(str) + worker_pattern = StringField() class GetAllResponse(Base): diff --git a/apiserver/bll/workers/__init__.py b/apiserver/bll/workers/__init__.py index 0d04aa8..1caec6c 100644 --- a/apiserver/bll/workers/__init__.py +++ b/apiserver/bll/workers/__init__.py @@ -1,4 +1,5 @@ import itertools +import re from datetime import datetime, timedelta from time import time from typing import Sequence, Set, Optional @@ -34,6 +35,8 @@ log = config.logger(__file__) class WorkerBLL: + _key_regex_trans = str.maketrans({"*": ".*", "?": ".?"}) + def __init__(self, es=None, redis=None): self.es_client = es or es_factory.connect("workers") self.config = config.get("services.workers", ConfigTree()) @@ -207,15 +210,25 @@ class WorkerBLL: last_seen: Optional[int] = None, tags: Sequence[str] = None, system_tags: Sequence[str] = None, + worker_pattern: str = None, ): if not last_seen: return len( - self._get_keys(company_id, user_tags=tags, system_tags=system_tags) + self._get_keys( + company_id, + user_tags=tags, + system_tags=system_tags, + worker_pattern=worker_pattern, + ) ) return len( self.get_all( - company_id, last_seen=last_seen, tags=tags, system_tags=system_tags + company_id, + last_seen=last_seen, + tags=tags, + system_tags=system_tags, + worker_pattern=worker_pattern, ) ) @@ -225,6 +238,7 @@ class WorkerBLL: last_seen: Optional[int] = None, tags: Sequence[str] = None, system_tags: Sequence[str] = None, + worker_pattern: str = None, ) -> Sequence[WorkerEntry]: """ Get all the company workers that were active during the last_seen period @@ -233,7 +247,12 @@ class WorkerBLL: :return: """ try: - workers = self._get(company_id, user_tags=tags, system_tags=system_tags) + workers = self._get( + company_id, + user_tags=tags, + system_tags=system_tags, + worker_pattern=worker_pattern, + ) except Exception as e: raise server_error.DataError("failed loading worker entries", err=e.args[0]) @@ -253,6 +272,7 @@ class WorkerBLL: last_seen: int, tags: Sequence[str] = None, system_tags: Sequence[str] = None, + worker_pattern: str = None, ) -> Sequence[WorkerResponseEntry]: helpers = [ WorkerConversionHelper.from_worker_entry(entry) @@ -261,6 +281,7 @@ class WorkerBLL: last_seen=last_seen, tags=tags, system_tags=system_tags, + worker_pattern=worker_pattern, ) ] @@ -320,7 +341,7 @@ class WorkerBLL: for helper in helpers: worker = helper.worker if helper.task_id: - task = tasks_info.get(helper.task_id, None) + task: Task = tasks_info.get(helper.task_id, None) if task: worker.task.running_time = (task.active_duration or 0) * 1000 worker.task.last_iteration = task.last_iteration @@ -416,16 +437,25 @@ class WorkerBLL: user: str = "*", user_tags: Sequence[str] = None, system_tags: Sequence[str] = None, + worker_pattern: str = None, ) -> Sequence[bytes]: if not (user_tags or system_tags): - match = self._get_worker_key(company, user, "*") + match = self._get_worker_key(company, user, worker_pattern or "*") return list(self.redis.scan_iter(match)) - def filter_by_user(in_keys: Set[bytes]) -> Set[bytes]: - if user == "*": - return in_keys - user_bytes = user.encode() - return {k for k in in_keys if user_bytes in k} + def filter_by_user_and_pattern(in_keys: Set[bytes]) -> Set[bytes]: + if user != "*": + user_bytes = user.encode() + in_keys = {k for k in in_keys if user_bytes in k} + + if worker_pattern: + worker_pattern_bytes = ( + f"{worker_pattern.translate(self._key_regex_trans)}$".encode() + ) + regex = re.compile(worker_pattern_bytes) + in_keys = {k for k in in_keys if regex.search(k)} + + return in_keys worker_keys = set() for tags, tags_field in ( @@ -448,7 +478,7 @@ class WorkerBLL: ) tagged_workers.update(self.redis.zrange(tagged_workers_key, 0, -1)) - tagged_workers = filter_by_user(tagged_workers) + tagged_workers = filter_by_user_and_pattern(tagged_workers) worker_keys = ( worker_keys.intersection(tagged_workers) if worker_keys @@ -462,7 +492,7 @@ class WorkerBLL: all_workers_key = self._get_all_workers_key(company) self.redis.zremrangebyscore(all_workers_key, min=0, max=timestamp) worker_keys.update(self.redis.zrange(all_workers_key, 0, -1)) - worker_keys = filter_by_user(worker_keys) + worker_keys = filter_by_user_and_pattern(worker_keys) if not worker_keys: return [] @@ -487,13 +517,18 @@ class WorkerBLL: user: str = "*", user_tags: Sequence[str] = None, system_tags: Sequence[str] = None, + worker_pattern: str = None, ) -> Sequence[WorkerEntry]: """Get worker entries matching the company and user, worker patterns""" entries = [] for keys in chunked_iter( self._get_keys( - company, user=user, user_tags=user_tags, system_tags=system_tags + company, + user=user, + user_tags=user_tags, + system_tags=system_tags, + worker_pattern=worker_pattern, ), 1000, ): diff --git a/apiserver/schema/services/workers.conf b/apiserver/schema/services/workers.conf index 1d732bf..42f6fd8 100644 --- a/apiserver/schema/services/workers.conf +++ b/apiserver/schema/services/workers.conf @@ -310,6 +310,12 @@ get_all { items { type: string } } } + "2.30": ${get_all."2.22"} { + request.properties.worker_pattern { + description: The worker name pattern. If specified then only matching keys returned + type: string + } + } } get_count { "2.26": { @@ -345,6 +351,12 @@ get_count { } } } + "2.30": ${get_count."2.26"} { + request.properties.worker_pattern { + description: The worker name pattern. If specified then only matching keys are counted + type: string + } + } } register { "2.4" { diff --git a/apiserver/services/workers.py b/apiserver/services/workers.py index 92d064d..4bea19a 100644 --- a/apiserver/services/workers.py +++ b/apiserver/services/workers.py @@ -47,6 +47,7 @@ def get_all(call: APICall, company_id: str, request: GetAllRequest): request.last_seen, tags=request.tags, system_tags=request.system_tags, + worker_pattern=request.worker_pattern, ) ) @@ -61,6 +62,7 @@ def get_all(call: APICall, company_id: str, request: GetCountRequest): request.last_seen, tags=request.tags, system_tags=request.system_tags, + worker_pattern=request.worker_pattern, ) } diff --git a/apiserver/tests/automated/test_workers.py b/apiserver/tests/automated/test_workers.py index 93f41e8..3108e0c 100644 --- a/apiserver/tests/automated/test_workers.py +++ b/apiserver/tests/automated/test_workers.py @@ -32,7 +32,7 @@ class TestWorkersService(TestService): self.api.workers.register(worker=w, system_tags=[system_tag]) # total workers count include the new ones count = self.api.workers.get_count().count - self.assertGreater(count, len(test_workers)) + self.assertGreaterEqual(count, len(test_workers)) # filter by system tag and last seen count = self.api.workers.get_count(system_tags=[system_tag], last_seen=4).count self.assertEqual(count, len(test_workers))