Add _any_/_all_ queries support for datetime fields

This commit is contained in:
clearml 2024-12-05 22:25:08 +02:00
parent 073cc96fb8
commit 57ce9446b1
7 changed files with 119 additions and 80 deletions

View File

@ -1,5 +1,5 @@
import re import re
from collections import namedtuple, defaultdict from collections import defaultdict
from datetime import datetime from datetime import datetime
from functools import reduce, partial from functools import reduce, partial
from typing import ( from typing import (
@ -107,7 +107,18 @@ class GetMixin(PropsMixin):
("_any_", "_or_"): lambda a, b: a | b, ("_any_", "_or_"): lambda a, b: a | b,
("_all_", "_and_"): lambda a, b: a & b, ("_all_", "_and_"): lambda a, b: a & b,
} }
MultiFieldParameters = namedtuple("MultiFieldParameters", "pattern fields")
@attr.s(auto_attribs=True)
class MultiFieldParameters:
fields: Sequence[str]
pattern: str = None
datetime: Union[list, str] = None
def __attrs_post_init__(self):
if not any(f is not None for f in (self.pattern, self.datetime)):
raise ValueError("Either 'pattern' or 'datetime' should be provided")
if all(f is not None for f in (self.pattern, self.datetime)):
raise ValueError("Only one of the 'pattern' and 'datetime' can be provided")
_numeric_locale = {"locale": "en_US", "numericOrdering": True} _numeric_locale = {"locale": "en_US", "numericOrdering": True}
_field_collation_overrides = {} _field_collation_overrides = {}
@ -323,6 +334,8 @@ class GetMixin(PropsMixin):
specific rules on handling values). Only items matching ALL of these conditions will be retrieved. specific rules on handling values). Only items matching ALL of these conditions will be retrieved.
- <any|all>: {fields: [<field1>, <field2>, ...], pattern: <pattern>} Will query for items where any or all - <any|all>: {fields: [<field1>, <field2>, ...], pattern: <pattern>} Will query for items where any or all
provided fields match the provided pattern. provided fields match the provided pattern.
- <any|all>: {fields: [<field1>, <field2>, ...], datetime: <datetime condition>} Will query for items where any or all
provided datetime fields match the provided condition.
:return: mongoengine.Q query object :return: mongoengine.Q query object
""" """
return cls._prepare_query_no_company( return cls._prepare_query_no_company(
@ -376,6 +389,46 @@ class GetMixin(PropsMixin):
return cls._try_convert_to_numeric(value) return cls._try_convert_to_numeric(value)
return value return value
@classmethod
def _get_dates_query(cls, field: str, data: Union[list, str]) -> Union[Q, dict]:
"""
Return dates query for the field
If the data is 2 values array and none of the values starts from dates comparison operations
then return the simplified range query
Otherwise return the dictionary of dates conditions
"""
if not isinstance(data, list):
data = [data]
if len(data) == 2 and not any(
d.startswith(mod)
for d in data
if d is not None
for mod in ACCESS_MODIFIER
):
return cls.get_range_field_query(field, data)
dict_query = {}
for d in data:
m = ACCESS_REGEX.match(d)
if not m:
continue
try:
value = parse_datetime(m.group("value"))
prefix = m.group("prefix")
modifier = ACCESS_MODIFIER.get(prefix)
f = (
field
if not modifier
else "__".join((field, modifier))
)
dict_query[f] = value
except (ValueError, OverflowError):
pass
return dict_query
@classmethod @classmethod
def _prepare_query_no_company( def _prepare_query_no_company(
cls, parameters=None, parameters_options=QueryParameterOptions() cls, parameters=None, parameters_options=QueryParameterOptions()
@ -446,33 +499,11 @@ class GetMixin(PropsMixin):
for field in opts.datetime_fields or []: for field in opts.datetime_fields or []:
data = parameters.pop(field, None) data = parameters.pop(field, None)
if data is not None: if data is not None:
if not isinstance(data, list): dates_q = cls._get_dates_query(field, data)
data = [data] if isinstance(dates_q, Q):
# date time fields also support simplified range queries. Check if this is the case query &= dates_q
if len(data) == 2 and not any( elif isinstance(dates_q, dict):
d.startswith(mod) dict_query.update(dates_q)
for d in data
if d is not None
for mod in ACCESS_MODIFIER
):
query &= cls.get_range_field_query(field, data)
else:
for d in data: # type: str
m = ACCESS_REGEX.match(d)
if not m:
continue
try:
value = parse_datetime(m.group("value"))
prefix = m.group("prefix")
modifier = ACCESS_MODIFIER.get(prefix)
f = (
field
if not modifier
else "__".join((field, modifier))
)
dict_query[f] = value
except (ValueError, OverflowError):
pass
for field, value in parameters.items(): for field, value in parameters.items():
for keys, func in cls._multi_field_param_prefix.items(): for keys, func in cls._multi_field_param_prefix.items():
@ -484,6 +515,7 @@ class GetMixin(PropsMixin):
raise MakeGetAllQueryError("incorrect field format", field) raise MakeGetAllQueryError("incorrect field format", field)
if not data.fields: if not data.fields:
break break
if data.pattern is not None:
if any("._" in f for f in data.fields): if any("._" in f for f in data.fields):
q = reduce( q = reduce(
lambda a, x: func( lambda a, x: func(
@ -505,6 +537,18 @@ class GetMixin(PropsMixin):
sep_fields, sep_fields,
RegexQ(), RegexQ(),
) )
else:
date_fields = [field for field in data.fields if field in opts.datetime_fields]
if not date_fields:
break
q = Q()
for date_f in date_fields:
dates_q = cls._get_dates_query(date_f, data.datetime)
if isinstance(dates_q, dict):
dates_q = RegexQ(**dates_q)
q = func(q, dates_q)
query = query & q query = query & q
except APIError: except APIError:
raise raise

View File

@ -79,8 +79,8 @@ class Model(AttributedDocument):
"parent", "parent",
"metadata.*", "metadata.*",
), ),
range_fields=("last_metrics.*", "last_iteration"), range_fields=("created", "last_metrics.*", "last_iteration"),
datetime_fields=("last_update",), datetime_fields=("last_update", "last_change"),
) )
id = StringField(primary_key=True) id = StringField(primary_key=True)

View File

@ -244,7 +244,7 @@ class Task(AttributedDocument):
"models.input.model", "models.input.model",
), ),
range_fields=("created", "started", "active_duration", "last_metrics.*", "last_iteration"), range_fields=("created", "started", "active_duration", "last_metrics.*", "last_iteration"),
datetime_fields=("status_changed", "last_update"), datetime_fields=("status_changed", "last_update", "last_change"),
pattern_fields=("name", "comment", "report"), pattern_fields=("name", "comment", "report"),
fields=("runtime.*",), fields=("runtime.*",),
) )

