Fix single bad event causes events.add_batch to skip remaining events

This commit is contained in:
allegroai 2020-06-01 11:33:39 +03:00
parent ede5586ccc
commit b0b09616a8
5 changed files with 129 additions and 87 deletions

View File

@ -47,6 +47,7 @@ _error_codes = {
128: ('invalid_task_output', 'invalid task output'), 128: ('invalid_task_output', 'invalid task output'),
129: ('task_publish_in_progress', 'Task publish in progress'), 129: ('task_publish_in_progress', 'Task publish in progress'),
130: ('task_not_found', 'task not found'), 130: ('task_not_found', 'task not found'),
131: ('events_not_added', 'events not added'),
# Models # Models
200: ('model_error', 'general task error'), 200: ('model_error', 'general task error'),

View File

@ -3,7 +3,7 @@ from collections import defaultdict
from contextlib import closing from contextlib import closing
from datetime import datetime from datetime import datetime
from operator import attrgetter from operator import attrgetter
from typing import Sequence from typing import Sequence, Set, Tuple
import six import six
from elasticsearch import helpers from elasticsearch import helpers
@ -46,7 +46,22 @@ class EventBLL(object):
def metrics(self) -> EventMetrics: def metrics(self) -> EventMetrics:
return self._metrics return self._metrics
def add_events(self, company_id, events, worker, allow_locked_tasks=False): @staticmethod
def _get_valid_tasks(company_id, task_ids: Set, allow_locked_tasks=False) -> Set:
"""Verify that task exists and can be updated"""
if not task_ids:
return set()
with translate_errors_context(), TimingContext("mongo", "task_by_ids"):
query = Q(id__in=task_ids, company=company_id)
if not allow_locked_tasks:
query &= Q(status__nin=LOCKED_TASK_STATUSES)
res = Task.objects(query).only("id")
return {r.id for r in res}
def add_events(
self, company_id, events, worker, allow_locked_tasks=False
) -> Tuple[int, int, dict]:
actions = [] actions = []
task_ids = set() task_ids = set()
task_iteration = defaultdict(lambda: 0) task_iteration = defaultdict(lambda: 0)
@ -56,19 +71,32 @@ class EventBLL(object):
task_last_events = nested_dict( task_last_events = nested_dict(
3, dict 3, dict
) # task_id -> metric_hash -> event_type -> MetricEvent ) # task_id -> metric_hash -> event_type -> MetricEvent
errors_per_type = defaultdict(int)
valid_tasks = self._get_valid_tasks(
company_id,
task_ids={event["task"] for event in events if event.get("task")},
allow_locked_tasks=allow_locked_tasks,
)
for event in events: for event in events:
# remove spaces from event type # remove spaces from event type
if "type" not in event: event_type = event.get("type")
raise errors.BadRequest("Event must have a 'type' field", event=event) if event_type is None:
errors_per_type["Event must have a 'type' field"] += 1
continue
event_type = event["type"].replace(" ", "_") event_type = event_type.replace(" ", "_")
if event_type not in EVENT_TYPES: if event_type not in EVENT_TYPES:
raise errors.BadRequest( errors_per_type[f"Invalid event type {event_type}"] += 1
"Invalid event type {}".format(event_type), continue
event=event,
types=EVENT_TYPES, task_id = event.get("task")
) if task_id is None:
errors_per_type["Event must have a 'task' field"] += 1
continue
if task_id not in valid_tasks:
errors_per_type["Invalid task id"] += 1
continue
event["type"] = event_type event["type"] = event_type
@ -114,8 +142,6 @@ class EventBLL(object):
else: else:
es_action["_id"] = dbutils.id() es_action["_id"] = dbutils.id()
task_id = event.get("task")
if task_id is not None:
es_action["_routing"] = task_id es_action["_routing"] = task_id
task_ids.add(task_id) task_ids.add(task_id)
if ( if (
@ -131,28 +157,11 @@ class EventBLL(object):
self._update_last_scalar_events_for_task( self._update_last_scalar_events_for_task(
last_events=task_last_scalar_events[task_id], event=event last_events=task_last_scalar_events[task_id], event=event
) )
else:
es_action["_routing"] = task_id
actions.append(es_action) actions.append(es_action)
if task_ids:
# verify task_ids
with translate_errors_context(), TimingContext("mongo", "task_by_ids"):
extra_msg = None
query = Q(id__in=task_ids, company=company_id)
if not allow_locked_tasks:
query &= Q(status__nin=LOCKED_TASK_STATUSES)
extra_msg = "or task published"
res = Task.objects(query).only("id")
if len(res) < len(task_ids):
invalid_task_ids = tuple(set(task_ids) - set(r.id for r in res))
raise errors.bad_request.InvalidTaskId(
extra_msg, company=company_id, ids=invalid_task_ids
)
errors_in_bulk = []
added = 0 added = 0
if actions:
chunk_size = 500 chunk_size = 500
with translate_errors_context(), TimingContext("es", "events_add_batch"): with translate_errors_context(), TimingContext("es", "events_add_batch"):
# TODO: replace it with helpers.parallel_bulk in the future once the parallel pool leak is fixed # TODO: replace it with helpers.parallel_bulk in the future once the parallel pool leak is fixed
@ -169,14 +178,13 @@ class EventBLL(object):
if success: if success:
added += chunk_size added += chunk_size
else: else:
errors_in_bulk.append(info) errors_per_type["Error when indexing events batch"] += 1
remaining_tasks = set() remaining_tasks = set()
now = datetime.utcnow() now = datetime.utcnow()
for task_id in task_ids: for task_id in task_ids:
# Update related tasks. For reasons of performance, we prefer to update all of them and not only those # Update related tasks. For reasons of performance, we prefer to update
# who's events were successful # all of them and not only those who's events were successful
updated = self._update_task( updated = self._update_task(
company_id=company_id, company_id=company_id,
task_id=task_id, task_id=task_id,
@ -191,12 +199,18 @@ class EventBLL(object):
continue continue
if remaining_tasks: if remaining_tasks:
TaskBLL.set_last_update(remaining_tasks, company_id, last_update=now) TaskBLL.set_last_update(
remaining_tasks, company_id, last_update=now
)
# Compensate for always adding chunk_size on success (last chunk is probably smaller) # Compensate for always adding chunk_size on success (last chunk is probably smaller)
added = min(added, len(actions)) added = min(added, len(actions))
return added, errors_in_bulk if not added:
raise errors.bad_request.EventsNotAdded(**errors_per_type)
errors_count = sum(errors_per_type.values())
return added, errors_count, errors_per_type
def _update_last_scalar_events_for_task(self, last_events, event): def _update_last_scalar_events_for_task(self, last_events, event):
""" """

View File

@ -258,6 +258,7 @@
properties { properties {
added { type: integer } added { type: integer }
errors { type: integer } errors { type: integer }
errors_info { type: object }
} }
} }
} }

View File

@ -27,10 +27,10 @@ event_bll = EventBLL()
def add(call: APICall, company_id, req_model): def add(call: APICall, company_id, req_model):
data = call.data.copy() data = call.data.copy()
allow_locked = data.pop("allow_locked", False) allow_locked = data.pop("allow_locked", False)
added, batch_errors = event_bll.add_events( added, err_count, err_info = event_bll.add_events(
company_id, [data], call.worker, allow_locked_tasks=allow_locked company_id, [data], call.worker, allow_locked_tasks=allow_locked
) )
call.result.data = dict(added=added, errors=len(batch_errors)) call.result.data = dict(added=added, errors=err_count, errors_info=err_info)
call.kpis["events"] = 1 call.kpis["events"] = 1
@ -40,8 +40,8 @@ def add_batch(call: APICall, company_id, req_model):
if events is None or len(events) == 0: if events is None or len(events) == 0:
raise errors.bad_request.BatchContainsNoItems() raise errors.bad_request.BatchContainsNoItems()
added, batch_errors = event_bll.add_events(company_id, events, call.worker) added, err_count, err_info = event_bll.add_events(company_id, events, call.worker)
call.result.data = dict(added=added, errors=len(batch_errors)) call.result.data = dict(added=added, errors=err_count, errors_info=err_info)
call.kpis["events"] = len(events) call.kpis["events"] = len(events)

View File

@ -9,6 +9,7 @@ from statistics import mean
from typing import Sequence from typing import Sequence
import es_factory import es_factory
from apierrors.errors.bad_request import EventsNotAdded
from tests.automated import TestService from tests.automated import TestService
@ -160,6 +161,30 @@ class TestTaskEvents(TestService):
self.assertEqual(len(it["events"]), events_per_iter) self.assertEqual(len(it["events"]), events_per_iter)
return res.scroll_id return res.scroll_id
def test_error_events(self):
task = self._temp_task()
events = [
self._create_task_event("unknown type", task, iteration=1),
self._create_task_event("training_debug_image", task=None, iteration=1),
self._create_task_event(
"training_debug_image", task="Invalid task", iteration=1
),
]
# failure if no events added
with self.api.raises(EventsNotAdded):
self.send_batch(events)
events.append(
self._create_task_event("training_debug_image", task=task, iteration=1)
)
# success if at least one event added
res = self.send_batch(events)
self.assertEqual(res["added"], 1)
self.assertEqual(res["errors"], 3)
self.assertEqual(len(res["errors_info"]), 3)
res = self.api.events.get_task_events(task=task)
self.assertEqual(len(res.events), 1)
def test_task_logs(self): def test_task_logs(self):
task = self._temp_task() task = self._temp_task()
timestamp = es_factory.get_timestamp_millis() timestamp = es_factory.get_timestamp_millis()
@ -429,7 +454,8 @@ class TestTaskEvents(TestService):
assert len(data["plots"]) == 0 assert len(data["plots"]) == 0
def send_batch(self, events): def send_batch(self, events):
self.api.send_batch("events.add_batch", events) _, data = self.api.send_batch("events.add_batch", events)
return data
def send(self, event): def send(self, event):
self.api.send("events.add", event) self.api.send("events.add", event)