Add range queries

Switch from sematic_version to packaging.version in db migrations
This commit is contained in:
allegroai 2021-05-03 17:33:47 +03:00
parent 1d55710a0b
commit 4cd4b2914d
4 changed files with 87 additions and 10 deletions

View File

@ -1,7 +1,7 @@
import re
from collections import namedtuple
from functools import reduce
from typing import Collection, Sequence, Union, Optional, Type, Tuple
from typing import Collection, Sequence, Union, Optional, Type, Tuple, Mapping, Any
from boltons.iterutils import first, bucketize, partition
from dateutil.parser import parse as parse_datetime
@ -86,6 +86,7 @@ class GetMixin(PropsMixin):
list_fields=("tags", "system_tags", "id"),
datetime_fields=None,
fields=None,
range_fields=None,
):
"""
:param pattern_fields: Fields for which a "string contains" condition should be generated
@ -97,6 +98,7 @@ class GetMixin(PropsMixin):
self.fields = fields
self.datetime_fields = datetime_fields
self.list_fields = list_fields
self.range_fields = range_fields
self.pattern_fields = pattern_fields
class ListFieldBucketHelper:
@ -183,6 +185,32 @@ class GetMixin(PropsMixin):
parameters, parameters_options
) & cls._prepare_perm_query(company, allow_public=allow_public)
@staticmethod
def _pop_matching_params(
patterns: Sequence[str], parameters: dict
) -> Mapping[str, Any]:
"""
Pop the parameters that match the specified patterns and return
the dictionary of matching parameters
For backwards compatibility with the previous version of the code
the None values are discarded
"""
if not patterns:
return {}
fields = set()
for pattern in patterns:
if pattern.endswith("*"):
prefix = pattern[:-1]
fields.update(
{field for field in parameters if field.startswith(prefix)}
)
elif pattern in parameters:
fields.add(pattern)
pairs = ((field, parameters.pop(field, None)) for field in fields)
return {k: v for k, v in pairs if v is not None}
@classmethod
def _prepare_query_no_company(
cls, parameters=None, parameters_options=QueryParameterOptions()
@ -212,10 +240,15 @@ class GetMixin(PropsMixin):
if pattern:
dict_query[field] = RegexWrapper(pattern)
for field in tuple(opts.list_fields or ()):
data = parameters.pop(field, None)
if data:
query &= cls.get_list_field_query(field, data)
for field, data in cls._pop_matching_params(
patterns=opts.list_fields, parameters=parameters
).items():
query &= cls.get_list_field_query(field, data)
for field, data in cls._pop_matching_params(
patterns=opts.range_fields, parameters=parameters
).items():
query &= cls.get_range_field_query(field, data)
for field in opts.fields or []:
data = parameters.pop(field, None)
@ -259,6 +292,37 @@ class GetMixin(PropsMixin):
return query & RegexQ(**dict_query)
@classmethod
def get_range_field_query(cls, field: str, data: Sequence[Optional[str]]) -> Q:
"""
Return a range query for the provided field. The data should contain min and max values
Both intervals are included. For open range queries either min or max can be None
In case the min value is None the records with missing or None value from db are included
"""
if not isinstance(data, (list, tuple)) or len(data) != 2:
raise errors.bad_request.ValidationError(
f"Min and max values should be specified for range field {field}"
)
min_val, max_val = data
if min_val is None and max_val is None:
raise errors.bad_request.ValidationError(
f"At least one of min or max values should be provided for field {field}"
)
mongoengine_field = field.replace(".", "__")
query = {}
if min_val is not None:
query[f"{mongoengine_field}__gte"] = min_val
if max_val is not None:
query[f"{mongoengine_field}__lte"] = max_val
q = Q(**query)
if min_val is None:
q |= Q(**{mongoengine_field: None})
return q
@classmethod
def get_list_field_query(cls, field: str, data: Sequence[Optional[str]]) -> Q:
"""

View File

@ -79,7 +79,9 @@ DEFAULT_ARTIFACT_MODE = ArtifactModes.output
class Artifact(EmbeddedDocument):
key = StringField(required=True)
type = StringField(required=True)
mode = StringField(choices=get_options(ArtifactModes), default=DEFAULT_ARTIFACT_MODE)
mode = StringField(
choices=get_options(ArtifactModes), default=DEFAULT_ARTIFACT_MODE
)
uri = StringField()
hash = StringField()
content_size = LongField()
@ -185,8 +187,18 @@ class Task(AttributedDocument):
],
}
get_all_query_options = GetMixin.QueryParameterOptions(
list_fields=("id", "user", "tags", "system_tags", "type", "status", "project", "parent"),
datetime_fields=("status_changed",),
list_fields=(
"id",
"user",
"tags",
"system_tags",
"type",
"status",
"project",
"parent",
),
range_fields=("started", "active_duration", "last_metrics.*"),
datetime_fields=("status_changed", "last_update"),
pattern_fields=("name", "comment"),
)

View File

@ -4,7 +4,7 @@ from logging import Logger
from pathlib import Path
from mongoengine.connection import get_db
from semantic_version import Version
from packaging.version import Version, parse
from apiserver.database import utils
from apiserver.database import Database
@ -50,7 +50,7 @@ def _apply_migrations(log: Logger):
try:
new_scripts = {
ver: path
for ver, path in ((Version(f.stem), f) for f in migration_dir.glob("*.py"))
for ver, path in ((parse(f.stem), f) for f in migration_dir.glob("*.py"))
if ver > last_version
}
except ValueError as ex:

View File

@ -17,6 +17,7 @@ jsonschema>=2.6.0
luqum>=0.10.0
mongoengine==0.19.1
nested_dict>=1.61
packaging==20.3
psutil>=5.6.5
pyhocon>=0.3.35
pyjwt<2.0.0