Support chart series per single resource in workers.get_stats

This commit is contained in:
clearml 2025-06-04 11:51:00 +03:00
parent 1983b22157
commit f3c67ac3fd
6 changed files with 303 additions and 199 deletions

View File

@ -12,7 +12,7 @@ from jsonmodels.fields import (
)
from jsonmodels.models import Base
from apiserver.apimodels import ListField, EnumField, JsonSerializableMixin
from apiserver.apimodels import ListField, EnumField, JsonSerializableMixin, ActualEnumField
from apiserver.config_repo import config
@ -130,7 +130,7 @@ class AggregationType(Enum):
class StatItem(Base):
key = StringField(required=True)
aggregation = EnumField(AggregationType, default=AggregationType.avg)
aggregation = ActualEnumField(AggregationType, default=AggregationType.avg)
class GetStatsRequest(StatsReportBase):
@ -138,17 +138,24 @@ class GetStatsRequest(StatsReportBase):
StatItem, required=True, validators=validators.Length(minimum_value=1)
)
split_by_variant = BoolField(default=False)
split_by_resource = BoolField(default=False)
class MetricResourceSeries(Base):
name = StringField()
values = ListField(float)
class AggregationStats(Base):
aggregation = EnumField(AggregationType)
dates = ListField(int)
values = ListField(float)
resource_series = ListField(MetricResourceSeries)
class MetricStats(Base):
metric = StringField()
variant = StringField()
dates = ListField(int)
stats = ListField(AggregationStats)

View File

