Support __$and condition for excluded terms in get_all_ex endpoints list filters

This commit is contained in:
allegroai 2023-07-26 18:26:49 +03:00
parent 8135cf5258
commit 011164ce9b
5 changed files with 135 additions and 99 deletions
apiserver
bll/project
database/model
service_repo
services
tests/automated

View File

@ -13,7 +13,8 @@ from typing import (
TypeVar, TypeVar,
Callable, Callable,
Mapping, Mapping,
Any, Union, Any,
Union,
) )
from mongoengine import Q, Document from mongoengine import Q, Document
@ -1089,12 +1090,15 @@ class ProjectBLL:
raise errors.bad_request.ValidationError( raise errors.bad_request.ValidationError(
f"List of strings expected for the field: {field}" f"List of strings expected for the field: {field}"
) )
helper = GetMixin.ListFieldBucketHelper(field, legacy=True) helper = GetMixin.NewListFieldBucketHelper(
actions = helper.get_actions(field_filter) field, data=field_filter, legacy=True
conditions[field] = { )
f"${action}": list(set(actions[action])) conditions[field] = {}
for action in filter(None, actions) for action, values in helper.actions.items():
} value = list(set(values))
for key in reversed(action.split("__")):
value = {f"${key}": value}
conditions[field].update(value)
return conditions return conditions

View File

