Support __$and condition for excluded terms in get_all_ex endpoints list filters

This commit is contained in:
allegroai 2023-07-26 18:26:49 +03:00
parent 8135cf5258
commit 011164ce9b
5 changed files with 135 additions and 99 deletions

View File

@ -13,7 +13,8 @@ from typing import (
TypeVar,
Callable,
Mapping,
Any, Union,
Any,
Union,
)
from mongoengine import Q, Document
@ -1089,12 +1090,15 @@ class ProjectBLL:
raise errors.bad_request.ValidationError(
f"List of strings expected for the field: {field}"
)
helper = GetMixin.ListFieldBucketHelper(field, legacy=True)
actions = helper.get_actions(field_filter)
conditions[field] = {
f"${action}": list(set(actions[action]))
for action in filter(None, actions)
}
helper = GetMixin.NewListFieldBucketHelper(
field, data=field_filter, legacy=True
)
conditions[field] = {}
for action, values in helper.actions.items():
value = list(set(values))
for key in reversed(action.split("__")):
value = {f"${key}": value}
conditions[field].update(value)
return conditions

View File

@ -1,5 +1,5 @@
import re
from collections import namedtuple
from collections import namedtuple, defaultdict
from functools import reduce, partial
from typing import (
Collection,
@ -11,10 +11,10 @@ from typing import (
Mapping,
Any,
Callable,
Dict,
List,
Generator,
)
import attr
from boltons.iterutils import first, partition
from dateutil.parser import parse as parse_datetime
from mongoengine import Q, Document, ListField, StringField, IntField, QuerySet
@ -22,6 +22,7 @@ from pymongo.command_cursor import CommandCursor
from apiserver.apierrors import errors, APIError
from apiserver.apierrors.base import BaseError
from apiserver.apierrors.errors.bad_request import FieldsValueError
from apiserver.bll.redis_cache_manager import RedisCacheManager
from apiserver.config_repo import config
from apiserver.database import Database
@ -132,94 +133,117 @@ class GetMixin(PropsMixin):
self.range_fields = range_fields
self.pattern_fields = pattern_fields
class ListFieldBucketHelper:
class NewListFieldBucketHelper:
op_prefix = "__$"
_legacy_exclude_prefix = "-"
_legacy_exclude_mongo_op = "nin"
default_mongo_op = "in"
_ops = {
# op -> (mongo_op, sticky)
"not": ("nin", False),
"nop": (default_mongo_op, False),
"all": ("all", True),
"and": ("all", True),
"any": (default_mongo_op, True),
"or": (default_mongo_op, True),
_unary_operators = {
"__$not": False,
"__$nop": True,
}
_operators = {
"__$all": Q.AND,
"__$and": Q.AND,
"__$any": Q.OR,
"__$or": Q.OR,
}
default_operator = Q.OR
mongo_modifiers = {
Q.AND: {True: "all", False: "not__all"},
Q.OR: {True: "in", False: "nin"},
}
def __init__(self, field, legacy=False):
@attr.s(auto_attribs=True)
class Term:
operator: str = None
include: bool = True
value: str = None
def __init__(self, field: str, data: Sequence[str], legacy=False):
self._field = field
self._current_op = None
self._sticky = False
self._support_legacy = legacy
self.allow_empty = False
self.global_operator = None
self.actions = defaultdict(list)
def _get_op(self, v: str, translate: bool = False) -> Optional[str]:
try:
op = (
v[len(self.op_prefix) :]
if v and v.startswith(self.op_prefix)
else None
)
if translate:
tup = self._ops.get(op, None)
return tup[0] if tup else None
return op
except AttributeError:
raise errors.bad_request.FieldsValueError(
"invalid value type, string expected",
field=self._field,
value=str(v),
)
def _key(self, v) -> Optional[Union[str, bool]]:
if v is None:
self.allow_empty = True
return None
op = self._get_op(v)
if op is not None:
# operator - set state and return None
self._current_op, self._sticky = self._ops.get(
op, (self.default_mongo_op, self._sticky)
)
return None
elif self._current_op:
current_op = self._current_op
if not self._sticky:
self._current_op = None
return current_op
elif self._support_legacy and v.startswith(self._legacy_exclude_prefix):
self._current_op = None
return False
return self.default_mongo_op
def get_global_op(self, data: Sequence[str]) -> int:
op_to_res = {
"in": Q.OR,
"all": Q.AND,
}
data = (x for x in data if x is not None)
first_op = (
self._get_op(next(data, ""), translate=True) or self.default_mongo_op
)
return op_to_res.get(first_op, self.default_mongo_op)
def get_actions(self, data: Sequence[str]) -> Dict[str, List[Union[str, None]]]:
actions = {}
for val in data:
key = self._key(val)
if key is None:
current_context = self.default_operator
for d in self._get_next_term(data):
if d.operator is not None:
current_context = d.operator
if self.global_operator is None:
self.global_operator = d.operator
continue
elif self._support_legacy and key is False:
key = self._legacy_exclude_mongo_op
val = val[len(self._legacy_exclude_prefix) :]
actions.setdefault(key, []).append(val)
return actions
if self.global_operator is None:
self.global_operator = self.default_operator
if d.value is None:
self.allow_empty = True
continue
self.actions[self.mongo_modifiers[current_context][d.include]].append(
d.value
)
if self.global_operator is None:
self.global_operator = self.default_operator
def _get_next_term(self, data: Sequence[str]) -> Generator[Term, None, None]:
unary_operator = None
for value in data:
if value is None:
unary_operator = None
yield self.Term()
continue
if not isinstance(value, str):
raise FieldsValueError(
"invalid value type, string expected",
field=self._field,
value=str(value),
)
if value.startswith(self.op_prefix):
if unary_operator:
raise FieldsValueError(
"Value is expected after",
field=self._field,
operator=unary_operator,
)
if value in self._unary_operators:
unary_operator = value
continue
operator = self._operators.get(value)
if operator is None:
raise FieldsValueError(
"Unsupported operator",
field=self._field,
operator=value,
)
yield self.Term(operator=operator)
continue
if self._support_legacy and value.startswith("-"):
value = value[1:]
if not value:
raise FieldsValueError(
"Missing value after the exclude prefix -",
field=self._field,
value=value,
)
unary_operator = None
yield self.Term(value=value, include=False)
continue
term = self.Term(value=value)
if unary_operator:
term.include = self._unary_operators[unary_operator]
unary_operator = None
yield term
if unary_operator:
raise FieldsValueError(
"Value is expected after", operator=unary_operator
)
get_all_query_options = QueryParameterOptions()
@ -507,15 +531,15 @@ class GetMixin(PropsMixin):
if not isinstance(data, (list, tuple)):
data = [data]
helper = cls.ListFieldBucketHelper(field, legacy=True)
global_op = helper.get_global_op(data)
actions = helper.get_actions(data)
helper = cls.NewListFieldBucketHelper(field, data=data, legacy=True)
global_op = helper.global_operator
actions = helper.actions
mongoengine_field = field.replace(".", "__")
queries = [
RegexQ(**{f"{mongoengine_field}__{action}": list(set(actions[action]))})
for action in filter(None, actions)
RegexQ(**{f"{mongoengine_field}__{action}": list(set(values))})
for action, values in actions.items()
]
if not queries:

View File

@ -655,7 +655,11 @@ class APICall(DataContainer):
}
if self.content_type.lower() == JSON_CONTENT_TYPE:
try:
func = json.dumps if self._json_flags.pop("ensure_ascii", True) else json.dumps_notascii
func = (
json.dumps
if self._json_flags.pop("ensure_ascii", True)
else json.dumps_notascii
)
res = func(res, **(self._json_flags or {}))
except Exception as ex:
# JSON serialization may fail, probably problem with data or error_data so pop it and try again
@ -685,8 +689,12 @@ class APICall(DataContainer):
cookies=self._result.cookies,
)
def get_redacted_headers(self):
headers = self.headers.copy()
def get_redacted_headers(self, fields=None):
headers = (
{k: v for k, v in self._headers.items() if k in fields}
if fields
else self.headers
)
if not self.requires_authorization or self.auth:
# We won't log the authorization header if call shouldn't be authorized, or if it was successfully
# authorized. This means we'll only log authorization header for calls that failed to authorize (hopefully

View File

@ -101,7 +101,7 @@ def _upgrade_tags(call: APICall, tags: Sequence, system_tags: Sequence):
def validate_tags(tags: Sequence[str], system_tags: Sequence[str]):
for values in filter(None, (tags, system_tags)):
unsupported = [
t for t in values if t.startswith(GetMixin.ListFieldBucketHelper.op_prefix)
t for t in values if t.startswith(GetMixin.NewListFieldBucketHelper.op_prefix)
]
if unsupported:
raise errors.bad_request.FieldsValueError(

View File

@ -51,7 +51,7 @@ class TestUsersService(TestService):
self._assertUsers((user_2, user_3), users)
# specific project
users = self.api.users.get_all_ex(active_in_projects=[project]).users
users = self.api.users.get_all_ex(id=user_ids, active_in_projects=[project]).users
self._assertUsers((user_3,), users)
def _assertUsers(self, expected: Sequence, users: Sequence):