Initial commit

This commit is contained in:
allegroai
2019-06-11 00:24:35 +03:00
parent 6eea80c4a2
commit a6344bad57
138 changed files with 15951 additions and 0 deletions

View File

@@ -0,0 +1,58 @@
from jsonmodels import models
from jsonmodels.errors import ValidationError
from jsonmodels.fields import StringField
from mongoengine import register_connection
from mongoengine.connection import get_connection
from config import config
from .defs import Database
from .utils import get_items
log = config.logger(__file__)
strict = config.get('apiserver.mongo.strict', True)
_entries = []
class DatabaseEntry(models.Base):
host = StringField(required=True)
alias = StringField()
@property
def health_alias(self):
return '__health__' + self.alias
def initialize():
db_entries = config.get('hosts.mongo', {})
missing = []
log.info('Initializing database connections')
for key, alias in get_items(Database).items():
if key not in db_entries:
missing.append(key)
continue
entry = DatabaseEntry(alias=alias, **db_entries.get(key))
try:
entry.validate()
log.info('Registering connection to %(alias)s (%(host)s)' % entry.to_struct())
register_connection(alias=alias, host=entry.host)
_entries.append(entry)
except ValidationError as ex:
raise Exception('Invalid database entry `%s`: %s' % (key, ex.args[0]))
if missing:
raise ValueError('Missing database configuration for %s' % ', '.join(missing))
def get_entries():
return _entries
def get_aliases():
return [entry.alias for entry in get_entries()]
def reconnect():
for entry in get_entries():
get_connection(entry.alias, reconnect=True)

10
server/database/defs.py Normal file
View File

@@ -0,0 +1,10 @@
class Database(object):
""" Database names for our different DB instances """
backend = 'backend-db'
''' Used for all backend objects (tasks, models etc.) '''
auth = 'auth-db'
''' Used for all authentication and permission objects '''

189
server/database/errors.py Normal file
View File

@@ -0,0 +1,189 @@
import re
from contextlib import contextmanager
from functools import wraps
import dpath
from dpath.exceptions import InvalidKeyName
from elasticsearch import ElasticsearchException
from elasticsearch.helpers import BulkIndexError
from jsonmodels.errors import ValidationError as JsonschemaValidationError
from mongoengine.errors import (
ValidationError,
NotUniqueError,
FieldDoesNotExist,
InvalidDocumentError,
LookUpError,
InvalidQueryError,
)
from pymongo.errors import PyMongoError, NotMasterError
from apierrors import errors
class MakeGetAllQueryError(Exception):
def __init__(self, error, field):
super(MakeGetAllQueryError, self).__init__(f"{error}: field={field}")
self.error = error
self.field = field
class ParseCallError(Exception):
def __init__(self, msg, **kwargs):
super(ParseCallError, self).__init__(msg)
self.params = kwargs
def throws_default_error(err_cls):
"""
Used to make functions (Exception, str) -> Optional[str] searching for specialized error messages raise those
messages in ``err_cls``. If the decorated function does not find a suitable error message,
the underlying exception is returned.
:param err_cls: Error class (generated by apierrors)
"""
def decorator(func):
@wraps(func)
def wrapper(self, e, message, **kwargs):
extra_info = func(self, e, message, **kwargs)
raise err_cls(message, err=e, extra_info=extra_info)
return wrapper
return decorator
class ElasticErrorsHandler(object):
@classmethod
@throws_default_error(errors.server_error.DataError)
def bulk_error(cls, e, _, **__):
if not e.errors:
return
# Else try returning a better error string
for _, reason in dpath.search(e.errors[0], "*/error/reason", yielded=True):
return reason
class MongoEngineErrorsHandler(object):
# NotUniqueError
__not_unique_regex = re.compile(
r"collection:\s(?P<collection>[\w.]+)\sindex:\s(?P<index>\w+)\sdup\skey:\s{(?P<values>[^\}]+)\}"
)
__not_unique_value_regex = re.compile(r':\s"(?P<value>[^"]+)"')
__id_index = "_id_"
__index_sep_regex = re.compile(r"_[0-9]+_?")
# FieldDoesNotExist
__not_exist_fields_regex = re.compile(r'"{(?P<fields>.+?)}".+?"(?P<document>.+?)"')
__not_exist_field_regex = re.compile(r"'(?P<field>\w+)'")
@classmethod
def validation_error(cls, e: ValidationError, message, **_):
# Thrown when a document is validated. Documents are validated by default on save and on update
err_dict = e.errors or {e.field_name: e.message}
raise errors.bad_request.DataValidationError(message, **err_dict)
@classmethod
def not_unique_error(cls, e, message, **_):
# Thrown when a save/update violates a unique index constraint
m = cls.__not_unique_regex.search(str(e))
if not m:
raise errors.bad_request.ExpectedUniqueData(message, err=str(e))
values = cls.__not_unique_value_regex.findall(m.group("values"))
index = m.group("index")
if index == cls.__id_index:
fields = "id"
else:
fields = cls.__index_sep_regex.split(index)[:-1]
raise errors.bad_request.ExpectedUniqueData(
message, **dict(zip(fields, values))
)
@classmethod
def field_does_not_exist(cls, e, message, **kwargs):
# Strict mode. Unknown fields encountered in loaded document(s)
field_does_not_exist_cls = kwargs.get(
"field_does_not_exist_cls", errors.server_error.InconsistentData
)
m = cls.__not_exist_fields_regex.search(str(e))
params = {}
if m:
params["document"] = m.group("document")
fields = cls.__not_exist_field_regex.findall(m.group("fields"))
if fields:
if len(fields) > 1:
params["fields"] = "(%s)" % ", ".join(fields)
else:
params["field"] = fields[0]
raise field_does_not_exist_cls(message, **params)
@classmethod
@throws_default_error(errors.server_error.DataError)
def invalid_document_error(cls, e, message, **_):
# Reverse_delete_rule used in reference field
pass
@classmethod
def lookup_error(cls, e, message, **_):
raise errors.bad_request.InvalidFields(
"probably an invalid field name or unsupported nested field",
replacement_msg="Lookup error",
err=str(e),
)
@classmethod
@throws_default_error(errors.bad_request.InvalidRegexError)
def invalid_regex_error(cls, e, _, **__):
if e.args and e.args[0] == "unexpected end of regular expression":
raise errors.bad_request.InvalidRegexError(e.args[0])
@classmethod
@throws_default_error(errors.server_error.InternalError)
def invalid_query_error(cls, e, message, **_):
pass
@contextmanager
def translate_errors_context(message=None, **kwargs):
"""
A context manager that translates MongoEngine's and Elastic thrown errors into our apierrors classes,
with an appropriate message.
"""
try:
if message:
message = "while " + message
yield True
except ValidationError as e:
MongoEngineErrorsHandler.validation_error(e, message, **kwargs)
except NotUniqueError as e:
MongoEngineErrorsHandler.not_unique_error(e, message, **kwargs)
except FieldDoesNotExist as e:
MongoEngineErrorsHandler.field_does_not_exist(e, message, **kwargs)
except InvalidDocumentError as e:
MongoEngineErrorsHandler.invalid_document_error(e, message, **kwargs)
except LookUpError as e:
MongoEngineErrorsHandler.lookup_error(e, message, **kwargs)
except re.error as e:
MongoEngineErrorsHandler.invalid_regex_error(e, message, **kwargs)
except InvalidQueryError as e:
MongoEngineErrorsHandler.invalid_query_error(e, message, **kwargs)
except PyMongoError as e:
raise errors.server_error.InternalError(message, err=str(e))
except NotMasterError as e:
raise errors.server_error.InternalError(message, err=str(e))
except MakeGetAllQueryError as e:
raise errors.bad_request.ValidationError(e.error, field=e.field)
except ParseCallError as e:
raise errors.bad_request.FieldsValueError(e.args[0], **e.params)
except JsonschemaValidationError as e:
if len(e.args) >= 2:
raise errors.bad_request.ValidationError(e.args[0], reason=e.args[1])
raise errors.bad_request.ValidationError(e.args[0])
except BulkIndexError as e:
ElasticErrorsHandler.bulk_error(e, message, **kwargs)
except ElasticsearchException as e:
raise errors.server_error.DataError(e, message, **kwargs)
except InvalidKeyName:
raise errors.server_error.DataError("invalid empty key encountered in data")
except Exception as ex:
raise

