Support workers system tags

This commit is contained in:
allegroai 2022-11-29 17:35:25 +02:00
parent caaf801cd0
commit 6b3eff1426
5 changed files with 173 additions and 24 deletions

View File

@ -20,6 +20,7 @@ DEFAULT_TIMEOUT = 10 * 60
class WorkerRequest(Base): class WorkerRequest(Base):
worker = StringField(required=True) worker = StringField(required=True)
tags = ListField(str) tags = ListField(str)
system_tags = ListField(str)
class RegisterRequest(WorkerRequest): class RegisterRequest(WorkerRequest):
@ -76,6 +77,7 @@ class WorkerEntry(Base, JsonSerializableMixin):
last_activity_time = DateTimeField(required=True) last_activity_time = DateTimeField(required=True)
last_report_time = DateTimeField() last_report_time = DateTimeField()
tags = ListField(str) tags = ListField(str)
system_tags = ListField(str)
class CurrentTaskEntry(IdNameEntry): class CurrentTaskEntry(IdNameEntry):
@ -97,6 +99,7 @@ class WorkerResponseEntry(WorkerEntry):
class GetAllRequest(Base): class GetAllRequest(Base):
last_seen = IntField(default=3600) last_seen = IntField(default=3600)
tags = ListField(str) tags = ListField(str)
system_tags = ListField(str)
class GetAllResponse(Base): class GetAllResponse(Base):

View File

