From 0c15169668835fea32415bb19f23ff6b9cae6967 Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Fri, 8 Jul 2022 17:38:31 +0300 Subject: [PATCH] Improve tests --- apiserver/tests/api_client.conf | 4 - .../tests/automated/test_batch_operations.py | 2 +- apiserver/tests/automated/test_subprojects.py | 31 +- apiserver/tests/automated/test_task_events.py | 80 +++++ apiserver/tests/automated/test_task_plots.py | 321 ++++++++++++++++++ apiserver/tests/automated/test_workers.py | 15 +- 6 files changed, 437 insertions(+), 16 deletions(-) delete mode 100644 apiserver/tests/api_client.conf create mode 100644 apiserver/tests/automated/test_task_plots.py diff --git a/apiserver/tests/api_client.conf b/apiserver/tests/api_client.conf deleted file mode 100644 index d0f7847..0000000 --- a/apiserver/tests/api_client.conf +++ /dev/null @@ -1,4 +0,0 @@ -{ - api_key: "EGRTCO8JMSIGI6S39GTP43NFWXDQOW" - secret_key: "x!XTov_G-#vspE*Y(h$Anm&DIc5Ou-F)jsl$PdOyj5wG1&E!Z8" -} \ No newline at end of file diff --git a/apiserver/tests/automated/test_batch_operations.py b/apiserver/tests/automated/test_batch_operations.py index 98b5356..83cd2cd 100644 --- a/apiserver/tests/automated/test_batch_operations.py +++ b/apiserver/tests/automated/test_batch_operations.py @@ -20,7 +20,7 @@ class TestBatchOperations(TestService): ids = [*tasks, missing_id] # enqueue - res = self.api.tasks.enqueue_many(ids=ids, queue_name="test") + res = self.api.tasks.enqueue_many(ids=ids, queue_name="test batch") self._assert_succeeded(res, tasks) self._assert_failed(res, [missing_id]) data = self.api.tasks.get_all_ex(id=ids).tasks diff --git a/apiserver/tests/automated/test_subprojects.py b/apiserver/tests/automated/test_subprojects.py index fc1aa47..a143cf5 100644 --- a/apiserver/tests/automated/test_subprojects.py +++ b/apiserver/tests/automated/test_subprojects.py @@ -20,11 +20,12 @@ class TestSubProjects(TestService): base_url=f"http://localhost:8008/v2.13", ) - child = self._temp_project(name="Aggregation/Pr1", client=user2_client) + basename = "Pr1" + child = self._temp_project(name=f"Aggregation/{basename}", client=user2_client) project = self.api.projects.get_all_ex(name="^Aggregation$").projects[0].id child_project = self.api.projects.get_all_ex(id=[child]).projects[0] self.assertEqual(child_project.parent.id, project) - + self.assertEqual(child_project.basename, basename) user = self.api.users.get_current_user().user.id # test aggregations on project with empty subprojects @@ -42,12 +43,17 @@ class TestSubProjects(TestService): # test aggregations with non-empty subprojects task1 = self._temp_task(project=child) self._temp_task(project=child, parent=task1) + user2_task = self._temp_task(project=child, client=user2_client) framework = "Test framework" self._temp_model(project=child, framework=framework) res = self.api.users.get_all_ex(active_in_projects=[project]) self._assert_ids(res.users, [user]) - res = self.api.projects.get_all_ex(id=[project], active_users=[user]) + res = self.api.projects.get_all_ex(id=[project], include_stats=True) self._assert_ids(res.projects, [project]) + self.assertEqual(res.projects[0].stats.active.total_tasks, 3) + res = self.api.projects.get_all_ex(id=[project], active_users=[user], include_stats=True) + self._assert_ids(res.projects, [project]) + self.assertEqual(res.projects[0].stats.active.total_tasks, 2) res = self.api.projects.get_task_parents(projects=[project]) self._assert_ids(res.parents, [task1]) res = self.api.models.get_frameworks(projects=[project]) @@ -70,8 +76,11 @@ class TestSubProjects(TestService): # update with self.api.raises(errors.bad_request.CannotUpdateProjectLocation): self.api.projects.update(project=project1, name="Root2/Pr2") - res = self.api.projects.update(project=project1, name="Root1/Pr2") + new_basename = "Pr2" + res = self.api.projects.update(project=project1, name=f"Root1/{new_basename}") self.assertEqual(res.updated, 1) + res = self.api.projects.get_by_id(project=project1) + self.assertEqual(res.project.basename, new_basename) res = self.api.projects.get_by_id(project=project1_child) self.assertEqual(res.project.name, "Root1/Pr2/Pr2") @@ -80,6 +89,7 @@ class TestSubProjects(TestService): self.assertEqual(res.moved, 2) res = self.api.projects.get_by_id(project=project1_child) self.assertEqual(res.project.name, "Root2/Pr2/Pr2") + self.assertEqual(res.project.basename, "Pr2") # merge project_with_task, (active, archived) = self._temp_project_with_tasks( @@ -102,6 +112,7 @@ class TestSubProjects(TestService): self.assertEqual(res.moved_projects, 1) res = self.api.projects.get_by_id(project=project_with_task) self.assertEqual(res.project.name, "Root2/Pr2/Pr4") + self.assertEqual(res.project.basename, "Pr4") with self.api.raises(errors.bad_request.InvalidProjectId): self.api.projects.get_by_id(project=merge_source) @@ -156,6 +167,11 @@ class TestSubProjects(TestService): self.assertEqual([p.id for p in res], [project1]) res = self.api.projects.get_all_ex(name="project1", parent=[project1]).projects self.assertEqual([p.id for p in res], [project2]) + # basename search + res = self.api.projects.get_all_ex( + basename="project2", shallow_search=True + ).projects + self.assertEqual(res, []) # global search finds all or below the specified level res = self.api.projects.get_all_ex(name="project1").projects @@ -163,7 +179,9 @@ class TestSubProjects(TestService): project4 = self._temp_project(name="project1/project2/project1") res = self.api.projects.get_all_ex(name="project1", parent=[project2]).projects self.assertEqual([p.id for p in res], [project4]) - + # basename search + res = self.api.projects.get_all_ex(basename="project2").projects + self.assertEqual([p.id for p in res], [project2]) self.api.projects.delete(project=project1, force=True) def test_get_all_with_check_own_contents(self): @@ -249,13 +267,14 @@ class TestSubProjects(TestService): **kwargs, ) - def _temp_task(self, **kwargs): + def _temp_task(self, client=None, **kwargs): return self.create_temp( "tasks", delete_params=self.delete_params, type="testing", name=db_id(), input=dict(view=dict()), + client=client, **kwargs, ) diff --git a/apiserver/tests/automated/test_task_events.py b/apiserver/tests/automated/test_task_events.py index ca28b08..fea05a3 100644 --- a/apiserver/tests/automated/test_task_events.py +++ b/apiserver/tests/automated/test_task_events.py @@ -67,6 +67,86 @@ class TestTaskEvents(TestService): ), ) + def test_task_single_value_metrics(self): + metric = "Metric1" + variant = "Variant1" + iter_count = 10 + task = self._temp_task() + special_iteration = -(2 ** 31) + events = [ + { + **self._create_task_event( + "training_stats_scalar", task, iteration or special_iteration + ), + "metric": metric, + "variant": variant, + "value": iteration, + } + for iteration in range(iter_count) + ] + self.send_batch(events) + + # special iteration is present in the events retrieval + metric_param = {"metric": metric, "variants": [variant]} + res = self.api.events.scalar_metrics_iter_raw( + task=task, batch_size=100, metric=metric_param, count_total=True + ) + self.assertEqual(res.returned, iter_count) + self.assertEqual(res.total, iter_count) + self.assertEqual( + res.variants[variant]["iter"], + [x or special_iteration for x in range(iter_count)], + ) + self.assertEqual( + res.variants[variant]["y"], list(range(iter_count)) + ) + + # but not in the histogram + data = self.api.events.scalar_metrics_iter_histogram(task=task) + self.assertEqual(data[metric][variant]["x"], list(range(1, iter_count))) + + # new api + res = self.api.events.get_task_single_value_metrics(tasks=[task]).tasks + self.assertEqual(len(res), 1) + data = res[0] + self.assertEqual(data.task, task) + self.assertEqual(len(data["values"]), 1) + value = data["values"][0] + self.assertEqual(value.metric, metric) + self.assertEqual(value.variant, variant) + self.assertEqual(value.value, 0) + + # update is working + task_data = self.api.tasks.get_by_id(task=task).task + last_metrics = first(first(task_data.last_metrics.values()).values()) + self.assertEqual(last_metrics.value, iter_count - 1) + new_value = 1000 + new_event = { + **self._create_task_event("training_stats_scalar", task, special_iteration), + "metric": metric, + "variant": variant, + "value": new_value, + } + self.send(new_event) + + res = self.api.events.scalar_metrics_iter_raw( + task=task, batch_size=100, metric=metric_param, count_total=True + ) + self.assertEqual( + res.variants[variant]["y"], + [y or new_value for y in range(iter_count)], + ) + + task_data = self.api.tasks.get_by_id(task=task).task + last_metrics = first(first(task_data.last_metrics.values()).values()) + self.assertEqual(last_metrics.value, new_value) + + data = self.api.events.get_task_single_value_metrics(tasks=[task]).tasks[0] + self.assertEqual(data.task, task) + self.assertEqual(len(data["values"]), 1) + value = data["values"][0] + self.assertEqual(value.value, new_value) + def test_last_scalar_metrics(self): metric = "Metric1" variant = "Variant1" diff --git a/apiserver/tests/automated/test_task_plots.py b/apiserver/tests/automated/test_task_plots.py new file mode 100644 index 0000000..f2af4c9 --- /dev/null +++ b/apiserver/tests/automated/test_task_plots.py @@ -0,0 +1,321 @@ +from functools import partial +from typing import Sequence, Mapping, Optional + +from apiserver.es_factory import es_factory +from apiserver.tests.automated import TestService + + +class TestTaskPlots(TestService): + def _temp_task(self, name="test task events"): + task_input = dict( + name=name, type="training", input=dict(mapping={}, view=dict(entries=[])), + ) + return self.create_temp("tasks", **task_input) + + @staticmethod + def _create_task_event(task, iteration, **kwargs): + return { + "worker": "test", + "type": "plot", + "task": task, + "iter": iteration, + "timestamp": kwargs.get("timestamp") or es_factory.get_timestamp_millis(), + **kwargs, + } + + def test_get_plot_sample(self): + task = self._temp_task() + metric = "Metric1" + variant = "Variant1" + + # test empty + res = self.api.events.get_plot_sample( + task=task, metric=metric, variant=variant + ) + self.assertEqual(res.min_iteration, None) + self.assertEqual(res.max_iteration, None) + self.assertEqual(res.event, None) + + # test existing events + iterations = 10 + events = [ + self._create_task_event( + task=task, + iteration=n, + metric=metric, + variant=variant, + plot_str=f"Test plot str {n}", + ) + for n in range(iterations) + ] + self.send_batch(events) + + # if iteration is not specified then return the event from the last one + res = self.api.events.get_plot_sample( + task=task, metric=metric, variant=variant + ) + self._assertEqualEvent(res.event, events[-1]) + self.assertEqual(res.max_iteration, iterations - 1) + self.assertEqual(res.min_iteration, 0) + self.assertTrue(res.scroll_id) + + # else from the specific iteration + iteration = 8 + res = self.api.events.get_plot_sample( + task=task, + metric=metric, + variant=variant, + iteration=iteration, + scroll_id=res.scroll_id, + ) + self._assertEqualEvent(res.event, events[iteration]) + + def test_next_plot_sample(self): + task = self._temp_task() + metric1 = "Metric1" + variant1 = "Variant1" + metric2 = "Metric2" + variant2 = "Variant2" + metrics = [(metric1, variant1), (metric2, variant2)] + # test existing events + events = [ + self._create_task_event( + task=task, + iteration=n, + metric=metric, + variant=variant, + plot_str=f"Test plot str {n}", + ) + for n in range(2) + for metric, variant in metrics + ] + self.send_batch(events) + + # single metric navigation + # init scroll + res = self.api.events.get_plot_sample( + task=task, metric=metric1, variant=variant1 + ) + self._assertEqualEvent(res.event, events[-2]) + + # navigate forwards + res = self.api.events.next_plot_sample( + task=task, scroll_id=res.scroll_id, navigate_earlier=False + ) + self.assertEqual(res.event, None) + + # navigate backwards + res = self.api.events.next_plot_sample( + task=task, scroll_id=res.scroll_id + ) + self._assertEqualEvent(res.event, events[-4]) + res = self.api.events.next_plot_sample( + task=task, scroll_id=res.scroll_id + ) + self._assertEqualEvent(res.event, None) + + # all metrics navigation + # init scroll + res = self.api.events.get_plot_sample( + task=task, metric=metric1, variant=variant1, navigate_current_metric=False + ) + self._assertEqualEvent(res.event, events[-2]) + + # navigate forwards + res = self.api.events.next_plot_sample( + task=task, scroll_id=res.scroll_id, navigate_earlier=False + ) + self._assertEqualEvent(res.event, events[-1]) + + # navigate backwards + res = self.api.events.next_plot_sample( + task=task, scroll_id=res.scroll_id + ) + self._assertEqualEvent(res.event, events[-2]) + res = self.api.events.next_plot_sample( + task=task, scroll_id=res.scroll_id + ) + self._assertEqualEvent(res.event, events[-3]) + + def _assertEqualEvent(self, ev1: dict, ev2: Optional[dict]): + if ev2 is None: + self.assertIsNone(ev1) + return + self.assertIsNotNone(ev1) + for field in ("iter", "timestamp", "metric", "variant", "plot_str", "task"): + self.assertEqual(ev1[field], ev2[field]) + + def test_task_plots(self): + task = self._temp_task() + + # test empty + res = self.api.events.plots(metrics=[{"task": task}], iters=5) + self.assertFalse(res.metrics[0].iterations) + res = self.api.events.plots( + metrics=[{"task": task}], iters=5, scroll_id=res.scroll_id, refresh=True + ) + self.assertFalse(res.metrics[0].iterations) + + # test not empty + metrics = { + "Metric1": ["Variant1", "Variant2"], + "Metric2": ["Variant3", "Variant4"], + } + events = [ + self._create_task_event( + task=task, + iteration=1, + metric=metric, + variant=variant, + plot_str=f"Test plot str {metric}_{variant}", + ) + for metric, variants in metrics.items() + for variant in variants + ] + self.send_batch(events) + scroll_id = self._assertTaskMetrics( + task=task, expected_metrics=metrics, iterations=1 + ) + + # test refresh + update = { + "Metric2": ["Variant3", "Variant4", "Variant5"], + "Metric3": ["VariantA", "VariantB"], + } + events = [ + self._create_task_event( + task=task, + iteration=2, + metric=metric, + variant=variant, + plot_str=f"Test plot str {metric}_{variant}_2", + ) + for metric, variants in update.items() + for variant in variants + ] + self.send_batch(events) + # without refresh the metric states are not updated + scroll_id = self._assertTaskMetrics( + task=task, expected_metrics=metrics, iterations=0, scroll_id=scroll_id + ) + + # with refresh there are new metrics and existing ones are updated + self._assertTaskMetrics( + task=task, + expected_metrics=update, + iterations=1, + scroll_id=scroll_id, + refresh=True, + ) + + def _assertTaskMetrics( + self, + task: str, + expected_metrics: Mapping[str, Sequence[str]], + iterations, + scroll_id: str = None, + refresh=False, + ) -> str: + res = self.api.events.plots( + metrics=[{"task": task}], iters=1, scroll_id=scroll_id, refresh=refresh + ) + if not iterations: + self.assertTrue(all(m.iterations == [] for m in res.metrics)) + return res.scroll_id + + expected_variants = set((m, var) for m, vars_ in expected_metrics.items() for var in vars_) + for metric_data in res.metrics: + self.assertEqual(len(metric_data.iterations), iterations) + for it_data in metric_data.iterations: + self.assertEqual( + set((e.metric, e.variant) for e in it_data.events), expected_variants + ) + + return res.scroll_id + + def test_plots_navigation(self): + task = self._temp_task() + metric = "Metric1" + variants = ["Variant1", "Variant2"] + iterations = 10 + + # test empty + res = self.api.events.plots( + metrics=[{"task": task, "metric": metric}], iters=5, + ) + self.assertFalse(res.metrics[0].iterations) + + # create events + events = [ + self._create_task_event( + task=task, + iteration=n, + metric=metric, + variant=variant, + plot_str=f"{metric}_{variant}_{n}", + ) + for n in range(iterations) + for variant in variants + ] + self.send_batch(events) + + # init testing + scroll_id = None + assert_plots = partial( + self._assertPlots, + task=task, + metric=metric, + iterations=iterations, + variants=len(variants) + ) + + # test forward navigation + for page in range(3): + scroll_id = assert_plots(scroll_id=scroll_id, expected_page=page) + + # test backwards navigation + scroll_id = assert_plots( + scroll_id=scroll_id, expected_page=0, navigate_earlier=False + ) + + # beyond the latest iteration and back + res = self.api.events.debug_images( + metrics=[{"task": task, "metric": metric}], + iters=5, + scroll_id=scroll_id, + navigate_earlier=False, + ) + self.assertEqual(len(res["metrics"][0]["iterations"]), 0) + assert_plots(scroll_id=scroll_id, expected_page=1) + + # refresh + assert_plots(scroll_id=scroll_id, expected_page=0, refresh=True) + + def _assertPlots( + self, + task, + metric, + iterations: int, + variants: int, + scroll_id, + expected_page: int, + iters: int = 5, + **extra_params, + ) -> str: + res = self.api.events.plots( + metrics=[{"task": task, "metric": metric}], + iters=iters, + scroll_id=scroll_id, + **extra_params, + ) + data = res["metrics"][0] + self.assertEqual(data["task"], task) + left_iterations = max(0, iterations - expected_page * iters) + self.assertEqual(len(data["iterations"]), min(iters, left_iterations)) + for it in data["iterations"]: + self.assertEqual(len(it["events"]), variants) + return res.scroll_id + + def send_batch(self, events): + _, data = self.api.send_batch("events.add_batch", events) + return data diff --git a/apiserver/tests/automated/test_workers.py b/apiserver/tests/automated/test_workers.py index e7d315f..7735d5e 100644 --- a/apiserver/tests/automated/test_workers.py +++ b/apiserver/tests/automated/test_workers.py @@ -12,11 +12,8 @@ log = config.logger(__file__) class TestWorkersService(TestService): - def setUp(self, version="2.4"): - super().setUp(version=version) - - def _check_exists(self, worker: str, exists: bool = True): - workers = self.api.workers.get_all(last_seen=100).workers + def _check_exists(self, worker: str, exists: bool = True, tags: list = None): + workers = self.api.workers.get_all(last_seen=100, tags=tags).workers found = any(w for w in workers if w.id == worker) assert exists == found @@ -40,6 +37,14 @@ class TestWorkersService(TestService): time.sleep(5) self._check_exists(test_worker, False) + def test_filters(self): + test_worker = f"test_{uuid4().hex}" + self.api.workers.register(worker=test_worker, tags=["application"], timeout=3) + self._check_exists(test_worker) + self._check_exists(test_worker, tags=["application", "test"]) + self._check_exists(test_worker, False, tags=["test"]) + self._check_exists(test_worker, False, tags=["-application"]) + def _simulate_workers(self) -> Sequence[str]: """ Two workers writing the same metrics. One for 4 seconds. Another one for 2