237
server/database/fields.py Normal file
View File

@@ -0,0 +1,237 @@
import re
from sys import maxsize
import six
from mongoengine import (
EmbeddedDocumentListField,
ListField,
FloatField,
StringField,
EmbeddedDocumentField,
SortedListField,
MapField,
DictField,
)
class LengthRangeListField(ListField):
def __init__(self, field=None, max_length=maxsize, min_length=0, **kwargs):
self.__min_length = min_length
self.__max_length = max_length
super(LengthRangeListField, self).__init__(field, **kwargs)
def validate(self, value):
min, val, max = self.__min_length, len(value), self.__max_length
if not min <= val <= max:
self.error("Item count %d exceeds range [%d, %d]" % (val, min, max))
super(LengthRangeListField, self).validate(value)
class LengthRangeEmbeddedDocumentListField(LengthRangeListField):
def __init__(self, field=None, *args, **kwargs):
super(LengthRangeEmbeddedDocumentListField, self).__init__(
EmbeddedDocumentField(field), *args, **kwargs
)
class UniqueEmbeddedDocumentListField(EmbeddedDocumentListField):
def __init__(self, document_type, key, **kwargs):
"""
Create a unique embedded document list field for a document type with a unique comparison key func/property
:param document_type: The type of :class:`~mongoengine.EmbeddedDocument` the list will hold.
:param key: A callable to extract a key from each item
"""
if not callable(key):
raise KeyError("key must be callable")
self.__key = key
super(UniqueEmbeddedDocumentListField, self).__init__(document_type)
def validate(self, value):
if len({self.__key(i) for i in value}) != len(value):
self.error("Items with duplicate key exist in the list")
super(UniqueEmbeddedDocumentListField, self).validate(value)
def object_to_key_value_pairs(obj):
if isinstance(obj, dict):
return [(key, object_to_key_value_pairs(value)) for key, value in obj.items()]
if isinstance(obj, list):
return list(map(object_to_key_value_pairs, obj))
return obj
class EmbeddedDocumentSortedListField(EmbeddedDocumentListField):
"""
A sorted list of embedded documents
"""
def to_mongo(self, value, use_db_field=True, fields=None):
value = super(EmbeddedDocumentSortedListField, self).to_mongo(
value, use_db_field, fields
)
return sorted(value, key=object_to_key_value_pairs)
class LengthRangeSortedListField(LengthRangeListField, SortedListField):
pass
class CustomFloatField(FloatField):
def __init__(self, greater_than=None, **kwargs):
self.greater_than = greater_than
super(CustomFloatField, self).__init__(**kwargs)
def validate(self, value):
super(CustomFloatField, self).validate(value)
if self.greater_than is not None and value <= self.greater_than:
self.error("Float value must be greater than %s" % str(self.greater_than))
# TODO: bucket name should be at most 63 characters....
aws_s3_bucket_only_regex = (
r"^s3://"
r"(?:(?:\w[A-Z0-9\-]+\w)\.)*(?:\w[A-Z0-9\-]+\w)" # bucket name
)
aws_s3_url_with_bucket_regex = (
r"^s3://"
r"(?:(?:\w[A-Z0-9\-]+\w)\.)*(?:\w[A-Z0-9\-]+\w)" # bucket name
r"(?:(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+(?:[A-Z]{2,6}\.?|[A-Z0-9-]{2,}(?<!-)\.?))" # domain...
)
non_aws_s3_regex = (
r"^s3://"
r"(?:(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+(?:[A-Z]{2,6}\.?|[A-Z0-9-]{2,}(?<!-)\.?)|" # domain...
r"localhost|" # localhost...
r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}|" # ...or ipv4
r"\[?[A-F0-9]*:[A-F0-9:]+\]?)" # ...or ipv6
r"(?::\d+)?" # optional port
r"(?:/(?:(?:\w[A-Z0-9\-]+\w)\.)*(?:\w[A-Z0-9\-]+\w))" # bucket name
)
google_gs_bucket_only_regex = (
r"^gs://"
r"(?:(?:\w[A-Z0-9\-_]+\w)\.)*(?:\w[A-Z0-9\-_]+\w)" # bucket name
)
file_regex = r"^file://"
generic_url_regex = (
r"^%s://" # scheme placeholder
r"(?:(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+(?:[A-Z]{2,6}\.?|[A-Z0-9-]{2,}(?<!-)\.?)|" # domain...
r"localhost|" # localhost...
r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}|" # ...or ipv4
r"\[?[A-F0-9]*:[A-F0-9:]+\]?)" # ...or ipv6
r"(?::\d+)?" # optional port
)
path_suffix = r"(?:/?|[/?]\S+)$"
file_path_suffix = r"(?:/\S*[^/]+)$"
class _RegexURLField(StringField):
_regex = []
def __init__(self, regex, **kwargs):
super(_RegexURLField, self).__init__(**kwargs)
regex = regex if isinstance(regex, (tuple, list)) else [regex]
self._regex = [
re.compile(e, re.IGNORECASE) if isinstance(e, six.string_types) else e
for e in regex
]
def validate(self, value):
# Check first if the scheme is valid
if not any(regex for regex in self._regex if regex.match(value)):
self.error("Invalid URL: {}".format(value))
return
class OutputDestinationField(_RegexURLField):
""" A field representing task output URL """
schemes = ["s3", "gs", "file"]
_expressions = (
aws_s3_bucket_only_regex + path_suffix,
aws_s3_url_with_bucket_regex + path_suffix,
non_aws_s3_regex + path_suffix,
google_gs_bucket_only_regex + path_suffix,
file_regex + path_suffix,
)
def __init__(self, **kwargs):
super(OutputDestinationField, self).__init__(self._expressions, **kwargs)
class SupportedURLField(_RegexURLField):
""" A field representing a model URL """
schemes = ["s3", "gs", "file", "http", "https"]
_expressions = tuple(
pattern + file_path_suffix
for pattern in (
aws_s3_bucket_only_regex,
aws_s3_url_with_bucket_regex,
non_aws_s3_regex,
google_gs_bucket_only_regex,
file_regex,
(generic_url_regex % "http"),
(generic_url_regex % "https"),
)
)
def __init__(self, **kwargs):
super(SupportedURLField, self).__init__(self._expressions, **kwargs)
class StrippedStringField(StringField):
def __init__(
self, regex=None, max_length=None, min_length=None, strip_chars=None, **kwargs
):
super(StrippedStringField, self).__init__(
regex, max_length, min_length, **kwargs
)
self._strip_chars = strip_chars
def __set__(self, instance, value):
if value is not None:
try:
value = value.strip(self._strip_chars)
except AttributeError:
pass
super(StrippedStringField, self).__set__(instance, value)
def prepare_query_value(self, op, value):
if not isinstance(op, six.string_types):
return value
if value is not None:
value = value.strip(self._strip_chars)
return super(StrippedStringField, self).prepare_query_value(op, value)
def contains_empty_key(d):
"""
Helper function to recursively determine if any key in a
dictionary is empty (based on mongoengine.fields.key_not_string)
"""
for k, v in list(d.items()):
if not k or (isinstance(v, dict) and contains_empty_key(v)):
return True
class SafeMapField(MapField):
def validate(self, value):
super(SafeMapField, self).validate(value)
if contains_empty_key(value):
self.error("Empty keys are not allowed in a MapField")
class SafeDictField(DictField):
def validate(self, value):
super(SafeDictField, self).validate(value)
if contains_empty_key(value):
self.error("Empty keys are not allowed in a DictField")

View File