View File

@ -74,7 +74,11 @@ multi_field_pattern_data {
type: object type: object
properties { properties {
pattern { pattern {
description: "Pattern string (regex)" description: "Pattern string (regex). Either 'pattern' or 'datetime' should be specified"
type: string
}
datetime {
description: "Date time conditions (applicable only to datetime fields). Either 'pattern' or 'datetime' should be specified"
type: string type: string
} }
fields { fields {

View File

@ -1,20 +1,6 @@
_description: """This service provides a management interface for models (results of training tasks) stored in the system.""" _description: """This service provides a management interface for models (results of training tasks) stored in the system."""
_definitions { _definitions {
include "_tasks_common.conf" include "_tasks_common.conf"
multi_field_pattern_data {
type: object
properties {
pattern {
description: "Pattern string (regex)"
type: string
}
fields {
description: "List of field names"
type: array
items { type: string }
}
}
}
model { model {
type: object type: object
properties { properties {

View File

@ -1,20 +1,6 @@
_description: "Provides support for defining Projects containing Tasks, Models and Dataset Versions." _description: "Provides support for defining Projects containing Tasks, Models and Dataset Versions."
_definitions { _definitions {
include "_common.conf" include "_common.conf"
multi_field_pattern_data {
type: object
properties {
pattern {
description: "Pattern string (regex)"
type: string
}
fields {
description: "List of field names"
type: array
items { type: string }
}
}
}
project { project {
type: object type: object
properties { properties {

View File

@ -71,6 +71,16 @@ class TestTasksFiltering(TestService):
).tasks ).tasks
self.assertFalse(set(tasks).issubset({t.id for t in res})) self.assertFalse(set(tasks).issubset({t.id for t in res}))
# _any_/_all_ queries
res = self.api.tasks.get_all_ex(
**{"_any_": {"datetime": f">={now.isoformat()}", "fields": ["last_update", "status_changed"]}}
).tasks
self.assertTrue(set(tasks).issubset({t.id for t in res}))
res = self.api.tasks.get_all_ex(
**{"_all_": {"datetime": f">={now.isoformat()}", "fields": ["last_update", "status_changed"]}}
).tasks
self.assertFalse(set(tasks).issubset({t.id for t in res}))
# simplified range syntax # simplified range syntax
res = self.api.tasks.get_all_ex(last_update=[now.isoformat(), None]).tasks res = self.api.tasks.get_all_ex(last_update=[now.isoformat(), None]).tasks
self.assertTrue(set(tasks).issubset({t.id for t in res})) self.assertTrue(set(tasks).issubset({t.id for t in res}))
@ -80,6 +90,15 @@ class TestTasksFiltering(TestService):
).tasks ).tasks
self.assertFalse(set(tasks).issubset({t.id for t in res})) self.assertFalse(set(tasks).issubset({t.id for t in res}))
res = self.api.tasks.get_all_ex(
**{"_any_": {"datetime": [now.isoformat(), None], "fields": ["last_update", "status_changed"]}}
).tasks
self.assertTrue(set(tasks).issubset({t.id for t in res}))
res = self.api.tasks.get_all_ex(
**{"_all_": {"datetime": [now.isoformat(), None], "fields": ["last_update", "status_changed"]}}
).tasks
self.assertFalse(set(tasks).issubset({t.id for t in res}))
def test_range_queries(self): def test_range_queries(self):
tasks = [self.temp_task() for _ in range(5)] tasks = [self.temp_task() for _ in range(5)]
now = datetime.utcnow() now = datetime.utcnow()