Add support for returning only valid plot events

This commit is contained in:
allegroai 2021-01-05 16:41:55 +02:00
parent 171969c5ea
commit f4ead86449
6 changed files with 190 additions and 11 deletions

View File

@ -1,15 +1,18 @@
import base64
import hashlib
import zlib
from collections import defaultdict
from contextlib import closing
from datetime import datetime
from operator import attrgetter
from typing import Sequence, Set, Tuple, Optional
from typing import Sequence, Set, Tuple, Optional, Dict
import six
from elasticsearch import helpers
from mongoengine import Q
from nested_dict import nested_dict
from apiserver.bll.util import parallel_chunked_decorator
from apiserver.database import utils as dbutils
from apiserver.es_factory import es_factory
from apiserver.apierrors import errors
@ -26,10 +29,19 @@ from apiserver.tools import safe_get
from apiserver.utilities.dicts import flatten_nested_items
# noinspection PyTypeChecker
from apiserver.utilities.json import loads
EVENT_TYPES = set(map(attrgetter("value"), EventType))
LOCKED_TASK_STATUSES = (TaskStatus.publishing, TaskStatus.published)
class PlotFields:
valid_plot = "valid_plot"
plot_len = "plot_len"
plot_str = "plot_str"
plot_data = "plot_data"
class EventBLL(object):
id_fields = ("task", "iter", "metric", "variant", "key")
empty_scroll = "FFFF"
@ -81,6 +93,7 @@ class EventBLL(object):
},
allow_locked_tasks=allow_locked_tasks,
)
for event in events:
# remove spaces from event type
event_type = event.get("type")
@ -162,6 +175,21 @@ class EventBLL(object):
actions.append(es_action)
action: Dict[dict]
plot_actions = [
action["_source"]
for action in actions
if action["_source"]["type"] == EventType.metrics_plot.value
]
if plot_actions:
self.validate_and_compress_plots(
plot_actions,
validate_json=config.get("services.events.validate_plot_str", False),
compression_threshold=config.get(
"services.events.plot_compression_threshold", 100_000
),
)
added = 0
if actions:
chunk_size = 500
@ -214,6 +242,52 @@ class EventBLL(object):
errors_count = sum(errors_per_type.values())
return added, errors_count, errors_per_type
@parallel_chunked_decorator(chunk_size=10)
def validate_and_compress_plots(
self,
plot_events: Sequence[dict],
validate_json: bool,
compression_threshold: int,
):
for event in plot_events:
validate = validate_json and not event.pop("skip_validation", False)
plot_str = event.get(PlotFields.plot_str)
if not plot_str:
event[PlotFields.plot_len] = 0
if validate:
event[PlotFields.valid_plot] = False
continue
plot_len = len(plot_str)
event[PlotFields.plot_len] = plot_len
if validate:
event[PlotFields.valid_plot] = self._is_valid_json(plot_str)
if compression_threshold and plot_len >= compression_threshold:
event[PlotFields.plot_data] = base64.encodebytes(
zlib.compress(plot_str.encode(), level=1)
).decode("ascii")
event.pop(PlotFields.plot_str, None)
@parallel_chunked_decorator(chunk_size=10)
def uncompress_plots(self, plot_events: Sequence[dict]):
for event in plot_events:
plot_data = event.pop(PlotFields.plot_data, None)
if plot_data and event.get(PlotFields.plot_str) is None:
event[PlotFields.plot_str] = zlib.decompress(
base64.b64decode(plot_data)
).decode()
@staticmethod
def _is_valid_json(text: str) -> bool:
"""Check str for valid json"""
if not text:
return False
try:
loads(text)
except Exception:
return False
return True
def _update_last_scalar_events_for_task(self, last_events, event):
"""
Update last_events structure with the provided event details if this event is more
@ -423,7 +497,20 @@ class EventBLL(object):
if not self.es.indices.exists(es_index):
return TaskEventsResult()
must = []
plot_valid_condition = {
"bool": {
"should": [
{"term": {PlotFields.valid_plot: True}},
{
"bool": {
"must_not": {"exists": {"field": PlotFields.valid_plot}}
}
},
]
}
}
must = [plot_valid_condition]
if last_iterations_per_plot is None:
must.append({"terms": {"task": tasks}})
else:
@ -467,6 +554,7 @@ class EventBLL(object):
)
events, total_events, next_scroll_id = self._get_events_from_es_res(es_res)
self.uncompress_plots(events)
return TaskEventsResult(
events=events, next_scroll_id=next_scroll_id, total_events=total_events
)

View File

@ -1,6 +1,10 @@
import functools
import itertools
from concurrent.futures.thread import ThreadPoolExecutor
from operator import itemgetter
from typing import Sequence, Optional, Callable, Tuple, Dict, Any, Set
from typing import Sequence, Optional, Callable, Tuple, Dict, Any, Set, Iterable
from boltons import iterutils
from apiserver.database.model import AttributedDocument
from apiserver.database.model.settings import Settings
@ -78,3 +82,36 @@ class SetFieldsResolver:
@functools.lru_cache()
def get_server_uuid() -> Optional[str]:
return Settings.get_by_key("server.uuid")
def parallel_chunked_decorator(func: Callable = None, chunk_size: int = 100):
"""
Decorates a method for parallel chunked execution. The method should have
one positional parameter (that is used for breaking into chunks)
and arbitrary number of keyword params. The return value should be iterable
The results are concatenated in the same order as the passed params
"""
if func is None:
return functools.partial(parallel_chunked_decorator, chunk_size=chunk_size)
@functools.wraps(func)
def wrapper(self, iterable: Iterable, **kwargs):
assert iterutils.is_collection(
iterable
), "The positional parameter should be an iterable for breaking into chunks"
func_with_params = functools.partial(func, self, **kwargs)
with ThreadPoolExecutor() as pool:
return list(
itertools.chain.from_iterable(
filter(
None,
pool.map(
func_with_params,
iterutils.chunked_iter(iterable, chunk_size),
),
)
),
)
return wrapper

View File

@ -11,3 +11,10 @@ max_metrics_concurrency: 4
events_retrieval {
state_expiration_sec: 3600
}
# if set then plot str will be checked for the valid json on plot add
# and the result of the check is written to the db
validate_plot_str: false
# If not 0 then the plots equal or greater to the size will be stored compressed in the DB
plot_compression_threshold: 100000

View File

@ -6,6 +6,9 @@
"plot_str": {
"type": "text",
"index": false
},
"plot_data": {
"type": "binary"
}
}
}

View File

@ -147,6 +147,10 @@
"""
type: string
}
skip_validation {
description: "If set then plot_str is not checked for a valid json. The default is False"
type: boolean
}
}
}
scalar_key_enum {

View File

@ -435,9 +435,9 @@ class TestTaskEvents(TestService):
)
self.send(event)
event = self._create_task_event("plot", task, 100)
event["metric"] = "confusion"
event.update(
event1 = self._create_task_event("plot", task, 100)
event1["metric"] = "confusion"
event1.update(
{
"plot_str": json.dumps(
{
@ -476,14 +476,54 @@ class TestTaskEvents(TestService):
)
}
)
self.send(event)
self.send(event1)
data = self.api.events.get_task_plots(task=task)
assert len(data["plots"]) == 2
plots = self.api.events.get_task_plots(task=task).plots
self.assertEqual(
{e["plot_str"] for e in (event, event1)}, {p.plot_str for p in plots}
)
self.api.tasks.reset(task=task)
data = self.api.events.get_task_plots(task=task)
assert len(data["plots"]) == 0
plots = self.api.events.get_task_plots(task=task).plots
self.assertEqual(len(plots), 0)
@unittest.skip("this test will run only if 'validate_plot_str' is set to true")
def test_plots_validation(self):
valid_plot_str = json.dumps({"data": []})
invalid_plot_str = "Not a valid json"
task = self._temp_task()
event = self._create_task_event(
"plot", task, 0, metric="test1", plot_str=valid_plot_str
)
event1 = self._create_task_event(
"plot", task, 100, metric="test2", plot_str=invalid_plot_str
)
self.send_batch([event, event1])
res = self.api.events.get_task_plots(task=task).plots
self.assertEqual(len(res), 1)
self.assertEqual(res[0].metric, "test1")
event = self._create_task_event(
"plot",
task,
0,
metric="test1",
plot_str=valid_plot_str,
skip_validation=True,
)
event1 = self._create_task_event(
"plot",
task,
100,
metric="test2",
plot_str=invalid_plot_str,
skip_validation=True,
)
self.send_batch([event, event1])
res = self.api.events.get_task_plots(task=task).plots
self.assertEqual(len(res), 2)
self.assertEqual(set(r.metric for r in res), {"test1", "test2"})
def send_batch(self, events):
_, data = self.api.send_batch("events.add_batch", events)