@@ -0,0 +1,56 @@
from mongoengine import Document, StringField
from apierrors import errors
from database.model.base import DbModelMixin, ABSTRACT_FLAG
from database.model.company import Company
from database.model.user import User
class AttributedDocument(DbModelMixin, Document):
"""
Represents objects which are attributed to a company and a user or to "no one".
Company must be required since it can be used as unique field.
"""
meta = ABSTRACT_FLAG
company = StringField(required=True, reference_field=Company)
user = StringField(reference_field=User)
def is_public(self) -> bool:
return bool(self.company)
class PrivateDocument(AttributedDocument):
"""
Represents documents which always belong to a single company
"""
meta = ABSTRACT_FLAG
# can not have an empty string as this is the "public" marker
company = StringField(required=True, reference_field=Company, min_length=1)
user = StringField(reference_field=User, required=True)
def is_public(self) -> bool:
return False
def validate_id(cls, company, **kwargs):
"""
Validate existence of objects with certain IDs. within company.
:param cls: Model class to search in
:param company: Company to search in
:param kwargs: Mapping of field name to object ID. If any ID does not have a corresponding object,
it will be reported along with the name it was assigned to.
:return:
"""
ids = set(kwargs.values())
objs = list(cls.objects(company=company, id__in=ids).only('id'))
missing = ids - set(x.id for x in objs)
if not missing:
return
id_to_name = {}
for name, obj_id in kwargs.items():
id_to_name.setdefault(obj_id, []).append(name)
raise errors.bad_request.ValidationError(
'Invalid {} ids'.format(cls.__name__.lower()),
**{name: obj_id for obj_id in missing for name in id_to_name[obj_id]}
)

View File

@@ -0,0 +1,72 @@
from mongoengine import (
StringField,
EmbeddedDocument,
EmbeddedDocumentListField,
EmailField,
DateTimeField,
)
from database import Database, strict
from database.model import DbModelMixin
from database.model.base import AuthDocument
from database.utils import get_options
class Entities(object):
company = "company"
task = "task"
user = "user"
model = "model"
class Role(object):
system = "system"
""" Internal system component """
root = "root"
""" Root admin (person) """
admin = "admin"
""" Company administrator """
superuser = "superuser"
""" Company super user """
user = "user"
""" Company user """
annotator = "annotator"
""" Annotator with limited access"""
@classmethod
def get_system_roles(cls) -> set:
return {cls.system, cls.root}
@classmethod
def get_company_roles(cls) -> set:
return set(get_options(cls)) - cls.get_system_roles()
class Credentials(EmbeddedDocument):
key = StringField(required=True)
secret = StringField(required=True)
class User(DbModelMixin, AuthDocument):
meta = {"db_alias": Database.auth, "strict": strict}
id = StringField(primary_key=True)
name = StringField(unique_with="company")
created = DateTimeField()
""" User auth entry creation time """
validated = DateTimeField()
""" Last validation (login) time """
role = StringField(required=True, choices=get_options(Role), default=Role.user)
""" User role """
company = StringField(required=True)
""" Company this user belongs to """
credentials = EmbeddedDocumentListField(Credentials, default=list)
""" Credentials generated for this user """
email = EmailField(unique=True, required=True)
""" Email uniquely identifying the user """

View File

