Model events are fully supported

This commit is contained in:
allegroai
2023-05-25 19:17:40 +03:00
parent 2e4e060a82
commit 58465fbc17
19 changed files with 341 additions and 228 deletions

View File

@@ -113,66 +113,75 @@ class TestReports(TestService):
def test_reports_task_data(self):
report_task = self._temp_report(name="Rep1")
non_report_task = self._temp_task(name="hello")
debug_image_events = [
self._create_task_event(
task=non_report_task,
type_="training_debug_image",
iteration=1,
metric=f"Metric_{m}",
variant=f"Variant_{v}",
url=f"{m}_{v}",
for model_events in (False, True):
if model_events:
non_report_task = self._temp_model(name="hello")
event_args = {"model_event": True}
else:
non_report_task = self._temp_task(name="hello")
event_args = {}
debug_image_events = [
self._create_task_event(
task=non_report_task,
type_="training_debug_image",
iteration=1,
metric=f"Metric_{m}",
variant=f"Variant_{v}",
url=f"{m}_{v}",
**event_args,
)
for m in range(2)
for v in range(2)
]
plot_events = [
self._create_task_event(
task=non_report_task,
type_="plot",
iteration=1,
metric=f"Metric_{m}",
variant=f"Variant_{v}",
plot_str=f"Hello plot",
**event_args,
)
for m in range(2)
for v in range(2)
]
self.send_batch([*debug_image_events, *plot_events])
res = self.api.reports.get_task_data(
id=[non_report_task], only_fields=["name"], model_events=model_events
)
for m in range(2)
for v in range(2)
]
plot_events = [
self._create_task_event(
task=non_report_task,
type_="plot",
iteration=1,
metric=f"Metric_{m}",
variant=f"Variant_{v}",
plot_str=f"Hello plot",
self.assertEqual(len(res.tasks), 1)
self.assertEqual(res.tasks[0].id, non_report_task)
self.assertFalse(any(field in res for field in ("plots", "debug_images")))
res = self.api.reports.get_task_data(
id=[non_report_task],
only_fields=["name"],
debug_images={"metrics": []},
plots={"metrics": [{"metric": "Metric_1"}]},
model_events=model_events,
)
for m in range(2)
for v in range(2)
]
self.send_batch([*debug_image_events, *plot_events])
self.assertEqual(len(res.debug_images), 1)
task_events = res.debug_images[0]
self.assertEqual(task_events.task, non_report_task)
self.assertEqual(len(task_events.iterations), 1)
self.assertEqual(len(task_events.iterations[0].events), 4)
res = self.api.reports.get_task_data(
id=[non_report_task], only_fields=["name"],
)
self.assertEqual(len(res.tasks), 1)
self.assertEqual(res.tasks[0].id, non_report_task)
self.assertFalse(any(field in res for field in ("plots", "debug_images")))
res = self.api.reports.get_task_data(
id=[non_report_task],
only_fields=["name"],
debug_images={"metrics": []},
plots={"metrics": [{"metric": "Metric_1"}]},
)
self.assertEqual(len(res.debug_images), 1)
task_events = res.debug_images[0]
self.assertEqual(task_events.task, non_report_task)
self.assertEqual(len(task_events.iterations), 1)
self.assertEqual(len(task_events.iterations[0].events), 4)
self.assertEqual(len(res.plots), 1)
for m, v in (("Metric_1", "Variant_0"), ("Metric_1", "Variant_1")):
tasks = nested_get(res.plots, (m, v))
self.assertEqual(len(tasks), 1)
task_plots = tasks[non_report_task]
self.assertEqual(len(task_plots), 1)
iter_plots = task_plots["1"]
self.assertEqual(iter_plots.name, "hello")
self.assertEqual(len(iter_plots.plots), 1)
ev = iter_plots.plots[0]
self.assertEqual(ev["metric"], m)
self.assertEqual(ev["variant"], v)
self.assertEqual(ev["task"], non_report_task)
self.assertEqual(ev["iter"], 1)
self.assertEqual(len(res.plots), 1)
for m, v in (("Metric_1", "Variant_0"), ("Metric_1", "Variant_1")):
tasks = nested_get(res.plots, (m, v))
self.assertEqual(len(tasks), 1)
task_plots = tasks[non_report_task]
self.assertEqual(len(task_plots), 1)
iter_plots = task_plots["1"]
self.assertEqual(iter_plots.name, "hello")
self.assertEqual(len(iter_plots.plots), 1)
ev = iter_plots.plots[0]
self.assertEqual(ev["metric"], m)
self.assertEqual(ev["variant"], v)
self.assertEqual(ev["task"], non_report_task)
self.assertEqual(ev["iter"], 1)
@staticmethod
def _create_task_event(type_, task, iteration, **kwargs):
@@ -185,12 +194,14 @@ class TestReports(TestService):
**kwargs,
}
delete_params = {"force": True}
def _temp_report(self, name, **kwargs):
return self.create_temp(
"reports",
name=name,
object_name="task",
delete_params={"force": True},
delete_params=self.delete_params,
**kwargs,
)
@@ -199,10 +210,16 @@ class TestReports(TestService):
"tasks",
name=name,
type="training",
delete_params={"force": True},
delete_params=self.delete_params,
**kwargs,
)
def _temp_model(self, name="test model events", **kwargs):
self.update_missing(
kwargs, name=name, uri="file:///a/b", labels={}, ready=False
)
return self.create_temp("models", delete_params=self.delete_params, **kwargs)
def send_batch(self, events):
_, data = self.api.send_batch("events.add_batch", events)
return data

View File

@@ -16,9 +16,7 @@ class TestTaskEvents(TestService):
delete_params = dict(can_fail=True, force=True)
def _temp_task(self, name="test task events"):
task_input = dict(
name=name, type="training",
)
task_input = dict(name=name, type="training",)
return self.create_temp(
"tasks", delete_paramse=self.delete_params, **task_input
)
@@ -220,6 +218,17 @@ class TestTaskEvents(TestService):
self.assertEqual(variant_data.x, [0, 1])
self.assertEqual(variant_data.y, [0.0, 1.0])
model_data = self.api.models.get_all_ex(
id=[model], only_fields=["last_metrics", "last_iteration"]
).models[0]
metric_data = first(first(model_data.last_metrics.values()).values())
self.assertEqual(1, model_data.last_iteration)
self.assertEqual(1, metric_data.value)
self.assertEqual(1, metric_data.max_value)
self.assertEqual(1, metric_data.max_value_iteration)
self.assertEqual(0, metric_data.min_value)
self.assertEqual(0, metric_data.min_value_iteration)
def test_error_events(self):
task = self._temp_task()
events = [