diff --git a/apiserver/apimodels/projects.py b/apiserver/apimodels/projects.py index 933feb4..2d10ab8 100644 --- a/apiserver/apimodels/projects.py +++ b/apiserver/apimodels/projects.py @@ -33,6 +33,7 @@ class ProjectOrNoneRequest(models.Base): class GetUniqueMetricsRequest(ProjectOrNoneRequest): model_metrics = fields.BoolField(default=False) + ids = fields.ListField(str) class GetParamsRequest(ProjectOrNoneRequest): diff --git a/apiserver/bll/project/project_queries.py b/apiserver/bll/project/project_queries.py index 3e95103..5fd05b9 100644 --- a/apiserver/bll/project/project_queries.py +++ b/apiserver/bll/project/project_queries.py @@ -239,6 +239,7 @@ class ProjectQueries: company_id, project_ids: Sequence[str], include_subprojects: bool, + ids: Sequence[str], model_metrics: bool = False, ): pipeline = [ @@ -246,6 +247,7 @@ class ProjectQueries: "$match": { **cls._get_company_constraint(company_id), **cls._get_project_constraint(project_ids, include_subprojects), + **({"_id": {"$in": ids}} if ids else {}), } }, {"$project": {"metrics": {"$objectToArray": "$last_metrics"}}}, diff --git a/apiserver/schema/services/projects.conf b/apiserver/schema/services/projects.conf index ffb20dd..f782aa7 100644 --- a/apiserver/schema/services/projects.conf +++ b/apiserver/schema/services/projects.conf @@ -949,6 +949,13 @@ get_unique_metric_variants { default: false } } + "999.0": ${get_unique_metric_variants."2.25"} { + request.properties.ids { + description: IDs of the tasks or models to get metrics from + type: array + items {type: string} + } + } } get_hyperparam_values { "2.13" { diff --git a/apiserver/services/projects.py b/apiserver/services/projects.py index b0d701b..ee6c75e 100644 --- a/apiserver/services/projects.py +++ b/apiserver/services/projects.py @@ -380,6 +380,7 @@ def get_unique_metric_variants( company_id, [request.project] if request.project else None, include_subprojects=request.include_subprojects, + ids=request.ids, model_metrics=request.model_metrics, ) diff --git a/apiserver/tests/automated/test_task_events.py b/apiserver/tests/automated/test_task_events.py index e42b9a6..f9974ca 100644 --- a/apiserver/tests/automated/test_task_events.py +++ b/apiserver/tests/automated/test_task_events.py @@ -16,10 +16,18 @@ class TestTaskEvents(TestService): delete_params = dict(can_fail=True, force=True) default_task_name = "test task events" - def _temp_task(self, name=default_task_name): - task_input = dict(name=name, type="training",) + def _temp_project(self, name=default_task_name): return self.create_temp( - "tasks", delete_paramse=self.delete_params, **task_input + "projects", + name=name, + description="test", + delete_params=self.delete_params, + ) + + def _temp_task(self, name=default_task_name, **kwargs): + self.update_missing(kwargs, name=name, type="training") + return self.create_temp( + "tasks", delete_paramse=self.delete_params, **kwargs ) def _temp_model(self, name="test model events", **kwargs): @@ -122,6 +130,15 @@ class TestTaskEvents(TestService): self.assertEqual(value.metric, metric) self.assertEqual(value.variant, variant) self.assertEqual(value.value, 0) + # test metrics parameter + res = self.api.events.get_task_single_value_metrics( + tasks=[task], metrics=[{"metric": metric, "variants": [variant]}] + ).tasks + self.assertEqual(len(res), 1) + res = self.api.events.get_task_single_value_metrics( + tasks=[task], metrics=[{"metric": "non_existing", "variants": [variant]}] + ).tasks + self.assertEqual(len(res), 0) # update is working task_data = self.api.tasks.get_by_id(task=task).task @@ -340,6 +357,30 @@ class TestTaskEvents(TestService): else (None, None) ) + def test_task_unique_metric_variants(self): + project = self._temp_project() + task1 = self._temp_task(project=project) + task2 = self._temp_task(project=project) + metric1 = "Metric1" + metric2 = "Metric2" + events = [ + { + **self._create_task_event("training_stats_scalar", task, 0), + "metric": metric, + "variant": "Variant", + "value": 10, + } + for task, metric in ((task1, metric1), (task2, metric2)) + ] + self.send_batch(events) + + metrics = self.api.projects.get_unique_metric_variants(project=project).metrics + self.assertEqual({m.metric for m in metrics}, {metric1, metric2}) + metrics = self.api.projects.get_unique_metric_variants(ids=[task1, task2]).metrics + self.assertEqual({m.metric for m in metrics}, {metric1, metric2}) + metrics = self.api.projects.get_unique_metric_variants(ids=[task1]).metrics + self.assertEqual([m.metric for m in metrics], [metric1]) + def test_task_metric_value_intervals_keys(self): metric = "Metric1" variant = "Variant1" @@ -395,6 +436,25 @@ class TestTaskEvents(TestService): iterations=iter_count, ) + # test metrics + data = self.api.events.multi_task_scalar_metrics_iter_histogram( + tasks=tasks, + metrics=[ + { + "metric": f"Metric{m_idx}", + "variants": [f"Variant{v_idx}" for v_idx in range(4)], + } + for m_idx in range(2) + ], + ) + self._assert_metrics_and_variants( + data.metrics, + metrics=2, + variants=4, + tasks=tasks, + iterations=iter_count, + ) + def _assert_metrics_and_variants( self, data: dict, metrics: int, variants: int, tasks: Sequence, iterations: int ): @@ -515,6 +575,13 @@ class TestTaskEvents(TestService): self.assertEqual(plots.C.CX[task1]["3"]["plots"][0]["plot_str"], "Task1_3_C_CX") self.assertEqual(plots.C.CX[task2]["1"]["plots"][0]["plot_str"], "Task2_1_C_CX") + # test metrics + plots = self.api.events.get_multi_task_plots( + tasks=[task1, task2], metrics=[{"metric": "A"}] + ).plots + self.assertEqual(len(plots), 1) + self.assertEqual(len(plots.A), 2) + def test_task_plots(self): task = self._temp_task() event = self._create_task_event("plot", task, 0)