@@ -0,0 +1,529 @@
import re
from collections import namedtuple
from functools import reduce
from typing import Collection
from dateutil.parser import parse as parse_datetime
from mongoengine import Q, Document
from six import string_types
from apierrors import errors
from config import config
from database.errors import MakeGetAllQueryError
from database.projection import project_dict, ProjectionHelper
from database.props import PropsMixin
from database.query import RegexQ, RegexWrapper
from database.utils import get_company_or_none_constraint, get_fields_with_attr
log = config.logger("dbmodel")
ACCESS_REGEX = re.compile(r"^(?P<prefix>>=|>|<=|<)?(?P<value>.*)$")
ACCESS_MODIFIER = {">=": "gte", ">": "gt", "<=": "lte", "<": "lt"}
ABSTRACT_FLAG = {"abstract": True}
class AuthDocument(Document):
meta = ABSTRACT_FLAG
class ProperDictMixin(object):
def to_proper_dict(self, strip_private=True, only=None, extra_dict=None) -> dict:
return self.properize_dict(
self.to_mongo(use_db_field=False).to_dict(),
strip_private=strip_private,
only=only,
extra_dict=extra_dict,
)
@classmethod
def properize_dict(
cls, d, strip_private=True, only=None, extra_dict=None, normalize_id=True
):
res = d
if normalize_id and "_id" in res:
res["id"] = res.pop("_id")
if strip_private:
res = {k: v for k, v in res.items() if k[0] != "_"}
if only:
res = project_dict(res, only)
if extra_dict:
res.update(extra_dict)
return res
class GetMixin(PropsMixin):
_text_score = "$text_score"
_ordering_key = "order_by"
_multi_field_param_sep = "__"
_multi_field_param_prefix = {
("_any_", "_or_"): lambda a, b: a | b,
("_all_", "_and_"): lambda a, b: a & b,
}
MultiFieldParameters = namedtuple("MultiFieldParameters", "pattern fields")
class QueryParameterOptions(object):
def __init__(
self,
pattern_fields=("name",),
list_fields=("tags", "id"),
datetime_fields=None,
fields=None,
):
"""
:param pattern_fields: Fields for which a "string contains" condition should be generated
:param list_fields: Fields for which a "list contains" condition should be generated
:param datetime_fields: Fields for which datetime condition should be generated (see ACCESS_MODIFIER)
:param fields: Fields which which a simple equality condition should be generated (basically filters out all
other unsupported query fields)
"""
self.fields = fields
self.datetime_fields = datetime_fields
self.list_fields = list_fields
self.pattern_fields = pattern_fields
get_all_query_options = QueryParameterOptions()
@classmethod
def get(
cls, company, id, *, _only=None, include_public=False, **kwargs
) -> "GetMixin":
q = cls.objects(
cls._prepare_perm_query(company, allow_public=include_public)
& Q(id=id, **kwargs)
)
if _only:
q = q.only(*_only)
return q.first()
@classmethod
def prepare_query(
cls,
company: str,
parameters: dict = None,
parameters_options: QueryParameterOptions = None,
allow_public=False,
):
"""
Prepare a query object based on the provided query dictionary and various fields.
:param parameters_options: Specifies options for parsing the parameters (see ParametersOptions)
:param company: Company ID (required)
:param allow_public: Allow results from public objects
:param parameters: Query dictionary (relevant keys are these specified by the various field names parameters).
Supported parameters:
- <field_name>: <value> Will query for items with this value in the field (see QueryParameterOptions for
specific rules on handling values). Only items matching ALL of these conditions will be retrieved.
- <any|all>: {fields: [<field1>, <field2>, ...], pattern: <pattern>} Will query for items where any or all
provided fields match the provided pattern.
:return: mongoengine.Q query object
"""
return cls._prepare_query_no_company(
parameters, parameters_options
) & cls._prepare_perm_query(company, allow_public=allow_public)
@classmethod
def _prepare_query_no_company(
cls, parameters=None, parameters_options=QueryParameterOptions()
):
"""
Prepare a query object based on the provided query dictionary and various fields.
NOTE: BE VERY CAREFUL WITH THIS CALL, as it allows creating queries that span across companies.
:param parameters_options: Specifies options for parsing the parameters (see ParametersOptions)
:param parameters: Query dictionary (relevant keys are these specified by the various field names parameters).
Supported parameters:
- <field_name>: <value> Will query for items with this value in the field (see QueryParameterOptions for
specific rules on handling values). Only items matching ALL of these conditions will be retrieved.
- <any|all>: {fields: [<field1>, <field2>, ...], pattern: <pattern>} Will query for items where any or all
provided fields match the provided pattern.
:return: mongoengine.Q query object
"""
parameters_options = parameters_options or cls.get_all_query_options
dict_query = {}
query = RegexQ()
if parameters:
parameters = parameters.copy()
opts = parameters_options
for field in opts.pattern_fields:
pattern = parameters.pop(field, None)
if pattern:
dict_query[field] = RegexWrapper(pattern)
for field in tuple(opts.list_fields or ()):
data = parameters.pop(field, None)
if data:
if not isinstance(data, (list, tuple)):
raise MakeGetAllQueryError("expected list", field)
exclude = [t for t in data if t.startswith("-")]
include = list(set(data).difference(exclude))
mongoengine_field = field.replace(".", "__")
if include:
dict_query[f"{mongoengine_field}__in"] = include
if exclude:
dict_query[f"{mongoengine_field}__nin"] = [
t[1:] for t in exclude
]
for field in opts.fields or []:
data = parameters.pop(field, None)
if data is not None:
dict_query[field] = data
for field in opts.datetime_fields or []:
data = parameters.pop(field, None)
if data is not None:
if not isinstance(data, list):
data = [data]
for d in data: # type: str
m = ACCESS_REGEX.match(d)
if not m:
continue
try:
value = parse_datetime(m.group("value"))
prefix = m.group("prefix")
modifier = ACCESS_MODIFIER.get(prefix)
f = field if not modifier else "__".join((field, modifier))
dict_query[f] = value
except (ValueError, OverflowError):
pass
for field, value in parameters.items():
for keys, func in cls._multi_field_param_prefix.items():
if field not in keys:
continue
try:
data = cls.MultiFieldParameters(**value)
except Exception:
raise MakeGetAllQueryError("incorrect field format", field)
if not data.fields:
break
regex = RegexWrapper(data.pattern, flags=re.IGNORECASE)
sep_fields = [f.replace(".", "__") for f in data.fields]
q = reduce(
lambda a, x: func(a, RegexQ(**{x: regex})), sep_fields, RegexQ()
)
query = query & q
return query & RegexQ(**dict_query)
@classmethod
def _prepare_perm_query(cls, company, allow_public=False):
if allow_public:
return get_company_or_none_constraint(company)
return Q(company=company)
@classmethod
def validate_paging(
cls, parameters=None, default_page=None, default_page_size=None
):
""" Validate and extract paging info from from the provided dictionary. Supports default values. """
if parameters is None:
parameters = {}
default_page = parameters.get("page", default_page)
if default_page is None:
return None, None
default_page_size = parameters.get("page_size", default_page_size)
if not default_page_size:
raise errors.bad_request.MissingRequiredFields(
"page_size is required when page is requested", field="page_size"
)
elif default_page < 0:
raise errors.bad_request.ValidationError("page must be >=0", field="page")
elif default_page_size < 1:
raise errors.bad_request.ValidationError(
"page_size must be >0", field="page_size"
)
return default_page, default_page_size
@classmethod
def get_projection(cls, parameters, override_projection=None, **__):
""" Extract a projection list from the provided dictionary. Supports an override projection. """
if override_projection is not None:
return override_projection
if not parameters:
return []
return parameters.get("projection") or parameters.get("only_fields", [])
@classmethod
def set_default_ordering(cls, parameters, value):
parameters[cls._ordering_key] = parameters.get(cls._ordering_key) or value
@classmethod
def get_many_with_join(
cls,
company,
query_dict=None,
query_options=None,
query=None,
allow_public=False,
override_projection=None,
expand_reference_ids=True,
):
"""
Fetch all documents matching a provided query with support for joining referenced documents according to the
requested projection. See get_many() for more info.
:param expand_reference_ids: If True, reference fields that contain just an ID string are expanded into
a sub-document in the format {_id: <ID>}. Otherwise, field values are left as a string.
"""
if issubclass(cls, AuthDocument):
# Refuse projection (join) for auth documents (auth.User etc.) to avoid inadvertently disclosing
# auth-related secrets and prevent security leaks
log.error(
f"Attempted projection of {cls.__name__} auth document (ignored)",
stack_info=True,
)
return []
override_projection = cls.get_projection(
parameters=query_dict, override_projection=override_projection
)
helper = ProjectionHelper(
doc_cls=cls,
projection=override_projection,
expand_reference_ids=expand_reference_ids,
)
# Make the main query
results = cls.get_many(
override_projection=helper.doc_projection,
company=company,
parameters=query_dict,
query_dict=query_dict,
query=query,
query_options=query_options,
allow_public=allow_public,
)
def projection_func(doc_type, projection, ids):
return doc_type.get_many_with_join(
company=company,
override_projection=projection,
query=Q(id__in=ids),
expand_reference_ids=expand_reference_ids,
allow_public=allow_public,
)
return helper.project(results, projection_func)
@classmethod
def get_many(
cls,
company,
parameters: dict = None,
query_dict: dict = None,
query_options: QueryParameterOptions = None,
query: Q = None,
allow_public=False,
override_projection: Collection[str] = None,
return_dicts=True,
):
"""
Fetch all documents matching a provided query. Supported several built-in options
(aside from those provided by the parameters):
- Ordering: using query field `order_by` which can contain a string or a list of strings corresponding to
field names. Using field names not defined in the document will cause an error.
- Paging: using query fields page and page_size. page must be larger than or equal to 0, page_size must be
larger than 0 and is required when specifying a page.
- Text search: using query field `search_text`. If used, text score can be used in the ordering, using the
`@text_score` keyword. A text index must be defined on the document type, otherwise an error will
be raised.
:param return_dicts: Return a list of dictionaries. If True, a list of dicts is returned (if projection was
requested, each contains only the requested projection).
If False, a QuerySet object is returned (lazy evaluated)
:param company: Company ID (required)
:param parameters: Parameters dict from which paging ordering and searching parameters are extracted.
:param query_dict: If provided, passed to prepare_query() along with all of the relevant arguments to produce
a query. The resulting query is AND'ed with the `query` parameter (if provided).
:param query_options: query parameters options (see ParametersOptions)
:param query: Optional query object (mongoengine.Q)
:param override_projection: A list of projection fields overriding any projection specified in the `param_dict`
argument
:param allow_public: If True, objects marked as public (no associated company) are also queried.
:return: A list of objects matching the query.
"""
if query_dict is not None:
q = cls.prepare_query(
parameters=query_dict,
company=company,
parameters_options=query_options,
allow_public=allow_public,
)
else:
q = cls._prepare_perm_query(company, allow_public=allow_public)
_query = (q & query) if query else q
return cls._get_many_no_company(
query=_query,
parameters=parameters,
override_projection=override_projection,
return_dicts=return_dicts,
)
@classmethod
def _get_many_no_company(
cls, query, parameters=None, override_projection=None, return_dicts=True
):
"""
Fetch all documents matching a provided query.
This is a company-less version for internal uses. We assume the caller has either added any necessary
constraints to the query or that no constraints are required.
NOTE: BE VERY CAREFUL WITH THIS CALL, as it allows returning data across companies.
:param query: Query object (mongoengine.Q)
:param return_dicts: Return a list of dictionaries. If True, a list of dicts is returned (if projection was
requested, each contains only the requested projection).
If False, a QuerySet object is returned (lazy evaluated)
:param parameters: Parameters dict from which paging ordering and searching parameters are extracted.
:param override_projection: A list of projection fields overriding any projection specified in the `param_dict`
argument
"""
parameters = parameters or {}
if not query:
raise ValueError("query or call_data must be provided")
page, page_size = cls.validate_paging(parameters=parameters)
order_by = parameters.get(cls._ordering_key)
if order_by:
order_by = order_by if isinstance(order_by, list) else [order_by]
order_by = [cls._text_score if x == "@text_score" else x for x in order_by]
search_text = parameters.get("search_text")
only = cls.get_projection(parameters, override_projection)
if not search_text and order_by and cls._text_score in order_by:
raise errors.bad_request.FieldsValueError(
"text score cannot be used in order_by when search text is not used"
)
qs = cls.objects(query)
if search_text:
qs = qs.search_text(search_text)
if order_by:
# add ordering
qs = (
qs.order_by(order_by)
if isinstance(order_by, string_types)
else qs.order_by(*order_by)
)
if only:
# add projection
qs = qs.only(*only)
else:
exclude = set(cls.get_exclude_fields()).difference(only)
if exclude:
qs = qs.exclude(*exclude)
if page is not None and page_size:
# add paging
qs = qs.skip(page * page_size).limit(page_size)
if return_dicts:
return [obj.to_proper_dict(only=only) for obj in qs]
return qs
@classmethod
def get_for_writing(
cls, *args, _only: Collection[str] = None, **kwargs
) -> "GetMixin":
if _only and "company" not in _only:
_only = list(set(_only) | {"company"})
result = cls.get(*args, _only=_only, include_public=True, **kwargs)
if result and not result.company:
object_name = cls.__name__.lower()
raise errors.forbidden.NoWritePermission(
f"cannot modify public {object_name}(s), ids={(result.id,)}"
)
return result
@classmethod
def get_many_for_writing(cls, company, *args, **kwargs):
result = cls.get_many(
company=company,
*args,
**dict(return_dicts=False, **kwargs),
allow_public=True,
)
forbidden_objects = {obj.id for obj in result if not obj.company}
if forbidden_objects:
object_name = cls.__name__.lower()
raise errors.forbidden.NoWritePermission(
f"cannot modify public {object_name}(s), ids={tuple(forbidden_objects)}"
)
return result
class UpdateMixin(object):
@classmethod
def user_set_allowed(cls):
res = getattr(cls, "__user_set_allowed_fields", None)
if res is None:
res = cls.__user_set_allowed_fields = dict(
get_fields_with_attr(cls, "user_set_allowed")
)
return res
@classmethod
def get_safe_update_dict(cls, fields):
if not fields:
return {}
valid_fields = cls.user_set_allowed()
fields = [(k, v, fields[k]) for k, v in valid_fields.items() if k in fields]
update_dict = {
field: value
for field, allowed, value in fields
if allowed is None
or (
(value in allowed)
if not isinstance(value, list)
else all(v in allowed for v in value)
)
}
return update_dict
@classmethod
def safe_update(cls, company_id, id, partial_update_dict, injected_update=None):
update_dict = cls.get_safe_update_dict(partial_update_dict)
if not update_dict:
return 0, {}
if injected_update:
update_dict.update(injected_update)
update_count = cls.objects(id=id, company=company_id).update(
upsert=False, **update_dict
)
return update_count, update_dict
class DbModelMixin(GetMixin, ProperDictMixin, UpdateMixin):
""" Provide convenience methods for a subclass of mongoengine.Document """
pass
def validate_id(cls, company, **kwargs):
"""
Validate existence of objects with certain IDs. within company.
:param cls: Model class to search in
:param company: Company to search in
:param kwargs: Mapping of field name to object ID. If any ID does not have a corresponding object,
it will be reported along with the name it was assigned to.
:return:
"""
ids = set(kwargs.values())
objs = list(cls.objects(company=company, id__in=ids).only("id"))
missing = ids - set(x.id for x in objs)
if not missing:
return
id_to_name = {}
for name, obj_id in kwargs.items():
id_to_name.setdefault(obj_id, []).append(name)
raise errors.bad_request.ValidationError(
"Invalid {} ids".format(cls.__name__.lower()),
**{name: obj_id for obj_id in missing for name in id_to_name[obj_id]}
)

