Fix single bad event causes events.add_batch to skip remaining events

This commit is contained in:
allegroai 2020-06-01 11:33:39 +03:00
parent ede5586ccc
commit b0b09616a8
5 changed files with 129 additions and 87 deletions

View File

@ -47,6 +47,7 @@ _error_codes = {
128: ('invalid_task_output', 'invalid task output'),
129: ('task_publish_in_progress', 'Task publish in progress'),
130: ('task_not_found', 'task not found'),
131: ('events_not_added', 'events not added'),
# Models
200: ('model_error', 'general task error'),

View File

@ -3,7 +3,7 @@ from collections import defaultdict
from contextlib import closing
from datetime import datetime
from operator import attrgetter
from typing import Sequence
from typing import Sequence, Set, Tuple
import six
from elasticsearch import helpers
@ -46,7 +46,22 @@ class EventBLL(object):
def metrics(self) -> EventMetrics:
return self._metrics
def add_events(self, company_id, events, worker, allow_locked_tasks=False):
@staticmethod
def _get_valid_tasks(company_id, task_ids: Set, allow_locked_tasks=False) -> Set:
"""Verify that task exists and can be updated"""
if not task_ids:
return set()
with translate_errors_context(), TimingContext("mongo", "task_by_ids"):
query = Q(id__in=task_ids, company=company_id)
if not allow_locked_tasks:
query &= Q(status__nin=LOCKED_TASK_STATUSES)
res = Task.objects(query).only("id")
return {r.id for r in res}
def add_events(
self, company_id, events, worker, allow_locked_tasks=False
) -> Tuple[int, int, dict]:
actions = []
task_ids = set()
task_iteration = defaultdict(lambda: 0)
@ -56,19 +71,32 @@ class EventBLL(object):
task_last_events = nested_dict(
3, dict
) # task_id -> metric_hash -> event_type -> MetricEvent
errors_per_type = defaultdict(int)
valid_tasks = self._get_valid_tasks(
company_id,
task_ids={event["task"] for event in events if event.get("task")},
allow_locked_tasks=allow_locked_tasks,
)
for event in events:
# remove spaces from event type
if "type" not in event:
raise errors.BadRequest("Event must have a 'type' field", event=event)
event_type = event.get("type")
if event_type is None:
errors_per_type["Event must have a 'type' field"] += 1
continue
event_type = event["type"].replace(" ", "_")
event_type = event_type.replace(" ", "_")
if event_type not in EVENT_TYPES:
raise errors.BadRequest(
"Invalid event type {}".format(event_type),
event=event,
types=EVENT_TYPES,
)
errors_per_type[f"Invalid event type {event_type}"] += 1
continue
task_id = event.get("task")
if task_id is None:
errors_per_type["Event must have a 'task' field"] += 1
continue
if task_id not in valid_tasks:
errors_per_type["Invalid task id"] += 1
continue
event["type"] = event_type
@ -114,89 +142,75 @@ class EventBLL(object):
else:
es_action["_id"] = dbutils.id()
task_id = event.get("task")
if task_id is not None:
es_action["_routing"] = task_id
task_ids.add(task_id)
if (
iter is not None
and event.get("metric") not in self._skip_iteration_for_metric
):
task_iteration[task_id] = max(iter, task_iteration[task_id])
es_action["_routing"] = task_id
task_ids.add(task_id)
if (
iter is not None
and event.get("metric") not in self._skip_iteration_for_metric
):
task_iteration[task_id] = max(iter, task_iteration[task_id])
self._update_last_metric_events_for_task(
last_events=task_last_events[task_id], event=event,
self._update_last_metric_events_for_task(
last_events=task_last_events[task_id], event=event,
)
if event_type == EventType.metrics_scalar.value:
self._update_last_scalar_events_for_task(
last_events=task_last_scalar_events[task_id], event=event
)
if event_type == EventType.metrics_scalar.value:
self._update_last_scalar_events_for_task(
last_events=task_last_scalar_events[task_id], event=event
)
else:
es_action["_routing"] = task_id
actions.append(es_action)
if task_ids:
# verify task_ids
with translate_errors_context(), TimingContext("mongo", "task_by_ids"):
extra_msg = None
query = Q(id__in=task_ids, company=company_id)
if not allow_locked_tasks:
query &= Q(status__nin=LOCKED_TASK_STATUSES)
extra_msg = "or task published"
res = Task.objects(query).only("id")
if len(res) < len(task_ids):
invalid_task_ids = tuple(set(task_ids) - set(r.id for r in res))
raise errors.bad_request.InvalidTaskId(
extra_msg, company=company_id, ids=invalid_task_ids
added = 0
if actions:
chunk_size = 500
with translate_errors_context(), TimingContext("es", "events_add_batch"):
# TODO: replace it with helpers.parallel_bulk in the future once the parallel pool leak is fixed
with closing(
helpers.streaming_bulk(
self.es,
actions,
chunk_size=chunk_size,
# thread_count=8,
refresh=True,
)
) as it:
for success, info in it:
if success:
added += chunk_size
else:
errors_per_type["Error when indexing events batch"] += 1
remaining_tasks = set()
now = datetime.utcnow()
for task_id in task_ids:
# Update related tasks. For reasons of performance, we prefer to update
# all of them and not only those who's events were successful
updated = self._update_task(
company_id=company_id,
task_id=task_id,
now=now,
iter_max=task_iteration.get(task_id),
last_scalar_events=task_last_scalar_events.get(task_id),
last_events=task_last_events.get(task_id),
)
errors_in_bulk = []
added = 0
chunk_size = 500
with translate_errors_context(), TimingContext("es", "events_add_batch"):
# TODO: replace it with helpers.parallel_bulk in the future once the parallel pool leak is fixed
with closing(
helpers.streaming_bulk(
self.es,
actions,
chunk_size=chunk_size,
# thread_count=8,
refresh=True,
)
) as it:
for success, info in it:
if success:
added += chunk_size
else:
errors_in_bulk.append(info)
if not updated:
remaining_tasks.add(task_id)
continue
remaining_tasks = set()
now = datetime.utcnow()
for task_id in task_ids:
# Update related tasks. For reasons of performance, we prefer to update all of them and not only those
# who's events were successful
updated = self._update_task(
company_id=company_id,
task_id=task_id,
now=now,
iter_max=task_iteration.get(task_id),
last_scalar_events=task_last_scalar_events.get(task_id),
last_events=task_last_events.get(task_id),
)
if not updated:
remaining_tasks.add(task_id)
continue
if remaining_tasks:
TaskBLL.set_last_update(remaining_tasks, company_id, last_update=now)
if remaining_tasks:
TaskBLL.set_last_update(
remaining_tasks, company_id, last_update=now
)
# Compensate for always adding chunk_size on success (last chunk is probably smaller)
added = min(added, len(actions))
return added, errors_in_bulk
if not added:
raise errors.bad_request.EventsNotAdded(**errors_per_type)
errors_count = sum(errors_per_type.values())
return added, errors_count, errors_per_type
def _update_last_scalar_events_for_task(self, last_events, event):
"""

View File

@ -258,6 +258,7 @@
properties {
added { type: integer }
errors { type: integer }
errors_info { type: object }
}
}
}

View File

@ -27,10 +27,10 @@ event_bll = EventBLL()
def add(call: APICall, company_id, req_model):
data = call.data.copy()
allow_locked = data.pop("allow_locked", False)
added, batch_errors = event_bll.add_events(
added, err_count, err_info = event_bll.add_events(
company_id, [data], call.worker, allow_locked_tasks=allow_locked
)
call.result.data = dict(added=added, errors=len(batch_errors))
call.result.data = dict(added=added, errors=err_count, errors_info=err_info)
call.kpis["events"] = 1
@ -40,8 +40,8 @@ def add_batch(call: APICall, company_id, req_model):
if events is None or len(events) == 0:
raise errors.bad_request.BatchContainsNoItems()
added, batch_errors = event_bll.add_events(company_id, events, call.worker)
call.result.data = dict(added=added, errors=len(batch_errors))
added, err_count, err_info = event_bll.add_events(company_id, events, call.worker)
call.result.data = dict(added=added, errors=err_count, errors_info=err_info)
call.kpis["events"] = len(events)

View File

@ -9,6 +9,7 @@ from statistics import mean
from typing import Sequence
import es_factory
from apierrors.errors.bad_request import EventsNotAdded
from tests.automated import TestService
@ -160,6 +161,30 @@ class TestTaskEvents(TestService):
self.assertEqual(len(it["events"]), events_per_iter)
return res.scroll_id
def test_error_events(self):
task = self._temp_task()
events = [
self._create_task_event("unknown type", task, iteration=1),
self._create_task_event("training_debug_image", task=None, iteration=1),
self._create_task_event(
"training_debug_image", task="Invalid task", iteration=1
),
]
# failure if no events added
with self.api.raises(EventsNotAdded):
self.send_batch(events)
events.append(
self._create_task_event("training_debug_image", task=task, iteration=1)
)
# success if at least one event added
res = self.send_batch(events)
self.assertEqual(res["added"], 1)
self.assertEqual(res["errors"], 3)
self.assertEqual(len(res["errors_info"]), 3)
res = self.api.events.get_task_events(task=task)
self.assertEqual(len(res.events), 1)
def test_task_logs(self):
task = self._temp_task()
timestamp = es_factory.get_timestamp_millis()
@ -429,7 +454,8 @@ class TestTaskEvents(TestService):
assert len(data["plots"]) == 0
def send_batch(self, events):
self.api.send_batch("events.add_batch", events)
_, data = self.api.send_batch("events.add_batch", events)
return data
def send(self, event):
self.api.send("events.add", event)