diff --git a/apiserver/apimodels/workers.py b/apiserver/apimodels/workers.py index b7123a2..a7524b7 100644 --- a/apiserver/apimodels/workers.py +++ b/apiserver/apimodels/workers.py @@ -20,6 +20,7 @@ DEFAULT_TIMEOUT = 10 * 60 class WorkerRequest(Base): worker = StringField(required=True) tags = ListField(str) + system_tags = ListField(str) class RegisterRequest(WorkerRequest): @@ -76,6 +77,7 @@ class WorkerEntry(Base, JsonSerializableMixin): last_activity_time = DateTimeField(required=True) last_report_time = DateTimeField() tags = ListField(str) + system_tags = ListField(str) class CurrentTaskEntry(IdNameEntry): @@ -97,6 +99,7 @@ class WorkerResponseEntry(WorkerEntry): class GetAllRequest(Base): last_seen = IntField(default=3600) tags = ListField(str) + system_tags = ListField(str) class GetAllResponse(Base): diff --git a/apiserver/bll/workers/__init__.py b/apiserver/bll/workers/__init__.py index 28db111..00a9fc9 100644 --- a/apiserver/bll/workers/__init__.py +++ b/apiserver/bll/workers/__init__.py @@ -1,9 +1,11 @@ import itertools from datetime import datetime, timedelta +from time import time from typing import Sequence, Set, Optional import attr import elasticsearch.helpers +from boltons.iterutils import partition from apiserver.es_factory import es_factory from apiserver.apierrors import APIError @@ -50,6 +52,7 @@ class WorkerBLL: queues: Sequence[str] = None, timeout: int = 0, tags: Sequence[str] = None, + system_tags: Sequence[str] = None, ) -> WorkerEntry: """ Register a worker @@ -94,9 +97,10 @@ class WorkerBLL: register_timeout=timeout, last_activity_time=now, tags=tags, + system_tags=system_tags, ) - self.redis.setex(key, timedelta(seconds=timeout), entry.to_json()) + self._save_worker_data(entry) return entry @@ -121,6 +125,7 @@ class WorkerBLL: ip: str, report: StatusReportRequest, tags: Sequence[str] = None, + system_tags: Sequence[str] = None, ) -> None: """ Write worker status report @@ -141,6 +146,8 @@ class WorkerBLL: if tags is not None: entry.tags = tags + if system_tags is not None: + entry.system_tags = system_tags if report.machine_stats: self._log_stats_to_es( @@ -198,6 +205,7 @@ class WorkerBLL: company_id: str, last_seen: Optional[int] = None, tags: Sequence[str] = None, + system_tags: Sequence[str] = None, ) -> Sequence[WorkerEntry]: """ Get all the company workers that were active during the last_seen period @@ -206,7 +214,7 @@ class WorkerBLL: :return: """ try: - workers = self._get(company_id) + workers = self._get(company_id, user_tags=tags, system_tags=system_tags) except Exception as e: raise server_error.DataError("failed loading worker entries", err=e.args[0]) @@ -218,26 +226,25 @@ class WorkerBLL: if w.last_activity_time.replace(tzinfo=None) >= ref_time ] - if tags: - include = {t for t in tags if not t.startswith("-")} - exclude = {t[1:] for t in tags if t.startswith("-")} - workers = [ - w - for w in workers - if (not include or any(t in include for t in w.tags)) - and (not exclude or all(t not in exclude for t in w.tags)) - ] - return workers def get_all_with_projection( - self, company_id: str, last_seen: int, tags: Sequence[str] = None + self, + company_id: str, + last_seen: int, + tags: Sequence[str] = None, + system_tags: Sequence[str] = None, ) -> Sequence[WorkerResponseEntry]: helpers = list( map( WorkerConversionHelper.from_worker_entry, - self.get_all(company_id=company_id, last_seen=last_seen, tags=tags), + self.get_all( + company_id=company_id, + last_seen=last_seen, + tags=tags, + system_tags=system_tags, + ), ) ) @@ -355,24 +362,115 @@ class WorkerBLL: raise bad_request.InvalidWorkerId(worker=worker) + @staticmethod + def _get_tagged_workers_key(company: str, tags_field: str, tag: str) -> str: + """Build redis key from company, user and worker_id""" + return f"workers.{tags_field}_{company}_{tag}" + + @staticmethod + def _get_all_workers_key(company: str) -> str: + """Build redis key from company, user and worker_id""" + return f"workers_{company}" + + def _save_worker_data(self, entry: WorkerEntry): + self.redis.setex( + entry.key, timedelta(seconds=entry.register_timeout), entry.to_json() + ) + company_id = entry.company.id + expiration = int(time()) + entry.register_timeout + worker_item = {entry.key: expiration} + self.redis.zadd(self._get_all_workers_key(company_id), worker_item) + for tags, tags_field in ( + (entry.tags, "tags"), + (entry.system_tags, "systemtags"), + ): + for tag in tags: + name = self._get_tagged_workers_key(company_id, tags_field, tag) + self.redis.zadd(name, worker_item) + def _save_worker(self, entry: WorkerEntry) -> None: """Save worker entry in Redis""" try: - self.redis.setex( - entry.key, timedelta(seconds=entry.register_timeout), entry.to_json() - ) + self._save_worker_data(entry) except Exception: msg = "Failed saving worker entry" log.exception(msg) def _get( - self, company: str, user: str = "*", worker_id: str = "*" + self, + company: str, + user: str = "*", + worker_id: str = "*", + user_tags: Sequence[str] = None, + system_tags: Sequence[str] = None, ) -> Sequence[WorkerEntry]: """Get worker entries matching the company and user, worker patterns""" + + def filter_by_user(in_keys: Set[bytes]) -> Set[bytes]: + if user == "*": + return in_keys + user_bytes = user.encode() + return {k for k in in_keys if user_bytes in k} + + if user_tags or system_tags: + worker_keys = set() + for tags, tags_field in ( + (user_tags, "tags"), + (system_tags, "systemtags"), + ): + if not tags: + continue + timestamp = int(time()) + include, exclude = partition(tags, key=lambda x: x[0] != "-") + if include: + tagged_workers = set() + for tag in include: + tagged_workers_key = self._get_tagged_workers_key( + company, tags_field, tag + ) + self.redis.zremrangebyscore( + tagged_workers_key, min=0, max=timestamp + ) + tagged_workers.update( + self.redis.zrange(tagged_workers_key, 0, -1) + ) + tagged_workers = filter_by_user(tagged_workers) + worker_keys = ( + worker_keys.intersection(tagged_workers) + if worker_keys + else tagged_workers + ) + if not worker_keys: + return [] + if exclude: + if not worker_keys: + all_workers_key = self._get_all_workers_key(company) + self.redis.zremrangebyscore( + all_workers_key, min=0, max=timestamp + ) + worker_keys.update(self.redis.zrange(all_workers_key, 0, -1)) + worker_keys = filter_by_user(worker_keys) + if not worker_keys: + return [] + for tag in exclude: + tagged_workers_key = self._get_tagged_workers_key( + company, tags_field, tag[1:] + ) + self.redis.zremrangebyscore( + tagged_workers_key, min=0, max=timestamp + ) + worker_keys.difference_update( + self.redis.zrange(tagged_workers_key, 0, -1) + ) + if not worker_keys: + return [] + else: + match = self._get_worker_key(company, user, "*") + worker_keys = self.redis.scan_iter(match) + entries = [] - match = self._get_worker_key(company, user, worker_id) - for r in self.redis.scan_iter(match): - data = self.redis.get(r) + for key in worker_keys: + data = self.redis.get(key) if data: entries.append(WorkerEntry.from_json(data)) diff --git a/apiserver/schema/services/workers.conf b/apiserver/schema/services/workers.conf index 3c437ac..3794f7f 100644 --- a/apiserver/schema/services/workers.conf +++ b/apiserver/schema/services/workers.conf @@ -152,6 +152,15 @@ _definitions { type: array items: { type: string } } + system_tags { + description: "System tags for the worker" + type: array + items: { type: string } + } + key { + description: "Worker entry key" + type: string + } } } @@ -159,11 +168,11 @@ _definitions { type: object properties { id { - description: "ID" + description: "Worker ID" type: string } name { - description: "Name" + description: "Worker name" type: string } } @@ -294,6 +303,13 @@ get_all { items { type: string } } } + "999.0": ${get_all."2.20"} { + request.properties.system_tags { + description: The list of allowed worker system tags. Prepend tag value with '-' in order to exclude + type: array + items { type: string } + } + } } register { "2.4" { @@ -328,6 +344,13 @@ register { properties {} } } + "999.0": ${register."2.4"} { + request.properties.system_tags { + description: "System tags for the worker" + type: array + items: { type: string } + } + } } unregister { "2.4" { @@ -395,6 +418,13 @@ status_report { properties {} } } + "999.0": ${status_report."2.4"} { + request.properties.system_tags { + description: "New system tags for the worker" + type: array + items: { type: string } + } + } } get_metric_keys { "2.4" { diff --git a/apiserver/services/workers.py b/apiserver/services/workers.py index a2e4906..b7d2222 100644 --- a/apiserver/services/workers.py +++ b/apiserver/services/workers.py @@ -42,7 +42,10 @@ worker_bll = WorkerBLL() def get_all(call: APICall, company_id: str, request: GetAllRequest): call.result.data_model = GetAllResponse( workers=worker_bll.get_all_with_projection( - company_id, request.last_seen, tags=request.tags + company_id, + request.last_seen, + tags=request.tags, + system_tags=request.system_tags, ) ) @@ -66,6 +69,7 @@ def register(call: APICall, company_id, request: RegisterRequest): queues=queues, timeout=timeout, tags=request.tags, + system_tags=request.system_tags, ) @@ -84,6 +88,7 @@ def status_report(call: APICall, company_id, request: StatusReportRequest): ip=call.real_ip, report=request, tags=request.tags, + system_tags=request.system_tags, ) diff --git a/apiserver/tests/automated/test_workers.py b/apiserver/tests/automated/test_workers.py index 703ee53..4f043c8 100644 --- a/apiserver/tests/automated/test_workers.py +++ b/apiserver/tests/automated/test_workers.py @@ -37,6 +37,19 @@ class TestWorkersService(TestService): time.sleep(5) self._check_exists(test_worker, False) + def test_system_tags(self): + test_worker = f"test_{uuid4().hex}" + tag = uuid4().hex + + # system_tags support + worker = self.api.workers.get_all(tags=[tag], system_tags=["Application"]).workers[0] + self.assertEqual(worker.id, test_worker) + self.assertEqual(worker.tags, [tag]) + self.assertEqual(worker.system_tags, ["Application"]) + + workers = self.api.workers.get_all(tags=[tag], system_tags=["-Application"]).workers + self.assertFalse(workers) + def test_filters(self): test_worker = f"test_{uuid4().hex}" self.api.workers.register(worker=test_worker, tags=["application"], timeout=3)