Support filtering by task or model ids in projects.get_unique_metric_variants

This commit is contained in:
allegroai 2024-01-10 15:06:21 +02:00
parent 4684fd5b74
commit 35c4061992
5 changed files with 81 additions and 3 deletions

View File

@ -33,6 +33,7 @@ class ProjectOrNoneRequest(models.Base):
class GetUniqueMetricsRequest(ProjectOrNoneRequest):
model_metrics = fields.BoolField(default=False)
ids = fields.ListField(str)
class GetParamsRequest(ProjectOrNoneRequest):

View File

@ -239,6 +239,7 @@ class ProjectQueries:
company_id,
project_ids: Sequence[str],
include_subprojects: bool,
ids: Sequence[str],
model_metrics: bool = False,
):
pipeline = [
@ -246,6 +247,7 @@ class ProjectQueries:
"$match": {
**cls._get_company_constraint(company_id),
**cls._get_project_constraint(project_ids, include_subprojects),
**({"_id": {"$in": ids}} if ids else {}),
}
},
{"$project": {"metrics": {"$objectToArray": "$last_metrics"}}},

View File

@ -949,6 +949,13 @@ get_unique_metric_variants {
default: false
}
}
"999.0": ${get_unique_metric_variants."2.25"} {
request.properties.ids {
description: IDs of the tasks or models to get metrics from
type: array
items {type: string}
}
}
}
get_hyperparam_values {
"2.13" {

View File

@ -380,6 +380,7 @@ def get_unique_metric_variants(
company_id,
[request.project] if request.project else None,
include_subprojects=request.include_subprojects,
ids=request.ids,
model_metrics=request.model_metrics,
)

View File

@ -16,10 +16,18 @@ class TestTaskEvents(TestService):
delete_params = dict(can_fail=True, force=True)
default_task_name = "test task events"
def _temp_task(self, name=default_task_name):
task_input = dict(name=name, type="training",)
def _temp_project(self, name=default_task_name):
return self.create_temp(
"tasks", delete_paramse=self.delete_params, **task_input
"projects",
name=name,
description="test",
delete_params=self.delete_params,
)
def _temp_task(self, name=default_task_name, **kwargs):
self.update_missing(kwargs, name=name, type="training")
return self.create_temp(
"tasks", delete_paramse=self.delete_params, **kwargs
)
def _temp_model(self, name="test model events", **kwargs):
@ -122,6 +130,15 @@ class TestTaskEvents(TestService):
self.assertEqual(value.metric, metric)
self.assertEqual(value.variant, variant)
self.assertEqual(value.value, 0)
# test metrics parameter
res = self.api.events.get_task_single_value_metrics(
tasks=[task], metrics=[{"metric": metric, "variants": [variant]}]
).tasks
self.assertEqual(len(res), 1)
res = self.api.events.get_task_single_value_metrics(
tasks=[task], metrics=[{"metric": "non_existing", "variants": [variant]}]
).tasks
self.assertEqual(len(res), 0)
# update is working
task_data = self.api.tasks.get_by_id(task=task).task
@ -340,6 +357,30 @@ class TestTaskEvents(TestService):
else (None, None)
)
def test_task_unique_metric_variants(self):
project = self._temp_project()
task1 = self._temp_task(project=project)
task2 = self._temp_task(project=project)
metric1 = "Metric1"
metric2 = "Metric2"
events = [
{
**self._create_task_event("training_stats_scalar", task, 0),
"metric": metric,
"variant": "Variant",
"value": 10,
}
for task, metric in ((task1, metric1), (task2, metric2))
]
self.send_batch(events)
metrics = self.api.projects.get_unique_metric_variants(project=project).metrics
self.assertEqual({m.metric for m in metrics}, {metric1, metric2})
metrics = self.api.projects.get_unique_metric_variants(ids=[task1, task2]).metrics
self.assertEqual({m.metric for m in metrics}, {metric1, metric2})
metrics = self.api.projects.get_unique_metric_variants(ids=[task1]).metrics
self.assertEqual([m.metric for m in metrics], [metric1])
def test_task_metric_value_intervals_keys(self):
metric = "Metric1"
variant = "Variant1"
@ -395,6 +436,25 @@ class TestTaskEvents(TestService):
iterations=iter_count,
)
# test metrics
data = self.api.events.multi_task_scalar_metrics_iter_histogram(
tasks=tasks,
metrics=[
{
"metric": f"Metric{m_idx}",
"variants": [f"Variant{v_idx}" for v_idx in range(4)],
}
for m_idx in range(2)
],
)
self._assert_metrics_and_variants(
data.metrics,
metrics=2,
variants=4,
tasks=tasks,
iterations=iter_count,
)
def _assert_metrics_and_variants(
self, data: dict, metrics: int, variants: int, tasks: Sequence, iterations: int
):
@ -515,6 +575,13 @@ class TestTaskEvents(TestService):
self.assertEqual(plots.C.CX[task1]["3"]["plots"][0]["plot_str"], "Task1_3_C_CX")
self.assertEqual(plots.C.CX[task2]["1"]["plots"][0]["plot_str"], "Task2_1_C_CX")
# test metrics
plots = self.api.events.get_multi_task_plots(
tasks=[task1, task2], metrics=[{"metric": "A"}]
).plots
self.assertEqual(len(plots), 1)
self.assertEqual(len(plots.A), 2)
def test_task_plots(self):
task = self._temp_task()
event = self._create_task_event("plot", task, 0)