View File

@@ -0,0 +1,25 @@
from mongoengine import Document, EmbeddedDocument, EmbeddedDocumentField, StringField, Q
from database import Database, strict
from database.fields import StrippedStringField
from database.model import DbModelMixin
class CompanyDefaults(EmbeddedDocument):
cluster = StringField()
class Company(DbModelMixin, Document):
meta = {
'db_alias': Database.backend,
'strict': strict,
}
id = StringField(primary_key=True)
name = StrippedStringField(unique=True, min_length=3)
defaults = EmbeddedDocumentField(CompanyDefaults)
@classmethod
def _prepare_perm_query(cls, company, allow_public=False):
""" Override default behavior since a 'company' constraint is not supported for this document... """
return Q()

View File

@@ -0,0 +1,56 @@
from mongoengine import Document, StringField, DateTimeField, ListField, BooleanField
from database import Database, strict
from database.fields import SupportedURLField, StrippedStringField, SafeDictField
from database.model import DbModelMixin
from database.model.model_labels import ModelLabels
from database.model.company import Company
from database.model.project import Project
from database.model.task.task import Task
from database.model.user import User
class Model(DbModelMixin, Document):
meta = {
'db_alias': Database.backend,
'strict': strict,
'indexes': [
{
'name': '%s.model.main_text_index' % Database.backend,
'fields': [
'$name',
'$id',
'$comment',
'$parent',
'$task',
'$project',
],
'default_language': 'english',
'weights': {
'name': 10,
'id': 10,
'comment': 10,
'parent': 5,
'task': 3,
'project': 3,
}
}
],
}
id = StringField(primary_key=True)
name = StrippedStringField(user_set_allowed=True, min_length=3)
parent = StringField(reference_field='Model', required=False)
user = StringField(required=True, reference_field=User)
company = StringField(required=True, reference_field=Company)
project = StringField(reference_field=Project, user_set_allowed=True)
created = DateTimeField(required=True, user_set_allowed=True)
task = StringField(reference_field=Task)
comment = StringField(user_set_allowed=True)
tags = ListField(StringField(required=True), user_set_allowed=True)
uri = SupportedURLField(default='', user_set_allowed=True)
framework = StringField()
design = SafeDictField()
labels = ModelLabels()
ready = BooleanField(required=True)
ui_cache = SafeDictField(default=dict, user_set_allowed=True, exclude_by_default=True)

View File

@@ -0,0 +1,11 @@
from mongoengine import MapField, IntField
class ModelLabels(MapField):
def __init__(self, *args, **kwargs):
super(ModelLabels, self).__init__(field=IntField(), *args, **kwargs)
def validate(self, value):
super(ModelLabels, self).validate(value)
if value and (len(set(value.values())) < len(value)):
self.error("Same label id appears more than once in model labels")

View File

@@ -0,0 +1,39 @@
from mongoengine import StringField, DateTimeField, ListField
from database import Database, strict
from database.fields import OutputDestinationField, StrippedStringField
from database.model import AttributedDocument
from database.model.base import GetMixin
class Project(AttributedDocument):
get_all_query_options = GetMixin.QueryParameterOptions(
pattern_fields=("name", "description"), list_fields=("tags", "id")
)
meta = {
"db_alias": Database.backend,
"strict": strict,
"indexes": [
{
"name": "%s.project.main_text_index" % Database.backend,
"fields": ["$name", "$id", "$description"],
"default_language": "english",
"weights": {"name": 10, "id": 10, "description": 10},
}
],
}
id = StringField(primary_key=True)
name = StrippedStringField(
required=True,
unique_with=AttributedDocument.company.name,
min_length=3,
sparse=True,
)
description = StringField(required=True)
created = DateTimeField(required=True)
tags = ListField(StringField(required=True), default=list)
default_output_destination = OutputDestinationField()
last_update = DateTimeField()

View File

@@ -0,0 +1,14 @@
from mongoengine import EmbeddedDocument, StringField, DateTimeField, LongField, DynamicField
class MetricEvent(EmbeddedDocument):
metric = StringField(required=True, )
variant = StringField(required=True)
type = StringField(required=True)
timestamp = DateTimeField(default=0, required=True)
iter = LongField()
value = DynamicField(required=True)
@classmethod
def from_dict(cls, **kwargs):
return cls(**{k: v for k, v in kwargs.items() if k in cls._fields})

View File

@@ -0,0 +1,16 @@
from mongoengine import EmbeddedDocument, StringField
from database.utils import get_options
from database.fields import OutputDestinationField
class Result(object):
success = 'success'
failure = 'failure'
class Output(EmbeddedDocument):
destination = OutputDestinationField()
model = StringField(reference_field='Model')
error = StringField(user_set_allowed=True)
result = StringField(choices=get_options(Result))

