mirror of
https://github.com/clearml/clearml-server
synced 2025-05-08 14:04:44 +00:00
Support workers system tags
This commit is contained in:
parent
caaf801cd0
commit
6b3eff1426
@ -20,6 +20,7 @@ DEFAULT_TIMEOUT = 10 * 60
|
||||
class WorkerRequest(Base):
|
||||
worker = StringField(required=True)
|
||||
tags = ListField(str)
|
||||
system_tags = ListField(str)
|
||||
|
||||
|
||||
class RegisterRequest(WorkerRequest):
|
||||
@ -76,6 +77,7 @@ class WorkerEntry(Base, JsonSerializableMixin):
|
||||
last_activity_time = DateTimeField(required=True)
|
||||
last_report_time = DateTimeField()
|
||||
tags = ListField(str)
|
||||
system_tags = ListField(str)
|
||||
|
||||
|
||||
class CurrentTaskEntry(IdNameEntry):
|
||||
@ -97,6 +99,7 @@ class WorkerResponseEntry(WorkerEntry):
|
||||
class GetAllRequest(Base):
|
||||
last_seen = IntField(default=3600)
|
||||
tags = ListField(str)
|
||||
system_tags = ListField(str)
|
||||
|
||||
|
||||
class GetAllResponse(Base):
|
||||
|
@ -1,9 +1,11 @@
|
||||
import itertools
|
||||
from datetime import datetime, timedelta
|
||||
from time import time
|
||||
from typing import Sequence, Set, Optional
|
||||
|
||||
import attr
|
||||
import elasticsearch.helpers
|
||||
from boltons.iterutils import partition
|
||||
|
||||
from apiserver.es_factory import es_factory
|
||||
from apiserver.apierrors import APIError
|
||||
@ -50,6 +52,7 @@ class WorkerBLL:
|
||||
queues: Sequence[str] = None,
|
||||
timeout: int = 0,
|
||||
tags: Sequence[str] = None,
|
||||
system_tags: Sequence[str] = None,
|
||||
) -> WorkerEntry:
|
||||
"""
|
||||
Register a worker
|
||||
@ -94,9 +97,10 @@ class WorkerBLL:
|
||||
register_timeout=timeout,
|
||||
last_activity_time=now,
|
||||
tags=tags,
|
||||
system_tags=system_tags,
|
||||
)
|
||||
|
||||
self.redis.setex(key, timedelta(seconds=timeout), entry.to_json())
|
||||
self._save_worker_data(entry)
|
||||
|
||||
return entry
|
||||
|
||||
@ -121,6 +125,7 @@ class WorkerBLL:
|
||||
ip: str,
|
||||
report: StatusReportRequest,
|
||||
tags: Sequence[str] = None,
|
||||
system_tags: Sequence[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Write worker status report
|
||||
@ -141,6 +146,8 @@ class WorkerBLL:
|
||||
|
||||
if tags is not None:
|
||||
entry.tags = tags
|
||||
if system_tags is not None:
|
||||
entry.system_tags = system_tags
|
||||
|
||||
if report.machine_stats:
|
||||
self._log_stats_to_es(
|
||||
@ -198,6 +205,7 @@ class WorkerBLL:
|
||||
company_id: str,
|
||||
last_seen: Optional[int] = None,
|
||||
tags: Sequence[str] = None,
|
||||
system_tags: Sequence[str] = None,
|
||||
) -> Sequence[WorkerEntry]:
|
||||
"""
|
||||
Get all the company workers that were active during the last_seen period
|
||||
@ -206,7 +214,7 @@ class WorkerBLL:
|
||||
:return:
|
||||
"""
|
||||
try:
|
||||
workers = self._get(company_id)
|
||||
workers = self._get(company_id, user_tags=tags, system_tags=system_tags)
|
||||
except Exception as e:
|
||||
raise server_error.DataError("failed loading worker entries", err=e.args[0])
|
||||
|
||||
@ -218,26 +226,25 @@ class WorkerBLL:
|
||||
if w.last_activity_time.replace(tzinfo=None) >= ref_time
|
||||
]
|
||||
|
||||
if tags:
|
||||
include = {t for t in tags if not t.startswith("-")}
|
||||
exclude = {t[1:] for t in tags if t.startswith("-")}
|
||||
workers = [
|
||||
w
|
||||
for w in workers
|
||||
if (not include or any(t in include for t in w.tags))
|
||||
and (not exclude or all(t not in exclude for t in w.tags))
|
||||
]
|
||||
|
||||
return workers
|
||||
|
||||
def get_all_with_projection(
|
||||
self, company_id: str, last_seen: int, tags: Sequence[str] = None
|
||||
self,
|
||||
company_id: str,
|
||||
last_seen: int,
|
||||
tags: Sequence[str] = None,
|
||||
system_tags: Sequence[str] = None,
|
||||
) -> Sequence[WorkerResponseEntry]:
|
||||
|
||||
helpers = list(
|
||||
map(
|
||||
WorkerConversionHelper.from_worker_entry,
|
||||
self.get_all(company_id=company_id, last_seen=last_seen, tags=tags),
|
||||
self.get_all(
|
||||
company_id=company_id,
|
||||
last_seen=last_seen,
|
||||
tags=tags,
|
||||
system_tags=system_tags,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
@ -355,24 +362,115 @@ class WorkerBLL:
|
||||
|
||||
raise bad_request.InvalidWorkerId(worker=worker)
|
||||
|
||||
@staticmethod
|
||||
def _get_tagged_workers_key(company: str, tags_field: str, tag: str) -> str:
|
||||
"""Build redis key from company, user and worker_id"""
|
||||
return f"workers.{tags_field}_{company}_{tag}"
|
||||
|
||||
@staticmethod
|
||||
def _get_all_workers_key(company: str) -> str:
|
||||
"""Build redis key from company, user and worker_id"""
|
||||
return f"workers_{company}"
|
||||
|
||||
def _save_worker_data(self, entry: WorkerEntry):
|
||||
self.redis.setex(
|
||||
entry.key, timedelta(seconds=entry.register_timeout), entry.to_json()
|
||||
)
|
||||
company_id = entry.company.id
|
||||
expiration = int(time()) + entry.register_timeout
|
||||
worker_item = {entry.key: expiration}
|
||||
self.redis.zadd(self._get_all_workers_key(company_id), worker_item)
|
||||
for tags, tags_field in (
|
||||
(entry.tags, "tags"),
|
||||
(entry.system_tags, "systemtags"),
|
||||
):
|
||||
for tag in tags:
|
||||
name = self._get_tagged_workers_key(company_id, tags_field, tag)
|
||||
self.redis.zadd(name, worker_item)
|
||||
|
||||
def _save_worker(self, entry: WorkerEntry) -> None:
|
||||
"""Save worker entry in Redis"""
|
||||
try:
|
||||
self.redis.setex(
|
||||
entry.key, timedelta(seconds=entry.register_timeout), entry.to_json()
|
||||
)
|
||||
self._save_worker_data(entry)
|
||||
except Exception:
|
||||
msg = "Failed saving worker entry"
|
||||
log.exception(msg)
|
||||
|
||||
def _get(
|
||||
self, company: str, user: str = "*", worker_id: str = "*"
|
||||
self,
|
||||
company: str,
|
||||
user: str = "*",
|
||||
worker_id: str = "*",
|
||||
user_tags: Sequence[str] = None,
|
||||
system_tags: Sequence[str] = None,
|
||||
) -> Sequence[WorkerEntry]:
|
||||
"""Get worker entries matching the company and user, worker patterns"""
|
||||
|
||||
def filter_by_user(in_keys: Set[bytes]) -> Set[bytes]:
|
||||
if user == "*":
|
||||
return in_keys
|
||||
user_bytes = user.encode()
|
||||
return {k for k in in_keys if user_bytes in k}
|
||||
|
||||
if user_tags or system_tags:
|
||||
worker_keys = set()
|
||||
for tags, tags_field in (
|
||||
(user_tags, "tags"),
|
||||
(system_tags, "systemtags"),
|
||||
):
|
||||
if not tags:
|
||||
continue
|
||||
timestamp = int(time())
|
||||
include, exclude = partition(tags, key=lambda x: x[0] != "-")
|
||||
if include:
|
||||
tagged_workers = set()
|
||||
for tag in include:
|
||||
tagged_workers_key = self._get_tagged_workers_key(
|
||||
company, tags_field, tag
|
||||
)
|
||||
self.redis.zremrangebyscore(
|
||||
tagged_workers_key, min=0, max=timestamp
|
||||
)
|
||||
tagged_workers.update(
|
||||
self.redis.zrange(tagged_workers_key, 0, -1)
|
||||
)
|
||||
tagged_workers = filter_by_user(tagged_workers)
|
||||
worker_keys = (
|
||||
worker_keys.intersection(tagged_workers)
|
||||
if worker_keys
|
||||
else tagged_workers
|
||||
)
|
||||
if not worker_keys:
|
||||
return []
|
||||
if exclude:
|
||||
if not worker_keys:
|
||||
all_workers_key = self._get_all_workers_key(company)
|
||||
self.redis.zremrangebyscore(
|
||||
all_workers_key, min=0, max=timestamp
|
||||
)
|
||||
worker_keys.update(self.redis.zrange(all_workers_key, 0, -1))
|
||||
worker_keys = filter_by_user(worker_keys)
|
||||
if not worker_keys:
|
||||
return []
|
||||
for tag in exclude:
|
||||
tagged_workers_key = self._get_tagged_workers_key(
|
||||
company, tags_field, tag[1:]
|
||||
)
|
||||
self.redis.zremrangebyscore(
|
||||
tagged_workers_key, min=0, max=timestamp
|
||||
)
|
||||
worker_keys.difference_update(
|
||||
self.redis.zrange(tagged_workers_key, 0, -1)
|
||||
)
|
||||
if not worker_keys:
|
||||
return []
|
||||
else:
|
||||
match = self._get_worker_key(company, user, "*")
|
||||
worker_keys = self.redis.scan_iter(match)
|
||||
|
||||
entries = []
|
||||
match = self._get_worker_key(company, user, worker_id)
|
||||
for r in self.redis.scan_iter(match):
|
||||
data = self.redis.get(r)
|
||||
for key in worker_keys:
|
||||
data = self.redis.get(key)
|
||||
if data:
|
||||
entries.append(WorkerEntry.from_json(data))
|
||||
|
||||
|
@ -152,6 +152,15 @@ _definitions {
|
||||
type: array
|
||||
items: { type: string }
|
||||
}
|
||||
system_tags {
|
||||
description: "System tags for the worker"
|
||||
type: array
|
||||
items: { type: string }
|
||||
}
|
||||
key {
|
||||
description: "Worker entry key"
|
||||
type: string
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -159,11 +168,11 @@ _definitions {
|
||||
type: object
|
||||
properties {
|
||||
id {
|
||||
description: "ID"
|
||||
description: "Worker ID"
|
||||
type: string
|
||||
}
|
||||
name {
|
||||
description: "Name"
|
||||
description: "Worker name"
|
||||
type: string
|
||||
}
|
||||
}
|
||||
@ -294,6 +303,13 @@ get_all {
|
||||
items { type: string }
|
||||
}
|
||||
}
|
||||
"999.0": ${get_all."2.20"} {
|
||||
request.properties.system_tags {
|
||||
description: The list of allowed worker system tags. Prepend tag value with '-' in order to exclude
|
||||
type: array
|
||||
items { type: string }
|
||||
}
|
||||
}
|
||||
}
|
||||
register {
|
||||
"2.4" {
|
||||
@ -328,6 +344,13 @@ register {
|
||||
properties {}
|
||||
}
|
||||
}
|
||||
"999.0": ${register."2.4"} {
|
||||
request.properties.system_tags {
|
||||
description: "System tags for the worker"
|
||||
type: array
|
||||
items: { type: string }
|
||||
}
|
||||
}
|
||||
}
|
||||
unregister {
|
||||
"2.4" {
|
||||
@ -395,6 +418,13 @@ status_report {
|
||||
properties {}
|
||||
}
|
||||
}
|
||||
"999.0": ${status_report."2.4"} {
|
||||
request.properties.system_tags {
|
||||
description: "New system tags for the worker"
|
||||
type: array
|
||||
items: { type: string }
|
||||
}
|
||||
}
|
||||
}
|
||||
get_metric_keys {
|
||||
"2.4" {
|
||||
|
@ -42,7 +42,10 @@ worker_bll = WorkerBLL()
|
||||
def get_all(call: APICall, company_id: str, request: GetAllRequest):
|
||||
call.result.data_model = GetAllResponse(
|
||||
workers=worker_bll.get_all_with_projection(
|
||||
company_id, request.last_seen, tags=request.tags
|
||||
company_id,
|
||||
request.last_seen,
|
||||
tags=request.tags,
|
||||
system_tags=request.system_tags,
|
||||
)
|
||||
)
|
||||
|
||||
@ -66,6 +69,7 @@ def register(call: APICall, company_id, request: RegisterRequest):
|
||||
queues=queues,
|
||||
timeout=timeout,
|
||||
tags=request.tags,
|
||||
system_tags=request.system_tags,
|
||||
)
|
||||
|
||||
|
||||
@ -84,6 +88,7 @@ def status_report(call: APICall, company_id, request: StatusReportRequest):
|
||||
ip=call.real_ip,
|
||||
report=request,
|
||||
tags=request.tags,
|
||||
system_tags=request.system_tags,
|
||||
)
|
||||
|
||||
|
||||
|
@ -37,6 +37,19 @@ class TestWorkersService(TestService):
|
||||
time.sleep(5)
|
||||
self._check_exists(test_worker, False)
|
||||
|
||||
def test_system_tags(self):
|
||||
test_worker = f"test_{uuid4().hex}"
|
||||
tag = uuid4().hex
|
||||
|
||||
# system_tags support
|
||||
worker = self.api.workers.get_all(tags=[tag], system_tags=["Application"]).workers[0]
|
||||
self.assertEqual(worker.id, test_worker)
|
||||
self.assertEqual(worker.tags, [tag])
|
||||
self.assertEqual(worker.system_tags, ["Application"])
|
||||
|
||||
workers = self.api.workers.get_all(tags=[tag], system_tags=["-Application"]).workers
|
||||
self.assertFalse(workers)
|
||||
|
||||
def test_filters(self):
|
||||
test_worker = f"test_{uuid4().hex}"
|
||||
self.api.workers.register(worker=test_worker, tags=["application"], timeout=3)
|
||||
|
Loading…
Reference in New Issue
Block a user