@ -1,5 +1,5 @@
import re import re
from collections import namedtuple from collections import namedtuple, defaultdict
from functools import reduce, partial from functools import reduce, partial
from typing import ( from typing import (
Collection, Collection,
@ -11,10 +11,10 @@ from typing import (
Mapping, Mapping,
Any, Any,
Callable, Callable,
Dict, Generator,
List,
) )
import attr
from boltons.iterutils import first, partition from boltons.iterutils import first, partition
from dateutil.parser import parse as parse_datetime from dateutil.parser import parse as parse_datetime
from mongoengine import Q, Document, ListField, StringField, IntField, QuerySet from mongoengine import Q, Document, ListField, StringField, IntField, QuerySet
@ -22,6 +22,7 @@ from pymongo.command_cursor import CommandCursor
from apiserver.apierrors import errors, APIError from apiserver.apierrors import errors, APIError
from apiserver.apierrors.base import BaseError from apiserver.apierrors.base import BaseError
from apiserver.apierrors.errors.bad_request import FieldsValueError
from apiserver.bll.redis_cache_manager import RedisCacheManager from apiserver.bll.redis_cache_manager import RedisCacheManager
from apiserver.config_repo import config from apiserver.config_repo import config
from apiserver.database import Database from apiserver.database import Database
@ -132,94 +133,117 @@ class GetMixin(PropsMixin):
self.range_fields = range_fields self.range_fields = range_fields
self.pattern_fields = pattern_fields self.pattern_fields = pattern_fields
class ListFieldBucketHelper: class NewListFieldBucketHelper:
op_prefix = "__$" op_prefix = "__$"
_legacy_exclude_prefix = "-" _unary_operators = {
_legacy_exclude_mongo_op = "nin" "__$not": False,
"__$nop": True,
default_mongo_op = "in" }
_ops = { _operators = {
# op -> (mongo_op, sticky) "__$all": Q.AND,
"not": ("nin", False), "__$and": Q.AND,
"nop": (default_mongo_op, False), "__$any": Q.OR,
"all": ("all", True), "__$or": Q.OR,
"and": ("all", True), }
"any": (default_mongo_op, True), default_operator = Q.OR
"or": (default_mongo_op, True), mongo_modifiers = {
Q.AND: {True: "all", False: "not__all"},
Q.OR: {True: "in", False: "nin"},
} }
def __init__(self, field, legacy=False): @attr.s(auto_attribs=True)
class Term:
operator: str = None
include: bool = True
value: str = None
def __init__(self, field: str, data: Sequence[str], legacy=False):
self._field = field self._field = field
self._current_op = None
self._sticky = False
self._support_legacy = legacy self._support_legacy = legacy
self.allow_empty = False self.allow_empty = False
self.global_operator = None
self.actions = defaultdict(list)
def _get_op(self, v: str, translate: bool = False) -> Optional[str]: current_context = self.default_operator
try: for d in self._get_next_term(data):
op = ( if d.operator is not None:
v[len(self.op_prefix) :] current_context = d.operator
if v and v.startswith(self.op_prefix) if self.global_operator is None:
else None self.global_operator = d.operator
continue
if self.global_operator is None:
self.global_operator = self.default_operator
if d.value is None:
self.allow_empty = True
continue
self.actions[self.mongo_modifiers[current_context][d.include]].append(
d.value
) )
if translate:
tup = self._ops.get(op, None) if self.global_operator is None:
return tup[0] if tup else None self.global_operator = self.default_operator
return op
except AttributeError: def _get_next_term(self, data: Sequence[str]) -> Generator[Term, None, None]:
raise errors.bad_request.FieldsValueError( unary_operator = None
for value in data:
if value is None:
unary_operator = None
yield self.Term()
continue
if not isinstance(value, str):
raise FieldsValueError(
"invalid value type, string expected", "invalid value type, string expected",
field=self._field, field=self._field,
value=str(v), value=str(value),
) )
def _key(self, v) -> Optional[Union[str, bool]]: if value.startswith(self.op_prefix):
if v is None: if unary_operator:
self.allow_empty = True raise FieldsValueError(
return None "Value is expected after",
field=self._field,
op = self._get_op(v) operator=unary_operator,
if op is not None:
# operator - set state and return None
self._current_op, self._sticky = self._ops.get(
op, (self.default_mongo_op, self._sticky)
) )
return None if value in self._unary_operators:
elif self._current_op: unary_operator = value
current_op = self._current_op
if not self._sticky:
self._current_op = None
return current_op
elif self._support_legacy and v.startswith(self._legacy_exclude_prefix):
self._current_op = None
return False
return self.default_mongo_op
def get_global_op(self, data: Sequence[str]) -> int:
op_to_res = {
"in": Q.OR,
"all": Q.AND,
}
data = (x for x in data if x is not None)
first_op = (
self._get_op(next(data, ""), translate=True) or self.default_mongo_op
)
return op_to_res.get(first_op, self.default_mongo_op)
def get_actions(self, data: Sequence[str]) -> Dict[str, List[Union[str, None]]]:
actions = {}
for val in data:
key = self._key(val)
if key is None:
continue continue
elif self._support_legacy and key is False:
key = self._legacy_exclude_mongo_op
val = val[len(self._legacy_exclude_prefix) :]
actions.setdefault(key, []).append(val)
return actions operator = self._operators.get(value)
if operator is None:
raise FieldsValueError(
"Unsupported operator",
field=self._field,
operator=value,
)
yield self.Term(operator=operator)
continue
if self._support_legacy and value.startswith("-"):
value = value[1:]
if not value:
raise FieldsValueError(
"Missing value after the exclude prefix -",
field=self._field,
value=value,
)
unary_operator = None
yield self.Term(value=value, include=False)
continue
term = self.Term(value=value)
if unary_operator:
term.include = self._unary_operators[unary_operator]
unary_operator = None
yield term
if unary_operator:
raise FieldsValueError(
"Value is expected after", operator=unary_operator
)
get_all_query_options = QueryParameterOptions() get_all_query_options = QueryParameterOptions()
@ -507,15 +531,15 @@ class GetMixin(PropsMixin):
if not isinstance(data, (list, tuple)): if not isinstance(data, (list, tuple)):
data = [data] data = [data]
helper = cls.ListFieldBucketHelper(field, legacy=True) helper = cls.NewListFieldBucketHelper(field, data=data, legacy=True)
global_op = helper.get_global_op(data) global_op = helper.global_operator
actions = helper.get_actions(data) actions = helper.actions
mongoengine_field = field.replace(".", "__") mongoengine_field = field.replace(".", "__")
queries = [ queries = [
RegexQ(**{f"{mongoengine_field}__{action}": list(set(actions[action]))}) RegexQ(**{f"{mongoengine_field}__{action}": list(set(values))})
for action in filter(None, actions) for action, values in actions.items()
] ]
if not queries: if not queries:

View File

@ -655,7 +655,11 @@ class APICall(DataContainer):
} }
if self.content_type.lower() == JSON_CONTENT_TYPE: if self.content_type.lower() == JSON_CONTENT_TYPE:
try: try:
func = json.dumps if self._json_flags.pop("ensure_ascii", True) else json.dumps_notascii func = (
json.dumps
if self._json_flags.pop("ensure_ascii", True)
else json.dumps_notascii
)
res = func(res, **(self._json_flags or {})) res = func(res, **(self._json_flags or {}))
except Exception as ex: except Exception as ex:
# JSON serialization may fail, probably problem with data or error_data so pop it and try again # JSON serialization may fail, probably problem with data or error_data so pop it and try again
@ -685,8 +689,12 @@ class APICall(DataContainer):
cookies=self._result.cookies, cookies=self._result.cookies,
) )
def get_redacted_headers(self): def get_redacted_headers(self, fields=None):
headers = self.headers.copy() headers = (
{k: v for k, v in self._headers.items() if k in fields}
if fields
else self.headers
)
if not self.requires_authorization or self.auth: if not self.requires_authorization or self.auth:
# We won't log the authorization header if call shouldn't be authorized, or if it was successfully # We won't log the authorization header if call shouldn't be authorized, or if it was successfully
# authorized. This means we'll only log authorization header for calls that failed to authorize (hopefully # authorized. This means we'll only log authorization header for calls that failed to authorize (hopefully

View File

@ -101,7 +101,7 @@ def _upgrade_tags(call: APICall, tags: Sequence, system_tags: Sequence):
def validate_tags(tags: Sequence[str], system_tags: Sequence[str]): def validate_tags(tags: Sequence[str], system_tags: Sequence[str]):
for values in filter(None, (tags, system_tags)): for values in filter(None, (tags, system_tags)):
unsupported = [ unsupported = [
t for t in values if t.startswith(GetMixin.ListFieldBucketHelper.op_prefix) t for t in values if t.startswith(GetMixin.NewListFieldBucketHelper.op_prefix)
] ]
if unsupported: if unsupported:
raise errors.bad_request.FieldsValueError( raise errors.bad_request.FieldsValueError(

View File

@ -51,7 +51,7 @@ class TestUsersService(TestService):
self._assertUsers((user_2, user_3), users) self._assertUsers((user_2, user_3), users)
# specific project # specific project
users = self.api.users.get_all_ex(active_in_projects=[project]).users users = self.api.users.get_all_ex(id=user_ids, active_in_projects=[project]).users
self._assertUsers((user_3,), users) self._assertUsers((user_3,), users)
def _assertUsers(self, expected: Sequence, users: Sequence): def _assertUsers(self, expected: Sequence, users: Sequence):