Add worker_pattern parameter to workers.get_all and get_count endpoints

This commit is contained in:
allegroai 2024-06-20 17:59:28 +03:00
parent dd0ecb712d
commit f1c876089b
5 changed files with 64 additions and 14 deletions

View File

@ -100,6 +100,7 @@ class GetAllRequest(Base):
last_seen = IntField(default=3600) last_seen = IntField(default=3600)
tags = ListField(str) tags = ListField(str)
system_tags = ListField(str) system_tags = ListField(str)
worker_pattern = StringField()
class GetAllResponse(Base): class GetAllResponse(Base):

View File

@ -1,4 +1,5 @@
import itertools import itertools
import re
from datetime import datetime, timedelta from datetime import datetime, timedelta
from time import time from time import time
from typing import Sequence, Set, Optional from typing import Sequence, Set, Optional
@ -34,6 +35,8 @@ log = config.logger(__file__)
class WorkerBLL: class WorkerBLL:
_key_regex_trans = str.maketrans({"*": ".*", "?": ".?"})
def __init__(self, es=None, redis=None): def __init__(self, es=None, redis=None):
self.es_client = es or es_factory.connect("workers") self.es_client = es or es_factory.connect("workers")
self.config = config.get("services.workers", ConfigTree()) self.config = config.get("services.workers", ConfigTree())
@ -207,15 +210,25 @@ class WorkerBLL:
last_seen: Optional[int] = None, last_seen: Optional[int] = None,
tags: Sequence[str] = None, tags: Sequence[str] = None,
system_tags: Sequence[str] = None, system_tags: Sequence[str] = None,
worker_pattern: str = None,
): ):
if not last_seen: if not last_seen:
return len( return len(
self._get_keys(company_id, user_tags=tags, system_tags=system_tags) self._get_keys(
company_id,
user_tags=tags,
system_tags=system_tags,
worker_pattern=worker_pattern,
)
) )
return len( return len(
self.get_all( self.get_all(
company_id, last_seen=last_seen, tags=tags, system_tags=system_tags company_id,
last_seen=last_seen,
tags=tags,
system_tags=system_tags,
worker_pattern=worker_pattern,
) )
) )
@ -225,6 +238,7 @@ class WorkerBLL:
last_seen: Optional[int] = None, last_seen: Optional[int] = None,
tags: Sequence[str] = None, tags: Sequence[str] = None,
system_tags: Sequence[str] = None, system_tags: Sequence[str] = None,
worker_pattern: str = None,
) -> Sequence[WorkerEntry]: ) -> Sequence[WorkerEntry]:
""" """
Get all the company workers that were active during the last_seen period Get all the company workers that were active during the last_seen period
@ -233,7 +247,12 @@ class WorkerBLL:
:return: :return:
""" """
try: try:
workers = self._get(company_id, user_tags=tags, system_tags=system_tags) workers = self._get(
company_id,
user_tags=tags,
system_tags=system_tags,
worker_pattern=worker_pattern,
)
except Exception as e: except Exception as e:
raise server_error.DataError("failed loading worker entries", err=e.args[0]) raise server_error.DataError("failed loading worker entries", err=e.args[0])
@ -253,6 +272,7 @@ class WorkerBLL:
last_seen: int, last_seen: int,
tags: Sequence[str] = None, tags: Sequence[str] = None,
system_tags: Sequence[str] = None, system_tags: Sequence[str] = None,
worker_pattern: str = None,
) -> Sequence[WorkerResponseEntry]: ) -> Sequence[WorkerResponseEntry]:
helpers = [ helpers = [
WorkerConversionHelper.from_worker_entry(entry) WorkerConversionHelper.from_worker_entry(entry)
@ -261,6 +281,7 @@ class WorkerBLL:
last_seen=last_seen, last_seen=last_seen,
tags=tags, tags=tags,
system_tags=system_tags, system_tags=system_tags,
worker_pattern=worker_pattern,
) )
] ]
@ -320,7 +341,7 @@ class WorkerBLL:
for helper in helpers: for helper in helpers:
worker = helper.worker worker = helper.worker
if helper.task_id: if helper.task_id:
task = tasks_info.get(helper.task_id, None) task: Task = tasks_info.get(helper.task_id, None)
if task: if task:
worker.task.running_time = (task.active_duration or 0) * 1000 worker.task.running_time = (task.active_duration or 0) * 1000
worker.task.last_iteration = task.last_iteration worker.task.last_iteration = task.last_iteration
@ -416,16 +437,25 @@ class WorkerBLL:
user: str = "*", user: str = "*",
user_tags: Sequence[str] = None, user_tags: Sequence[str] = None,
system_tags: Sequence[str] = None, system_tags: Sequence[str] = None,
worker_pattern: str = None,
) -> Sequence[bytes]: ) -> Sequence[bytes]:
if not (user_tags or system_tags): if not (user_tags or system_tags):
match = self._get_worker_key(company, user, "*") match = self._get_worker_key(company, user, worker_pattern or "*")
return list(self.redis.scan_iter(match)) return list(self.redis.scan_iter(match))
def filter_by_user(in_keys: Set[bytes]) -> Set[bytes]: def filter_by_user_and_pattern(in_keys: Set[bytes]) -> Set[bytes]:
if user == "*": if user != "*":
return in_keys
user_bytes = user.encode() user_bytes = user.encode()
return {k for k in in_keys if user_bytes in k} in_keys = {k for k in in_keys if user_bytes in k}
if worker_pattern:
worker_pattern_bytes = (
f"{worker_pattern.translate(self._key_regex_trans)}$".encode()
)
regex = re.compile(worker_pattern_bytes)
in_keys = {k for k in in_keys if regex.search(k)}
return in_keys
worker_keys = set() worker_keys = set()
for tags, tags_field in ( for tags, tags_field in (
@ -448,7 +478,7 @@ class WorkerBLL:
) )
tagged_workers.update(self.redis.zrange(tagged_workers_key, 0, -1)) tagged_workers.update(self.redis.zrange(tagged_workers_key, 0, -1))
tagged_workers = filter_by_user(tagged_workers) tagged_workers = filter_by_user_and_pattern(tagged_workers)
worker_keys = ( worker_keys = (
worker_keys.intersection(tagged_workers) worker_keys.intersection(tagged_workers)
if worker_keys if worker_keys
@ -462,7 +492,7 @@ class WorkerBLL:
all_workers_key = self._get_all_workers_key(company) all_workers_key = self._get_all_workers_key(company)
self.redis.zremrangebyscore(all_workers_key, min=0, max=timestamp) self.redis.zremrangebyscore(all_workers_key, min=0, max=timestamp)
worker_keys.update(self.redis.zrange(all_workers_key, 0, -1)) worker_keys.update(self.redis.zrange(all_workers_key, 0, -1))
worker_keys = filter_by_user(worker_keys) worker_keys = filter_by_user_and_pattern(worker_keys)
if not worker_keys: if not worker_keys:
return [] return []
@ -487,13 +517,18 @@ class WorkerBLL:
user: str = "*", user: str = "*",
user_tags: Sequence[str] = None, user_tags: Sequence[str] = None,
system_tags: Sequence[str] = None, system_tags: Sequence[str] = None,
worker_pattern: str = None,
) -> Sequence[WorkerEntry]: ) -> Sequence[WorkerEntry]:
"""Get worker entries matching the company and user, worker patterns""" """Get worker entries matching the company and user, worker patterns"""
entries = [] entries = []
for keys in chunked_iter( for keys in chunked_iter(
self._get_keys( self._get_keys(
company, user=user, user_tags=user_tags, system_tags=system_tags company,
user=user,
user_tags=user_tags,
system_tags=system_tags,
worker_pattern=worker_pattern,
), ),
1000, 1000,
): ):

