diff --git a/apiserver/apimodels/events.py b/apiserver/apimodels/events.py index 422b6c8..a1ad685 100644 --- a/apiserver/apimodels/events.py +++ b/apiserver/apimodels/events.py @@ -102,6 +102,7 @@ class LogEventsRequest(TaskEventsRequestBase): class ScalarMetricsIterRawRequest(TaskEventsRequestBase): + batch_size: int = IntField() key: ScalarKeyEnum = ActualEnumField(ScalarKeyEnum, default=ScalarKeyEnum.iter) metric: MetricVariants = EmbeddedField(MetricVariants, required=True) count_total: bool = BoolField(default=False) diff --git a/apiserver/config/default/services/events.conf b/apiserver/config/default/services/events.conf index 8053568..8e44c3e 100644 --- a/apiserver/config/default/services/events.conf +++ b/apiserver/config/default/services/events.conf @@ -18,7 +18,7 @@ events_retrieval { # the max amount of variants to aggregate on max_variants_count: 100 - max_raw_scalars_size: 10000 + max_raw_scalars_size: 200000 scroll_id_key: "cTN5VEtWEC6QrHvUl0FTx9kNyO0CcCK1p57akxma" } diff --git a/apiserver/database/model/base.py b/apiserver/database/model/base.py index aba7b9d..bd7d625 100644 --- a/apiserver/database/model/base.py +++ b/apiserver/database/model/base.py @@ -1,9 +1,21 @@ import re from collections import namedtuple from functools import reduce, partial -from typing import Collection, Sequence, Union, Optional, Type, Tuple, Mapping, Any, Callable +from typing import ( + Collection, + Sequence, + Union, + Optional, + Type, + Tuple, + Mapping, + Any, + Callable, + Dict, + List, +) -from boltons.iterutils import first, bucketize, partition +from boltons.iterutils import first, partition from dateutil.parser import parse as parse_datetime from mongoengine import Q, Document, ListField, StringField, IntField from pymongo.command_cursor import CommandCursor @@ -109,43 +121,55 @@ class GetMixin(PropsMixin): class ListFieldBucketHelper: op_prefix = "__$" - legacy_exclude_prefix = "-" + _legacy_exclude_prefix = "-" + _legacy_exclude_mongo_op = "nin" - _default = "in" + default_mongo_op = "in" _ops = { + # op -> (mongo_op, sticky) "not": ("nin", False), + "nop": (default_mongo_op, False), "all": ("all", True), "and": ("all", True), } - _next = _default - _sticky = False def __init__(self, legacy=False): - self._legacy = legacy + self._current_op = None + self._sticky = False + self._support_legacy = legacy - def key(self, v) -> Optional[str]: + def _key(self, v) -> Optional[Union[str, bool]]: if v is None: - self._next = self._default - return self._default - elif self._legacy and v.startswith(self.legacy_exclude_prefix): - self._next = self._default - return self._ops["not"][0] + self._current_op = None + self._sticky = False + return self.default_mongo_op + elif self._current_op: + current_op = self._current_op + if not self._sticky: + self._current_op = None + return current_op + elif self._support_legacy and v.startswith(self._legacy_exclude_prefix): + self._current_op = None + return False elif v.startswith(self.op_prefix): - self._next, self._sticky = self._ops.get( - v[len(self.op_prefix) :], (self._default, self._sticky) + self._current_op, self._sticky = self._ops.get( + v[len(self.op_prefix):], (self.default_mongo_op, self._sticky) ) return None - next_ = self._next - if not self._sticky: - self._next = self._default + return self.default_mongo_op - return next_ - - def value_transform(self, v): - if self._legacy and v and v.startswith(self.legacy_exclude_prefix): - return v[len(self.legacy_exclude_prefix) :] - return v + def get_actions(self, data: Sequence[str]) -> Dict[str, List[Union[str, None]]]: + actions = {} + for val in data: + key = self._key(val) + if key is None: + continue + elif self._support_legacy and key is False: + key = self._legacy_exclude_mongo_op + val = val[len(self._legacy_exclude_prefix) :] + actions.setdefault(key, []).append(val) + return actions get_all_query_options = QueryParameterOptions() @@ -261,7 +285,9 @@ class GetMixin(PropsMixin): Prepare a query object based on the provided query dictionary and various fields. NOTE: BE VERY CAREFUL WITH THIS CALL, as it allows creating queries that span across companies. - + IMPLEMENTATION NOTE: Make sure that inside this function or the function it depends on RegexQ is always + used instead of Q. Otherwise we can and up with some combination that is not processed according to + RegexQ rules :param parameters_options: Specifies options for parsing the parameters (see ParametersOptions) :param parameters: Query dictionary (relevant keys are these specified by the various field names parameters). Supported parameters: @@ -298,7 +324,7 @@ class GetMixin(PropsMixin): patterns=opts.fields or [], parameters=parameters ).items(): if "._" in field or "_." in field: - query &= Q(__raw__={field: data}) + query &= RegexQ(__raw__={field: data}) else: dict_query[field.replace(".", "__")] = data @@ -332,22 +358,31 @@ class GetMixin(PropsMixin): break if any("._" in f for f in data.fields): q = reduce( - lambda a, x: func(a, Q(__raw__={x: {"$regex": data.pattern, "$options": "i"}})), + lambda a, x: func( + a, + RegexQ( + __raw__={ + x: {"$regex": data.pattern, "$options": "i"} + } + ), + ), data.fields, - Q() + RegexQ(), ) else: regex = RegexWrapper(data.pattern, flags=re.IGNORECASE) sep_fields = [f.replace(".", "__") for f in data.fields] q = reduce( - lambda a, x: func(a, RegexQ(**{x: regex})), sep_fields, RegexQ() + lambda a, x: func(a, RegexQ(**{x: regex})), + sep_fields, + RegexQ(), ) query = query & q return query & RegexQ(**dict_query) @classmethod - def get_range_field_query(cls, field: str, data: Sequence[Optional[str]]) -> Q: + def get_range_field_query(cls, field: str, data: Sequence[Optional[str]]) -> RegexQ: """ Return a range query for the provided field. The data should contain min and max values Both intervals are included. For open range queries either min or max can be None @@ -371,14 +406,14 @@ class GetMixin(PropsMixin): if max_val is not None: query[f"{mongoengine_field}__lte"] = max_val - q = Q(**query) + q = RegexQ(**query) if min_val is None: - q |= Q(**{mongoengine_field: None}) + q |= RegexQ(**{mongoengine_field: None}) return q @classmethod - def get_list_field_query(cls, field: str, data: Sequence[Optional[str]]) -> Q: + def get_list_field_query(cls, field: str, data: Sequence[Optional[str]]) -> RegexQ: """ Get a proper mongoengine Q object that represents an "or" query for the provided values with respect to the given list field, with support for "none of empty" in case a None value @@ -392,13 +427,15 @@ class GetMixin(PropsMixin): data = [data] # raise MakeGetAllQueryError("expected list", field) - # TODO: backwards compatibility only for older API versions - helper = cls.ListFieldBucketHelper(legacy=True) - actions = bucketize( - data, key=helper.key, value_transform=helper.value_transform - ) + actions = cls.ListFieldBucketHelper(legacy=True).get_actions(data) + allow_empty = False + default_op_actions = actions.get(cls.ListFieldBucketHelper.default_mongo_op) + if default_op_actions and None in default_op_actions: + allow_empty = True + default_op_actions.remove(None) + if not default_op_actions: + actions.pop(cls.ListFieldBucketHelper.default_mongo_op) - allow_empty = None in actions.get("in", {}) mongoengine_field = field.replace(".", "__") q = RegexQ() @@ -412,8 +449,9 @@ class GetMixin(PropsMixin): return ( q - | Q(**{f"{mongoengine_field}__exists": False}) - | Q(**{mongoengine_field: []}) + | RegexQ(**{f"{mongoengine_field}__exists": False}) + | RegexQ(**{mongoengine_field: []}) + | RegexQ(**{mongoengine_field: None}) ) @classmethod @@ -701,9 +739,7 @@ class GetMixin(PropsMixin): override_collation=override_collation, ) return cls.get_data_with_scroll_and_filter_support( - query_dict=query_dict, - data_getter=data_getter, - ret_params=ret_params, + query_dict=query_dict, data_getter=data_getter, ret_params=ret_params, ) return cls._get_many_no_company( diff --git a/apiserver/schema/services/events.conf b/apiserver/schema/services/events.conf index 5746b74..024c77d 100644 --- a/apiserver/schema/services/events.conf +++ b/apiserver/schema/services/events.conf @@ -1244,9 +1244,9 @@ scalar_metrics_iter_raw { "$ref": "#/definitions/scalar_key_enum" } batch_size { - description: "The number of data points to return for this call. Optional, the default value is 5000" + description: "The number of data points to return for this call. Optional, the default value is 10000. Maximum batch size is 200000" type: integer - default: 5000 + default: 10000 } count_total { description: "Count the total number of data points. If false, total number of data points is not counted and null is returned" diff --git a/apiserver/schema/services/projects.conf b/apiserver/schema/services/projects.conf index 4552b6d..fd5b38f 100644 --- a/apiserver/schema/services/projects.conf +++ b/apiserver/schema/services/projects.conf @@ -540,7 +540,7 @@ get_all_ex { } } "999.0": ${get_all_ex."2.15"} { - response.properties.stats_with_children { + request.properties.stats_with_children { description: "If include_stats flag is set then this flag contols whether the child projects tasks are taken into statistics or not" type: boolean default: true diff --git a/apiserver/services/events.py b/apiserver/services/events.py index 8df8c91..01cf14b 100644 --- a/apiserver/services/events.py +++ b/apiserver/services/events.py @@ -1,4 +1,5 @@ import itertools +import math from collections import defaultdict from operator import itemgetter from typing import Sequence, Optional @@ -844,10 +845,15 @@ def scalar_metrics_iter_raw( ): key = request.key or ScalarKeyEnum.iter scalar_key = ScalarKey.resolve(key) + if request.batch_size and request.batch_size < 0: + raise errors.bad_request.ValidationError( + "batch_size should be non negative number" + ) if not request.scroll_id: from_key_value = None total = None + request.batch_size = request.batch_size or 10_000 else: try: scroll = ScalarMetricsIterRawScroll.from_scroll_id(request.scroll_id) @@ -861,9 +867,7 @@ def scalar_metrics_iter_raw( from_key_value = scalar_key.cast_value(scroll.from_key_value) total = scroll.total - - scroll.request.batch_size = request.batch_size or scroll.request.batch_size - request = scroll.request + request.batch_size = request.batch_size or scroll.request.batch_size task_id = request.task @@ -884,38 +888,45 @@ def scalar_metrics_iter_raw( batch_size = min( request.batch_size, int( - config.get("services.events.events_retrieval.max_raw_scalars_size", 10_000) + config.get("services.events.events_retrieval.max_raw_scalars_size", 200_000) ), ) - res = event_bll.events_iterator.get_task_events( - event_type=EventType.metrics_scalar, - company_id=task.company, - task_id=task_id, - batch_size=batch_size, - navigate_earlier=False, - from_key_value=from_key_value, - metric_variants=metric_variants, - key=key, - ) + events = [] + for iteration in range(0, math.ceil(batch_size / 10_000)): + res = event_bll.events_iterator.get_task_events( + event_type=EventType.metrics_scalar, + company_id=task.company, + task_id=task_id, + batch_size=min(batch_size, 10_000), + navigate_earlier=False, + from_key_value=from_key_value, + metric_variants=metric_variants, + key=key, + ) + if not res.events: + break + events.extend(res.events) + from_key_value = str(events[-1][scalar_key.field]) key = str(key) - variants = { variant: extract_properties_to_lists( ["value", scalar_key.field], events, target_keys=["y", key] ) - for variant, events in bucketize(res.events, key=itemgetter("variant")).items() + for variant, events in bucketize(events, key=itemgetter("variant")).items() } + call.kpis["events"] = len(events) + scroll = ScalarMetricsIterRawScroll( - from_key_value=str(res.events[-1][scalar_key.field]) if res.events else None, + from_key_value=str(events[-1][scalar_key.field]) if events else None, total=total, request=request, ) return make_response( - returned=len(res.events), + returned=len(events), total=total, scroll_id=scroll.get_scroll_id(), variants=variants, diff --git a/apiserver/tests/automated/test_task_events.py b/apiserver/tests/automated/test_task_events.py index 984ebcc..ca28b08 100644 --- a/apiserver/tests/automated/test_task_events.py +++ b/apiserver/tests/automated/test_task_events.py @@ -12,9 +12,6 @@ from apiserver.tests.automated import TestService class TestTaskEvents(TestService): - def setUp(self, version="2.9"): - super().setUp(version=version) - def _temp_task(self, name="test task events"): task_input = dict( name=name, type="training", input=dict(mapping={}, view=dict(entries=[])), @@ -257,6 +254,45 @@ class TestTaskEvents(TestService): self.assertEqual(len(task_data["x"]), iterations) self.assertEqual(len(task_data["y"]), iterations) + def test_task_metric_raw(self): + metric = "Metric1" + variant = "Variant1" + iter_count = 100 + task = self._temp_task() + events = [ + { + **self._create_task_event("training_stats_scalar", task, iteration), + "metric": metric, + "variant": variant, + "value": iteration, + } + for iteration in range(iter_count) + ] + self.send_batch(events) + + batch_size = 15 + metric_param = {"metric": metric, "variants": [variant]} + res = self.api.events.scalar_metrics_iter_raw( + task=task, batch_size=batch_size, metric=metric_param, count_total=True + ) + self.assertEqual(res.total, len(events)) + self.assertTrue(res.scroll_id) + res_iters = [] + res_ys = [] + calls = 0 + while res.returned or calls > 10: + calls += 1 + res_iters.extend(res.variants[variant]["iter"]) + res_ys.extend(res.variants[variant]["y"]) + scroll_id = res.scroll_id + res = self.api.events.scalar_metrics_iter_raw( + task=task, metric=metric_param, scroll_id=scroll_id + ) + + self.assertEqual(calls, len(events) // batch_size + 1) + self.assertEqual(res_iters, [ev["iter"] for ev in events]) + self.assertEqual(res_ys, [ev["value"] for ev in events]) + def test_task_metric_value_intervals(self): metric = "Metric1" variant = "Variant1"