Optimize Workers retrieval

Store worker statistics under worker id and not internal redis key
Fix unit tests
This commit is contained in:
allegroai 2023-11-17 09:46:44 +02:00
parent a7865ccbec
commit 4ac6f88278
3 changed files with 71 additions and 80 deletions

View File

@ -13,8 +13,7 @@ from jsonmodels.fields import (
from jsonmodels.models import Base
from apiserver.apimodels import ListField, EnumField, JsonSerializableMixin
DEFAULT_TIMEOUT = 10 * 60
from apiserver.config_repo import config
class WorkerRequest(Base):
@ -24,7 +23,10 @@ class WorkerRequest(Base):
class RegisterRequest(WorkerRequest):
timeout = IntField(default=0) # registration timeout in seconds (if not specified, default is 10min)
timeout = IntField(
default=int(config.get("services.workers.default_worker_timeout_sec", 10 * 60))
)
""" registration timeout in seconds (default is 10min) """
queues = ListField(six.string_types) # list of queues this worker listens to

View File

@ -5,13 +5,13 @@ from typing import Sequence, Set, Optional
import attr
import elasticsearch.helpers
from boltons.iterutils import partition
from boltons.iterutils import partition, chunked_iter
from pyhocon import ConfigTree
from apiserver.es_factory import es_factory
from apiserver.apierrors import APIError
from apiserver.apierrors.errors import bad_request, server_error
from apiserver.apimodels.workers import (
DEFAULT_TIMEOUT,
IdNameEntry,
WorkerEntry,
StatusReportRequest,
@ -30,12 +30,14 @@ from apiserver.redis_manager import redman
from apiserver.tools import safe_get
from .stats import WorkerStats
log = config.logger(__file__)
class WorkerBLL:
def __init__(self, es=None, redis=None):
self.es_client = es or es_factory.connect("workers")
self.config = config.get("services.workers", ConfigTree())
self.redis = redis or redman.connection("workers")
self._stats = WorkerStats(self.es_client)
@ -68,7 +70,7 @@ class WorkerBLL:
"""
key = WorkerBLL._get_worker_key(company_id, user_id, worker)
timeout = timeout or DEFAULT_TIMEOUT
timeout = timeout or int(self.config.get("default_worker_timeout_sec", 10 * 60))
queues = queues or []
with translate_errors_context():
@ -141,8 +143,6 @@ class WorkerBLL:
try:
entry.ip = ip
now = datetime.utcnow()
entry.last_activity_time = now
if tags is not None:
entry.tags = tags
@ -150,15 +150,16 @@ class WorkerBLL:
entry.system_tags = system_tags
if report.machine_stats:
self._log_stats_to_es(
self.log_stats_to_es(
company_id=company_id,
company_name=entry.company.name,
worker=entry.key,
worker_id=report.worker,
timestamp=report.timestamp,
task=report.task,
machine_stats=report.machine_stats,
)
now = datetime.utcnow()
entry.last_activity_time = now
entry.queue = report.queue
if report.queues:
@ -254,18 +255,15 @@ class WorkerBLL:
tags: Sequence[str] = None,
system_tags: Sequence[str] = None,
) -> Sequence[WorkerResponseEntry]:
helpers = list(
map(
WorkerConversionHelper.from_worker_entry,
self.get_all(
helpers = [
WorkerConversionHelper.from_worker_entry(entry)
for entry in self.get_all(
company_id=company_id,
last_seen=last_seen,
tags=tags,
system_tags=system_tags,
),
)
)
]
task_ids = set(filter(None, (helper.task_id for helper in helpers)))
all_queues = set(
@ -284,9 +282,7 @@ class WorkerBLL:
}
},
]
queues_info = {
res["_id"]: res for res in Queue.objects.aggregate(projection)
}
queues_info = {res["_id"]: res for res in Queue.aggregate(projection)}
task_ids = task_ids.union(
filter(
None,
@ -496,12 +492,15 @@ class WorkerBLL:
"""Get worker entries matching the company and user, worker patterns"""
entries = []
for key in self._get_keys(
for keys in chunked_iter(
self._get_keys(
company, user=user, user_tags=user_tags, system_tags=system_tags
),
1000,
):
data = self.redis.get(key)
data = self.redis.mget(keys)
if data:
entries.append(WorkerEntry.from_json(data))
entries.extend(WorkerEntry.from_json(d) for d in data if d)
return entries
@ -510,18 +509,17 @@ class WorkerBLL:
"""Get the index name suffix for storing current month data"""
return datetime.utcnow().strftime("%Y-%m")
def _log_stats_to_es(
def log_stats_to_es(
self,
company_id: str,
company_name: str,
worker: str,
worker_id: str,
timestamp: int,
task: str,
machine_stats: MachineStats,
) -> bool:
) -> int:
"""
Actually writing the worker statistics to Elastic
:return: True if successful, False otherwise
:return: The amount of logged documents
"""
es_index = (
f"{self._stats.worker_stats_prefix_for_company(company_id)}"
@ -533,8 +531,7 @@ class WorkerBLL:
_index=es_index,
_source=dict(
timestamp=timestamp,
worker=worker,
company=company_name,
worker=worker_id,
task=task,
category=category,
metric=metric,
@ -559,7 +556,7 @@ class WorkerBLL:
es_res = elasticsearch.helpers.bulk(self.es_client, actions)
added, errors = es_res[:2]
return (added == len(actions)) and not errors
return added
@attr.s(auto_attribs=True)

View File

@ -1,7 +1,6 @@
import time
from uuid import uuid4
from datetime import timedelta
from operator import attrgetter
from typing import Sequence
from apiserver.apierrors.errors import bad_request
@ -72,7 +71,9 @@ class TestWorkersService(TestService):
self.assertEqual(worker.tags, [tag])
self.assertEqual(worker.system_tags, [system_tag])
workers = self.api.workers.get_all(tags=[tag], system_tags=[f"-{system_tag}"]).workers
workers = self.api.workers.get_all(
tags=[tag], system_tags=[f"-{system_tag}"]
).workers
self.assertFalse(workers)
def test_filters(self):
@ -105,25 +106,23 @@ class TestWorkersService(TestService):
(workers[0],),
(workers[0],),
]
timestamp = int(utc_now_tz_aware().timestamp() * 1000)
for ws, stats in zip(workers_activity, workers_stats):
for w, s in zip(ws, stats):
data = dict(
worker=w,
timestamp=int(utc_now_tz_aware().timestamp() * 1000),
timestamp=timestamp,
machine_stats=s,
)
if w == workers[0]:
data["task"] = task_id
self.api.workers.status_report(**data)
time.sleep(1)
timestamp += 1000
res = self.api.workers.get_all(last_seen=100)
return [w.key for w in res.workers]
return workers
def _create_running_task(self, task_name):
task_input = dict(
name=task_name, type="testing"
)
task_input = dict(name=task_name, type="testing")
task_id = self.create_temp("tasks", **task_input)
@ -132,6 +131,7 @@ class TestWorkersService(TestService):
def test_get_keys(self):
workers = self._simulate_workers()
time.sleep(5) # give to es time to refresh
res = self.api.workers.get_metric_keys(worker_ids=workers)
assert {"cpu", "memory"} == set(c.name for c in res["categories"])
assert all(
@ -152,6 +152,7 @@ class TestWorkersService(TestService):
to_date = utc_now_tz_aware() + timedelta(seconds=10)
from_date = to_date - timedelta(days=1)
time.sleep(5) # give to ES time to refresh
# no variants
res = self.api.workers.get_stats(
items=[
@ -166,25 +167,21 @@ class TestWorkersService(TestService):
interval=1,
worker_ids=workers,
)
self.assertWorkersInStats(workers, res["workers"])
assert all(
{"cpu_usage", "memory_used"}
== set(map(attrgetter("metric"), worker["metrics"]))
for worker in res["workers"]
self.assertWorkersInStats(workers, res.workers)
for worker in res.workers:
self.assertEqual(
set(metric.metric for metric in worker.metrics),
{"cpu_usage", "memory_used"},
)
def _check_dates_and_stats(metric, stats, worker_id) -> bool:
return set(
map(attrgetter("aggregation"), metric["stats"])
) == stats and len(metric["dates"]) == (4 if worker_id == workers[0] else 2)
assert all(
_check_dates_and_stats(metric, metric_stats, worker["worker"])
for worker in res["workers"]
for worker in res.workers:
for metric, metric_stats in zip(
worker["metrics"], ({"avg", "max"}, {"max", "min"})
)
worker.metrics, ({"avg", "max"}, {"max", "min"})
):
self.assertEqual(
set(stat.aggregation for stat in metric.stats), metric_stats
)
self.assertEqual(len(metric.dates), 4 if worker.worker == workers[0] else 2)
# split by variants
res = self.api.workers.get_stats(
@ -195,20 +192,15 @@ class TestWorkersService(TestService):
interval=1,
worker_ids=workers,
)
self.assertWorkersInStats(workers, res["workers"])
self.assertWorkersInStats(workers, res.workers)
def _check_metric_and_variants(worker):
return (
all(
_check_dates_and_stats(metric, {"avg"}, worker["worker"])
for metric in worker["metrics"]
for worker in res.workers:
for metric in worker.metrics:
self.assertEqual(
set(metric.variant for metric in worker.metrics),
{"0", "1"} if worker.worker == workers[0] else {"0"},
)
and set(map(attrgetter("variant"), worker["metrics"])) == {"0", "1"}
if worker["worker"] == workers[0]
else {"0"}
)
assert all(_check_metric_and_variants(worker) for worker in res["workers"])
self.assertEqual(len(metric.dates), 4 if worker.worker == workers[0] else 2)
res = self.api.workers.get_stats(
items=[dict(key="cpu_usage", aggregation="avg")],
@ -217,11 +209,10 @@ class TestWorkersService(TestService):
interval=1,
worker_ids=["Non existing worker id"],
)
assert not res["workers"]
assert not res.workers
@staticmethod
def assertWorkersInStats(workers: Sequence[str], stats: dict):
assert set(workers) == set(map(attrgetter("worker"), stats))
def assertWorkersInStats(self, workers: Sequence[str], stats: Sequence):
self.assertEqual(set(workers), set(item.worker for item in stats))
def test_get_activity_report(self):
# test no workers data
@ -238,6 +229,7 @@ class TestWorkersService(TestService):
to_date = utc_now_tz_aware() + timedelta(seconds=10)
from_date = to_date - timedelta(minutes=1)
time.sleep(5) # give to es time to refresh
# no variants
res = self.api.workers.get_activity_report(
from_date=from_date.timestamp(), to_date=to_date.timestamp(), interval=20