View File

@ -310,6 +310,12 @@ get_all {
items { type: string } items { type: string }
} }
} }
"2.30": ${get_all."2.22"} {
request.properties.worker_pattern {
description: The worker name pattern. If specified then only matching keys returned
type: string
}
}
} }
get_count { get_count {
"2.26": { "2.26": {
@ -345,6 +351,12 @@ get_count {
} }
} }
} }
"2.30": ${get_count."2.26"} {
request.properties.worker_pattern {
description: The worker name pattern. If specified then only matching keys are counted
type: string
}
}
} }
register { register {
"2.4" { "2.4" {

View File

@ -47,6 +47,7 @@ def get_all(call: APICall, company_id: str, request: GetAllRequest):
request.last_seen, request.last_seen,
tags=request.tags, tags=request.tags,
system_tags=request.system_tags, system_tags=request.system_tags,
worker_pattern=request.worker_pattern,
) )
) )
@ -61,6 +62,7 @@ def get_all(call: APICall, company_id: str, request: GetCountRequest):
request.last_seen, request.last_seen,
tags=request.tags, tags=request.tags,
system_tags=request.system_tags, system_tags=request.system_tags,
worker_pattern=request.worker_pattern,
) )
} }

View File

@ -32,7 +32,7 @@ class TestWorkersService(TestService):
self.api.workers.register(worker=w, system_tags=[system_tag]) self.api.workers.register(worker=w, system_tags=[system_tag])
# total workers count include the new ones # total workers count include the new ones
count = self.api.workers.get_count().count count = self.api.workers.get_count().count
self.assertGreater(count, len(test_workers)) self.assertGreaterEqual(count, len(test_workers))
# filter by system tag and last seen # filter by system tag and last seen
count = self.api.workers.get_count(system_tags=[system_tag], last_seen=4).count count = self.api.workers.get_count(system_tags=[system_tag], last_seen=4).count
self.assertEqual(count, len(test_workers)) self.assertEqual(count, len(test_workers))