diff --git a/apiserver/apimodels/events.py b/apiserver/apimodels/events.py index ca39cdc..72bd3d7 100644 --- a/apiserver/apimodels/events.py +++ b/apiserver/apimodels/events.py @@ -156,6 +156,10 @@ class TaskMetricsRequest(MultiTasksRequestBase): event_type: EventType = ActualEnumField(EventType, required=True) +class MultiTaskMetricsRequest(MultiTasksRequestBase): + event_type: EventType = ActualEnumField(EventType, default=EventType.all) + + class MultiTaskPlotsRequest(MultiTasksRequestBase): iters: int = IntField(default=1) scroll_id: str = StringField() diff --git a/apiserver/bll/event/event_metrics.py b/apiserver/bll/event/event_metrics.py index cdbbe5a..70234df 100644 --- a/apiserver/bll/event/event_metrics.py +++ b/apiserver/bll/event/event_metrics.py @@ -21,6 +21,7 @@ from apiserver.bll.event.event_common import ( TaskCompanies, ) from apiserver.bll.event.scalar_key import ScalarKey, ScalarKeyEnum +from apiserver.bll.query import Builder as QueryBuilder from apiserver.config_repo import config from apiserver.database.errors import translate_errors_context from apiserver.tools import safe_get @@ -463,12 +464,96 @@ class EventMetrics: return {"bool": {"must": must}} + def get_multi_task_metrics(self, companies: TaskCompanies, event_type: EventType) -> Mapping[str, list]: + """ + For the requested tasks return reported metrics and variants + """ + tasks_ids = { + company: [t.id for t in tasks] + for company, tasks in companies.items() + } + with ThreadPoolExecutor(EventSettings.max_workers) as pool: + companies_res: Sequence = list( + pool.map( + partial( + self._get_multi_task_metrics, + event_type=event_type, + ), + tasks_ids.items(), + ) + ) + + if len(companies_res) == 1: + return companies_res[0] + + res = defaultdict(set) + for c_res in companies_res: + for m, vars_ in c_res.items(): + res[m].update(vars_) + + return { + k: list(v) + for k, v in res.items() + } + + def _get_multi_task_metrics( + self, company_tasks: Tuple[str, Sequence[str]], event_type: EventType + ) -> Mapping[str, list]: + company_id, task_ids = company_tasks + if check_empty_data(self.es, company_id, event_type): + return {} + + search_args = dict( + es=self.es, + company_id=company_id, + event_type=event_type, + ) + query = QueryBuilder.terms("task", task_ids) + max_metrics, max_variants = get_max_metric_and_variant_counts( + query=query, + **search_args, + ) + es_req = { + "size": 0, + "query": query, + "aggs": { + "metrics": { + "terms": { + "field": "metric", + "size": max_metrics, + "order": {"_key": "asc"}, + }, + "aggs": { + "variants": { + "terms": { + "field": "variant", + "size": max_variants, + "order": {"_key": "asc"}, + }, + } + } + } + }, + } + + es_res = search_company_events( + body=es_req, + **search_args, + ) + aggs_result = es_res.get("aggregations") + if not aggs_result: + return {} + + return { + mb["key"]: [vb["key"] for vb in mb["variants"]["buckets"]] + for mb in aggs_result["metrics"]["buckets"] + } + def get_task_metrics( self, company_id, task_ids: Sequence, event_type: EventType ) -> Sequence: """ - For the requested tasks return all the metrics that - reported events of the requested types + For the requested tasks return reported metrics per task """ if check_empty_data(self.es, company_id, event_type): return {} diff --git a/apiserver/schema/services/events.conf b/apiserver/schema/services/events.conf index e86a8d4..f445e15 100644 --- a/apiserver/schema/services/events.conf +++ b/apiserver/schema/services/events.conf @@ -754,6 +754,42 @@ get_task_metrics{ } } } +get_multi_task_metrics { + "999.0" { + description: """Get unique metrics and variants from the events of the specified type. + Only events reported for the passed task or model ids are analyzed.""" + request { + type: object + required: [ tasks ] + properties { + tasks { + description: task ids to get metrics from + type: array + items {type: string} + } + model_metrics { + description: If not set or set to false then passed ids are task ids otherwise model ids + type: boolean + default: false + } + event_type { + "description": Event type. If not specified then metrics are collected from the reported events of all types + "$ref": "#/definitions/event_type_enum" + } + } + } + response { + type: object + properties { + metrics { + type: array + description: List of metrics and variants + items { "$ref": "#/definitions/metric_variants" } + } + } + } + } +} get_task_log { "1.5" { description: "Get all 'log' events for this task" diff --git a/apiserver/services/events.py b/apiserver/services/events.py index 26fcb09..4e928ff 100644 --- a/apiserver/services/events.py +++ b/apiserver/services/events.py @@ -31,6 +31,7 @@ from apiserver.apimodels.events import ( GetMetricSamplesRequest, TaskMetric, MultiTaskPlotsRequest, + MultiTaskMetricsRequest, ) from apiserver.bll.event import EventBLL from apiserver.bll.event.event_common import EventType, MetricVariants, TaskCompanies @@ -965,6 +966,30 @@ def get_task_metrics(call: APICall, company_id, request: TaskMetricsRequest): } +@endpoint("events.get_multi_task_metrics") +def get_multi_task_metrics(call: APICall, company_id, request: MultiTaskMetricsRequest): + companies = _get_task_or_model_index_companies( + company_id, request.tasks, model_events=request.model_events + ) + if not companies: + return {"metrics": []} + + metrics = event_bll.metrics.get_multi_task_metrics( + companies=companies, + event_type=request.event_type + ) + res = [ + { + "metric": m, + "variants": sorted(vars_), + } + for m, vars_ in metrics.items() + ] + call.result.data = { + "metrics": sorted(res, key=itemgetter("metric")) + } + + @endpoint("events.delete_for_task", required_fields=["task"]) def delete_for_task(call, company_id, _): task_id = call.data["task"] diff --git a/apiserver/tests/automated/test_task_events.py b/apiserver/tests/automated/test_task_events.py index f9974ca..df81218 100644 --- a/apiserver/tests/automated/test_task_events.py +++ b/apiserver/tests/automated/test_task_events.py @@ -70,6 +70,26 @@ class TestTaskEvents(TestService): self._assert_task_metrics(tasks, "log") self._assert_task_metrics(tasks, "training_stats_scalar") + self._assert_multitask_metrics( + tasks=list(tasks), metrics=["Metric1", "Metric2", "Metric3"] + ) + self._assert_multitask_metrics( + tasks=list(tasks), + event_type="training_debug_image", + metrics=["Metric1", "Metric2", "Metric3"], + ) + self._assert_multitask_metrics(tasks=list(tasks), event_type="plot", metrics=[]) + + def _assert_multitask_metrics( + self, tasks: Sequence[str], metrics: Sequence[str], event_type: str = None + ): + res = self.api.events.get_multi_task_metrics( + tasks=tasks, + **({"event_type": event_type} if event_type else {}), + ).metrics + self.assertEqual([r.metric for r in res], metrics) + self.assertTrue(all(r.variants == ["Test variant"] for r in res)) + def _assert_task_metrics(self, tasks: dict, event_type: str): res = self.api.events.get_task_metrics(tasks=list(tasks), event_type=event_type) for task, metrics in tasks.items():