Allow mixing Model and task events in the same events batch

This commit is contained in:
allegroai 2023-05-25 19:19:45 +03:00
parent 818496236b
commit de9651d761
2 changed files with 79 additions and 67 deletions

View File

@ -131,19 +131,39 @@ class EventBLL(object):
def add_events( def add_events(
self, company_id, events, worker, allow_locked=False self, company_id, events, worker, allow_locked=False
) -> Tuple[int, int, dict]: ) -> Tuple[int, int, dict]:
model_events = events[0].get("model_event", False) task_ids = set()
model_ids = set()
for event in events: for event in events:
if event.get("model_event", model_events) != model_events: if event.get("model_event", False):
raise errors.bad_request.ValidationError( model = event.pop("model", None)
"Inconsistent model_event setting in the passed events" if model is not None:
) event["task"] = model
model_ids.add(event.get("task"))
else:
event["model_event"] = False
task_ids.add(event.get("task"))
if event.pop("allow_locked", allow_locked) != allow_locked: if event.pop("allow_locked", allow_locked) != allow_locked:
raise errors.bad_request.ValidationError( raise errors.bad_request.ValidationError(
"Inconsistent allow_locked setting in the passed events" "Inconsistent allow_locked setting in the passed events"
) )
task_ids.discard(None)
model_ids.discard(None)
found_in_both = task_ids.intersection(model_ids)
if found_in_both:
raise errors.bad_request.ValidationError(
"Inconsistent model_event setting in the passed events",
tasks=found_in_both,
)
valid_models = self._get_valid_models(
company_id, model_ids=model_ids, allow_locked_models=allow_locked,
)
valid_tasks = self._get_valid_tasks(
company_id, task_ids=task_ids, allow_locked_tasks=allow_locked,
)
actions: List[dict] = [] actions: List[dict] = []
task_or_model_ids = set() used_task_ids = set()
used_model_ids = set()
task_iteration = defaultdict(lambda: 0) task_iteration = defaultdict(lambda: 0)
task_last_scalar_events = nested_dict( task_last_scalar_events = nested_dict(
3, dict 3, dict
@ -153,28 +173,6 @@ class EventBLL(object):
) # task_id -> metric_hash -> event_type -> MetricEvent ) # task_id -> metric_hash -> event_type -> MetricEvent
errors_per_type = defaultdict(int) errors_per_type = defaultdict(int)
invalid_iteration_error = f"Iteration number should not exceed {MAX_LONG}" invalid_iteration_error = f"Iteration number should not exceed {MAX_LONG}"
if model_events:
for event in events:
model = event.pop("model", None)
if model is not None:
event["task"] = model
valid_entities = self._get_valid_models(
company_id,
model_ids={
event["task"] for event in events if event.get("task") is not None
},
allow_locked_models=allow_locked,
)
entity_name = "model"
else:
valid_entities = self._get_valid_tasks(
company_id,
task_ids={
event["task"] for event in events if event.get("task") is not None
},
allow_locked_tasks=allow_locked,
)
entity_name = "task"
for event in events: for event in events:
# remove spaces from event type # remove spaces from event type
@ -188,7 +186,8 @@ class EventBLL(object):
errors_per_type[f"Invalid event type {event_type}"] += 1 errors_per_type[f"Invalid event type {event_type}"] += 1
continue continue
if model_events and event_type == EventType.task_log.value: model_event = event["model_event"]
if model_event and event_type == EventType.task_log.value:
errors_per_type[f"Task log events are not supported for models"] += 1 errors_per_type[f"Task log events are not supported for models"] += 1
continue continue
@ -197,8 +196,12 @@ class EventBLL(object):
errors_per_type["Event must have a 'task' field"] += 1 errors_per_type["Event must have a 'task' field"] += 1
continue continue
if task_or_model_id not in valid_entities: if (model_event and task_or_model_id not in valid_models) or (
errors_per_type[f"Invalid {entity_name} id {task_or_model_id}"] += 1 not model_event and task_or_model_id not in valid_tasks
):
errors_per_type[
f"Invalid {'model' if model_event else 'task'} id {task_or_model_id}"
] += 1
continue continue
event["type"] = event_type event["type"] = event_type
@ -233,7 +236,6 @@ class EventBLL(object):
event["metric"] = event.get("metric") or "" event["metric"] = event.get("metric") or ""
event["variant"] = event.get("variant") or "" event["variant"] = event.get("variant") or ""
event["model_event"] = model_events
index_name = get_index_name(company_id, event_type) index_name = get_index_name(company_id, event_type)
es_action = { es_action = {
@ -242,13 +244,12 @@ class EventBLL(object):
"_source": event, "_source": event,
} }
# for "log" events, don't assing custom _id - whatever is sent, is written (not overwritten) # for "log" events, don't assign custom _id - whatever is sent, is written (not overwritten)
if event_type != EventType.task_log.value: if event_type != EventType.task_log.value:
es_action["_id"] = self._get_event_id(event) es_action["_id"] = self._get_event_id(event)
else: else:
es_action["_id"] = dbutils.id() es_action["_id"] = dbutils.id()
task_or_model_ids.add(task_or_model_id)
if ( if (
iter is not None iter is not None
and event.get("metric") not in self._skip_iteration_for_metric and event.get("metric") not in self._skip_iteration_for_metric
@ -257,7 +258,10 @@ class EventBLL(object):
iter, task_iteration[task_or_model_id] iter, task_iteration[task_or_model_id]
) )
if not model_events: if model_event:
used_model_ids.add(task_or_model_id)
else:
used_task_ids.add(task_or_model_id)
self._update_last_metric_events_for_task( self._update_last_metric_events_for_task(
last_events=task_last_events[task_or_model_id], event=event, last_events=task_last_events[task_or_model_id], event=event,
) )
@ -302,39 +306,32 @@ class EventBLL(object):
else: else:
errors_per_type["Error when indexing events batch"] += 1 errors_per_type["Error when indexing events batch"] += 1
remaining_tasks = set() for model_id in used_model_ids:
now = datetime.utcnow() ModelBLL.update_statistics(
for task_or_model_id in task_or_model_ids: company_id=company_id,
# Update related tasks. For reasons of performance, we prefer to update model_id=model_id,
# all of them and not only those who's events were successful last_iteration_max=task_iteration.get(model_id),
if model_events: last_scalar_events=task_last_scalar_events.get(model_id),
ModelBLL.update_statistics( )
company_id=company_id, remaining_tasks = set()
model_id=task_or_model_id, now = datetime.utcnow()
last_iteration_max=task_iteration.get(task_or_model_id), for task_id in used_task_ids:
last_scalar_events=task_last_scalar_events.get( # Update related tasks. For reasons of performance, we prefer to update
task_or_model_id # all of them and not only those who's events were successful
), updated = self._update_task(
) company_id=company_id,
else: task_id=task_id,
updated = self._update_task( now=now,
company_id=company_id, iter_max=task_iteration.get(task_id),
task_id=task_or_model_id, last_scalar_events=task_last_scalar_events.get(task_id),
now=now, last_events=task_last_events.get(task_id),
iter_max=task_iteration.get(task_or_model_id), )
last_scalar_events=task_last_scalar_events.get( if not updated:
task_or_model_id remaining_tasks.add(task_id)
), continue
last_events=task_last_events.get(task_or_model_id),
)
if not updated:
remaining_tasks.add(task_or_model_id)
continue
if remaining_tasks: if remaining_tasks:
TaskBLL.set_last_update( TaskBLL.set_last_update(remaining_tasks, company_id, last_update=now)
remaining_tasks, company_id, last_update=now
)
# this is for backwards compatibility with streaming bulk throwing exception on those # this is for backwards compatibility with streaming bulk throwing exception on those
invalid_iterations_count = errors_per_type.get(invalid_iteration_error) invalid_iterations_count = errors_per_type.get(invalid_iteration_error)

View File

@ -195,6 +195,7 @@ class TestTaskEvents(TestService):
with self.api.raises(errors.bad_request.EventsNotAdded): with self.api.raises(errors.bad_request.EventsNotAdded):
self.send(log_event) self.send(log_event)
# mixed batch
events = [ events = [
{ {
**self._create_task_event("training_stats_scalar", model, iteration), **self._create_task_event("training_stats_scalar", model, iteration),
@ -207,6 +208,18 @@ class TestTaskEvents(TestService):
for metric_idx in range(5) for metric_idx in range(5)
for variant_idx in range(5) for variant_idx in range(5)
] ]
task = self._temp_task()
# noinspection PyTypeChecker
events.append(
self._create_task_event(
"log",
task=task,
iteration=0,
msg=f"This is a log message",
metric="Metric0",
variant="Variant0",
)
)
self.send_batch(events) self.send_batch(events)
data = self.api.events.scalar_metrics_iter_histogram( data = self.api.events.scalar_metrics_iter_histogram(
task=model, model_events=True task=model, model_events=True
@ -229,6 +242,8 @@ class TestTaskEvents(TestService):
self.assertEqual(0, metric_data.min_value) self.assertEqual(0, metric_data.min_value)
self.assertEqual(0, metric_data.min_value_iteration) self.assertEqual(0, metric_data.min_value_iteration)
self._assert_log_events(task=task, expected_total=1)
def test_error_events(self): def test_error_events(self):
task = self._temp_task() task = self._temp_task()
events = [ events = [