View File

@@ -0,0 +1,132 @@
from enum import Enum
from mongoengine import (
StringField,
EmbeddedDocumentField,
EmbeddedDocument,
DateTimeField,
IntField,
ListField,
)
from database import Database, strict
from database.fields import StrippedStringField, SafeMapField, SafeDictField
from database.model import AttributedDocument
from database.model.model_labels import ModelLabels
from database.model.project import Project
from database.utils import get_options
from .metrics import MetricEvent
from .output import Output
DEFAULT_LAST_ITERATION = 0
class TaskStatus(object):
created = 'created'
in_progress = 'in_progress'
stopped = 'stopped'
publishing = 'publishing'
published = 'published'
closed = 'closed'
failed = 'failed'
unknown = 'unknown'
class TaskStatusMessage(object):
stopping = 'stopping'
class TaskTags(object):
development = 'development'
class Script(EmbeddedDocument):
binary = StringField(default='python')
repository = StringField(required=True)
tag = StringField()
branch = StringField()
version_num = StringField()
entry_point = StringField(required=True)
working_dir = StringField()
requirements = SafeDictField()
class Execution(EmbeddedDocument):
test_split = IntField(default=0)
parameters = SafeDictField(default=dict)
model = StringField(reference_field='Model')
model_desc = SafeMapField(StringField(default=''))
model_labels = ModelLabels()
framework = StringField()
queue = StringField()
''' Queue ID where task was queued '''
class TaskType(object):
training = 'training'
testing = 'testing'
class Task(AttributedDocument):
meta = {
'db_alias': Database.backend,
'strict': strict,
'indexes': [
'created',
'started',
'completed',
{
'name': '%s.task.main_text_index' % Database.backend,
'fields': [
'$name',
'$id',
'$comment',
'$execution.model',
'$output.model',
'$script.repository',
'$script.entry_point',
],
'default_language': 'english',
'weights': {
'name': 10,
'id': 10,
'comment': 10,
'execution.model': 2,
'output.model': 2,
'script.repository': 1,
'script.entry_point': 1,
},
},
],
}
id = StringField(primary_key=True)
name = StrippedStringField(
required=True, user_set_allowed=True, sparse=False, min_length=3
)
type = StringField(required=True, choices=get_options(TaskType))
status = StringField(default=TaskStatus.created, choices=get_options(TaskStatus))
status_reason = StringField()
status_message = StringField()
status_changed = DateTimeField()
comment = StringField(user_set_allowed=True)
created = DateTimeField(required=True, user_set_allowed=True)
started = DateTimeField()
completed = DateTimeField()
published = DateTimeField()
parent = StringField()
project = StringField(reference_field=Project, user_set_allowed=True)
output = EmbeddedDocumentField(Output, default=Output)
execution: Execution = EmbeddedDocumentField(Execution, default=Execution)
tags = ListField(StringField(required=True), user_set_allowed=True)
script = EmbeddedDocumentField(Script)
last_update = DateTimeField()
last_iteration = IntField(default=DEFAULT_LAST_ITERATION)
last_metrics = SafeMapField(field=SafeMapField(EmbeddedDocumentField(MetricEvent)))
class TaskVisibility(Enum):
active = 'active'
archived = 'archived'

View File

@@ -0,0 +1,21 @@
from mongoengine import Document, StringField
from database import Database, strict
from database.fields import SafeDictField
from database.model import DbModelMixin
from database.model.company import Company
class User(DbModelMixin, Document):
meta = {
'db_alias': Database.backend,
'strict': strict,
}
id = StringField(primary_key=True)
company = StringField(required=True, reference_field=Company)
name = StringField(required=True, user_set_allowed=True)
family_name = StringField(user_set_allowed=True)
given_name = StringField(user_set_allowed=True)
avatar = StringField()
preferences = SafeDictField(default=dict, exclude_by_default=True)

View File

