mirror of
https://github.com/clearml/clearml-server
synced 2025-06-26 23:15:47 +00:00
Add support for returning only valid plot events
This commit is contained in:
@@ -1,15 +1,18 @@
|
||||
import base64
|
||||
import hashlib
|
||||
import zlib
|
||||
from collections import defaultdict
|
||||
from contextlib import closing
|
||||
from datetime import datetime
|
||||
from operator import attrgetter
|
||||
from typing import Sequence, Set, Tuple, Optional
|
||||
from typing import Sequence, Set, Tuple, Optional, Dict
|
||||
|
||||
import six
|
||||
from elasticsearch import helpers
|
||||
from mongoengine import Q
|
||||
from nested_dict import nested_dict
|
||||
|
||||
from apiserver.bll.util import parallel_chunked_decorator
|
||||
from apiserver.database import utils as dbutils
|
||||
from apiserver.es_factory import es_factory
|
||||
from apiserver.apierrors import errors
|
||||
@@ -26,10 +29,19 @@ from apiserver.tools import safe_get
|
||||
from apiserver.utilities.dicts import flatten_nested_items
|
||||
|
||||
# noinspection PyTypeChecker
|
||||
from apiserver.utilities.json import loads
|
||||
|
||||
EVENT_TYPES = set(map(attrgetter("value"), EventType))
|
||||
LOCKED_TASK_STATUSES = (TaskStatus.publishing, TaskStatus.published)
|
||||
|
||||
|
||||
class PlotFields:
|
||||
valid_plot = "valid_plot"
|
||||
plot_len = "plot_len"
|
||||
plot_str = "plot_str"
|
||||
plot_data = "plot_data"
|
||||
|
||||
|
||||
class EventBLL(object):
|
||||
id_fields = ("task", "iter", "metric", "variant", "key")
|
||||
empty_scroll = "FFFF"
|
||||
@@ -81,6 +93,7 @@ class EventBLL(object):
|
||||
},
|
||||
allow_locked_tasks=allow_locked_tasks,
|
||||
)
|
||||
|
||||
for event in events:
|
||||
# remove spaces from event type
|
||||
event_type = event.get("type")
|
||||
@@ -162,6 +175,21 @@ class EventBLL(object):
|
||||
|
||||
actions.append(es_action)
|
||||
|
||||
action: Dict[dict]
|
||||
plot_actions = [
|
||||
action["_source"]
|
||||
for action in actions
|
||||
if action["_source"]["type"] == EventType.metrics_plot.value
|
||||
]
|
||||
if plot_actions:
|
||||
self.validate_and_compress_plots(
|
||||
plot_actions,
|
||||
validate_json=config.get("services.events.validate_plot_str", False),
|
||||
compression_threshold=config.get(
|
||||
"services.events.plot_compression_threshold", 100_000
|
||||
),
|
||||
)
|
||||
|
||||
added = 0
|
||||
if actions:
|
||||
chunk_size = 500
|
||||
@@ -214,6 +242,52 @@ class EventBLL(object):
|
||||
errors_count = sum(errors_per_type.values())
|
||||
return added, errors_count, errors_per_type
|
||||
|
||||
@parallel_chunked_decorator(chunk_size=10)
|
||||
def validate_and_compress_plots(
|
||||
self,
|
||||
plot_events: Sequence[dict],
|
||||
validate_json: bool,
|
||||
compression_threshold: int,
|
||||
):
|
||||
for event in plot_events:
|
||||
validate = validate_json and not event.pop("skip_validation", False)
|
||||
plot_str = event.get(PlotFields.plot_str)
|
||||
if not plot_str:
|
||||
event[PlotFields.plot_len] = 0
|
||||
if validate:
|
||||
event[PlotFields.valid_plot] = False
|
||||
continue
|
||||
|
||||
plot_len = len(plot_str)
|
||||
event[PlotFields.plot_len] = plot_len
|
||||
if validate:
|
||||
event[PlotFields.valid_plot] = self._is_valid_json(plot_str)
|
||||
if compression_threshold and plot_len >= compression_threshold:
|
||||
event[PlotFields.plot_data] = base64.encodebytes(
|
||||
zlib.compress(plot_str.encode(), level=1)
|
||||
).decode("ascii")
|
||||
event.pop(PlotFields.plot_str, None)
|
||||
|
||||
@parallel_chunked_decorator(chunk_size=10)
|
||||
def uncompress_plots(self, plot_events: Sequence[dict]):
|
||||
for event in plot_events:
|
||||
plot_data = event.pop(PlotFields.plot_data, None)
|
||||
if plot_data and event.get(PlotFields.plot_str) is None:
|
||||
event[PlotFields.plot_str] = zlib.decompress(
|
||||
base64.b64decode(plot_data)
|
||||
).decode()
|
||||
|
||||
@staticmethod
|
||||
def _is_valid_json(text: str) -> bool:
|
||||
"""Check str for valid json"""
|
||||
if not text:
|
||||
return False
|
||||
try:
|
||||
loads(text)
|
||||
except Exception:
|
||||
return False
|
||||
return True
|
||||
|
||||
def _update_last_scalar_events_for_task(self, last_events, event):
|
||||
"""
|
||||
Update last_events structure with the provided event details if this event is more
|
||||
@@ -423,7 +497,20 @@ class EventBLL(object):
|
||||
if not self.es.indices.exists(es_index):
|
||||
return TaskEventsResult()
|
||||
|
||||
must = []
|
||||
plot_valid_condition = {
|
||||
"bool": {
|
||||
"should": [
|
||||
{"term": {PlotFields.valid_plot: True}},
|
||||
{
|
||||
"bool": {
|
||||
"must_not": {"exists": {"field": PlotFields.valid_plot}}
|
||||
}
|
||||
},
|
||||
]
|
||||
}
|
||||
}
|
||||
must = [plot_valid_condition]
|
||||
|
||||
if last_iterations_per_plot is None:
|
||||
must.append({"terms": {"task": tasks}})
|
||||
else:
|
||||
@@ -467,6 +554,7 @@ class EventBLL(object):
|
||||
)
|
||||
|
||||
events, total_events, next_scroll_id = self._get_events_from_es_res(es_res)
|
||||
self.uncompress_plots(events)
|
||||
return TaskEventsResult(
|
||||
events=events, next_scroll_id=next_scroll_id, total_events=total_events
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user