Support workers system tags

This commit is contained in:
allegroai 2022-11-29 17:35:25 +02:00
parent caaf801cd0
commit 6b3eff1426
5 changed files with 173 additions and 24 deletions

View File

@ -20,6 +20,7 @@ DEFAULT_TIMEOUT = 10 * 60
class WorkerRequest(Base):
worker = StringField(required=True)
tags = ListField(str)
system_tags = ListField(str)
class RegisterRequest(WorkerRequest):
@ -76,6 +77,7 @@ class WorkerEntry(Base, JsonSerializableMixin):
last_activity_time = DateTimeField(required=True)
last_report_time = DateTimeField()
tags = ListField(str)
system_tags = ListField(str)
class CurrentTaskEntry(IdNameEntry):
@ -97,6 +99,7 @@ class WorkerResponseEntry(WorkerEntry):
class GetAllRequest(Base):
last_seen = IntField(default=3600)
tags = ListField(str)
system_tags = ListField(str)
class GetAllResponse(Base):

View File

@ -1,9 +1,11 @@
import itertools
from datetime import datetime, timedelta
from time import time
from typing import Sequence, Set, Optional
import attr
import elasticsearch.helpers
from boltons.iterutils import partition
from apiserver.es_factory import es_factory
from apiserver.apierrors import APIError
@ -50,6 +52,7 @@ class WorkerBLL:
queues: Sequence[str] = None,
timeout: int = 0,
tags: Sequence[str] = None,
system_tags: Sequence[str] = None,
) -> WorkerEntry:
"""
Register a worker
@ -94,9 +97,10 @@ class WorkerBLL:
register_timeout=timeout,
last_activity_time=now,
tags=tags,
system_tags=system_tags,
)
self.redis.setex(key, timedelta(seconds=timeout), entry.to_json())
self._save_worker_data(entry)
return entry
@ -121,6 +125,7 @@ class WorkerBLL:
ip: str,
report: StatusReportRequest,
tags: Sequence[str] = None,
system_tags: Sequence[str] = None,
) -> None:
"""
Write worker status report
@ -141,6 +146,8 @@ class WorkerBLL:
if tags is not None:
entry.tags = tags
if system_tags is not None:
entry.system_tags = system_tags
if report.machine_stats:
self._log_stats_to_es(
@ -198,6 +205,7 @@ class WorkerBLL:
company_id: str,
last_seen: Optional[int] = None,
tags: Sequence[str] = None,
system_tags: Sequence[str] = None,
) -> Sequence[WorkerEntry]:
"""
Get all the company workers that were active during the last_seen period
@ -206,7 +214,7 @@ class WorkerBLL:
:return:
"""
try:
workers = self._get(company_id)
workers = self._get(company_id, user_tags=tags, system_tags=system_tags)
except Exception as e:
raise server_error.DataError("failed loading worker entries", err=e.args[0])
@ -218,26 +226,25 @@ class WorkerBLL:
if w.last_activity_time.replace(tzinfo=None) >= ref_time
]
if tags:
include = {t for t in tags if not t.startswith("-")}
exclude = {t[1:] for t in tags if t.startswith("-")}
workers = [
w
for w in workers
if (not include or any(t in include for t in w.tags))
and (not exclude or all(t not in exclude for t in w.tags))
]
return workers
def get_all_with_projection(
self, company_id: str, last_seen: int, tags: Sequence[str] = None
self,
company_id: str,
last_seen: int,
tags: Sequence[str] = None,
system_tags: Sequence[str] = None,
) -> Sequence[WorkerResponseEntry]:
helpers = list(
map(
WorkerConversionHelper.from_worker_entry,
self.get_all(company_id=company_id, last_seen=last_seen, tags=tags),
self.get_all(
company_id=company_id,
last_seen=last_seen,
tags=tags,
system_tags=system_tags,
),
)
)
@ -355,24 +362,115 @@ class WorkerBLL:
raise bad_request.InvalidWorkerId(worker=worker)
@staticmethod
def _get_tagged_workers_key(company: str, tags_field: str, tag: str) -> str:
"""Build redis key from company, user and worker_id"""
return f"workers.{tags_field}_{company}_{tag}"
@staticmethod
def _get_all_workers_key(company: str) -> str:
"""Build redis key from company, user and worker_id"""
return f"workers_{company}"
def _save_worker_data(self, entry: WorkerEntry):
self.redis.setex(
entry.key, timedelta(seconds=entry.register_timeout), entry.to_json()
)
company_id = entry.company.id
expiration = int(time()) + entry.register_timeout
worker_item = {entry.key: expiration}
self.redis.zadd(self._get_all_workers_key(company_id), worker_item)
for tags, tags_field in (
(entry.tags, "tags"),
(entry.system_tags, "systemtags"),
):
for tag in tags:
name = self._get_tagged_workers_key(company_id, tags_field, tag)
self.redis.zadd(name, worker_item)
def _save_worker(self, entry: WorkerEntry) -> None:
"""Save worker entry in Redis"""
try:
self.redis.setex(
entry.key, timedelta(seconds=entry.register_timeout), entry.to_json()
)
self._save_worker_data(entry)
except Exception:
msg = "Failed saving worker entry"
log.exception(msg)
def _get(
self, company: str, user: str = "*", worker_id: str = "*"
self,
company: str,
user: str = "*",
worker_id: str = "*",
user_tags: Sequence[str] = None,
system_tags: Sequence[str] = None,
) -> Sequence[WorkerEntry]:
"""Get worker entries matching the company and user, worker patterns"""
def filter_by_user(in_keys: Set[bytes]) -> Set[bytes]:
if user == "*":
return in_keys
user_bytes = user.encode()
return {k for k in in_keys if user_bytes in k}
if user_tags or system_tags:
worker_keys = set()
for tags, tags_field in (
(user_tags, "tags"),
(system_tags, "systemtags"),
):
if not tags:
continue
timestamp = int(time())
include, exclude = partition(tags, key=lambda x: x[0] != "-")
if include:
tagged_workers = set()
for tag in include:
tagged_workers_key = self._get_tagged_workers_key(
company, tags_field, tag
)
self.redis.zremrangebyscore(
tagged_workers_key, min=0, max=timestamp
)
tagged_workers.update(
self.redis.zrange(tagged_workers_key, 0, -1)
)
tagged_workers = filter_by_user(tagged_workers)
worker_keys = (
worker_keys.intersection(tagged_workers)
if worker_keys
else tagged_workers
)
if not worker_keys:
return []
if exclude:
if not worker_keys:
all_workers_key = self._get_all_workers_key(company)
self.redis.zremrangebyscore(
all_workers_key, min=0, max=timestamp
)
worker_keys.update(self.redis.zrange(all_workers_key, 0, -1))
worker_keys = filter_by_user(worker_keys)
if not worker_keys:
return []
for tag in exclude:
tagged_workers_key = self._get_tagged_workers_key(
company, tags_field, tag[1:]
)
self.redis.zremrangebyscore(
tagged_workers_key, min=0, max=timestamp
)
worker_keys.difference_update(
self.redis.zrange(tagged_workers_key, 0, -1)
)
if not worker_keys:
return []
else:
match = self._get_worker_key(company, user, "*")
worker_keys = self.redis.scan_iter(match)
entries = []
match = self._get_worker_key(company, user, worker_id)
for r in self.redis.scan_iter(match):
data = self.redis.get(r)
for key in worker_keys:
data = self.redis.get(key)
if data:
entries.append(WorkerEntry.from_json(data))

