Add support for field exclusion in get_all endpoints

Add support for ephemeral worker tags (valid while worker has not timed out)
This commit is contained in:
allegroai
2020-08-10 08:48:48 +03:00
parent 8c7e230898
commit cd4ce30f7c
10 changed files with 526 additions and 42 deletions

View File

@@ -1,9 +1,9 @@
import re
from collections import namedtuple
from functools import reduce
from typing import Collection, Sequence, Union, Optional, Type
from typing import Collection, Sequence, Union, Optional, Type, Tuple
from boltons.iterutils import first, bucketize
from boltons.iterutils import first, bucketize, partition
from dateutil.parser import parse as parse_datetime
from mongoengine import Q, Document, ListField, StringField
from pymongo.command_cursor import CommandCursor
@@ -348,6 +348,17 @@ class GetMixin(PropsMixin):
return []
return parameters.get(cls._projection_key) or parameters.get("only_fields", [])
@classmethod
def split_projection(
cls, projection: Sequence[str]
) -> Tuple[Collection[str], Collection[str]]:
"""Return include and exclude lists based on passed projection and class definition"""
include, exclude = partition(
projection, key=lambda x: x[0] != ProjectionHelper.exclusion_prefix,
)
exclude = {x.lstrip(ProjectionHelper.exclusion_prefix) for x in exclude}
return include, set(cls.get_exclude_fields()).union(exclude).difference(include)
@classmethod
def set_projection(cls, parameters: dict, value: Sequence[str]) -> Sequence[str]:
parameters.pop("only_fields", None)
@@ -502,7 +513,7 @@ class GetMixin(PropsMixin):
@classmethod
def _get_many_no_company(
cls: Union["GetMixin", Document],
query,
query: Q,
parameters=None,
override_projection=None,
):
@@ -525,7 +536,9 @@ class GetMixin(PropsMixin):
search_text = parameters.get(cls._search_text_key)
order_by = cls.validate_order_by(parameters=parameters, search_text=search_text)
page, page_size = cls.validate_paging(parameters=parameters)
only = cls.get_projection(parameters, override_projection)
include, exclude = cls.split_projection(
cls.get_projection(parameters, override_projection)
)
qs = cls.objects(query)
if search_text:
@@ -533,13 +546,14 @@ class GetMixin(PropsMixin):
if order_by:
# add ordering
qs = qs.order_by(*order_by)
if only:
if include:
# add projection
qs = qs.only(*only)
else:
exclude = set(cls.get_exclude_fields()).difference(only)
if exclude:
qs = qs.exclude(*exclude)
qs = qs.only(*include)
if exclude:
qs = qs.exclude(*exclude)
if page is not None and page_size:
# add paging
qs = qs.skip(page * page_size).limit(page_size)
@@ -575,7 +589,9 @@ class GetMixin(PropsMixin):
search_text = parameters.get(cls._search_text_key)
order_by = cls.validate_order_by(parameters=parameters, search_text=search_text)
page, page_size = cls.validate_paging(parameters=parameters)
only = cls.get_projection(parameters, override_projection)
include, exclude = cls.split_projection(
cls.get_projection(parameters, override_projection)
)
query_sets = [cls.objects(query)]
if order_by:
@@ -612,16 +628,15 @@ class GetMixin(PropsMixin):
if search_text:
query_sets = [qs.search_text(search_text) for qs in query_sets]
if only:
if include:
# add projection
query_sets = [qs.only(*only) for qs in query_sets]
else:
exclude = set(cls.get_exclude_fields())
if exclude:
query_sets = [qs.exclude(*exclude) for qs in query_sets]
query_sets = [qs.only(*include) for qs in query_sets]
if exclude:
query_sets = [qs.exclude(*exclude) for qs in query_sets]
if page is None or not page_size:
return [obj.to_proper_dict(only=only) for qs in query_sets for obj in qs]
return [obj.to_proper_dict(only=include) for qs in query_sets for obj in qs]
# add paging
ret = []
@@ -632,7 +647,8 @@ class GetMixin(PropsMixin):
start -= qs_size
continue
ret.extend(
obj.to_proper_dict(only=only) for obj in qs.skip(start).limit(page_size)
obj.to_proper_dict(only=include)
for obj in qs.skip(start).limit(page_size)
)
if len(ret) >= page_size:
break

View File

@@ -45,7 +45,7 @@ def project_dict(data, projection, separator=SEP):
)
dst[path_part] = [
copy_path(path_parts[depth + 1:], s, d)
copy_path(path_parts[depth + 1 :], s, d)
for s, d in zip(src_part, dst[path_part])
]
@@ -96,6 +96,7 @@ class _ProxyManager:
class ProjectionHelper(object):
pool = ThreadPoolExecutor()
exclusion_prefix = "-"
@property
def doc_projection(self):
@@ -128,20 +129,28 @@ class ProjectionHelper(object):
[]
) # Projection information for reference fields (used in join queries)
for field in projection:
field_ = field.lstrip(self.exclusion_prefix)
for ref_field, ref_field_cls in doc_cls.get_reference_fields().items():
if not field.startswith(ref_field):
if not field_.startswith(ref_field):
# Doesn't start with a reference field
continue
if field == ref_field:
if field_ == ref_field:
# Field is exactly a reference field. In this case we won't perform any inner projection (for that,
# use '<reference field name>.*')
continue
subfield = field[len(ref_field):]
subfield = field_[len(ref_field) :]
if not subfield.startswith(SEP):
# Starts with something that looks like a reference field, but isn't
continue
ref_projection_info.append((ref_field, ref_field_cls, subfield[1:]))
ref_projection_info.append(
(
ref_field,
ref_field_cls,
("" if field_[0] == field[0] else self.exclusion_prefix)
+ subfield[1:],
)
)
break
else:
# Not a reference field, just add to the top-level projection
@@ -149,7 +158,7 @@ class ProjectionHelper(object):
orig_field = field
if field.endswith(".*"):
field = field[:-2]
if not field:
if not field.lstrip(self.exclusion_prefix):
raise errors.bad_request.InvalidFields(
field=orig_field, object=doc_cls.__name__
)
@@ -199,7 +208,7 @@ class ProjectionHelper(object):
# Make sure this doesn't contain any reference field we'll join anyway
# (i.e. in case only_fields=[project, project.name])
doc_projection = normalize_cls_projection(
doc_cls, doc_projection.difference(ref_projection).union({"id"})
doc_cls, doc_projection.difference(ref_projection)
)
# Make sure that in case one or more field is a subfield of another field, we only use the the top-level field.
@@ -218,7 +227,10 @@ class ProjectionHelper(object):
# Make sure we didn't get any invalid projection fields for this class
invalid_fields = [
f for f in doc_projection if f.split(SEP)[0] not in doc_cls.get_fields()
f
for f in doc_projection
if f.partition(SEP)[0].lstrip(self.exclusion_prefix)
not in doc_cls.get_fields()
]
if invalid_fields:
raise errors.bad_request.InvalidFields(
@@ -234,6 +246,13 @@ class ProjectionHelper(object):
doc_projection.add(field)
doc_projection = list(doc_projection)
# If there are include fields (not only exclude) then add an id field
if (
not all(p.startswith(self.exclusion_prefix) for p in doc_projection)
and "id" not in doc_projection
):
doc_projection.append("id")
self._doc_projection = doc_projection
self._ref_projection = ref_projection
@@ -314,6 +333,7 @@ class ProjectionHelper(object):
]
if items:
def do_projection(item):
ref_field_name, data, ids = item