diff --git a/apiserver/bll/event/event_bll.py b/apiserver/bll/event/event_bll.py index b4c934c..8649016 100644 --- a/apiserver/bll/event/event_bll.py +++ b/apiserver/bll/event/event_bll.py @@ -64,7 +64,7 @@ class PlotFields: class EventBLL(object): - id_fields = ("task", "iter", "metric", "variant", "key") + event_id_fields = ("task", "iter", "metric", "variant", "key") empty_scroll = "FFFF" img_source_regex = re.compile( r"['\"]source['\"]:\s?['\"]([a-z][a-z0-9+\-.]*://.*?)['\"]", @@ -219,13 +219,10 @@ class EventBLL(object): # force iter to be a long int iter = event.get("iter") if iter is not None: - if model_events: - iter = 0 - else: - iter = int(iter) - if iter > MAX_LONG or iter < MIN_LONG: - errors_per_type[invalid_iteration_error] += 1 - continue + iter = int(iter) + if iter > MAX_LONG or iter < MIN_LONG: + errors_per_type[invalid_iteration_error] += 1 + continue event["iter"] = iter # used to have "values" to indicate array. no need anymore @@ -487,7 +484,7 @@ class EventBLL(object): ) def _get_event_id(self, event): - id_values = (str(event[field]) for field in self.id_fields if field in event) + id_values = (str(event[field]) for field in self.event_id_fields if field in event) return hashlib.md5("-".join(id_values).encode()).hexdigest() def scroll_task_events( diff --git a/apiserver/tests/automated/test_task_events.py b/apiserver/tests/automated/test_task_events.py index cf0b85b..7256936 100644 --- a/apiserver/tests/automated/test_task_events.py +++ b/apiserver/tests/automated/test_task_events.py @@ -19,7 +19,9 @@ class TestTaskEvents(TestService): task_input = dict( name=name, type="training", input=dict(mapping={}, view=dict(entries=[])), ) - return self.create_temp("tasks", delete_paramse=self.delete_params, **task_input) + return self.create_temp( + "tasks", delete_paramse=self.delete_params, **task_input + ) def _temp_model(self, name="test model events", **kwargs): self.update_missing(kwargs, name=name, uri="file:///a/b", labels={}) @@ -104,9 +106,7 @@ class TestTaskEvents(TestService): res.variants[variant]["iter"], [x or special_iteration for x in range(iter_count)], ) - self.assertEqual( - res.variants[variant]["y"], list(range(iter_count)) - ) + self.assertEqual(res.variants[variant]["y"], list(range(iter_count))) # but not in the histogram data = self.api.events.scalar_metrics_iter_histogram(task=task) @@ -140,8 +140,7 @@ class TestTaskEvents(TestService): task=task, batch_size=100, metric=metric_param, count_total=True ) self.assertEqual( - res.variants[variant]["y"], - [y or new_value for y in range(iter_count)], + res.variants[variant]["y"], [y or new_value for y in range(iter_count)], ) task_data = self.api.tasks.get_by_id(task=task).task @@ -198,7 +197,6 @@ class TestTaskEvents(TestService): with self.api.raises(errors.bad_request.EventsNotAdded): self.send(log_event) - # send metric events and check that model data always have iteration 0 and only last data is saved events = [ { **self._create_task_event("training_stats_scalar", model, iteration), @@ -212,13 +210,15 @@ class TestTaskEvents(TestService): for variant_idx in range(5) ] self.send_batch(events) - data = self.api.events.scalar_metrics_iter_histogram(task=model, model_events=True) + data = self.api.events.scalar_metrics_iter_histogram( + task=model, model_events=True + ) self.assertEqual(list(data), [f"Metric{idx}" for idx in range(5)]) metric_data = data.Metric0 self.assertEqual(list(metric_data), [f"Variant{idx}" for idx in range(5)]) variant_data = metric_data.Variant0 - self.assertEqual(variant_data.x, [0]) - self.assertEqual(variant_data.y, [1.0]) + self.assertEqual(variant_data.x, [0, 1]) + self.assertEqual(variant_data.y, [0.0, 1.0]) def test_error_events(self): task = self._temp_task()