Fix bad field values might cause ugly server exception to be returned

This commit is contained in:
allegroai 2022-12-21 18:33:28 +02:00
parent ae4c33fa0e
commit 54ce6c34c6

View File

@ -20,7 +20,7 @@ from dateutil.parser import parse as parse_datetime
from mongoengine import Q, Document, ListField, StringField, IntField from mongoengine import Q, Document, ListField, StringField, IntField
from pymongo.command_cursor import CommandCursor from pymongo.command_cursor import CommandCursor
from apiserver.apierrors import errors from apiserver.apierrors import errors, APIError
from apiserver.apierrors.base import BaseError from apiserver.apierrors.base import BaseError
from apiserver.bll.redis_cache_manager import RedisCacheManager from apiserver.bll.redis_cache_manager import RedisCacheManager
from apiserver.config_repo import config from apiserver.config_repo import config
@ -148,20 +148,26 @@ class GetMixin(PropsMixin):
"or": (default_mongo_op, True), "or": (default_mongo_op, True),
} }
def __init__(self, legacy=False): def __init__(self, field, legacy=False):
self._field = field
self._current_op = None self._current_op = None
self._sticky = False self._sticky = False
self._support_legacy = legacy self._support_legacy = legacy
self.allow_empty = False self.allow_empty = False
def _get_op(self, v: str, translate: bool = False) -> Optional[str]: def _get_op(self, v: str, translate: bool = False) -> Optional[str]:
op = ( try:
v[len(self.op_prefix) :] if v and v.startswith(self.op_prefix) else None op = (
) v[len(self.op_prefix) :] if v and v.startswith(self.op_prefix) else None
if translate: )
tup = self._ops.get(op, None) if translate:
return tup[0] if tup else None tup = self._ops.get(op, None)
return op return tup[0] if tup else None
return op
except AttributeError:
raise errors.bad_request.FieldsValueError(
"invalid value type, string expected", field=self._field, value=str(v)
)
def _key(self, v) -> Optional[Union[str, bool]]: def _key(self, v) -> Optional[Union[str, bool]]:
if v is None: if v is None:
@ -347,97 +353,106 @@ class GetMixin(PropsMixin):
parameters_options = parameters_options or cls.get_all_query_options parameters_options = parameters_options or cls.get_all_query_options
dict_query = {} dict_query = {}
query = RegexQ() query = RegexQ()
if parameters: field = None
parameters = { # noinspection PyBroadException
k: cls._get_fixed_field_value(k, v) for k, v in parameters.items() try:
} if parameters:
opts = parameters_options parameters = {
for field in opts.pattern_fields: k: cls._get_fixed_field_value(k, v) for k, v in parameters.items()
pattern = parameters.pop(field, None) }
if pattern: opts = parameters_options
dict_query[field] = RegexWrapper(pattern) for field in opts.pattern_fields:
pattern = parameters.pop(field, None)
if pattern:
dict_query[field] = RegexWrapper(pattern)
for field, data in cls._pop_matching_params( for field, data in cls._pop_matching_params(
patterns=opts.list_fields, parameters=parameters patterns=opts.list_fields, parameters=parameters
).items(): ).items():
query &= cls.get_list_field_query(field, data) query &= cls.get_list_field_query(field, data)
for field, data in cls._pop_matching_params( for field, data in cls._pop_matching_params(
patterns=opts.range_fields, parameters=parameters patterns=opts.range_fields, parameters=parameters
).items(): ).items():
query &= cls.get_range_field_query(field, data) query &= cls.get_range_field_query(field, data)
for field, data in cls._pop_matching_params( for field, data in cls._pop_matching_params(
patterns=opts.fields or [], parameters=parameters patterns=opts.fields or [], parameters=parameters
).items(): ).items():
if "._" in field or "_." in field: if "._" in field or "_." in field:
query &= RegexQ(__raw__={field: data}) query &= RegexQ(__raw__={field: data})
else:
dict_query[field.replace(".", "__")] = data
for field in opts.datetime_fields or []:
data = parameters.pop(field, None)
if data is not None:
if not isinstance(data, list):
data = [data]
# date time fields also support simplified range queries. Check if this is the case
if len(data) == 2 and not any(
d.startswith(mod)
for d in data
if d is not None
for mod in ACCESS_MODIFIER
):
query &= cls.get_range_field_query(field, data)
else: else:
for d in data: # type: str dict_query[field.replace(".", "__")] = data
m = ACCESS_REGEX.match(d)
if not m:
continue
try:
value = parse_datetime(m.group("value"))
prefix = m.group("prefix")
modifier = ACCESS_MODIFIER.get(prefix)
f = (
field
if not modifier
else "__".join((field, modifier))
)
dict_query[f] = value
except (ValueError, OverflowError):
pass
for field, value in parameters.items(): for field in opts.datetime_fields or []:
for keys, func in cls._multi_field_param_prefix.items(): data = parameters.pop(field, None)
if field not in keys: if data is not None:
continue if not isinstance(data, list):
try: data = [data]
data = cls.MultiFieldParameters(**value) # date time fields also support simplified range queries. Check if this is the case
except Exception: if len(data) == 2 and not any(
raise MakeGetAllQueryError("incorrect field format", field) d.startswith(mod)
if not data.fields: for d in data
break if d is not None
if any("._" in f for f in data.fields): for mod in ACCESS_MODIFIER
q = reduce( ):
lambda a, x: func( query &= cls.get_range_field_query(field, data)
a, else:
RegexQ( for d in data: # type: str
__raw__={ m = ACCESS_REGEX.match(d)
x: {"$regex": data.pattern, "$options": "i"} if not m:
} continue
try:
value = parse_datetime(m.group("value"))
prefix = m.group("prefix")
modifier = ACCESS_MODIFIER.get(prefix)
f = (
field
if not modifier
else "__".join((field, modifier))
)
dict_query[f] = value
except (ValueError, OverflowError):
pass
for field, value in parameters.items():
for keys, func in cls._multi_field_param_prefix.items():
if field not in keys:
continue
try:
data = cls.MultiFieldParameters(**value)
except Exception:
raise MakeGetAllQueryError("incorrect field format", field)
if not data.fields:
break
if any("._" in f for f in data.fields):
q = reduce(
lambda a, x: func(
a,
RegexQ(
__raw__={
x: {"$regex": data.pattern, "$options": "i"}
}
),
), ),
), data.fields,
data.fields, RegexQ(),
RegexQ(), )
) else:
else: regex = RegexWrapper(data.pattern, flags=re.IGNORECASE)
regex = RegexWrapper(data.pattern, flags=re.IGNORECASE) sep_fields = [f.replace(".", "__") for f in data.fields]
sep_fields = [f.replace(".", "__") for f in data.fields] q = reduce(
q = reduce( lambda a, x: func(a, RegexQ(**{x: regex})),
lambda a, x: func(a, RegexQ(**{x: regex})), sep_fields,
sep_fields, RegexQ(),
RegexQ(), )
) query = query & q
query = query & q except APIError:
raise
except Exception as ex:
raise errors.bad_request.FieldsValueError(
"failed parsing query field", error=str(ex), **({"field": field} if field else {})
)
return query & RegexQ(**dict_query) return query & RegexQ(**dict_query)
@ -486,7 +501,7 @@ class GetMixin(PropsMixin):
if not isinstance(data, (list, tuple)): if not isinstance(data, (list, tuple)):
data = [data] data = [data]
helper = cls.ListFieldBucketHelper(legacy=True) helper = cls.ListFieldBucketHelper(field, legacy=True)
global_op = helper.get_global_op(data) global_op = helper.get_global_op(data)
actions = helper.get_actions(data) actions = helper.get_actions(data)