Add _any_/_all_ queries support for datetime fields

This commit is contained in:
clearml 2024-12-05 22:25:08 +02:00
parent 073cc96fb8
commit 57ce9446b1
7 changed files with 119 additions and 80 deletions

View File

@ -1,5 +1,5 @@
import re
from collections import namedtuple, defaultdict
from collections import defaultdict
from datetime import datetime
from functools import reduce, partial
from typing import (
@ -107,7 +107,18 @@ class GetMixin(PropsMixin):
("_any_", "_or_"): lambda a, b: a | b,
("_all_", "_and_"): lambda a, b: a & b,
}
MultiFieldParameters = namedtuple("MultiFieldParameters", "pattern fields")
@attr.s(auto_attribs=True)
class MultiFieldParameters:
fields: Sequence[str]
pattern: str = None
datetime: Union[list, str] = None
def __attrs_post_init__(self):
if not any(f is not None for f in (self.pattern, self.datetime)):
raise ValueError("Either 'pattern' or 'datetime' should be provided")
if all(f is not None for f in (self.pattern, self.datetime)):
raise ValueError("Only one of the 'pattern' and 'datetime' can be provided")
_numeric_locale = {"locale": "en_US", "numericOrdering": True}
_field_collation_overrides = {}
@ -323,6 +334,8 @@ class GetMixin(PropsMixin):
specific rules on handling values). Only items matching ALL of these conditions will be retrieved.
- <any|all>: {fields: [<field1>, <field2>, ...], pattern: <pattern>} Will query for items where any or all
provided fields match the provided pattern.
- <any|all>: {fields: [<field1>, <field2>, ...], datetime: <datetime condition>} Will query for items where any or all
provided datetime fields match the provided condition.
:return: mongoengine.Q query object
"""
return cls._prepare_query_no_company(
@ -376,6 +389,46 @@ class GetMixin(PropsMixin):
return cls._try_convert_to_numeric(value)
return value
@classmethod
def _get_dates_query(cls, field: str, data: Union[list, str]) -> Union[Q, dict]:
"""
Return dates query for the field
If the data is 2 values array and none of the values starts from dates comparison operations
then return the simplified range query
Otherwise return the dictionary of dates conditions
"""
if not isinstance(data, list):
data = [data]
if len(data) == 2 and not any(
d.startswith(mod)
for d in data
if d is not None
for mod in ACCESS_MODIFIER
):
return cls.get_range_field_query(field, data)
dict_query = {}
for d in data:
m = ACCESS_REGEX.match(d)
if not m:
continue
try:
value = parse_datetime(m.group("value"))
prefix = m.group("prefix")
modifier = ACCESS_MODIFIER.get(prefix)
f = (
field
if not modifier
else "__".join((field, modifier))
)
dict_query[f] = value
except (ValueError, OverflowError):
pass
return dict_query
@classmethod
def _prepare_query_no_company(
cls, parameters=None, parameters_options=QueryParameterOptions()
@ -446,33 +499,11 @@ class GetMixin(PropsMixin):
for field in opts.datetime_fields or []:
data = parameters.pop(field, None)
if data is not None:
if not isinstance(data, list):
data = [data]
# date time fields also support simplified range queries. Check if this is the case
if len(data) == 2 and not any(
d.startswith(mod)
for d in data
if d is not None
for mod in ACCESS_MODIFIER
):
query &= cls.get_range_field_query(field, data)
else:
for d in data: # type: str
m = ACCESS_REGEX.match(d)
if not m:
continue
try:
value = parse_datetime(m.group("value"))
prefix = m.group("prefix")
modifier = ACCESS_MODIFIER.get(prefix)
f = (
field
if not modifier
else "__".join((field, modifier))
)
dict_query[f] = value
except (ValueError, OverflowError):
pass
dates_q = cls._get_dates_query(field, data)
if isinstance(dates_q, Q):
query &= dates_q
elif isinstance(dates_q, dict):
dict_query.update(dates_q)
for field, value in parameters.items():
for keys, func in cls._multi_field_param_prefix.items():
@ -484,27 +515,40 @@ class GetMixin(PropsMixin):
raise MakeGetAllQueryError("incorrect field format", field)
if not data.fields:
break
if any("._" in f for f in data.fields):
q = reduce(
lambda a, x: func(
a,
RegexQ(
__raw__={
x: {"$regex": data.pattern, "$options": "i"}
}
if data.pattern is not None:
if any("._" in f for f in data.fields):
q = reduce(
lambda a, x: func(
a,
RegexQ(
__raw__={
x: {"$regex": data.pattern, "$options": "i"}
}
),
),
),
data.fields,
RegexQ(),
)
data.fields,
RegexQ(),
)
else:
regex = RegexWrapper(data.pattern, flags=re.IGNORECASE)
sep_fields = [f.replace(".", "__") for f in data.fields]
q = reduce(
lambda a, x: func(a, RegexQ(**{x: regex})),
sep_fields,
RegexQ(),
)
else:
regex = RegexWrapper(data.pattern, flags=re.IGNORECASE)
sep_fields = [f.replace(".", "__") for f in data.fields]
q = reduce(
lambda a, x: func(a, RegexQ(**{x: regex})),
sep_fields,
RegexQ(),
)
date_fields = [field for field in data.fields if field in opts.datetime_fields]
if not date_fields:
break
q = Q()
for date_f in date_fields:
dates_q = cls._get_dates_query(date_f, data.datetime)
if isinstance(dates_q, dict):
dates_q = RegexQ(**dates_q)
q = func(q, dates_q)
query = query & q
except APIError:
raise

View File

@ -79,8 +79,8 @@ class Model(AttributedDocument):
"parent",
"metadata.*",
),
range_fields=("last_metrics.*", "last_iteration"),
datetime_fields=("last_update",),
range_fields=("created", "last_metrics.*", "last_iteration"),
datetime_fields=("last_update", "last_change"),
)
id = StringField(primary_key=True)

View File

@ -244,7 +244,7 @@ class Task(AttributedDocument):
"models.input.model",
),
range_fields=("created", "started", "active_duration", "last_metrics.*", "last_iteration"),
datetime_fields=("status_changed", "last_update"),
datetime_fields=("status_changed", "last_update", "last_change"),
pattern_fields=("name", "comment", "report"),
fields=("runtime.*",),
)

View File

@ -74,7 +74,11 @@ multi_field_pattern_data {
type: object
properties {
pattern {
description: "Pattern string (regex)"
description: "Pattern string (regex). Either 'pattern' or 'datetime' should be specified"
type: string
}
datetime {
description: "Date time conditions (applicable only to datetime fields). Either 'pattern' or 'datetime' should be specified"
type: string
}
fields {

View File

@ -1,20 +1,6 @@
_description: """This service provides a management interface for models (results of training tasks) stored in the system."""
_definitions {
include "_tasks_common.conf"
multi_field_pattern_data {
type: object
properties {
pattern {
description: "Pattern string (regex)"
type: string
}
fields {
description: "List of field names"
type: array
items { type: string }
}
}
}
model {
type: object
properties {

View File

@ -1,20 +1,6 @@
_description: "Provides support for defining Projects containing Tasks, Models and Dataset Versions."
_definitions {
include "_common.conf"
multi_field_pattern_data {
type: object
properties {
pattern {
description: "Pattern string (regex)"
type: string
}
fields {
description: "List of field names"
type: array
items { type: string }
}
}
}
project {
type: object
properties {

View File

@ -71,6 +71,16 @@ class TestTasksFiltering(TestService):
).tasks
self.assertFalse(set(tasks).issubset({t.id for t in res}))
# _any_/_all_ queries
res = self.api.tasks.get_all_ex(
**{"_any_": {"datetime": f">={now.isoformat()}", "fields": ["last_update", "status_changed"]}}
).tasks
self.assertTrue(set(tasks).issubset({t.id for t in res}))
res = self.api.tasks.get_all_ex(
**{"_all_": {"datetime": f">={now.isoformat()}", "fields": ["last_update", "status_changed"]}}
).tasks
self.assertFalse(set(tasks).issubset({t.id for t in res}))
# simplified range syntax
res = self.api.tasks.get_all_ex(last_update=[now.isoformat(), None]).tasks
self.assertTrue(set(tasks).issubset({t.id for t in res}))
@ -80,6 +90,15 @@ class TestTasksFiltering(TestService):
).tasks
self.assertFalse(set(tasks).issubset({t.id for t in res}))
res = self.api.tasks.get_all_ex(
**{"_any_": {"datetime": [now.isoformat(), None], "fields": ["last_update", "status_changed"]}}
).tasks
self.assertTrue(set(tasks).issubset({t.id for t in res}))
res = self.api.tasks.get_all_ex(
**{"_all_": {"datetime": [now.isoformat(), None], "fields": ["last_update", "status_changed"]}}
).tasks
self.assertFalse(set(tasks).issubset({t.id for t in res}))
def test_range_queries(self):
tasks = [self.temp_task() for _ in range(5)]
now = datetime.utcnow()