Add support for returning only valid plot events

This commit is contained in:
allegroai
2021-01-05 16:41:55 +02:00
parent 171969c5ea
commit f4ead86449
6 changed files with 190 additions and 11 deletions

View File

@@ -1,15 +1,18 @@
import base64
import hashlib
import zlib
from collections import defaultdict
from contextlib import closing
from datetime import datetime
from operator import attrgetter
from typing import Sequence, Set, Tuple, Optional
from typing import Sequence, Set, Tuple, Optional, Dict
import six
from elasticsearch import helpers
from mongoengine import Q
from nested_dict import nested_dict
from apiserver.bll.util import parallel_chunked_decorator
from apiserver.database import utils as dbutils
from apiserver.es_factory import es_factory
from apiserver.apierrors import errors
@@ -26,10 +29,19 @@ from apiserver.tools import safe_get
from apiserver.utilities.dicts import flatten_nested_items
# noinspection PyTypeChecker
from apiserver.utilities.json import loads
EVENT_TYPES = set(map(attrgetter("value"), EventType))
LOCKED_TASK_STATUSES = (TaskStatus.publishing, TaskStatus.published)
class PlotFields:
valid_plot = "valid_plot"
plot_len = "plot_len"
plot_str = "plot_str"
plot_data = "plot_data"
class EventBLL(object):
id_fields = ("task", "iter", "metric", "variant", "key")
empty_scroll = "FFFF"
@@ -81,6 +93,7 @@ class EventBLL(object):
},
allow_locked_tasks=allow_locked_tasks,
)
for event in events:
# remove spaces from event type
event_type = event.get("type")
@@ -162,6 +175,21 @@ class EventBLL(object):
actions.append(es_action)
action: Dict[dict]
plot_actions = [
action["_source"]
for action in actions
if action["_source"]["type"] == EventType.metrics_plot.value
]
if plot_actions:
self.validate_and_compress_plots(
plot_actions,
validate_json=config.get("services.events.validate_plot_str", False),
compression_threshold=config.get(
"services.events.plot_compression_threshold", 100_000
),
)
added = 0
if actions:
chunk_size = 500
@@ -214,6 +242,52 @@ class EventBLL(object):
errors_count = sum(errors_per_type.values())
return added, errors_count, errors_per_type
@parallel_chunked_decorator(chunk_size=10)
def validate_and_compress_plots(
self,
plot_events: Sequence[dict],
validate_json: bool,
compression_threshold: int,
):
for event in plot_events:
validate = validate_json and not event.pop("skip_validation", False)
plot_str = event.get(PlotFields.plot_str)
if not plot_str:
event[PlotFields.plot_len] = 0
if validate:
event[PlotFields.valid_plot] = False
continue
plot_len = len(plot_str)
event[PlotFields.plot_len] = plot_len
if validate:
event[PlotFields.valid_plot] = self._is_valid_json(plot_str)
if compression_threshold and plot_len >= compression_threshold:
event[PlotFields.plot_data] = base64.encodebytes(
zlib.compress(plot_str.encode(), level=1)
).decode("ascii")
event.pop(PlotFields.plot_str, None)
@parallel_chunked_decorator(chunk_size=10)
def uncompress_plots(self, plot_events: Sequence[dict]):
for event in plot_events:
plot_data = event.pop(PlotFields.plot_data, None)
if plot_data and event.get(PlotFields.plot_str) is None:
event[PlotFields.plot_str] = zlib.decompress(
base64.b64decode(plot_data)
).decode()
@staticmethod
def _is_valid_json(text: str) -> bool:
"""Check str for valid json"""
if not text:
return False
try:
loads(text)
except Exception:
return False
return True
def _update_last_scalar_events_for_task(self, last_events, event):
"""
Update last_events structure with the provided event details if this event is more
@@ -423,7 +497,20 @@ class EventBLL(object):
if not self.es.indices.exists(es_index):
return TaskEventsResult()
must = []
plot_valid_condition = {
"bool": {
"should": [
{"term": {PlotFields.valid_plot: True}},
{
"bool": {
"must_not": {"exists": {"field": PlotFields.valid_plot}}
}
},
]
}
}
must = [plot_valid_condition]
if last_iterations_per_plot is None:
must.append({"terms": {"task": tasks}})
else:
@@ -467,6 +554,7 @@ class EventBLL(object):
)
events, total_events, next_scroll_id = self._get_events_from_es_res(es_res)
self.uncompress_plots(events)
return TaskEventsResult(
events=events, next_scroll_id=next_scroll_id, total_events=total_events
)