import re
from collections import namedtuple
from functools import reduce
from typing import Collection
from dateutil.parser import parse as parse_datetime
from mongoengine import Q, Document
from six import string_types
from apierrors import errors
from config import config
from database.errors import MakeGetAllQueryError
from database.projection import project_dict, ProjectionHelper
from database.props import PropsMixin
from database.query import RegexQ, RegexWrapper
from database.utils import get_company_or_none_constraint, get_fields_with_attr
log = config.logger("dbmodel")
ACCESS_REGEX = re.compile(r"^(?P<prefix>>=|>|<=|<)?(?P<value>.*)$")
ACCESS_MODIFIER = {">=": "gte", ">": "gt", "<=": "lte", "<": "lt"}
ABSTRACT_FLAG = {"abstract": True}
class AuthDocument(Document):
class ProperDictMixin(object):
def to_proper_dict(self, strip_private=True, only=None, extra_dict=None) -> dict:
return self.properize_dict(
def properize_dict(
cls, d, strip_private=True, only=None, extra_dict=None, normalize_id=True
res = d
if normalize_id and "_id" in res:
res["id"] = res.pop("_id")
if strip_private:
res = {k: v for k, v in res.items() if k[0] != "_"}
if only:
res = project_dict(res, only)
if extra_dict:
return res
class GetMixin(PropsMixin):
_text_score = "$text_score"
_ordering_key = "order_by"
_multi_field_param_sep = "__"
_multi_field_param_prefix = {
("_any_", "_or_"): lambda a, b: a | b,
("_all_", "_and_"): lambda a, b: a & b,
MultiFieldParameters = namedtuple("MultiFieldParameters", "pattern fields")
class QueryParameterOptions(object):
def __init__(
list_fields=("tags", "id"),
:param pattern_fields: Fields for which a "string contains" condition should be generated
:param list_fields: Fields for which a "list contains" condition should be generated
:param datetime_fields: Fields for which datetime condition should be generated (see ACCESS_MODIFIER)
:param fields: Fields which which a simple equality condition should be generated (basically filters out all
other unsupported query fields)
self.fields = fields
self.datetime_fields = datetime_fields
self.list_fields = list_fields
self.pattern_fields = pattern_fields
get_all_query_options = QueryParameterOptions()
def get(
cls, company, id, *, _only=None, include_public=False, **kwargs
) -> "GetMixin":
q = cls.objects(
cls._prepare_perm_query(company, allow_public=include_public)
& Q(id=id, **kwargs)
if _only:
q = q.only(*_only)
return q.first()
def prepare_query(
company: str,
parameters: dict = None,
parameters_options: QueryParameterOptions = None,
Prepare a query object based on the provided query dictionary and various fields.
:param parameters_options: Specifies options for parsing the parameters (see ParametersOptions)
:param company: Company ID (required)
:param allow_public: Allow results from public objects
:param parameters: Query dictionary (relevant keys are these specified by the various field names parameters).
Supported parameters:
- <field_name>: <value> Will query for items with this value in the field (see QueryParameterOptions for
specific rules on handling values). Only items matching ALL of these conditions will be retrieved.
- <any|all>: {fields: [<field1>, <field2>, ...], pattern: <pattern>} Will query for items where any or all
provided fields match the provided pattern.
:return: mongoengine.Q query object
return cls._prepare_query_no_company(
parameters, parameters_options
) & cls._prepare_perm_query(company, allow_public=allow_public)
def _prepare_query_no_company(
cls, parameters=None, parameters_options=QueryParameterOptions()
Prepare a query object based on the provided query dictionary and various fields.
NOTE: BE VERY CAREFUL WITH THIS CALL, as it allows creating queries that span across companies.
:param parameters_options: Specifies options for parsing the parameters (see ParametersOptions)
:param parameters: Query dictionary (relevant keys are these specified by the various field names parameters).
Supported parameters:
- <field_name>: <value> Will query for items with this value in the field (see QueryParameterOptions for
specific rules on handling values). Only items matching ALL of these conditions will be retrieved.
- <any|all>: {fields: [<field1>, <field2>, ...], pattern: <pattern>} Will query for items where any or all
provided fields match the provided pattern.
:return: mongoengine.Q query object
parameters_options = parameters_options or cls.get_all_query_options
dict_query = {}
query = RegexQ()
if parameters:
parameters = parameters.copy()
opts = parameters_options
for field in opts.pattern_fields:
pattern = parameters.pop(field, None)
if pattern:
dict_query[field] = RegexWrapper(pattern)
for field in tuple(opts.list_fields or ()):
data = parameters.pop(field, None)
if data:
if not isinstance(data, (list, tuple)):
raise MakeGetAllQueryError("expected list", field)
exclude = [t for t in data if t.startswith("-")]
include = list(set(data).difference(exclude))
mongoengine_field = field.replace(".", "__")
if include:
dict_query[f"{mongoengine_field}__in"] = include
if exclude:
dict_query[f"{mongoengine_field}__nin"] = [
t[1:] for t in exclude
for field in opts.fields or []:
data = parameters.pop(field, None)
if data is not None:
dict_query[field] = data
for field in opts.datetime_fields or []:
data = parameters.pop(field, None)
if data is not None:
if not isinstance(data, list):
data = [data]
for d in data: # type: str
m = ACCESS_REGEX.match(d)
if not m:
value = parse_datetime(m.group("value"))
prefix = m.group("prefix")
modifier = ACCESS_MODIFIER.get(prefix)
f = field if not modifier else "__".join((field, modifier))
dict_query[f] = value
except (ValueError, OverflowError):
for field, value in parameters.items():
for keys, func in cls._multi_field_param_prefix.items():
if field not in keys:
data = cls.MultiFieldParameters(**value)
except Exception:
raise MakeGetAllQueryError("incorrect field format", field)
if not data.fields:
regex = RegexWrapper(data.pattern, flags=re.IGNORECASE)
sep_fields = [f.replace(".", "__") for f in data.fields]
q = reduce(
lambda a, x: func(a, RegexQ(**{x: regex})), sep_fields, RegexQ()
query = query & q
return query & RegexQ(**dict_query)
def _prepare_perm_query(cls, company, allow_public=False):
if allow_public:
return get_company_or_none_constraint(company)
return Q(company=company)
def validate_paging(
cls, parameters=None, default_page=None, default_page_size=None
""" Validate and extract paging info from from the provided dictionary. Supports default values. """
if parameters is None:
parameters = {}
default_page = parameters.get("page", default_page)
if default_page is None:
return None, None
default_page_size = parameters.get("page_size", default_page_size)
if not default_page_size:
raise errors.bad_request.MissingRequiredFields(
"page_size is required when page is requested", field="page_size"
elif default_page < 0:
raise errors.bad_request.ValidationError("page must be >=0", field="page")
elif default_page_size < 1:
raise errors.bad_request.ValidationError(
"page_size must be >0", field="page_size"
return default_page, default_page_size
def get_projection(cls, parameters, override_projection=None, **__):
""" Extract a projection list from the provided dictionary. Supports an override projection. """
if override_projection is not None:
return override_projection
if not parameters:
return []
return parameters.get("projection") or parameters.get("only_fields", [])
def set_default_ordering(cls, parameters, value):
parameters[cls._ordering_key] = parameters.get(cls._ordering_key) or value
def get_many_with_join(
Fetch all documents matching a provided query with support for joining referenced documents according to the
requested projection. See get_many() for more info.
:param expand_reference_ids: If True, reference fields that contain just an ID string are expanded into
a sub-document in the format {_id: <ID>}. Otherwise, field values are left as a string.
if issubclass(cls, AuthDocument):
# Refuse projection (join) for auth documents (auth.User etc.) to avoid inadvertently disclosing
# auth-related secrets and prevent security leaks
f"Attempted projection of {cls.__name__} auth document (ignored)",
return []
override_projection = cls.get_projection(
parameters=query_dict, override_projection=override_projection
helper = ProjectionHelper(
# Make the main query
results = cls.get_many(
def projection_func(doc_type, projection, ids):
return doc_type.get_many_with_join(
return helper.project(results, projection_func)
def get_many(
parameters: dict = None,
query_dict: dict = None,
query_options: QueryParameterOptions = None,
query: Q = None,
override_projection: Collection[str] = None,
Fetch all documents matching a provided query. Supported several built-in options
(aside from those provided by the parameters):
- Ordering: using query field `order_by` which can contain a string or a list of strings corresponding to
field names. Using field names not defined in the document will cause an error.
- Paging: using query fields page and page_size. page must be larger than or equal to 0, page_size must be
larger than 0 and is required when specifying a page.
- Text search: using query field `search_text`. If used, text score can be used in the ordering, using the
`@text_score` keyword. A text index must be defined on the document type, otherwise an error will
be raised.
:param return_dicts: Return a list of dictionaries. If True, a list of dicts is returned (if projection was
requested, each contains only the requested projection).
If False, a QuerySet object is returned (lazy evaluated)
:param company: Company ID (required)
:param parameters: Parameters dict from which paging ordering and searching parameters are extracted.
:param query_dict: If provided, passed to prepare_query() along with all of the relevant arguments to produce
a query. The resulting query is AND'ed with the `query` parameter (if provided).
:param query_options: query parameters options (see ParametersOptions)
:param query: Optional query object (mongoengine.Q)
:param override_projection: A list of projection fields overriding any projection specified in the `param_dict`
:param allow_public: If True, objects marked as public (no associated company) are also queried.
:return: A list of objects matching the query.
if query_dict is not None:
q = cls.prepare_query(
q = cls._prepare_perm_query(company, allow_public=allow_public)
_query = (q & query) if query else q
return cls._get_many_no_company(
def _get_many_no_company(
cls, query, parameters=None, override_projection=None, return_dicts=True
Fetch all documents matching a provided query.
This is a company-less version for internal uses. We assume the caller has either added any necessary
constraints to the query or that no constraints are required.
NOTE: BE VERY CAREFUL WITH THIS CALL, as it allows returning data across companies.
:param query: Query object (mongoengine.Q)
:param return_dicts: Return a list of dictionaries. If True, a list of dicts is returned (if projection was
requested, each contains only the requested projection).
If False, a QuerySet object is returned (lazy evaluated)
:param parameters: Parameters dict from which paging ordering and searching parameters are extracted.
:param override_projection: A list of projection fields overriding any projection specified in the `param_dict`
parameters = parameters or {}
if not query:
raise ValueError("query or call_data must be provided")
page, page_size = cls.validate_paging(parameters=parameters)
order_by = parameters.get(cls._ordering_key)
if order_by:
order_by = order_by if isinstance(order_by, list) else [order_by]
order_by = [cls._text_score if x == "@text_score" else x for x in order_by]
search_text = parameters.get("search_text")
only = cls.get_projection(parameters, override_projection)
if not search_text and order_by and cls._text_score in order_by:
raise errors.bad_request.FieldsValueError(
"text score cannot be used in order_by when search text is not used"
qs = cls.objects(query)
if search_text:
qs = qs.search_text(search_text)
if order_by:
# add ordering
qs = (
if isinstance(order_by, string_types)
else qs.order_by(*order_by)
if only:
# add projection
qs = qs.only(*only)
exclude = set(cls.get_exclude_fields()).difference(only)
if exclude:
qs = qs.exclude(*exclude)
if page is not None and page_size:
# add paging
qs = qs.skip(page * page_size).limit(page_size)
if return_dicts:
return [obj.to_proper_dict(only=only) for obj in qs]
return qs
def get_for_writing(
cls, *args, _only: Collection[str] = None, **kwargs
) -> "GetMixin":
if _only and "company" not in _only:
_only = list(set(_only) | {"company"})
result = cls.get(*args, _only=_only, include_public=True, **kwargs)
if result and not result.company:
object_name = cls.__name__.lower()
raise errors.forbidden.NoWritePermission(
f"cannot modify public {object_name}(s), ids={(result.id,)}"
return result
def get_many_for_writing(cls, company, *args, **kwargs):
result = cls.get_many(
**dict(return_dicts=False, **kwargs),
forbidden_objects = {obj.id for obj in result if not obj.company}
if forbidden_objects:
object_name = cls.__name__.lower()
raise errors.forbidden.NoWritePermission(
f"cannot modify public {object_name}(s), ids={tuple(forbidden_objects)}"
return result
class UpdateMixin(object):
def user_set_allowed(cls):
res = getattr(cls, "__user_set_allowed_fields", None)
if res is None:
res = cls.__user_set_allowed_fields = dict(
get_fields_with_attr(cls, "user_set_allowed")
return res
def get_safe_update_dict(cls, fields):
if not fields:
return {}
valid_fields = cls.user_set_allowed()
fields = [(k, v, fields[k]) for k, v in valid_fields.items() if k in fields]
update_dict = {
field: value
for field, allowed, value in fields
if allowed is None
or (
(value in allowed)
if not isinstance(value, list)
else all(v in allowed for v in value)
return update_dict
def safe_update(cls, company_id, id, partial_update_dict, injected_update=None):
update_dict = cls.get_safe_update_dict(partial_update_dict)
if not update_dict:
return 0, {}
if injected_update:
update_count = cls.objects(id=id, company=company_id).update(
upsert=False, **update_dict
return update_count, update_dict
class DbModelMixin(GetMixin, ProperDictMixin, UpdateMixin):
""" Provide convenience methods for a subclass of mongoengine.Document """
def validate_id(cls, company, **kwargs):
Validate existence of objects with certain IDs. within company.
:param cls: Model class to search in
:param company: Company to search in
:param kwargs: Mapping of field name to object ID. If any ID does not have a corresponding object,
it will be reported along with the name it was assigned to.
ids = set(kwargs.values())
objs = list(cls.objects(company=company, id__in=ids).only("id"))
missing = ids - set(x.id for x in objs)
if not missing:
id_to_name = {}
for name, obj_id in kwargs.items():
id_to_name.setdefault(obj_id, []).append(name)
raise errors.bad_request.ValidationError(
"Invalid {} ids".format(cls.__name__.lower()),
**{name: obj_id for obj_id in missing for name in id_to_name[obj_id]}