From 0722b20c1cd7c32d70781f91b8d4deaf61218fb3 Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Tue, 4 Feb 2020 18:16:27 +0200 Subject: [PATCH] Fix task scalars comparison aggregation --- server/bll/event/event_metrics.py | 68 ++++++++++-- server/config/default/services/events.conf | 6 +- server/tests/automated/test_task_events.py | 121 ++++++++++++++------- 3 files changed, 140 insertions(+), 55 deletions(-) diff --git a/server/bll/event/event_metrics.py b/server/bll/event/event_metrics.py index 57bcff8..a90ab3a 100644 --- a/server/bll/event/event_metrics.py +++ b/server/bll/event/event_metrics.py @@ -4,8 +4,9 @@ from concurrent.futures import ThreadPoolExecutor from functools import partial from operator import itemgetter +from boltons.iterutils import bucketize from elasticsearch import Elasticsearch -from typing import Sequence, Tuple, Callable +from typing import Sequence, Tuple, Callable, Iterable from mongoengine import Q @@ -21,9 +22,10 @@ log = config.logger(__file__) class EventMetrics: - MAX_TASKS_COUNT = 100 + MAX_TASKS_COUNT = 50 MAX_METRICS_COUNT = 200 MAX_VARIANTS_COUNT = 500 + MAX_AGGS_ELEMENTS_COUNT = 50 def __init__(self, es: Elasticsearch): self.es = es @@ -62,6 +64,11 @@ class EventMetrics: Compare scalar metrics for different tasks per metric and variant The amount of points in each histogram should not exceed the requested samples """ + if len(task_ids) > self.MAX_TASKS_COUNT: + raise errors.BadRequest( + f"Up to {self.MAX_TASKS_COUNT} tasks supported for comparison", len(task_ids) + ) + task_name_by_id = {} with translate_errors_context(): task_objs = Task.get_many( @@ -97,6 +104,31 @@ class EventMetrics: MetricInterval = Tuple[int, Sequence[TaskMetric]] MetricData = Tuple[str, dict] + def _split_metrics_by_max_aggs_count( + self, task_metrics: Sequence[TaskMetric] + ) -> Iterable[Sequence[TaskMetric]]: + """ + Return task metrics in groups where amount of task metrics in each group + is roughly limited by MAX_AGGS_ELEMENTS_COUNT. The split is done on metrics and + variants while always preserving all their tasks in the same group + """ + if len(task_metrics) < self.MAX_AGGS_ELEMENTS_COUNT: + yield task_metrics + return + + tm_grouped = bucketize(task_metrics, key=itemgetter(1, 2)) + groups = [] + for group in tm_grouped.values(): + groups.append(group) + if sum(map(len, groups)) >= self.MAX_AGGS_ELEMENTS_COUNT: + yield list(itertools.chain(*groups)) + groups = [] + + if groups: + yield list(itertools.chain(*groups)) + + return + def _run_get_scalar_metrics_as_parallel( self, company_id: str, @@ -126,21 +158,27 @@ class EventMetrics: if not intervals: return {} - with ThreadPoolExecutor(len(intervals)) as pool: - metrics = list( - itertools.chain.from_iterable( - pool.map( - partial( - get_func, task_ids=task_ids, es_index=es_index, key=key - ), - intervals, - ) + intervals = list( + itertools.chain.from_iterable( + zip(itertools.repeat(i), self._split_metrics_by_max_aggs_count(tms)) + for i, tms in intervals + ) + ) + max_concurrency = config.get("services.events.max_metrics_concurrency", 4) + with ThreadPoolExecutor(max_workers=max_concurrency) as pool: + metrics = itertools.chain.from_iterable( + pool.map( + partial( + get_func, task_ids=task_ids, es_index=es_index, key=key + ), + intervals, ) ) ret = defaultdict(dict) for metric_key, metric_values in metrics: ret[metric_key].update(metric_values) + return ret def _get_metric_intervals( @@ -310,7 +348,13 @@ class EventMetrics: "variants": { "terms": {"field": "variant", "size": self.MAX_VARIANTS_COUNT}, "aggs": { - "tasks": {"terms": {"field": "task"}, "aggs": aggregation} + "tasks": { + "terms": { + "field": "task", + "size": self.MAX_TASKS_COUNT, + }, + "aggs": aggregation, + } }, } }, diff --git a/server/config/default/services/events.conf b/server/config/default/services/events.conf index 660f204..91f5810 100644 --- a/server/config/default/services/events.conf +++ b/server/config/default/services/events.conf @@ -2,4 +2,8 @@ es_index_prefix: "events" ignore_iteration { metrics: [":monitor:machine", ":monitor:gpu"] -} \ No newline at end of file +} + +# max number of concurrent queries to ES when calculating events metrics +# should not exceed the amount of concurrent connections set in the ES driver +max_metrics_concurrency: 4 \ No newline at end of file diff --git a/server/tests/automated/test_task_events.py b/server/tests/automated/test_task_events.py index bcd5a60..4e43934 100644 --- a/server/tests/automated/test_task_events.py +++ b/server/tests/automated/test_task_events.py @@ -2,9 +2,12 @@ Comprehensive test of all(?) use cases of datasets and frames """ import json +import time import unittest from statistics import mean +from typing import Sequence + import es_factory from config import config from tests.automated import TestService @@ -16,69 +19,56 @@ class TestTaskEvents(TestService): def setUp(self, version="1.7"): super().setUp(version=version) - self.created_tasks = [] - - self.task = dict( - name="test task events", - type="training", - input=dict(mapping={}, view=dict(entries=[])), + def _temp_task(self, name="test task events"): + task_input = dict( + name=name, type="training", input=dict(mapping={}, view=dict(entries=[])), ) - res, self.task_id = self.api.send("tasks.create", self.task, extract="id") - assert res.meta.result_code == 200 - self.created_tasks.append(self.task_id) + return self.create_temp("tasks", **task_input) - def tearDown(self): - log.info("Cleanup...") - for task_id in self.created_tasks: - try: - self.api.send("tasks.delete", dict(task=task_id, force=True)) - except Exception as ex: - log.exception(ex) - - def create_task_event(self, type, iteration): + def _create_task_event(self, type_, task, iteration): return { "worker": "test", - "type": type, - "task": self.task_id, + "type": type_, + "task": task, "iter": iteration, - "timestamp": es_factory.get_timestamp_millis() + "timestamp": es_factory.get_timestamp_millis(), } - def copy_and_update(self, src_obj, new_data): + def _copy_and_update(self, src_obj, new_data): obj = src_obj.copy() obj.update(new_data) return obj def test_task_logs(self): events = [] - for iter in range(10): - log_event = self.create_task_event("log", iteration=iter) + task = self._temp_task() + for iter_ in range(10): + log_event = self._create_task_event("log", task, iteration=iter_) events.append( - self.copy_and_update( + self._copy_and_update( log_event, - {"msg": "This is a log message from test task iter " + str(iter)}, + {"msg": "This is a log message from test task iter " + str(iter_)}, ) ) # sleep so timestamp is not the same - import time - time.sleep(0.01) self.send_batch(events) - data = self.api.events.get_task_log(task=self.task_id) + data = self.api.events.get_task_log(task=task) assert len(data["events"]) == 10 - self.api.tasks.reset(task=self.task_id) - data = self.api.events.get_task_log(task=self.task_id) + self.api.tasks.reset(task=task) + data = self.api.events.get_task_log(task=task) assert len(data["events"]) == 0 def test_task_metric_value_intervals_keys(self): metric = "Metric1" variant = "Variant1" iter_count = 100 + task = self._temp_task() events = [ { - **self.create_task_event("training_stats_scalar", iteration), + **self._create_task_event("training_stats_scalar", task, iteration), "metric": metric, "variant": variant, "value": iteration, @@ -88,19 +78,65 @@ class TestTaskEvents(TestService): self.send_batch(events) for key in None, "iter", "timestamp", "iso_time": with self.subTest(key=key): - data = self.api.events.scalar_metrics_iter_histogram(task=self.task_id, key=key) + data = self.api.events.scalar_metrics_iter_histogram(task=task, key=key) self.assertIn(metric, data) self.assertIn(variant, data[metric]) self.assertIn("x", data[metric][variant]) self.assertIn("y", data[metric][variant]) + def test_multitask_events_many_metrics(self): + tasks = [ + self._temp_task(name="test events1"), + self._temp_task(name="test events2"), + ] + iter_count = 10 + metrics_count = 10 + variants_count = 10 + events = [ + { + **self._create_task_event("training_stats_scalar", task, iteration), + "metric": f"Metric{metric_idx}", + "variant": f"Variant{variant_idx}", + "value": iteration, + } + for iteration in range(iter_count) + for task in tasks + for metric_idx in range(metrics_count) + for variant_idx in range(variants_count) + ] + self.send_batch(events) + data = self.api.events.multi_task_scalar_metrics_iter_histogram(tasks=tasks) + self._assert_metrics_and_variants( + data.metrics, + metrics=metrics_count, + variants=variants_count, + tasks=tasks, + iterations=iter_count, + ) + + def _assert_metrics_and_variants( + self, data: dict, metrics: int, variants: int, tasks: Sequence, iterations: int + ): + self.assertEqual(len(data), metrics) + for m in range(metrics): + metric_data = data[f"Metric{m}"] + self.assertEqual(len(metric_data), variants) + for v in range(variants): + variant_data = metric_data[f"Variant{v}"] + self.assertEqual(len(variant_data), len(tasks)) + for t in tasks: + task_data = variant_data[t] + self.assertEqual(len(task_data["x"]), iterations) + self.assertEqual(len(task_data["y"]), iterations) + def test_task_metric_value_intervals(self): metric = "Metric1" variant = "Variant1" iter_count = 100 + task = self._temp_task() events = [ { - **self.create_task_event("training_stats_scalar", iteration), + **self._create_task_event("training_stats_scalar", task, iteration), "metric": metric, "variant": variant, "value": iteration, @@ -109,13 +145,13 @@ class TestTaskEvents(TestService): ] self.send_batch(events) - data = self.api.events.scalar_metrics_iter_histogram(task=self.task_id) + data = self.api.events.scalar_metrics_iter_histogram(task=task) self._assert_metrics_histogram(data[metric][variant], iter_count, 100) - data = self.api.events.scalar_metrics_iter_histogram(task=self.task_id, samples=100) + data = self.api.events.scalar_metrics_iter_histogram(task=task, samples=100) self._assert_metrics_histogram(data[metric][variant], iter_count, 100) - data = self.api.events.scalar_metrics_iter_histogram(task=self.task_id, samples=10) + data = self.api.events.scalar_metrics_iter_histogram(task=task, samples=10) self._assert_metrics_histogram(data[metric][variant], iter_count, 10) def _assert_metrics_histogram(self, data, iters, samples): @@ -130,7 +166,8 @@ class TestTaskEvents(TestService): ) def test_task_plots(self): - event = self.create_task_event("plot", 0) + task = self._temp_task() + event = self._create_task_event("plot", task, 0) event["metric"] = "roc" event.update( { @@ -179,7 +216,7 @@ class TestTaskEvents(TestService): ) self.send(event) - event = self.create_task_event("plot", 100) + event = self._create_task_event("plot", task, 100) event["metric"] = "confusion" event.update( { @@ -222,11 +259,11 @@ class TestTaskEvents(TestService): ) self.send(event) - data = self.api.events.get_task_plots(task=self.task_id) + data = self.api.events.get_task_plots(task=task) assert len(data["plots"]) == 2 - self.api.tasks.reset(task=self.task_id) - data = self.api.events.get_task_plots(task=self.task_id) + self.api.tasks.reset(task=task) + data = self.api.events.get_task_plots(task=task) assert len(data["plots"]) == 0 def send_batch(self, events):