mirror of
https://github.com/clearml/clearml-server
synced 2025-03-03 18:54:20 +00:00
Fix base query building
Fix schema Improve events.scalar_metrics_iter_raw implementation
This commit is contained in:
parent
e334246b46
commit
cae38a365b
@ -102,6 +102,7 @@ class LogEventsRequest(TaskEventsRequestBase):
|
||||
|
||||
|
||||
class ScalarMetricsIterRawRequest(TaskEventsRequestBase):
|
||||
batch_size: int = IntField()
|
||||
key: ScalarKeyEnum = ActualEnumField(ScalarKeyEnum, default=ScalarKeyEnum.iter)
|
||||
metric: MetricVariants = EmbeddedField(MetricVariants, required=True)
|
||||
count_total: bool = BoolField(default=False)
|
||||
|
@ -18,7 +18,7 @@ events_retrieval {
|
||||
# the max amount of variants to aggregate on
|
||||
max_variants_count: 100
|
||||
|
||||
max_raw_scalars_size: 10000
|
||||
max_raw_scalars_size: 200000
|
||||
|
||||
scroll_id_key: "cTN5VEtWEC6QrHvUl0FTx9kNyO0CcCK1p57akxma"
|
||||
}
|
||||
|
@ -1,9 +1,21 @@
|
||||
import re
|
||||
from collections import namedtuple
|
||||
from functools import reduce, partial
|
||||
from typing import Collection, Sequence, Union, Optional, Type, Tuple, Mapping, Any, Callable
|
||||
from typing import (
|
||||
Collection,
|
||||
Sequence,
|
||||
Union,
|
||||
Optional,
|
||||
Type,
|
||||
Tuple,
|
||||
Mapping,
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
)
|
||||
|
||||
from boltons.iterutils import first, bucketize, partition
|
||||
from boltons.iterutils import first, partition
|
||||
from dateutil.parser import parse as parse_datetime
|
||||
from mongoengine import Q, Document, ListField, StringField, IntField
|
||||
from pymongo.command_cursor import CommandCursor
|
||||
@ -109,43 +121,55 @@ class GetMixin(PropsMixin):
|
||||
|
||||
class ListFieldBucketHelper:
|
||||
op_prefix = "__$"
|
||||
legacy_exclude_prefix = "-"
|
||||
_legacy_exclude_prefix = "-"
|
||||
_legacy_exclude_mongo_op = "nin"
|
||||
|
||||
_default = "in"
|
||||
default_mongo_op = "in"
|
||||
_ops = {
|
||||
# op -> (mongo_op, sticky)
|
||||
"not": ("nin", False),
|
||||
"nop": (default_mongo_op, False),
|
||||
"all": ("all", True),
|
||||
"and": ("all", True),
|
||||
}
|
||||
_next = _default
|
||||
_sticky = False
|
||||
|
||||
def __init__(self, legacy=False):
|
||||
self._legacy = legacy
|
||||
self._current_op = None
|
||||
self._sticky = False
|
||||
self._support_legacy = legacy
|
||||
|
||||
def key(self, v) -> Optional[str]:
|
||||
def _key(self, v) -> Optional[Union[str, bool]]:
|
||||
if v is None:
|
||||
self._next = self._default
|
||||
return self._default
|
||||
elif self._legacy and v.startswith(self.legacy_exclude_prefix):
|
||||
self._next = self._default
|
||||
return self._ops["not"][0]
|
||||
self._current_op = None
|
||||
self._sticky = False
|
||||
return self.default_mongo_op
|
||||
elif self._current_op:
|
||||
current_op = self._current_op
|
||||
if not self._sticky:
|
||||
self._current_op = None
|
||||
return current_op
|
||||
elif self._support_legacy and v.startswith(self._legacy_exclude_prefix):
|
||||
self._current_op = None
|
||||
return False
|
||||
elif v.startswith(self.op_prefix):
|
||||
self._next, self._sticky = self._ops.get(
|
||||
v[len(self.op_prefix) :], (self._default, self._sticky)
|
||||
self._current_op, self._sticky = self._ops.get(
|
||||
v[len(self.op_prefix):], (self.default_mongo_op, self._sticky)
|
||||
)
|
||||
return None
|
||||
|
||||
next_ = self._next
|
||||
if not self._sticky:
|
||||
self._next = self._default
|
||||
return self.default_mongo_op
|
||||
|
||||
return next_
|
||||
|
||||
def value_transform(self, v):
|
||||
if self._legacy and v and v.startswith(self.legacy_exclude_prefix):
|
||||
return v[len(self.legacy_exclude_prefix) :]
|
||||
return v
|
||||
def get_actions(self, data: Sequence[str]) -> Dict[str, List[Union[str, None]]]:
|
||||
actions = {}
|
||||
for val in data:
|
||||
key = self._key(val)
|
||||
if key is None:
|
||||
continue
|
||||
elif self._support_legacy and key is False:
|
||||
key = self._legacy_exclude_mongo_op
|
||||
val = val[len(self._legacy_exclude_prefix) :]
|
||||
actions.setdefault(key, []).append(val)
|
||||
return actions
|
||||
|
||||
get_all_query_options = QueryParameterOptions()
|
||||
|
||||
@ -261,7 +285,9 @@ class GetMixin(PropsMixin):
|
||||
Prepare a query object based on the provided query dictionary and various fields.
|
||||
|
||||
NOTE: BE VERY CAREFUL WITH THIS CALL, as it allows creating queries that span across companies.
|
||||
|
||||
IMPLEMENTATION NOTE: Make sure that inside this function or the function it depends on RegexQ is always
|
||||
used instead of Q. Otherwise we can and up with some combination that is not processed according to
|
||||
RegexQ rules
|
||||
:param parameters_options: Specifies options for parsing the parameters (see ParametersOptions)
|
||||
:param parameters: Query dictionary (relevant keys are these specified by the various field names parameters).
|
||||
Supported parameters:
|
||||
@ -298,7 +324,7 @@ class GetMixin(PropsMixin):
|
||||
patterns=opts.fields or [], parameters=parameters
|
||||
).items():
|
||||
if "._" in field or "_." in field:
|
||||
query &= Q(__raw__={field: data})
|
||||
query &= RegexQ(__raw__={field: data})
|
||||
else:
|
||||
dict_query[field.replace(".", "__")] = data
|
||||
|
||||
@ -332,22 +358,31 @@ class GetMixin(PropsMixin):
|
||||
break
|
||||
if any("._" in f for f in data.fields):
|
||||
q = reduce(
|
||||
lambda a, x: func(a, Q(__raw__={x: {"$regex": data.pattern, "$options": "i"}})),
|
||||
lambda a, x: func(
|
||||
a,
|
||||
RegexQ(
|
||||
__raw__={
|
||||
x: {"$regex": data.pattern, "$options": "i"}
|
||||
}
|
||||
),
|
||||
),
|
||||
data.fields,
|
||||
Q()
|
||||
RegexQ(),
|
||||
)
|
||||
else:
|
||||
regex = RegexWrapper(data.pattern, flags=re.IGNORECASE)
|
||||
sep_fields = [f.replace(".", "__") for f in data.fields]
|
||||
q = reduce(
|
||||
lambda a, x: func(a, RegexQ(**{x: regex})), sep_fields, RegexQ()
|
||||
lambda a, x: func(a, RegexQ(**{x: regex})),
|
||||
sep_fields,
|
||||
RegexQ(),
|
||||
)
|
||||
query = query & q
|
||||
|
||||
return query & RegexQ(**dict_query)
|
||||
|
||||
@classmethod
|
||||
def get_range_field_query(cls, field: str, data: Sequence[Optional[str]]) -> Q:
|
||||
def get_range_field_query(cls, field: str, data: Sequence[Optional[str]]) -> RegexQ:
|
||||
"""
|
||||
Return a range query for the provided field. The data should contain min and max values
|
||||
Both intervals are included. For open range queries either min or max can be None
|
||||
@ -371,14 +406,14 @@ class GetMixin(PropsMixin):
|
||||
if max_val is not None:
|
||||
query[f"{mongoengine_field}__lte"] = max_val
|
||||
|
||||
q = Q(**query)
|
||||
q = RegexQ(**query)
|
||||
if min_val is None:
|
||||
q |= Q(**{mongoengine_field: None})
|
||||
q |= RegexQ(**{mongoengine_field: None})
|
||||
|
||||
return q
|
||||
|
||||
@classmethod
|
||||
def get_list_field_query(cls, field: str, data: Sequence[Optional[str]]) -> Q:
|
||||
def get_list_field_query(cls, field: str, data: Sequence[Optional[str]]) -> RegexQ:
|
||||
"""
|
||||
Get a proper mongoengine Q object that represents an "or" query for the provided values
|
||||
with respect to the given list field, with support for "none of empty" in case a None value
|
||||
@ -392,13 +427,15 @@ class GetMixin(PropsMixin):
|
||||
data = [data]
|
||||
# raise MakeGetAllQueryError("expected list", field)
|
||||
|
||||
# TODO: backwards compatibility only for older API versions
|
||||
helper = cls.ListFieldBucketHelper(legacy=True)
|
||||
actions = bucketize(
|
||||
data, key=helper.key, value_transform=helper.value_transform
|
||||
)
|
||||
actions = cls.ListFieldBucketHelper(legacy=True).get_actions(data)
|
||||
allow_empty = False
|
||||
default_op_actions = actions.get(cls.ListFieldBucketHelper.default_mongo_op)
|
||||
if default_op_actions and None in default_op_actions:
|
||||
allow_empty = True
|
||||
default_op_actions.remove(None)
|
||||
if not default_op_actions:
|
||||
actions.pop(cls.ListFieldBucketHelper.default_mongo_op)
|
||||
|
||||
allow_empty = None in actions.get("in", {})
|
||||
mongoengine_field = field.replace(".", "__")
|
||||
|
||||
q = RegexQ()
|
||||
@ -412,8 +449,9 @@ class GetMixin(PropsMixin):
|
||||
|
||||
return (
|
||||
q
|
||||
| Q(**{f"{mongoengine_field}__exists": False})
|
||||
| Q(**{mongoengine_field: []})
|
||||
| RegexQ(**{f"{mongoengine_field}__exists": False})
|
||||
| RegexQ(**{mongoengine_field: []})
|
||||
| RegexQ(**{mongoengine_field: None})
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@ -701,9 +739,7 @@ class GetMixin(PropsMixin):
|
||||
override_collation=override_collation,
|
||||
)
|
||||
return cls.get_data_with_scroll_and_filter_support(
|
||||
query_dict=query_dict,
|
||||
data_getter=data_getter,
|
||||
ret_params=ret_params,
|
||||
query_dict=query_dict, data_getter=data_getter, ret_params=ret_params,
|
||||
)
|
||||
|
||||
return cls._get_many_no_company(
|
||||
|
@ -1244,9 +1244,9 @@ scalar_metrics_iter_raw {
|
||||
"$ref": "#/definitions/scalar_key_enum"
|
||||
}
|
||||
batch_size {
|
||||
description: "The number of data points to return for this call. Optional, the default value is 5000"
|
||||
description: "The number of data points to return for this call. Optional, the default value is 10000. Maximum batch size is 200000"
|
||||
type: integer
|
||||
default: 5000
|
||||
default: 10000
|
||||
}
|
||||
count_total {
|
||||
description: "Count the total number of data points. If false, total number of data points is not counted and null is returned"
|
||||
|
@ -540,7 +540,7 @@ get_all_ex {
|
||||
}
|
||||
}
|
||||
"999.0": ${get_all_ex."2.15"} {
|
||||
response.properties.stats_with_children {
|
||||
request.properties.stats_with_children {
|
||||
description: "If include_stats flag is set then this flag contols whether the child projects tasks are taken into statistics or not"
|
||||
type: boolean
|
||||
default: true
|
||||
|
@ -1,4 +1,5 @@
|
||||
import itertools
|
||||
import math
|
||||
from collections import defaultdict
|
||||
from operator import itemgetter
|
||||
from typing import Sequence, Optional
|
||||
@ -844,10 +845,15 @@ def scalar_metrics_iter_raw(
|
||||
):
|
||||
key = request.key or ScalarKeyEnum.iter
|
||||
scalar_key = ScalarKey.resolve(key)
|
||||
if request.batch_size and request.batch_size < 0:
|
||||
raise errors.bad_request.ValidationError(
|
||||
"batch_size should be non negative number"
|
||||
)
|
||||
|
||||
if not request.scroll_id:
|
||||
from_key_value = None
|
||||
total = None
|
||||
request.batch_size = request.batch_size or 10_000
|
||||
else:
|
||||
try:
|
||||
scroll = ScalarMetricsIterRawScroll.from_scroll_id(request.scroll_id)
|
||||
@ -861,9 +867,7 @@ def scalar_metrics_iter_raw(
|
||||
|
||||
from_key_value = scalar_key.cast_value(scroll.from_key_value)
|
||||
total = scroll.total
|
||||
|
||||
scroll.request.batch_size = request.batch_size or scroll.request.batch_size
|
||||
request = scroll.request
|
||||
request.batch_size = request.batch_size or scroll.request.batch_size
|
||||
|
||||
task_id = request.task
|
||||
|
||||
@ -884,38 +888,45 @@ def scalar_metrics_iter_raw(
|
||||
batch_size = min(
|
||||
request.batch_size,
|
||||
int(
|
||||
config.get("services.events.events_retrieval.max_raw_scalars_size", 10_000)
|
||||
config.get("services.events.events_retrieval.max_raw_scalars_size", 200_000)
|
||||
),
|
||||
)
|
||||
|
||||
res = event_bll.events_iterator.get_task_events(
|
||||
event_type=EventType.metrics_scalar,
|
||||
company_id=task.company,
|
||||
task_id=task_id,
|
||||
batch_size=batch_size,
|
||||
navigate_earlier=False,
|
||||
from_key_value=from_key_value,
|
||||
metric_variants=metric_variants,
|
||||
key=key,
|
||||
)
|
||||
events = []
|
||||
for iteration in range(0, math.ceil(batch_size / 10_000)):
|
||||
res = event_bll.events_iterator.get_task_events(
|
||||
event_type=EventType.metrics_scalar,
|
||||
company_id=task.company,
|
||||
task_id=task_id,
|
||||
batch_size=min(batch_size, 10_000),
|
||||
navigate_earlier=False,
|
||||
from_key_value=from_key_value,
|
||||
metric_variants=metric_variants,
|
||||
key=key,
|
||||
)
|
||||
if not res.events:
|
||||
break
|
||||
events.extend(res.events)
|
||||
from_key_value = str(events[-1][scalar_key.field])
|
||||
|
||||
key = str(key)
|
||||
|
||||
variants = {
|
||||
variant: extract_properties_to_lists(
|
||||
["value", scalar_key.field], events, target_keys=["y", key]
|
||||
)
|
||||
for variant, events in bucketize(res.events, key=itemgetter("variant")).items()
|
||||
for variant, events in bucketize(events, key=itemgetter("variant")).items()
|
||||
}
|
||||
|
||||
call.kpis["events"] = len(events)
|
||||
|
||||
scroll = ScalarMetricsIterRawScroll(
|
||||
from_key_value=str(res.events[-1][scalar_key.field]) if res.events else None,
|
||||
from_key_value=str(events[-1][scalar_key.field]) if events else None,
|
||||
total=total,
|
||||
request=request,
|
||||
)
|
||||
|
||||
return make_response(
|
||||
returned=len(res.events),
|
||||
returned=len(events),
|
||||
total=total,
|
||||
scroll_id=scroll.get_scroll_id(),
|
||||
variants=variants,
|
||||
|
@ -12,9 +12,6 @@ from apiserver.tests.automated import TestService
|
||||
|
||||
|
||||
class TestTaskEvents(TestService):
|
||||
def setUp(self, version="2.9"):
|
||||
super().setUp(version=version)
|
||||
|
||||
def _temp_task(self, name="test task events"):
|
||||
task_input = dict(
|
||||
name=name, type="training", input=dict(mapping={}, view=dict(entries=[])),
|
||||
@ -257,6 +254,45 @@ class TestTaskEvents(TestService):
|
||||
self.assertEqual(len(task_data["x"]), iterations)
|
||||
self.assertEqual(len(task_data["y"]), iterations)
|
||||
|
||||
def test_task_metric_raw(self):
|
||||
metric = "Metric1"
|
||||
variant = "Variant1"
|
||||
iter_count = 100
|
||||
task = self._temp_task()
|
||||
events = [
|
||||
{
|
||||
**self._create_task_event("training_stats_scalar", task, iteration),
|
||||
"metric": metric,
|
||||
"variant": variant,
|
||||
"value": iteration,
|
||||
}
|
||||
for iteration in range(iter_count)
|
||||
]
|
||||
self.send_batch(events)
|
||||
|
||||
batch_size = 15
|
||||
metric_param = {"metric": metric, "variants": [variant]}
|
||||
res = self.api.events.scalar_metrics_iter_raw(
|
||||
task=task, batch_size=batch_size, metric=metric_param, count_total=True
|
||||
)
|
||||
self.assertEqual(res.total, len(events))
|
||||
self.assertTrue(res.scroll_id)
|
||||
res_iters = []
|
||||
res_ys = []
|
||||
calls = 0
|
||||
while res.returned or calls > 10:
|
||||
calls += 1
|
||||
res_iters.extend(res.variants[variant]["iter"])
|
||||
res_ys.extend(res.variants[variant]["y"])
|
||||
scroll_id = res.scroll_id
|
||||
res = self.api.events.scalar_metrics_iter_raw(
|
||||
task=task, metric=metric_param, scroll_id=scroll_id
|
||||
)
|
||||
|
||||
self.assertEqual(calls, len(events) // batch_size + 1)
|
||||
self.assertEqual(res_iters, [ev["iter"] for ev in events])
|
||||
self.assertEqual(res_ys, [ev["value"] for ev in events])
|
||||
|
||||
def test_task_metric_value_intervals(self):
|
||||
metric = "Metric1"
|
||||
variant = "Variant1"
|
||||
|
Loading…
Reference in New Issue
Block a user