View File

@ -152,6 +152,15 @@ _definitions {
type: array
items: { type: string }
}
system_tags {
description: "System tags for the worker"
type: array
items: { type: string }
}
key {
description: "Worker entry key"
type: string
}
}
}
@ -159,11 +168,11 @@ _definitions {
type: object
properties {
id {
description: "ID"
description: "Worker ID"
type: string
}
name {
description: "Name"
description: "Worker name"
type: string
}
}
@ -294,6 +303,13 @@ get_all {
items { type: string }
}
}
"999.0": ${get_all."2.20"} {
request.properties.system_tags {
description: The list of allowed worker system tags. Prepend tag value with '-' in order to exclude
type: array
items { type: string }
}
}
}
register {
"2.4" {
@ -328,6 +344,13 @@ register {
properties {}
}
}
"999.0": ${register."2.4"} {
request.properties.system_tags {
description: "System tags for the worker"
type: array
items: { type: string }
}
}
}
unregister {
"2.4" {
@ -395,6 +418,13 @@ status_report {
properties {}
}
}
"999.0": ${status_report."2.4"} {
request.properties.system_tags {
description: "New system tags for the worker"
type: array
items: { type: string }
}
}
}
get_metric_keys {
"2.4" {

View File

@ -42,7 +42,10 @@ worker_bll = WorkerBLL()
def get_all(call: APICall, company_id: str, request: GetAllRequest):
call.result.data_model = GetAllResponse(
workers=worker_bll.get_all_with_projection(
company_id, request.last_seen, tags=request.tags
company_id,
request.last_seen,
tags=request.tags,
system_tags=request.system_tags,
)
)
@ -66,6 +69,7 @@ def register(call: APICall, company_id, request: RegisterRequest):
queues=queues,
timeout=timeout,
tags=request.tags,
system_tags=request.system_tags,
)
@ -84,6 +88,7 @@ def status_report(call: APICall, company_id, request: StatusReportRequest):
ip=call.real_ip,
report=request,
tags=request.tags,
system_tags=request.system_tags,
)

View File

@ -37,6 +37,19 @@ class TestWorkersService(TestService):
time.sleep(5)
self._check_exists(test_worker, False)
def test_system_tags(self):
test_worker = f"test_{uuid4().hex}"
tag = uuid4().hex
# system_tags support
worker = self.api.workers.get_all(tags=[tag], system_tags=["Application"]).workers[0]
self.assertEqual(worker.id, test_worker)
self.assertEqual(worker.tags, [tag])
self.assertEqual(worker.system_tags, ["Application"])
workers = self.api.workers.get_all(tags=[tag], system_tags=["-Application"]).workers
self.assertFalse(workers)
def test_filters(self):
test_worker = f"test_{uuid4().hex}"
self.api.workers.register(worker=test_worker, tags=["application"], timeout=3)