From 5d3ba4fa73efcc810db2868a2f83af9c8b1c7059 Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Wed, 26 Jul 2023 18:34:30 +0300 Subject: [PATCH] Fix events.get_multitask_plots to retrieve last iterations per each task metric separately --- apiserver/apimodels/events.py | 7 + apiserver/apimodels/reports.py | 6 +- apiserver/bll/event/event_bll.py | 121 +++++++++++++++--- apiserver/schema/services/events.conf | 7 + apiserver/schema/services/reports.conf | 7 + apiserver/services/events.py | 21 +-- apiserver/services/reports.py | 1 + apiserver/tests/automated/test_subprojects.py | 2 +- apiserver/tests/automated/test_task_events.py | 30 +++++ 9 files changed, 172 insertions(+), 30 deletions(-) diff --git a/apiserver/apimodels/events.py b/apiserver/apimodels/events.py index aa201d8..8bd52d9 100644 --- a/apiserver/apimodels/events.py +++ b/apiserver/apimodels/events.py @@ -155,6 +155,13 @@ class TaskMetricsRequest(MultiTasksRequestBase): event_type: EventType = ActualEnumField(EventType, required=True) +class MultiTaskPlotsRequest(MultiTasksRequestBase): + iters: int = IntField(default=1) + scroll_id: str = StringField() + no_scroll: bool = BoolField(default=False) + last_iters_per_task_metric: bool = BoolField(default=True) + + class TaskPlotsRequest(Base): task: str = StringField(required=True) iters: int = IntField(default=1) diff --git a/apiserver/apimodels/reports.py b/apiserver/apimodels/reports.py index 6315367..79becc4 100644 --- a/apiserver/apimodels/reports.py +++ b/apiserver/apimodels/reports.py @@ -57,6 +57,10 @@ class EventsRequest(Base): metrics: Sequence[MetricVariants] = ListField(items_types=MetricVariants) +class PlotEventsRequest(EventsRequest): + last_iters_per_task_metric: bool = BoolField(default=True) + + class ScalarMetricsIterHistogram(HistogramRequestBase): metrics: Sequence[MetricVariants] = ListField(items_types=MetricVariants) @@ -67,7 +71,7 @@ class SingleValueMetrics(Base): class GetTasksDataRequest(Base): debug_images: EventsRequest = EmbeddedField(EventsRequest) - plots: EventsRequest = EmbeddedField(EventsRequest) + plots: PlotEventsRequest = EmbeddedField(PlotEventsRequest) scalar_metrics_iter_histogram: ScalarMetricsIterHistogram = EmbeddedField( ScalarMetricsIterHistogram ) diff --git a/apiserver/bll/event/event_bll.py b/apiserver/bll/event/event_bll.py index 59ddcf1..83d07a8 100644 --- a/apiserver/bll/event/event_bll.py +++ b/apiserver/bll/event/event_bll.py @@ -717,6 +717,7 @@ class EventBLL(object): size=500, scroll_id=None, no_scroll=False, + last_iters_per_task_metric=False, ) -> TaskEventsResult: if scroll_id == self.empty_scroll: return TaskEventsResult() @@ -743,25 +744,47 @@ class EventBLL(object): if last_iter_count is None: must.append({"terms": {"task": task_ids}}) else: - tasks_iters = self.get_last_iters( - company_id=company_ids, - event_type=event_type, - task_id=task_ids, - iters=last_iter_count, - metrics=metrics, - ) - should = [ - { - "bool": { - "must": [ - {"term": {"task": task}}, - {"terms": {"iter": last_iters}}, - ] + if last_iters_per_task_metric: + task_metric_iters = self.get_last_iters_per_metric( + company_id=company_ids, + event_type=event_type, + task_id=task_ids, + iters=last_iter_count, + metrics=metrics, + ) + should = [ + { + "bool": { + "must": [ + {"term": {"task": task}}, + {"term": {"metric": metric}}, + {"terms": {"iter": last_iters}}, + ] + } } - } - for task, last_iters in tasks_iters.items() - if last_iters - ] + for (task, metric), last_iters in task_metric_iters.items() + if last_iters + ] + else: + tasks_iters = self.get_last_iters( + company_id=company_ids, + event_type=event_type, + task_id=task_ids, + iters=last_iter_count, + metrics=metrics, + ) + should = [ + { + "bool": { + "must": [ + {"term": {"task": task}}, + {"terms": {"iter": last_iters}}, + ] + } + } + for task, last_iters in tasks_iters.items() + if last_iters + ] if not should: return TaskEventsResult() must.append({"bool": {"should": should}}) @@ -959,6 +982,68 @@ class EventBLL(object): return iterations, vectors + def get_last_iters_per_metric( + self, + company_id: Union[str, Sequence[str]], + event_type: EventType, + task_id: Union[str, Sequence[str]], + iters: int, + metrics: MetricVariants = None, + ) -> Mapping[Tuple[str, str], Sequence]: + company_ids = [company_id] if isinstance(company_id, str) else company_id + company_ids = [ + c_id + for c_id in set(company_ids) + if not check_empty_data(self.es, c_id, event_type) + ] + if not company_ids: + return {} + + task_ids = [task_id] if isinstance(task_id, str) else task_id + must = [{"terms": {"task": task_ids}}] + if metrics: + must.append(get_metric_variants_condition(metrics)) + + es_req: dict = { + "size": 0, + "aggs": { + "tasks": { + "terms": {"field": "task"}, + "aggs": { + "metrics": { + "terms": {"field": "metric"}, + "aggs": { + "iters": { + "terms": { + "field": "iter", + "size": iters, + "order": {"_key": "desc"}, + } + } + } + } + } + } + }, + "query": {"bool": {"must": must}}, + } + + with translate_errors_context(): + es_res = search_company_events( + self.es, + company_id=company_ids, + event_type=event_type, + body=es_req, + ) + if "aggregations" not in es_res: + return {} + + return { + (tb["key"], mb["key"]): [ib["key"] for ib in mb["iters"]["buckets"]] + for tb in es_res["aggregations"]["tasks"]["buckets"] + for mb in tb["metrics"]["buckets"] + } + def get_last_iters( self, company_id: Union[str, Sequence[str]], diff --git a/apiserver/schema/services/events.conf b/apiserver/schema/services/events.conf index 6fd89be..2a6c4b8 100644 --- a/apiserver/schema/services/events.conf +++ b/apiserver/schema/services/events.conf @@ -1149,6 +1149,13 @@ get_multi_task_plots { default: false } } + "999.0": ${get_multi_task_plots."2.22"} { + request.properties.last_iters_per_task_metric { + type: boolean + description: If set to 'true' and iters passed then last iterations for each task metrics are retrieved. Otherwise last iterations for the whole task are retrieved + default: true + } + } } get_vector_metrics_and_variants { "2.1" { diff --git a/apiserver/schema/services/reports.conf b/apiserver/schema/services/reports.conf index 83ed581..c9df092 100644 --- a/apiserver/schema/services/reports.conf +++ b/apiserver/schema/services/reports.conf @@ -587,6 +587,13 @@ get_task_data { items {"$ref": "#/definitions/single_value_task_metrics"} } } + "999.0": ${get_task_data."2.25"} { + request.properties.plots.properties.last_iters_per_task_metric { + type: boolean + description: If set to 'true' and iters passed then last iterations for each task metrics are retrieved. Otherwise last iterations for the whole task are retrieved + default: true + } + } } get_all_ex { "2.23" { diff --git a/apiserver/services/events.py b/apiserver/services/events.py index 1b856b0..63f16af 100644 --- a/apiserver/services/events.py +++ b/apiserver/services/events.py @@ -30,6 +30,7 @@ from apiserver.apimodels.events import ( GetVariantSampleRequest, GetMetricSamplesRequest, TaskMetric, + MultiTaskPlotsRequest, ) from apiserver.bll.event import EventBLL from apiserver.bll.event.event_common import EventType, MetricVariants, TaskCompanies @@ -554,6 +555,7 @@ def get_multi_task_plots_v1_7(call, company_id, _): def _get_multitask_plots( companies: TaskCompanies, last_iters: int, + last_iters_per_task_metric: bool, metrics: MetricVariants = None, scroll_id=None, no_scroll=True, @@ -573,6 +575,7 @@ def _get_multitask_plots( size=config.get( "services.events.events_retrieval.multi_plots_batch_size", 1000 ), + last_iters_per_task_metric=last_iters_per_task_metric, ) return_events = _get_top_iter_unique_events_per_task( result.events, max_iters=last_iters, task_names=task_names @@ -580,19 +583,17 @@ def _get_multitask_plots( return return_events, result.total_events, result.next_scroll_id -@endpoint("events.get_multi_task_plots", min_version="1.8", required_fields=["tasks"]) -def get_multi_task_plots(call, company_id, _): - task_ids = call.data["tasks"] - iters = call.data.get("iters", 1) - scroll_id = call.data.get("scroll_id") - no_scroll = call.data.get("no_scroll", False) - model_events = call.data.get("model_events", False) - +@endpoint("events.get_multi_task_plots", min_version="1.8") +def get_multi_task_plots(call, company_id, request: MultiTaskPlotsRequest): companies = _get_task_or_model_index_companies( - company_id, task_ids, model_events=model_events + company_id, request.tasks, model_events=request.model_events ) return_events, total_events, next_scroll_id = _get_multitask_plots( - companies=companies, last_iters=iters, scroll_id=scroll_id, no_scroll=no_scroll, + companies=companies, + last_iters=request.iters, + scroll_id=request.scroll_id, + no_scroll=request.no_scroll, + last_iters_per_task_metric=request.last_iters_per_task_metric, ) call.result.data = dict( plots=return_events, diff --git a/apiserver/services/reports.py b/apiserver/services/reports.py index 13172b6..5786ab7 100644 --- a/apiserver/services/reports.py +++ b/apiserver/services/reports.py @@ -264,6 +264,7 @@ def get_task_data(call: APICall, company_id, request: GetTasksDataRequest): companies=companies, last_iters=request.plots.iters, metrics=_get_metric_variants_from_request(request.plots.metrics), + last_iters_per_task_metric=request.plots.last_iters_per_task_metric, )[0] if request.scalar_metrics_iter_histogram: diff --git a/apiserver/tests/automated/test_subprojects.py b/apiserver/tests/automated/test_subprojects.py index 54c2c12..da26a62 100644 --- a/apiserver/tests/automated/test_subprojects.py +++ b/apiserver/tests/automated/test_subprojects.py @@ -153,7 +153,7 @@ class TestSubProjects(TestService): self.assertEqual(p.own_tasks, 0) self.assertIsNone(p.get("own_datasets")) self.assertEqual( - p.stats.active.total_tasks, 1 if p.basename != "Project4" else 0 + p.stats.active.total_tasks, 1 if p.basename != "Project2" else 0 ) def test_project_aggregations(self): diff --git a/apiserver/tests/automated/test_task_events.py b/apiserver/tests/automated/test_task_events.py index a9ff3cf..869b53d 100644 --- a/apiserver/tests/automated/test_task_events.py +++ b/apiserver/tests/automated/test_task_events.py @@ -482,6 +482,36 @@ class TestTaskEvents(TestService): mean(v for v in range(curr * interval, (curr + 1) * interval)), ) + def test_multitask_plots(self): + task1 = self._temp_task() + events = [ + self._create_task_event("plot", task1, 1, metric="A", variant="AX", plot_str="Task1_1_A_AX"), + self._create_task_event("plot", task1, 2, metric="B", variant="BX", plot_str="Task1_2_B_BX"), + self._create_task_event("plot", task1, 3, metric="B", variant="BX", plot_str="Task1_3_B_BX"), + self._create_task_event("plot", task1, 3, metric="C", variant="CX", plot_str="Task1_3_C_CX"), + ] + self.send_batch(events) + task2 = self._temp_task() + events = [ + self._create_task_event("plot", task2, 1, metric="C", variant="CX", plot_str="Task2_1_C_CX"), + self._create_task_event("plot", task2, 2, metric="A", variant="AY", plot_str="Task2_2_A_AY"), + ] + self.send_batch(events) + plots = self.api.events.get_multi_task_plots(tasks=[task1, task2]).plots + self.assertEqual(len(plots), 3) + self.assertEqual(len(plots.A), 2) + self.assertEqual(len(plots.A.AX), 1) + self.assertEqual(len(plots.A.AY), 1) + self.assertEqual(plots.A.AX[task1]["1"]["plots"][0]["plot_str"], "Task1_1_A_AX") + self.assertEqual(plots.A.AY[task2]["2"]["plots"][0]["plot_str"], "Task2_2_A_AY") + self.assertEqual(len(plots.B), 1) + self.assertEqual(len(plots.B.BX), 1) + self.assertEqual(plots.B.BX[task1]["3"]["plots"][0]["plot_str"], "Task1_3_B_BX") + self.assertEqual(len(plots.C), 1) + self.assertEqual(len(plots.C.CX), 2) + self.assertEqual(plots.C.CX[task1]["3"]["plots"][0]["plot_str"], "Task1_3_C_CX") + self.assertEqual(plots.C.CX[task2]["1"]["plots"][0]["plot_str"], "Task2_1_C_CX") + def test_task_plots(self): task = self._temp_task() event = self._create_task_event("plot", task, 0)