@@ -0,0 +1,269 @@
from concurrent.futures import ThreadPoolExecutor
from itertools import groupby, chain
import dpath
from apierrors import errors
from database.props import PropsMixin
def project_dict(data, projection, separator='.'):
"""
Project partial data from a dictionary into a new dictionary
:param data: Input dictionary
:param projection: List of dictionary paths (each a string with field names separated using a separator)
:param separator: Separator (default is '.')
:return: A new dictionary containing only the projected parts from the original dictionary
"""
assert isinstance(data, dict)
result = {}
def copy_path(path_parts, source, destination):
src, dst = source, destination
try:
for depth, path_part in enumerate(path_parts[:-1]):
src_part = src[path_part]
if isinstance(src_part, dict):
src = src_part
dst = dst.setdefault(path_part, {})
elif isinstance(src_part, (list, tuple)):
if path_part not in dst:
dst[path_part] = [{} for _ in range(len(src_part))]
elif not isinstance(dst[path_part], (list, tuple)):
raise TypeError('Incompatible destination type %s for %s (list expected)'
% (type(dst), separator.join(path_parts[:depth + 1])))
elif not len(dst[path_part]) == len(src_part):
raise ValueError('Destination list length differs from source length for %s'
% separator.join(path_parts[:depth + 1]))
dst[path_part] = [copy_path(path_parts[depth + 1:], s, d)
for s, d in zip(src_part, dst[path_part])]
return destination
else:
raise TypeError('Unsupported projection type %s for %s'
% (type(src), separator.join(path_parts[:depth + 1])))
last_part = path_parts[-1]
dst[last_part] = src[last_part]
except KeyError:
# Projection field not in source, no biggie.
pass
return destination
for projection_path in sorted(projection):
copy_path(
path_parts=projection_path.split(separator),
source=data,
destination=result)
return result
class ProjectionHelper(object):
pool = ThreadPoolExecutor()
@property
def doc_projection(self):
return self._doc_projection
def __init__(self, doc_cls, projection, expand_reference_ids=False):
super(ProjectionHelper, self).__init__()
self._should_expand_reference_ids = expand_reference_ids
self._doc_cls = doc_cls
self._doc_projection = None
self._ref_projection = None
self._parse_projection(projection)
def _collect_projection_fields(self, doc_cls, projection):
"""
Collect projection for the given document into immediate document projection and reference documents projection
:param doc_cls: Document class
:param projection: List of projection fields
:return: A tuple of document projection and reference fields information
"""
doc_projection = set() # Projection fields for this class (used in the main query)
ref_projection_info = [] # Projection information for reference fields (used in join queries)
for field in projection:
for ref_field, ref_field_cls in doc_cls.get_reference_fields().items():
if not field.startswith(ref_field):
# Doesn't start with a reference field
continue
if field == ref_field:
# Field is exactly a reference field. In this case we won't perform any inner projection (for that,
# use '<reference field name>.*')
continue
subfield = field[len(ref_field):]
if not subfield.startswith('.'):
# Starts with something that looks like a reference field, but isn't
continue
ref_projection_info.append((ref_field, ref_field_cls, subfield[1:]))
break
else:
# Not a reference field, just add to the top-level projection
# We strip any trailing '*' since it means nothing for simple fields and for embedded documents
orig_field = field
if field.endswith('.*'):
field = field[:-2]
if not field:
raise errors.bad_request.InvalidFields(field=orig_field, object=doc_cls.__name__)
doc_projection.add(field)
return doc_projection, ref_projection_info
def _parse_projection(self, projection):
"""
Prepare the projection data structures for get_many_with_join().
:param projection: A list of field names that should be returned by the query. Sub-fields can be specified
using '.' (i.e. "parent.name"). A field terminated by '.*' indicated that all of the field's sub-fields
should be returned (only relevant for fields that represent sub-documents or referenced documents)
:type projection: list of strings
:returns A tuple of (class fields projection, reference fields projection)
"""
doc_cls = self._doc_cls
assert issubclass(doc_cls, PropsMixin)
if not projection:
return [], {}
doc_projection, ref_projection_info = self._collect_projection_fields(doc_cls, projection)
def normalize_cls_projection(cls_, fields):
""" Normalize projection for this class and group (expand *, for once) """
if '*' in fields:
return list(fields.difference('*').union(cls_.get_fields()))
return list(fields)
def compute_ref_cls_projection(cls_, group):
""" Compute inner projection for this class and group """
subfields = set([x[2] for x in group if x[2]])
return normalize_cls_projection(cls_, subfields)
def sort_key(proj_info):
return proj_info[:2]
# Aggregate by reference field. We'll leave out '*' from the projected items since
ref_projection = {
ref_field: dict(cls=ref_cls, only=compute_ref_cls_projection(ref_cls, g))
for (ref_field, ref_cls), g in groupby(sorted(ref_projection_info, key=sort_key), sort_key)
}
# Make sure this doesn't contain any reference field we'll join anyway
# (i.e. in case only_fields=[project, project.name])
doc_projection = normalize_cls_projection(doc_cls, doc_projection.difference(ref_projection).union({'id'}))
# Make sure that in case one or more field is a subfield of another field, we only use the the top-level field.
# This is done since in such a case, MongoDB will only use the most restrictive field (most nested field) and
# won't return some of the data we need.
# This way, we make sure to use the most inclusive field that contains all requested subfields.
projection_set = set(doc_projection)
doc_projection = [
field
for field in doc_projection
if not any(field.startswith(f"{other_field}.") for other_field in projection_set - {field})
]
# Make sure we didn't get any invalid projection fields for this class
invalid_fields = [f for f in doc_projection if f.split('.')[0] not in doc_cls.get_fields()]
if invalid_fields:
raise errors.bad_request.InvalidFields(fields=invalid_fields, object=doc_cls.__name__)
if ref_projection:
# Join mode - use both normal projection fields and top-level reference fields
doc_projection = set(doc_projection)
for field in set(ref_projection).difference(doc_projection):
if any(f for f in doc_projection if field.startswith(f)):
continue
doc_projection.add(field)
doc_projection = list(doc_projection)
self._doc_projection = doc_projection
self._ref_projection = ref_projection
@staticmethod
def _search(doc_cls, obj, path, only_values=True):
""" Call dpath.search with yielded=True, collect result values """
norm_path = doc_cls.get_dpath_translated_path(path)
return [v if only_values else (k, v) for k, v in dpath.search(obj, norm_path, separator='.', yielded=True)]
def project(self, results, projection_func):
"""
Perform projection on query results, using the provided projection func.
:param results: A list of results dictionaries on which projection should be performed
:param projection_func: A callable that receives a document type, list of ids and projection and returns query
results. This callable is used in order to perform sub-queries during projection
:return: Modified results (in-place)
"""
cls = self._doc_cls
ref_projection = self._ref_projection
if ref_projection:
# Join mode - get results for each reference fields projection required (this is the join step)
# Note: this is a recursive step, so we support nested reference fields
def do_projection(item):
ref_field_name, data = item
res = {}
ids = list(filter(None, set(chain.from_iterable(self._search(cls, res, ref_field_name)
for res in results))))
if ids:
doc_type = data['cls']
doc_only = list(filter(None, data['only']))
doc_only = list({'id'} | set(doc_only)) if doc_only else None
res = {r['id']: r for r in projection_func(doc_type=doc_type, projection=doc_only, ids=ids)}
data['res'] = res
items = list(ref_projection.items())
if len(ref_projection) == 1:
do_projection(items[0])
else:
for _ in self.pool.map(do_projection, items):
# From ThreadPoolExecutor.map() documentation: If a call raises an exception then that exception
# will be raised when its value is retrieved from the map() iterator
pass
def do_expand_reference_ids(result, skip_fields=None):
ref_fields = cls.get_reference_fields()
if skip_fields:
ref_fields = set(ref_fields) - set(skip_fields)
self._expand_reference_fields(cls, result, ref_fields)
def merge_projection_result(result):
for ref_field_name, data in ref_projection.items():
res = data.get('res')
if not res:
self._expand_reference_fields(cls, result, [ref_field_name])
continue
ref_ids = self._search(cls, result, ref_field_name, only_values=False)
if not ref_ids:
continue
for path, value in ref_ids:
obj = res.get(value) or {'id': value}
dpath.new(result, path, obj, separator='.')
# any reference field not projected should be expanded
do_expand_reference_ids(result, skip_fields=list(ref_projection))
update_func = merge_projection_result if ref_projection else \
do_expand_reference_ids if self._should_expand_reference_ids else None
if update_func:
for result in results:
update_func(result)
return results
@classmethod
def _expand_reference_fields(cls, doc_cls, result, fields):
for ref_field_name in fields:
ref_ids = cls._search(doc_cls, result, ref_field_name, only_values=False)
if not ref_ids:
continue
for path, value in ref_ids:
dpath.set(
result,
path,
{'id': value} if value else {},
separator='.')
@classmethod
def expand_reference_ids(cls, doc_cls, result):
cls._expand_reference_fields(doc_cls, result, doc_cls.get_reference_fields())

142
server/database/props.py Normal file
View File

@@ -0,0 +1,142 @@
from collections import OrderedDict
from operator import attrgetter
from threading import Lock
import six
from mongoengine import EmbeddedDocumentField, EmbeddedDocumentListField
from mongoengine.base import get_document
from database.fields import (
LengthRangeEmbeddedDocumentListField,
UniqueEmbeddedDocumentListField,
EmbeddedDocumentSortedListField,
)
from database.utils import get_fields, get_fields_and_attr
class PropsMixin(object):
__cached_fields = None
__cached_reference_fields = None
__cached_exclude_fields = None
__cached_fields_with_instance = None
__cached_dpath_computed_fields_lock = Lock()
__cached_dpath_computed_fields = None
@classmethod
def get_fields(cls):
if cls.__cached_fields is None:
cls.__cached_fields = get_fields(cls)
return cls.__cached_fields
@classmethod
def get_fields_with_instance(cls, doc_cls):
if cls.__cached_fields_with_instance is None:
cls.__cached_fields_with_instance = {}
if doc_cls not in cls.__cached_fields_with_instance:
cls.__cached_fields_with_instance[doc_cls] = get_fields(
doc_cls, return_instance=True
)
return cls.__cached_fields_with_instance[doc_cls]
@staticmethod
def _get_fields_with_attr(cls_, attr):
""" Get all fields with the specified attribute (supports nested fields) """
res = get_fields_and_attr(cls_, attr=attr)
def resolve_doc(v):
if not isinstance(v, six.string_types):
return v
if v == 'self':
return cls_.owner_document
return get_document(v)
fields = {k: resolve_doc(v) for k, v in res.items()}
def collect_embedded_docs(doc_cls, embedded_doc_field_getter):
for field, embedded_doc_field in get_fields(
cls_, of_type=doc_cls, return_instance=True
):
embedded_doc_cls = embedded_doc_field_getter(
embedded_doc_field
).document_type
fields.update(
{
'.'.join((field, subfield)): doc
for subfield, doc in PropsMixin._get_fields_with_attr(
embedded_doc_cls, attr
).items()
}
)
collect_embedded_docs(EmbeddedDocumentField, lambda x: x)
collect_embedded_docs(EmbeddedDocumentListField, attrgetter('field'))
collect_embedded_docs(LengthRangeEmbeddedDocumentListField, attrgetter('field'))
collect_embedded_docs(UniqueEmbeddedDocumentListField, attrgetter('field'))
collect_embedded_docs(EmbeddedDocumentSortedListField, attrgetter('field'))
return fields
@classmethod
def _translate_fields_path(cls, parts):
current_cls = cls
translated_parts = []
for depth, part in enumerate(parts):
if current_cls is None:
raise ValueError(
'Invalid path (non-document encountered at %s)' % parts[: depth - 1]
)
try:
field_name, field = next(
(k, v)
for k, v in cls.get_fields_with_instance(current_cls)
if k == part
)
except StopIteration:
raise ValueError('Invalid field path %s' % parts[:depth])
translated_parts.append(part)
if isinstance(field, EmbeddedDocumentField):
current_cls = field.document_type
elif isinstance(
field,
(
EmbeddedDocumentListField,
LengthRangeEmbeddedDocumentListField,
UniqueEmbeddedDocumentListField,
EmbeddedDocumentSortedListField,
),
):
current_cls = field.field.document_type
translated_parts.append('*')
else:
current_cls = None
return translated_parts
@classmethod
def get_reference_fields(cls):
if cls.__cached_reference_fields is None:
fields = cls._get_fields_with_attr(cls, 'reference_field')
cls.__cached_reference_fields = OrderedDict(sorted(fields.items()))
return cls.__cached_reference_fields
@classmethod
def get_exclude_fields(cls):
if cls.__cached_exclude_fields is None:
fields = cls._get_fields_with_attr(cls, 'exclude_by_default')
cls.__cached_exclude_fields = OrderedDict(sorted(fields.items()))
return cls.__cached_exclude_fields
@classmethod
def get_dpath_translated_path(cls, path, separator='.'):
if cls.__cached_dpath_computed_fields is None:
cls.__cached_dpath_computed_fields = {}
if path not in cls.__cached_dpath_computed_fields:
with cls.__cached_dpath_computed_fields_lock:
parts = path.split(separator)
translated = cls._translate_fields_path(parts)
result = separator.join(translated)
cls.__cached_dpath_computed_fields[path] = result
return cls.__cached_dpath_computed_fields[path]

