diff --git a/apiserver/database/model/base.py b/apiserver/database/model/base.py index 3443e4b..0d5c5cd 100644 --- a/apiserver/database/model/base.py +++ b/apiserver/database/model/base.py @@ -1,5 +1,5 @@ import re -from collections import namedtuple, defaultdict +from collections import defaultdict from datetime import datetime from functools import reduce, partial from typing import ( @@ -107,7 +107,18 @@ class GetMixin(PropsMixin): ("_any_", "_or_"): lambda a, b: a | b, ("_all_", "_and_"): lambda a, b: a & b, } - MultiFieldParameters = namedtuple("MultiFieldParameters", "pattern fields") + + @attr.s(auto_attribs=True) + class MultiFieldParameters: + fields: Sequence[str] + pattern: str = None + datetime: Union[list, str] = None + + def __attrs_post_init__(self): + if not any(f is not None for f in (self.pattern, self.datetime)): + raise ValueError("Either 'pattern' or 'datetime' should be provided") + if all(f is not None for f in (self.pattern, self.datetime)): + raise ValueError("Only one of the 'pattern' and 'datetime' can be provided") _numeric_locale = {"locale": "en_US", "numericOrdering": True} _field_collation_overrides = {} @@ -323,6 +334,8 @@ class GetMixin(PropsMixin): specific rules on handling values). Only items matching ALL of these conditions will be retrieved. - : {fields: [, , ...], pattern: } Will query for items where any or all provided fields match the provided pattern. + - : {fields: [, , ...], datetime: } Will query for items where any or all + provided datetime fields match the provided condition. :return: mongoengine.Q query object """ return cls._prepare_query_no_company( @@ -376,6 +389,46 @@ class GetMixin(PropsMixin): return cls._try_convert_to_numeric(value) return value + @classmethod + def _get_dates_query(cls, field: str, data: Union[list, str]) -> Union[Q, dict]: + """ + Return dates query for the field + If the data is 2 values array and none of the values starts from dates comparison operations + then return the simplified range query + Otherwise return the dictionary of dates conditions + """ + if not isinstance(data, list): + data = [data] + + if len(data) == 2 and not any( + d.startswith(mod) + for d in data + if d is not None + for mod in ACCESS_MODIFIER + ): + return cls.get_range_field_query(field, data) + + dict_query = {} + for d in data: + m = ACCESS_REGEX.match(d) + if not m: + continue + + try: + value = parse_datetime(m.group("value")) + prefix = m.group("prefix") + modifier = ACCESS_MODIFIER.get(prefix) + f = ( + field + if not modifier + else "__".join((field, modifier)) + ) + dict_query[f] = value + except (ValueError, OverflowError): + pass + + return dict_query + @classmethod def _prepare_query_no_company( cls, parameters=None, parameters_options=QueryParameterOptions() @@ -446,33 +499,11 @@ class GetMixin(PropsMixin): for field in opts.datetime_fields or []: data = parameters.pop(field, None) if data is not None: - if not isinstance(data, list): - data = [data] - # date time fields also support simplified range queries. Check if this is the case - if len(data) == 2 and not any( - d.startswith(mod) - for d in data - if d is not None - for mod in ACCESS_MODIFIER - ): - query &= cls.get_range_field_query(field, data) - else: - for d in data: # type: str - m = ACCESS_REGEX.match(d) - if not m: - continue - try: - value = parse_datetime(m.group("value")) - prefix = m.group("prefix") - modifier = ACCESS_MODIFIER.get(prefix) - f = ( - field - if not modifier - else "__".join((field, modifier)) - ) - dict_query[f] = value - except (ValueError, OverflowError): - pass + dates_q = cls._get_dates_query(field, data) + if isinstance(dates_q, Q): + query &= dates_q + elif isinstance(dates_q, dict): + dict_query.update(dates_q) for field, value in parameters.items(): for keys, func in cls._multi_field_param_prefix.items(): @@ -484,27 +515,40 @@ class GetMixin(PropsMixin): raise MakeGetAllQueryError("incorrect field format", field) if not data.fields: break - if any("._" in f for f in data.fields): - q = reduce( - lambda a, x: func( - a, - RegexQ( - __raw__={ - x: {"$regex": data.pattern, "$options": "i"} - } + if data.pattern is not None: + if any("._" in f for f in data.fields): + q = reduce( + lambda a, x: func( + a, + RegexQ( + __raw__={ + x: {"$regex": data.pattern, "$options": "i"} + } + ), ), - ), - data.fields, - RegexQ(), - ) + data.fields, + RegexQ(), + ) + else: + regex = RegexWrapper(data.pattern, flags=re.IGNORECASE) + sep_fields = [f.replace(".", "__") for f in data.fields] + q = reduce( + lambda a, x: func(a, RegexQ(**{x: regex})), + sep_fields, + RegexQ(), + ) else: - regex = RegexWrapper(data.pattern, flags=re.IGNORECASE) - sep_fields = [f.replace(".", "__") for f in data.fields] - q = reduce( - lambda a, x: func(a, RegexQ(**{x: regex})), - sep_fields, - RegexQ(), - ) + date_fields = [field for field in data.fields if field in opts.datetime_fields] + if not date_fields: + break + + q = Q() + for date_f in date_fields: + dates_q = cls._get_dates_query(date_f, data.datetime) + if isinstance(dates_q, dict): + dates_q = RegexQ(**dates_q) + q = func(q, dates_q) + query = query & q except APIError: raise diff --git a/apiserver/database/model/model.py b/apiserver/database/model/model.py index 9cf3bd1..977e3e4 100644 --- a/apiserver/database/model/model.py +++ b/apiserver/database/model/model.py @@ -79,8 +79,8 @@ class Model(AttributedDocument): "parent", "metadata.*", ), - range_fields=("last_metrics.*", "last_iteration"), - datetime_fields=("last_update",), + range_fields=("created", "last_metrics.*", "last_iteration"), + datetime_fields=("last_update", "last_change"), ) id = StringField(primary_key=True) diff --git a/apiserver/database/model/task/task.py b/apiserver/database/model/task/task.py index 15b87f0..42c7816 100644 --- a/apiserver/database/model/task/task.py +++ b/apiserver/database/model/task/task.py @@ -244,7 +244,7 @@ class Task(AttributedDocument): "models.input.model", ), range_fields=("created", "started", "active_duration", "last_metrics.*", "last_iteration"), - datetime_fields=("status_changed", "last_update"), + datetime_fields=("status_changed", "last_update", "last_change"), pattern_fields=("name", "comment", "report"), fields=("runtime.*",), ) diff --git a/apiserver/schema/services/_common.conf b/apiserver/schema/services/_common.conf index 6982e52..f075047 100644 --- a/apiserver/schema/services/_common.conf +++ b/apiserver/schema/services/_common.conf @@ -74,7 +74,11 @@ multi_field_pattern_data { type: object properties { pattern { - description: "Pattern string (regex)" + description: "Pattern string (regex). Either 'pattern' or 'datetime' should be specified" + type: string + } + datetime { + description: "Date time conditions (applicable only to datetime fields). Either 'pattern' or 'datetime' should be specified" type: string } fields { diff --git a/apiserver/schema/services/models.conf b/apiserver/schema/services/models.conf index 574a2fd..a9d02e7 100644 --- a/apiserver/schema/services/models.conf +++ b/apiserver/schema/services/models.conf @@ -1,20 +1,6 @@ _description: """This service provides a management interface for models (results of training tasks) stored in the system.""" _definitions { include "_tasks_common.conf" - multi_field_pattern_data { - type: object - properties { - pattern { - description: "Pattern string (regex)" - type: string - } - fields { - description: "List of field names" - type: array - items { type: string } - } - } - } model { type: object properties { diff --git a/apiserver/schema/services/projects.conf b/apiserver/schema/services/projects.conf index 168cd7c..06b26d6 100644 --- a/apiserver/schema/services/projects.conf +++ b/apiserver/schema/services/projects.conf @@ -1,20 +1,6 @@ _description: "Provides support for defining Projects containing Tasks, Models and Dataset Versions." _definitions { include "_common.conf" - multi_field_pattern_data { - type: object - properties { - pattern { - description: "Pattern string (regex)" - type: string - } - fields { - description: "List of field names" - type: array - items { type: string } - } - } - } project { type: object properties { diff --git a/apiserver/tests/automated/test_tasks_filtering.py b/apiserver/tests/automated/test_tasks_filtering.py index 5b52564..72e12fb 100644 --- a/apiserver/tests/automated/test_tasks_filtering.py +++ b/apiserver/tests/automated/test_tasks_filtering.py @@ -71,6 +71,16 @@ class TestTasksFiltering(TestService): ).tasks self.assertFalse(set(tasks).issubset({t.id for t in res})) + # _any_/_all_ queries + res = self.api.tasks.get_all_ex( + **{"_any_": {"datetime": f">={now.isoformat()}", "fields": ["last_update", "status_changed"]}} + ).tasks + self.assertTrue(set(tasks).issubset({t.id for t in res})) + res = self.api.tasks.get_all_ex( + **{"_all_": {"datetime": f">={now.isoformat()}", "fields": ["last_update", "status_changed"]}} + ).tasks + self.assertFalse(set(tasks).issubset({t.id for t in res})) + # simplified range syntax res = self.api.tasks.get_all_ex(last_update=[now.isoformat(), None]).tasks self.assertTrue(set(tasks).issubset({t.id for t in res})) @@ -80,6 +90,15 @@ class TestTasksFiltering(TestService): ).tasks self.assertFalse(set(tasks).issubset({t.id for t in res})) + res = self.api.tasks.get_all_ex( + **{"_any_": {"datetime": [now.isoformat(), None], "fields": ["last_update", "status_changed"]}} + ).tasks + self.assertTrue(set(tasks).issubset({t.id for t in res})) + res = self.api.tasks.get_all_ex( + **{"_all_": {"datetime": [now.isoformat(), None], "fields": ["last_update", "status_changed"]}} + ).tasks + self.assertFalse(set(tasks).issubset({t.id for t in res})) + def test_range_queries(self): tasks = [self.temp_task() for _ in range(5)] now = datetime.utcnow()