Optimize Workers retrieval

Store worker statistics under worker id and not internal redis key
Fix unit tests
This commit is contained in:
allegroai 2023-11-17 09:46:44 +02:00
parent a7865ccbec
commit 4ac6f88278
3 changed files with 71 additions and 80 deletions

View File

@ -13,8 +13,7 @@ from jsonmodels.fields import (
from jsonmodels.models import Base from jsonmodels.models import Base
from apiserver.apimodels import ListField, EnumField, JsonSerializableMixin from apiserver.apimodels import ListField, EnumField, JsonSerializableMixin
from apiserver.config_repo import config
DEFAULT_TIMEOUT = 10 * 60
class WorkerRequest(Base): class WorkerRequest(Base):
@ -24,7 +23,10 @@ class WorkerRequest(Base):
class RegisterRequest(WorkerRequest): class RegisterRequest(WorkerRequest):
timeout = IntField(default=0) # registration timeout in seconds (if not specified, default is 10min) timeout = IntField(
default=int(config.get("services.workers.default_worker_timeout_sec", 10 * 60))
)
""" registration timeout in seconds (default is 10min) """
queues = ListField(six.string_types) # list of queues this worker listens to queues = ListField(six.string_types) # list of queues this worker listens to

View File

@ -5,13 +5,13 @@ from typing import Sequence, Set, Optional
import attr import attr
import elasticsearch.helpers import elasticsearch.helpers
from boltons.iterutils import partition from boltons.iterutils import partition, chunked_iter
from pyhocon import ConfigTree
from apiserver.es_factory import es_factory from apiserver.es_factory import es_factory
from apiserver.apierrors import APIError from apiserver.apierrors import APIError
from apiserver.apierrors.errors import bad_request, server_error from apiserver.apierrors.errors import bad_request, server_error
from apiserver.apimodels.workers import ( from apiserver.apimodels.workers import (
DEFAULT_TIMEOUT,
IdNameEntry, IdNameEntry,
WorkerEntry, WorkerEntry,
StatusReportRequest, StatusReportRequest,
@ -30,12 +30,14 @@ from apiserver.redis_manager import redman
from apiserver.tools import safe_get from apiserver.tools import safe_get
from .stats import WorkerStats from .stats import WorkerStats
log = config.logger(__file__) log = config.logger(__file__)
class WorkerBLL: class WorkerBLL:
def __init__(self, es=None, redis=None): def __init__(self, es=None, redis=None):
self.es_client = es or es_factory.connect("workers") self.es_client = es or es_factory.connect("workers")
self.config = config.get("services.workers", ConfigTree())
self.redis = redis or redman.connection("workers") self.redis = redis or redman.connection("workers")
self._stats = WorkerStats(self.es_client) self._stats = WorkerStats(self.es_client)
@ -68,7 +70,7 @@ class WorkerBLL:
""" """
key = WorkerBLL._get_worker_key(company_id, user_id, worker) key = WorkerBLL._get_worker_key(company_id, user_id, worker)
timeout = timeout or DEFAULT_TIMEOUT timeout = timeout or int(self.config.get("default_worker_timeout_sec", 10 * 60))
queues = queues or [] queues = queues or []
with translate_errors_context(): with translate_errors_context():
@ -141,8 +143,6 @@ class WorkerBLL:
try: try:
entry.ip = ip entry.ip = ip
now = datetime.utcnow()
entry.last_activity_time = now
if tags is not None: if tags is not None:
entry.tags = tags entry.tags = tags
@ -150,15 +150,16 @@ class WorkerBLL:
entry.system_tags = system_tags entry.system_tags = system_tags
if report.machine_stats: if report.machine_stats:
self._log_stats_to_es( self.log_stats_to_es(
company_id=company_id, company_id=company_id,
company_name=entry.company.name, worker_id=report.worker,
worker=entry.key,
timestamp=report.timestamp, timestamp=report.timestamp,
task=report.task, task=report.task,
machine_stats=report.machine_stats, machine_stats=report.machine_stats,
) )
now = datetime.utcnow()
entry.last_activity_time = now
entry.queue = report.queue entry.queue = report.queue
if report.queues: if report.queues:
@ -254,18 +255,15 @@ class WorkerBLL:
tags: Sequence[str] = None, tags: Sequence[str] = None,
system_tags: Sequence[str] = None, system_tags: Sequence[str] = None,
) -> Sequence[WorkerResponseEntry]: ) -> Sequence[WorkerResponseEntry]:
helpers = [
helpers = list( WorkerConversionHelper.from_worker_entry(entry)
map( for entry in self.get_all(
WorkerConversionHelper.from_worker_entry, company_id=company_id,
self.get_all( last_seen=last_seen,
company_id=company_id, tags=tags,
last_seen=last_seen, system_tags=system_tags,
tags=tags,
system_tags=system_tags,
),
) )
) ]
task_ids = set(filter(None, (helper.task_id for helper in helpers))) task_ids = set(filter(None, (helper.task_id for helper in helpers)))
all_queues = set( all_queues = set(
@ -284,9 +282,7 @@ class WorkerBLL:
} }
}, },
] ]
queues_info = { queues_info = {res["_id"]: res for res in Queue.aggregate(projection)}
res["_id"]: res for res in Queue.objects.aggregate(projection)
}
task_ids = task_ids.union( task_ids = task_ids.union(
filter( filter(
None, None,
@ -496,12 +492,15 @@ class WorkerBLL:
"""Get worker entries matching the company and user, worker patterns""" """Get worker entries matching the company and user, worker patterns"""
entries = [] entries = []
for key in self._get_keys( for keys in chunked_iter(
company, user=user, user_tags=user_tags, system_tags=system_tags self._get_keys(
company, user=user, user_tags=user_tags, system_tags=system_tags
),
1000,
): ):
data = self.redis.get(key) data = self.redis.mget(keys)
if data: if data:
entries.append(WorkerEntry.from_json(data)) entries.extend(WorkerEntry.from_json(d) for d in data if d)
return entries return entries
@ -510,18 +509,17 @@ class WorkerBLL:
"""Get the index name suffix for storing current month data""" """Get the index name suffix for storing current month data"""
return datetime.utcnow().strftime("%Y-%m") return datetime.utcnow().strftime("%Y-%m")
def _log_stats_to_es( def log_stats_to_es(
self, self,
company_id: str, company_id: str,
company_name: str, worker_id: str,
worker: str,
timestamp: int, timestamp: int,
task: str, task: str,
machine_stats: MachineStats, machine_stats: MachineStats,
) -> bool: ) -> int:
""" """
Actually writing the worker statistics to Elastic Actually writing the worker statistics to Elastic
:return: True if successful, False otherwise :return: The amount of logged documents
""" """
es_index = ( es_index = (
f"{self._stats.worker_stats_prefix_for_company(company_id)}" f"{self._stats.worker_stats_prefix_for_company(company_id)}"
@ -533,8 +531,7 @@ class WorkerBLL:
_index=es_index, _index=es_index,
_source=dict( _source=dict(
timestamp=timestamp, timestamp=timestamp,
worker=worker, worker=worker_id,
company=company_name,
task=task, task=task,
category=category, category=category,
metric=metric, metric=metric,
@ -559,7 +556,7 @@ class WorkerBLL:
es_res = elasticsearch.helpers.bulk(self.es_client, actions) es_res = elasticsearch.helpers.bulk(self.es_client, actions)
added, errors = es_res[:2] added, errors = es_res[:2]
return (added == len(actions)) and not errors return added
@attr.s(auto_attribs=True) @attr.s(auto_attribs=True)

View File

@ -1,7 +1,6 @@
import time import time
from uuid import uuid4 from uuid import uuid4
from datetime import timedelta from datetime import timedelta
from operator import attrgetter
from typing import Sequence from typing import Sequence
from apiserver.apierrors.errors import bad_request from apiserver.apierrors.errors import bad_request
@ -72,7 +71,9 @@ class TestWorkersService(TestService):
self.assertEqual(worker.tags, [tag]) self.assertEqual(worker.tags, [tag])
self.assertEqual(worker.system_tags, [system_tag]) self.assertEqual(worker.system_tags, [system_tag])
workers = self.api.workers.get_all(tags=[tag], system_tags=[f"-{system_tag}"]).workers workers = self.api.workers.get_all(
tags=[tag], system_tags=[f"-{system_tag}"]
).workers
self.assertFalse(workers) self.assertFalse(workers)
def test_filters(self): def test_filters(self):
@ -105,25 +106,23 @@ class TestWorkersService(TestService):
(workers[0],), (workers[0],),
(workers[0],), (workers[0],),
] ]
timestamp = int(utc_now_tz_aware().timestamp() * 1000)
for ws, stats in zip(workers_activity, workers_stats): for ws, stats in zip(workers_activity, workers_stats):
for w, s in zip(ws, stats): for w, s in zip(ws, stats):
data = dict( data = dict(
worker=w, worker=w,
timestamp=int(utc_now_tz_aware().timestamp() * 1000), timestamp=timestamp,
machine_stats=s, machine_stats=s,
) )
if w == workers[0]: if w == workers[0]:
data["task"] = task_id data["task"] = task_id
self.api.workers.status_report(**data) self.api.workers.status_report(**data)
time.sleep(1) timestamp += 1000
res = self.api.workers.get_all(last_seen=100) return workers
return [w.key for w in res.workers]
def _create_running_task(self, task_name): def _create_running_task(self, task_name):
task_input = dict( task_input = dict(name=task_name, type="testing")
name=task_name, type="testing"
)
task_id = self.create_temp("tasks", **task_input) task_id = self.create_temp("tasks", **task_input)
@ -132,6 +131,7 @@ class TestWorkersService(TestService):
def test_get_keys(self): def test_get_keys(self):
workers = self._simulate_workers() workers = self._simulate_workers()
time.sleep(5) # give to es time to refresh
res = self.api.workers.get_metric_keys(worker_ids=workers) res = self.api.workers.get_metric_keys(worker_ids=workers)
assert {"cpu", "memory"} == set(c.name for c in res["categories"]) assert {"cpu", "memory"} == set(c.name for c in res["categories"])
assert all( assert all(
@ -152,6 +152,7 @@ class TestWorkersService(TestService):
to_date = utc_now_tz_aware() + timedelta(seconds=10) to_date = utc_now_tz_aware() + timedelta(seconds=10)
from_date = to_date - timedelta(days=1) from_date = to_date - timedelta(days=1)
time.sleep(5) # give to ES time to refresh
# no variants # no variants
res = self.api.workers.get_stats( res = self.api.workers.get_stats(
items=[ items=[
@ -166,25 +167,21 @@ class TestWorkersService(TestService):
interval=1, interval=1,
worker_ids=workers, worker_ids=workers,
) )
self.assertWorkersInStats(workers, res["workers"]) self.assertWorkersInStats(workers, res.workers)
assert all( for worker in res.workers:
{"cpu_usage", "memory_used"} self.assertEqual(
== set(map(attrgetter("metric"), worker["metrics"])) set(metric.metric for metric in worker.metrics),
for worker in res["workers"] {"cpu_usage", "memory_used"},
)
def _check_dates_and_stats(metric, stats, worker_id) -> bool:
return set(
map(attrgetter("aggregation"), metric["stats"])
) == stats and len(metric["dates"]) == (4 if worker_id == workers[0] else 2)
assert all(
_check_dates_and_stats(metric, metric_stats, worker["worker"])
for worker in res["workers"]
for metric, metric_stats in zip(
worker["metrics"], ({"avg", "max"}, {"max", "min"})
) )
)
for worker in res.workers:
for metric, metric_stats in zip(
worker.metrics, ({"avg", "max"}, {"max", "min"})
):
self.assertEqual(
set(stat.aggregation for stat in metric.stats), metric_stats
)
self.assertEqual(len(metric.dates), 4 if worker.worker == workers[0] else 2)
# split by variants # split by variants
res = self.api.workers.get_stats( res = self.api.workers.get_stats(
@ -195,20 +192,15 @@ class TestWorkersService(TestService):
interval=1, interval=1,
worker_ids=workers, worker_ids=workers,
) )
self.assertWorkersInStats(workers, res["workers"]) self.assertWorkersInStats(workers, res.workers)
def _check_metric_and_variants(worker): for worker in res.workers:
return ( for metric in worker.metrics:
all( self.assertEqual(
_check_dates_and_stats(metric, {"avg"}, worker["worker"]) set(metric.variant for metric in worker.metrics),
for metric in worker["metrics"] {"0", "1"} if worker.worker == workers[0] else {"0"},
) )
and set(map(attrgetter("variant"), worker["metrics"])) == {"0", "1"} self.assertEqual(len(metric.dates), 4 if worker.worker == workers[0] else 2)
if worker["worker"] == workers[0]
else {"0"}
)
assert all(_check_metric_and_variants(worker) for worker in res["workers"])
res = self.api.workers.get_stats( res = self.api.workers.get_stats(
items=[dict(key="cpu_usage", aggregation="avg")], items=[dict(key="cpu_usage", aggregation="avg")],
@ -217,11 +209,10 @@ class TestWorkersService(TestService):
interval=1, interval=1,
worker_ids=["Non existing worker id"], worker_ids=["Non existing worker id"],
) )
assert not res["workers"] assert not res.workers
@staticmethod def assertWorkersInStats(self, workers: Sequence[str], stats: Sequence):
def assertWorkersInStats(workers: Sequence[str], stats: dict): self.assertEqual(set(workers), set(item.worker for item in stats))
assert set(workers) == set(map(attrgetter("worker"), stats))
def test_get_activity_report(self): def test_get_activity_report(self):
# test no workers data # test no workers data
@ -238,6 +229,7 @@ class TestWorkersService(TestService):
to_date = utc_now_tz_aware() + timedelta(seconds=10) to_date = utc_now_tz_aware() + timedelta(seconds=10)
from_date = to_date - timedelta(minutes=1) from_date = to_date - timedelta(minutes=1)
time.sleep(5) # give to es time to refresh
# no variants # no variants
res = self.api.workers.get_activity_report( res = self.api.workers.get_activity_report(
from_date=from_date.timestamp(), to_date=to_date.timestamp(), interval=20 from_date=from_date.timestamp(), to_date=to_date.timestamp(), interval=20