Fix base query building

Fix schema
Improve events.scalar_metrics_iter_raw implementation
This commit is contained in:
allegroai 2022-02-13 19:28:23 +02:00
parent e334246b46
commit cae38a365b
7 changed files with 153 additions and 69 deletions

View File

@ -102,6 +102,7 @@ class LogEventsRequest(TaskEventsRequestBase):
class ScalarMetricsIterRawRequest(TaskEventsRequestBase): class ScalarMetricsIterRawRequest(TaskEventsRequestBase):
batch_size: int = IntField()
key: ScalarKeyEnum = ActualEnumField(ScalarKeyEnum, default=ScalarKeyEnum.iter) key: ScalarKeyEnum = ActualEnumField(ScalarKeyEnum, default=ScalarKeyEnum.iter)
metric: MetricVariants = EmbeddedField(MetricVariants, required=True) metric: MetricVariants = EmbeddedField(MetricVariants, required=True)
count_total: bool = BoolField(default=False) count_total: bool = BoolField(default=False)

View File

@ -18,7 +18,7 @@ events_retrieval {
# the max amount of variants to aggregate on # the max amount of variants to aggregate on
max_variants_count: 100 max_variants_count: 100
max_raw_scalars_size: 10000 max_raw_scalars_size: 200000
scroll_id_key: "cTN5VEtWEC6QrHvUl0FTx9kNyO0CcCK1p57akxma" scroll_id_key: "cTN5VEtWEC6QrHvUl0FTx9kNyO0CcCK1p57akxma"
} }

View File

