mirror of
https://github.com/clearml/clearml-server
synced 2025-01-30 18:36:52 +00:00
Add _any_/_all_ queries support for datetime fields
This commit is contained in:
parent
073cc96fb8
commit
57ce9446b1
@ -1,5 +1,5 @@
|
||||
import re
|
||||
from collections import namedtuple, defaultdict
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
from functools import reduce, partial
|
||||
from typing import (
|
||||
@ -107,7 +107,18 @@ class GetMixin(PropsMixin):
|
||||
("_any_", "_or_"): lambda a, b: a | b,
|
||||
("_all_", "_and_"): lambda a, b: a & b,
|
||||
}
|
||||
MultiFieldParameters = namedtuple("MultiFieldParameters", "pattern fields")
|
||||
|
||||
@attr.s(auto_attribs=True)
|
||||
class MultiFieldParameters:
|
||||
fields: Sequence[str]
|
||||
pattern: str = None
|
||||
datetime: Union[list, str] = None
|
||||
|
||||
def __attrs_post_init__(self):
|
||||
if not any(f is not None for f in (self.pattern, self.datetime)):
|
||||
raise ValueError("Either 'pattern' or 'datetime' should be provided")
|
||||
if all(f is not None for f in (self.pattern, self.datetime)):
|
||||
raise ValueError("Only one of the 'pattern' and 'datetime' can be provided")
|
||||
|
||||
_numeric_locale = {"locale": "en_US", "numericOrdering": True}
|
||||
_field_collation_overrides = {}
|
||||
@ -323,6 +334,8 @@ class GetMixin(PropsMixin):
|
||||
specific rules on handling values). Only items matching ALL of these conditions will be retrieved.
|
||||
- <any|all>: {fields: [<field1>, <field2>, ...], pattern: <pattern>} Will query for items where any or all
|
||||
provided fields match the provided pattern.
|
||||
- <any|all>: {fields: [<field1>, <field2>, ...], datetime: <datetime condition>} Will query for items where any or all
|
||||
provided datetime fields match the provided condition.
|
||||
:return: mongoengine.Q query object
|
||||
"""
|
||||
return cls._prepare_query_no_company(
|
||||
@ -376,6 +389,46 @@ class GetMixin(PropsMixin):
|
||||
return cls._try_convert_to_numeric(value)
|
||||
return value
|
||||
|
||||
@classmethod
|
||||
def _get_dates_query(cls, field: str, data: Union[list, str]) -> Union[Q, dict]:
|
||||
"""
|
||||
Return dates query for the field
|
||||
If the data is 2 values array and none of the values starts from dates comparison operations
|
||||
then return the simplified range query
|
||||
Otherwise return the dictionary of dates conditions
|
||||
"""
|
||||
if not isinstance(data, list):
|
||||
data = [data]
|
||||
|
||||
if len(data) == 2 and not any(
|
||||
d.startswith(mod)
|
||||
for d in data
|
||||
if d is not None
|
||||
for mod in ACCESS_MODIFIER
|
||||
):
|
||||
return cls.get_range_field_query(field, data)
|
||||
|
||||
dict_query = {}
|
||||
for d in data:
|
||||
m = ACCESS_REGEX.match(d)
|
||||
if not m:
|
||||
continue
|
||||
|
||||
try:
|
||||
value = parse_datetime(m.group("value"))
|
||||
prefix = m.group("prefix")
|
||||
modifier = ACCESS_MODIFIER.get(prefix)
|
||||
f = (
|
||||
field
|
||||
if not modifier
|
||||
else "__".join((field, modifier))
|
||||
)
|
||||
dict_query[f] = value
|
||||
except (ValueError, OverflowError):
|
||||
pass
|
||||
|
||||
return dict_query
|
||||
|
||||
@classmethod
|
||||
def _prepare_query_no_company(
|
||||
cls, parameters=None, parameters_options=QueryParameterOptions()
|
||||
@ -446,33 +499,11 @@ class GetMixin(PropsMixin):
|
||||
for field in opts.datetime_fields or []:
|
||||
data = parameters.pop(field, None)
|
||||
if data is not None:
|
||||
if not isinstance(data, list):
|
||||
data = [data]
|
||||
# date time fields also support simplified range queries. Check if this is the case
|
||||
if len(data) == 2 and not any(
|
||||
d.startswith(mod)
|
||||
for d in data
|
||||
if d is not None
|
||||
for mod in ACCESS_MODIFIER
|
||||
):
|
||||
query &= cls.get_range_field_query(field, data)
|
||||
else:
|
||||
for d in data: # type: str
|
||||
m = ACCESS_REGEX.match(d)
|
||||
if not m:
|
||||
continue
|
||||
try:
|
||||
value = parse_datetime(m.group("value"))
|
||||
prefix = m.group("prefix")
|
||||
modifier = ACCESS_MODIFIER.get(prefix)
|
||||
f = (
|
||||
field
|
||||
if not modifier
|
||||
else "__".join((field, modifier))
|
||||
)
|
||||
dict_query[f] = value
|
||||
except (ValueError, OverflowError):
|
||||
pass
|
||||
dates_q = cls._get_dates_query(field, data)
|
||||
if isinstance(dates_q, Q):
|
||||
query &= dates_q
|
||||
elif isinstance(dates_q, dict):
|
||||
dict_query.update(dates_q)
|
||||
|
||||
for field, value in parameters.items():
|
||||
for keys, func in cls._multi_field_param_prefix.items():
|
||||
@ -484,27 +515,40 @@ class GetMixin(PropsMixin):
|
||||
raise MakeGetAllQueryError("incorrect field format", field)
|
||||
if not data.fields:
|
||||
break
|
||||
if any("._" in f for f in data.fields):
|
||||
q = reduce(
|
||||
lambda a, x: func(
|
||||
a,
|
||||
RegexQ(
|
||||
__raw__={
|
||||
x: {"$regex": data.pattern, "$options": "i"}
|
||||
}
|
||||
if data.pattern is not None:
|
||||
if any("._" in f for f in data.fields):
|
||||
q = reduce(
|
||||
lambda a, x: func(
|
||||
a,
|
||||
RegexQ(
|
||||
__raw__={
|
||||
x: {"$regex": data.pattern, "$options": "i"}
|
||||
}
|
||||
),
|
||||
),
|
||||
),
|
||||
data.fields,
|
||||
RegexQ(),
|
||||
)
|
||||
data.fields,
|
||||
RegexQ(),
|
||||
)
|
||||
else:
|
||||
regex = RegexWrapper(data.pattern, flags=re.IGNORECASE)
|
||||
sep_fields = [f.replace(".", "__") for f in data.fields]
|
||||
q = reduce(
|
||||
lambda a, x: func(a, RegexQ(**{x: regex})),
|
||||
sep_fields,
|
||||
RegexQ(),
|
||||
)
|
||||
else:
|
||||
regex = RegexWrapper(data.pattern, flags=re.IGNORECASE)
|
||||
sep_fields = [f.replace(".", "__") for f in data.fields]
|
||||
q = reduce(
|
||||
lambda a, x: func(a, RegexQ(**{x: regex})),
|
||||
sep_fields,
|
||||
RegexQ(),
|
||||
)
|
||||
date_fields = [field for field in data.fields if field in opts.datetime_fields]
|
||||
if not date_fields:
|
||||
break
|
||||
|
||||
q = Q()
|
||||
for date_f in date_fields:
|
||||
dates_q = cls._get_dates_query(date_f, data.datetime)
|
||||
if isinstance(dates_q, dict):
|
||||
dates_q = RegexQ(**dates_q)
|
||||
q = func(q, dates_q)
|
||||
|
||||
query = query & q
|
||||
except APIError:
|
||||
raise
|
||||
|
@ -79,8 +79,8 @@ class Model(AttributedDocument):
|
||||
"parent",
|
||||
"metadata.*",
|
||||
),
|
||||
range_fields=("last_metrics.*", "last_iteration"),
|
||||
datetime_fields=("last_update",),
|
||||
range_fields=("created", "last_metrics.*", "last_iteration"),
|
||||
datetime_fields=("last_update", "last_change"),
|
||||
)
|
||||
|
||||
id = StringField(primary_key=True)
|
||||
|
@ -244,7 +244,7 @@ class Task(AttributedDocument):
|
||||
"models.input.model",
|
||||
),
|
||||
range_fields=("created", "started", "active_duration", "last_metrics.*", "last_iteration"),
|
||||
datetime_fields=("status_changed", "last_update"),
|
||||
datetime_fields=("status_changed", "last_update", "last_change"),
|
||||
pattern_fields=("name", "comment", "report"),
|
||||
fields=("runtime.*",),
|
||||
)
|
||||
|
@ -74,7 +74,11 @@ multi_field_pattern_data {
|
||||
type: object
|
||||
properties {
|
||||
pattern {
|
||||
description: "Pattern string (regex)"
|
||||
description: "Pattern string (regex). Either 'pattern' or 'datetime' should be specified"
|
||||
type: string
|
||||
}
|
||||
datetime {
|
||||
description: "Date time conditions (applicable only to datetime fields). Either 'pattern' or 'datetime' should be specified"
|
||||
type: string
|
||||
}
|
||||
fields {
|
||||
|
@ -1,20 +1,6 @@
|
||||
_description: """This service provides a management interface for models (results of training tasks) stored in the system."""
|
||||
_definitions {
|
||||
include "_tasks_common.conf"
|
||||
multi_field_pattern_data {
|
||||
type: object
|
||||
properties {
|
||||
pattern {
|
||||
description: "Pattern string (regex)"
|
||||
type: string
|
||||
}
|
||||
fields {
|
||||
description: "List of field names"
|
||||
type: array
|
||||
items { type: string }
|
||||
}
|
||||
}
|
||||
}
|
||||
model {
|
||||
type: object
|
||||
properties {
|
||||
|
@ -1,20 +1,6 @@
|
||||
_description: "Provides support for defining Projects containing Tasks, Models and Dataset Versions."
|
||||
_definitions {
|
||||
include "_common.conf"
|
||||
multi_field_pattern_data {
|
||||
type: object
|
||||
properties {
|
||||
pattern {
|
||||
description: "Pattern string (regex)"
|
||||
type: string
|
||||
}
|
||||
fields {
|
||||
description: "List of field names"
|
||||
type: array
|
||||
items { type: string }
|
||||
}
|
||||
}
|
||||
}
|
||||
project {
|
||||
type: object
|
||||
properties {
|
||||
|
@ -71,6 +71,16 @@ class TestTasksFiltering(TestService):
|
||||
).tasks
|
||||
self.assertFalse(set(tasks).issubset({t.id for t in res}))
|
||||
|
||||
# _any_/_all_ queries
|
||||
res = self.api.tasks.get_all_ex(
|
||||
**{"_any_": {"datetime": f">={now.isoformat()}", "fields": ["last_update", "status_changed"]}}
|
||||
).tasks
|
||||
self.assertTrue(set(tasks).issubset({t.id for t in res}))
|
||||
res = self.api.tasks.get_all_ex(
|
||||
**{"_all_": {"datetime": f">={now.isoformat()}", "fields": ["last_update", "status_changed"]}}
|
||||
).tasks
|
||||
self.assertFalse(set(tasks).issubset({t.id for t in res}))
|
||||
|
||||
# simplified range syntax
|
||||
res = self.api.tasks.get_all_ex(last_update=[now.isoformat(), None]).tasks
|
||||
self.assertTrue(set(tasks).issubset({t.id for t in res}))
|
||||
@ -80,6 +90,15 @@ class TestTasksFiltering(TestService):
|
||||
).tasks
|
||||
self.assertFalse(set(tasks).issubset({t.id for t in res}))
|
||||
|
||||
res = self.api.tasks.get_all_ex(
|
||||
**{"_any_": {"datetime": [now.isoformat(), None], "fields": ["last_update", "status_changed"]}}
|
||||
).tasks
|
||||
self.assertTrue(set(tasks).issubset({t.id for t in res}))
|
||||
res = self.api.tasks.get_all_ex(
|
||||
**{"_all_": {"datetime": [now.isoformat(), None], "fields": ["last_update", "status_changed"]}}
|
||||
).tasks
|
||||
self.assertFalse(set(tasks).issubset({t.id for t in res}))
|
||||
|
||||
def test_range_queries(self):
|
||||
tasks = [self.temp_task() for _ in range(5)]
|
||||
now = datetime.utcnow()
|
||||
|
Loading…
Reference in New Issue
Block a user