diff --git a/apiserver/bll/event/event_bll.py b/apiserver/bll/event/event_bll.py index 66e07ae..31c25e8 100644 --- a/apiserver/bll/event/event_bll.py +++ b/apiserver/bll/event/event_bll.py @@ -438,10 +438,8 @@ class EventBLL(object): last_events contains [hashed_metric_name -> hashed_variant_name -> event]. Keys are hashed to avoid mongodb key conflicts due to invalid characters and/or long field names. """ - metric = event.get("metric") - variant = event.get("variant") - if not (metric and variant): - return + metric = event.get("metric") or "" + variant = event.get("variant") or "" metric_hash = dbutils.hash_field_name(metric) variant_hash = dbutils.hash_field_name(variant) @@ -486,9 +484,9 @@ class EventBLL(object): recent than the currently stored event for its metric/event_type combination. last_events contains [metric_name -> event_type -> event] """ - metric = event.get("metric") + metric = event.get("metric") or "" event_type = event.get("type") - if not (metric and event_type): + if not event_type: return timestamp = last_events[metric][event_type].get("timestamp", None) diff --git a/apiserver/tests/automated/test_task_events.py b/apiserver/tests/automated/test_task_events.py index 3cc3963..d9998ea 100644 --- a/apiserver/tests/automated/test_task_events.py +++ b/apiserver/tests/automated/test_task_events.py @@ -193,33 +193,33 @@ class TestTaskEvents(TestService): def test_last_scalar_metrics(self): metric = "Metric1" - variant = "Variant1" - iter_count = 100 - task = self._temp_task() - events = [ - { - **self._create_task_event("training_stats_scalar", task, iteration), - "metric": metric, - "variant": variant, - "value": iteration, - } - for iteration in range(iter_count) - ] - # send 2 batches to check the interaction with already stored db value - # each batch contains multiple iterations - self.send_batch(events[:50]) - self.send_batch(events[50:]) + for variant in ("Variant1", None): + iter_count = 100 + task = self._temp_task() + events = [ + { + **self._create_task_event("training_stats_scalar", task, iteration), + "metric": metric, + "variant": variant, + "value": iteration, + } + for iteration in range(iter_count) + ] + # send 2 batches to check the interaction with already stored db value + # each batch contains multiple iterations + self.send_batch(events[:50]) + self.send_batch(events[50:]) - task_data = self.api.tasks.get_by_id(task=task).task - metric_data = first(first(task_data.last_metrics.values()).values()) - self.assertEqual(iter_count - 1, metric_data.value) - self.assertEqual(iter_count - 1, metric_data.max_value) - self.assertEqual(iter_count - 1, metric_data.max_value_iteration) - self.assertEqual(0, metric_data.min_value) - self.assertEqual(0, metric_data.min_value_iteration) + task_data = self.api.tasks.get_by_id(task=task).task + metric_data = first(first(task_data.last_metrics.values()).values()) + self.assertEqual(iter_count - 1, metric_data.value) + self.assertEqual(iter_count - 1, metric_data.max_value) + self.assertEqual(iter_count - 1, metric_data.max_value_iteration) + self.assertEqual(0, metric_data.min_value) + self.assertEqual(0, metric_data.min_value_iteration) - res = self.api.events.get_task_latest_scalar_values(task=task) - self.assertEqual(iter_count - 1, res.last_iter) + res = self.api.events.get_task_latest_scalar_values(task=task) + self.assertEqual(iter_count - 1, res.last_iter) def test_model_events(self): model = self._temp_model(ready=False)