diff --git a/apiserver/bll/event/event_bll.py b/apiserver/bll/event/event_bll.py index a6b775b..ce0fee3 100644 --- a/apiserver/bll/event/event_bll.py +++ b/apiserver/bll/event/event_bll.py @@ -131,19 +131,39 @@ class EventBLL(object): def add_events( self, company_id, events, worker, allow_locked=False ) -> Tuple[int, int, dict]: - model_events = events[0].get("model_event", False) + task_ids = set() + model_ids = set() for event in events: - if event.get("model_event", model_events) != model_events: - raise errors.bad_request.ValidationError( - "Inconsistent model_event setting in the passed events" - ) + if event.get("model_event", False): + model = event.pop("model", None) + if model is not None: + event["task"] = model + model_ids.add(event.get("task")) + else: + event["model_event"] = False + task_ids.add(event.get("task")) if event.pop("allow_locked", allow_locked) != allow_locked: raise errors.bad_request.ValidationError( "Inconsistent allow_locked setting in the passed events" ) + task_ids.discard(None) + model_ids.discard(None) + found_in_both = task_ids.intersection(model_ids) + if found_in_both: + raise errors.bad_request.ValidationError( + "Inconsistent model_event setting in the passed events", + tasks=found_in_both, + ) + valid_models = self._get_valid_models( + company_id, model_ids=model_ids, allow_locked_models=allow_locked, + ) + valid_tasks = self._get_valid_tasks( + company_id, task_ids=task_ids, allow_locked_tasks=allow_locked, + ) actions: List[dict] = [] - task_or_model_ids = set() + used_task_ids = set() + used_model_ids = set() task_iteration = defaultdict(lambda: 0) task_last_scalar_events = nested_dict( 3, dict @@ -153,28 +173,6 @@ class EventBLL(object): ) # task_id -> metric_hash -> event_type -> MetricEvent errors_per_type = defaultdict(int) invalid_iteration_error = f"Iteration number should not exceed {MAX_LONG}" - if model_events: - for event in events: - model = event.pop("model", None) - if model is not None: - event["task"] = model - valid_entities = self._get_valid_models( - company_id, - model_ids={ - event["task"] for event in events if event.get("task") is not None - }, - allow_locked_models=allow_locked, - ) - entity_name = "model" - else: - valid_entities = self._get_valid_tasks( - company_id, - task_ids={ - event["task"] for event in events if event.get("task") is not None - }, - allow_locked_tasks=allow_locked, - ) - entity_name = "task" for event in events: # remove spaces from event type @@ -188,7 +186,8 @@ class EventBLL(object): errors_per_type[f"Invalid event type {event_type}"] += 1 continue - if model_events and event_type == EventType.task_log.value: + model_event = event["model_event"] + if model_event and event_type == EventType.task_log.value: errors_per_type[f"Task log events are not supported for models"] += 1 continue @@ -197,8 +196,12 @@ class EventBLL(object): errors_per_type["Event must have a 'task' field"] += 1 continue - if task_or_model_id not in valid_entities: - errors_per_type[f"Invalid {entity_name} id {task_or_model_id}"] += 1 + if (model_event and task_or_model_id not in valid_models) or ( + not model_event and task_or_model_id not in valid_tasks + ): + errors_per_type[ + f"Invalid {'model' if model_event else 'task'} id {task_or_model_id}" + ] += 1 continue event["type"] = event_type @@ -233,7 +236,6 @@ class EventBLL(object): event["metric"] = event.get("metric") or "" event["variant"] = event.get("variant") or "" - event["model_event"] = model_events index_name = get_index_name(company_id, event_type) es_action = { @@ -242,13 +244,12 @@ class EventBLL(object): "_source": event, } - # for "log" events, don't assing custom _id - whatever is sent, is written (not overwritten) + # for "log" events, don't assign custom _id - whatever is sent, is written (not overwritten) if event_type != EventType.task_log.value: es_action["_id"] = self._get_event_id(event) else: es_action["_id"] = dbutils.id() - task_or_model_ids.add(task_or_model_id) if ( iter is not None and event.get("metric") not in self._skip_iteration_for_metric @@ -257,7 +258,10 @@ class EventBLL(object): iter, task_iteration[task_or_model_id] ) - if not model_events: + if model_event: + used_model_ids.add(task_or_model_id) + else: + used_task_ids.add(task_or_model_id) self._update_last_metric_events_for_task( last_events=task_last_events[task_or_model_id], event=event, ) @@ -302,39 +306,32 @@ class EventBLL(object): else: errors_per_type["Error when indexing events batch"] += 1 - remaining_tasks = set() - now = datetime.utcnow() - for task_or_model_id in task_or_model_ids: - # Update related tasks. For reasons of performance, we prefer to update - # all of them and not only those who's events were successful - if model_events: - ModelBLL.update_statistics( - company_id=company_id, - model_id=task_or_model_id, - last_iteration_max=task_iteration.get(task_or_model_id), - last_scalar_events=task_last_scalar_events.get( - task_or_model_id - ), - ) - else: - updated = self._update_task( - company_id=company_id, - task_id=task_or_model_id, - now=now, - iter_max=task_iteration.get(task_or_model_id), - last_scalar_events=task_last_scalar_events.get( - task_or_model_id - ), - last_events=task_last_events.get(task_or_model_id), - ) - if not updated: - remaining_tasks.add(task_or_model_id) - continue + for model_id in used_model_ids: + ModelBLL.update_statistics( + company_id=company_id, + model_id=model_id, + last_iteration_max=task_iteration.get(model_id), + last_scalar_events=task_last_scalar_events.get(model_id), + ) + remaining_tasks = set() + now = datetime.utcnow() + for task_id in used_task_ids: + # Update related tasks. For reasons of performance, we prefer to update + # all of them and not only those who's events were successful + updated = self._update_task( + company_id=company_id, + task_id=task_id, + now=now, + iter_max=task_iteration.get(task_id), + last_scalar_events=task_last_scalar_events.get(task_id), + last_events=task_last_events.get(task_id), + ) + if not updated: + remaining_tasks.add(task_id) + continue - if remaining_tasks: - TaskBLL.set_last_update( - remaining_tasks, company_id, last_update=now - ) + if remaining_tasks: + TaskBLL.set_last_update(remaining_tasks, company_id, last_update=now) # this is for backwards compatibility with streaming bulk throwing exception on those invalid_iterations_count = errors_per_type.get(invalid_iteration_error) diff --git a/apiserver/tests/automated/test_task_events.py b/apiserver/tests/automated/test_task_events.py index 8dbdc7f..6ca7f44 100644 --- a/apiserver/tests/automated/test_task_events.py +++ b/apiserver/tests/automated/test_task_events.py @@ -195,6 +195,7 @@ class TestTaskEvents(TestService): with self.api.raises(errors.bad_request.EventsNotAdded): self.send(log_event) + # mixed batch events = [ { **self._create_task_event("training_stats_scalar", model, iteration), @@ -207,6 +208,18 @@ class TestTaskEvents(TestService): for metric_idx in range(5) for variant_idx in range(5) ] + task = self._temp_task() + # noinspection PyTypeChecker + events.append( + self._create_task_event( + "log", + task=task, + iteration=0, + msg=f"This is a log message", + metric="Metric0", + variant="Variant0", + ) + ) self.send_batch(events) data = self.api.events.scalar_metrics_iter_histogram( task=model, model_events=True @@ -229,6 +242,8 @@ class TestTaskEvents(TestService): self.assertEqual(0, metric_data.min_value) self.assertEqual(0, metric_data.min_value_iteration) + self._assert_log_events(task=task, expected_total=1) + def test_error_events(self): task = self._temp_task() events = [