mirror of
https://github.com/clearml/clearml-server
synced 2025-06-03 19:36:14 +00:00
Add workers.get_count endpoint
This commit is contained in:
parent
1b650b1689
commit
db021f2863
@ -12,7 +12,7 @@ from jsonmodels.fields import (
|
|||||||
)
|
)
|
||||||
from jsonmodels.models import Base
|
from jsonmodels.models import Base
|
||||||
|
|
||||||
from apiserver.apimodels import make_default, ListField, EnumField, JsonSerializableMixin
|
from apiserver.apimodels import ListField, EnumField, JsonSerializableMixin
|
||||||
|
|
||||||
DEFAULT_TIMEOUT = 10 * 60
|
DEFAULT_TIMEOUT = 10 * 60
|
||||||
|
|
||||||
@ -104,6 +104,10 @@ class GetAllResponse(Base):
|
|||||||
workers = ListField(WorkerResponseEntry)
|
workers = ListField(WorkerResponseEntry)
|
||||||
|
|
||||||
|
|
||||||
|
class GetCountRequest(GetAllRequest):
|
||||||
|
last_seen = IntField(default=0)
|
||||||
|
|
||||||
|
|
||||||
class StatsBase(Base):
|
class StatsBase(Base):
|
||||||
worker_ids = ListField(str)
|
worker_ids = ListField(str)
|
||||||
|
|
||||||
|
@ -200,6 +200,24 @@ class WorkerBLL:
|
|||||||
finally:
|
finally:
|
||||||
self._save_worker(entry)
|
self._save_worker(entry)
|
||||||
|
|
||||||
|
def get_count(
|
||||||
|
self,
|
||||||
|
company_id: str,
|
||||||
|
last_seen: Optional[int] = None,
|
||||||
|
tags: Sequence[str] = None,
|
||||||
|
system_tags: Sequence[str] = None,
|
||||||
|
):
|
||||||
|
if not last_seen:
|
||||||
|
return len(
|
||||||
|
self._get_keys(company_id, user_tags=tags, system_tags=system_tags)
|
||||||
|
)
|
||||||
|
|
||||||
|
return len(
|
||||||
|
self.get_all(
|
||||||
|
company_id, last_seen=last_seen, tags=tags, system_tags=system_tags
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
def get_all(
|
def get_all(
|
||||||
self,
|
self,
|
||||||
company_id: str,
|
company_id: str,
|
||||||
@ -396,15 +414,16 @@ class WorkerBLL:
|
|||||||
msg = "Failed saving worker entry"
|
msg = "Failed saving worker entry"
|
||||||
log.exception(msg)
|
log.exception(msg)
|
||||||
|
|
||||||
def _get(
|
def _get_keys(
|
||||||
self,
|
self,
|
||||||
company: str,
|
company: str,
|
||||||
user: str = "*",
|
user: str = "*",
|
||||||
worker_id: str = "*",
|
|
||||||
user_tags: Sequence[str] = None,
|
user_tags: Sequence[str] = None,
|
||||||
system_tags: Sequence[str] = None,
|
system_tags: Sequence[str] = None,
|
||||||
) -> Sequence[WorkerEntry]:
|
) -> Sequence[bytes]:
|
||||||
"""Get worker entries matching the company and user, worker patterns"""
|
if not (user_tags or system_tags):
|
||||||
|
match = self._get_worker_key(company, user, "*")
|
||||||
|
return list(self.redis.scan_iter(match))
|
||||||
|
|
||||||
def filter_by_user(in_keys: Set[bytes]) -> Set[bytes]:
|
def filter_by_user(in_keys: Set[bytes]) -> Set[bytes]:
|
||||||
if user == "*":
|
if user == "*":
|
||||||
@ -412,7 +431,6 @@ class WorkerBLL:
|
|||||||
user_bytes = user.encode()
|
user_bytes = user.encode()
|
||||||
return {k for k in in_keys if user_bytes in k}
|
return {k for k in in_keys if user_bytes in k}
|
||||||
|
|
||||||
if user_tags or system_tags:
|
|
||||||
worker_keys = set()
|
worker_keys = set()
|
||||||
for tags, tags_field in (
|
for tags, tags_field in (
|
||||||
(user_tags, "tags"),
|
(user_tags, "tags"),
|
||||||
@ -420,6 +438,7 @@ class WorkerBLL:
|
|||||||
):
|
):
|
||||||
if not tags:
|
if not tags:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
timestamp = int(time())
|
timestamp = int(time())
|
||||||
include, exclude = partition(tags, key=lambda x: x[0] != "-")
|
include, exclude = partition(tags, key=lambda x: x[0] != "-")
|
||||||
if include:
|
if include:
|
||||||
@ -431,9 +450,8 @@ class WorkerBLL:
|
|||||||
self.redis.zremrangebyscore(
|
self.redis.zremrangebyscore(
|
||||||
tagged_workers_key, min=0, max=timestamp
|
tagged_workers_key, min=0, max=timestamp
|
||||||
)
|
)
|
||||||
tagged_workers.update(
|
tagged_workers.update(self.redis.zrange(tagged_workers_key, 0, -1))
|
||||||
self.redis.zrange(tagged_workers_key, 0, -1)
|
|
||||||
)
|
|
||||||
tagged_workers = filter_by_user(tagged_workers)
|
tagged_workers = filter_by_user(tagged_workers)
|
||||||
worker_keys = (
|
worker_keys = (
|
||||||
worker_keys.intersection(tagged_workers)
|
worker_keys.intersection(tagged_workers)
|
||||||
@ -442,16 +460,16 @@ class WorkerBLL:
|
|||||||
)
|
)
|
||||||
if not worker_keys:
|
if not worker_keys:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
if exclude:
|
if exclude:
|
||||||
if not worker_keys:
|
if not worker_keys:
|
||||||
all_workers_key = self._get_all_workers_key(company)
|
all_workers_key = self._get_all_workers_key(company)
|
||||||
self.redis.zremrangebyscore(
|
self.redis.zremrangebyscore(all_workers_key, min=0, max=timestamp)
|
||||||
all_workers_key, min=0, max=timestamp
|
|
||||||
)
|
|
||||||
worker_keys.update(self.redis.zrange(all_workers_key, 0, -1))
|
worker_keys.update(self.redis.zrange(all_workers_key, 0, -1))
|
||||||
worker_keys = filter_by_user(worker_keys)
|
worker_keys = filter_by_user(worker_keys)
|
||||||
if not worker_keys:
|
if not worker_keys:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
for tag in exclude:
|
for tag in exclude:
|
||||||
tagged_workers_key = self._get_tagged_workers_key(
|
tagged_workers_key = self._get_tagged_workers_key(
|
||||||
company, tags_field, tag[1:]
|
company, tags_field, tag[1:]
|
||||||
@ -464,12 +482,22 @@ class WorkerBLL:
|
|||||||
)
|
)
|
||||||
if not worker_keys:
|
if not worker_keys:
|
||||||
return []
|
return []
|
||||||
else:
|
|
||||||
match = self._get_worker_key(company, user, "*")
|
return list(worker_keys)
|
||||||
worker_keys = self.redis.scan_iter(match)
|
|
||||||
|
def _get(
|
||||||
|
self,
|
||||||
|
company: str,
|
||||||
|
user: str = "*",
|
||||||
|
user_tags: Sequence[str] = None,
|
||||||
|
system_tags: Sequence[str] = None,
|
||||||
|
) -> Sequence[WorkerEntry]:
|
||||||
|
"""Get worker entries matching the company and user, worker patterns"""
|
||||||
|
|
||||||
entries = []
|
entries = []
|
||||||
for key in worker_keys:
|
for key in self._get_keys(
|
||||||
|
company, user=user, user_tags=user_tags, system_tags=system_tags
|
||||||
|
):
|
||||||
data = self.redis.get(key)
|
data = self.redis.get(key)
|
||||||
if data:
|
if data:
|
||||||
entries.append(WorkerEntry.from_json(data))
|
entries.append(WorkerEntry.from_json(data))
|
||||||
|
@ -311,6 +311,41 @@ get_all {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
get_count {
|
||||||
|
"999.0": {
|
||||||
|
description: "Returns the number of registered workers."
|
||||||
|
request {
|
||||||
|
type: object
|
||||||
|
properties {
|
||||||
|
last_seen {
|
||||||
|
description: """Filter out workers not active for more than last_seen seconds.
|
||||||
|
A value or 0 or 'none' will disable the filter."""
|
||||||
|
type: integer
|
||||||
|
default: 0
|
||||||
|
}
|
||||||
|
tags {
|
||||||
|
description: The list of allowed worker tags. Prepend tag value with '-' in order to exclude
|
||||||
|
type: array
|
||||||
|
items { type: string }
|
||||||
|
}
|
||||||
|
system_tags {
|
||||||
|
description: The list of allowed worker system tags. Prepend tag value with '-' in order to exclude
|
||||||
|
type: array
|
||||||
|
items { type: string }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
response {
|
||||||
|
type: object
|
||||||
|
properties {
|
||||||
|
count {
|
||||||
|
description: Workers count
|
||||||
|
type: integer
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
register {
|
register {
|
||||||
"2.4" {
|
"2.4" {
|
||||||
description: "Register a worker in the system. Called by the Worker Daemon."
|
description: "Register a worker in the system. Called by the Worker Daemon."
|
||||||
|
@ -22,6 +22,7 @@ from apiserver.apimodels.workers import (
|
|||||||
GetActivityReportRequest,
|
GetActivityReportRequest,
|
||||||
GetActivityReportResponse,
|
GetActivityReportResponse,
|
||||||
ActivityReportSeries,
|
ActivityReportSeries,
|
||||||
|
GetCountRequest,
|
||||||
)
|
)
|
||||||
from apiserver.bll.workers import WorkerBLL
|
from apiserver.bll.workers import WorkerBLL
|
||||||
from apiserver.config_repo import config
|
from apiserver.config_repo import config
|
||||||
@ -50,6 +51,20 @@ def get_all(call: APICall, company_id: str, request: GetAllRequest):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@endpoint(
|
||||||
|
"workers.get_count", request_data_model=GetCountRequest,
|
||||||
|
)
|
||||||
|
def get_all(call: APICall, company_id: str, request: GetCountRequest):
|
||||||
|
call.result.data = {
|
||||||
|
"count": worker_bll.get_count(
|
||||||
|
company_id,
|
||||||
|
request.last_seen,
|
||||||
|
tags=request.tags,
|
||||||
|
system_tags=request.system_tags,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@endpoint("workers.register", min_version="2.4", request_data_model=RegisterRequest)
|
@endpoint("workers.register", min_version="2.4", request_data_model=RegisterRequest)
|
||||||
def register(call: APICall, company_id, request: RegisterRequest):
|
def register(call: APICall, company_id, request: RegisterRequest):
|
||||||
worker = request.worker
|
worker = request.worker
|
||||||
|
@ -27,6 +27,25 @@ class TestWorkersService(TestService):
|
|||||||
self.api.workers.unregister(worker=test_worker)
|
self.api.workers.unregister(worker=test_worker)
|
||||||
self._check_exists(test_worker, False)
|
self._check_exists(test_worker, False)
|
||||||
|
|
||||||
|
def test_get_count(self):
|
||||||
|
test_workers = [f"test_{uuid4().hex}" for _ in range(2)]
|
||||||
|
system_tag = f"tag_{uuid4().hex}"
|
||||||
|
for w in test_workers:
|
||||||
|
self.api.workers.register(worker=w, system_tags=[system_tag])
|
||||||
|
# total workers count include the new ones
|
||||||
|
count = self.api.workers.get_count().count
|
||||||
|
self.assertGreater(count, len(test_workers))
|
||||||
|
# filter by system tag and last seen
|
||||||
|
count = self.api.workers.get_count(system_tags=[system_tag], last_seen=4).count
|
||||||
|
self.assertEqual(count, len(test_workers))
|
||||||
|
time.sleep(5)
|
||||||
|
# workers not seen recently
|
||||||
|
count = self.api.workers.get_count(system_tags=[system_tag], last_seen=4).count
|
||||||
|
self.assertEqual(count, 0)
|
||||||
|
# but still visible without the last seen filter
|
||||||
|
count = self.api.workers.get_count(system_tags=[system_tag]).count
|
||||||
|
self.assertEqual(count, len(test_workers))
|
||||||
|
|
||||||
def test_workers_timeout(self):
|
def test_workers_timeout(self):
|
||||||
test_worker = f"test_{uuid4().hex}"
|
test_worker = f"test_{uuid4().hex}"
|
||||||
self._check_exists(test_worker, False)
|
self._check_exists(test_worker, False)
|
||||||
|
Loading…
Reference in New Issue
Block a user