clearml-server/apiserver/tests/automated/test_task_plots.py
2023-05-25 19:13:10 +03:00

338 lines
11 KiB
Python

from functools import partial
from typing import Sequence, Mapping, Optional
from apiserver.es_factory import es_factory
from apiserver.tests.automated import TestService
class TestTaskPlots(TestService):
def _temp_task(self, name="test task events"):
task_input = dict(
name=name, type="training"
)
return self.create_temp("tasks", **task_input)
@staticmethod
def _create_task_event(task, iteration, **kwargs):
return {
"worker": "test",
"type": "plot",
"task": task,
"iter": iteration,
"timestamp": kwargs.get("timestamp") or es_factory.get_timestamp_millis(),
**kwargs,
}
def test_get_plot_sample(self):
task = self._temp_task()
metric = "Metric1"
variants = ["Variant1", "Variant2"]
# test empty
res = self.api.events.get_plot_sample(task=task, metric=metric)
self.assertEqual(res.min_iteration, None)
self.assertEqual(res.max_iteration, None)
self.assertEqual(res.events, [])
# test existing events
iterations = 5
events = [
self._create_task_event(
task=task,
iteration=n // len(variants),
metric=metric,
variant=variants[n % len(variants)],
plot_str=f"Test plot str {n}",
)
for n in range(iterations * len(variants))
]
self.send_batch(events)
# if iteration is not specified then return the event from the last one
res = self.api.events.get_plot_sample(task=task, metric=metric)
self._assertEqualEvents(res.events, events[-len(variants) :])
self.assertEqual(res.max_iteration, iterations - 1)
self.assertEqual(res.min_iteration, 0)
self.assertTrue(res.scroll_id)
# else from the specific iteration
iteration = 3
res = self.api.events.get_plot_sample(
task=task, metric=metric, iteration=iteration, scroll_id=res.scroll_id,
)
self._assertEqualEvents(
res.events,
events[iteration * len(variants) : (iteration + 1) * len(variants)],
)
def test_next_plot_sample(self):
task = self._temp_task()
metric1 = "Metric1"
metric2 = "Metric2"
metrics = [
(metric1, "variant1"),
(metric1, "variant2"),
(metric2, "variant3"),
(metric2, "variant4"),
]
# test existing events
events = [
self._create_task_event(
task=task,
iteration=n,
metric=metric,
variant=variant,
plot_str=f"Test plot str {n}",
)
for n in range(2)
for metric, variant in metrics
]
self.send_batch(events)
# single metric navigation
# init scroll
res = self.api.events.get_plot_sample(task=task, metric=metric1)
self._assertEqualEvents(res.events, events[-4:-2])
# navigate forwards
res = self.api.events.next_plot_sample(
task=task, scroll_id=res.scroll_id, navigate_earlier=False
)
self.assertEqual(res.events, [])
# navigate backwards
res = self.api.events.next_plot_sample(task=task, scroll_id=res.scroll_id)
self._assertEqualEvents(res.events, events[-8:-6])
res = self.api.events.next_plot_sample(task=task, scroll_id=res.scroll_id)
self._assertEqualEvents(res.events, [])
# all metrics navigation
# init scroll
res = self.api.events.get_plot_sample(
task=task, metric=metric1, navigate_current_metric=False
)
self._assertEqualEvents(res.events, events[-4:-2])
# navigate forwards
res = self.api.events.next_plot_sample(
task=task, scroll_id=res.scroll_id, navigate_earlier=False
)
self._assertEqualEvents(res.events, events[-2:])
# navigate backwards
res = self.api.events.next_plot_sample(task=task, scroll_id=res.scroll_id)
self._assertEqualEvents(res.events, events[-4:-2])
res = self.api.events.next_plot_sample(task=task, scroll_id=res.scroll_id)
self._assertEqualEvents(res.events, events[-6:-4])
# next_iteration
res = self.api.events.next_plot_sample(
task=task, scroll_id=res.scroll_id, next_iteration=True
)
self._assertEqualEvents(res.events, [])
res = self.api.events.next_plot_sample(
task=task,
scroll_id=res.scroll_id,
next_iteration=True,
navigate_earlier=False,
)
self._assertEqualEvents(res.events, events[-4:-2])
self.assertTrue(all(ev.iter == 1 for ev in res.events))
res = self.api.events.next_plot_sample(
task=task,
scroll_id=res.scroll_id,
next_iteration=True,
navigate_earlier=False,
)
self._assertEqualEvents(res.events, [])
def _assertEqualEvents(
self, ev_source: Sequence[dict], ev_target: Sequence[Optional[dict]]
):
self.assertEqual(len(ev_source), len(ev_target))
def compare_event(ev1, ev2):
for field in ("iter", "timestamp", "metric", "variant", "plot_str", "task"):
self.assertEqual(ev1[field], ev2[field])
for e1, e2 in zip(ev_source, ev_target):
compare_event(e1, e2)
def test_task_plots(self):
task = self._temp_task()
# test empty
res = self.api.events.plots(metrics=[{"task": task}], iters=5)
self.assertFalse(res.metrics[0].iterations)
res = self.api.events.plots(
metrics=[{"task": task}], iters=5, scroll_id=res.scroll_id, refresh=True
)
self.assertFalse(res.metrics[0].iterations)
# test not empty
metrics = {
"Metric1": ["Variant1", "Variant2"],
"Metric2": ["Variant3", "Variant4"],
}
events = [
self._create_task_event(
task=task,
iteration=1,
metric=metric,
variant=variant,
plot_str=f"Test plot str {metric}_{variant}",
)
for metric, variants in metrics.items()
for variant in variants
]
self.send_batch(events)
scroll_id = self._assertTaskMetrics(
task=task, expected_metrics=metrics, iterations=1
)
# test refresh
update = {
"Metric2": ["Variant3", "Variant4", "Variant5"],
"Metric3": ["VariantA", "VariantB"],
}
events = [
self._create_task_event(
task=task,
iteration=2,
metric=metric,
variant=variant,
plot_str=f"Test plot str {metric}_{variant}_2",
)
for metric, variants in update.items()
for variant in variants
]
self.send_batch(events)
# without refresh the metric states are not updated
scroll_id = self._assertTaskMetrics(
task=task, expected_metrics=metrics, iterations=0, scroll_id=scroll_id
)
# with refresh there are new metrics and existing ones are updated
self._assertTaskMetrics(
task=task,
expected_metrics=update,
iterations=1,
scroll_id=scroll_id,
refresh=True,
)
def _assertTaskMetrics(
self,
task: str,
expected_metrics: Mapping[str, Sequence[str]],
iterations,
scroll_id: str = None,
refresh=False,
) -> str:
res = self.api.events.plots(
metrics=[{"task": task}], iters=1, scroll_id=scroll_id, refresh=refresh
)
if not iterations:
self.assertTrue(all(m.iterations == [] for m in res.metrics))
return res.scroll_id
expected_variants = set(
(m, var) for m, vars_ in expected_metrics.items() for var in vars_
)
for metric_data in res.metrics:
self.assertEqual(len(metric_data.iterations), iterations)
for it_data in metric_data.iterations:
self.assertEqual(
set((e.metric, e.variant) for e in it_data.events),
expected_variants,
)
return res.scroll_id
def test_plots_navigation(self):
task = self._temp_task()
metric = "Metric1"
variants = ["Variant1", "Variant2"]
iterations = 10
# test empty
res = self.api.events.plots(
metrics=[{"task": task, "metric": metric}], iters=5,
)
self.assertFalse(res.metrics[0].iterations)
# create events
events = [
self._create_task_event(
task=task,
iteration=n,
metric=metric,
variant=variant,
plot_str=f"{metric}_{variant}_{n}",
)
for n in range(iterations)
for variant in variants
]
self.send_batch(events)
# init testing
scroll_id = None
assert_plots = partial(
self._assertPlots,
task=task,
metric=metric,
iterations=iterations,
variants=len(variants),
)
# test forward navigation
for page in range(3):
scroll_id = assert_plots(scroll_id=scroll_id, expected_page=page)
# test backwards navigation
scroll_id = assert_plots(
scroll_id=scroll_id, expected_page=0, navigate_earlier=False
)
# beyond the latest iteration and back
res = self.api.events.debug_images(
metrics=[{"task": task, "metric": metric}],
iters=5,
scroll_id=scroll_id,
navigate_earlier=False,
)
self.assertEqual(len(res["metrics"][0]["iterations"]), 0)
assert_plots(scroll_id=scroll_id, expected_page=1)
# refresh
assert_plots(scroll_id=scroll_id, expected_page=0, refresh=True)
def _assertPlots(
self,
task,
metric,
iterations: int,
variants: int,
scroll_id,
expected_page: int,
iters: int = 5,
**extra_params,
) -> str:
res = self.api.events.plots(
metrics=[{"task": task, "metric": metric}],
iters=iters,
scroll_id=scroll_id,
**extra_params,
)
data = res["metrics"][0]
self.assertEqual(data["task"], task)
left_iterations = max(0, iterations - expected_page * iters)
self.assertEqual(len(data["iterations"]), min(iters, left_iterations))
for it in data["iterations"]:
self.assertEqual(len(it["events"]), variants)
return res.scroll_id
def send_batch(self, events):
_, data = self.api.send_batch("events.add_batch", events)
return data