Fix base query building

Fix schema
Improve events.scalar_metrics_iter_raw implementation
This commit is contained in:
allegroai 2022-02-13 19:28:23 +02:00
parent e334246b46
commit cae38a365b
7 changed files with 153 additions and 69 deletions

View File

@ -102,6 +102,7 @@ class LogEventsRequest(TaskEventsRequestBase):
class ScalarMetricsIterRawRequest(TaskEventsRequestBase):
batch_size: int = IntField()
key: ScalarKeyEnum = ActualEnumField(ScalarKeyEnum, default=ScalarKeyEnum.iter)
metric: MetricVariants = EmbeddedField(MetricVariants, required=True)
count_total: bool = BoolField(default=False)

View File

@ -18,7 +18,7 @@ events_retrieval {
# the max amount of variants to aggregate on
max_variants_count: 100
max_raw_scalars_size: 10000
max_raw_scalars_size: 200000
scroll_id_key: "cTN5VEtWEC6QrHvUl0FTx9kNyO0CcCK1p57akxma"
}

View File

@ -1,9 +1,21 @@
import re
from collections import namedtuple
from functools import reduce, partial
from typing import Collection, Sequence, Union, Optional, Type, Tuple, Mapping, Any, Callable
from typing import (
Collection,
Sequence,
Union,
Optional,
Type,
Tuple,
Mapping,
Any,
Callable,
Dict,
List,
)
from boltons.iterutils import first, bucketize, partition
from boltons.iterutils import first, partition
from dateutil.parser import parse as parse_datetime
from mongoengine import Q, Document, ListField, StringField, IntField
from pymongo.command_cursor import CommandCursor
@ -109,43 +121,55 @@ class GetMixin(PropsMixin):
class ListFieldBucketHelper:
op_prefix = "__$"
legacy_exclude_prefix = "-"
_legacy_exclude_prefix = "-"
_legacy_exclude_mongo_op = "nin"
_default = "in"
default_mongo_op = "in"
_ops = {
# op -> (mongo_op, sticky)
"not": ("nin", False),
"nop": (default_mongo_op, False),
"all": ("all", True),
"and": ("all", True),
}
_next = _default
_sticky = False
def __init__(self, legacy=False):
self._legacy = legacy
self._current_op = None
self._sticky = False
self._support_legacy = legacy
def key(self, v) -> Optional[str]:
def _key(self, v) -> Optional[Union[str, bool]]:
if v is None:
self._next = self._default
return self._default
elif self._legacy and v.startswith(self.legacy_exclude_prefix):
self._next = self._default
return self._ops["not"][0]
self._current_op = None
self._sticky = False
return self.default_mongo_op
elif self._current_op:
current_op = self._current_op
if not self._sticky:
self._current_op = None
return current_op
elif self._support_legacy and v.startswith(self._legacy_exclude_prefix):
self._current_op = None
return False
elif v.startswith(self.op_prefix):
self._next, self._sticky = self._ops.get(
v[len(self.op_prefix) :], (self._default, self._sticky)
self._current_op, self._sticky = self._ops.get(
v[len(self.op_prefix):], (self.default_mongo_op, self._sticky)
)
return None
next_ = self._next
if not self._sticky:
self._next = self._default
return self.default_mongo_op
return next_
def value_transform(self, v):
if self._legacy and v and v.startswith(self.legacy_exclude_prefix):
return v[len(self.legacy_exclude_prefix) :]
return v
def get_actions(self, data: Sequence[str]) -> Dict[str, List[Union[str, None]]]:
actions = {}
for val in data:
key = self._key(val)
if key is None:
continue
elif self._support_legacy and key is False:
key = self._legacy_exclude_mongo_op
val = val[len(self._legacy_exclude_prefix) :]
actions.setdefault(key, []).append(val)
return actions
get_all_query_options = QueryParameterOptions()
@ -261,7 +285,9 @@ class GetMixin(PropsMixin):
Prepare a query object based on the provided query dictionary and various fields.
NOTE: BE VERY CAREFUL WITH THIS CALL, as it allows creating queries that span across companies.
IMPLEMENTATION NOTE: Make sure that inside this function or the function it depends on RegexQ is always
used instead of Q. Otherwise we can and up with some combination that is not processed according to
RegexQ rules
:param parameters_options: Specifies options for parsing the parameters (see ParametersOptions)
:param parameters: Query dictionary (relevant keys are these specified by the various field names parameters).
Supported parameters:
@ -298,7 +324,7 @@ class GetMixin(PropsMixin):
patterns=opts.fields or [], parameters=parameters
).items():
if "._" in field or "_." in field:
query &= Q(__raw__={field: data})
query &= RegexQ(__raw__={field: data})
else:
dict_query[field.replace(".", "__")] = data
@ -332,22 +358,31 @@ class GetMixin(PropsMixin):
break
if any("._" in f for f in data.fields):
q = reduce(
lambda a, x: func(a, Q(__raw__={x: {"$regex": data.pattern, "$options": "i"}})),
lambda a, x: func(
a,
RegexQ(
__raw__={
x: {"$regex": data.pattern, "$options": "i"}
}
),
),
data.fields,
Q()
RegexQ(),
)
else:
regex = RegexWrapper(data.pattern, flags=re.IGNORECASE)
sep_fields = [f.replace(".", "__") for f in data.fields]
q = reduce(
lambda a, x: func(a, RegexQ(**{x: regex})), sep_fields, RegexQ()
lambda a, x: func(a, RegexQ(**{x: regex})),
sep_fields,
RegexQ(),
)
query = query & q
return query & RegexQ(**dict_query)
@classmethod
def get_range_field_query(cls, field: str, data: Sequence[Optional[str]]) -> Q:
def get_range_field_query(cls, field: str, data: Sequence[Optional[str]]) -> RegexQ:
"""
Return a range query for the provided field. The data should contain min and max values
Both intervals are included. For open range queries either min or max can be None
@ -371,14 +406,14 @@ class GetMixin(PropsMixin):
if max_val is not None:
query[f"{mongoengine_field}__lte"] = max_val
q = Q(**query)
q = RegexQ(**query)
if min_val is None:
q |= Q(**{mongoengine_field: None})
q |= RegexQ(**{mongoengine_field: None})
return q
@classmethod
def get_list_field_query(cls, field: str, data: Sequence[Optional[str]]) -> Q:
def get_list_field_query(cls, field: str, data: Sequence[Optional[str]]) -> RegexQ:
"""
Get a proper mongoengine Q object that represents an "or" query for the provided values
with respect to the given list field, with support for "none of empty" in case a None value
@ -392,13 +427,15 @@ class GetMixin(PropsMixin):
data = [data]
# raise MakeGetAllQueryError("expected list", field)
# TODO: backwards compatibility only for older API versions
helper = cls.ListFieldBucketHelper(legacy=True)
actions = bucketize(
data, key=helper.key, value_transform=helper.value_transform
)
actions = cls.ListFieldBucketHelper(legacy=True).get_actions(data)
allow_empty = False
default_op_actions = actions.get(cls.ListFieldBucketHelper.default_mongo_op)
if default_op_actions and None in default_op_actions:
allow_empty = True
default_op_actions.remove(None)
if not default_op_actions:
actions.pop(cls.ListFieldBucketHelper.default_mongo_op)
allow_empty = None in actions.get("in", {})
mongoengine_field = field.replace(".", "__")
q = RegexQ()
@ -412,8 +449,9 @@ class GetMixin(PropsMixin):
return (
q
| Q(**{f"{mongoengine_field}__exists": False})
| Q(**{mongoengine_field: []})
| RegexQ(**{f"{mongoengine_field}__exists": False})
| RegexQ(**{mongoengine_field: []})
| RegexQ(**{mongoengine_field: None})
)
@classmethod
@ -701,9 +739,7 @@ class GetMixin(PropsMixin):
override_collation=override_collation,
)
return cls.get_data_with_scroll_and_filter_support(
query_dict=query_dict,
data_getter=data_getter,
ret_params=ret_params,
query_dict=query_dict, data_getter=data_getter, ret_params=ret_params,
)
return cls._get_many_no_company(

View File

@ -1244,9 +1244,9 @@ scalar_metrics_iter_raw {
"$ref": "#/definitions/scalar_key_enum"
}
batch_size {
description: "The number of data points to return for this call. Optional, the default value is 5000"
description: "The number of data points to return for this call. Optional, the default value is 10000. Maximum batch size is 200000"
type: integer
default: 5000
default: 10000
}
count_total {
description: "Count the total number of data points. If false, total number of data points is not counted and null is returned"

View File

@ -540,7 +540,7 @@ get_all_ex {
}
}
"999.0": ${get_all_ex."2.15"} {
response.properties.stats_with_children {
request.properties.stats_with_children {
description: "If include_stats flag is set then this flag contols whether the child projects tasks are taken into statistics or not"
type: boolean
default: true

View File

@ -1,4 +1,5 @@
import itertools
import math
from collections import defaultdict
from operator import itemgetter
from typing import Sequence, Optional
@ -844,10 +845,15 @@ def scalar_metrics_iter_raw(
):
key = request.key or ScalarKeyEnum.iter
scalar_key = ScalarKey.resolve(key)
if request.batch_size and request.batch_size < 0:
raise errors.bad_request.ValidationError(
"batch_size should be non negative number"
)
if not request.scroll_id:
from_key_value = None
total = None
request.batch_size = request.batch_size or 10_000
else:
try:
scroll = ScalarMetricsIterRawScroll.from_scroll_id(request.scroll_id)
@ -861,9 +867,7 @@ def scalar_metrics_iter_raw(
from_key_value = scalar_key.cast_value(scroll.from_key_value)
total = scroll.total
scroll.request.batch_size = request.batch_size or scroll.request.batch_size
request = scroll.request
request.batch_size = request.batch_size or scroll.request.batch_size
task_id = request.task
@ -884,38 +888,45 @@ def scalar_metrics_iter_raw(
batch_size = min(
request.batch_size,
int(
config.get("services.events.events_retrieval.max_raw_scalars_size", 10_000)
config.get("services.events.events_retrieval.max_raw_scalars_size", 200_000)
),
)
res = event_bll.events_iterator.get_task_events(
event_type=EventType.metrics_scalar,
company_id=task.company,
task_id=task_id,
batch_size=batch_size,
navigate_earlier=False,
from_key_value=from_key_value,
metric_variants=metric_variants,
key=key,
)
events = []
for iteration in range(0, math.ceil(batch_size / 10_000)):
res = event_bll.events_iterator.get_task_events(
event_type=EventType.metrics_scalar,
company_id=task.company,
task_id=task_id,
batch_size=min(batch_size, 10_000),
navigate_earlier=False,
from_key_value=from_key_value,
metric_variants=metric_variants,
key=key,
)
if not res.events:
break
events.extend(res.events)
from_key_value = str(events[-1][scalar_key.field])
key = str(key)
variants = {
variant: extract_properties_to_lists(
["value", scalar_key.field], events, target_keys=["y", key]
)
for variant, events in bucketize(res.events, key=itemgetter("variant")).items()
for variant, events in bucketize(events, key=itemgetter("variant")).items()
}
call.kpis["events"] = len(events)
scroll = ScalarMetricsIterRawScroll(
from_key_value=str(res.events[-1][scalar_key.field]) if res.events else None,
from_key_value=str(events[-1][scalar_key.field]) if events else None,
total=total,
request=request,
)
return make_response(
returned=len(res.events),
returned=len(events),
total=total,
scroll_id=scroll.get_scroll_id(),
variants=variants,

View File

@ -12,9 +12,6 @@ from apiserver.tests.automated import TestService
class TestTaskEvents(TestService):
def setUp(self, version="2.9"):
super().setUp(version=version)
def _temp_task(self, name="test task events"):
task_input = dict(
name=name, type="training", input=dict(mapping={}, view=dict(entries=[])),
@ -257,6 +254,45 @@ class TestTaskEvents(TestService):
self.assertEqual(len(task_data["x"]), iterations)
self.assertEqual(len(task_data["y"]), iterations)
def test_task_metric_raw(self):
metric = "Metric1"
variant = "Variant1"
iter_count = 100
task = self._temp_task()
events = [
{
**self._create_task_event("training_stats_scalar", task, iteration),
"metric": metric,
"variant": variant,
"value": iteration,
}
for iteration in range(iter_count)
]
self.send_batch(events)
batch_size = 15
metric_param = {"metric": metric, "variants": [variant]}
res = self.api.events.scalar_metrics_iter_raw(
task=task, batch_size=batch_size, metric=metric_param, count_total=True
)
self.assertEqual(res.total, len(events))
self.assertTrue(res.scroll_id)
res_iters = []
res_ys = []
calls = 0
while res.returned or calls > 10:
calls += 1
res_iters.extend(res.variants[variant]["iter"])
res_ys.extend(res.variants[variant]["y"])
scroll_id = res.scroll_id
res = self.api.events.scalar_metrics_iter_raw(
task=task, metric=metric_param, scroll_id=scroll_id
)
self.assertEqual(calls, len(events) // batch_size + 1)
self.assertEqual(res_iters, [ev["iter"] for ev in events])
self.assertEqual(res_ys, [ev["value"] for ev in events])
def test_task_metric_value_intervals(self):
metric = "Metric1"
variant = "Variant1"