63
server/database/query.py Normal file
View File

@@ -0,0 +1,63 @@
import copy
import re
from mongoengine import Q
from mongoengine.queryset.visitor import QueryCompilerVisitor, SimplificationVisitor, QCombination
class RegexWrapper(object):
def __init__(self, pattern, flags=None):
super(RegexWrapper, self).__init__()
self.pattern = pattern
self.flags = flags
@property
def regex(self):
return re.compile(self.pattern, self.flags if self.flags is not None else 0)
class RegexMixin(object):
def to_query(self, document):
query = self.accept(SimplificationVisitor())
query = query.accept(RegexQueryCompilerVisitor(document))
return query
def _combine(self, other, operation):
"""Combine this node with another node into a QCombination
object.
"""
if getattr(other, 'empty', True):
return self
if self.empty:
return other
return RegexQCombination(operation, [self, other])
class RegexQCombination(RegexMixin, QCombination):
pass
class RegexQ(RegexMixin, Q):
pass
class RegexQueryCompilerVisitor(QueryCompilerVisitor):
"""
Improved mongoengine complied queries visitor class that supports compiled regex expressions as part of the query.
We need this class since mongoengine's Q (QNode) class uses copy.deepcopy() as part of the tree simplification
stage, which does not support re.compiled objects (since Python 2.5).
This class allows users to provide regex strings wrapped in QueryRegex instances, which are lazily evaluated to
to re.compile instances just before being visited for compilation (this is done after the simplification stage)
"""
def visit_query(self, query):
query = copy.deepcopy(query)
query.query = self._transform_query(query.query)
return super(RegexQueryCompilerVisitor, self).visit_query(query)
def _transform_query(self, query):
return {k: v.regex if isinstance(v, RegexWrapper) else v for k, v in query.items()}

160
server/database/utils.py Normal file
View File

@@ -0,0 +1,160 @@
import hashlib
from inspect import ismethod, getmembers
from uuid import uuid4
from mongoengine import EmbeddedDocumentField, ListField, Document, Q
from mongoengine.base import BaseField
from .errors import translate_errors_context, ParseCallError
def get_fields(cls, of_type=BaseField, return_instance=False):
""" get field names from a class containing mongoengine fields """
res = []
for cls_ in reversed(cls.mro()):
res.extend([k if not return_instance else (k, v)
for k, v in vars(cls_).items()
if isinstance(v, of_type)])
return res
def get_fields_and_attr(cls, attr):
""" get field names from a class containing mongoengine fields """
res = {}
for cls_ in reversed(cls.mro()):
res.update({k: getattr(v, attr)
for k, v in vars(cls_).items()
if isinstance(v, BaseField) and hasattr(v, attr)})
return res
def _get_field_choices(name, field):
field_t = type(field)
if issubclass(field_t, EmbeddedDocumentField):
obj = field.document_type_obj
n, choices = _get_field_choices(field.name, obj.field)
return '%s__%s' % (name, n), choices
elif issubclass(type(field), ListField):
return name, field.field.choices
return name, field.choices
def get_fields_with_attr(cls, attr, default=False):
fields = []
for field_name, field in cls._fields.items():
if not getattr(field, attr, default):
continue
field_t = type(field)
if issubclass(field_t, EmbeddedDocumentField):
fields.extend((('%s__%s' % (field_name, name), choices)
for name, choices in get_fields_with_attr(field.document_type, attr, default)))
elif issubclass(type(field), ListField):
fields.append((field_name, field.field.choices))
else:
fields.append((field_name, field.choices))
return fields
def get_items(cls):
""" get key/value items from an enum-like class (members represent enumeration key/value) """
res = {
k: v
for k, v in getmembers(cls)
if not (k.startswith("_") or ismethod(v))
}
return res
def get_options(cls):
""" get options from an enum-like class (members represent enumeration key/value) """
return list(get_items(cls).values())
# return a dictionary of items which:
# 1. are in the call_data
# 2. are in the fields dictionary, and their value in the call_data matches the type in fields
# 3. are in the cls_fields
def parse_from_call(call_data, fields, cls_fields, discard_none_values=True):
if not isinstance(fields, dict):
# fields should be key=>type dict
fields = {k: None for k in fields}
fields = {k: v for k, v in fields.items() if k in cls_fields}
res = {}
with translate_errors_context('parsing call data'):
for field, desc in fields.items():
value = call_data.get(field)
if value is None:
if not discard_none_values and field in call_data:
# we'll keep the None value in case the field actually exists in the call data
res[field] = None
continue
if desc:
if callable(desc):
desc(value)
else:
if issubclass(desc, (list, tuple, dict)) and not isinstance(value, desc):
raise ParseCallError('expecting %s' % desc.__name__, field=field)
if issubclass(desc, Document) and not desc.objects(id=value).only('id'):
raise ParseCallError('expecting %s id' % desc.__name__, id=value, field=field)
res[field] = value
return res
def init_cls_from_base(cls, instance):
return cls(**{k: v for k, v in instance.to_mongo(use_db_field=False).to_dict().items() if k[0] != '_'})
def get_company_or_none_constraint(company=None):
return Q(company__in=(company, None, '')) | Q(company__exists=False)
def field_does_not_exist(field: str, empty_value=None, is_list=False) -> Q:
"""
Creates a query object used for finding a field that doesn't exist, or has None or an empty value.
:param field: Field name
:param empty_value: The empty value to test for (None means no specific empty value will be used)
:param is_list: Is this a list (array) field. In this case, instead of testing for an empty value,
the length of the array will be used (len==0 means empty)
:return:
"""
query = (Q(**{f"{field}__exists": False}) |
Q(**{f"{field}__in": {empty_value, None}}))
if is_list:
query |= Q(**{f"{field}__size": 0})
return query
def get_subkey(d, key_path, default=None):
""" Get a key from a nested dictionary. kay_path is a '.' separated string of keys used to traverse
the nested dictionary.
"""
keys = key_path.split('.')
for i, key in enumerate(keys):
if not isinstance(d, dict):
raise KeyError('Expecting a dict (%s)' % ('.'.join(keys[:i]) if i else 'bad input'))
d = d.get(key)
if key is None:
return default
return d
def id():
return str(uuid4()).replace("-", "")
def hash_field_name(s):
""" Hash field name into a unique safe string """
return hashlib.md5(s.encode()).hexdigest()
def merge_dicts(*dicts):
base = {}
for dct in dicts:
base.update(dct)
return base
def filter_fields(cls, fields):
"""From the fields dictionary return only the fields that match cls fields"""
return {key: fields[key] for key in fields if key in get_fields(cls)}