mirror of
https://github.com/clearml/clearml-server
synced 2025-04-16 13:32:07 +00:00
Support __$and condition for excluded terms in get_all_ex endpoints list filters
This commit is contained in:
parent
8135cf5258
commit
011164ce9b
@ -13,7 +13,8 @@ from typing import (
|
||||
TypeVar,
|
||||
Callable,
|
||||
Mapping,
|
||||
Any, Union,
|
||||
Any,
|
||||
Union,
|
||||
)
|
||||
|
||||
from mongoengine import Q, Document
|
||||
@ -1089,12 +1090,15 @@ class ProjectBLL:
|
||||
raise errors.bad_request.ValidationError(
|
||||
f"List of strings expected for the field: {field}"
|
||||
)
|
||||
helper = GetMixin.ListFieldBucketHelper(field, legacy=True)
|
||||
actions = helper.get_actions(field_filter)
|
||||
conditions[field] = {
|
||||
f"${action}": list(set(actions[action]))
|
||||
for action in filter(None, actions)
|
||||
}
|
||||
helper = GetMixin.NewListFieldBucketHelper(
|
||||
field, data=field_filter, legacy=True
|
||||
)
|
||||
conditions[field] = {}
|
||||
for action, values in helper.actions.items():
|
||||
value = list(set(values))
|
||||
for key in reversed(action.split("__")):
|
||||
value = {f"${key}": value}
|
||||
conditions[field].update(value)
|
||||
|
||||
return conditions
|
||||
|
||||
|
@ -1,5 +1,5 @@
|
||||
import re
|
||||
from collections import namedtuple
|
||||
from collections import namedtuple, defaultdict
|
||||
from functools import reduce, partial
|
||||
from typing import (
|
||||
Collection,
|
||||
@ -11,10 +11,10 @@ from typing import (
|
||||
Mapping,
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Generator,
|
||||
)
|
||||
|
||||
import attr
|
||||
from boltons.iterutils import first, partition
|
||||
from dateutil.parser import parse as parse_datetime
|
||||
from mongoengine import Q, Document, ListField, StringField, IntField, QuerySet
|
||||
@ -22,6 +22,7 @@ from pymongo.command_cursor import CommandCursor
|
||||
|
||||
from apiserver.apierrors import errors, APIError
|
||||
from apiserver.apierrors.base import BaseError
|
||||
from apiserver.apierrors.errors.bad_request import FieldsValueError
|
||||
from apiserver.bll.redis_cache_manager import RedisCacheManager
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.database import Database
|
||||
@ -132,94 +133,117 @@ class GetMixin(PropsMixin):
|
||||
self.range_fields = range_fields
|
||||
self.pattern_fields = pattern_fields
|
||||
|
||||
class ListFieldBucketHelper:
|
||||
class NewListFieldBucketHelper:
|
||||
op_prefix = "__$"
|
||||
_legacy_exclude_prefix = "-"
|
||||
_legacy_exclude_mongo_op = "nin"
|
||||
|
||||
default_mongo_op = "in"
|
||||
_ops = {
|
||||
# op -> (mongo_op, sticky)
|
||||
"not": ("nin", False),
|
||||
"nop": (default_mongo_op, False),
|
||||
"all": ("all", True),
|
||||
"and": ("all", True),
|
||||
"any": (default_mongo_op, True),
|
||||
"or": (default_mongo_op, True),
|
||||
_unary_operators = {
|
||||
"__$not": False,
|
||||
"__$nop": True,
|
||||
}
|
||||
_operators = {
|
||||
"__$all": Q.AND,
|
||||
"__$and": Q.AND,
|
||||
"__$any": Q.OR,
|
||||
"__$or": Q.OR,
|
||||
}
|
||||
default_operator = Q.OR
|
||||
mongo_modifiers = {
|
||||
Q.AND: {True: "all", False: "not__all"},
|
||||
Q.OR: {True: "in", False: "nin"},
|
||||
}
|
||||
|
||||
def __init__(self, field, legacy=False):
|
||||
@attr.s(auto_attribs=True)
|
||||
class Term:
|
||||
operator: str = None
|
||||
include: bool = True
|
||||
value: str = None
|
||||
|
||||
def __init__(self, field: str, data: Sequence[str], legacy=False):
|
||||
self._field = field
|
||||
self._current_op = None
|
||||
self._sticky = False
|
||||
self._support_legacy = legacy
|
||||
self.allow_empty = False
|
||||
self.global_operator = None
|
||||
self.actions = defaultdict(list)
|
||||
|
||||
def _get_op(self, v: str, translate: bool = False) -> Optional[str]:
|
||||
try:
|
||||
op = (
|
||||
v[len(self.op_prefix) :]
|
||||
if v and v.startswith(self.op_prefix)
|
||||
else None
|
||||
)
|
||||
if translate:
|
||||
tup = self._ops.get(op, None)
|
||||
return tup[0] if tup else None
|
||||
return op
|
||||
except AttributeError:
|
||||
raise errors.bad_request.FieldsValueError(
|
||||
"invalid value type, string expected",
|
||||
field=self._field,
|
||||
value=str(v),
|
||||
)
|
||||
|
||||
def _key(self, v) -> Optional[Union[str, bool]]:
|
||||
if v is None:
|
||||
self.allow_empty = True
|
||||
return None
|
||||
|
||||
op = self._get_op(v)
|
||||
if op is not None:
|
||||
# operator - set state and return None
|
||||
self._current_op, self._sticky = self._ops.get(
|
||||
op, (self.default_mongo_op, self._sticky)
|
||||
)
|
||||
return None
|
||||
elif self._current_op:
|
||||
current_op = self._current_op
|
||||
if not self._sticky:
|
||||
self._current_op = None
|
||||
return current_op
|
||||
elif self._support_legacy and v.startswith(self._legacy_exclude_prefix):
|
||||
self._current_op = None
|
||||
return False
|
||||
|
||||
return self.default_mongo_op
|
||||
|
||||
def get_global_op(self, data: Sequence[str]) -> int:
|
||||
op_to_res = {
|
||||
"in": Q.OR,
|
||||
"all": Q.AND,
|
||||
}
|
||||
data = (x for x in data if x is not None)
|
||||
first_op = (
|
||||
self._get_op(next(data, ""), translate=True) or self.default_mongo_op
|
||||
)
|
||||
return op_to_res.get(first_op, self.default_mongo_op)
|
||||
|
||||
def get_actions(self, data: Sequence[str]) -> Dict[str, List[Union[str, None]]]:
|
||||
actions = {}
|
||||
|
||||
for val in data:
|
||||
key = self._key(val)
|
||||
if key is None:
|
||||
current_context = self.default_operator
|
||||
for d in self._get_next_term(data):
|
||||
if d.operator is not None:
|
||||
current_context = d.operator
|
||||
if self.global_operator is None:
|
||||
self.global_operator = d.operator
|
||||
continue
|
||||
elif self._support_legacy and key is False:
|
||||
key = self._legacy_exclude_mongo_op
|
||||
val = val[len(self._legacy_exclude_prefix) :]
|
||||
actions.setdefault(key, []).append(val)
|
||||
|
||||
return actions
|
||||
if self.global_operator is None:
|
||||
self.global_operator = self.default_operator
|
||||
|
||||
if d.value is None:
|
||||
self.allow_empty = True
|
||||
continue
|
||||
|
||||
self.actions[self.mongo_modifiers[current_context][d.include]].append(
|
||||
d.value
|
||||
)
|
||||
|
||||
if self.global_operator is None:
|
||||
self.global_operator = self.default_operator
|
||||
|
||||
def _get_next_term(self, data: Sequence[str]) -> Generator[Term, None, None]:
|
||||
unary_operator = None
|
||||
for value in data:
|
||||
if value is None:
|
||||
unary_operator = None
|
||||
yield self.Term()
|
||||
continue
|
||||
|
||||
if not isinstance(value, str):
|
||||
raise FieldsValueError(
|
||||
"invalid value type, string expected",
|
||||
field=self._field,
|
||||
value=str(value),
|
||||
)
|
||||
|
||||
if value.startswith(self.op_prefix):
|
||||
if unary_operator:
|
||||
raise FieldsValueError(
|
||||
"Value is expected after",
|
||||
field=self._field,
|
||||
operator=unary_operator,
|
||||
)
|
||||
if value in self._unary_operators:
|
||||
unary_operator = value
|
||||
continue
|
||||
|
||||
operator = self._operators.get(value)
|
||||
if operator is None:
|
||||
raise FieldsValueError(
|
||||
"Unsupported operator",
|
||||
field=self._field,
|
||||
operator=value,
|
||||
)
|
||||
yield self.Term(operator=operator)
|
||||
continue
|
||||
|
||||
if self._support_legacy and value.startswith("-"):
|
||||
value = value[1:]
|
||||
if not value:
|
||||
raise FieldsValueError(
|
||||
"Missing value after the exclude prefix -",
|
||||
field=self._field,
|
||||
value=value,
|
||||
)
|
||||
unary_operator = None
|
||||
yield self.Term(value=value, include=False)
|
||||
continue
|
||||
|
||||
term = self.Term(value=value)
|
||||
if unary_operator:
|
||||
term.include = self._unary_operators[unary_operator]
|
||||
unary_operator = None
|
||||
yield term
|
||||
|
||||
if unary_operator:
|
||||
raise FieldsValueError(
|
||||
"Value is expected after", operator=unary_operator
|
||||
)
|
||||
|
||||
get_all_query_options = QueryParameterOptions()
|
||||
|
||||
@ -507,15 +531,15 @@ class GetMixin(PropsMixin):
|
||||
if not isinstance(data, (list, tuple)):
|
||||
data = [data]
|
||||
|
||||
helper = cls.ListFieldBucketHelper(field, legacy=True)
|
||||
global_op = helper.get_global_op(data)
|
||||
actions = helper.get_actions(data)
|
||||
helper = cls.NewListFieldBucketHelper(field, data=data, legacy=True)
|
||||
global_op = helper.global_operator
|
||||
actions = helper.actions
|
||||
|
||||
mongoengine_field = field.replace(".", "__")
|
||||
|
||||
queries = [
|
||||
RegexQ(**{f"{mongoengine_field}__{action}": list(set(actions[action]))})
|
||||
for action in filter(None, actions)
|
||||
RegexQ(**{f"{mongoengine_field}__{action}": list(set(values))})
|
||||
for action, values in actions.items()
|
||||
]
|
||||
|
||||
if not queries:
|
||||
|
@ -655,7 +655,11 @@ class APICall(DataContainer):
|
||||
}
|
||||
if self.content_type.lower() == JSON_CONTENT_TYPE:
|
||||
try:
|
||||
func = json.dumps if self._json_flags.pop("ensure_ascii", True) else json.dumps_notascii
|
||||
func = (
|
||||
json.dumps
|
||||
if self._json_flags.pop("ensure_ascii", True)
|
||||
else json.dumps_notascii
|
||||
)
|
||||
res = func(res, **(self._json_flags or {}))
|
||||
except Exception as ex:
|
||||
# JSON serialization may fail, probably problem with data or error_data so pop it and try again
|
||||
@ -685,8 +689,12 @@ class APICall(DataContainer):
|
||||
cookies=self._result.cookies,
|
||||
)
|
||||
|
||||
def get_redacted_headers(self):
|
||||
headers = self.headers.copy()
|
||||
def get_redacted_headers(self, fields=None):
|
||||
headers = (
|
||||
{k: v for k, v in self._headers.items() if k in fields}
|
||||
if fields
|
||||
else self.headers
|
||||
)
|
||||
if not self.requires_authorization or self.auth:
|
||||
# We won't log the authorization header if call shouldn't be authorized, or if it was successfully
|
||||
# authorized. This means we'll only log authorization header for calls that failed to authorize (hopefully
|
||||
|
@ -101,7 +101,7 @@ def _upgrade_tags(call: APICall, tags: Sequence, system_tags: Sequence):
|
||||
def validate_tags(tags: Sequence[str], system_tags: Sequence[str]):
|
||||
for values in filter(None, (tags, system_tags)):
|
||||
unsupported = [
|
||||
t for t in values if t.startswith(GetMixin.ListFieldBucketHelper.op_prefix)
|
||||
t for t in values if t.startswith(GetMixin.NewListFieldBucketHelper.op_prefix)
|
||||
]
|
||||
if unsupported:
|
||||
raise errors.bad_request.FieldsValueError(
|
||||
|
@ -51,7 +51,7 @@ class TestUsersService(TestService):
|
||||
self._assertUsers((user_2, user_3), users)
|
||||
|
||||
# specific project
|
||||
users = self.api.users.get_all_ex(active_in_projects=[project]).users
|
||||
users = self.api.users.get_all_ex(id=user_ids, active_in_projects=[project]).users
|
||||
self._assertUsers((user_3,), users)
|
||||
|
||||
def _assertUsers(self, expected: Sequence, users: Sequence):
|
||||
|
Loading…
Reference in New Issue
Block a user