@ -1,9 +1,11 @@
import itertools import itertools
from datetime import datetime, timedelta from datetime import datetime, timedelta
from time import time
from typing import Sequence, Set, Optional from typing import Sequence, Set, Optional
import attr import attr
import elasticsearch.helpers import elasticsearch.helpers
from boltons.iterutils import partition
from apiserver.es_factory import es_factory from apiserver.es_factory import es_factory
from apiserver.apierrors import APIError from apiserver.apierrors import APIError
@ -50,6 +52,7 @@ class WorkerBLL:
queues: Sequence[str] = None, queues: Sequence[str] = None,
timeout: int = 0, timeout: int = 0,
tags: Sequence[str] = None, tags: Sequence[str] = None,
system_tags: Sequence[str] = None,
) -> WorkerEntry: ) -> WorkerEntry:
""" """
Register a worker Register a worker
@ -94,9 +97,10 @@ class WorkerBLL:
register_timeout=timeout, register_timeout=timeout,
last_activity_time=now, last_activity_time=now,
tags=tags, tags=tags,
system_tags=system_tags,
) )
self.redis.setex(key, timedelta(seconds=timeout), entry.to_json()) self._save_worker_data(entry)
return entry return entry
@ -121,6 +125,7 @@ class WorkerBLL:
ip: str, ip: str,
report: StatusReportRequest, report: StatusReportRequest,
tags: Sequence[str] = None, tags: Sequence[str] = None,
system_tags: Sequence[str] = None,
) -> None: ) -> None:
""" """
Write worker status report Write worker status report
@ -141,6 +146,8 @@ class WorkerBLL:
if tags is not None: if tags is not None:
entry.tags = tags entry.tags = tags
if system_tags is not None:
entry.system_tags = system_tags
if report.machine_stats: if report.machine_stats:
self._log_stats_to_es( self._log_stats_to_es(
@ -198,6 +205,7 @@ class WorkerBLL:
company_id: str, company_id: str,
last_seen: Optional[int] = None, last_seen: Optional[int] = None,
tags: Sequence[str] = None, tags: Sequence[str] = None,
system_tags: Sequence[str] = None,
) -> Sequence[WorkerEntry]: ) -> Sequence[WorkerEntry]:
""" """
Get all the company workers that were active during the last_seen period Get all the company workers that were active during the last_seen period
@ -206,7 +214,7 @@ class WorkerBLL:
:return: :return:
""" """
try: try:
workers = self._get(company_id) workers = self._get(company_id, user_tags=tags, system_tags=system_tags)
except Exception as e: except Exception as e:
raise server_error.DataError("failed loading worker entries", err=e.args[0]) raise server_error.DataError("failed loading worker entries", err=e.args[0])
@ -218,26 +226,25 @@ class WorkerBLL:
if w.last_activity_time.replace(tzinfo=None) >= ref_time if w.last_activity_time.replace(tzinfo=None) >= ref_time
] ]
if tags:
include = {t for t in tags if not t.startswith("-")}
exclude = {t[1:] for t in tags if t.startswith("-")}
workers = [
w
for w in workers
if (not include or any(t in include for t in w.tags))
and (not exclude or all(t not in exclude for t in w.tags))
]
return workers return workers
def get_all_with_projection( def get_all_with_projection(
self, company_id: str, last_seen: int, tags: Sequence[str] = None self,
company_id: str,
last_seen: int,
tags: Sequence[str] = None,
system_tags: Sequence[str] = None,
) -> Sequence[WorkerResponseEntry]: ) -> Sequence[WorkerResponseEntry]:
helpers = list( helpers = list(
map( map(
WorkerConversionHelper.from_worker_entry, WorkerConversionHelper.from_worker_entry,
self.get_all(company_id=company_id, last_seen=last_seen, tags=tags), self.get_all(
company_id=company_id,
last_seen=last_seen,
tags=tags,
system_tags=system_tags,
),
) )
) )
@ -355,24 +362,115 @@ class WorkerBLL:
raise bad_request.InvalidWorkerId(worker=worker) raise bad_request.InvalidWorkerId(worker=worker)
@staticmethod
def _get_tagged_workers_key(company: str, tags_field: str, tag: str) -> str:
"""Build redis key from company, user and worker_id"""
return f"workers.{tags_field}_{company}_{tag}"
@staticmethod
def _get_all_workers_key(company: str) -> str:
"""Build redis key from company, user and worker_id"""
return f"workers_{company}"
def _save_worker_data(self, entry: WorkerEntry):
self.redis.setex(
entry.key, timedelta(seconds=entry.register_timeout), entry.to_json()
)
company_id = entry.company.id
expiration = int(time()) + entry.register_timeout
worker_item = {entry.key: expiration}
self.redis.zadd(self._get_all_workers_key(company_id), worker_item)
for tags, tags_field in (
(entry.tags, "tags"),
(entry.system_tags, "systemtags"),
):
for tag in tags:
name = self._get_tagged_workers_key(company_id, tags_field, tag)
self.redis.zadd(name, worker_item)
def _save_worker(self, entry: WorkerEntry) -> None: def _save_worker(self, entry: WorkerEntry) -> None:
"""Save worker entry in Redis""" """Save worker entry in Redis"""
try: try:
self.redis.setex( self._save_worker_data(entry)
entry.key, timedelta(seconds=entry.register_timeout), entry.to_json()
)
except Exception: except Exception:
msg = "Failed saving worker entry" msg = "Failed saving worker entry"
log.exception(msg) log.exception(msg)
def _get( def _get(
self, company: str, user: str = "*", worker_id: str = "*" self,
company: str,
user: str = "*",
worker_id: str = "*",
user_tags: Sequence[str] = None,
system_tags: Sequence[str] = None,
) -> Sequence[WorkerEntry]: ) -> Sequence[WorkerEntry]:
"""Get worker entries matching the company and user, worker patterns""" """Get worker entries matching the company and user, worker patterns"""
def filter_by_user(in_keys: Set[bytes]) -> Set[bytes]:
if user == "*":
return in_keys
user_bytes = user.encode()
return {k for k in in_keys if user_bytes in k}
if user_tags or system_tags:
worker_keys = set()
for tags, tags_field in (
(user_tags, "tags"),
(system_tags, "systemtags"),
):
if not tags:
continue
timestamp = int(time())
include, exclude = partition(tags, key=lambda x: x[0] != "-")
if include:
tagged_workers = set()
for tag in include:
tagged_workers_key = self._get_tagged_workers_key(
company, tags_field, tag
)
self.redis.zremrangebyscore(
tagged_workers_key, min=0, max=timestamp
)
tagged_workers.update(
self.redis.zrange(tagged_workers_key, 0, -1)
)
tagged_workers = filter_by_user(tagged_workers)
worker_keys = (
worker_keys.intersection(tagged_workers)
if worker_keys
else tagged_workers
)
if not worker_keys:
return []
if exclude:
if not worker_keys:
all_workers_key = self._get_all_workers_key(company)
self.redis.zremrangebyscore(
all_workers_key, min=0, max=timestamp
)
worker_keys.update(self.redis.zrange(all_workers_key, 0, -1))
worker_keys = filter_by_user(worker_keys)
if not worker_keys:
return []
for tag in exclude:
tagged_workers_key = self._get_tagged_workers_key(
company, tags_field, tag[1:]
)
self.redis.zremrangebyscore(
tagged_workers_key, min=0, max=timestamp
)
worker_keys.difference_update(
self.redis.zrange(tagged_workers_key, 0, -1)
)
if not worker_keys:
return []
else:
match = self._get_worker_key(company, user, "*")
worker_keys = self.redis.scan_iter(match)
entries = [] entries = []
match = self._get_worker_key(company, user, worker_id) for key in worker_keys:
for r in self.redis.scan_iter(match): data = self.redis.get(key)
data = self.redis.get(r)
if data: if data:
entries.append(WorkerEntry.from_json(data)) entries.append(WorkerEntry.from_json(data))

