From f4ead86449a79171e69cac41de242c5f11bfdd6b Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Tue, 5 Jan 2021 16:41:55 +0200 Subject: [PATCH] Add support for returning only valid plot events --- apiserver/bll/event/event_bll.py | 92 ++++++++++++++++++- apiserver/bll/util.py | 39 +++++++- apiserver/config/default/services/events.conf | 7 ++ .../elastic/mappings/events/events_plot.json | 3 + apiserver/schema/services/events.conf | 4 + apiserver/tests/automated/test_task_events.py | 56 +++++++++-- 6 files changed, 190 insertions(+), 11 deletions(-) diff --git a/apiserver/bll/event/event_bll.py b/apiserver/bll/event/event_bll.py index bef1407..29e99b0 100644 --- a/apiserver/bll/event/event_bll.py +++ b/apiserver/bll/event/event_bll.py @@ -1,15 +1,18 @@ +import base64 import hashlib +import zlib from collections import defaultdict from contextlib import closing from datetime import datetime from operator import attrgetter -from typing import Sequence, Set, Tuple, Optional +from typing import Sequence, Set, Tuple, Optional, Dict import six from elasticsearch import helpers from mongoengine import Q from nested_dict import nested_dict +from apiserver.bll.util import parallel_chunked_decorator from apiserver.database import utils as dbutils from apiserver.es_factory import es_factory from apiserver.apierrors import errors @@ -26,10 +29,19 @@ from apiserver.tools import safe_get from apiserver.utilities.dicts import flatten_nested_items # noinspection PyTypeChecker +from apiserver.utilities.json import loads + EVENT_TYPES = set(map(attrgetter("value"), EventType)) LOCKED_TASK_STATUSES = (TaskStatus.publishing, TaskStatus.published) +class PlotFields: + valid_plot = "valid_plot" + plot_len = "plot_len" + plot_str = "plot_str" + plot_data = "plot_data" + + class EventBLL(object): id_fields = ("task", "iter", "metric", "variant", "key") empty_scroll = "FFFF" @@ -81,6 +93,7 @@ class EventBLL(object): }, allow_locked_tasks=allow_locked_tasks, ) + for event in events: # remove spaces from event type event_type = event.get("type") @@ -162,6 +175,21 @@ class EventBLL(object): actions.append(es_action) + action: Dict[dict] + plot_actions = [ + action["_source"] + for action in actions + if action["_source"]["type"] == EventType.metrics_plot.value + ] + if plot_actions: + self.validate_and_compress_plots( + plot_actions, + validate_json=config.get("services.events.validate_plot_str", False), + compression_threshold=config.get( + "services.events.plot_compression_threshold", 100_000 + ), + ) + added = 0 if actions: chunk_size = 500 @@ -214,6 +242,52 @@ class EventBLL(object): errors_count = sum(errors_per_type.values()) return added, errors_count, errors_per_type + @parallel_chunked_decorator(chunk_size=10) + def validate_and_compress_plots( + self, + plot_events: Sequence[dict], + validate_json: bool, + compression_threshold: int, + ): + for event in plot_events: + validate = validate_json and not event.pop("skip_validation", False) + plot_str = event.get(PlotFields.plot_str) + if not plot_str: + event[PlotFields.plot_len] = 0 + if validate: + event[PlotFields.valid_plot] = False + continue + + plot_len = len(plot_str) + event[PlotFields.plot_len] = plot_len + if validate: + event[PlotFields.valid_plot] = self._is_valid_json(plot_str) + if compression_threshold and plot_len >= compression_threshold: + event[PlotFields.plot_data] = base64.encodebytes( + zlib.compress(plot_str.encode(), level=1) + ).decode("ascii") + event.pop(PlotFields.plot_str, None) + + @parallel_chunked_decorator(chunk_size=10) + def uncompress_plots(self, plot_events: Sequence[dict]): + for event in plot_events: + plot_data = event.pop(PlotFields.plot_data, None) + if plot_data and event.get(PlotFields.plot_str) is None: + event[PlotFields.plot_str] = zlib.decompress( + base64.b64decode(plot_data) + ).decode() + + @staticmethod + def _is_valid_json(text: str) -> bool: + """Check str for valid json""" + if not text: + return False + try: + loads(text) + except Exception: + return False + return True + def _update_last_scalar_events_for_task(self, last_events, event): """ Update last_events structure with the provided event details if this event is more @@ -423,7 +497,20 @@ class EventBLL(object): if not self.es.indices.exists(es_index): return TaskEventsResult() - must = [] + plot_valid_condition = { + "bool": { + "should": [ + {"term": {PlotFields.valid_plot: True}}, + { + "bool": { + "must_not": {"exists": {"field": PlotFields.valid_plot}} + } + }, + ] + } + } + must = [plot_valid_condition] + if last_iterations_per_plot is None: must.append({"terms": {"task": tasks}}) else: @@ -467,6 +554,7 @@ class EventBLL(object): ) events, total_events, next_scroll_id = self._get_events_from_es_res(es_res) + self.uncompress_plots(events) return TaskEventsResult( events=events, next_scroll_id=next_scroll_id, total_events=total_events ) diff --git a/apiserver/bll/util.py b/apiserver/bll/util.py index ed83b8d..f845a61 100644 --- a/apiserver/bll/util.py +++ b/apiserver/bll/util.py @@ -1,6 +1,10 @@ import functools +import itertools +from concurrent.futures.thread import ThreadPoolExecutor from operator import itemgetter -from typing import Sequence, Optional, Callable, Tuple, Dict, Any, Set +from typing import Sequence, Optional, Callable, Tuple, Dict, Any, Set, Iterable + +from boltons import iterutils from apiserver.database.model import AttributedDocument from apiserver.database.model.settings import Settings @@ -78,3 +82,36 @@ class SetFieldsResolver: @functools.lru_cache() def get_server_uuid() -> Optional[str]: return Settings.get_by_key("server.uuid") + + +def parallel_chunked_decorator(func: Callable = None, chunk_size: int = 100): + """ + Decorates a method for parallel chunked execution. The method should have + one positional parameter (that is used for breaking into chunks) + and arbitrary number of keyword params. The return value should be iterable + The results are concatenated in the same order as the passed params + """ + if func is None: + return functools.partial(parallel_chunked_decorator, chunk_size=chunk_size) + + @functools.wraps(func) + def wrapper(self, iterable: Iterable, **kwargs): + assert iterutils.is_collection( + iterable + ), "The positional parameter should be an iterable for breaking into chunks" + + func_with_params = functools.partial(func, self, **kwargs) + with ThreadPoolExecutor() as pool: + return list( + itertools.chain.from_iterable( + filter( + None, + pool.map( + func_with_params, + iterutils.chunked_iter(iterable, chunk_size), + ), + ) + ), + ) + + return wrapper diff --git a/apiserver/config/default/services/events.conf b/apiserver/config/default/services/events.conf index 5adea87..953e964 100644 --- a/apiserver/config/default/services/events.conf +++ b/apiserver/config/default/services/events.conf @@ -11,3 +11,10 @@ max_metrics_concurrency: 4 events_retrieval { state_expiration_sec: 3600 } + +# if set then plot str will be checked for the valid json on plot add +# and the result of the check is written to the db +validate_plot_str: false + +# If not 0 then the plots equal or greater to the size will be stored compressed in the DB +plot_compression_threshold: 100000 \ No newline at end of file diff --git a/apiserver/elastic/mappings/events/events_plot.json b/apiserver/elastic/mappings/events/events_plot.json index 260700a..8e65c71 100644 --- a/apiserver/elastic/mappings/events/events_plot.json +++ b/apiserver/elastic/mappings/events/events_plot.json @@ -6,6 +6,9 @@ "plot_str": { "type": "text", "index": false + }, + "plot_data": { + "type": "binary" } } } diff --git a/apiserver/schema/services/events.conf b/apiserver/schema/services/events.conf index 0368f63..18f8bab 100644 --- a/apiserver/schema/services/events.conf +++ b/apiserver/schema/services/events.conf @@ -147,6 +147,10 @@ """ type: string } + skip_validation { + description: "If set then plot_str is not checked for a valid json. The default is False" + type: boolean + } } } scalar_key_enum { diff --git a/apiserver/tests/automated/test_task_events.py b/apiserver/tests/automated/test_task_events.py index 1e0be39..eade322 100644 --- a/apiserver/tests/automated/test_task_events.py +++ b/apiserver/tests/automated/test_task_events.py @@ -435,9 +435,9 @@ class TestTaskEvents(TestService): ) self.send(event) - event = self._create_task_event("plot", task, 100) - event["metric"] = "confusion" - event.update( + event1 = self._create_task_event("plot", task, 100) + event1["metric"] = "confusion" + event1.update( { "plot_str": json.dumps( { @@ -476,14 +476,54 @@ class TestTaskEvents(TestService): ) } ) - self.send(event) + self.send(event1) - data = self.api.events.get_task_plots(task=task) - assert len(data["plots"]) == 2 + plots = self.api.events.get_task_plots(task=task).plots + self.assertEqual( + {e["plot_str"] for e in (event, event1)}, {p.plot_str for p in plots} + ) self.api.tasks.reset(task=task) - data = self.api.events.get_task_plots(task=task) - assert len(data["plots"]) == 0 + plots = self.api.events.get_task_plots(task=task).plots + self.assertEqual(len(plots), 0) + + @unittest.skip("this test will run only if 'validate_plot_str' is set to true") + def test_plots_validation(self): + valid_plot_str = json.dumps({"data": []}) + invalid_plot_str = "Not a valid json" + task = self._temp_task() + + event = self._create_task_event( + "plot", task, 0, metric="test1", plot_str=valid_plot_str + ) + event1 = self._create_task_event( + "plot", task, 100, metric="test2", plot_str=invalid_plot_str + ) + self.send_batch([event, event1]) + res = self.api.events.get_task_plots(task=task).plots + self.assertEqual(len(res), 1) + self.assertEqual(res[0].metric, "test1") + + event = self._create_task_event( + "plot", + task, + 0, + metric="test1", + plot_str=valid_plot_str, + skip_validation=True, + ) + event1 = self._create_task_event( + "plot", + task, + 100, + metric="test2", + plot_str=invalid_plot_str, + skip_validation=True, + ) + self.send_batch([event, event1]) + res = self.api.events.get_task_plots(task=task).plots + self.assertEqual(len(res), 2) + self.assertEqual(set(r.metric for r in res), {"test1", "test2"}) def send_batch(self, events): _, data = self.api.send_batch("events.add_batch", events)