mirror of
https://github.com/clearml/clearml-server
synced 2025-06-26 23:15:47 +00:00
Initial commit
This commit is contained in:
58
server/database/__init__.py
Normal file
58
server/database/__init__.py
Normal file
@@ -0,0 +1,58 @@
|
||||
from jsonmodels import models
|
||||
from jsonmodels.errors import ValidationError
|
||||
from jsonmodels.fields import StringField
|
||||
from mongoengine import register_connection
|
||||
from mongoengine.connection import get_connection
|
||||
|
||||
from config import config
|
||||
from .defs import Database
|
||||
from .utils import get_items
|
||||
|
||||
log = config.logger(__file__)
|
||||
|
||||
strict = config.get('apiserver.mongo.strict', True)
|
||||
|
||||
_entries = []
|
||||
|
||||
|
||||
class DatabaseEntry(models.Base):
|
||||
host = StringField(required=True)
|
||||
alias = StringField()
|
||||
|
||||
@property
|
||||
def health_alias(self):
|
||||
return '__health__' + self.alias
|
||||
|
||||
|
||||
def initialize():
|
||||
db_entries = config.get('hosts.mongo', {})
|
||||
missing = []
|
||||
log.info('Initializing database connections')
|
||||
for key, alias in get_items(Database).items():
|
||||
if key not in db_entries:
|
||||
missing.append(key)
|
||||
continue
|
||||
entry = DatabaseEntry(alias=alias, **db_entries.get(key))
|
||||
try:
|
||||
entry.validate()
|
||||
log.info('Registering connection to %(alias)s (%(host)s)' % entry.to_struct())
|
||||
register_connection(alias=alias, host=entry.host)
|
||||
|
||||
_entries.append(entry)
|
||||
except ValidationError as ex:
|
||||
raise Exception('Invalid database entry `%s`: %s' % (key, ex.args[0]))
|
||||
if missing:
|
||||
raise ValueError('Missing database configuration for %s' % ', '.join(missing))
|
||||
|
||||
|
||||
def get_entries():
|
||||
return _entries
|
||||
|
||||
|
||||
def get_aliases():
|
||||
return [entry.alias for entry in get_entries()]
|
||||
|
||||
|
||||
def reconnect():
|
||||
for entry in get_entries():
|
||||
get_connection(entry.alias, reconnect=True)
|
||||
10
server/database/defs.py
Normal file
10
server/database/defs.py
Normal file
@@ -0,0 +1,10 @@
|
||||
|
||||
|
||||
class Database(object):
|
||||
""" Database names for our different DB instances """
|
||||
|
||||
backend = 'backend-db'
|
||||
''' Used for all backend objects (tasks, models etc.) '''
|
||||
|
||||
auth = 'auth-db'
|
||||
''' Used for all authentication and permission objects '''
|
||||
189
server/database/errors.py
Normal file
189
server/database/errors.py
Normal file
@@ -0,0 +1,189 @@
|
||||
import re
|
||||
from contextlib import contextmanager
|
||||
from functools import wraps
|
||||
|
||||
import dpath
|
||||
from dpath.exceptions import InvalidKeyName
|
||||
from elasticsearch import ElasticsearchException
|
||||
from elasticsearch.helpers import BulkIndexError
|
||||
from jsonmodels.errors import ValidationError as JsonschemaValidationError
|
||||
from mongoengine.errors import (
|
||||
ValidationError,
|
||||
NotUniqueError,
|
||||
FieldDoesNotExist,
|
||||
InvalidDocumentError,
|
||||
LookUpError,
|
||||
InvalidQueryError,
|
||||
)
|
||||
from pymongo.errors import PyMongoError, NotMasterError
|
||||
|
||||
from apierrors import errors
|
||||
|
||||
|
||||
class MakeGetAllQueryError(Exception):
|
||||
def __init__(self, error, field):
|
||||
super(MakeGetAllQueryError, self).__init__(f"{error}: field={field}")
|
||||
self.error = error
|
||||
self.field = field
|
||||
|
||||
|
||||
class ParseCallError(Exception):
|
||||
def __init__(self, msg, **kwargs):
|
||||
super(ParseCallError, self).__init__(msg)
|
||||
self.params = kwargs
|
||||
|
||||
|
||||
def throws_default_error(err_cls):
|
||||
"""
|
||||
Used to make functions (Exception, str) -> Optional[str] searching for specialized error messages raise those
|
||||
messages in ``err_cls``. If the decorated function does not find a suitable error message,
|
||||
the underlying exception is returned.
|
||||
:param err_cls: Error class (generated by apierrors)
|
||||
"""
|
||||
|
||||
def decorator(func):
|
||||
@wraps(func)
|
||||
def wrapper(self, e, message, **kwargs):
|
||||
extra_info = func(self, e, message, **kwargs)
|
||||
raise err_cls(message, err=e, extra_info=extra_info)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
class ElasticErrorsHandler(object):
|
||||
@classmethod
|
||||
@throws_default_error(errors.server_error.DataError)
|
||||
def bulk_error(cls, e, _, **__):
|
||||
if not e.errors:
|
||||
return
|
||||
|
||||
# Else try returning a better error string
|
||||
for _, reason in dpath.search(e.errors[0], "*/error/reason", yielded=True):
|
||||
return reason
|
||||
|
||||
|
||||
class MongoEngineErrorsHandler(object):
|
||||
# NotUniqueError
|
||||
__not_unique_regex = re.compile(
|
||||
r"collection:\s(?P<collection>[\w.]+)\sindex:\s(?P<index>\w+)\sdup\skey:\s{(?P<values>[^\}]+)\}"
|
||||
)
|
||||
__not_unique_value_regex = re.compile(r':\s"(?P<value>[^"]+)"')
|
||||
__id_index = "_id_"
|
||||
__index_sep_regex = re.compile(r"_[0-9]+_?")
|
||||
|
||||
# FieldDoesNotExist
|
||||
__not_exist_fields_regex = re.compile(r'"{(?P<fields>.+?)}".+?"(?P<document>.+?)"')
|
||||
__not_exist_field_regex = re.compile(r"'(?P<field>\w+)'")
|
||||
|
||||
@classmethod
|
||||
def validation_error(cls, e: ValidationError, message, **_):
|
||||
# Thrown when a document is validated. Documents are validated by default on save and on update
|
||||
err_dict = e.errors or {e.field_name: e.message}
|
||||
raise errors.bad_request.DataValidationError(message, **err_dict)
|
||||
|
||||
@classmethod
|
||||
def not_unique_error(cls, e, message, **_):
|
||||
# Thrown when a save/update violates a unique index constraint
|
||||
m = cls.__not_unique_regex.search(str(e))
|
||||
if not m:
|
||||
raise errors.bad_request.ExpectedUniqueData(message, err=str(e))
|
||||
values = cls.__not_unique_value_regex.findall(m.group("values"))
|
||||
index = m.group("index")
|
||||
if index == cls.__id_index:
|
||||
fields = "id"
|
||||
else:
|
||||
fields = cls.__index_sep_regex.split(index)[:-1]
|
||||
raise errors.bad_request.ExpectedUniqueData(
|
||||
message, **dict(zip(fields, values))
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def field_does_not_exist(cls, e, message, **kwargs):
|
||||
# Strict mode. Unknown fields encountered in loaded document(s)
|
||||
field_does_not_exist_cls = kwargs.get(
|
||||
"field_does_not_exist_cls", errors.server_error.InconsistentData
|
||||
)
|
||||
m = cls.__not_exist_fields_regex.search(str(e))
|
||||
params = {}
|
||||
if m:
|
||||
params["document"] = m.group("document")
|
||||
fields = cls.__not_exist_field_regex.findall(m.group("fields"))
|
||||
if fields:
|
||||
if len(fields) > 1:
|
||||
params["fields"] = "(%s)" % ", ".join(fields)
|
||||
else:
|
||||
params["field"] = fields[0]
|
||||
raise field_does_not_exist_cls(message, **params)
|
||||
|
||||
@classmethod
|
||||
@throws_default_error(errors.server_error.DataError)
|
||||
def invalid_document_error(cls, e, message, **_):
|
||||
# Reverse_delete_rule used in reference field
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def lookup_error(cls, e, message, **_):
|
||||
raise errors.bad_request.InvalidFields(
|
||||
"probably an invalid field name or unsupported nested field",
|
||||
replacement_msg="Lookup error",
|
||||
err=str(e),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@throws_default_error(errors.bad_request.InvalidRegexError)
|
||||
def invalid_regex_error(cls, e, _, **__):
|
||||
if e.args and e.args[0] == "unexpected end of regular expression":
|
||||
raise errors.bad_request.InvalidRegexError(e.args[0])
|
||||
|
||||
@classmethod
|
||||
@throws_default_error(errors.server_error.InternalError)
|
||||
def invalid_query_error(cls, e, message, **_):
|
||||
pass
|
||||
|
||||
|
||||
@contextmanager
|
||||
def translate_errors_context(message=None, **kwargs):
|
||||
"""
|
||||
A context manager that translates MongoEngine's and Elastic thrown errors into our apierrors classes,
|
||||
with an appropriate message.
|
||||
"""
|
||||
try:
|
||||
if message:
|
||||
message = "while " + message
|
||||
yield True
|
||||
except ValidationError as e:
|
||||
MongoEngineErrorsHandler.validation_error(e, message, **kwargs)
|
||||
except NotUniqueError as e:
|
||||
MongoEngineErrorsHandler.not_unique_error(e, message, **kwargs)
|
||||
except FieldDoesNotExist as e:
|
||||
MongoEngineErrorsHandler.field_does_not_exist(e, message, **kwargs)
|
||||
except InvalidDocumentError as e:
|
||||
MongoEngineErrorsHandler.invalid_document_error(e, message, **kwargs)
|
||||
except LookUpError as e:
|
||||
MongoEngineErrorsHandler.lookup_error(e, message, **kwargs)
|
||||
except re.error as e:
|
||||
MongoEngineErrorsHandler.invalid_regex_error(e, message, **kwargs)
|
||||
except InvalidQueryError as e:
|
||||
MongoEngineErrorsHandler.invalid_query_error(e, message, **kwargs)
|
||||
except PyMongoError as e:
|
||||
raise errors.server_error.InternalError(message, err=str(e))
|
||||
except NotMasterError as e:
|
||||
raise errors.server_error.InternalError(message, err=str(e))
|
||||
except MakeGetAllQueryError as e:
|
||||
raise errors.bad_request.ValidationError(e.error, field=e.field)
|
||||
except ParseCallError as e:
|
||||
raise errors.bad_request.FieldsValueError(e.args[0], **e.params)
|
||||
except JsonschemaValidationError as e:
|
||||
if len(e.args) >= 2:
|
||||
raise errors.bad_request.ValidationError(e.args[0], reason=e.args[1])
|
||||
raise errors.bad_request.ValidationError(e.args[0])
|
||||
except BulkIndexError as e:
|
||||
ElasticErrorsHandler.bulk_error(e, message, **kwargs)
|
||||
except ElasticsearchException as e:
|
||||
raise errors.server_error.DataError(e, message, **kwargs)
|
||||
except InvalidKeyName:
|
||||
raise errors.server_error.DataError("invalid empty key encountered in data")
|
||||
except Exception as ex:
|
||||
raise
|
||||
237
server/database/fields.py
Normal file
237
server/database/fields.py
Normal file
@@ -0,0 +1,237 @@
|
||||
import re
|
||||
from sys import maxsize
|
||||
|
||||
import six
|
||||
from mongoengine import (
|
||||
EmbeddedDocumentListField,
|
||||
ListField,
|
||||
FloatField,
|
||||
StringField,
|
||||
EmbeddedDocumentField,
|
||||
SortedListField,
|
||||
MapField,
|
||||
DictField,
|
||||
)
|
||||
|
||||
|
||||
class LengthRangeListField(ListField):
|
||||
def __init__(self, field=None, max_length=maxsize, min_length=0, **kwargs):
|
||||
self.__min_length = min_length
|
||||
self.__max_length = max_length
|
||||
super(LengthRangeListField, self).__init__(field, **kwargs)
|
||||
|
||||
def validate(self, value):
|
||||
min, val, max = self.__min_length, len(value), self.__max_length
|
||||
if not min <= val <= max:
|
||||
self.error("Item count %d exceeds range [%d, %d]" % (val, min, max))
|
||||
super(LengthRangeListField, self).validate(value)
|
||||
|
||||
|
||||
class LengthRangeEmbeddedDocumentListField(LengthRangeListField):
|
||||
def __init__(self, field=None, *args, **kwargs):
|
||||
super(LengthRangeEmbeddedDocumentListField, self).__init__(
|
||||
EmbeddedDocumentField(field), *args, **kwargs
|
||||
)
|
||||
|
||||
|
||||
class UniqueEmbeddedDocumentListField(EmbeddedDocumentListField):
|
||||
def __init__(self, document_type, key, **kwargs):
|
||||
"""
|
||||
Create a unique embedded document list field for a document type with a unique comparison key func/property
|
||||
:param document_type: The type of :class:`~mongoengine.EmbeddedDocument` the list will hold.
|
||||
:param key: A callable to extract a key from each item
|
||||
"""
|
||||
if not callable(key):
|
||||
raise KeyError("key must be callable")
|
||||
self.__key = key
|
||||
super(UniqueEmbeddedDocumentListField, self).__init__(document_type)
|
||||
|
||||
def validate(self, value):
|
||||
if len({self.__key(i) for i in value}) != len(value):
|
||||
self.error("Items with duplicate key exist in the list")
|
||||
super(UniqueEmbeddedDocumentListField, self).validate(value)
|
||||
|
||||
|
||||
def object_to_key_value_pairs(obj):
|
||||
if isinstance(obj, dict):
|
||||
return [(key, object_to_key_value_pairs(value)) for key, value in obj.items()]
|
||||
if isinstance(obj, list):
|
||||
return list(map(object_to_key_value_pairs, obj))
|
||||
return obj
|
||||
|
||||
|
||||
class EmbeddedDocumentSortedListField(EmbeddedDocumentListField):
|
||||
"""
|
||||
A sorted list of embedded documents
|
||||
"""
|
||||
|
||||
def to_mongo(self, value, use_db_field=True, fields=None):
|
||||
value = super(EmbeddedDocumentSortedListField, self).to_mongo(
|
||||
value, use_db_field, fields
|
||||
)
|
||||
return sorted(value, key=object_to_key_value_pairs)
|
||||
|
||||
|
||||
class LengthRangeSortedListField(LengthRangeListField, SortedListField):
|
||||
pass
|
||||
|
||||
|
||||
class CustomFloatField(FloatField):
|
||||
def __init__(self, greater_than=None, **kwargs):
|
||||
self.greater_than = greater_than
|
||||
super(CustomFloatField, self).__init__(**kwargs)
|
||||
|
||||
def validate(self, value):
|
||||
super(CustomFloatField, self).validate(value)
|
||||
|
||||
if self.greater_than is not None and value <= self.greater_than:
|
||||
self.error("Float value must be greater than %s" % str(self.greater_than))
|
||||
|
||||
|
||||
# TODO: bucket name should be at most 63 characters....
|
||||
aws_s3_bucket_only_regex = (
|
||||
r"^s3://"
|
||||
r"(?:(?:\w[A-Z0-9\-]+\w)\.)*(?:\w[A-Z0-9\-]+\w)" # bucket name
|
||||
)
|
||||
|
||||
aws_s3_url_with_bucket_regex = (
|
||||
r"^s3://"
|
||||
r"(?:(?:\w[A-Z0-9\-]+\w)\.)*(?:\w[A-Z0-9\-]+\w)" # bucket name
|
||||
r"(?:(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+(?:[A-Z]{2,6}\.?|[A-Z0-9-]{2,}(?<!-)\.?))" # domain...
|
||||
)
|
||||
|
||||
non_aws_s3_regex = (
|
||||
r"^s3://"
|
||||
r"(?:(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+(?:[A-Z]{2,6}\.?|[A-Z0-9-]{2,}(?<!-)\.?)|" # domain...
|
||||
r"localhost|" # localhost...
|
||||
r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}|" # ...or ipv4
|
||||
r"\[?[A-F0-9]*:[A-F0-9:]+\]?)" # ...or ipv6
|
||||
r"(?::\d+)?" # optional port
|
||||
r"(?:/(?:(?:\w[A-Z0-9\-]+\w)\.)*(?:\w[A-Z0-9\-]+\w))" # bucket name
|
||||
)
|
||||
|
||||
google_gs_bucket_only_regex = (
|
||||
r"^gs://"
|
||||
r"(?:(?:\w[A-Z0-9\-_]+\w)\.)*(?:\w[A-Z0-9\-_]+\w)" # bucket name
|
||||
)
|
||||
|
||||
file_regex = r"^file://"
|
||||
|
||||
generic_url_regex = (
|
||||
r"^%s://" # scheme placeholder
|
||||
r"(?:(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+(?:[A-Z]{2,6}\.?|[A-Z0-9-]{2,}(?<!-)\.?)|" # domain...
|
||||
r"localhost|" # localhost...
|
||||
r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}|" # ...or ipv4
|
||||
r"\[?[A-F0-9]*:[A-F0-9:]+\]?)" # ...or ipv6
|
||||
r"(?::\d+)?" # optional port
|
||||
)
|
||||
|
||||
path_suffix = r"(?:/?|[/?]\S+)$"
|
||||
file_path_suffix = r"(?:/\S*[^/]+)$"
|
||||
|
||||
|
||||
class _RegexURLField(StringField):
|
||||
_regex = []
|
||||
|
||||
def __init__(self, regex, **kwargs):
|
||||
super(_RegexURLField, self).__init__(**kwargs)
|
||||
regex = regex if isinstance(regex, (tuple, list)) else [regex]
|
||||
self._regex = [
|
||||
re.compile(e, re.IGNORECASE) if isinstance(e, six.string_types) else e
|
||||
for e in regex
|
||||
]
|
||||
|
||||
def validate(self, value):
|
||||
# Check first if the scheme is valid
|
||||
if not any(regex for regex in self._regex if regex.match(value)):
|
||||
self.error("Invalid URL: {}".format(value))
|
||||
return
|
||||
|
||||
|
||||
class OutputDestinationField(_RegexURLField):
|
||||
""" A field representing task output URL """
|
||||
|
||||
schemes = ["s3", "gs", "file"]
|
||||
_expressions = (
|
||||
aws_s3_bucket_only_regex + path_suffix,
|
||||
aws_s3_url_with_bucket_regex + path_suffix,
|
||||
non_aws_s3_regex + path_suffix,
|
||||
google_gs_bucket_only_regex + path_suffix,
|
||||
file_regex + path_suffix,
|
||||
)
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super(OutputDestinationField, self).__init__(self._expressions, **kwargs)
|
||||
|
||||
|
||||
class SupportedURLField(_RegexURLField):
|
||||
""" A field representing a model URL """
|
||||
|
||||
schemes = ["s3", "gs", "file", "http", "https"]
|
||||
|
||||
_expressions = tuple(
|
||||
pattern + file_path_suffix
|
||||
for pattern in (
|
||||
aws_s3_bucket_only_regex,
|
||||
aws_s3_url_with_bucket_regex,
|
||||
non_aws_s3_regex,
|
||||
google_gs_bucket_only_regex,
|
||||
file_regex,
|
||||
(generic_url_regex % "http"),
|
||||
(generic_url_regex % "https"),
|
||||
)
|
||||
)
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super(SupportedURLField, self).__init__(self._expressions, **kwargs)
|
||||
|
||||
|
||||
class StrippedStringField(StringField):
|
||||
def __init__(
|
||||
self, regex=None, max_length=None, min_length=None, strip_chars=None, **kwargs
|
||||
):
|
||||
super(StrippedStringField, self).__init__(
|
||||
regex, max_length, min_length, **kwargs
|
||||
)
|
||||
self._strip_chars = strip_chars
|
||||
|
||||
def __set__(self, instance, value):
|
||||
if value is not None:
|
||||
try:
|
||||
value = value.strip(self._strip_chars)
|
||||
except AttributeError:
|
||||
pass
|
||||
super(StrippedStringField, self).__set__(instance, value)
|
||||
|
||||
def prepare_query_value(self, op, value):
|
||||
if not isinstance(op, six.string_types):
|
||||
return value
|
||||
if value is not None:
|
||||
value = value.strip(self._strip_chars)
|
||||
return super(StrippedStringField, self).prepare_query_value(op, value)
|
||||
|
||||
|
||||
def contains_empty_key(d):
|
||||
"""
|
||||
Helper function to recursively determine if any key in a
|
||||
dictionary is empty (based on mongoengine.fields.key_not_string)
|
||||
"""
|
||||
for k, v in list(d.items()):
|
||||
if not k or (isinstance(v, dict) and contains_empty_key(v)):
|
||||
return True
|
||||
|
||||
|
||||
class SafeMapField(MapField):
|
||||
def validate(self, value):
|
||||
super(SafeMapField, self).validate(value)
|
||||
|
||||
if contains_empty_key(value):
|
||||
self.error("Empty keys are not allowed in a MapField")
|
||||
|
||||
|
||||
class SafeDictField(DictField):
|
||||
def validate(self, value):
|
||||
super(SafeDictField, self).validate(value)
|
||||
|
||||
if contains_empty_key(value):
|
||||
self.error("Empty keys are not allowed in a DictField")
|
||||
56
server/database/model/__init__.py
Normal file
56
server/database/model/__init__.py
Normal file
@@ -0,0 +1,56 @@
|
||||
from mongoengine import Document, StringField
|
||||
|
||||
from apierrors import errors
|
||||
from database.model.base import DbModelMixin, ABSTRACT_FLAG
|
||||
from database.model.company import Company
|
||||
from database.model.user import User
|
||||
|
||||
|
||||
class AttributedDocument(DbModelMixin, Document):
|
||||
"""
|
||||
Represents objects which are attributed to a company and a user or to "no one".
|
||||
Company must be required since it can be used as unique field.
|
||||
"""
|
||||
meta = ABSTRACT_FLAG
|
||||
company = StringField(required=True, reference_field=Company)
|
||||
user = StringField(reference_field=User)
|
||||
|
||||
def is_public(self) -> bool:
|
||||
return bool(self.company)
|
||||
|
||||
|
||||
class PrivateDocument(AttributedDocument):
|
||||
"""
|
||||
Represents documents which always belong to a single company
|
||||
"""
|
||||
meta = ABSTRACT_FLAG
|
||||
# can not have an empty string as this is the "public" marker
|
||||
company = StringField(required=True, reference_field=Company, min_length=1)
|
||||
user = StringField(reference_field=User, required=True)
|
||||
|
||||
def is_public(self) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def validate_id(cls, company, **kwargs):
|
||||
"""
|
||||
Validate existence of objects with certain IDs. within company.
|
||||
:param cls: Model class to search in
|
||||
:param company: Company to search in
|
||||
:param kwargs: Mapping of field name to object ID. If any ID does not have a corresponding object,
|
||||
it will be reported along with the name it was assigned to.
|
||||
:return:
|
||||
"""
|
||||
ids = set(kwargs.values())
|
||||
objs = list(cls.objects(company=company, id__in=ids).only('id'))
|
||||
missing = ids - set(x.id for x in objs)
|
||||
if not missing:
|
||||
return
|
||||
id_to_name = {}
|
||||
for name, obj_id in kwargs.items():
|
||||
id_to_name.setdefault(obj_id, []).append(name)
|
||||
raise errors.bad_request.ValidationError(
|
||||
'Invalid {} ids'.format(cls.__name__.lower()),
|
||||
**{name: obj_id for obj_id in missing for name in id_to_name[obj_id]}
|
||||
)
|
||||
|
||||
72
server/database/model/auth.py
Normal file
72
server/database/model/auth.py
Normal file
@@ -0,0 +1,72 @@
|
||||
from mongoengine import (
|
||||
StringField,
|
||||
EmbeddedDocument,
|
||||
EmbeddedDocumentListField,
|
||||
EmailField,
|
||||
DateTimeField,
|
||||
)
|
||||
|
||||
from database import Database, strict
|
||||
from database.model import DbModelMixin
|
||||
from database.model.base import AuthDocument
|
||||
from database.utils import get_options
|
||||
|
||||
|
||||
class Entities(object):
|
||||
company = "company"
|
||||
task = "task"
|
||||
user = "user"
|
||||
model = "model"
|
||||
|
||||
|
||||
class Role(object):
|
||||
system = "system"
|
||||
""" Internal system component """
|
||||
root = "root"
|
||||
""" Root admin (person) """
|
||||
admin = "admin"
|
||||
""" Company administrator """
|
||||
superuser = "superuser"
|
||||
""" Company super user """
|
||||
user = "user"
|
||||
""" Company user """
|
||||
annotator = "annotator"
|
||||
""" Annotator with limited access"""
|
||||
|
||||
@classmethod
|
||||
def get_system_roles(cls) -> set:
|
||||
return {cls.system, cls.root}
|
||||
|
||||
@classmethod
|
||||
def get_company_roles(cls) -> set:
|
||||
return set(get_options(cls)) - cls.get_system_roles()
|
||||
|
||||
|
||||
class Credentials(EmbeddedDocument):
|
||||
key = StringField(required=True)
|
||||
secret = StringField(required=True)
|
||||
|
||||
|
||||
class User(DbModelMixin, AuthDocument):
|
||||
meta = {"db_alias": Database.auth, "strict": strict}
|
||||
|
||||
id = StringField(primary_key=True)
|
||||
name = StringField(unique_with="company")
|
||||
|
||||
created = DateTimeField()
|
||||
""" User auth entry creation time """
|
||||
|
||||
validated = DateTimeField()
|
||||
""" Last validation (login) time """
|
||||
|
||||
role = StringField(required=True, choices=get_options(Role), default=Role.user)
|
||||
""" User role """
|
||||
|
||||
company = StringField(required=True)
|
||||
""" Company this user belongs to """
|
||||
|
||||
credentials = EmbeddedDocumentListField(Credentials, default=list)
|
||||
""" Credentials generated for this user """
|
||||
|
||||
email = EmailField(unique=True, required=True)
|
||||
""" Email uniquely identifying the user """
|
||||
529
server/database/model/base.py
Normal file
529
server/database/model/base.py
Normal file
@@ -0,0 +1,529 @@
|
||||
import re
|
||||
from collections import namedtuple
|
||||
from functools import reduce
|
||||
from typing import Collection
|
||||
|
||||
from dateutil.parser import parse as parse_datetime
|
||||
from mongoengine import Q, Document
|
||||
from six import string_types
|
||||
|
||||
from apierrors import errors
|
||||
from config import config
|
||||
from database.errors import MakeGetAllQueryError
|
||||
from database.projection import project_dict, ProjectionHelper
|
||||
from database.props import PropsMixin
|
||||
from database.query import RegexQ, RegexWrapper
|
||||
from database.utils import get_company_or_none_constraint, get_fields_with_attr
|
||||
|
||||
log = config.logger("dbmodel")
|
||||
|
||||
ACCESS_REGEX = re.compile(r"^(?P<prefix>>=|>|<=|<)?(?P<value>.*)$")
|
||||
ACCESS_MODIFIER = {">=": "gte", ">": "gt", "<=": "lte", "<": "lt"}
|
||||
|
||||
ABSTRACT_FLAG = {"abstract": True}
|
||||
|
||||
|
||||
class AuthDocument(Document):
|
||||
meta = ABSTRACT_FLAG
|
||||
|
||||
|
||||
class ProperDictMixin(object):
|
||||
def to_proper_dict(self, strip_private=True, only=None, extra_dict=None) -> dict:
|
||||
return self.properize_dict(
|
||||
self.to_mongo(use_db_field=False).to_dict(),
|
||||
strip_private=strip_private,
|
||||
only=only,
|
||||
extra_dict=extra_dict,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def properize_dict(
|
||||
cls, d, strip_private=True, only=None, extra_dict=None, normalize_id=True
|
||||
):
|
||||
res = d
|
||||
if normalize_id and "_id" in res:
|
||||
res["id"] = res.pop("_id")
|
||||
if strip_private:
|
||||
res = {k: v for k, v in res.items() if k[0] != "_"}
|
||||
if only:
|
||||
res = project_dict(res, only)
|
||||
if extra_dict:
|
||||
res.update(extra_dict)
|
||||
return res
|
||||
|
||||
|
||||
class GetMixin(PropsMixin):
|
||||
_text_score = "$text_score"
|
||||
|
||||
_ordering_key = "order_by"
|
||||
|
||||
_multi_field_param_sep = "__"
|
||||
_multi_field_param_prefix = {
|
||||
("_any_", "_or_"): lambda a, b: a | b,
|
||||
("_all_", "_and_"): lambda a, b: a & b,
|
||||
}
|
||||
MultiFieldParameters = namedtuple("MultiFieldParameters", "pattern fields")
|
||||
|
||||
class QueryParameterOptions(object):
|
||||
def __init__(
|
||||
self,
|
||||
pattern_fields=("name",),
|
||||
list_fields=("tags", "id"),
|
||||
datetime_fields=None,
|
||||
fields=None,
|
||||
):
|
||||
"""
|
||||
:param pattern_fields: Fields for which a "string contains" condition should be generated
|
||||
:param list_fields: Fields for which a "list contains" condition should be generated
|
||||
:param datetime_fields: Fields for which datetime condition should be generated (see ACCESS_MODIFIER)
|
||||
:param fields: Fields which which a simple equality condition should be generated (basically filters out all
|
||||
other unsupported query fields)
|
||||
"""
|
||||
self.fields = fields
|
||||
self.datetime_fields = datetime_fields
|
||||
self.list_fields = list_fields
|
||||
self.pattern_fields = pattern_fields
|
||||
|
||||
get_all_query_options = QueryParameterOptions()
|
||||
|
||||
@classmethod
|
||||
def get(
|
||||
cls, company, id, *, _only=None, include_public=False, **kwargs
|
||||
) -> "GetMixin":
|
||||
q = cls.objects(
|
||||
cls._prepare_perm_query(company, allow_public=include_public)
|
||||
& Q(id=id, **kwargs)
|
||||
)
|
||||
if _only:
|
||||
q = q.only(*_only)
|
||||
return q.first()
|
||||
|
||||
@classmethod
|
||||
def prepare_query(
|
||||
cls,
|
||||
company: str,
|
||||
parameters: dict = None,
|
||||
parameters_options: QueryParameterOptions = None,
|
||||
allow_public=False,
|
||||
):
|
||||
"""
|
||||
Prepare a query object based on the provided query dictionary and various fields.
|
||||
:param parameters_options: Specifies options for parsing the parameters (see ParametersOptions)
|
||||
:param company: Company ID (required)
|
||||
:param allow_public: Allow results from public objects
|
||||
:param parameters: Query dictionary (relevant keys are these specified by the various field names parameters).
|
||||
Supported parameters:
|
||||
- <field_name>: <value> Will query for items with this value in the field (see QueryParameterOptions for
|
||||
specific rules on handling values). Only items matching ALL of these conditions will be retrieved.
|
||||
- <any|all>: {fields: [<field1>, <field2>, ...], pattern: <pattern>} Will query for items where any or all
|
||||
provided fields match the provided pattern.
|
||||
:return: mongoengine.Q query object
|
||||
"""
|
||||
return cls._prepare_query_no_company(
|
||||
parameters, parameters_options
|
||||
) & cls._prepare_perm_query(company, allow_public=allow_public)
|
||||
|
||||
@classmethod
|
||||
def _prepare_query_no_company(
|
||||
cls, parameters=None, parameters_options=QueryParameterOptions()
|
||||
):
|
||||
"""
|
||||
Prepare a query object based on the provided query dictionary and various fields.
|
||||
|
||||
NOTE: BE VERY CAREFUL WITH THIS CALL, as it allows creating queries that span across companies.
|
||||
|
||||
:param parameters_options: Specifies options for parsing the parameters (see ParametersOptions)
|
||||
:param parameters: Query dictionary (relevant keys are these specified by the various field names parameters).
|
||||
Supported parameters:
|
||||
- <field_name>: <value> Will query for items with this value in the field (see QueryParameterOptions for
|
||||
specific rules on handling values). Only items matching ALL of these conditions will be retrieved.
|
||||
- <any|all>: {fields: [<field1>, <field2>, ...], pattern: <pattern>} Will query for items where any or all
|
||||
provided fields match the provided pattern.
|
||||
:return: mongoengine.Q query object
|
||||
"""
|
||||
parameters_options = parameters_options or cls.get_all_query_options
|
||||
dict_query = {}
|
||||
query = RegexQ()
|
||||
if parameters:
|
||||
parameters = parameters.copy()
|
||||
opts = parameters_options
|
||||
for field in opts.pattern_fields:
|
||||
pattern = parameters.pop(field, None)
|
||||
if pattern:
|
||||
dict_query[field] = RegexWrapper(pattern)
|
||||
|
||||
for field in tuple(opts.list_fields or ()):
|
||||
data = parameters.pop(field, None)
|
||||
if data:
|
||||
if not isinstance(data, (list, tuple)):
|
||||
raise MakeGetAllQueryError("expected list", field)
|
||||
exclude = [t for t in data if t.startswith("-")]
|
||||
include = list(set(data).difference(exclude))
|
||||
mongoengine_field = field.replace(".", "__")
|
||||
if include:
|
||||
dict_query[f"{mongoengine_field}__in"] = include
|
||||
if exclude:
|
||||
dict_query[f"{mongoengine_field}__nin"] = [
|
||||
t[1:] for t in exclude
|
||||
]
|
||||
|
||||
for field in opts.fields or []:
|
||||
data = parameters.pop(field, None)
|
||||
if data is not None:
|
||||
dict_query[field] = data
|
||||
|
||||
for field in opts.datetime_fields or []:
|
||||
data = parameters.pop(field, None)
|
||||
if data is not None:
|
||||
if not isinstance(data, list):
|
||||
data = [data]
|
||||
for d in data: # type: str
|
||||
m = ACCESS_REGEX.match(d)
|
||||
if not m:
|
||||
continue
|
||||
try:
|
||||
value = parse_datetime(m.group("value"))
|
||||
prefix = m.group("prefix")
|
||||
modifier = ACCESS_MODIFIER.get(prefix)
|
||||
f = field if not modifier else "__".join((field, modifier))
|
||||
dict_query[f] = value
|
||||
except (ValueError, OverflowError):
|
||||
pass
|
||||
|
||||
for field, value in parameters.items():
|
||||
for keys, func in cls._multi_field_param_prefix.items():
|
||||
if field not in keys:
|
||||
continue
|
||||
try:
|
||||
data = cls.MultiFieldParameters(**value)
|
||||
except Exception:
|
||||
raise MakeGetAllQueryError("incorrect field format", field)
|
||||
if not data.fields:
|
||||
break
|
||||
regex = RegexWrapper(data.pattern, flags=re.IGNORECASE)
|
||||
sep_fields = [f.replace(".", "__") for f in data.fields]
|
||||
q = reduce(
|
||||
lambda a, x: func(a, RegexQ(**{x: regex})), sep_fields, RegexQ()
|
||||
)
|
||||
query = query & q
|
||||
|
||||
return query & RegexQ(**dict_query)
|
||||
|
||||
@classmethod
|
||||
def _prepare_perm_query(cls, company, allow_public=False):
|
||||
if allow_public:
|
||||
return get_company_or_none_constraint(company)
|
||||
return Q(company=company)
|
||||
|
||||
@classmethod
|
||||
def validate_paging(
|
||||
cls, parameters=None, default_page=None, default_page_size=None
|
||||
):
|
||||
""" Validate and extract paging info from from the provided dictionary. Supports default values. """
|
||||
if parameters is None:
|
||||
parameters = {}
|
||||
default_page = parameters.get("page", default_page)
|
||||
if default_page is None:
|
||||
return None, None
|
||||
default_page_size = parameters.get("page_size", default_page_size)
|
||||
if not default_page_size:
|
||||
raise errors.bad_request.MissingRequiredFields(
|
||||
"page_size is required when page is requested", field="page_size"
|
||||
)
|
||||
elif default_page < 0:
|
||||
raise errors.bad_request.ValidationError("page must be >=0", field="page")
|
||||
elif default_page_size < 1:
|
||||
raise errors.bad_request.ValidationError(
|
||||
"page_size must be >0", field="page_size"
|
||||
)
|
||||
return default_page, default_page_size
|
||||
|
||||
@classmethod
|
||||
def get_projection(cls, parameters, override_projection=None, **__):
|
||||
""" Extract a projection list from the provided dictionary. Supports an override projection. """
|
||||
if override_projection is not None:
|
||||
return override_projection
|
||||
if not parameters:
|
||||
return []
|
||||
return parameters.get("projection") or parameters.get("only_fields", [])
|
||||
|
||||
@classmethod
|
||||
def set_default_ordering(cls, parameters, value):
|
||||
parameters[cls._ordering_key] = parameters.get(cls._ordering_key) or value
|
||||
|
||||
@classmethod
|
||||
def get_many_with_join(
|
||||
cls,
|
||||
company,
|
||||
query_dict=None,
|
||||
query_options=None,
|
||||
query=None,
|
||||
allow_public=False,
|
||||
override_projection=None,
|
||||
expand_reference_ids=True,
|
||||
):
|
||||
"""
|
||||
Fetch all documents matching a provided query with support for joining referenced documents according to the
|
||||
requested projection. See get_many() for more info.
|
||||
:param expand_reference_ids: If True, reference fields that contain just an ID string are expanded into
|
||||
a sub-document in the format {_id: <ID>}. Otherwise, field values are left as a string.
|
||||
"""
|
||||
if issubclass(cls, AuthDocument):
|
||||
# Refuse projection (join) for auth documents (auth.User etc.) to avoid inadvertently disclosing
|
||||
# auth-related secrets and prevent security leaks
|
||||
log.error(
|
||||
f"Attempted projection of {cls.__name__} auth document (ignored)",
|
||||
stack_info=True,
|
||||
)
|
||||
return []
|
||||
|
||||
override_projection = cls.get_projection(
|
||||
parameters=query_dict, override_projection=override_projection
|
||||
)
|
||||
|
||||
helper = ProjectionHelper(
|
||||
doc_cls=cls,
|
||||
projection=override_projection,
|
||||
expand_reference_ids=expand_reference_ids,
|
||||
)
|
||||
|
||||
# Make the main query
|
||||
results = cls.get_many(
|
||||
override_projection=helper.doc_projection,
|
||||
company=company,
|
||||
parameters=query_dict,
|
||||
query_dict=query_dict,
|
||||
query=query,
|
||||
query_options=query_options,
|
||||
allow_public=allow_public,
|
||||
)
|
||||
|
||||
def projection_func(doc_type, projection, ids):
|
||||
return doc_type.get_many_with_join(
|
||||
company=company,
|
||||
override_projection=projection,
|
||||
query=Q(id__in=ids),
|
||||
expand_reference_ids=expand_reference_ids,
|
||||
allow_public=allow_public,
|
||||
)
|
||||
|
||||
return helper.project(results, projection_func)
|
||||
|
||||
@classmethod
|
||||
def get_many(
|
||||
cls,
|
||||
company,
|
||||
parameters: dict = None,
|
||||
query_dict: dict = None,
|
||||
query_options: QueryParameterOptions = None,
|
||||
query: Q = None,
|
||||
allow_public=False,
|
||||
override_projection: Collection[str] = None,
|
||||
return_dicts=True,
|
||||
):
|
||||
"""
|
||||
Fetch all documents matching a provided query. Supported several built-in options
|
||||
(aside from those provided by the parameters):
|
||||
- Ordering: using query field `order_by` which can contain a string or a list of strings corresponding to
|
||||
field names. Using field names not defined in the document will cause an error.
|
||||
- Paging: using query fields page and page_size. page must be larger than or equal to 0, page_size must be
|
||||
larger than 0 and is required when specifying a page.
|
||||
- Text search: using query field `search_text`. If used, text score can be used in the ordering, using the
|
||||
`@text_score` keyword. A text index must be defined on the document type, otherwise an error will
|
||||
be raised.
|
||||
:param return_dicts: Return a list of dictionaries. If True, a list of dicts is returned (if projection was
|
||||
requested, each contains only the requested projection).
|
||||
If False, a QuerySet object is returned (lazy evaluated)
|
||||
:param company: Company ID (required)
|
||||
:param parameters: Parameters dict from which paging ordering and searching parameters are extracted.
|
||||
:param query_dict: If provided, passed to prepare_query() along with all of the relevant arguments to produce
|
||||
a query. The resulting query is AND'ed with the `query` parameter (if provided).
|
||||
:param query_options: query parameters options (see ParametersOptions)
|
||||
:param query: Optional query object (mongoengine.Q)
|
||||
:param override_projection: A list of projection fields overriding any projection specified in the `param_dict`
|
||||
argument
|
||||
:param allow_public: If True, objects marked as public (no associated company) are also queried.
|
||||
:return: A list of objects matching the query.
|
||||
"""
|
||||
if query_dict is not None:
|
||||
q = cls.prepare_query(
|
||||
parameters=query_dict,
|
||||
company=company,
|
||||
parameters_options=query_options,
|
||||
allow_public=allow_public,
|
||||
)
|
||||
else:
|
||||
q = cls._prepare_perm_query(company, allow_public=allow_public)
|
||||
_query = (q & query) if query else q
|
||||
|
||||
return cls._get_many_no_company(
|
||||
query=_query,
|
||||
parameters=parameters,
|
||||
override_projection=override_projection,
|
||||
return_dicts=return_dicts,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _get_many_no_company(
|
||||
cls, query, parameters=None, override_projection=None, return_dicts=True
|
||||
):
|
||||
"""
|
||||
Fetch all documents matching a provided query.
|
||||
This is a company-less version for internal uses. We assume the caller has either added any necessary
|
||||
constraints to the query or that no constraints are required.
|
||||
|
||||
NOTE: BE VERY CAREFUL WITH THIS CALL, as it allows returning data across companies.
|
||||
|
||||
:param query: Query object (mongoengine.Q)
|
||||
:param return_dicts: Return a list of dictionaries. If True, a list of dicts is returned (if projection was
|
||||
requested, each contains only the requested projection).
|
||||
If False, a QuerySet object is returned (lazy evaluated)
|
||||
:param parameters: Parameters dict from which paging ordering and searching parameters are extracted.
|
||||
:param override_projection: A list of projection fields overriding any projection specified in the `param_dict`
|
||||
argument
|
||||
"""
|
||||
parameters = parameters or {}
|
||||
|
||||
if not query:
|
||||
raise ValueError("query or call_data must be provided")
|
||||
|
||||
page, page_size = cls.validate_paging(parameters=parameters)
|
||||
|
||||
order_by = parameters.get(cls._ordering_key)
|
||||
if order_by:
|
||||
order_by = order_by if isinstance(order_by, list) else [order_by]
|
||||
order_by = [cls._text_score if x == "@text_score" else x for x in order_by]
|
||||
|
||||
search_text = parameters.get("search_text")
|
||||
|
||||
only = cls.get_projection(parameters, override_projection)
|
||||
|
||||
if not search_text and order_by and cls._text_score in order_by:
|
||||
raise errors.bad_request.FieldsValueError(
|
||||
"text score cannot be used in order_by when search text is not used"
|
||||
)
|
||||
|
||||
qs = cls.objects(query)
|
||||
if search_text:
|
||||
qs = qs.search_text(search_text)
|
||||
if order_by:
|
||||
# add ordering
|
||||
qs = (
|
||||
qs.order_by(order_by)
|
||||
if isinstance(order_by, string_types)
|
||||
else qs.order_by(*order_by)
|
||||
)
|
||||
if only:
|
||||
# add projection
|
||||
qs = qs.only(*only)
|
||||
else:
|
||||
exclude = set(cls.get_exclude_fields()).difference(only)
|
||||
if exclude:
|
||||
qs = qs.exclude(*exclude)
|
||||
if page is not None and page_size:
|
||||
# add paging
|
||||
qs = qs.skip(page * page_size).limit(page_size)
|
||||
|
||||
if return_dicts:
|
||||
return [obj.to_proper_dict(only=only) for obj in qs]
|
||||
return qs
|
||||
|
||||
@classmethod
|
||||
def get_for_writing(
|
||||
cls, *args, _only: Collection[str] = None, **kwargs
|
||||
) -> "GetMixin":
|
||||
if _only and "company" not in _only:
|
||||
_only = list(set(_only) | {"company"})
|
||||
result = cls.get(*args, _only=_only, include_public=True, **kwargs)
|
||||
if result and not result.company:
|
||||
object_name = cls.__name__.lower()
|
||||
raise errors.forbidden.NoWritePermission(
|
||||
f"cannot modify public {object_name}(s), ids={(result.id,)}"
|
||||
)
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def get_many_for_writing(cls, company, *args, **kwargs):
|
||||
result = cls.get_many(
|
||||
company=company,
|
||||
*args,
|
||||
**dict(return_dicts=False, **kwargs),
|
||||
allow_public=True,
|
||||
)
|
||||
forbidden_objects = {obj.id for obj in result if not obj.company}
|
||||
if forbidden_objects:
|
||||
object_name = cls.__name__.lower()
|
||||
raise errors.forbidden.NoWritePermission(
|
||||
f"cannot modify public {object_name}(s), ids={tuple(forbidden_objects)}"
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
class UpdateMixin(object):
|
||||
@classmethod
|
||||
def user_set_allowed(cls):
|
||||
res = getattr(cls, "__user_set_allowed_fields", None)
|
||||
if res is None:
|
||||
res = cls.__user_set_allowed_fields = dict(
|
||||
get_fields_with_attr(cls, "user_set_allowed")
|
||||
)
|
||||
return res
|
||||
|
||||
@classmethod
|
||||
def get_safe_update_dict(cls, fields):
|
||||
if not fields:
|
||||
return {}
|
||||
valid_fields = cls.user_set_allowed()
|
||||
fields = [(k, v, fields[k]) for k, v in valid_fields.items() if k in fields]
|
||||
update_dict = {
|
||||
field: value
|
||||
for field, allowed, value in fields
|
||||
if allowed is None
|
||||
or (
|
||||
(value in allowed)
|
||||
if not isinstance(value, list)
|
||||
else all(v in allowed for v in value)
|
||||
)
|
||||
}
|
||||
return update_dict
|
||||
|
||||
@classmethod
|
||||
def safe_update(cls, company_id, id, partial_update_dict, injected_update=None):
|
||||
update_dict = cls.get_safe_update_dict(partial_update_dict)
|
||||
if not update_dict:
|
||||
return 0, {}
|
||||
if injected_update:
|
||||
update_dict.update(injected_update)
|
||||
update_count = cls.objects(id=id, company=company_id).update(
|
||||
upsert=False, **update_dict
|
||||
)
|
||||
return update_count, update_dict
|
||||
|
||||
|
||||
class DbModelMixin(GetMixin, ProperDictMixin, UpdateMixin):
|
||||
""" Provide convenience methods for a subclass of mongoengine.Document """
|
||||
|
||||
pass
|
||||
|
||||
|
||||
def validate_id(cls, company, **kwargs):
|
||||
"""
|
||||
Validate existence of objects with certain IDs. within company.
|
||||
:param cls: Model class to search in
|
||||
:param company: Company to search in
|
||||
:param kwargs: Mapping of field name to object ID. If any ID does not have a corresponding object,
|
||||
it will be reported along with the name it was assigned to.
|
||||
:return:
|
||||
"""
|
||||
ids = set(kwargs.values())
|
||||
objs = list(cls.objects(company=company, id__in=ids).only("id"))
|
||||
missing = ids - set(x.id for x in objs)
|
||||
if not missing:
|
||||
return
|
||||
id_to_name = {}
|
||||
for name, obj_id in kwargs.items():
|
||||
id_to_name.setdefault(obj_id, []).append(name)
|
||||
raise errors.bad_request.ValidationError(
|
||||
"Invalid {} ids".format(cls.__name__.lower()),
|
||||
**{name: obj_id for obj_id in missing for name in id_to_name[obj_id]}
|
||||
)
|
||||
25
server/database/model/company.py
Normal file
25
server/database/model/company.py
Normal file
@@ -0,0 +1,25 @@
|
||||
from mongoengine import Document, EmbeddedDocument, EmbeddedDocumentField, StringField, Q
|
||||
|
||||
from database import Database, strict
|
||||
from database.fields import StrippedStringField
|
||||
from database.model import DbModelMixin
|
||||
|
||||
|
||||
class CompanyDefaults(EmbeddedDocument):
|
||||
cluster = StringField()
|
||||
|
||||
|
||||
class Company(DbModelMixin, Document):
|
||||
meta = {
|
||||
'db_alias': Database.backend,
|
||||
'strict': strict,
|
||||
}
|
||||
|
||||
id = StringField(primary_key=True)
|
||||
name = StrippedStringField(unique=True, min_length=3)
|
||||
defaults = EmbeddedDocumentField(CompanyDefaults)
|
||||
|
||||
@classmethod
|
||||
def _prepare_perm_query(cls, company, allow_public=False):
|
||||
""" Override default behavior since a 'company' constraint is not supported for this document... """
|
||||
return Q()
|
||||
56
server/database/model/model.py
Normal file
56
server/database/model/model.py
Normal file
@@ -0,0 +1,56 @@
|
||||
from mongoengine import Document, StringField, DateTimeField, ListField, BooleanField
|
||||
|
||||
from database import Database, strict
|
||||
from database.fields import SupportedURLField, StrippedStringField, SafeDictField
|
||||
from database.model import DbModelMixin
|
||||
from database.model.model_labels import ModelLabels
|
||||
from database.model.company import Company
|
||||
from database.model.project import Project
|
||||
from database.model.task.task import Task
|
||||
from database.model.user import User
|
||||
|
||||
|
||||
class Model(DbModelMixin, Document):
|
||||
meta = {
|
||||
'db_alias': Database.backend,
|
||||
'strict': strict,
|
||||
'indexes': [
|
||||
{
|
||||
'name': '%s.model.main_text_index' % Database.backend,
|
||||
'fields': [
|
||||
'$name',
|
||||
'$id',
|
||||
'$comment',
|
||||
'$parent',
|
||||
'$task',
|
||||
'$project',
|
||||
],
|
||||
'default_language': 'english',
|
||||
'weights': {
|
||||
'name': 10,
|
||||
'id': 10,
|
||||
'comment': 10,
|
||||
'parent': 5,
|
||||
'task': 3,
|
||||
'project': 3,
|
||||
}
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
id = StringField(primary_key=True)
|
||||
name = StrippedStringField(user_set_allowed=True, min_length=3)
|
||||
parent = StringField(reference_field='Model', required=False)
|
||||
user = StringField(required=True, reference_field=User)
|
||||
company = StringField(required=True, reference_field=Company)
|
||||
project = StringField(reference_field=Project, user_set_allowed=True)
|
||||
created = DateTimeField(required=True, user_set_allowed=True)
|
||||
task = StringField(reference_field=Task)
|
||||
comment = StringField(user_set_allowed=True)
|
||||
tags = ListField(StringField(required=True), user_set_allowed=True)
|
||||
uri = SupportedURLField(default='', user_set_allowed=True)
|
||||
framework = StringField()
|
||||
design = SafeDictField()
|
||||
labels = ModelLabels()
|
||||
ready = BooleanField(required=True)
|
||||
ui_cache = SafeDictField(default=dict, user_set_allowed=True, exclude_by_default=True)
|
||||
11
server/database/model/model_labels.py
Normal file
11
server/database/model/model_labels.py
Normal file
@@ -0,0 +1,11 @@
|
||||
from mongoengine import MapField, IntField
|
||||
|
||||
|
||||
class ModelLabels(MapField):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(ModelLabels, self).__init__(field=IntField(), *args, **kwargs)
|
||||
|
||||
def validate(self, value):
|
||||
super(ModelLabels, self).validate(value)
|
||||
if value and (len(set(value.values())) < len(value)):
|
||||
self.error("Same label id appears more than once in model labels")
|
||||
39
server/database/model/project.py
Normal file
39
server/database/model/project.py
Normal file
@@ -0,0 +1,39 @@
|
||||
from mongoengine import StringField, DateTimeField, ListField
|
||||
|
||||
from database import Database, strict
|
||||
from database.fields import OutputDestinationField, StrippedStringField
|
||||
from database.model import AttributedDocument
|
||||
from database.model.base import GetMixin
|
||||
|
||||
|
||||
class Project(AttributedDocument):
|
||||
|
||||
get_all_query_options = GetMixin.QueryParameterOptions(
|
||||
pattern_fields=("name", "description"), list_fields=("tags", "id")
|
||||
)
|
||||
|
||||
meta = {
|
||||
"db_alias": Database.backend,
|
||||
"strict": strict,
|
||||
"indexes": [
|
||||
{
|
||||
"name": "%s.project.main_text_index" % Database.backend,
|
||||
"fields": ["$name", "$id", "$description"],
|
||||
"default_language": "english",
|
||||
"weights": {"name": 10, "id": 10, "description": 10},
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
id = StringField(primary_key=True)
|
||||
name = StrippedStringField(
|
||||
required=True,
|
||||
unique_with=AttributedDocument.company.name,
|
||||
min_length=3,
|
||||
sparse=True,
|
||||
)
|
||||
description = StringField(required=True)
|
||||
created = DateTimeField(required=True)
|
||||
tags = ListField(StringField(required=True), default=list)
|
||||
default_output_destination = OutputDestinationField()
|
||||
last_update = DateTimeField()
|
||||
14
server/database/model/task/metrics.py
Normal file
14
server/database/model/task/metrics.py
Normal file
@@ -0,0 +1,14 @@
|
||||
from mongoengine import EmbeddedDocument, StringField, DateTimeField, LongField, DynamicField
|
||||
|
||||
|
||||
class MetricEvent(EmbeddedDocument):
|
||||
metric = StringField(required=True, )
|
||||
variant = StringField(required=True)
|
||||
type = StringField(required=True)
|
||||
timestamp = DateTimeField(default=0, required=True)
|
||||
iter = LongField()
|
||||
value = DynamicField(required=True)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, **kwargs):
|
||||
return cls(**{k: v for k, v in kwargs.items() if k in cls._fields})
|
||||
16
server/database/model/task/output.py
Normal file
16
server/database/model/task/output.py
Normal file
@@ -0,0 +1,16 @@
|
||||
from mongoengine import EmbeddedDocument, StringField
|
||||
from database.utils import get_options
|
||||
|
||||
from database.fields import OutputDestinationField
|
||||
|
||||
|
||||
class Result(object):
|
||||
success = 'success'
|
||||
failure = 'failure'
|
||||
|
||||
|
||||
class Output(EmbeddedDocument):
|
||||
destination = OutputDestinationField()
|
||||
model = StringField(reference_field='Model')
|
||||
error = StringField(user_set_allowed=True)
|
||||
result = StringField(choices=get_options(Result))
|
||||
132
server/database/model/task/task.py
Normal file
132
server/database/model/task/task.py
Normal file
@@ -0,0 +1,132 @@
|
||||
from enum import Enum
|
||||
|
||||
from mongoengine import (
|
||||
StringField,
|
||||
EmbeddedDocumentField,
|
||||
EmbeddedDocument,
|
||||
DateTimeField,
|
||||
IntField,
|
||||
ListField,
|
||||
)
|
||||
|
||||
from database import Database, strict
|
||||
from database.fields import StrippedStringField, SafeMapField, SafeDictField
|
||||
from database.model import AttributedDocument
|
||||
from database.model.model_labels import ModelLabels
|
||||
from database.model.project import Project
|
||||
from database.utils import get_options
|
||||
from .metrics import MetricEvent
|
||||
from .output import Output
|
||||
|
||||
DEFAULT_LAST_ITERATION = 0
|
||||
|
||||
|
||||
class TaskStatus(object):
|
||||
created = 'created'
|
||||
in_progress = 'in_progress'
|
||||
stopped = 'stopped'
|
||||
publishing = 'publishing'
|
||||
published = 'published'
|
||||
closed = 'closed'
|
||||
failed = 'failed'
|
||||
unknown = 'unknown'
|
||||
|
||||
|
||||
class TaskStatusMessage(object):
|
||||
stopping = 'stopping'
|
||||
|
||||
|
||||
class TaskTags(object):
|
||||
development = 'development'
|
||||
|
||||
|
||||
class Script(EmbeddedDocument):
|
||||
binary = StringField(default='python')
|
||||
repository = StringField(required=True)
|
||||
tag = StringField()
|
||||
branch = StringField()
|
||||
version_num = StringField()
|
||||
entry_point = StringField(required=True)
|
||||
working_dir = StringField()
|
||||
requirements = SafeDictField()
|
||||
|
||||
|
||||
class Execution(EmbeddedDocument):
|
||||
test_split = IntField(default=0)
|
||||
parameters = SafeDictField(default=dict)
|
||||
model = StringField(reference_field='Model')
|
||||
model_desc = SafeMapField(StringField(default=''))
|
||||
model_labels = ModelLabels()
|
||||
framework = StringField()
|
||||
|
||||
queue = StringField()
|
||||
''' Queue ID where task was queued '''
|
||||
|
||||
|
||||
class TaskType(object):
|
||||
training = 'training'
|
||||
testing = 'testing'
|
||||
|
||||
|
||||
class Task(AttributedDocument):
|
||||
meta = {
|
||||
'db_alias': Database.backend,
|
||||
'strict': strict,
|
||||
'indexes': [
|
||||
'created',
|
||||
'started',
|
||||
'completed',
|
||||
{
|
||||
'name': '%s.task.main_text_index' % Database.backend,
|
||||
'fields': [
|
||||
'$name',
|
||||
'$id',
|
||||
'$comment',
|
||||
'$execution.model',
|
||||
'$output.model',
|
||||
'$script.repository',
|
||||
'$script.entry_point',
|
||||
],
|
||||
'default_language': 'english',
|
||||
'weights': {
|
||||
'name': 10,
|
||||
'id': 10,
|
||||
'comment': 10,
|
||||
'execution.model': 2,
|
||||
'output.model': 2,
|
||||
'script.repository': 1,
|
||||
'script.entry_point': 1,
|
||||
},
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
id = StringField(primary_key=True)
|
||||
name = StrippedStringField(
|
||||
required=True, user_set_allowed=True, sparse=False, min_length=3
|
||||
)
|
||||
|
||||
type = StringField(required=True, choices=get_options(TaskType))
|
||||
status = StringField(default=TaskStatus.created, choices=get_options(TaskStatus))
|
||||
status_reason = StringField()
|
||||
status_message = StringField()
|
||||
status_changed = DateTimeField()
|
||||
comment = StringField(user_set_allowed=True)
|
||||
created = DateTimeField(required=True, user_set_allowed=True)
|
||||
started = DateTimeField()
|
||||
completed = DateTimeField()
|
||||
published = DateTimeField()
|
||||
parent = StringField()
|
||||
project = StringField(reference_field=Project, user_set_allowed=True)
|
||||
output = EmbeddedDocumentField(Output, default=Output)
|
||||
execution: Execution = EmbeddedDocumentField(Execution, default=Execution)
|
||||
tags = ListField(StringField(required=True), user_set_allowed=True)
|
||||
script = EmbeddedDocumentField(Script)
|
||||
last_update = DateTimeField()
|
||||
last_iteration = IntField(default=DEFAULT_LAST_ITERATION)
|
||||
last_metrics = SafeMapField(field=SafeMapField(EmbeddedDocumentField(MetricEvent)))
|
||||
|
||||
|
||||
class TaskVisibility(Enum):
|
||||
active = 'active'
|
||||
archived = 'archived'
|
||||
21
server/database/model/user.py
Normal file
21
server/database/model/user.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from mongoengine import Document, StringField
|
||||
|
||||
from database import Database, strict
|
||||
from database.fields import SafeDictField
|
||||
from database.model import DbModelMixin
|
||||
from database.model.company import Company
|
||||
|
||||
|
||||
class User(DbModelMixin, Document):
|
||||
meta = {
|
||||
'db_alias': Database.backend,
|
||||
'strict': strict,
|
||||
}
|
||||
|
||||
id = StringField(primary_key=True)
|
||||
company = StringField(required=True, reference_field=Company)
|
||||
name = StringField(required=True, user_set_allowed=True)
|
||||
family_name = StringField(user_set_allowed=True)
|
||||
given_name = StringField(user_set_allowed=True)
|
||||
avatar = StringField()
|
||||
preferences = SafeDictField(default=dict, exclude_by_default=True)
|
||||
269
server/database/projection.py
Normal file
269
server/database/projection.py
Normal file
@@ -0,0 +1,269 @@
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from itertools import groupby, chain
|
||||
|
||||
import dpath
|
||||
|
||||
from apierrors import errors
|
||||
from database.props import PropsMixin
|
||||
|
||||
|
||||
def project_dict(data, projection, separator='.'):
|
||||
"""
|
||||
Project partial data from a dictionary into a new dictionary
|
||||
:param data: Input dictionary
|
||||
:param projection: List of dictionary paths (each a string with field names separated using a separator)
|
||||
:param separator: Separator (default is '.')
|
||||
:return: A new dictionary containing only the projected parts from the original dictionary
|
||||
"""
|
||||
assert isinstance(data, dict)
|
||||
result = {}
|
||||
|
||||
def copy_path(path_parts, source, destination):
|
||||
src, dst = source, destination
|
||||
try:
|
||||
for depth, path_part in enumerate(path_parts[:-1]):
|
||||
src_part = src[path_part]
|
||||
if isinstance(src_part, dict):
|
||||
src = src_part
|
||||
dst = dst.setdefault(path_part, {})
|
||||
elif isinstance(src_part, (list, tuple)):
|
||||
if path_part not in dst:
|
||||
dst[path_part] = [{} for _ in range(len(src_part))]
|
||||
elif not isinstance(dst[path_part], (list, tuple)):
|
||||
raise TypeError('Incompatible destination type %s for %s (list expected)'
|
||||
% (type(dst), separator.join(path_parts[:depth + 1])))
|
||||
elif not len(dst[path_part]) == len(src_part):
|
||||
raise ValueError('Destination list length differs from source length for %s'
|
||||
% separator.join(path_parts[:depth + 1]))
|
||||
|
||||
dst[path_part] = [copy_path(path_parts[depth + 1:], s, d)
|
||||
for s, d in zip(src_part, dst[path_part])]
|
||||
|
||||
return destination
|
||||
else:
|
||||
raise TypeError('Unsupported projection type %s for %s'
|
||||
% (type(src), separator.join(path_parts[:depth + 1])))
|
||||
|
||||
last_part = path_parts[-1]
|
||||
dst[last_part] = src[last_part]
|
||||
except KeyError:
|
||||
# Projection field not in source, no biggie.
|
||||
pass
|
||||
return destination
|
||||
|
||||
for projection_path in sorted(projection):
|
||||
copy_path(
|
||||
path_parts=projection_path.split(separator),
|
||||
source=data,
|
||||
destination=result)
|
||||
return result
|
||||
|
||||
|
||||
class ProjectionHelper(object):
|
||||
pool = ThreadPoolExecutor()
|
||||
|
||||
@property
|
||||
def doc_projection(self):
|
||||
return self._doc_projection
|
||||
|
||||
def __init__(self, doc_cls, projection, expand_reference_ids=False):
|
||||
super(ProjectionHelper, self).__init__()
|
||||
self._should_expand_reference_ids = expand_reference_ids
|
||||
self._doc_cls = doc_cls
|
||||
self._doc_projection = None
|
||||
self._ref_projection = None
|
||||
self._parse_projection(projection)
|
||||
|
||||
def _collect_projection_fields(self, doc_cls, projection):
|
||||
"""
|
||||
Collect projection for the given document into immediate document projection and reference documents projection
|
||||
:param doc_cls: Document class
|
||||
:param projection: List of projection fields
|
||||
:return: A tuple of document projection and reference fields information
|
||||
"""
|
||||
doc_projection = set() # Projection fields for this class (used in the main query)
|
||||
ref_projection_info = [] # Projection information for reference fields (used in join queries)
|
||||
for field in projection:
|
||||
for ref_field, ref_field_cls in doc_cls.get_reference_fields().items():
|
||||
if not field.startswith(ref_field):
|
||||
# Doesn't start with a reference field
|
||||
continue
|
||||
if field == ref_field:
|
||||
# Field is exactly a reference field. In this case we won't perform any inner projection (for that,
|
||||
# use '<reference field name>.*')
|
||||
continue
|
||||
subfield = field[len(ref_field):]
|
||||
if not subfield.startswith('.'):
|
||||
# Starts with something that looks like a reference field, but isn't
|
||||
continue
|
||||
|
||||
ref_projection_info.append((ref_field, ref_field_cls, subfield[1:]))
|
||||
break
|
||||
else:
|
||||
# Not a reference field, just add to the top-level projection
|
||||
# We strip any trailing '*' since it means nothing for simple fields and for embedded documents
|
||||
orig_field = field
|
||||
if field.endswith('.*'):
|
||||
field = field[:-2]
|
||||
if not field:
|
||||
raise errors.bad_request.InvalidFields(field=orig_field, object=doc_cls.__name__)
|
||||
doc_projection.add(field)
|
||||
return doc_projection, ref_projection_info
|
||||
|
||||
def _parse_projection(self, projection):
|
||||
"""
|
||||
Prepare the projection data structures for get_many_with_join().
|
||||
:param projection: A list of field names that should be returned by the query. Sub-fields can be specified
|
||||
using '.' (i.e. "parent.name"). A field terminated by '.*' indicated that all of the field's sub-fields
|
||||
should be returned (only relevant for fields that represent sub-documents or referenced documents)
|
||||
:type projection: list of strings
|
||||
:returns A tuple of (class fields projection, reference fields projection)
|
||||
"""
|
||||
doc_cls = self._doc_cls
|
||||
assert issubclass(doc_cls, PropsMixin)
|
||||
if not projection:
|
||||
return [], {}
|
||||
|
||||
doc_projection, ref_projection_info = self._collect_projection_fields(doc_cls, projection)
|
||||
|
||||
def normalize_cls_projection(cls_, fields):
|
||||
""" Normalize projection for this class and group (expand *, for once) """
|
||||
if '*' in fields:
|
||||
return list(fields.difference('*').union(cls_.get_fields()))
|
||||
return list(fields)
|
||||
|
||||
def compute_ref_cls_projection(cls_, group):
|
||||
""" Compute inner projection for this class and group """
|
||||
subfields = set([x[2] for x in group if x[2]])
|
||||
return normalize_cls_projection(cls_, subfields)
|
||||
|
||||
def sort_key(proj_info):
|
||||
return proj_info[:2]
|
||||
|
||||
# Aggregate by reference field. We'll leave out '*' from the projected items since
|
||||
ref_projection = {
|
||||
ref_field: dict(cls=ref_cls, only=compute_ref_cls_projection(ref_cls, g))
|
||||
for (ref_field, ref_cls), g in groupby(sorted(ref_projection_info, key=sort_key), sort_key)
|
||||
}
|
||||
|
||||
# Make sure this doesn't contain any reference field we'll join anyway
|
||||
# (i.e. in case only_fields=[project, project.name])
|
||||
doc_projection = normalize_cls_projection(doc_cls, doc_projection.difference(ref_projection).union({'id'}))
|
||||
|
||||
# Make sure that in case one or more field is a subfield of another field, we only use the the top-level field.
|
||||
# This is done since in such a case, MongoDB will only use the most restrictive field (most nested field) and
|
||||
# won't return some of the data we need.
|
||||
# This way, we make sure to use the most inclusive field that contains all requested subfields.
|
||||
projection_set = set(doc_projection)
|
||||
doc_projection = [
|
||||
field
|
||||
for field in doc_projection
|
||||
if not any(field.startswith(f"{other_field}.") for other_field in projection_set - {field})
|
||||
]
|
||||
|
||||
# Make sure we didn't get any invalid projection fields for this class
|
||||
invalid_fields = [f for f in doc_projection if f.split('.')[0] not in doc_cls.get_fields()]
|
||||
if invalid_fields:
|
||||
raise errors.bad_request.InvalidFields(fields=invalid_fields, object=doc_cls.__name__)
|
||||
|
||||
if ref_projection:
|
||||
# Join mode - use both normal projection fields and top-level reference fields
|
||||
doc_projection = set(doc_projection)
|
||||
for field in set(ref_projection).difference(doc_projection):
|
||||
if any(f for f in doc_projection if field.startswith(f)):
|
||||
continue
|
||||
doc_projection.add(field)
|
||||
doc_projection = list(doc_projection)
|
||||
|
||||
self._doc_projection = doc_projection
|
||||
self._ref_projection = ref_projection
|
||||
|
||||
@staticmethod
|
||||
def _search(doc_cls, obj, path, only_values=True):
|
||||
""" Call dpath.search with yielded=True, collect result values """
|
||||
norm_path = doc_cls.get_dpath_translated_path(path)
|
||||
return [v if only_values else (k, v) for k, v in dpath.search(obj, norm_path, separator='.', yielded=True)]
|
||||
|
||||
def project(self, results, projection_func):
|
||||
"""
|
||||
Perform projection on query results, using the provided projection func.
|
||||
:param results: A list of results dictionaries on which projection should be performed
|
||||
:param projection_func: A callable that receives a document type, list of ids and projection and returns query
|
||||
results. This callable is used in order to perform sub-queries during projection
|
||||
:return: Modified results (in-place)
|
||||
"""
|
||||
cls = self._doc_cls
|
||||
ref_projection = self._ref_projection
|
||||
|
||||
if ref_projection:
|
||||
# Join mode - get results for each reference fields projection required (this is the join step)
|
||||
# Note: this is a recursive step, so we support nested reference fields
|
||||
|
||||
def do_projection(item):
|
||||
ref_field_name, data = item
|
||||
res = {}
|
||||
ids = list(filter(None, set(chain.from_iterable(self._search(cls, res, ref_field_name)
|
||||
for res in results))))
|
||||
if ids:
|
||||
doc_type = data['cls']
|
||||
doc_only = list(filter(None, data['only']))
|
||||
doc_only = list({'id'} | set(doc_only)) if doc_only else None
|
||||
res = {r['id']: r for r in projection_func(doc_type=doc_type, projection=doc_only, ids=ids)}
|
||||
data['res'] = res
|
||||
|
||||
items = list(ref_projection.items())
|
||||
if len(ref_projection) == 1:
|
||||
do_projection(items[0])
|
||||
else:
|
||||
for _ in self.pool.map(do_projection, items):
|
||||
# From ThreadPoolExecutor.map() documentation: If a call raises an exception then that exception
|
||||
# will be raised when its value is retrieved from the map() iterator
|
||||
pass
|
||||
|
||||
def do_expand_reference_ids(result, skip_fields=None):
|
||||
ref_fields = cls.get_reference_fields()
|
||||
if skip_fields:
|
||||
ref_fields = set(ref_fields) - set(skip_fields)
|
||||
self._expand_reference_fields(cls, result, ref_fields)
|
||||
|
||||
def merge_projection_result(result):
|
||||
for ref_field_name, data in ref_projection.items():
|
||||
res = data.get('res')
|
||||
if not res:
|
||||
self._expand_reference_fields(cls, result, [ref_field_name])
|
||||
continue
|
||||
ref_ids = self._search(cls, result, ref_field_name, only_values=False)
|
||||
if not ref_ids:
|
||||
continue
|
||||
for path, value in ref_ids:
|
||||
obj = res.get(value) or {'id': value}
|
||||
dpath.new(result, path, obj, separator='.')
|
||||
|
||||
# any reference field not projected should be expanded
|
||||
do_expand_reference_ids(result, skip_fields=list(ref_projection))
|
||||
|
||||
update_func = merge_projection_result if ref_projection else \
|
||||
do_expand_reference_ids if self._should_expand_reference_ids else None
|
||||
|
||||
if update_func:
|
||||
for result in results:
|
||||
update_func(result)
|
||||
|
||||
return results
|
||||
|
||||
@classmethod
|
||||
def _expand_reference_fields(cls, doc_cls, result, fields):
|
||||
for ref_field_name in fields:
|
||||
ref_ids = cls._search(doc_cls, result, ref_field_name, only_values=False)
|
||||
if not ref_ids:
|
||||
continue
|
||||
for path, value in ref_ids:
|
||||
dpath.set(
|
||||
result,
|
||||
path,
|
||||
{'id': value} if value else {},
|
||||
separator='.')
|
||||
|
||||
@classmethod
|
||||
def expand_reference_ids(cls, doc_cls, result):
|
||||
cls._expand_reference_fields(doc_cls, result, doc_cls.get_reference_fields())
|
||||
142
server/database/props.py
Normal file
142
server/database/props.py
Normal file
@@ -0,0 +1,142 @@
|
||||
from collections import OrderedDict
|
||||
from operator import attrgetter
|
||||
from threading import Lock
|
||||
|
||||
import six
|
||||
from mongoengine import EmbeddedDocumentField, EmbeddedDocumentListField
|
||||
from mongoengine.base import get_document
|
||||
|
||||
from database.fields import (
|
||||
LengthRangeEmbeddedDocumentListField,
|
||||
UniqueEmbeddedDocumentListField,
|
||||
EmbeddedDocumentSortedListField,
|
||||
)
|
||||
from database.utils import get_fields, get_fields_and_attr
|
||||
|
||||
|
||||
class PropsMixin(object):
|
||||
__cached_fields = None
|
||||
__cached_reference_fields = None
|
||||
__cached_exclude_fields = None
|
||||
__cached_fields_with_instance = None
|
||||
|
||||
__cached_dpath_computed_fields_lock = Lock()
|
||||
__cached_dpath_computed_fields = None
|
||||
|
||||
@classmethod
|
||||
def get_fields(cls):
|
||||
if cls.__cached_fields is None:
|
||||
cls.__cached_fields = get_fields(cls)
|
||||
return cls.__cached_fields
|
||||
|
||||
@classmethod
|
||||
def get_fields_with_instance(cls, doc_cls):
|
||||
if cls.__cached_fields_with_instance is None:
|
||||
cls.__cached_fields_with_instance = {}
|
||||
if doc_cls not in cls.__cached_fields_with_instance:
|
||||
cls.__cached_fields_with_instance[doc_cls] = get_fields(
|
||||
doc_cls, return_instance=True
|
||||
)
|
||||
return cls.__cached_fields_with_instance[doc_cls]
|
||||
|
||||
@staticmethod
|
||||
def _get_fields_with_attr(cls_, attr):
|
||||
""" Get all fields with the specified attribute (supports nested fields) """
|
||||
res = get_fields_and_attr(cls_, attr=attr)
|
||||
|
||||
def resolve_doc(v):
|
||||
if not isinstance(v, six.string_types):
|
||||
return v
|
||||
if v == 'self':
|
||||
return cls_.owner_document
|
||||
return get_document(v)
|
||||
|
||||
fields = {k: resolve_doc(v) for k, v in res.items()}
|
||||
|
||||
def collect_embedded_docs(doc_cls, embedded_doc_field_getter):
|
||||
for field, embedded_doc_field in get_fields(
|
||||
cls_, of_type=doc_cls, return_instance=True
|
||||
):
|
||||
embedded_doc_cls = embedded_doc_field_getter(
|
||||
embedded_doc_field
|
||||
).document_type
|
||||
fields.update(
|
||||
{
|
||||
'.'.join((field, subfield)): doc
|
||||
for subfield, doc in PropsMixin._get_fields_with_attr(
|
||||
embedded_doc_cls, attr
|
||||
).items()
|
||||
}
|
||||
)
|
||||
|
||||
collect_embedded_docs(EmbeddedDocumentField, lambda x: x)
|
||||
collect_embedded_docs(EmbeddedDocumentListField, attrgetter('field'))
|
||||
collect_embedded_docs(LengthRangeEmbeddedDocumentListField, attrgetter('field'))
|
||||
collect_embedded_docs(UniqueEmbeddedDocumentListField, attrgetter('field'))
|
||||
collect_embedded_docs(EmbeddedDocumentSortedListField, attrgetter('field'))
|
||||
|
||||
return fields
|
||||
|
||||
@classmethod
|
||||
def _translate_fields_path(cls, parts):
|
||||
current_cls = cls
|
||||
translated_parts = []
|
||||
for depth, part in enumerate(parts):
|
||||
if current_cls is None:
|
||||
raise ValueError(
|
||||
'Invalid path (non-document encountered at %s)' % parts[: depth - 1]
|
||||
)
|
||||
try:
|
||||
field_name, field = next(
|
||||
(k, v)
|
||||
for k, v in cls.get_fields_with_instance(current_cls)
|
||||
if k == part
|
||||
)
|
||||
except StopIteration:
|
||||
raise ValueError('Invalid field path %s' % parts[:depth])
|
||||
|
||||
translated_parts.append(part)
|
||||
|
||||
if isinstance(field, EmbeddedDocumentField):
|
||||
current_cls = field.document_type
|
||||
elif isinstance(
|
||||
field,
|
||||
(
|
||||
EmbeddedDocumentListField,
|
||||
LengthRangeEmbeddedDocumentListField,
|
||||
UniqueEmbeddedDocumentListField,
|
||||
EmbeddedDocumentSortedListField,
|
||||
),
|
||||
):
|
||||
current_cls = field.field.document_type
|
||||
translated_parts.append('*')
|
||||
else:
|
||||
current_cls = None
|
||||
|
||||
return translated_parts
|
||||
|
||||
@classmethod
|
||||
def get_reference_fields(cls):
|
||||
if cls.__cached_reference_fields is None:
|
||||
fields = cls._get_fields_with_attr(cls, 'reference_field')
|
||||
cls.__cached_reference_fields = OrderedDict(sorted(fields.items()))
|
||||
return cls.__cached_reference_fields
|
||||
|
||||
@classmethod
|
||||
def get_exclude_fields(cls):
|
||||
if cls.__cached_exclude_fields is None:
|
||||
fields = cls._get_fields_with_attr(cls, 'exclude_by_default')
|
||||
cls.__cached_exclude_fields = OrderedDict(sorted(fields.items()))
|
||||
return cls.__cached_exclude_fields
|
||||
|
||||
@classmethod
|
||||
def get_dpath_translated_path(cls, path, separator='.'):
|
||||
if cls.__cached_dpath_computed_fields is None:
|
||||
cls.__cached_dpath_computed_fields = {}
|
||||
if path not in cls.__cached_dpath_computed_fields:
|
||||
with cls.__cached_dpath_computed_fields_lock:
|
||||
parts = path.split(separator)
|
||||
translated = cls._translate_fields_path(parts)
|
||||
result = separator.join(translated)
|
||||
cls.__cached_dpath_computed_fields[path] = result
|
||||
return cls.__cached_dpath_computed_fields[path]
|
||||
63
server/database/query.py
Normal file
63
server/database/query.py
Normal file
@@ -0,0 +1,63 @@
|
||||
import copy
|
||||
import re
|
||||
|
||||
from mongoengine import Q
|
||||
from mongoengine.queryset.visitor import QueryCompilerVisitor, SimplificationVisitor, QCombination
|
||||
|
||||
|
||||
class RegexWrapper(object):
|
||||
def __init__(self, pattern, flags=None):
|
||||
super(RegexWrapper, self).__init__()
|
||||
self.pattern = pattern
|
||||
self.flags = flags
|
||||
|
||||
@property
|
||||
def regex(self):
|
||||
return re.compile(self.pattern, self.flags if self.flags is not None else 0)
|
||||
|
||||
|
||||
class RegexMixin(object):
|
||||
|
||||
def to_query(self, document):
|
||||
query = self.accept(SimplificationVisitor())
|
||||
query = query.accept(RegexQueryCompilerVisitor(document))
|
||||
return query
|
||||
|
||||
def _combine(self, other, operation):
|
||||
"""Combine this node with another node into a QCombination
|
||||
object.
|
||||
"""
|
||||
if getattr(other, 'empty', True):
|
||||
return self
|
||||
|
||||
if self.empty:
|
||||
return other
|
||||
|
||||
return RegexQCombination(operation, [self, other])
|
||||
|
||||
|
||||
class RegexQCombination(RegexMixin, QCombination):
|
||||
pass
|
||||
|
||||
|
||||
class RegexQ(RegexMixin, Q):
|
||||
pass
|
||||
|
||||
|
||||
class RegexQueryCompilerVisitor(QueryCompilerVisitor):
|
||||
"""
|
||||
Improved mongoengine complied queries visitor class that supports compiled regex expressions as part of the query.
|
||||
|
||||
We need this class since mongoengine's Q (QNode) class uses copy.deepcopy() as part of the tree simplification
|
||||
stage, which does not support re.compiled objects (since Python 2.5).
|
||||
This class allows users to provide regex strings wrapped in QueryRegex instances, which are lazily evaluated to
|
||||
to re.compile instances just before being visited for compilation (this is done after the simplification stage)
|
||||
"""
|
||||
|
||||
def visit_query(self, query):
|
||||
query = copy.deepcopy(query)
|
||||
query.query = self._transform_query(query.query)
|
||||
return super(RegexQueryCompilerVisitor, self).visit_query(query)
|
||||
|
||||
def _transform_query(self, query):
|
||||
return {k: v.regex if isinstance(v, RegexWrapper) else v for k, v in query.items()}
|
||||
160
server/database/utils.py
Normal file
160
server/database/utils.py
Normal file
@@ -0,0 +1,160 @@
|
||||
import hashlib
|
||||
from inspect import ismethod, getmembers
|
||||
from uuid import uuid4
|
||||
|
||||
from mongoengine import EmbeddedDocumentField, ListField, Document, Q
|
||||
from mongoengine.base import BaseField
|
||||
|
||||
from .errors import translate_errors_context, ParseCallError
|
||||
|
||||
|
||||
def get_fields(cls, of_type=BaseField, return_instance=False):
|
||||
""" get field names from a class containing mongoengine fields """
|
||||
res = []
|
||||
for cls_ in reversed(cls.mro()):
|
||||
res.extend([k if not return_instance else (k, v)
|
||||
for k, v in vars(cls_).items()
|
||||
if isinstance(v, of_type)])
|
||||
return res
|
||||
|
||||
|
||||
def get_fields_and_attr(cls, attr):
|
||||
""" get field names from a class containing mongoengine fields """
|
||||
res = {}
|
||||
for cls_ in reversed(cls.mro()):
|
||||
res.update({k: getattr(v, attr)
|
||||
for k, v in vars(cls_).items()
|
||||
if isinstance(v, BaseField) and hasattr(v, attr)})
|
||||
return res
|
||||
|
||||
|
||||
def _get_field_choices(name, field):
|
||||
field_t = type(field)
|
||||
if issubclass(field_t, EmbeddedDocumentField):
|
||||
obj = field.document_type_obj
|
||||
n, choices = _get_field_choices(field.name, obj.field)
|
||||
return '%s__%s' % (name, n), choices
|
||||
elif issubclass(type(field), ListField):
|
||||
return name, field.field.choices
|
||||
return name, field.choices
|
||||
|
||||
|
||||
def get_fields_with_attr(cls, attr, default=False):
|
||||
fields = []
|
||||
for field_name, field in cls._fields.items():
|
||||
if not getattr(field, attr, default):
|
||||
continue
|
||||
field_t = type(field)
|
||||
if issubclass(field_t, EmbeddedDocumentField):
|
||||
fields.extend((('%s__%s' % (field_name, name), choices)
|
||||
for name, choices in get_fields_with_attr(field.document_type, attr, default)))
|
||||
elif issubclass(type(field), ListField):
|
||||
fields.append((field_name, field.field.choices))
|
||||
else:
|
||||
fields.append((field_name, field.choices))
|
||||
return fields
|
||||
|
||||
|
||||
def get_items(cls):
|
||||
""" get key/value items from an enum-like class (members represent enumeration key/value) """
|
||||
|
||||
res = {
|
||||
k: v
|
||||
for k, v in getmembers(cls)
|
||||
if not (k.startswith("_") or ismethod(v))
|
||||
}
|
||||
return res
|
||||
|
||||
|
||||
def get_options(cls):
|
||||
""" get options from an enum-like class (members represent enumeration key/value) """
|
||||
return list(get_items(cls).values())
|
||||
|
||||
|
||||
# return a dictionary of items which:
|
||||
# 1. are in the call_data
|
||||
# 2. are in the fields dictionary, and their value in the call_data matches the type in fields
|
||||
# 3. are in the cls_fields
|
||||
def parse_from_call(call_data, fields, cls_fields, discard_none_values=True):
|
||||
if not isinstance(fields, dict):
|
||||
# fields should be key=>type dict
|
||||
fields = {k: None for k in fields}
|
||||
fields = {k: v for k, v in fields.items() if k in cls_fields}
|
||||
res = {}
|
||||
with translate_errors_context('parsing call data'):
|
||||
for field, desc in fields.items():
|
||||
value = call_data.get(field)
|
||||
if value is None:
|
||||
if not discard_none_values and field in call_data:
|
||||
# we'll keep the None value in case the field actually exists in the call data
|
||||
res[field] = None
|
||||
continue
|
||||
if desc:
|
||||
if callable(desc):
|
||||
desc(value)
|
||||
else:
|
||||
if issubclass(desc, (list, tuple, dict)) and not isinstance(value, desc):
|
||||
raise ParseCallError('expecting %s' % desc.__name__, field=field)
|
||||
if issubclass(desc, Document) and not desc.objects(id=value).only('id'):
|
||||
raise ParseCallError('expecting %s id' % desc.__name__, id=value, field=field)
|
||||
res[field] = value
|
||||
return res
|
||||
|
||||
|
||||
def init_cls_from_base(cls, instance):
|
||||
return cls(**{k: v for k, v in instance.to_mongo(use_db_field=False).to_dict().items() if k[0] != '_'})
|
||||
|
||||
|
||||
def get_company_or_none_constraint(company=None):
|
||||
return Q(company__in=(company, None, '')) | Q(company__exists=False)
|
||||
|
||||
|
||||
def field_does_not_exist(field: str, empty_value=None, is_list=False) -> Q:
|
||||
"""
|
||||
Creates a query object used for finding a field that doesn't exist, or has None or an empty value.
|
||||
:param field: Field name
|
||||
:param empty_value: The empty value to test for (None means no specific empty value will be used)
|
||||
:param is_list: Is this a list (array) field. In this case, instead of testing for an empty value,
|
||||
the length of the array will be used (len==0 means empty)
|
||||
:return:
|
||||
"""
|
||||
query = (Q(**{f"{field}__exists": False}) |
|
||||
Q(**{f"{field}__in": {empty_value, None}}))
|
||||
if is_list:
|
||||
query |= Q(**{f"{field}__size": 0})
|
||||
return query
|
||||
|
||||
|
||||
def get_subkey(d, key_path, default=None):
|
||||
""" Get a key from a nested dictionary. kay_path is a '.' separated string of keys used to traverse
|
||||
the nested dictionary.
|
||||
"""
|
||||
keys = key_path.split('.')
|
||||
for i, key in enumerate(keys):
|
||||
if not isinstance(d, dict):
|
||||
raise KeyError('Expecting a dict (%s)' % ('.'.join(keys[:i]) if i else 'bad input'))
|
||||
d = d.get(key)
|
||||
if key is None:
|
||||
return default
|
||||
return d
|
||||
|
||||
|
||||
def id():
|
||||
return str(uuid4()).replace("-", "")
|
||||
|
||||
|
||||
def hash_field_name(s):
|
||||
""" Hash field name into a unique safe string """
|
||||
return hashlib.md5(s.encode()).hexdigest()
|
||||
|
||||
|
||||
def merge_dicts(*dicts):
|
||||
base = {}
|
||||
for dct in dicts:
|
||||
base.update(dct)
|
||||
return base
|
||||
|
||||
|
||||
def filter_fields(cls, fields):
|
||||
"""From the fields dictionary return only the fields that match cls fields"""
|
||||
return {key: fields[key] for key in fields if key in get_fields(cls)}
|
||||
Reference in New Issue
Block a user