diff --git a/apiserver/database/model/base.py b/apiserver/database/model/base.py index 6249361..05366ce 100644 --- a/apiserver/database/model/base.py +++ b/apiserver/database/model/base.py @@ -1,7 +1,7 @@ import re from collections import namedtuple from functools import reduce -from typing import Collection, Sequence, Union, Optional, Type, Tuple +from typing import Collection, Sequence, Union, Optional, Type, Tuple, Mapping, Any from boltons.iterutils import first, bucketize, partition from dateutil.parser import parse as parse_datetime @@ -86,6 +86,7 @@ class GetMixin(PropsMixin): list_fields=("tags", "system_tags", "id"), datetime_fields=None, fields=None, + range_fields=None, ): """ :param pattern_fields: Fields for which a "string contains" condition should be generated @@ -97,6 +98,7 @@ class GetMixin(PropsMixin): self.fields = fields self.datetime_fields = datetime_fields self.list_fields = list_fields + self.range_fields = range_fields self.pattern_fields = pattern_fields class ListFieldBucketHelper: @@ -183,6 +185,32 @@ class GetMixin(PropsMixin): parameters, parameters_options ) & cls._prepare_perm_query(company, allow_public=allow_public) + @staticmethod + def _pop_matching_params( + patterns: Sequence[str], parameters: dict + ) -> Mapping[str, Any]: + """ + Pop the parameters that match the specified patterns and return + the dictionary of matching parameters + For backwards compatibility with the previous version of the code + the None values are discarded + """ + if not patterns: + return {} + + fields = set() + for pattern in patterns: + if pattern.endswith("*"): + prefix = pattern[:-1] + fields.update( + {field for field in parameters if field.startswith(prefix)} + ) + elif pattern in parameters: + fields.add(pattern) + + pairs = ((field, parameters.pop(field, None)) for field in fields) + return {k: v for k, v in pairs if v is not None} + @classmethod def _prepare_query_no_company( cls, parameters=None, parameters_options=QueryParameterOptions() @@ -212,10 +240,15 @@ class GetMixin(PropsMixin): if pattern: dict_query[field] = RegexWrapper(pattern) - for field in tuple(opts.list_fields or ()): - data = parameters.pop(field, None) - if data: - query &= cls.get_list_field_query(field, data) + for field, data in cls._pop_matching_params( + patterns=opts.list_fields, parameters=parameters + ).items(): + query &= cls.get_list_field_query(field, data) + + for field, data in cls._pop_matching_params( + patterns=opts.range_fields, parameters=parameters + ).items(): + query &= cls.get_range_field_query(field, data) for field in opts.fields or []: data = parameters.pop(field, None) @@ -259,6 +292,37 @@ class GetMixin(PropsMixin): return query & RegexQ(**dict_query) + @classmethod + def get_range_field_query(cls, field: str, data: Sequence[Optional[str]]) -> Q: + """ + Return a range query for the provided field. The data should contain min and max values + Both intervals are included. For open range queries either min or max can be None + In case the min value is None the records with missing or None value from db are included + """ + if not isinstance(data, (list, tuple)) or len(data) != 2: + raise errors.bad_request.ValidationError( + f"Min and max values should be specified for range field {field}" + ) + + min_val, max_val = data + if min_val is None and max_val is None: + raise errors.bad_request.ValidationError( + f"At least one of min or max values should be provided for field {field}" + ) + + mongoengine_field = field.replace(".", "__") + query = {} + if min_val is not None: + query[f"{mongoengine_field}__gte"] = min_val + if max_val is not None: + query[f"{mongoengine_field}__lte"] = max_val + + q = Q(**query) + if min_val is None: + q |= Q(**{mongoengine_field: None}) + + return q + @classmethod def get_list_field_query(cls, field: str, data: Sequence[Optional[str]]) -> Q: """ diff --git a/apiserver/database/model/task/task.py b/apiserver/database/model/task/task.py index da9fcc7..1497c19 100644 --- a/apiserver/database/model/task/task.py +++ b/apiserver/database/model/task/task.py @@ -79,7 +79,9 @@ DEFAULT_ARTIFACT_MODE = ArtifactModes.output class Artifact(EmbeddedDocument): key = StringField(required=True) type = StringField(required=True) - mode = StringField(choices=get_options(ArtifactModes), default=DEFAULT_ARTIFACT_MODE) + mode = StringField( + choices=get_options(ArtifactModes), default=DEFAULT_ARTIFACT_MODE + ) uri = StringField() hash = StringField() content_size = LongField() @@ -185,8 +187,18 @@ class Task(AttributedDocument): ], } get_all_query_options = GetMixin.QueryParameterOptions( - list_fields=("id", "user", "tags", "system_tags", "type", "status", "project", "parent"), - datetime_fields=("status_changed",), + list_fields=( + "id", + "user", + "tags", + "system_tags", + "type", + "status", + "project", + "parent", + ), + range_fields=("started", "active_duration", "last_metrics.*"), + datetime_fields=("status_changed", "last_update"), pattern_fields=("name", "comment"), ) diff --git a/apiserver/mongo/initialize/migration.py b/apiserver/mongo/initialize/migration.py index 0c259cb..2c66de8 100644 --- a/apiserver/mongo/initialize/migration.py +++ b/apiserver/mongo/initialize/migration.py @@ -4,7 +4,7 @@ from logging import Logger from pathlib import Path from mongoengine.connection import get_db -from semantic_version import Version +from packaging.version import Version, parse from apiserver.database import utils from apiserver.database import Database @@ -50,7 +50,7 @@ def _apply_migrations(log: Logger): try: new_scripts = { ver: path - for ver, path in ((Version(f.stem), f) for f in migration_dir.glob("*.py")) + for ver, path in ((parse(f.stem), f) for f in migration_dir.glob("*.py")) if ver > last_version } except ValueError as ex: diff --git a/apiserver/requirements.txt b/apiserver/requirements.txt index 1b6a130..1d98919 100644 --- a/apiserver/requirements.txt +++ b/apiserver/requirements.txt @@ -17,6 +17,7 @@ jsonschema>=2.6.0 luqum>=0.10.0 mongoengine==0.19.1 nested_dict>=1.61 +packaging==20.3 psutil>=5.6.5 pyhocon>=0.3.35 pyjwt<2.0.0