mirror of
https://github.com/clearml/clearml-server
synced 2025-06-09 01:45:38 +00:00
Add support for returning only valid plot events
This commit is contained in:
parent
171969c5ea
commit
f4ead86449
@ -1,15 +1,18 @@
|
|||||||
|
import base64
|
||||||
import hashlib
|
import hashlib
|
||||||
|
import zlib
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from contextlib import closing
|
from contextlib import closing
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from operator import attrgetter
|
from operator import attrgetter
|
||||||
from typing import Sequence, Set, Tuple, Optional
|
from typing import Sequence, Set, Tuple, Optional, Dict
|
||||||
|
|
||||||
import six
|
import six
|
||||||
from elasticsearch import helpers
|
from elasticsearch import helpers
|
||||||
from mongoengine import Q
|
from mongoengine import Q
|
||||||
from nested_dict import nested_dict
|
from nested_dict import nested_dict
|
||||||
|
|
||||||
|
from apiserver.bll.util import parallel_chunked_decorator
|
||||||
from apiserver.database import utils as dbutils
|
from apiserver.database import utils as dbutils
|
||||||
from apiserver.es_factory import es_factory
|
from apiserver.es_factory import es_factory
|
||||||
from apiserver.apierrors import errors
|
from apiserver.apierrors import errors
|
||||||
@ -26,10 +29,19 @@ from apiserver.tools import safe_get
|
|||||||
from apiserver.utilities.dicts import flatten_nested_items
|
from apiserver.utilities.dicts import flatten_nested_items
|
||||||
|
|
||||||
# noinspection PyTypeChecker
|
# noinspection PyTypeChecker
|
||||||
|
from apiserver.utilities.json import loads
|
||||||
|
|
||||||
EVENT_TYPES = set(map(attrgetter("value"), EventType))
|
EVENT_TYPES = set(map(attrgetter("value"), EventType))
|
||||||
LOCKED_TASK_STATUSES = (TaskStatus.publishing, TaskStatus.published)
|
LOCKED_TASK_STATUSES = (TaskStatus.publishing, TaskStatus.published)
|
||||||
|
|
||||||
|
|
||||||
|
class PlotFields:
|
||||||
|
valid_plot = "valid_plot"
|
||||||
|
plot_len = "plot_len"
|
||||||
|
plot_str = "plot_str"
|
||||||
|
plot_data = "plot_data"
|
||||||
|
|
||||||
|
|
||||||
class EventBLL(object):
|
class EventBLL(object):
|
||||||
id_fields = ("task", "iter", "metric", "variant", "key")
|
id_fields = ("task", "iter", "metric", "variant", "key")
|
||||||
empty_scroll = "FFFF"
|
empty_scroll = "FFFF"
|
||||||
@ -81,6 +93,7 @@ class EventBLL(object):
|
|||||||
},
|
},
|
||||||
allow_locked_tasks=allow_locked_tasks,
|
allow_locked_tasks=allow_locked_tasks,
|
||||||
)
|
)
|
||||||
|
|
||||||
for event in events:
|
for event in events:
|
||||||
# remove spaces from event type
|
# remove spaces from event type
|
||||||
event_type = event.get("type")
|
event_type = event.get("type")
|
||||||
@ -162,6 +175,21 @@ class EventBLL(object):
|
|||||||
|
|
||||||
actions.append(es_action)
|
actions.append(es_action)
|
||||||
|
|
||||||
|
action: Dict[dict]
|
||||||
|
plot_actions = [
|
||||||
|
action["_source"]
|
||||||
|
for action in actions
|
||||||
|
if action["_source"]["type"] == EventType.metrics_plot.value
|
||||||
|
]
|
||||||
|
if plot_actions:
|
||||||
|
self.validate_and_compress_plots(
|
||||||
|
plot_actions,
|
||||||
|
validate_json=config.get("services.events.validate_plot_str", False),
|
||||||
|
compression_threshold=config.get(
|
||||||
|
"services.events.plot_compression_threshold", 100_000
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
added = 0
|
added = 0
|
||||||
if actions:
|
if actions:
|
||||||
chunk_size = 500
|
chunk_size = 500
|
||||||
@ -214,6 +242,52 @@ class EventBLL(object):
|
|||||||
errors_count = sum(errors_per_type.values())
|
errors_count = sum(errors_per_type.values())
|
||||||
return added, errors_count, errors_per_type
|
return added, errors_count, errors_per_type
|
||||||
|
|
||||||
|
@parallel_chunked_decorator(chunk_size=10)
|
||||||
|
def validate_and_compress_plots(
|
||||||
|
self,
|
||||||
|
plot_events: Sequence[dict],
|
||||||
|
validate_json: bool,
|
||||||
|
compression_threshold: int,
|
||||||
|
):
|
||||||
|
for event in plot_events:
|
||||||
|
validate = validate_json and not event.pop("skip_validation", False)
|
||||||
|
plot_str = event.get(PlotFields.plot_str)
|
||||||
|
if not plot_str:
|
||||||
|
event[PlotFields.plot_len] = 0
|
||||||
|
if validate:
|
||||||
|
event[PlotFields.valid_plot] = False
|
||||||
|
continue
|
||||||
|
|
||||||
|
plot_len = len(plot_str)
|
||||||
|
event[PlotFields.plot_len] = plot_len
|
||||||
|
if validate:
|
||||||
|
event[PlotFields.valid_plot] = self._is_valid_json(plot_str)
|
||||||
|
if compression_threshold and plot_len >= compression_threshold:
|
||||||
|
event[PlotFields.plot_data] = base64.encodebytes(
|
||||||
|
zlib.compress(plot_str.encode(), level=1)
|
||||||
|
).decode("ascii")
|
||||||
|
event.pop(PlotFields.plot_str, None)
|
||||||
|
|
||||||
|
@parallel_chunked_decorator(chunk_size=10)
|
||||||
|
def uncompress_plots(self, plot_events: Sequence[dict]):
|
||||||
|
for event in plot_events:
|
||||||
|
plot_data = event.pop(PlotFields.plot_data, None)
|
||||||
|
if plot_data and event.get(PlotFields.plot_str) is None:
|
||||||
|
event[PlotFields.plot_str] = zlib.decompress(
|
||||||
|
base64.b64decode(plot_data)
|
||||||
|
).decode()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _is_valid_json(text: str) -> bool:
|
||||||
|
"""Check str for valid json"""
|
||||||
|
if not text:
|
||||||
|
return False
|
||||||
|
try:
|
||||||
|
loads(text)
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
def _update_last_scalar_events_for_task(self, last_events, event):
|
def _update_last_scalar_events_for_task(self, last_events, event):
|
||||||
"""
|
"""
|
||||||
Update last_events structure with the provided event details if this event is more
|
Update last_events structure with the provided event details if this event is more
|
||||||
@ -423,7 +497,20 @@ class EventBLL(object):
|
|||||||
if not self.es.indices.exists(es_index):
|
if not self.es.indices.exists(es_index):
|
||||||
return TaskEventsResult()
|
return TaskEventsResult()
|
||||||
|
|
||||||
must = []
|
plot_valid_condition = {
|
||||||
|
"bool": {
|
||||||
|
"should": [
|
||||||
|
{"term": {PlotFields.valid_plot: True}},
|
||||||
|
{
|
||||||
|
"bool": {
|
||||||
|
"must_not": {"exists": {"field": PlotFields.valid_plot}}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
must = [plot_valid_condition]
|
||||||
|
|
||||||
if last_iterations_per_plot is None:
|
if last_iterations_per_plot is None:
|
||||||
must.append({"terms": {"task": tasks}})
|
must.append({"terms": {"task": tasks}})
|
||||||
else:
|
else:
|
||||||
@ -467,6 +554,7 @@ class EventBLL(object):
|
|||||||
)
|
)
|
||||||
|
|
||||||
events, total_events, next_scroll_id = self._get_events_from_es_res(es_res)
|
events, total_events, next_scroll_id = self._get_events_from_es_res(es_res)
|
||||||
|
self.uncompress_plots(events)
|
||||||
return TaskEventsResult(
|
return TaskEventsResult(
|
||||||
events=events, next_scroll_id=next_scroll_id, total_events=total_events
|
events=events, next_scroll_id=next_scroll_id, total_events=total_events
|
||||||
)
|
)
|
||||||
|
@ -1,6 +1,10 @@
|
|||||||
import functools
|
import functools
|
||||||
|
import itertools
|
||||||
|
from concurrent.futures.thread import ThreadPoolExecutor
|
||||||
from operator import itemgetter
|
from operator import itemgetter
|
||||||
from typing import Sequence, Optional, Callable, Tuple, Dict, Any, Set
|
from typing import Sequence, Optional, Callable, Tuple, Dict, Any, Set, Iterable
|
||||||
|
|
||||||
|
from boltons import iterutils
|
||||||
|
|
||||||
from apiserver.database.model import AttributedDocument
|
from apiserver.database.model import AttributedDocument
|
||||||
from apiserver.database.model.settings import Settings
|
from apiserver.database.model.settings import Settings
|
||||||
@ -78,3 +82,36 @@ class SetFieldsResolver:
|
|||||||
@functools.lru_cache()
|
@functools.lru_cache()
|
||||||
def get_server_uuid() -> Optional[str]:
|
def get_server_uuid() -> Optional[str]:
|
||||||
return Settings.get_by_key("server.uuid")
|
return Settings.get_by_key("server.uuid")
|
||||||
|
|
||||||
|
|
||||||
|
def parallel_chunked_decorator(func: Callable = None, chunk_size: int = 100):
|
||||||
|
"""
|
||||||
|
Decorates a method for parallel chunked execution. The method should have
|
||||||
|
one positional parameter (that is used for breaking into chunks)
|
||||||
|
and arbitrary number of keyword params. The return value should be iterable
|
||||||
|
The results are concatenated in the same order as the passed params
|
||||||
|
"""
|
||||||
|
if func is None:
|
||||||
|
return functools.partial(parallel_chunked_decorator, chunk_size=chunk_size)
|
||||||
|
|
||||||
|
@functools.wraps(func)
|
||||||
|
def wrapper(self, iterable: Iterable, **kwargs):
|
||||||
|
assert iterutils.is_collection(
|
||||||
|
iterable
|
||||||
|
), "The positional parameter should be an iterable for breaking into chunks"
|
||||||
|
|
||||||
|
func_with_params = functools.partial(func, self, **kwargs)
|
||||||
|
with ThreadPoolExecutor() as pool:
|
||||||
|
return list(
|
||||||
|
itertools.chain.from_iterable(
|
||||||
|
filter(
|
||||||
|
None,
|
||||||
|
pool.map(
|
||||||
|
func_with_params,
|
||||||
|
iterutils.chunked_iter(iterable, chunk_size),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
@ -11,3 +11,10 @@ max_metrics_concurrency: 4
|
|||||||
events_retrieval {
|
events_retrieval {
|
||||||
state_expiration_sec: 3600
|
state_expiration_sec: 3600
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# if set then plot str will be checked for the valid json on plot add
|
||||||
|
# and the result of the check is written to the db
|
||||||
|
validate_plot_str: false
|
||||||
|
|
||||||
|
# If not 0 then the plots equal or greater to the size will be stored compressed in the DB
|
||||||
|
plot_compression_threshold: 100000
|
@ -6,6 +6,9 @@
|
|||||||
"plot_str": {
|
"plot_str": {
|
||||||
"type": "text",
|
"type": "text",
|
||||||
"index": false
|
"index": false
|
||||||
|
},
|
||||||
|
"plot_data": {
|
||||||
|
"type": "binary"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -147,6 +147,10 @@
|
|||||||
"""
|
"""
|
||||||
type: string
|
type: string
|
||||||
}
|
}
|
||||||
|
skip_validation {
|
||||||
|
description: "If set then plot_str is not checked for a valid json. The default is False"
|
||||||
|
type: boolean
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
scalar_key_enum {
|
scalar_key_enum {
|
||||||
|
@ -435,9 +435,9 @@ class TestTaskEvents(TestService):
|
|||||||
)
|
)
|
||||||
self.send(event)
|
self.send(event)
|
||||||
|
|
||||||
event = self._create_task_event("plot", task, 100)
|
event1 = self._create_task_event("plot", task, 100)
|
||||||
event["metric"] = "confusion"
|
event1["metric"] = "confusion"
|
||||||
event.update(
|
event1.update(
|
||||||
{
|
{
|
||||||
"plot_str": json.dumps(
|
"plot_str": json.dumps(
|
||||||
{
|
{
|
||||||
@ -476,14 +476,54 @@ class TestTaskEvents(TestService):
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
self.send(event)
|
self.send(event1)
|
||||||
|
|
||||||
data = self.api.events.get_task_plots(task=task)
|
plots = self.api.events.get_task_plots(task=task).plots
|
||||||
assert len(data["plots"]) == 2
|
self.assertEqual(
|
||||||
|
{e["plot_str"] for e in (event, event1)}, {p.plot_str for p in plots}
|
||||||
|
)
|
||||||
|
|
||||||
self.api.tasks.reset(task=task)
|
self.api.tasks.reset(task=task)
|
||||||
data = self.api.events.get_task_plots(task=task)
|
plots = self.api.events.get_task_plots(task=task).plots
|
||||||
assert len(data["plots"]) == 0
|
self.assertEqual(len(plots), 0)
|
||||||
|
|
||||||
|
@unittest.skip("this test will run only if 'validate_plot_str' is set to true")
|
||||||
|
def test_plots_validation(self):
|
||||||
|
valid_plot_str = json.dumps({"data": []})
|
||||||
|
invalid_plot_str = "Not a valid json"
|
||||||
|
task = self._temp_task()
|
||||||
|
|
||||||
|
event = self._create_task_event(
|
||||||
|
"plot", task, 0, metric="test1", plot_str=valid_plot_str
|
||||||
|
)
|
||||||
|
event1 = self._create_task_event(
|
||||||
|
"plot", task, 100, metric="test2", plot_str=invalid_plot_str
|
||||||
|
)
|
||||||
|
self.send_batch([event, event1])
|
||||||
|
res = self.api.events.get_task_plots(task=task).plots
|
||||||
|
self.assertEqual(len(res), 1)
|
||||||
|
self.assertEqual(res[0].metric, "test1")
|
||||||
|
|
||||||
|
event = self._create_task_event(
|
||||||
|
"plot",
|
||||||
|
task,
|
||||||
|
0,
|
||||||
|
metric="test1",
|
||||||
|
plot_str=valid_plot_str,
|
||||||
|
skip_validation=True,
|
||||||
|
)
|
||||||
|
event1 = self._create_task_event(
|
||||||
|
"plot",
|
||||||
|
task,
|
||||||
|
100,
|
||||||
|
metric="test2",
|
||||||
|
plot_str=invalid_plot_str,
|
||||||
|
skip_validation=True,
|
||||||
|
)
|
||||||
|
self.send_batch([event, event1])
|
||||||
|
res = self.api.events.get_task_plots(task=task).plots
|
||||||
|
self.assertEqual(len(res), 2)
|
||||||
|
self.assertEqual(set(r.metric for r in res), {"test1", "test2"})
|
||||||
|
|
||||||
def send_batch(self, events):
|
def send_batch(self, events):
|
||||||
_, data = self.api.send_batch("events.add_batch", events)
|
_, data = self.api.send_batch("events.add_batch", events)
|
||||||
|
Loading…
Reference in New Issue
Block a user