@ -1,9 +1,21 @@
import re import re
from collections import namedtuple from collections import namedtuple
from functools import reduce, partial from functools import reduce, partial
from typing import Collection, Sequence, Union, Optional, Type, Tuple, Mapping, Any, Callable from typing import (
Collection,
Sequence,
Union,
Optional,
Type,
Tuple,
Mapping,
Any,
Callable,
Dict,
List,
)
from boltons.iterutils import first, bucketize, partition from boltons.iterutils import first, partition
from dateutil.parser import parse as parse_datetime from dateutil.parser import parse as parse_datetime
from mongoengine import Q, Document, ListField, StringField, IntField from mongoengine import Q, Document, ListField, StringField, IntField
from pymongo.command_cursor import CommandCursor from pymongo.command_cursor import CommandCursor
@ -109,43 +121,55 @@ class GetMixin(PropsMixin):
class ListFieldBucketHelper: class ListFieldBucketHelper:
op_prefix = "__$" op_prefix = "__$"
legacy_exclude_prefix = "-" _legacy_exclude_prefix = "-"
_legacy_exclude_mongo_op = "nin"
_default = "in" default_mongo_op = "in"
_ops = { _ops = {
# op -> (mongo_op, sticky)
"not": ("nin", False), "not": ("nin", False),
"nop": (default_mongo_op, False),
"all": ("all", True), "all": ("all", True),
"and": ("all", True), "and": ("all", True),
} }
_next = _default
_sticky = False
def __init__(self, legacy=False): def __init__(self, legacy=False):
self._legacy = legacy self._current_op = None
self._sticky = False
self._support_legacy = legacy
def key(self, v) -> Optional[str]: def _key(self, v) -> Optional[Union[str, bool]]:
if v is None: if v is None:
self._next = self._default self._current_op = None
return self._default self._sticky = False
elif self._legacy and v.startswith(self.legacy_exclude_prefix): return self.default_mongo_op
self._next = self._default elif self._current_op:
return self._ops["not"][0] current_op = self._current_op
if not self._sticky:
self._current_op = None
return current_op
elif self._support_legacy and v.startswith(self._legacy_exclude_prefix):
self._current_op = None
return False
elif v.startswith(self.op_prefix): elif v.startswith(self.op_prefix):
self._next, self._sticky = self._ops.get( self._current_op, self._sticky = self._ops.get(
v[len(self.op_prefix) :], (self._default, self._sticky) v[len(self.op_prefix):], (self.default_mongo_op, self._sticky)
) )
return None return None
next_ = self._next return self.default_mongo_op
if not self._sticky:
self._next = self._default
return next_ def get_actions(self, data: Sequence[str]) -> Dict[str, List[Union[str, None]]]:
actions = {}
def value_transform(self, v): for val in data:
if self._legacy and v and v.startswith(self.legacy_exclude_prefix): key = self._key(val)
return v[len(self.legacy_exclude_prefix) :] if key is None:
return v continue
elif self._support_legacy and key is False:
key = self._legacy_exclude_mongo_op
val = val[len(self._legacy_exclude_prefix) :]
actions.setdefault(key, []).append(val)
return actions
get_all_query_options = QueryParameterOptions() get_all_query_options = QueryParameterOptions()
@ -261,7 +285,9 @@ class GetMixin(PropsMixin):
Prepare a query object based on the provided query dictionary and various fields. Prepare a query object based on the provided query dictionary and various fields.
NOTE: BE VERY CAREFUL WITH THIS CALL, as it allows creating queries that span across companies. NOTE: BE VERY CAREFUL WITH THIS CALL, as it allows creating queries that span across companies.
IMPLEMENTATION NOTE: Make sure that inside this function or the function it depends on RegexQ is always
used instead of Q. Otherwise we can and up with some combination that is not processed according to
RegexQ rules
:param parameters_options: Specifies options for parsing the parameters (see ParametersOptions) :param parameters_options: Specifies options for parsing the parameters (see ParametersOptions)
:param parameters: Query dictionary (relevant keys are these specified by the various field names parameters). :param parameters: Query dictionary (relevant keys are these specified by the various field names parameters).
Supported parameters: Supported parameters:
@ -298,7 +324,7 @@ class GetMixin(PropsMixin):
patterns=opts.fields or [], parameters=parameters patterns=opts.fields or [], parameters=parameters
).items(): ).items():
if "._" in field or "_." in field: if "._" in field or "_." in field:
query &= Q(__raw__={field: data}) query &= RegexQ(__raw__={field: data})
else: else:
dict_query[field.replace(".", "__")] = data dict_query[field.replace(".", "__")] = data
@ -332,22 +358,31 @@ class GetMixin(PropsMixin):
break break
if any("._" in f for f in data.fields): if any("._" in f for f in data.fields):
q = reduce( q = reduce(
lambda a, x: func(a, Q(__raw__={x: {"$regex": data.pattern, "$options": "i"}})), lambda a, x: func(
a,
RegexQ(
__raw__={
x: {"$regex": data.pattern, "$options": "i"}
}
),
),
data.fields, data.fields,
Q() RegexQ(),
) )
else: else:
regex = RegexWrapper(data.pattern, flags=re.IGNORECASE) regex = RegexWrapper(data.pattern, flags=re.IGNORECASE)
sep_fields = [f.replace(".", "__") for f in data.fields] sep_fields = [f.replace(".", "__") for f in data.fields]
q = reduce( q = reduce(
lambda a, x: func(a, RegexQ(**{x: regex})), sep_fields, RegexQ() lambda a, x: func(a, RegexQ(**{x: regex})),
sep_fields,
RegexQ(),
) )
query = query & q query = query & q
return query & RegexQ(**dict_query) return query & RegexQ(**dict_query)
@classmethod @classmethod
def get_range_field_query(cls, field: str, data: Sequence[Optional[str]]) -> Q: def get_range_field_query(cls, field: str, data: Sequence[Optional[str]]) -> RegexQ:
""" """
Return a range query for the provided field. The data should contain min and max values Return a range query for the provided field. The data should contain min and max values
Both intervals are included. For open range queries either min or max can be None Both intervals are included. For open range queries either min or max can be None
@ -371,14 +406,14 @@ class GetMixin(PropsMixin):
if max_val is not None: if max_val is not None:
query[f"{mongoengine_field}__lte"] = max_val query[f"{mongoengine_field}__lte"] = max_val
q = Q(**query) q = RegexQ(**query)
if min_val is None: if min_val is None:
q |= Q(**{mongoengine_field: None}) q |= RegexQ(**{mongoengine_field: None})
return q return q
@classmethod @classmethod
def get_list_field_query(cls, field: str, data: Sequence[Optional[str]]) -> Q: def get_list_field_query(cls, field: str, data: Sequence[Optional[str]]) -> RegexQ:
""" """
Get a proper mongoengine Q object that represents an "or" query for the provided values Get a proper mongoengine Q object that represents an "or" query for the provided values
with respect to the given list field, with support for "none of empty" in case a None value with respect to the given list field, with support for "none of empty" in case a None value
@ -392,13 +427,15 @@ class GetMixin(PropsMixin):
data = [data] data = [data]
# raise MakeGetAllQueryError("expected list", field) # raise MakeGetAllQueryError("expected list", field)
# TODO: backwards compatibility only for older API versions actions = cls.ListFieldBucketHelper(legacy=True).get_actions(data)
helper = cls.ListFieldBucketHelper(legacy=True) allow_empty = False
actions = bucketize( default_op_actions = actions.get(cls.ListFieldBucketHelper.default_mongo_op)
data, key=helper.key, value_transform=helper.value_transform if default_op_actions and None in default_op_actions:
) allow_empty = True
default_op_actions.remove(None)
if not default_op_actions:
actions.pop(cls.ListFieldBucketHelper.default_mongo_op)
allow_empty = None in actions.get("in", {})
mongoengine_field = field.replace(".", "__") mongoengine_field = field.replace(".", "__")
q = RegexQ() q = RegexQ()
@ -412,8 +449,9 @@ class GetMixin(PropsMixin):
return ( return (
q q
| Q(**{f"{mongoengine_field}__exists": False}) | RegexQ(**{f"{mongoengine_field}__exists": False})
| Q(**{mongoengine_field: []}) | RegexQ(**{mongoengine_field: []})
| RegexQ(**{mongoengine_field: None})
) )
@classmethod @classmethod
@ -701,9 +739,7 @@ class GetMixin(PropsMixin):
override_collation=override_collation, override_collation=override_collation,
) )
return cls.get_data_with_scroll_and_filter_support( return cls.get_data_with_scroll_and_filter_support(
query_dict=query_dict, query_dict=query_dict, data_getter=data_getter, ret_params=ret_params,
data_getter=data_getter,
ret_params=ret_params,
) )
return cls._get_many_no_company( return cls._get_many_no_company(

View File

@ -1244,9 +1244,9 @@ scalar_metrics_iter_raw {
"$ref": "#/definitions/scalar_key_enum" "$ref": "#/definitions/scalar_key_enum"
} }
batch_size { batch_size {
description: "The number of data points to return for this call. Optional, the default value is 5000" description: "The number of data points to return for this call. Optional, the default value is 10000. Maximum batch size is 200000"
type: integer type: integer
default: 5000 default: 10000
} }
count_total { count_total {
description: "Count the total number of data points. If false, total number of data points is not counted and null is returned" description: "Count the total number of data points. If false, total number of data points is not counted and null is returned"

View File

@ -540,7 +540,7 @@ get_all_ex {
} }
} }
"999.0": ${get_all_ex."2.15"} { "999.0": ${get_all_ex."2.15"} {
response.properties.stats_with_children { request.properties.stats_with_children {
description: "If include_stats flag is set then this flag contols whether the child projects tasks are taken into statistics or not" description: "If include_stats flag is set then this flag contols whether the child projects tasks are taken into statistics or not"
type: boolean type: boolean
default: true default: true

View File

@ -1,4 +1,5 @@
import itertools import itertools
import math
from collections import defaultdict from collections import defaultdict
from operator import itemgetter from operator import itemgetter
from typing import Sequence, Optional from typing import Sequence, Optional
@ -844,10 +845,15 @@ def scalar_metrics_iter_raw(
): ):
key = request.key or ScalarKeyEnum.iter key = request.key or ScalarKeyEnum.iter
scalar_key = ScalarKey.resolve(key) scalar_key = ScalarKey.resolve(key)
if request.batch_size and request.batch_size < 0:
raise errors.bad_request.ValidationError(
"batch_size should be non negative number"
)
if not request.scroll_id: if not request.scroll_id:
from_key_value = None from_key_value = None
total = None total = None
request.batch_size = request.batch_size or 10_000
else: else:
try: try:
scroll = ScalarMetricsIterRawScroll.from_scroll_id(request.scroll_id) scroll = ScalarMetricsIterRawScroll.from_scroll_id(request.scroll_id)
@ -861,9 +867,7 @@ def scalar_metrics_iter_raw(
from_key_value = scalar_key.cast_value(scroll.from_key_value) from_key_value = scalar_key.cast_value(scroll.from_key_value)
total = scroll.total total = scroll.total
request.batch_size = request.batch_size or scroll.request.batch_size
scroll.request.batch_size = request.batch_size or scroll.request.batch_size
request = scroll.request
task_id = request.task task_id = request.task
@ -884,38 +888,45 @@ def scalar_metrics_iter_raw(
batch_size = min( batch_size = min(
request.batch_size, request.batch_size,
int( int(
config.get("services.events.events_retrieval.max_raw_scalars_size", 10_000) config.get("services.events.events_retrieval.max_raw_scalars_size", 200_000)
), ),
) )
res = event_bll.events_iterator.get_task_events( events = []
event_type=EventType.metrics_scalar, for iteration in range(0, math.ceil(batch_size / 10_000)):
company_id=task.company, res = event_bll.events_iterator.get_task_events(
task_id=task_id, event_type=EventType.metrics_scalar,
batch_size=batch_size, company_id=task.company,
navigate_earlier=False, task_id=task_id,
from_key_value=from_key_value, batch_size=min(batch_size, 10_000),
metric_variants=metric_variants, navigate_earlier=False,
key=key, from_key_value=from_key_value,
) metric_variants=metric_variants,
key=key,
)
if not res.events:
break
events.extend(res.events)
from_key_value = str(events[-1][scalar_key.field])
key = str(key) key = str(key)
variants = { variants = {
variant: extract_properties_to_lists( variant: extract_properties_to_lists(
["value", scalar_key.field], events, target_keys=["y", key] ["value", scalar_key.field], events, target_keys=["y", key]
) )
for variant, events in bucketize(res.events, key=itemgetter("variant")).items() for variant, events in bucketize(events, key=itemgetter("variant")).items()
} }
call.kpis["events"] = len(events)
scroll = ScalarMetricsIterRawScroll( scroll = ScalarMetricsIterRawScroll(
from_key_value=str(res.events[-1][scalar_key.field]) if res.events else None, from_key_value=str(events[-1][scalar_key.field]) if events else None,
total=total, total=total,
request=request, request=request,
) )
return make_response( return make_response(
returned=len(res.events), returned=len(events),
total=total, total=total,
scroll_id=scroll.get_scroll_id(), scroll_id=scroll.get_scroll_id(),
variants=variants, variants=variants,

View File

@ -12,9 +12,6 @@ from apiserver.tests.automated import TestService
class TestTaskEvents(TestService): class TestTaskEvents(TestService):
def setUp(self, version="2.9"):
super().setUp(version=version)
def _temp_task(self, name="test task events"): def _temp_task(self, name="test task events"):
task_input = dict( task_input = dict(
name=name, type="training", input=dict(mapping={}, view=dict(entries=[])), name=name, type="training", input=dict(mapping={}, view=dict(entries=[])),
@ -257,6 +254,45 @@ class TestTaskEvents(TestService):
self.assertEqual(len(task_data["x"]), iterations) self.assertEqual(len(task_data["x"]), iterations)
self.assertEqual(len(task_data["y"]), iterations) self.assertEqual(len(task_data["y"]), iterations)
def test_task_metric_raw(self):
metric = "Metric1"
variant = "Variant1"
iter_count = 100
task = self._temp_task()
events = [
{
**self._create_task_event("training_stats_scalar", task, iteration),
"metric": metric,
"variant": variant,
"value": iteration,
}
for iteration in range(iter_count)
]
self.send_batch(events)
batch_size = 15
metric_param = {"metric": metric, "variants": [variant]}
res = self.api.events.scalar_metrics_iter_raw(
task=task, batch_size=batch_size, metric=metric_param, count_total=True
)
self.assertEqual(res.total, len(events))
self.assertTrue(res.scroll_id)
res_iters = []
res_ys = []
calls = 0
while res.returned or calls > 10:
calls += 1
res_iters.extend(res.variants[variant]["iter"])
res_ys.extend(res.variants[variant]["y"])
scroll_id = res.scroll_id
res = self.api.events.scalar_metrics_iter_raw(
task=task, metric=metric_param, scroll_id=scroll_id
)
self.assertEqual(calls, len(events) // batch_size + 1)
self.assertEqual(res_iters, [ev["iter"] for ev in events])
self.assertEqual(res_ys, [ev["value"] for ev in events])
def test_task_metric_value_intervals(self): def test_task_metric_value_intervals(self):
metric = "Metric1" metric = "Metric1"
variant = "Variant1" variant = "Variant1"