from functools import partial from typing import Sequence, Mapping, Optional from apiserver.es_factory import es_factory from apiserver.tests.automated import TestService class TestTaskPlots(TestService): def _temp_task(self, name="test task events"): task_input = dict( name=name, type="training" ) return self.create_temp("tasks", **task_input) @staticmethod def _create_task_event(task, iteration, **kwargs): return { "worker": "test", "type": "plot", "task": task, "iter": iteration, "timestamp": kwargs.get("timestamp") or es_factory.get_timestamp_millis(), **kwargs, } def test_get_plot_sample(self): task = self._temp_task() metric = "Metric1" variants = ["Variant1", "Variant2"] # test empty res = self.api.events.get_plot_sample(task=task, metric=metric) self.assertEqual(res.min_iteration, None) self.assertEqual(res.max_iteration, None) self.assertEqual(res.events, []) # test existing events iterations = 5 events = [ self._create_task_event( task=task, iteration=n // len(variants), metric=metric, variant=variants[n % len(variants)], plot_str=f"Test plot str {n}", ) for n in range(iterations * len(variants)) ] self.send_batch(events) # if iteration is not specified then return the event from the last one res = self.api.events.get_plot_sample(task=task, metric=metric) self._assertEqualEvents(res.events, events[-len(variants) :]) self.assertEqual(res.max_iteration, iterations - 1) self.assertEqual(res.min_iteration, 0) self.assertTrue(res.scroll_id) # else from the specific iteration iteration = 3 res = self.api.events.get_plot_sample( task=task, metric=metric, iteration=iteration, scroll_id=res.scroll_id, ) self._assertEqualEvents( res.events, events[iteration * len(variants) : (iteration + 1) * len(variants)], ) def test_next_plot_sample(self): task = self._temp_task() metric1 = "Metric1" metric2 = "Metric2" metrics = [ (metric1, "variant1"), (metric1, "variant2"), (metric2, "variant3"), (metric2, "variant4"), ] # test existing events events = [ self._create_task_event( task=task, iteration=n, metric=metric, variant=variant, plot_str=f"Test plot str {n}", ) for n in range(2) for metric, variant in metrics ] self.send_batch(events) # single metric navigation # init scroll res = self.api.events.get_plot_sample(task=task, metric=metric1) self._assertEqualEvents(res.events, events[-4:-2]) # navigate forwards res = self.api.events.next_plot_sample( task=task, scroll_id=res.scroll_id, navigate_earlier=False ) self.assertEqual(res.events, []) # navigate backwards res = self.api.events.next_plot_sample(task=task, scroll_id=res.scroll_id) self._assertEqualEvents(res.events, events[-8:-6]) res = self.api.events.next_plot_sample(task=task, scroll_id=res.scroll_id) self._assertEqualEvents(res.events, []) # all metrics navigation # init scroll res = self.api.events.get_plot_sample( task=task, metric=metric1, navigate_current_metric=False ) self._assertEqualEvents(res.events, events[-4:-2]) # navigate forwards res = self.api.events.next_plot_sample( task=task, scroll_id=res.scroll_id, navigate_earlier=False ) self._assertEqualEvents(res.events, events[-2:]) # navigate backwards res = self.api.events.next_plot_sample(task=task, scroll_id=res.scroll_id) self._assertEqualEvents(res.events, events[-4:-2]) res = self.api.events.next_plot_sample(task=task, scroll_id=res.scroll_id) self._assertEqualEvents(res.events, events[-6:-4]) # next_iteration res = self.api.events.next_plot_sample( task=task, scroll_id=res.scroll_id, next_iteration=True ) self._assertEqualEvents(res.events, []) res = self.api.events.next_plot_sample( task=task, scroll_id=res.scroll_id, next_iteration=True, navigate_earlier=False, ) self._assertEqualEvents(res.events, events[-4:-2]) self.assertTrue(all(ev.iter == 1 for ev in res.events)) res = self.api.events.next_plot_sample( task=task, scroll_id=res.scroll_id, next_iteration=True, navigate_earlier=False, ) self._assertEqualEvents(res.events, []) def _assertEqualEvents( self, ev_source: Sequence[dict], ev_target: Sequence[Optional[dict]] ): self.assertEqual(len(ev_source), len(ev_target)) def compare_event(ev1, ev2): for field in ("iter", "timestamp", "metric", "variant", "plot_str", "task"): self.assertEqual(ev1[field], ev2[field]) for e1, e2 in zip(ev_source, ev_target): compare_event(e1, e2) def test_task_plots(self): task = self._temp_task() # test empty res = self.api.events.plots(metrics=[{"task": task}], iters=5) self.assertFalse(res.metrics[0].iterations) res = self.api.events.plots( metrics=[{"task": task}], iters=5, scroll_id=res.scroll_id, refresh=True ) self.assertFalse(res.metrics[0].iterations) # test not empty metrics = { "Metric1": ["Variant1", "Variant2"], "Metric2": ["Variant3", "Variant4"], } events = [ self._create_task_event( task=task, iteration=1, metric=metric, variant=variant, plot_str=f"Test plot str {metric}_{variant}", ) for metric, variants in metrics.items() for variant in variants ] self.send_batch(events) scroll_id = self._assertTaskMetrics( task=task, expected_metrics=metrics, iterations=1 ) # test refresh update = { "Metric2": ["Variant3", "Variant4", "Variant5"], "Metric3": ["VariantA", "VariantB"], } events = [ self._create_task_event( task=task, iteration=2, metric=metric, variant=variant, plot_str=f"Test plot str {metric}_{variant}_2", ) for metric, variants in update.items() for variant in variants ] self.send_batch(events) # without refresh the metric states are not updated scroll_id = self._assertTaskMetrics( task=task, expected_metrics=metrics, iterations=0, scroll_id=scroll_id ) # with refresh there are new metrics and existing ones are updated self._assertTaskMetrics( task=task, expected_metrics=update, iterations=1, scroll_id=scroll_id, refresh=True, ) def _assertTaskMetrics( self, task: str, expected_metrics: Mapping[str, Sequence[str]], iterations, scroll_id: str = None, refresh=False, ) -> str: res = self.api.events.plots( metrics=[{"task": task}], iters=1, scroll_id=scroll_id, refresh=refresh ) if not iterations: self.assertTrue(all(m.iterations == [] for m in res.metrics)) return res.scroll_id expected_variants = set( (m, var) for m, vars_ in expected_metrics.items() for var in vars_ ) for metric_data in res.metrics: self.assertEqual(len(metric_data.iterations), iterations) for it_data in metric_data.iterations: self.assertEqual( set((e.metric, e.variant) for e in it_data.events), expected_variants, ) return res.scroll_id def test_plots_navigation(self): task = self._temp_task() metric = "Metric1" variants = ["Variant1", "Variant2"] iterations = 10 # test empty res = self.api.events.plots( metrics=[{"task": task, "metric": metric}], iters=5, ) self.assertFalse(res.metrics[0].iterations) # create events events = [ self._create_task_event( task=task, iteration=n, metric=metric, variant=variant, plot_str=f"{metric}_{variant}_{n}", ) for n in range(iterations) for variant in variants ] self.send_batch(events) # init testing scroll_id = None assert_plots = partial( self._assertPlots, task=task, metric=metric, iterations=iterations, variants=len(variants), ) # test forward navigation for page in range(3): scroll_id = assert_plots(scroll_id=scroll_id, expected_page=page) # test backwards navigation scroll_id = assert_plots( scroll_id=scroll_id, expected_page=0, navigate_earlier=False ) # beyond the latest iteration and back res = self.api.events.debug_images( metrics=[{"task": task, "metric": metric}], iters=5, scroll_id=scroll_id, navigate_earlier=False, ) self.assertEqual(len(res["metrics"][0]["iterations"]), 0) assert_plots(scroll_id=scroll_id, expected_page=1) # refresh assert_plots(scroll_id=scroll_id, expected_page=0, refresh=True) def _assertPlots( self, task, metric, iterations: int, variants: int, scroll_id, expected_page: int, iters: int = 5, **extra_params, ) -> str: res = self.api.events.plots( metrics=[{"task": task, "metric": metric}], iters=iters, scroll_id=scroll_id, **extra_params, ) data = res["metrics"][0] self.assertEqual(data["task"], task) left_iterations = max(0, iterations - expected_page * iters) self.assertEqual(len(data["iterations"]), min(iters, left_iterations)) for it in data["iterations"]: self.assertEqual(len(it["events"]), variants) return res.scroll_id def send_batch(self, events): _, data = self.api.send_batch("events.add_batch", events) return data