mirror of
https://github.com/clearml/clearml-server
synced 2025-01-31 19:06:55 +00:00
4cd4b2914d
Switch from sematic_version to packaging.version in db migrations
897 lines
34 KiB
Python
897 lines
34 KiB
Python
import re
|
|
from collections import namedtuple
|
|
from functools import reduce
|
|
from typing import Collection, Sequence, Union, Optional, Type, Tuple, Mapping, Any
|
|
|
|
from boltons.iterutils import first, bucketize, partition
|
|
from dateutil.parser import parse as parse_datetime
|
|
from mongoengine import Q, Document, ListField, StringField
|
|
from pymongo.command_cursor import CommandCursor
|
|
|
|
from apiserver.apierrors import errors
|
|
from apiserver.apierrors.base import BaseError
|
|
from apiserver.config_repo import config
|
|
from apiserver.database.errors import MakeGetAllQueryError
|
|
from apiserver.database.projection import project_dict, ProjectionHelper
|
|
from apiserver.database.props import PropsMixin
|
|
from apiserver.database.query import RegexQ, RegexWrapper
|
|
from apiserver.database.utils import (
|
|
get_company_or_none_constraint,
|
|
get_fields_choices,
|
|
field_does_not_exist,
|
|
field_exists,
|
|
)
|
|
|
|
log = config.logger("dbmodel")
|
|
|
|
ACCESS_REGEX = re.compile(r"^(?P<prefix>>=|>|<=|<)?(?P<value>.*)$")
|
|
ACCESS_MODIFIER = {">=": "gte", ">": "gt", "<=": "lte", "<": "lt"}
|
|
|
|
ABSTRACT_FLAG = {"abstract": True}
|
|
|
|
|
|
class AuthDocument(Document):
|
|
meta = ABSTRACT_FLAG
|
|
|
|
|
|
class ProperDictMixin(object):
|
|
def to_proper_dict(
|
|
self: Union["ProperDictMixin", Document],
|
|
strip_private=True,
|
|
only=None,
|
|
extra_dict=None,
|
|
) -> dict:
|
|
return self.properize_dict(
|
|
self.to_mongo(use_db_field=False).to_dict(),
|
|
strip_private=strip_private,
|
|
only=only,
|
|
extra_dict=extra_dict,
|
|
)
|
|
|
|
@classmethod
|
|
def properize_dict(
|
|
cls, d, strip_private=True, only=None, extra_dict=None, normalize_id=True
|
|
):
|
|
res = d
|
|
if normalize_id and "_id" in res:
|
|
res["id"] = res.pop("_id")
|
|
if strip_private:
|
|
res = {k: v for k, v in res.items() if k[0] != "_"}
|
|
if only:
|
|
res = project_dict(res, only)
|
|
if extra_dict:
|
|
res.update(extra_dict)
|
|
return res
|
|
|
|
|
|
class GetMixin(PropsMixin):
|
|
_text_score = "$text_score"
|
|
_projection_key = "projection"
|
|
_ordering_key = "order_by"
|
|
_search_text_key = "search_text"
|
|
|
|
_multi_field_param_sep = "__"
|
|
_multi_field_param_prefix = {
|
|
("_any_", "_or_"): lambda a, b: a | b,
|
|
("_all_", "_and_"): lambda a, b: a & b,
|
|
}
|
|
MultiFieldParameters = namedtuple("MultiFieldParameters", "pattern fields")
|
|
|
|
_field_collation_overrides = {}
|
|
|
|
class QueryParameterOptions(object):
|
|
def __init__(
|
|
self,
|
|
pattern_fields=("name",),
|
|
list_fields=("tags", "system_tags", "id"),
|
|
datetime_fields=None,
|
|
fields=None,
|
|
range_fields=None,
|
|
):
|
|
"""
|
|
:param pattern_fields: Fields for which a "string contains" condition should be generated
|
|
:param list_fields: Fields for which a "list contains" condition should be generated
|
|
:param datetime_fields: Fields for which datetime condition should be generated (see ACCESS_MODIFIER)
|
|
:param fields: Fields which which a simple equality condition should be generated (basically filters out all
|
|
other unsupported query fields)
|
|
"""
|
|
self.fields = fields
|
|
self.datetime_fields = datetime_fields
|
|
self.list_fields = list_fields
|
|
self.range_fields = range_fields
|
|
self.pattern_fields = pattern_fields
|
|
|
|
class ListFieldBucketHelper:
|
|
op_prefix = "__$"
|
|
legacy_exclude_prefix = "-"
|
|
|
|
_default = "in"
|
|
_ops = {
|
|
"not": ("nin", False),
|
|
"all": ("all", True),
|
|
"and": ("all", True),
|
|
}
|
|
_next = _default
|
|
_sticky = False
|
|
|
|
def __init__(self, legacy=False):
|
|
self._legacy = legacy
|
|
|
|
def key(self, v):
|
|
if v is None:
|
|
self._next = self._default
|
|
return self._default
|
|
elif self._legacy and v.startswith(self.legacy_exclude_prefix):
|
|
self._next = self._default
|
|
return self._ops["not"][0]
|
|
elif v.startswith(self.op_prefix):
|
|
self._next, self._sticky = self._ops.get(
|
|
v[len(self.op_prefix) :], (self._default, self._sticky)
|
|
)
|
|
return None
|
|
|
|
next_ = self._next
|
|
if not self._sticky:
|
|
self._next = self._default
|
|
return next_
|
|
|
|
def value_transform(self, v):
|
|
if self._legacy and v and v.startswith(self.legacy_exclude_prefix):
|
|
return v[len(self.legacy_exclude_prefix) :]
|
|
return v
|
|
|
|
get_all_query_options = QueryParameterOptions()
|
|
|
|
@classmethod
|
|
def get(
|
|
cls: Union["GetMixin", Document],
|
|
company,
|
|
id,
|
|
*,
|
|
_only=None,
|
|
include_public=False,
|
|
**kwargs,
|
|
) -> "GetMixin":
|
|
q = cls.objects(
|
|
cls._prepare_perm_query(company, allow_public=include_public)
|
|
& Q(id=id, **kwargs)
|
|
)
|
|
if _only:
|
|
q = q.only(*_only)
|
|
return q.first()
|
|
|
|
@classmethod
|
|
def prepare_query(
|
|
cls,
|
|
company: str,
|
|
parameters: dict = None,
|
|
parameters_options: QueryParameterOptions = None,
|
|
allow_public=False,
|
|
):
|
|
"""
|
|
Prepare a query object based on the provided query dictionary and various fields.
|
|
:param parameters_options: Specifies options for parsing the parameters (see ParametersOptions)
|
|
:param company: Company ID (required)
|
|
:param allow_public: Allow results from public objects
|
|
:param parameters: Query dictionary (relevant keys are these specified by the various field names parameters).
|
|
Supported parameters:
|
|
- <field_name>: <value> Will query for items with this value in the field (see QueryParameterOptions for
|
|
specific rules on handling values). Only items matching ALL of these conditions will be retrieved.
|
|
- <any|all>: {fields: [<field1>, <field2>, ...], pattern: <pattern>} Will query for items where any or all
|
|
provided fields match the provided pattern.
|
|
:return: mongoengine.Q query object
|
|
"""
|
|
return cls._prepare_query_no_company(
|
|
parameters, parameters_options
|
|
) & cls._prepare_perm_query(company, allow_public=allow_public)
|
|
|
|
@staticmethod
|
|
def _pop_matching_params(
|
|
patterns: Sequence[str], parameters: dict
|
|
) -> Mapping[str, Any]:
|
|
"""
|
|
Pop the parameters that match the specified patterns and return
|
|
the dictionary of matching parameters
|
|
For backwards compatibility with the previous version of the code
|
|
the None values are discarded
|
|
"""
|
|
if not patterns:
|
|
return {}
|
|
|
|
fields = set()
|
|
for pattern in patterns:
|
|
if pattern.endswith("*"):
|
|
prefix = pattern[:-1]
|
|
fields.update(
|
|
{field for field in parameters if field.startswith(prefix)}
|
|
)
|
|
elif pattern in parameters:
|
|
fields.add(pattern)
|
|
|
|
pairs = ((field, parameters.pop(field, None)) for field in fields)
|
|
return {k: v for k, v in pairs if v is not None}
|
|
|
|
@classmethod
|
|
def _prepare_query_no_company(
|
|
cls, parameters=None, parameters_options=QueryParameterOptions()
|
|
):
|
|
"""
|
|
Prepare a query object based on the provided query dictionary and various fields.
|
|
|
|
NOTE: BE VERY CAREFUL WITH THIS CALL, as it allows creating queries that span across companies.
|
|
|
|
:param parameters_options: Specifies options for parsing the parameters (see ParametersOptions)
|
|
:param parameters: Query dictionary (relevant keys are these specified by the various field names parameters).
|
|
Supported parameters:
|
|
- <field_name>: <value> Will query for items with this value in the field (see QueryParameterOptions for
|
|
specific rules on handling values). Only items matching ALL of these conditions will be retrieved.
|
|
- <any|all>: {fields: [<field1>, <field2>, ...], pattern: <pattern>} Will query for items where any or all
|
|
provided fields match the provided pattern.
|
|
:return: mongoengine.Q query object
|
|
"""
|
|
parameters_options = parameters_options or cls.get_all_query_options
|
|
dict_query = {}
|
|
query = RegexQ()
|
|
if parameters:
|
|
parameters = parameters.copy()
|
|
opts = parameters_options
|
|
for field in opts.pattern_fields:
|
|
pattern = parameters.pop(field, None)
|
|
if pattern:
|
|
dict_query[field] = RegexWrapper(pattern)
|
|
|
|
for field, data in cls._pop_matching_params(
|
|
patterns=opts.list_fields, parameters=parameters
|
|
).items():
|
|
query &= cls.get_list_field_query(field, data)
|
|
|
|
for field, data in cls._pop_matching_params(
|
|
patterns=opts.range_fields, parameters=parameters
|
|
).items():
|
|
query &= cls.get_range_field_query(field, data)
|
|
|
|
for field in opts.fields or []:
|
|
data = parameters.pop(field, None)
|
|
if data is not None:
|
|
dict_query[field] = data
|
|
|
|
for field in opts.datetime_fields or []:
|
|
data = parameters.pop(field, None)
|
|
if data is not None:
|
|
if not isinstance(data, list):
|
|
data = [data]
|
|
for d in data: # type: str
|
|
m = ACCESS_REGEX.match(d)
|
|
if not m:
|
|
continue
|
|
try:
|
|
value = parse_datetime(m.group("value"))
|
|
prefix = m.group("prefix")
|
|
modifier = ACCESS_MODIFIER.get(prefix)
|
|
f = field if not modifier else "__".join((field, modifier))
|
|
dict_query[f] = value
|
|
except (ValueError, OverflowError):
|
|
pass
|
|
|
|
for field, value in parameters.items():
|
|
for keys, func in cls._multi_field_param_prefix.items():
|
|
if field not in keys:
|
|
continue
|
|
try:
|
|
data = cls.MultiFieldParameters(**value)
|
|
except Exception:
|
|
raise MakeGetAllQueryError("incorrect field format", field)
|
|
if not data.fields:
|
|
break
|
|
regex = RegexWrapper(data.pattern, flags=re.IGNORECASE)
|
|
sep_fields = [f.replace(".", "__") for f in data.fields]
|
|
q = reduce(
|
|
lambda a, x: func(a, RegexQ(**{x: regex})), sep_fields, RegexQ()
|
|
)
|
|
query = query & q
|
|
|
|
return query & RegexQ(**dict_query)
|
|
|
|
@classmethod
|
|
def get_range_field_query(cls, field: str, data: Sequence[Optional[str]]) -> Q:
|
|
"""
|
|
Return a range query for the provided field. The data should contain min and max values
|
|
Both intervals are included. For open range queries either min or max can be None
|
|
In case the min value is None the records with missing or None value from db are included
|
|
"""
|
|
if not isinstance(data, (list, tuple)) or len(data) != 2:
|
|
raise errors.bad_request.ValidationError(
|
|
f"Min and max values should be specified for range field {field}"
|
|
)
|
|
|
|
min_val, max_val = data
|
|
if min_val is None and max_val is None:
|
|
raise errors.bad_request.ValidationError(
|
|
f"At least one of min or max values should be provided for field {field}"
|
|
)
|
|
|
|
mongoengine_field = field.replace(".", "__")
|
|
query = {}
|
|
if min_val is not None:
|
|
query[f"{mongoengine_field}__gte"] = min_val
|
|
if max_val is not None:
|
|
query[f"{mongoengine_field}__lte"] = max_val
|
|
|
|
q = Q(**query)
|
|
if min_val is None:
|
|
q |= Q(**{mongoengine_field: None})
|
|
|
|
return q
|
|
|
|
@classmethod
|
|
def get_list_field_query(cls, field: str, data: Sequence[Optional[str]]) -> Q:
|
|
"""
|
|
Get a proper mongoengine Q object that represents an "or" query for the provided values
|
|
with respect to the given list field, with support for "none of empty" in case a None value
|
|
is included.
|
|
|
|
- Exclusion can be specified by a leading "-" for each value (API versions <2.8)
|
|
or by a preceding "__$not" value (operator)
|
|
- AND can be achieved using a preceding "__$all" or "__$and" value (operator)
|
|
"""
|
|
if not isinstance(data, (list, tuple)):
|
|
data = [data]
|
|
# raise MakeGetAllQueryError("expected list", field)
|
|
|
|
# TODO: backwards compatibility only for older API versions
|
|
helper = cls.ListFieldBucketHelper(legacy=True)
|
|
actions = bucketize(
|
|
data, key=helper.key, value_transform=helper.value_transform
|
|
)
|
|
|
|
allow_empty = None in actions.get("in", {})
|
|
mongoengine_field = field.replace(".", "__")
|
|
|
|
q = RegexQ()
|
|
for action in filter(None, actions):
|
|
q &= RegexQ(
|
|
**{
|
|
f"{mongoengine_field}__{action}": list(
|
|
set(filter(None, actions[action]))
|
|
)
|
|
}
|
|
)
|
|
|
|
if not allow_empty:
|
|
return q
|
|
|
|
return (
|
|
q
|
|
| Q(**{f"{mongoengine_field}__exists": False})
|
|
| Q(**{mongoengine_field: []})
|
|
)
|
|
|
|
@classmethod
|
|
def _prepare_perm_query(cls, company, allow_public=False):
|
|
if allow_public:
|
|
return get_company_or_none_constraint(company)
|
|
return Q(company=company)
|
|
|
|
@classmethod
|
|
def validate_order_by(cls, parameters, search_text) -> Sequence:
|
|
"""
|
|
Validate and extract order_by params as a list
|
|
"""
|
|
order_by = parameters.get(cls._ordering_key)
|
|
if not order_by:
|
|
return []
|
|
|
|
order_by = order_by if isinstance(order_by, list) else [order_by]
|
|
order_by = [cls._text_score if x == "@text_score" else x for x in order_by]
|
|
if not search_text and cls._text_score in order_by:
|
|
raise errors.bad_request.FieldsValueError(
|
|
"text score cannot be used in order_by when search text is not used"
|
|
)
|
|
|
|
return order_by
|
|
|
|
@classmethod
|
|
def validate_paging(
|
|
cls, parameters=None, default_page=None, default_page_size=None
|
|
):
|
|
""" Validate and extract paging info from from the provided dictionary. Supports default values. """
|
|
if parameters is None:
|
|
parameters = {}
|
|
default_page = parameters.get("page", default_page)
|
|
if default_page is None:
|
|
return None, None
|
|
default_page_size = parameters.get("page_size", default_page_size)
|
|
if not default_page_size:
|
|
raise errors.bad_request.MissingRequiredFields(
|
|
"page_size is required when page is requested", field="page_size"
|
|
)
|
|
elif default_page < 0:
|
|
raise errors.bad_request.ValidationError("page must be >=0", field="page")
|
|
elif default_page_size < 1:
|
|
raise errors.bad_request.ValidationError(
|
|
"page_size must be >0", field="page_size"
|
|
)
|
|
return default_page, default_page_size
|
|
|
|
@classmethod
|
|
def get_projection(cls, parameters, override_projection=None, **__):
|
|
""" Extract a projection list from the provided dictionary. Supports an override projection. """
|
|
if override_projection is not None:
|
|
return override_projection
|
|
if not parameters:
|
|
return []
|
|
return parameters.get(cls._projection_key) or parameters.get("only_fields", [])
|
|
|
|
@classmethod
|
|
def split_projection(
|
|
cls, projection: Sequence[str]
|
|
) -> Tuple[Collection[str], Collection[str]]:
|
|
"""Return include and exclude lists based on passed projection and class definition"""
|
|
if projection:
|
|
include, exclude = partition(
|
|
projection, key=lambda x: x[0] != ProjectionHelper.exclusion_prefix,
|
|
)
|
|
else:
|
|
include, exclude = [], []
|
|
exclude = {x.lstrip(ProjectionHelper.exclusion_prefix) for x in exclude}
|
|
return include, set(cls.get_exclude_fields()).union(exclude).difference(include)
|
|
|
|
@classmethod
|
|
def set_projection(cls, parameters: dict, value: Sequence[str]) -> Sequence[str]:
|
|
parameters.pop("only_fields", None)
|
|
parameters[cls._projection_key] = value
|
|
return value
|
|
|
|
@classmethod
|
|
def get_ordering(cls, parameters: dict) -> Optional[Sequence[str]]:
|
|
return parameters.get(cls._ordering_key)
|
|
|
|
@classmethod
|
|
def set_ordering(cls, parameters: dict, value: Sequence[str]) -> Sequence[str]:
|
|
parameters[cls._ordering_key] = value
|
|
return value
|
|
|
|
@classmethod
|
|
def set_default_ordering(cls, parameters: dict, value: Sequence[str]) -> None:
|
|
cls.set_ordering(parameters, cls.get_ordering(parameters) or value)
|
|
|
|
@classmethod
|
|
def get_many_with_join(
|
|
cls,
|
|
company,
|
|
query_dict=None,
|
|
query_options=None,
|
|
query=None,
|
|
allow_public=False,
|
|
override_projection=None,
|
|
expand_reference_ids=True,
|
|
):
|
|
"""
|
|
Fetch all documents matching a provided query with support for joining referenced documents according to the
|
|
requested projection. See get_many() for more info.
|
|
:param expand_reference_ids: If True, reference fields that contain just an ID string are expanded into
|
|
a sub-document in the format {_id: <ID>}. Otherwise, field values are left as a string.
|
|
"""
|
|
if issubclass(cls, AuthDocument):
|
|
# Refuse projection (join) for auth documents (auth.User etc.) to avoid inadvertently disclosing
|
|
# auth-related secrets and prevent security leaks
|
|
log.error(
|
|
f"Attempted projection of {cls.__name__} auth document (ignored)",
|
|
stack_info=True,
|
|
)
|
|
return []
|
|
|
|
override_projection = cls.get_projection(
|
|
parameters=query_dict, override_projection=override_projection
|
|
)
|
|
|
|
helper = ProjectionHelper(
|
|
doc_cls=cls,
|
|
projection=override_projection,
|
|
expand_reference_ids=expand_reference_ids,
|
|
)
|
|
|
|
# Make the main query
|
|
results = cls.get_many(
|
|
override_projection=helper.doc_projection,
|
|
company=company,
|
|
parameters=query_dict,
|
|
query_dict=query_dict,
|
|
query=query,
|
|
query_options=query_options,
|
|
allow_public=allow_public,
|
|
)
|
|
|
|
def projection_func(doc_type, projection, ids):
|
|
return doc_type.get_many_with_join(
|
|
company=company,
|
|
override_projection=projection,
|
|
query=Q(id__in=ids),
|
|
expand_reference_ids=expand_reference_ids,
|
|
allow_public=allow_public,
|
|
)
|
|
|
|
return helper.project(results, projection_func)
|
|
|
|
@classmethod
|
|
def get_many(
|
|
cls,
|
|
company,
|
|
parameters: dict = None,
|
|
query_dict: dict = None,
|
|
query_options: QueryParameterOptions = None,
|
|
query: Q = None,
|
|
allow_public=False,
|
|
override_projection: Collection[str] = None,
|
|
return_dicts=True,
|
|
):
|
|
"""
|
|
Fetch all documents matching a provided query. Supported several built-in options
|
|
(aside from those provided by the parameters):
|
|
- Ordering: using query field `order_by` which can contain a string or a list of strings corresponding to
|
|
field names. Using field names not defined in the document will cause an error.
|
|
- Paging: using query fields page and page_size. page must be larger than or equal to 0, page_size must be
|
|
larger than 0 and is required when specifying a page.
|
|
- Text search: using query field `search_text`. If used, text score can be used in the ordering, using the
|
|
`@text_score` keyword. A text index must be defined on the document type, otherwise an error will
|
|
be raised.
|
|
:param return_dicts: Return a list of dictionaries. If True, a list of dicts is returned (if projection was
|
|
requested, each contains only the requested projection). If False, a QuerySet object is returned
|
|
(lazy evaluated). If return_dicts is requested then the entities with the None value in order_by field
|
|
are returned last in the ordering.
|
|
:param company: Company ID (required)
|
|
:param parameters: Parameters dict from which paging ordering and searching parameters are extracted.
|
|
:param query_dict: If provided, passed to prepare_query() along with all of the relevant arguments to produce
|
|
a query. The resulting query is AND'ed with the `query` parameter (if provided).
|
|
:param query_options: query parameters options (see ParametersOptions)
|
|
:param query: Optional query object (mongoengine.Q)
|
|
:param override_projection: A list of projection fields overriding any projection specified in the `param_dict`
|
|
argument
|
|
:param allow_public: If True, objects marked as public (no associated company) are also queried.
|
|
:return: A list of objects matching the query.
|
|
"""
|
|
if query_dict is not None:
|
|
q = cls.prepare_query(
|
|
parameters=query_dict,
|
|
company=company,
|
|
parameters_options=query_options,
|
|
allow_public=allow_public,
|
|
)
|
|
else:
|
|
q = cls._prepare_perm_query(company, allow_public=allow_public)
|
|
_query = (q & query) if query else q
|
|
|
|
if return_dicts:
|
|
return cls._get_many_override_none_ordering(
|
|
query=_query,
|
|
parameters=parameters,
|
|
override_projection=override_projection,
|
|
)
|
|
|
|
return cls._get_many_no_company(
|
|
query=_query, parameters=parameters, override_projection=override_projection
|
|
)
|
|
|
|
@classmethod
|
|
def get_many_public(
|
|
cls, query: Q = None, projection: Collection[str] = None,
|
|
):
|
|
"""
|
|
Fetch all public documents matching a provided query.
|
|
:param query: Optional query object (mongoengine.Q).
|
|
:param projection: A list of projection fields.
|
|
:return: A list of documents matching the query.
|
|
"""
|
|
q = get_company_or_none_constraint()
|
|
_query = (q & query) if query else q
|
|
|
|
return cls._get_many_no_company(query=_query, override_projection=projection)
|
|
|
|
@classmethod
|
|
def _get_many_no_company(
|
|
cls: Union["GetMixin", Document],
|
|
query: Q,
|
|
parameters=None,
|
|
override_projection=None,
|
|
):
|
|
"""
|
|
Fetch all documents matching a provided query.
|
|
This is a company-less version for internal uses. We assume the caller has either added any necessary
|
|
constraints to the query or that no constraints are required.
|
|
|
|
NOTE: BE VERY CAREFUL WITH THIS CALL, as it allows returning data across companies.
|
|
|
|
:param query: Query object (mongoengine.Q)
|
|
:param parameters: Parameters dict from which paging ordering and searching parameters are extracted.
|
|
:param override_projection: A list of projection fields overriding any projection specified in the `param_dict`
|
|
argument
|
|
"""
|
|
if not query:
|
|
raise ValueError("query or call_data must be provided")
|
|
|
|
parameters = parameters or {}
|
|
search_text = parameters.get(cls._search_text_key)
|
|
order_by = cls.validate_order_by(parameters=parameters, search_text=search_text)
|
|
page, page_size = cls.validate_paging(parameters=parameters)
|
|
include, exclude = cls.split_projection(
|
|
cls.get_projection(parameters, override_projection)
|
|
)
|
|
|
|
qs = cls.objects(query)
|
|
if search_text:
|
|
qs = qs.search_text(search_text)
|
|
if order_by:
|
|
# add ordering
|
|
qs = qs.order_by(*order_by)
|
|
|
|
if include:
|
|
# add projection
|
|
qs = qs.only(*include)
|
|
|
|
if exclude:
|
|
qs = qs.exclude(*exclude)
|
|
|
|
if page is not None and page_size:
|
|
# add paging
|
|
qs = qs.skip(page * page_size).limit(page_size)
|
|
|
|
return qs
|
|
|
|
@classmethod
|
|
def _get_many_override_none_ordering(
|
|
cls: Union[Document, "GetMixin"],
|
|
query: Q = None,
|
|
parameters: dict = None,
|
|
override_projection: Collection[str] = None,
|
|
) -> Sequence[dict]:
|
|
"""
|
|
Fetch all documents matching a provided query. For the first order by field
|
|
the None values are sorted in the end regardless of the sorting order.
|
|
If the first order field is a user defined parameter (either from execution.parameters,
|
|
or from last_metrics) then the collation is set that sorts strings in numeric order where possible.
|
|
This is a company-less version for internal uses. We assume the caller has either added any necessary
|
|
constraints to the query or that no constraints are required.
|
|
|
|
NOTE: BE VERY CAREFUL WITH THIS CALL, as it allows returning data across companies.
|
|
|
|
:param query: Query object (mongoengine.Q)
|
|
:param parameters: Parameters dict from which paging ordering and searching parameters are extracted.
|
|
:param override_projection: A list of projection fields overriding any projection specified in the `param_dict`
|
|
argument
|
|
"""
|
|
if not query:
|
|
raise ValueError("query or call_data must be provided")
|
|
|
|
parameters = parameters or {}
|
|
search_text = parameters.get(cls._search_text_key)
|
|
order_by = cls.validate_order_by(parameters=parameters, search_text=search_text)
|
|
page, page_size = cls.validate_paging(parameters=parameters)
|
|
include, exclude = cls.split_projection(
|
|
cls.get_projection(parameters, override_projection)
|
|
)
|
|
|
|
query_sets = [cls.objects(query)]
|
|
if order_by:
|
|
order_field = first(
|
|
field for field in order_by if not field.startswith("$")
|
|
)
|
|
if (
|
|
order_field
|
|
and not order_field.startswith("-")
|
|
and "[" not in order_field
|
|
):
|
|
params = {}
|
|
mongo_field = order_field.replace(".", "__")
|
|
if mongo_field in cls.get_field_names_for_type(of_type=ListField):
|
|
params["is_list"] = True
|
|
elif mongo_field in cls.get_field_names_for_type(of_type=StringField):
|
|
params["empty_value"] = ""
|
|
non_empty = query & field_exists(mongo_field, **params)
|
|
empty = query & field_does_not_exist(mongo_field, **params)
|
|
query_sets = [cls.objects(non_empty), cls.objects(empty)]
|
|
|
|
query_sets = [qs.order_by(*order_by) for qs in query_sets]
|
|
if order_field:
|
|
collation_override = first(
|
|
v
|
|
for k, v in cls._field_collation_overrides.items()
|
|
if order_field.startswith(k)
|
|
)
|
|
if collation_override:
|
|
query_sets = [
|
|
qs.collation(collation=collation_override) for qs in query_sets
|
|
]
|
|
|
|
if search_text:
|
|
query_sets = [qs.search_text(search_text) for qs in query_sets]
|
|
|
|
if include:
|
|
# add projection
|
|
query_sets = [qs.only(*include) for qs in query_sets]
|
|
|
|
if exclude:
|
|
query_sets = [qs.exclude(*exclude) for qs in query_sets]
|
|
|
|
if page is None or not page_size:
|
|
return [obj.to_proper_dict(only=include) for qs in query_sets for obj in qs]
|
|
|
|
# add paging
|
|
ret = []
|
|
start = page * page_size
|
|
for qs in query_sets:
|
|
qs_size = qs.count()
|
|
if qs_size < start:
|
|
start -= qs_size
|
|
continue
|
|
ret.extend(
|
|
obj.to_proper_dict(only=include)
|
|
for obj in qs.skip(start).limit(page_size)
|
|
)
|
|
if len(ret) >= page_size:
|
|
break
|
|
start = 0
|
|
page_size -= len(ret)
|
|
|
|
return ret
|
|
|
|
@classmethod
|
|
def get_for_writing(
|
|
cls, *args, _only: Collection[str] = None, **kwargs
|
|
) -> "GetMixin":
|
|
if _only and "company" not in _only:
|
|
_only = list(set(_only) | {"company"})
|
|
result = cls.get(*args, _only=_only, include_public=True, **kwargs)
|
|
if result and not result.company:
|
|
object_name = cls.__name__.lower()
|
|
raise errors.forbidden.NoWritePermission(
|
|
f"cannot modify public {object_name}(s), ids={(result.id,)}"
|
|
)
|
|
return result
|
|
|
|
@classmethod
|
|
def get_many_for_writing(cls, company, *args, **kwargs):
|
|
result = cls.get_many(
|
|
company=company,
|
|
*args,
|
|
**dict(return_dicts=False, **kwargs),
|
|
allow_public=True,
|
|
)
|
|
forbidden_objects = {obj.id for obj in result if not obj.company}
|
|
if forbidden_objects:
|
|
object_name = cls.__name__.lower()
|
|
raise errors.forbidden.NoWritePermission(
|
|
f"cannot modify public {object_name}(s), ids={tuple(forbidden_objects)}"
|
|
)
|
|
return result
|
|
|
|
|
|
class UpdateMixin(object):
|
|
__user_set_allowed_fields = None
|
|
__locked_when_published_fields = None
|
|
|
|
@classmethod
|
|
def user_set_allowed(cls):
|
|
if cls.__user_set_allowed_fields is None:
|
|
cls.__user_set_allowed_fields = dict(
|
|
get_fields_choices(cls, "user_set_allowed")
|
|
)
|
|
return cls.__user_set_allowed_fields
|
|
|
|
@classmethod
|
|
def locked_when_published(cls):
|
|
if cls.__locked_when_published_fields is None:
|
|
cls.__locked_when_published_fields = dict(
|
|
get_fields_choices(cls, "locked_when_published")
|
|
)
|
|
return cls.__locked_when_published_fields
|
|
|
|
@classmethod
|
|
def get_safe_update_dict(cls, fields):
|
|
if not fields:
|
|
return {}
|
|
valid_fields = cls.user_set_allowed()
|
|
fields = [(k, v, fields[k]) for k, v in valid_fields.items() if k in fields]
|
|
update_dict = {
|
|
field: value
|
|
for field, allowed, value in fields
|
|
if allowed is None
|
|
or (
|
|
(value in allowed)
|
|
if not isinstance(value, list)
|
|
else all(v in allowed for v in value)
|
|
)
|
|
}
|
|
return update_dict
|
|
|
|
@classmethod
|
|
def safe_update(
|
|
cls: Union["UpdateMixin", Document],
|
|
company_id,
|
|
id,
|
|
partial_update_dict,
|
|
injected_update=None,
|
|
):
|
|
update_dict = cls.get_safe_update_dict(partial_update_dict)
|
|
if not update_dict:
|
|
return 0, {}
|
|
if injected_update:
|
|
update_dict.update(injected_update)
|
|
update_count = cls.objects(id=id, company=company_id).update(
|
|
upsert=False, **update_dict
|
|
)
|
|
return update_count, update_dict
|
|
|
|
|
|
class DbModelMixin(GetMixin, ProperDictMixin, UpdateMixin):
|
|
""" Provide convenience methods for a subclass of mongoengine.Document """
|
|
|
|
@classmethod
|
|
def aggregate(
|
|
cls: Union["DbModelMixin", Document],
|
|
pipeline: Sequence[dict],
|
|
allow_disk_use=None,
|
|
**kwargs,
|
|
) -> CommandCursor:
|
|
"""
|
|
Aggregate objects of this document class according to the provided pipeline.
|
|
:param pipeline: a list of dictionaries describing the pipeline stages
|
|
:param allow_disk_use: if True, allow the server to use disk space if aggregation query cannot fit in memory.
|
|
If None, default behavior will be used (see apiserver.conf/mongo/aggregate/allow_disk_use)
|
|
:param kwargs: additional keyword arguments passed to mongoengine
|
|
:return:
|
|
"""
|
|
kwargs.update(
|
|
allowDiskUse=allow_disk_use
|
|
if allow_disk_use is not None
|
|
else config.get("apiserver.mongo.aggregate.allow_disk_use", True)
|
|
)
|
|
return cls.objects.aggregate(pipeline, **kwargs)
|
|
|
|
@classmethod
|
|
def set_public(
|
|
cls: Type[Document],
|
|
company_id: str,
|
|
ids: Sequence[str],
|
|
invalid_cls: Type[BaseError],
|
|
enabled: bool = True,
|
|
):
|
|
if enabled:
|
|
items = list(cls.objects(id__in=ids, company=company_id).only("id"))
|
|
update = dict(set__company_origin=company_id, set__company="")
|
|
else:
|
|
items = list(
|
|
cls.objects(
|
|
id__in=ids, company__in=(None, ""), company_origin=company_id
|
|
).only("id")
|
|
)
|
|
update = dict(set__company=company_id, unset__company_origin=1)
|
|
|
|
if len(items) < len(ids):
|
|
missing = tuple(set(ids).difference(i.id for i in items))
|
|
raise invalid_cls(ids=missing)
|
|
|
|
return {"updated": cls.objects(id__in=ids).update(**update)}
|
|
|
|
|
|
def validate_id(cls, company, **kwargs):
|
|
"""
|
|
Validate existence of objects with certain IDs. within company.
|
|
:param cls: Model class to search in
|
|
:param company: Company to search in
|
|
:param kwargs: Mapping of field name to object ID. If any ID does not have a corresponding object,
|
|
it will be reported along with the name it was assigned to.
|
|
:return:
|
|
"""
|
|
ids = set(kwargs.values())
|
|
objs = list(cls.objects(company=company, id__in=ids).only("id"))
|
|
missing = ids - set(x.id for x in objs)
|
|
if not missing:
|
|
return
|
|
id_to_name = {}
|
|
for name, obj_id in kwargs.items():
|
|
id_to_name.setdefault(obj_id, []).append(name)
|
|
raise errors.bad_request.ValidationError(
|
|
"Invalid {} ids".format(cls.__name__.lower()),
|
|
**{name: obj_id for obj_id in missing for name in id_to_name[obj_id]},
|
|
)
|