Allow mixing Model and task events in the same events batch

This commit is contained in:
allegroai 2023-05-25 19:19:45 +03:00
parent 818496236b
commit de9651d761
2 changed files with 79 additions and 67 deletions

View File

@ -131,19 +131,39 @@ class EventBLL(object):
def add_events(
self, company_id, events, worker, allow_locked=False
) -> Tuple[int, int, dict]:
model_events = events[0].get("model_event", False)
task_ids = set()
model_ids = set()
for event in events:
if event.get("model_event", model_events) != model_events:
raise errors.bad_request.ValidationError(
"Inconsistent model_event setting in the passed events"
)
if event.get("model_event", False):
model = event.pop("model", None)
if model is not None:
event["task"] = model
model_ids.add(event.get("task"))
else:
event["model_event"] = False
task_ids.add(event.get("task"))
if event.pop("allow_locked", allow_locked) != allow_locked:
raise errors.bad_request.ValidationError(
"Inconsistent allow_locked setting in the passed events"
)
task_ids.discard(None)
model_ids.discard(None)
found_in_both = task_ids.intersection(model_ids)
if found_in_both:
raise errors.bad_request.ValidationError(
"Inconsistent model_event setting in the passed events",
tasks=found_in_both,
)
valid_models = self._get_valid_models(
company_id, model_ids=model_ids, allow_locked_models=allow_locked,
)
valid_tasks = self._get_valid_tasks(
company_id, task_ids=task_ids, allow_locked_tasks=allow_locked,
)
actions: List[dict] = []
task_or_model_ids = set()
used_task_ids = set()
used_model_ids = set()
task_iteration = defaultdict(lambda: 0)
task_last_scalar_events = nested_dict(
3, dict
@ -153,28 +173,6 @@ class EventBLL(object):
) # task_id -> metric_hash -> event_type -> MetricEvent
errors_per_type = defaultdict(int)
invalid_iteration_error = f"Iteration number should not exceed {MAX_LONG}"
if model_events:
for event in events:
model = event.pop("model", None)
if model is not None:
event["task"] = model
valid_entities = self._get_valid_models(
company_id,
model_ids={
event["task"] for event in events if event.get("task") is not None
},
allow_locked_models=allow_locked,
)
entity_name = "model"
else:
valid_entities = self._get_valid_tasks(
company_id,
task_ids={
event["task"] for event in events if event.get("task") is not None
},
allow_locked_tasks=allow_locked,
)
entity_name = "task"
for event in events:
# remove spaces from event type
@ -188,7 +186,8 @@ class EventBLL(object):
errors_per_type[f"Invalid event type {event_type}"] += 1
continue
if model_events and event_type == EventType.task_log.value:
model_event = event["model_event"]
if model_event and event_type == EventType.task_log.value:
errors_per_type[f"Task log events are not supported for models"] += 1
continue
@ -197,8 +196,12 @@ class EventBLL(object):
errors_per_type["Event must have a 'task' field"] += 1
continue
if task_or_model_id not in valid_entities:
errors_per_type[f"Invalid {entity_name} id {task_or_model_id}"] += 1
if (model_event and task_or_model_id not in valid_models) or (
not model_event and task_or_model_id not in valid_tasks
):
errors_per_type[
f"Invalid {'model' if model_event else 'task'} id {task_or_model_id}"
] += 1
continue
event["type"] = event_type
@ -233,7 +236,6 @@ class EventBLL(object):
event["metric"] = event.get("metric") or ""
event["variant"] = event.get("variant") or ""
event["model_event"] = model_events
index_name = get_index_name(company_id, event_type)
es_action = {
@ -242,13 +244,12 @@ class EventBLL(object):
"_source": event,
}
# for "log" events, don't assing custom _id - whatever is sent, is written (not overwritten)
# for "log" events, don't assign custom _id - whatever is sent, is written (not overwritten)
if event_type != EventType.task_log.value:
es_action["_id"] = self._get_event_id(event)
else:
es_action["_id"] = dbutils.id()
task_or_model_ids.add(task_or_model_id)
if (
iter is not None
and event.get("metric") not in self._skip_iteration_for_metric
@ -257,7 +258,10 @@ class EventBLL(object):
iter, task_iteration[task_or_model_id]
)
if not model_events:
if model_event:
used_model_ids.add(task_or_model_id)
else:
used_task_ids.add(task_or_model_id)
self._update_last_metric_events_for_task(
last_events=task_last_events[task_or_model_id], event=event,
)
@ -302,39 +306,32 @@ class EventBLL(object):
else:
errors_per_type["Error when indexing events batch"] += 1
remaining_tasks = set()
now = datetime.utcnow()
for task_or_model_id in task_or_model_ids:
# Update related tasks. For reasons of performance, we prefer to update
# all of them and not only those who's events were successful
if model_events:
ModelBLL.update_statistics(
company_id=company_id,
model_id=task_or_model_id,
last_iteration_max=task_iteration.get(task_or_model_id),
last_scalar_events=task_last_scalar_events.get(
task_or_model_id
),
)
else:
updated = self._update_task(
company_id=company_id,
task_id=task_or_model_id,
now=now,
iter_max=task_iteration.get(task_or_model_id),
last_scalar_events=task_last_scalar_events.get(
task_or_model_id
),
last_events=task_last_events.get(task_or_model_id),
)
if not updated:
remaining_tasks.add(task_or_model_id)
continue
for model_id in used_model_ids:
ModelBLL.update_statistics(
company_id=company_id,
model_id=model_id,
last_iteration_max=task_iteration.get(model_id),
last_scalar_events=task_last_scalar_events.get(model_id),
)
remaining_tasks = set()
now = datetime.utcnow()
for task_id in used_task_ids:
# Update related tasks. For reasons of performance, we prefer to update
# all of them and not only those who's events were successful
updated = self._update_task(
company_id=company_id,
task_id=task_id,
now=now,
iter_max=task_iteration.get(task_id),
last_scalar_events=task_last_scalar_events.get(task_id),
last_events=task_last_events.get(task_id),
)
if not updated:
remaining_tasks.add(task_id)
continue
if remaining_tasks:
TaskBLL.set_last_update(
remaining_tasks, company_id, last_update=now
)
if remaining_tasks:
TaskBLL.set_last_update(remaining_tasks, company_id, last_update=now)
# this is for backwards compatibility with streaming bulk throwing exception on those
invalid_iterations_count = errors_per_type.get(invalid_iteration_error)

View File

@ -195,6 +195,7 @@ class TestTaskEvents(TestService):
with self.api.raises(errors.bad_request.EventsNotAdded):
self.send(log_event)
# mixed batch
events = [
{
**self._create_task_event("training_stats_scalar", model, iteration),
@ -207,6 +208,18 @@ class TestTaskEvents(TestService):
for metric_idx in range(5)
for variant_idx in range(5)
]
task = self._temp_task()
# noinspection PyTypeChecker
events.append(
self._create_task_event(
"log",
task=task,
iteration=0,
msg=f"This is a log message",
metric="Metric0",
variant="Variant0",
)
)
self.send_batch(events)
data = self.api.events.scalar_metrics_iter_histogram(
task=model, model_events=True
@ -229,6 +242,8 @@ class TestTaskEvents(TestService):
self.assertEqual(0, metric_data.min_value)
self.assertEqual(0, metric_data.min_value_iteration)
self._assert_log_events(task=task, expected_total=1)
def test_error_events(self):
task = self._temp_task()
events = [