mirror of
https://github.com/clearml/clearml-server
synced 2025-05-08 05:54:58 +00:00
Fix task scalars comparison aggregation
This commit is contained in:
parent
a392a0e6ff
commit
0722b20c1c
@ -4,8 +4,9 @@ from concurrent.futures import ThreadPoolExecutor
|
||||
from functools import partial
|
||||
from operator import itemgetter
|
||||
|
||||
from boltons.iterutils import bucketize
|
||||
from elasticsearch import Elasticsearch
|
||||
from typing import Sequence, Tuple, Callable
|
||||
from typing import Sequence, Tuple, Callable, Iterable
|
||||
|
||||
from mongoengine import Q
|
||||
|
||||
@ -21,9 +22,10 @@ log = config.logger(__file__)
|
||||
|
||||
|
||||
class EventMetrics:
|
||||
MAX_TASKS_COUNT = 100
|
||||
MAX_TASKS_COUNT = 50
|
||||
MAX_METRICS_COUNT = 200
|
||||
MAX_VARIANTS_COUNT = 500
|
||||
MAX_AGGS_ELEMENTS_COUNT = 50
|
||||
|
||||
def __init__(self, es: Elasticsearch):
|
||||
self.es = es
|
||||
@ -62,6 +64,11 @@ class EventMetrics:
|
||||
Compare scalar metrics for different tasks per metric and variant
|
||||
The amount of points in each histogram should not exceed the requested samples
|
||||
"""
|
||||
if len(task_ids) > self.MAX_TASKS_COUNT:
|
||||
raise errors.BadRequest(
|
||||
f"Up to {self.MAX_TASKS_COUNT} tasks supported for comparison", len(task_ids)
|
||||
)
|
||||
|
||||
task_name_by_id = {}
|
||||
with translate_errors_context():
|
||||
task_objs = Task.get_many(
|
||||
@ -97,6 +104,31 @@ class EventMetrics:
|
||||
MetricInterval = Tuple[int, Sequence[TaskMetric]]
|
||||
MetricData = Tuple[str, dict]
|
||||
|
||||
def _split_metrics_by_max_aggs_count(
|
||||
self, task_metrics: Sequence[TaskMetric]
|
||||
) -> Iterable[Sequence[TaskMetric]]:
|
||||
"""
|
||||
Return task metrics in groups where amount of task metrics in each group
|
||||
is roughly limited by MAX_AGGS_ELEMENTS_COUNT. The split is done on metrics and
|
||||
variants while always preserving all their tasks in the same group
|
||||
"""
|
||||
if len(task_metrics) < self.MAX_AGGS_ELEMENTS_COUNT:
|
||||
yield task_metrics
|
||||
return
|
||||
|
||||
tm_grouped = bucketize(task_metrics, key=itemgetter(1, 2))
|
||||
groups = []
|
||||
for group in tm_grouped.values():
|
||||
groups.append(group)
|
||||
if sum(map(len, groups)) >= self.MAX_AGGS_ELEMENTS_COUNT:
|
||||
yield list(itertools.chain(*groups))
|
||||
groups = []
|
||||
|
||||
if groups:
|
||||
yield list(itertools.chain(*groups))
|
||||
|
||||
return
|
||||
|
||||
def _run_get_scalar_metrics_as_parallel(
|
||||
self,
|
||||
company_id: str,
|
||||
@ -126,21 +158,27 @@ class EventMetrics:
|
||||
if not intervals:
|
||||
return {}
|
||||
|
||||
with ThreadPoolExecutor(len(intervals)) as pool:
|
||||
metrics = list(
|
||||
itertools.chain.from_iterable(
|
||||
pool.map(
|
||||
partial(
|
||||
get_func, task_ids=task_ids, es_index=es_index, key=key
|
||||
),
|
||||
intervals,
|
||||
)
|
||||
intervals = list(
|
||||
itertools.chain.from_iterable(
|
||||
zip(itertools.repeat(i), self._split_metrics_by_max_aggs_count(tms))
|
||||
for i, tms in intervals
|
||||
)
|
||||
)
|
||||
max_concurrency = config.get("services.events.max_metrics_concurrency", 4)
|
||||
with ThreadPoolExecutor(max_workers=max_concurrency) as pool:
|
||||
metrics = itertools.chain.from_iterable(
|
||||
pool.map(
|
||||
partial(
|
||||
get_func, task_ids=task_ids, es_index=es_index, key=key
|
||||
),
|
||||
intervals,
|
||||
)
|
||||
)
|
||||
|
||||
ret = defaultdict(dict)
|
||||
for metric_key, metric_values in metrics:
|
||||
ret[metric_key].update(metric_values)
|
||||
|
||||
return ret
|
||||
|
||||
def _get_metric_intervals(
|
||||
@ -310,7 +348,13 @@ class EventMetrics:
|
||||
"variants": {
|
||||
"terms": {"field": "variant", "size": self.MAX_VARIANTS_COUNT},
|
||||
"aggs": {
|
||||
"tasks": {"terms": {"field": "task"}, "aggs": aggregation}
|
||||
"tasks": {
|
||||
"terms": {
|
||||
"field": "task",
|
||||
"size": self.MAX_TASKS_COUNT,
|
||||
},
|
||||
"aggs": aggregation,
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
|
@ -2,4 +2,8 @@ es_index_prefix: "events"
|
||||
|
||||
ignore_iteration {
|
||||
metrics: [":monitor:machine", ":monitor:gpu"]
|
||||
}
|
||||
}
|
||||
|
||||
# max number of concurrent queries to ES when calculating events metrics
|
||||
# should not exceed the amount of concurrent connections set in the ES driver
|
||||
max_metrics_concurrency: 4
|
@ -2,9 +2,12 @@
|
||||
Comprehensive test of all(?) use cases of datasets and frames
|
||||
"""
|
||||
import json
|
||||
import time
|
||||
import unittest
|
||||
from statistics import mean
|
||||
|
||||
from typing import Sequence
|
||||
|
||||
import es_factory
|
||||
from config import config
|
||||
from tests.automated import TestService
|
||||
@ -16,69 +19,56 @@ class TestTaskEvents(TestService):
|
||||
def setUp(self, version="1.7"):
|
||||
super().setUp(version=version)
|
||||
|
||||
self.created_tasks = []
|
||||
|
||||
self.task = dict(
|
||||
name="test task events",
|
||||
type="training",
|
||||
input=dict(mapping={}, view=dict(entries=[])),
|
||||
def _temp_task(self, name="test task events"):
|
||||
task_input = dict(
|
||||
name=name, type="training", input=dict(mapping={}, view=dict(entries=[])),
|
||||
)
|
||||
res, self.task_id = self.api.send("tasks.create", self.task, extract="id")
|
||||
assert res.meta.result_code == 200
|
||||
self.created_tasks.append(self.task_id)
|
||||
return self.create_temp("tasks", **task_input)
|
||||
|
||||
def tearDown(self):
|
||||
log.info("Cleanup...")
|
||||
for task_id in self.created_tasks:
|
||||
try:
|
||||
self.api.send("tasks.delete", dict(task=task_id, force=True))
|
||||
except Exception as ex:
|
||||
log.exception(ex)
|
||||
|
||||
def create_task_event(self, type, iteration):
|
||||
def _create_task_event(self, type_, task, iteration):
|
||||
return {
|
||||
"worker": "test",
|
||||
"type": type,
|
||||
"task": self.task_id,
|
||||
"type": type_,
|
||||
"task": task,
|
||||
"iter": iteration,
|
||||
"timestamp": es_factory.get_timestamp_millis()
|
||||
"timestamp": es_factory.get_timestamp_millis(),
|
||||
}
|
||||
|
||||
def copy_and_update(self, src_obj, new_data):
|
||||
def _copy_and_update(self, src_obj, new_data):
|
||||
obj = src_obj.copy()
|
||||
obj.update(new_data)
|
||||
return obj
|
||||
|
||||
def test_task_logs(self):
|
||||
events = []
|
||||
for iter in range(10):
|
||||
log_event = self.create_task_event("log", iteration=iter)
|
||||
task = self._temp_task()
|
||||
for iter_ in range(10):
|
||||
log_event = self._create_task_event("log", task, iteration=iter_)
|
||||
events.append(
|
||||
self.copy_and_update(
|
||||
self._copy_and_update(
|
||||
log_event,
|
||||
{"msg": "This is a log message from test task iter " + str(iter)},
|
||||
{"msg": "This is a log message from test task iter " + str(iter_)},
|
||||
)
|
||||
)
|
||||
# sleep so timestamp is not the same
|
||||
import time
|
||||
|
||||
time.sleep(0.01)
|
||||
self.send_batch(events)
|
||||
|
||||
data = self.api.events.get_task_log(task=self.task_id)
|
||||
data = self.api.events.get_task_log(task=task)
|
||||
assert len(data["events"]) == 10
|
||||
|
||||
self.api.tasks.reset(task=self.task_id)
|
||||
data = self.api.events.get_task_log(task=self.task_id)
|
||||
self.api.tasks.reset(task=task)
|
||||
data = self.api.events.get_task_log(task=task)
|
||||
assert len(data["events"]) == 0
|
||||
|
||||
def test_task_metric_value_intervals_keys(self):
|
||||
metric = "Metric1"
|
||||
variant = "Variant1"
|
||||
iter_count = 100
|
||||
task = self._temp_task()
|
||||
events = [
|
||||
{
|
||||
**self.create_task_event("training_stats_scalar", iteration),
|
||||
**self._create_task_event("training_stats_scalar", task, iteration),
|
||||
"metric": metric,
|
||||
"variant": variant,
|
||||
"value": iteration,
|
||||
@ -88,19 +78,65 @@ class TestTaskEvents(TestService):
|
||||
self.send_batch(events)
|
||||
for key in None, "iter", "timestamp", "iso_time":
|
||||
with self.subTest(key=key):
|
||||
data = self.api.events.scalar_metrics_iter_histogram(task=self.task_id, key=key)
|
||||
data = self.api.events.scalar_metrics_iter_histogram(task=task, key=key)
|
||||
self.assertIn(metric, data)
|
||||
self.assertIn(variant, data[metric])
|
||||
self.assertIn("x", data[metric][variant])
|
||||
self.assertIn("y", data[metric][variant])
|
||||
|
||||
def test_multitask_events_many_metrics(self):
|
||||
tasks = [
|
||||
self._temp_task(name="test events1"),
|
||||
self._temp_task(name="test events2"),
|
||||
]
|
||||
iter_count = 10
|
||||
metrics_count = 10
|
||||
variants_count = 10
|
||||
events = [
|
||||
{
|
||||
**self._create_task_event("training_stats_scalar", task, iteration),
|
||||
"metric": f"Metric{metric_idx}",
|
||||
"variant": f"Variant{variant_idx}",
|
||||
"value": iteration,
|
||||
}
|
||||
for iteration in range(iter_count)
|
||||
for task in tasks
|
||||
for metric_idx in range(metrics_count)
|
||||
for variant_idx in range(variants_count)
|
||||
]
|
||||
self.send_batch(events)
|
||||
data = self.api.events.multi_task_scalar_metrics_iter_histogram(tasks=tasks)
|
||||
self._assert_metrics_and_variants(
|
||||
data.metrics,
|
||||
metrics=metrics_count,
|
||||
variants=variants_count,
|
||||
tasks=tasks,
|
||||
iterations=iter_count,
|
||||
)
|
||||
|
||||
def _assert_metrics_and_variants(
|
||||
self, data: dict, metrics: int, variants: int, tasks: Sequence, iterations: int
|
||||
):
|
||||
self.assertEqual(len(data), metrics)
|
||||
for m in range(metrics):
|
||||
metric_data = data[f"Metric{m}"]
|
||||
self.assertEqual(len(metric_data), variants)
|
||||
for v in range(variants):
|
||||
variant_data = metric_data[f"Variant{v}"]
|
||||
self.assertEqual(len(variant_data), len(tasks))
|
||||
for t in tasks:
|
||||
task_data = variant_data[t]
|
||||
self.assertEqual(len(task_data["x"]), iterations)
|
||||
self.assertEqual(len(task_data["y"]), iterations)
|
||||
|
||||
def test_task_metric_value_intervals(self):
|
||||
metric = "Metric1"
|
||||
variant = "Variant1"
|
||||
iter_count = 100
|
||||
task = self._temp_task()
|
||||
events = [
|
||||
{
|
||||
**self.create_task_event("training_stats_scalar", iteration),
|
||||
**self._create_task_event("training_stats_scalar", task, iteration),
|
||||
"metric": metric,
|
||||
"variant": variant,
|
||||
"value": iteration,
|
||||
@ -109,13 +145,13 @@ class TestTaskEvents(TestService):
|
||||
]
|
||||
self.send_batch(events)
|
||||
|
||||
data = self.api.events.scalar_metrics_iter_histogram(task=self.task_id)
|
||||
data = self.api.events.scalar_metrics_iter_histogram(task=task)
|
||||
self._assert_metrics_histogram(data[metric][variant], iter_count, 100)
|
||||
|
||||
data = self.api.events.scalar_metrics_iter_histogram(task=self.task_id, samples=100)
|
||||
data = self.api.events.scalar_metrics_iter_histogram(task=task, samples=100)
|
||||
self._assert_metrics_histogram(data[metric][variant], iter_count, 100)
|
||||
|
||||
data = self.api.events.scalar_metrics_iter_histogram(task=self.task_id, samples=10)
|
||||
data = self.api.events.scalar_metrics_iter_histogram(task=task, samples=10)
|
||||
self._assert_metrics_histogram(data[metric][variant], iter_count, 10)
|
||||
|
||||
def _assert_metrics_histogram(self, data, iters, samples):
|
||||
@ -130,7 +166,8 @@ class TestTaskEvents(TestService):
|
||||
)
|
||||
|
||||
def test_task_plots(self):
|
||||
event = self.create_task_event("plot", 0)
|
||||
task = self._temp_task()
|
||||
event = self._create_task_event("plot", task, 0)
|
||||
event["metric"] = "roc"
|
||||
event.update(
|
||||
{
|
||||
@ -179,7 +216,7 @@ class TestTaskEvents(TestService):
|
||||
)
|
||||
self.send(event)
|
||||
|
||||
event = self.create_task_event("plot", 100)
|
||||
event = self._create_task_event("plot", task, 100)
|
||||
event["metric"] = "confusion"
|
||||
event.update(
|
||||
{
|
||||
@ -222,11 +259,11 @@ class TestTaskEvents(TestService):
|
||||
)
|
||||
self.send(event)
|
||||
|
||||
data = self.api.events.get_task_plots(task=self.task_id)
|
||||
data = self.api.events.get_task_plots(task=task)
|
||||
assert len(data["plots"]) == 2
|
||||
|
||||
self.api.tasks.reset(task=self.task_id)
|
||||
data = self.api.events.get_task_plots(task=self.task_id)
|
||||
self.api.tasks.reset(task=task)
|
||||
data = self.api.events.get_task_plots(task=task)
|
||||
assert len(data["plots"]) == 0
|
||||
|
||||
def send_batch(self, events):
|
||||
|
Loading…
Reference in New Issue
Block a user