diff --git a/apiserver/bll/project/project_bll.py b/apiserver/bll/project/project_bll.py index 2ae635b..89ec018 100644 --- a/apiserver/bll/project/project_bll.py +++ b/apiserver/bll/project/project_bll.py @@ -13,7 +13,8 @@ from typing import ( TypeVar, Callable, Mapping, - Any, Union, + Any, + Union, ) from mongoengine import Q, Document @@ -1089,12 +1090,15 @@ class ProjectBLL: raise errors.bad_request.ValidationError( f"List of strings expected for the field: {field}" ) - helper = GetMixin.ListFieldBucketHelper(field, legacy=True) - actions = helper.get_actions(field_filter) - conditions[field] = { - f"${action}": list(set(actions[action])) - for action in filter(None, actions) - } + helper = GetMixin.NewListFieldBucketHelper( + field, data=field_filter, legacy=True + ) + conditions[field] = {} + for action, values in helper.actions.items(): + value = list(set(values)) + for key in reversed(action.split("__")): + value = {f"${key}": value} + conditions[field].update(value) return conditions diff --git a/apiserver/database/model/base.py b/apiserver/database/model/base.py index a865abe..f232662 100644 --- a/apiserver/database/model/base.py +++ b/apiserver/database/model/base.py @@ -1,5 +1,5 @@ import re -from collections import namedtuple +from collections import namedtuple, defaultdict from functools import reduce, partial from typing import ( Collection, @@ -11,10 +11,10 @@ from typing import ( Mapping, Any, Callable, - Dict, - List, + Generator, ) +import attr from boltons.iterutils import first, partition from dateutil.parser import parse as parse_datetime from mongoengine import Q, Document, ListField, StringField, IntField, QuerySet @@ -22,6 +22,7 @@ from pymongo.command_cursor import CommandCursor from apiserver.apierrors import errors, APIError from apiserver.apierrors.base import BaseError +from apiserver.apierrors.errors.bad_request import FieldsValueError from apiserver.bll.redis_cache_manager import RedisCacheManager from apiserver.config_repo import config from apiserver.database import Database @@ -132,94 +133,117 @@ class GetMixin(PropsMixin): self.range_fields = range_fields self.pattern_fields = pattern_fields - class ListFieldBucketHelper: + class NewListFieldBucketHelper: op_prefix = "__$" - _legacy_exclude_prefix = "-" - _legacy_exclude_mongo_op = "nin" - - default_mongo_op = "in" - _ops = { - # op -> (mongo_op, sticky) - "not": ("nin", False), - "nop": (default_mongo_op, False), - "all": ("all", True), - "and": ("all", True), - "any": (default_mongo_op, True), - "or": (default_mongo_op, True), + _unary_operators = { + "__$not": False, + "__$nop": True, + } + _operators = { + "__$all": Q.AND, + "__$and": Q.AND, + "__$any": Q.OR, + "__$or": Q.OR, + } + default_operator = Q.OR + mongo_modifiers = { + Q.AND: {True: "all", False: "not__all"}, + Q.OR: {True: "in", False: "nin"}, } - def __init__(self, field, legacy=False): + @attr.s(auto_attribs=True) + class Term: + operator: str = None + include: bool = True + value: str = None + + def __init__(self, field: str, data: Sequence[str], legacy=False): self._field = field - self._current_op = None - self._sticky = False self._support_legacy = legacy self.allow_empty = False + self.global_operator = None + self.actions = defaultdict(list) - def _get_op(self, v: str, translate: bool = False) -> Optional[str]: - try: - op = ( - v[len(self.op_prefix) :] - if v and v.startswith(self.op_prefix) - else None - ) - if translate: - tup = self._ops.get(op, None) - return tup[0] if tup else None - return op - except AttributeError: - raise errors.bad_request.FieldsValueError( - "invalid value type, string expected", - field=self._field, - value=str(v), - ) - - def _key(self, v) -> Optional[Union[str, bool]]: - if v is None: - self.allow_empty = True - return None - - op = self._get_op(v) - if op is not None: - # operator - set state and return None - self._current_op, self._sticky = self._ops.get( - op, (self.default_mongo_op, self._sticky) - ) - return None - elif self._current_op: - current_op = self._current_op - if not self._sticky: - self._current_op = None - return current_op - elif self._support_legacy and v.startswith(self._legacy_exclude_prefix): - self._current_op = None - return False - - return self.default_mongo_op - - def get_global_op(self, data: Sequence[str]) -> int: - op_to_res = { - "in": Q.OR, - "all": Q.AND, - } - data = (x for x in data if x is not None) - first_op = ( - self._get_op(next(data, ""), translate=True) or self.default_mongo_op - ) - return op_to_res.get(first_op, self.default_mongo_op) - - def get_actions(self, data: Sequence[str]) -> Dict[str, List[Union[str, None]]]: - actions = {} - - for val in data: - key = self._key(val) - if key is None: + current_context = self.default_operator + for d in self._get_next_term(data): + if d.operator is not None: + current_context = d.operator + if self.global_operator is None: + self.global_operator = d.operator continue - elif self._support_legacy and key is False: - key = self._legacy_exclude_mongo_op - val = val[len(self._legacy_exclude_prefix) :] - actions.setdefault(key, []).append(val) - return actions + if self.global_operator is None: + self.global_operator = self.default_operator + + if d.value is None: + self.allow_empty = True + continue + + self.actions[self.mongo_modifiers[current_context][d.include]].append( + d.value + ) + + if self.global_operator is None: + self.global_operator = self.default_operator + + def _get_next_term(self, data: Sequence[str]) -> Generator[Term, None, None]: + unary_operator = None + for value in data: + if value is None: + unary_operator = None + yield self.Term() + continue + + if not isinstance(value, str): + raise FieldsValueError( + "invalid value type, string expected", + field=self._field, + value=str(value), + ) + + if value.startswith(self.op_prefix): + if unary_operator: + raise FieldsValueError( + "Value is expected after", + field=self._field, + operator=unary_operator, + ) + if value in self._unary_operators: + unary_operator = value + continue + + operator = self._operators.get(value) + if operator is None: + raise FieldsValueError( + "Unsupported operator", + field=self._field, + operator=value, + ) + yield self.Term(operator=operator) + continue + + if self._support_legacy and value.startswith("-"): + value = value[1:] + if not value: + raise FieldsValueError( + "Missing value after the exclude prefix -", + field=self._field, + value=value, + ) + unary_operator = None + yield self.Term(value=value, include=False) + continue + + term = self.Term(value=value) + if unary_operator: + term.include = self._unary_operators[unary_operator] + unary_operator = None + yield term + + if unary_operator: + raise FieldsValueError( + "Value is expected after", operator=unary_operator + ) get_all_query_options = QueryParameterOptions() @@ -507,15 +531,15 @@ class GetMixin(PropsMixin): if not isinstance(data, (list, tuple)): data = [data] - helper = cls.ListFieldBucketHelper(field, legacy=True) - global_op = helper.get_global_op(data) - actions = helper.get_actions(data) + helper = cls.NewListFieldBucketHelper(field, data=data, legacy=True) + global_op = helper.global_operator + actions = helper.actions mongoengine_field = field.replace(".", "__") queries = [ - RegexQ(**{f"{mongoengine_field}__{action}": list(set(actions[action]))}) - for action in filter(None, actions) + RegexQ(**{f"{mongoengine_field}__{action}": list(set(values))}) + for action, values in actions.items() ] if not queries: diff --git a/apiserver/service_repo/apicall.py b/apiserver/service_repo/apicall.py index 74141d0..144bc3c 100644 --- a/apiserver/service_repo/apicall.py +++ b/apiserver/service_repo/apicall.py @@ -655,7 +655,11 @@ class APICall(DataContainer): } if self.content_type.lower() == JSON_CONTENT_TYPE: try: - func = json.dumps if self._json_flags.pop("ensure_ascii", True) else json.dumps_notascii + func = ( + json.dumps + if self._json_flags.pop("ensure_ascii", True) + else json.dumps_notascii + ) res = func(res, **(self._json_flags or {})) except Exception as ex: # JSON serialization may fail, probably problem with data or error_data so pop it and try again @@ -685,8 +689,12 @@ class APICall(DataContainer): cookies=self._result.cookies, ) - def get_redacted_headers(self): - headers = self.headers.copy() + def get_redacted_headers(self, fields=None): + headers = ( + {k: v for k, v in self._headers.items() if k in fields} + if fields + else self.headers + ) if not self.requires_authorization or self.auth: # We won't log the authorization header if call shouldn't be authorized, or if it was successfully # authorized. This means we'll only log authorization header for calls that failed to authorize (hopefully diff --git a/apiserver/services/utils.py b/apiserver/services/utils.py index 84b745d..b569aba 100644 --- a/apiserver/services/utils.py +++ b/apiserver/services/utils.py @@ -101,7 +101,7 @@ def _upgrade_tags(call: APICall, tags: Sequence, system_tags: Sequence): def validate_tags(tags: Sequence[str], system_tags: Sequence[str]): for values in filter(None, (tags, system_tags)): unsupported = [ - t for t in values if t.startswith(GetMixin.ListFieldBucketHelper.op_prefix) + t for t in values if t.startswith(GetMixin.NewListFieldBucketHelper.op_prefix) ] if unsupported: raise errors.bad_request.FieldsValueError( diff --git a/apiserver/tests/automated/test_users.py b/apiserver/tests/automated/test_users.py index 4ae0875..cf97428 100644 --- a/apiserver/tests/automated/test_users.py +++ b/apiserver/tests/automated/test_users.py @@ -51,7 +51,7 @@ class TestUsersService(TestService): self._assertUsers((user_2, user_3), users) # specific project - users = self.api.users.get_all_ex(active_in_projects=[project]).users + users = self.api.users.get_all_ex(id=user_ids, active_in_projects=[project]).users self._assertUsers((user_3,), users) def _assertUsers(self, expected: Sequence, users: Sequence):