mirror of
https://github.com/clearml/clearml-server
synced 2025-06-04 03:47:03 +00:00
Add support for returning only valid plot events
This commit is contained in:
parent
171969c5ea
commit
f4ead86449
@ -1,15 +1,18 @@
|
||||
import base64
|
||||
import hashlib
|
||||
import zlib
|
||||
from collections import defaultdict
|
||||
from contextlib import closing
|
||||
from datetime import datetime
|
||||
from operator import attrgetter
|
||||
from typing import Sequence, Set, Tuple, Optional
|
||||
from typing import Sequence, Set, Tuple, Optional, Dict
|
||||
|
||||
import six
|
||||
from elasticsearch import helpers
|
||||
from mongoengine import Q
|
||||
from nested_dict import nested_dict
|
||||
|
||||
from apiserver.bll.util import parallel_chunked_decorator
|
||||
from apiserver.database import utils as dbutils
|
||||
from apiserver.es_factory import es_factory
|
||||
from apiserver.apierrors import errors
|
||||
@ -26,10 +29,19 @@ from apiserver.tools import safe_get
|
||||
from apiserver.utilities.dicts import flatten_nested_items
|
||||
|
||||
# noinspection PyTypeChecker
|
||||
from apiserver.utilities.json import loads
|
||||
|
||||
EVENT_TYPES = set(map(attrgetter("value"), EventType))
|
||||
LOCKED_TASK_STATUSES = (TaskStatus.publishing, TaskStatus.published)
|
||||
|
||||
|
||||
class PlotFields:
|
||||
valid_plot = "valid_plot"
|
||||
plot_len = "plot_len"
|
||||
plot_str = "plot_str"
|
||||
plot_data = "plot_data"
|
||||
|
||||
|
||||
class EventBLL(object):
|
||||
id_fields = ("task", "iter", "metric", "variant", "key")
|
||||
empty_scroll = "FFFF"
|
||||
@ -81,6 +93,7 @@ class EventBLL(object):
|
||||
},
|
||||
allow_locked_tasks=allow_locked_tasks,
|
||||
)
|
||||
|
||||
for event in events:
|
||||
# remove spaces from event type
|
||||
event_type = event.get("type")
|
||||
@ -162,6 +175,21 @@ class EventBLL(object):
|
||||
|
||||
actions.append(es_action)
|
||||
|
||||
action: Dict[dict]
|
||||
plot_actions = [
|
||||
action["_source"]
|
||||
for action in actions
|
||||
if action["_source"]["type"] == EventType.metrics_plot.value
|
||||
]
|
||||
if plot_actions:
|
||||
self.validate_and_compress_plots(
|
||||
plot_actions,
|
||||
validate_json=config.get("services.events.validate_plot_str", False),
|
||||
compression_threshold=config.get(
|
||||
"services.events.plot_compression_threshold", 100_000
|
||||
),
|
||||
)
|
||||
|
||||
added = 0
|
||||
if actions:
|
||||
chunk_size = 500
|
||||
@ -214,6 +242,52 @@ class EventBLL(object):
|
||||
errors_count = sum(errors_per_type.values())
|
||||
return added, errors_count, errors_per_type
|
||||
|
||||
@parallel_chunked_decorator(chunk_size=10)
|
||||
def validate_and_compress_plots(
|
||||
self,
|
||||
plot_events: Sequence[dict],
|
||||
validate_json: bool,
|
||||
compression_threshold: int,
|
||||
):
|
||||
for event in plot_events:
|
||||
validate = validate_json and not event.pop("skip_validation", False)
|
||||
plot_str = event.get(PlotFields.plot_str)
|
||||
if not plot_str:
|
||||
event[PlotFields.plot_len] = 0
|
||||
if validate:
|
||||
event[PlotFields.valid_plot] = False
|
||||
continue
|
||||
|
||||
plot_len = len(plot_str)
|
||||
event[PlotFields.plot_len] = plot_len
|
||||
if validate:
|
||||
event[PlotFields.valid_plot] = self._is_valid_json(plot_str)
|
||||
if compression_threshold and plot_len >= compression_threshold:
|
||||
event[PlotFields.plot_data] = base64.encodebytes(
|
||||
zlib.compress(plot_str.encode(), level=1)
|
||||
).decode("ascii")
|
||||
event.pop(PlotFields.plot_str, None)
|
||||
|
||||
@parallel_chunked_decorator(chunk_size=10)
|
||||
def uncompress_plots(self, plot_events: Sequence[dict]):
|
||||
for event in plot_events:
|
||||
plot_data = event.pop(PlotFields.plot_data, None)
|
||||
if plot_data and event.get(PlotFields.plot_str) is None:
|
||||
event[PlotFields.plot_str] = zlib.decompress(
|
||||
base64.b64decode(plot_data)
|
||||
).decode()
|
||||
|
||||
@staticmethod
|
||||
def _is_valid_json(text: str) -> bool:
|
||||
"""Check str for valid json"""
|
||||
if not text:
|
||||
return False
|
||||
try:
|
||||
loads(text)
|
||||
except Exception:
|
||||
return False
|
||||
return True
|
||||
|
||||
def _update_last_scalar_events_for_task(self, last_events, event):
|
||||
"""
|
||||
Update last_events structure with the provided event details if this event is more
|
||||
@ -423,7 +497,20 @@ class EventBLL(object):
|
||||
if not self.es.indices.exists(es_index):
|
||||
return TaskEventsResult()
|
||||
|
||||
must = []
|
||||
plot_valid_condition = {
|
||||
"bool": {
|
||||
"should": [
|
||||
{"term": {PlotFields.valid_plot: True}},
|
||||
{
|
||||
"bool": {
|
||||
"must_not": {"exists": {"field": PlotFields.valid_plot}}
|
||||
}
|
||||
},
|
||||
]
|
||||
}
|
||||
}
|
||||
must = [plot_valid_condition]
|
||||
|
||||
if last_iterations_per_plot is None:
|
||||
must.append({"terms": {"task": tasks}})
|
||||
else:
|
||||
@ -467,6 +554,7 @@ class EventBLL(object):
|
||||
)
|
||||
|
||||
events, total_events, next_scroll_id = self._get_events_from_es_res(es_res)
|
||||
self.uncompress_plots(events)
|
||||
return TaskEventsResult(
|
||||
events=events, next_scroll_id=next_scroll_id, total_events=total_events
|
||||
)
|
||||
|
@ -1,6 +1,10 @@
|
||||
import functools
|
||||
import itertools
|
||||
from concurrent.futures.thread import ThreadPoolExecutor
|
||||
from operator import itemgetter
|
||||
from typing import Sequence, Optional, Callable, Tuple, Dict, Any, Set
|
||||
from typing import Sequence, Optional, Callable, Tuple, Dict, Any, Set, Iterable
|
||||
|
||||
from boltons import iterutils
|
||||
|
||||
from apiserver.database.model import AttributedDocument
|
||||
from apiserver.database.model.settings import Settings
|
||||
@ -78,3 +82,36 @@ class SetFieldsResolver:
|
||||
@functools.lru_cache()
|
||||
def get_server_uuid() -> Optional[str]:
|
||||
return Settings.get_by_key("server.uuid")
|
||||
|
||||
|
||||
def parallel_chunked_decorator(func: Callable = None, chunk_size: int = 100):
|
||||
"""
|
||||
Decorates a method for parallel chunked execution. The method should have
|
||||
one positional parameter (that is used for breaking into chunks)
|
||||
and arbitrary number of keyword params. The return value should be iterable
|
||||
The results are concatenated in the same order as the passed params
|
||||
"""
|
||||
if func is None:
|
||||
return functools.partial(parallel_chunked_decorator, chunk_size=chunk_size)
|
||||
|
||||
@functools.wraps(func)
|
||||
def wrapper(self, iterable: Iterable, **kwargs):
|
||||
assert iterutils.is_collection(
|
||||
iterable
|
||||
), "The positional parameter should be an iterable for breaking into chunks"
|
||||
|
||||
func_with_params = functools.partial(func, self, **kwargs)
|
||||
with ThreadPoolExecutor() as pool:
|
||||
return list(
|
||||
itertools.chain.from_iterable(
|
||||
filter(
|
||||
None,
|
||||
pool.map(
|
||||
func_with_params,
|
||||
iterutils.chunked_iter(iterable, chunk_size),
|
||||
),
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
return wrapper
|
||||
|
@ -11,3 +11,10 @@ max_metrics_concurrency: 4
|
||||
events_retrieval {
|
||||
state_expiration_sec: 3600
|
||||
}
|
||||
|
||||
# if set then plot str will be checked for the valid json on plot add
|
||||
# and the result of the check is written to the db
|
||||
validate_plot_str: false
|
||||
|
||||
# If not 0 then the plots equal or greater to the size will be stored compressed in the DB
|
||||
plot_compression_threshold: 100000
|
@ -6,6 +6,9 @@
|
||||
"plot_str": {
|
||||
"type": "text",
|
||||
"index": false
|
||||
},
|
||||
"plot_data": {
|
||||
"type": "binary"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -147,6 +147,10 @@
|
||||
"""
|
||||
type: string
|
||||
}
|
||||
skip_validation {
|
||||
description: "If set then plot_str is not checked for a valid json. The default is False"
|
||||
type: boolean
|
||||
}
|
||||
}
|
||||
}
|
||||
scalar_key_enum {
|
||||
|
@ -435,9 +435,9 @@ class TestTaskEvents(TestService):
|
||||
)
|
||||
self.send(event)
|
||||
|
||||
event = self._create_task_event("plot", task, 100)
|
||||
event["metric"] = "confusion"
|
||||
event.update(
|
||||
event1 = self._create_task_event("plot", task, 100)
|
||||
event1["metric"] = "confusion"
|
||||
event1.update(
|
||||
{
|
||||
"plot_str": json.dumps(
|
||||
{
|
||||
@ -476,14 +476,54 @@ class TestTaskEvents(TestService):
|
||||
)
|
||||
}
|
||||
)
|
||||
self.send(event)
|
||||
self.send(event1)
|
||||
|
||||
data = self.api.events.get_task_plots(task=task)
|
||||
assert len(data["plots"]) == 2
|
||||
plots = self.api.events.get_task_plots(task=task).plots
|
||||
self.assertEqual(
|
||||
{e["plot_str"] for e in (event, event1)}, {p.plot_str for p in plots}
|
||||
)
|
||||
|
||||
self.api.tasks.reset(task=task)
|
||||
data = self.api.events.get_task_plots(task=task)
|
||||
assert len(data["plots"]) == 0
|
||||
plots = self.api.events.get_task_plots(task=task).plots
|
||||
self.assertEqual(len(plots), 0)
|
||||
|
||||
@unittest.skip("this test will run only if 'validate_plot_str' is set to true")
|
||||
def test_plots_validation(self):
|
||||
valid_plot_str = json.dumps({"data": []})
|
||||
invalid_plot_str = "Not a valid json"
|
||||
task = self._temp_task()
|
||||
|
||||
event = self._create_task_event(
|
||||
"plot", task, 0, metric="test1", plot_str=valid_plot_str
|
||||
)
|
||||
event1 = self._create_task_event(
|
||||
"plot", task, 100, metric="test2", plot_str=invalid_plot_str
|
||||
)
|
||||
self.send_batch([event, event1])
|
||||
res = self.api.events.get_task_plots(task=task).plots
|
||||
self.assertEqual(len(res), 1)
|
||||
self.assertEqual(res[0].metric, "test1")
|
||||
|
||||
event = self._create_task_event(
|
||||
"plot",
|
||||
task,
|
||||
0,
|
||||
metric="test1",
|
||||
plot_str=valid_plot_str,
|
||||
skip_validation=True,
|
||||
)
|
||||
event1 = self._create_task_event(
|
||||
"plot",
|
||||
task,
|
||||
100,
|
||||
metric="test2",
|
||||
plot_str=invalid_plot_str,
|
||||
skip_validation=True,
|
||||
)
|
||||
self.send_batch([event, event1])
|
||||
res = self.api.events.get_task_plots(task=task).plots
|
||||
self.assertEqual(len(res), 2)
|
||||
self.assertEqual(set(r.metric for r in res), {"test1", "test2"})
|
||||
|
||||
def send_batch(self, events):
|
||||
_, data = self.api.send_batch("events.add_batch", events)
|
||||
|
Loading…
Reference in New Issue
Block a user