Add support for returning only valid plot events

This commit is contained in:
allegroai
2021-01-05 16:41:55 +02:00
parent 171969c5ea
commit f4ead86449
6 changed files with 190 additions and 11 deletions

View File

@@ -435,9 +435,9 @@ class TestTaskEvents(TestService):
)
self.send(event)
event = self._create_task_event("plot", task, 100)
event["metric"] = "confusion"
event.update(
event1 = self._create_task_event("plot", task, 100)
event1["metric"] = "confusion"
event1.update(
{
"plot_str": json.dumps(
{
@@ -476,14 +476,54 @@ class TestTaskEvents(TestService):
)
}
)
self.send(event)
self.send(event1)
data = self.api.events.get_task_plots(task=task)
assert len(data["plots"]) == 2
plots = self.api.events.get_task_plots(task=task).plots
self.assertEqual(
{e["plot_str"] for e in (event, event1)}, {p.plot_str for p in plots}
)
self.api.tasks.reset(task=task)
data = self.api.events.get_task_plots(task=task)
assert len(data["plots"]) == 0
plots = self.api.events.get_task_plots(task=task).plots
self.assertEqual(len(plots), 0)
@unittest.skip("this test will run only if 'validate_plot_str' is set to true")
def test_plots_validation(self):
valid_plot_str = json.dumps({"data": []})
invalid_plot_str = "Not a valid json"
task = self._temp_task()
event = self._create_task_event(
"plot", task, 0, metric="test1", plot_str=valid_plot_str
)
event1 = self._create_task_event(
"plot", task, 100, metric="test2", plot_str=invalid_plot_str
)
self.send_batch([event, event1])
res = self.api.events.get_task_plots(task=task).plots
self.assertEqual(len(res), 1)
self.assertEqual(res[0].metric, "test1")
event = self._create_task_event(
"plot",
task,
0,
metric="test1",
plot_str=valid_plot_str,
skip_validation=True,
)
event1 = self._create_task_event(
"plot",
task,
100,
metric="test2",
plot_str=invalid_plot_str,
skip_validation=True,
)
self.send_batch([event, event1])
res = self.api.events.get_task_plots(task=task).plots
self.assertEqual(len(res), 2)
self.assertEqual(set(r.metric for r in res), {"test1", "test2"})
def send_batch(self, events):
_, data = self.api.send_batch("events.add_batch", events)