View File

@ -152,6 +152,15 @@ _definitions {
type: array type: array
items: { type: string } items: { type: string }
} }
system_tags {
description: "System tags for the worker"
type: array
items: { type: string }
}
key {
description: "Worker entry key"
type: string
}
} }
} }
@ -159,11 +168,11 @@ _definitions {
type: object type: object
properties { properties {
id { id {
description: "ID" description: "Worker ID"
type: string type: string
} }
name { name {
description: "Name" description: "Worker name"
type: string type: string
} }
} }
@ -294,6 +303,13 @@ get_all {
items { type: string } items { type: string }
} }
} }
"999.0": ${get_all."2.20"} {
request.properties.system_tags {
description: The list of allowed worker system tags. Prepend tag value with '-' in order to exclude
type: array
items { type: string }
}
}
} }
register { register {
"2.4" { "2.4" {
@ -328,6 +344,13 @@ register {
properties {} properties {}
} }
} }
"999.0": ${register."2.4"} {
request.properties.system_tags {
description: "System tags for the worker"
type: array
items: { type: string }
}
}
} }
unregister { unregister {
"2.4" { "2.4" {
@ -395,6 +418,13 @@ status_report {
properties {} properties {}
} }
} }
"999.0": ${status_report."2.4"} {
request.properties.system_tags {
description: "New system tags for the worker"
type: array
items: { type: string }
}
}
} }
get_metric_keys { get_metric_keys {
"2.4" { "2.4" {

View File

@ -42,7 +42,10 @@ worker_bll = WorkerBLL()
def get_all(call: APICall, company_id: str, request: GetAllRequest): def get_all(call: APICall, company_id: str, request: GetAllRequest):
call.result.data_model = GetAllResponse( call.result.data_model = GetAllResponse(
workers=worker_bll.get_all_with_projection( workers=worker_bll.get_all_with_projection(
company_id, request.last_seen, tags=request.tags company_id,
request.last_seen,
tags=request.tags,
system_tags=request.system_tags,
) )
) )
@ -66,6 +69,7 @@ def register(call: APICall, company_id, request: RegisterRequest):
queues=queues, queues=queues,
timeout=timeout, timeout=timeout,
tags=request.tags, tags=request.tags,
system_tags=request.system_tags,
) )
@ -84,6 +88,7 @@ def status_report(call: APICall, company_id, request: StatusReportRequest):
ip=call.real_ip, ip=call.real_ip,
report=request, report=request,
tags=request.tags, tags=request.tags,
system_tags=request.system_tags,
) )

View File

@ -37,6 +37,19 @@ class TestWorkersService(TestService):
time.sleep(5) time.sleep(5)
self._check_exists(test_worker, False) self._check_exists(test_worker, False)
def test_system_tags(self):
test_worker = f"test_{uuid4().hex}"
tag = uuid4().hex
# system_tags support
worker = self.api.workers.get_all(tags=[tag], system_tags=["Application"]).workers[0]
self.assertEqual(worker.id, test_worker)
self.assertEqual(worker.tags, [tag])
self.assertEqual(worker.system_tags, ["Application"])
workers = self.api.workers.get_all(tags=[tag], system_tags=["-Application"]).workers
self.assertFalse(workers)
def test_filters(self): def test_filters(self):
test_worker = f"test_{uuid4().hex}" test_worker = f"test_{uuid4().hex}"
self.api.workers.register(worker=test_worker, tags=["application"], timeout=3) self.api.workers.register(worker=test_worker, tags=["application"], timeout=3)