@ -1,8 +1,9 @@
from operator import attrgetter
from collections import defaultdict
from concurrent.futures.thread import ThreadPoolExecutor
from functools import partial
from typing import Optional, Sequence
from boltons.iterutils import bucketize
from apiserver.apierrors import errors
from apiserver.apierrors.errors import bad_request
from apiserver.apimodels.workers import AggregationType, GetStatsRequest, StatItem
from apiserver.bll.query import Builder as QueryBuilder
@ -14,6 +15,7 @@ log = config.logger(__file__)
class WorkerStats:
min_chart_interval = config.get("services.workers.min_chart_interval_sec", 40)
_max_metrics_concurrency = config.get("services.events.events_retrieval.max_metrics_concurrency", 4)
def __init__(self, es):
self.es = es
@ -23,7 +25,7 @@ class WorkerStats:
"""Returns the es index prefix for the company"""
return f"worker_stats_{company_id.lower()}_"
def _search_company_stats(self, company_id: str, es_req: dict) -> dict:
def search_company_stats(self, company_id: str, es_req: dict) -> dict:
return self.es.search(
index=f"{self.worker_stats_prefix_for_company(company_id)}*",
body=es_req,
@ -51,7 +53,7 @@ class WorkerStats:
if worker_ids:
es_req["query"] = QueryBuilder.terms("worker", worker_ids)
res = self._search_company_stats(company_id, es_req)
res = self.search_company_stats(company_id, es_req)
if not res["hits"]["total"]["value"]:
raise bad_request.WorkerStatsNotFound(
@ -65,6 +67,75 @@ class WorkerStats:
for category in res["aggregations"]["categories"]["buckets"]
}
def _get_worker_stats_per_metric(
self,
metric_item: StatItem,
company_id: str,
from_date: float,
to_date: float,
interval: int,
split_by_resource: bool,
worker_ids: Sequence[str],
):
agg_types_to_es = {
AggregationType.avg: "avg",
AggregationType.min: "min",
AggregationType.max: "max",
}
agg = {
metric_item.aggregation.value: {
agg_types_to_es[metric_item.aggregation]: {"field": "value", "missing": 0.0 }
}
}
split_by_resource = split_by_resource and metric_item.key.startswith("gpu_")
if split_by_resource:
split_aggs = {"split": {"terms": {"field": "variant"}, "aggs": agg}}
else:
split_aggs = {}
es_req = {
"size": 0,
"aggs": {
"workers": {
"terms": {"field": "worker"},
"aggs": {
"dates": {
"date_histogram": {
"field": "timestamp",
"fixed_interval": f"{interval}s",
"extended_bounds": {
"min": int(from_date) * 1000,
"max": int(to_date) * 1000,
},
},
"aggs": {
**agg,
**split_aggs,
},
}
},
}
},
}
query_terms = [
QueryBuilder.dates_range(from_date, to_date),
QueryBuilder.term("metric", metric_item.key),
]
if worker_ids:
query_terms.append(QueryBuilder.terms("worker", worker_ids))
es_req["query"] = {"bool": {"must": query_terms}}
with translate_errors_context():
data = self.search_company_stats(company_id, es_req)
cutoff_date = (
to_date - 0.9 * interval
) * 1000 # do not return the point for the incomplete last interval
return self._extract_results(
data, metric_item, split_by_resource, cutoff_date
)
def get_worker_stats(self, company_id: str, request: GetStatsRequest) -> dict:
"""
Get statistics for company workers metrics in the specified time range
@ -76,123 +147,90 @@ class WorkerStats:
from_date = request.from_date
to_date = request.to_date
if from_date >= to_date:
raise bad_request.FieldsValueError("from_date must be less than to_date")
interval = max(request.interval, self.min_chart_interval)
def get_dates_agg() -> dict:
es_to_agg_types = (
("avg", AggregationType.avg.value),
("min", AggregationType.min.value),
("max", AggregationType.max.value),
raise errors.bad_request.FieldsValueError(
"from_date must be less than to_date"
)
return {
"dates": {
"date_histogram": {
"field": "timestamp",
"fixed_interval": f"{interval}s",
"extended_bounds": {
"min": int(from_date) * 1000,
"max": int(to_date) * 1000,
}
},
"aggs": {
agg_type: {es_agg: {"field": "value"}}
for es_agg, agg_type in es_to_agg_types
},
}
}
interval = max(request.interval, self.min_chart_interval)
with ThreadPoolExecutor(self._max_metrics_concurrency) as pool:
res = list(
pool.map(
partial(
self._get_worker_stats_per_metric,
company_id=company_id,
from_date=from_date,
to_date=to_date,
interval=interval,
split_by_resource=request.split_by_resource,
worker_ids=request.worker_ids,
),
request.items,
)
)
def get_variants_agg() -> dict:
return {
"variants": {"terms": {"field": "variant"}, "aggs": get_dates_agg()}
}
ret = defaultdict(lambda: defaultdict(dict))
for workers in res:
for worker, metrics in workers.items():
for metric, stats in metrics.items():
ret[worker][metric].update(stats)
es_req = {
"size": 0,
"aggs": {
"workers": {
"terms": {"field": "worker"},
"aggs": {
"metrics": {
"terms": {"field": "metric"},
"aggs": get_variants_agg()
if request.split_by_variant
else get_dates_agg(),
}
},
}
},
}
query_terms = [
QueryBuilder.dates_range(from_date, to_date),
QueryBuilder.terms("metric", {item.key for item in request.items}),
]
if request.worker_ids:
query_terms.append(QueryBuilder.terms("worker", request.worker_ids))
es_req["query"] = {"bool": {"must": query_terms}}
with translate_errors_context():
data = self._search_company_stats(company_id, es_req)
cutoff_date = (to_date - 0.9 * interval) * 1000 # do not return the point for the incomplete last interval
return self._extract_results(data, request.items, request.split_by_variant, cutoff_date)
return ret
@staticmethod
def _extract_results(
data: dict, request_items: Sequence[StatItem], split_by_variant: bool, cutoff_date
data: dict,
metric_item: StatItem,
split_by_resource: bool,
cutoff_date,
) -> dict:
"""
Clean results returned from elastic search (remove "aggregations", "buckets" etc.),
leave only aggregation types requested by the user and return a clean dictionary
:param data: aggregation data retrieved from ES
:param request_items: aggs types requested by the user
:param split_by_variant: if False then aggregate by metric type, otherwise metric type + variant
"""
if "aggregations" not in data:
return {}
items_by_key = bucketize(request_items, key=attrgetter("key"))
aggs_per_metric = {
key: [item.aggregation for item in items]
for key, items in items_by_key.items()
}
def extract_metric_results(metric: dict) -> dict:
aggregation = metric_item.aggregation.value
date_buckets = metric["dates"]["buckets"]
length = len(date_buckets)
while length > 0 and date_buckets[length - 1]["key"] >= cutoff_date:
length -= 1
dates = [None] * length
agg_values = [0.0] * length
resource_series = defaultdict(lambda: [0.0] * length)
for idx in range(0, length):
date = date_buckets[idx]
dates[idx] = date["key"]
if aggregation in date:
agg_values[idx] = date[aggregation]["value"] or 0.0
if split_by_resource and "split" in date:
for resource in date["split"]["buckets"]:
series = resource_series[resource["key"]]
if aggregation in resource:
series[idx] = resource[aggregation]["value"] or 0.0
if len(resource_series) == 1:
resource_series = {}
def extract_date_stats(date: dict, metric_key) -> dict:
return {
"date": date["key"],
"count": date["doc_count"],
**{agg: date[agg]["value"] or 0.0 for agg in aggs_per_metric[metric_key]},
}
def extract_metric_results(
metric_or_variant: dict, metric_key: str
) -> Sequence[dict]:
return [
extract_date_stats(date, metric_key)
for date in metric_or_variant["dates"]["buckets"]
if date["key"] <= cutoff_date
]
def extract_variant_results(metric: dict) -> dict:
metric_key = metric["key"]
return {
variant["key"]: extract_metric_results(variant, metric_key)
for variant in metric["variants"]["buckets"]
}
def extract_worker_results(worker: dict) -> dict:
return {
metric["key"]: extract_variant_results(metric)
if split_by_variant
else extract_metric_results(metric, metric["key"])
for metric in worker["metrics"]["buckets"]
"dates": dates,
"values": agg_values,
**(
{"resource_series": resource_series} if resource_series else {}
),
}
return {
worker["key"]: extract_worker_results(worker)
worker["key"]: {
metric_item.key: {
metric_item.aggregation.value: extract_metric_results(worker)
}
}
for worker in data["aggregations"]["workers"]["buckets"]
}
@ -237,7 +275,7 @@ class WorkerStats:
}
with translate_errors_context():
data = self._search_company_stats(company_id, es_req)
data = self.search_company_stats(company_id, es_req)
if "aggregations" not in data:
return {}

View File

@ -3,3 +3,7 @@ default_cluster_timeout_sec: 600
# The minimal sampling interval for resource dashboard and worker activity charts
min_chart_interval_sec: 40
stats {
max_metrics_concurrency: 4
}

View File

@ -15,6 +15,26 @@ _definitions {
}
}
}
worker_stat_key {
type: string
enum: [
cpu_usage
cpu_temperature
memory_used
memory_free
gpu_usage
gpu_temperature
gpu_fraction
gpu_memory_free
gpu_memory_used
network_tx
network_rx
disk_free_home
disk_free_temp
disk_read
disk_write
]
}
aggregation_type {
type: string
enum: [ avg, min, max ]
@ -23,8 +43,7 @@ _definitions {
stat_item {
type: object
properties {
key {
type: string
key: ${_definitions.worker_stat_key} {
description: "Name of a metric"
}
category {
@ -38,6 +57,30 @@ _definitions {
aggregation {
"$ref": "#/definitions/aggregation_type"
}
dates {
type: array
description: "List of timestamps (in seconds from epoch) in the acceding order. The timestamps are separated by the requested interval. Timestamps where no workers activity was recorded are omitted."
items { type: integer }
}
values {
type: array
description: "List of values corresponding to the dates in metric statistics"
items { type: number }
}
resource_series {
type: array
description: "Metric data per single resource. Return only if split_by_resource request parameter is set to True"
items {"$ref": "#/definitions/metric_resource_series"}
}
}
}
metric_resource_series {
type: object
properties {
name {
type: string
description: Resource name
}
values {
type: array
description: "List of values corresponding to the dates in metric statistics"
@ -56,11 +99,6 @@ _definitions {
type: string
description: "Name of the metric component. Set only if 'split_by_variant' was set in the request"
}
dates {
type: array
description: "List of timestamps (in seconds from epoch) in the acceding order. The timestamps are separated by the requested interval. Timestamps where no workers activity was recorded are omitted."
items { type: integer }
}
stats {
type: array
description: "Statistics data by type"
@ -482,6 +520,20 @@ get_stats {
}
}
}
"2.32": ${get_stats."2.4"} {
request.properties {
split_by_variant {
description: "Obsolete, please do not use"
type: boolean
default: false
}
split_by_resource {
type: boolean
default: false
description: If set then for GPU related keys return the per GPU charts in addition to the aggregated one
}
}
}
}
get_activity_report {
"2.4" {

View File

@ -1,9 +1,3 @@
import itertools
from operator import attrgetter
from typing import Optional, Sequence, Union
from boltons.iterutils import bucketize
from apiserver.apierrors.errors import bad_request
from apiserver.apimodels.workers import (
WorkerRequest,
@ -23,6 +17,7 @@ from apiserver.apimodels.workers import (
GetActivityReportResponse,
ActivityReportSeries,
GetCountRequest,
MetricResourceSeries,
)
from apiserver.bll.workers import WorkerBLL
from apiserver.config_repo import config
@ -163,71 +158,47 @@ def get_activity_report(
@endpoint(
"workers.get_stats",
min_version="2.4",
request_data_model=GetStatsRequest,
response_data_model=GetStatsResponse,
validate_schema=True,
)
def get_stats(call: APICall, company_id, request: GetStatsRequest):
ret = worker_bll.stats.get_worker_stats(company_id, request)
def _get_variant_metric_stats(
metric: str,
agg_names: Sequence[str],
stats: Sequence[dict],
variant: Optional[str] = None,
) -> MetricStats:
stat_by_name = extract_properties_to_lists(agg_names, stats)
return MetricStats(
metric=metric,
variant=variant,
dates=stat_by_name["date"],
stats=[
AggregationStats(aggregation=name, values=aggs)
for name, aggs in stat_by_name.items()
if name != "date"
],
)
def _get_metric_stats(
metric: str, stats: Union[dict, Sequence[dict]], agg_types: Sequence[str]
) -> Sequence[MetricStats]:
"""
Return statistics for a certain metric or a list of statistic for
metric variants if break_by_variant was requested
"""
agg_names = ["date"] + list(set(agg_types))
if not isinstance(stats, dict):
# no variants were requested
return [_get_variant_metric_stats(metric, agg_names, stats)]
return [
_get_variant_metric_stats(metric, agg_names, variant_stats, variant)
for variant, variant_stats in stats.items()
]
def _get_worker_metrics(stats: dict) -> Sequence[MetricStats]:
"""
Convert the worker statistics data from the internal format of lists of structs
to a more "compact" format for json transfer (arrays of dates and arrays of values)
"""
# removed metrics that were requested but for some reason
# do not exist in stats data
metrics = [metric for metric in request.items if metric.key in stats]
aggs_by_metric = bucketize(
metrics, key=attrgetter("key"), value_transform=attrgetter("aggregation")
)
return list(
itertools.chain.from_iterable(
_get_metric_stats(metric, metric_stats, aggs_by_metric[metric])
for metric, metric_stats in stats.items()
)
def _get_agg_stats(
aggregation: str,
stats: dict,
) -> AggregationStats:
resource_series = []
if "resource_series" in stats:
for name, values in stats["resource_series"].items():
resource_series.append(
MetricResourceSeries(
name=name,
values=values
)
)
return AggregationStats(
aggregation=aggregation,
dates=stats["dates"],
values=stats["values"],
resource_series=resource_series,
)
return GetStatsResponse(
workers=[
WorkerStatistics(worker=worker, metrics=_get_worker_metrics(stats))
for worker, stats in ret.items()
WorkerStatistics(
worker=worker,
metrics=[
MetricStats(
metric=metric,
stats=[
_get_agg_stats(aggregation, a_stats)
for aggregation, a_stats in m_stats.items()
]
)
for metric, m_stats in w_stats.items()
],
)
for worker, w_stats in ret.items()
]
)

View File

@ -1,3 +1,4 @@
import statistics
import time
from uuid import uuid4
from typing import Sequence
@ -83,7 +84,7 @@ class TestWorkersService(TestService):
self._check_exists(test_worker, False, tags=["test"])
self._check_exists(test_worker, False, tags=["-application"])
def _simulate_workers(self, start: int) -> Sequence[str]:
def _simulate_workers(self, start: int, with_gpu: bool = False) -> dict:
"""
Two workers writing the same metrics. One for 4 seconds. Another one for 2
The first worker reports a task
@ -93,20 +94,25 @@ class TestWorkersService(TestService):
task_id = self._create_running_task(task_name="task-1")
workers = [f"test_{uuid4().hex}", f"test_{uuid4().hex}"]
workers_stats = [
if with_gpu:
gpu_usage = [dict(gpu_usage=[60, 70]), dict(gpu_usage=[40])]
else:
gpu_usage = [{}, {}]
worker_stats = [
(
dict(cpu_usage=[10, 20], memory_used=50),
dict(cpu_usage=[5], memory_used=30),
dict(cpu_usage=[10, 20], memory_used=50, **gpu_usage[0]),
dict(cpu_usage=[5], memory_used=30, **gpu_usage[1]),
)
] * 4
workers_activity = [
worker_activity = [
(workers[0], workers[1]),
(workers[0], workers[1]),
(workers[0],),
(workers[0],),
]
timestamp = start * 1000
for ws, stats in zip(workers_activity, workers_stats):
for ws, stats in zip(worker_activity, worker_stats):
for w, s in zip(ws, stats):
data = dict(
worker=w,
@ -118,7 +124,10 @@ class TestWorkersService(TestService):
self.api.workers.status_report(**data)
timestamp += 60*1000
return workers
return {
w: s
for w, s in zip(workers, worker_stats[0])
}
def _create_running_task(self, task_name):
task_input = dict(name=task_name, type="testing")
@ -131,7 +140,7 @@ class TestWorkersService(TestService):
def test_get_keys(self):
workers = self._simulate_workers(int(time.time()))
time.sleep(5) # give to es time to refresh
res = self.api.workers.get_metric_keys(worker_ids=workers)
res = self.api.workers.get_metric_keys(worker_ids=list(workers))
assert {"cpu", "memory"} == set(c.name for c in res["categories"])
assert all(
c.metric_keys == ["cpu_usage"] for c in res["categories"] if c.name == "cpu"
@ -147,7 +156,7 @@ class TestWorkersService(TestService):
def test_get_stats(self):
start = int(time.time())
workers = self._simulate_workers(start)
workers = self._simulate_workers(start, with_gpu=True)
time.sleep(5) # give to ES time to refresh
from_date = start
@ -157,49 +166,72 @@ class TestWorkersService(TestService):
items=[
dict(key="cpu_usage", aggregation="avg"),
dict(key="cpu_usage", aggregation="max"),
dict(key="gpu_usage", aggregation="avg"),
dict(key="gpu_usage", aggregation="max"),
dict(key="memory_used", aggregation="max"),
dict(key="memory_used", aggregation="min"),
],
from_date=from_date,
to_date=to_date,
# split_by_variant=True,
interval=1,
worker_ids=workers,
worker_ids=list(workers),
)
self.assertWorkersInStats(workers, res.workers)
self.assertWorkersInStats(list(workers), res.workers)
for worker in res.workers:
self.assertEqual(
set(metric.metric for metric in worker.metrics),
{"cpu_usage", "memory_used"},
{"cpu_usage", "gpu_usage", "memory_used"},
)
for worker in res.workers:
worker_id = worker.worker
for metric, metric_stats in zip(
worker.metrics, ({"avg", "max"}, {"max", "min"})
worker.metrics, ({"avg", "max"}, {"avg", "max"}, {"max"})
):
metric_name = metric.metric
self.assertEqual(
set(stat.aggregation for stat in metric.stats), metric_stats
)
self.assertTrue(11 >= len(metric.dates) >= 10)
for stat in metric.stats:
expected = workers[worker_id][metric_name]
self.assertTrue(11 >= len(stat.dates) >= 10)
self.assertFalse(stat.get("resource_series"))
agg = stat.aggregation
if isinstance(expected, list):
if agg == "avg":
val = statistics.mean(expected)
elif agg == "min":
val = min(expected)
else:
val = max(expected)
else:
val = expected
self.assertEqual(set(stat["values"]), {val, 0})
# split by variants
# split by resources
res = self.api.workers.get_stats(
items=[dict(key="cpu_usage", aggregation="avg")],
items=[dict(key="gpu_usage", aggregation="avg")],
from_date=from_date,
to_date=to_date,
split_by_variant=True,
split_by_resource=True,
interval=1,
worker_ids=workers,
worker_ids=list(workers),
)
self.assertWorkersInStats(workers, res.workers)
self.assertWorkersInStats(list(workers), res.workers)
for worker in res.workers:
worker_id = worker.worker
for metric in worker.metrics:
self.assertEqual(
set(metric.variant for metric in worker.metrics),
{"0", "1"} if worker.worker == workers[0] else {"0"},
)
self.assertTrue(11 >= len(metric.dates) >= 10)
metric_name = metric.metric
for stat in metric.stats:
expected = workers[worker_id][metric_name]
if metric_name.startswith("gpu") and len(expected) > 1:
resource_series = stat.get("resource_series")
self.assertEqual(len(resource_series), len(expected))
for rs, value in zip(resource_series, expected):
self.assertEqual(set(rs["values"]), {value, 0})
else:
self.assertEqual(stat.get("resource_series"), [])
res = self.api.workers.get_stats(
items=[dict(key="cpu_usage", aggregation="avg")],