clearml-server/apiserver/bll/event/event_metrics.py

505 lines
17 KiB
Python

import itertools
import math
from collections import defaultdict
from concurrent.futures.thread import ThreadPoolExecutor
from functools import partial
from operator import itemgetter
from typing import Sequence, Tuple, Mapping
from boltons.iterutils import bucketize
from elasticsearch import Elasticsearch
from mongoengine import Q
from apiserver.apierrors import errors
from apiserver.bll.event.event_common import (
EventType,
EventSettings,
search_company_events,
check_empty_data,
MetricVariants,
get_metric_variants_condition,
get_max_metric_and_variant_counts,
SINGLE_SCALAR_ITERATION,
)
from apiserver.bll.event.scalar_key import ScalarKey, ScalarKeyEnum
from apiserver.config_repo import config
from apiserver.database.errors import translate_errors_context
from apiserver.database.model.task.task import Task
from apiserver.timing_context import TimingContext
from apiserver.tools import safe_get
log = config.logger(__file__)
class EventMetrics:
MAX_AGGS_ELEMENTS_COUNT = 50
MAX_SAMPLE_BUCKETS = 6000
def __init__(self, es: Elasticsearch):
self.es = es
def get_scalar_metrics_average_per_iter(
self,
company_id: str,
task_id: str,
samples: int,
key: ScalarKeyEnum,
metric_variants: MetricVariants = None,
) -> dict:
"""
Get scalar metric histogram per metric and variant
The amount of points in each histogram should not exceed
the requested samples
"""
event_type = EventType.metrics_scalar
if check_empty_data(self.es, company_id=company_id, event_type=event_type):
return {}
return self._get_scalar_average_per_iter_core(
task_id=task_id,
company_id=company_id,
event_type=event_type,
samples=samples,
key=ScalarKey.resolve(key),
metric_variants=metric_variants,
)
def _get_scalar_average_per_iter_core(
self,
task_id: str,
company_id: str,
event_type: EventType,
samples: int,
key: ScalarKey,
run_parallel: bool = True,
metric_variants: MetricVariants = None,
) -> dict:
intervals = self._get_task_metric_intervals(
company_id=company_id,
event_type=event_type,
task_id=task_id,
samples=samples,
field=key.field,
metric_variants=metric_variants,
)
if not intervals:
return {}
interval_groups = self._group_task_metric_intervals(intervals)
get_scalar_average = partial(
self._get_scalar_average,
task_id=task_id,
company_id=company_id,
event_type=event_type,
key=key,
)
if run_parallel:
with ThreadPoolExecutor(max_workers=EventSettings.max_workers) as pool:
metrics = itertools.chain.from_iterable(
pool.map(get_scalar_average, interval_groups)
)
else:
metrics = itertools.chain.from_iterable(
get_scalar_average(group) for group in interval_groups
)
ret = defaultdict(dict)
for metric_key, metric_values in metrics:
ret[metric_key].update(metric_values)
return ret
def compare_scalar_metrics_average_per_iter(
self,
company_id,
task_ids: Sequence[str],
samples,
key: ScalarKeyEnum,
allow_public=True,
):
"""
Compare scalar metrics for different tasks per metric and variant
The amount of points in each histogram should not exceed the requested samples
"""
task_name_by_id = {}
with translate_errors_context():
task_objs = Task.get_many(
company=company_id,
query=Q(id__in=task_ids),
allow_public=allow_public,
override_projection=("id", "name", "company", "company_origin"),
return_dicts=False,
)
if len(task_objs) < len(task_ids):
invalid = tuple(set(task_ids) - set(r.id for r in task_objs))
raise errors.bad_request.InvalidTaskId(company=company_id, ids=invalid)
task_name_by_id = {t.id: t.name for t in task_objs}
companies = {t.get_index_company() for t in task_objs}
if len(companies) > 1:
raise errors.bad_request.InvalidTaskId(
"only tasks from the same company are supported"
)
event_type = EventType.metrics_scalar
company_id = next(iter(companies))
if check_empty_data(self.es, company_id=company_id, event_type=event_type):
return {}
get_scalar_average_per_iter = partial(
self._get_scalar_average_per_iter_core,
company_id=company_id,
event_type=event_type,
samples=samples,
key=ScalarKey.resolve(key),
run_parallel=False,
)
with ThreadPoolExecutor(max_workers=EventSettings.max_workers) as pool:
task_metrics = zip(
task_ids, pool.map(get_scalar_average_per_iter, task_ids)
)
res = defaultdict(lambda: defaultdict(dict))
for task_id, task_data in task_metrics:
task_name = task_name_by_id[task_id]
for metric_key, metric_data in task_data.items():
for variant_key, variant_data in metric_data.items():
variant_data["name"] = task_name
res[metric_key][variant_key][task_id] = variant_data
return res
def get_task_single_value_metrics(
self, company_id: str, task_ids: Sequence[str]
) -> Mapping[str, dict]:
"""
For the requested tasks return all the events delivered for the single iteration (-2**31)
"""
if check_empty_data(
self.es, company_id=company_id, event_type=EventType.metrics_scalar
):
return {}
with TimingContext("es", "get_task_single_value_metrics"):
task_events = self._get_task_single_value_metrics(company_id, task_ids)
def _get_value(event: dict):
return {
field: event.get(field)
for field in ("metric", "variant", "value", "timestamp")
}
return {
task: [_get_value(e) for e in events]
for task, events in bucketize(task_events, itemgetter("task")).items()
}
def _get_task_single_value_metrics(
self, company_id: str, task_ids: Sequence[str]
) -> Sequence[dict]:
es_req = {
"size": 10000,
"query": {
"bool": {
"must": [
{"terms": {"task": task_ids}},
{"term": {"iter": SINGLE_SCALAR_ITERATION}},
]
}
},
}
with translate_errors_context():
es_res = search_company_events(
body=es_req,
es=self.es,
company_id=company_id,
event_type=EventType.metrics_scalar,
)
if not es_res["hits"]["total"]["value"]:
return []
return [hit["_source"] for hit in es_res["hits"]["hits"]]
MetricInterval = Tuple[str, str, int, int]
MetricIntervalGroup = Tuple[int, Sequence[Tuple[str, str]]]
@classmethod
def _group_task_metric_intervals(
cls, intervals: Sequence[MetricInterval]
) -> Sequence[MetricIntervalGroup]:
"""
Group task metric intervals so that the following conditions are meat:
- All the metrics in the same group have the same interval (with 10% rounding)
- The amount of metrics in the group does not exceed MAX_AGGS_ELEMENTS_COUNT
- The total count of samples in the group does not exceed MAX_SAMPLE_BUCKETS
"""
metric_interval_groups = []
interval_group = []
group_interval_upper_bound = 0
group_max_interval = 0
group_samples = 0
for metric, variant, interval, size in sorted(intervals, key=itemgetter(2)):
if (
interval > group_interval_upper_bound
or (group_samples + size) > cls.MAX_SAMPLE_BUCKETS
or len(interval_group) >= cls.MAX_AGGS_ELEMENTS_COUNT
):
if interval_group:
metric_interval_groups.append((group_max_interval, interval_group))
interval_group = []
group_max_interval = interval
group_interval_upper_bound = interval + int(interval * 0.1)
group_samples = 0
interval_group.append((metric, variant))
group_samples += size
group_max_interval = max(group_max_interval, interval)
if interval_group:
metric_interval_groups.append((group_max_interval, interval_group))
return metric_interval_groups
def _get_task_metric_intervals(
self,
company_id: str,
event_type: EventType,
task_id: str,
samples: int,
field: str = "iter",
metric_variants: MetricVariants = None,
) -> Sequence[MetricInterval]:
"""
Calculate interval per task metric variant so that the resulting
amount of points does not exceed sample.
Return the list og metric variant intervals as the following tuple:
(metric, variant, interval, samples)
"""
must = self._task_conditions(task_id)
if metric_variants:
must.append(get_metric_variants_condition(metric_variants))
query = {"bool": {"must": must}}
search_args = dict(
es=self.es, company_id=company_id, event_type=event_type
)
max_metrics, max_variants = get_max_metric_and_variant_counts(
query=query, **search_args,
)
max_variants = int(max_variants // 2)
es_req = {
"size": 0,
"query": query,
"aggs": {
"metrics": {
"terms": {
"field": "metric",
"size": max_metrics,
"order": {"_key": "asc"},
},
"aggs": {
"variants": {
"terms": {
"field": "variant",
"size": max_variants,
"order": {"_key": "asc"},
},
"aggs": {
"count": {"value_count": {"field": field}},
"min_index": {"min": {"field": field}},
"max_index": {"max": {"field": field}},
},
}
},
}
},
}
with translate_errors_context(), TimingContext("es", "task_stats_get_interval"):
es_res = search_company_events(body=es_req, **search_args)
aggs_result = es_res.get("aggregations")
if not aggs_result:
return []
return [
self._build_metric_interval(metric["key"], variant["key"], variant, samples)
for metric in aggs_result["metrics"]["buckets"]
for variant in metric["variants"]["buckets"]
]
@staticmethod
def _build_metric_interval(
metric: str, variant: str, data: dict, samples: int
) -> Tuple[str, str, int, int]:
"""
Calculate index interval per metric_variant variant so that the
total amount of intervals does not exceeds the samples
Return the interval and resulting amount of intervals
"""
count = safe_get(data, "count/value", default=0)
if count < samples:
return metric, variant, 1, count
min_index = safe_get(data, "min_index/value", default=0)
max_index = safe_get(data, "max_index/value", default=min_index)
index_range = max_index - min_index + 1
interval = max(1, math.ceil(float(index_range) / samples))
max_samples = math.ceil(float(index_range) / interval)
return (
metric,
variant,
interval,
max_samples,
)
MetricData = Tuple[str, dict]
def _get_scalar_average(
self,
metrics_interval: MetricIntervalGroup,
task_id: str,
company_id: str,
event_type: EventType,
key: ScalarKey,
) -> Sequence[MetricData]:
"""
Retrieve scalar histograms per several metric variants that share the same interval
"""
interval, metrics = metrics_interval
aggregation = self._add_aggregation_average(key.get_aggregation(interval))
query = self._get_task_metrics_query(task_id=task_id, metrics=metrics)
search_args = dict(
es=self.es, company_id=company_id, event_type=event_type
)
max_metrics, max_variants = get_max_metric_and_variant_counts(
query=query, **search_args,
)
max_variants = int(max_variants // 2)
es_req = {
"size": 0,
"query": query,
"aggs": {
"metrics": {
"terms": {
"field": "metric",
"size": max_metrics,
"order": {"_key": "asc"},
},
"aggs": {
"variants": {
"terms": {
"field": "variant",
"size": max_variants,
"order": {"_key": "asc"},
},
"aggs": aggregation,
}
},
}
},
}
with translate_errors_context():
es_res = search_company_events(body=es_req, **search_args)
aggs_result = es_res.get("aggregations")
if not aggs_result:
return {}
metrics = [
(
metric["key"],
{
variant["key"]: {
"name": variant["key"],
**key.get_iterations_data(variant),
}
for variant in metric["variants"]["buckets"]
},
)
for metric in aggs_result["metrics"]["buckets"]
]
return metrics
@staticmethod
def _add_aggregation_average(aggregation):
average_agg = {"avg_val": {"avg": {"field": "value"}}}
return {
key: {**value, "aggs": {**value.get("aggs", {}), **average_agg}}
for key, value in aggregation.items()
}
@staticmethod
def _task_conditions(task_id: str) -> list:
return [
{"term": {"task": task_id}},
{"range": {"iter": {"gt": SINGLE_SCALAR_ITERATION}}},
]
@classmethod
def _get_task_metrics_query(
cls, task_id: str, metrics: Sequence[Tuple[str, str]],
):
must = cls._task_conditions(task_id)
if metrics:
should = [
{
"bool": {
"must": [
{"term": {"metric": metric}},
{"term": {"variant": variant}},
]
}
}
for metric, variant in metrics
]
must.append({"bool": {"should": should}})
return {"bool": {"must": must}}
def get_task_metrics(
self, company_id, task_ids: Sequence, event_type: EventType
) -> Sequence:
"""
For the requested tasks return all the metrics that
reported events of the requested types
"""
if check_empty_data(self.es, company_id, event_type):
return {}
with ThreadPoolExecutor(EventSettings.max_workers) as pool:
res = pool.map(
partial(
self._get_task_metrics,
company_id=company_id,
event_type=event_type,
),
task_ids,
)
return list(zip(task_ids, res))
def _get_task_metrics(
self, task_id: str, company_id: str, event_type: EventType
) -> Sequence:
es_req = {
"size": 0,
"query": {"bool": {"must": self._task_conditions(task_id)}},
"aggs": {
"metrics": {
"terms": {
"field": "metric",
"size": EventSettings.max_es_buckets,
"order": {"_key": "asc"},
}
}
},
}
with translate_errors_context(), TimingContext("es", "_get_task_metrics"):
es_res = search_company_events(
self.es, company_id=company_id, event_type=event_type, body=es_req
)
return [
metric["key"]
for metric in safe_get(es_res, "aggregations/metrics/buckets", default=[])
]