mirror of
https://github.com/clearml/clearml-server
synced 2025-06-26 23:15:47 +00:00
Rename server to apiserver
This commit is contained in:
92
apiserver/database/__init__.py
Normal file
92
apiserver/database/__init__.py
Normal file
@@ -0,0 +1,92 @@
|
||||
from os import getenv
|
||||
|
||||
from boltons.iterutils import first
|
||||
from furl import furl
|
||||
from jsonmodels import models
|
||||
from jsonmodels.errors import ValidationError
|
||||
from jsonmodels.fields import StringField
|
||||
from mongoengine import register_connection
|
||||
from mongoengine.connection import get_connection
|
||||
|
||||
from config import config
|
||||
from .defs import Database
|
||||
from .utils import get_items
|
||||
|
||||
log = config.logger("database")
|
||||
|
||||
strict = config.get("apiserver.mongo.strict", True)
|
||||
|
||||
OVERRIDE_HOST_ENV_KEY = (
|
||||
"TRAINS_MONGODB_SERVICE_HOST",
|
||||
"MONGODB_SERVICE_HOST",
|
||||
"MONGODB_SERVICE_SERVICE_HOST",
|
||||
)
|
||||
OVERRIDE_PORT_ENV_KEY = ("TRAINS_MONGODB_SERVICE_PORT", "MONGODB_SERVICE_PORT")
|
||||
|
||||
_entries = []
|
||||
|
||||
|
||||
class DatabaseEntry(models.Base):
|
||||
host = StringField(required=True)
|
||||
alias = StringField()
|
||||
|
||||
@property
|
||||
def health_alias(self):
|
||||
return "__health__" + self.alias
|
||||
|
||||
|
||||
def initialize():
|
||||
db_entries = config.get("hosts.mongo", {})
|
||||
missing = []
|
||||
log.info("Initializing database connections")
|
||||
|
||||
override_hostname = first(map(getenv, OVERRIDE_HOST_ENV_KEY), None)
|
||||
if override_hostname:
|
||||
log.info(f"Using override mongodb host {override_hostname}")
|
||||
|
||||
override_port = first(map(getenv, OVERRIDE_PORT_ENV_KEY), None)
|
||||
if override_port:
|
||||
log.info(f"Using override mongodb port {override_port}")
|
||||
|
||||
for key, alias in get_items(Database).items():
|
||||
if key not in db_entries:
|
||||
missing.append(key)
|
||||
continue
|
||||
|
||||
entry = DatabaseEntry(alias=alias, **db_entries.get(key))
|
||||
|
||||
if override_hostname:
|
||||
entry.host = furl(entry.host).set(host=override_hostname).url
|
||||
|
||||
if override_port:
|
||||
entry.host = furl(entry.host).set(port=override_port).url
|
||||
|
||||
try:
|
||||
entry.validate()
|
||||
log.info(
|
||||
"Registering connection to %(alias)s (%(host)s)" % entry.to_struct()
|
||||
)
|
||||
register_connection(alias=alias, host=entry.host)
|
||||
|
||||
_entries.append(entry)
|
||||
except ValidationError as ex:
|
||||
raise Exception("Invalid database entry `%s`: %s" % (key, ex.args[0]))
|
||||
if missing:
|
||||
raise ValueError("Missing database configuration for %s" % ", ".join(missing))
|
||||
|
||||
|
||||
def get_entries():
|
||||
return _entries
|
||||
|
||||
|
||||
def get_hosts():
|
||||
return [entry.host for entry in get_entries()]
|
||||
|
||||
|
||||
def get_aliases():
|
||||
return [entry.alias for entry in get_entries()]
|
||||
|
||||
|
||||
def reconnect():
|
||||
for entry in get_entries():
|
||||
get_connection(entry.alias, reconnect=True)
|
||||
10
apiserver/database/defs.py
Normal file
10
apiserver/database/defs.py
Normal file
@@ -0,0 +1,10 @@
|
||||
|
||||
|
||||
class Database(object):
|
||||
""" Database names for our different DB instances """
|
||||
|
||||
backend = 'backend-db'
|
||||
''' Used for all backend objects (tasks, models etc.) '''
|
||||
|
||||
auth = 'auth-db'
|
||||
''' Used for all authentication and permission objects '''
|
||||
189
apiserver/database/errors.py
Normal file
189
apiserver/database/errors.py
Normal file
@@ -0,0 +1,189 @@
|
||||
import re
|
||||
from contextlib import contextmanager
|
||||
from functools import wraps
|
||||
|
||||
import dpath
|
||||
from dpath.exceptions import InvalidKeyName
|
||||
from elasticsearch import ElasticsearchException
|
||||
from elasticsearch.helpers import BulkIndexError
|
||||
from jsonmodels.errors import ValidationError as JsonschemaValidationError
|
||||
from mongoengine.errors import (
|
||||
ValidationError,
|
||||
NotUniqueError,
|
||||
FieldDoesNotExist,
|
||||
InvalidDocumentError,
|
||||
LookUpError,
|
||||
InvalidQueryError,
|
||||
)
|
||||
from pymongo.errors import PyMongoError, NotMasterError
|
||||
|
||||
from apierrors import errors
|
||||
|
||||
|
||||
class MakeGetAllQueryError(Exception):
|
||||
def __init__(self, error, field):
|
||||
super(MakeGetAllQueryError, self).__init__(f"{error}: field={field}")
|
||||
self.error = error
|
||||
self.field = field
|
||||
|
||||
|
||||
class ParseCallError(Exception):
|
||||
def __init__(self, msg, **kwargs):
|
||||
super(ParseCallError, self).__init__(msg)
|
||||
self.params = kwargs
|
||||
|
||||
|
||||
def throws_default_error(err_cls):
|
||||
"""
|
||||
Used to make functions (Exception, str) -> Optional[str] searching for specialized error messages raise those
|
||||
messages in ``err_cls``. If the decorated function does not find a suitable error message,
|
||||
the underlying exception is returned.
|
||||
:param err_cls: Error class (generated by apierrors)
|
||||
"""
|
||||
|
||||
def decorator(func):
|
||||
@wraps(func)
|
||||
def wrapper(self, e, message, **kwargs):
|
||||
extra_info = func(self, e, message, **kwargs)
|
||||
raise err_cls(message, err=e, extra_info=extra_info)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
class ElasticErrorsHandler(object):
|
||||
@classmethod
|
||||
@throws_default_error(errors.server_error.DataError)
|
||||
def bulk_error(cls, e, _, **__):
|
||||
if not e.errors:
|
||||
return
|
||||
|
||||
# Else try returning a better error string
|
||||
for _, reason in dpath.search(e.errors[0], "*/error/reason", yielded=True):
|
||||
return reason
|
||||
|
||||
|
||||
class MongoEngineErrorsHandler(object):
|
||||
# NotUniqueError
|
||||
__not_unique_regex = re.compile(
|
||||
r"collection:\s(?P<collection>[\w.]+)\sindex:\s(?P<index>\w+)\sdup\skey:\s{(?P<values>[^\}]+)\}"
|
||||
)
|
||||
__not_unique_value_regex = re.compile(r':\s"(?P<value>[^"]+)"')
|
||||
__id_index = "_id_"
|
||||
__index_sep_regex = re.compile(r"_[0-9]+_?")
|
||||
|
||||
# FieldDoesNotExist
|
||||
__not_exist_fields_regex = re.compile(r'"{(?P<fields>.+?)}".+?"(?P<document>.+?)"')
|
||||
__not_exist_field_regex = re.compile(r"'(?P<field>\w+)'")
|
||||
|
||||
@classmethod
|
||||
def validation_error(cls, e: ValidationError, message, **_):
|
||||
# Thrown when a document is validated. Documents are validated by default on save and on update
|
||||
err_dict = e.errors or {e.field_name: e.message}
|
||||
raise errors.bad_request.DataValidationError(message, **err_dict)
|
||||
|
||||
@classmethod
|
||||
def not_unique_error(cls, e, message, **_):
|
||||
# Thrown when a save/update violates a unique index constraint
|
||||
m = cls.__not_unique_regex.search(str(e))
|
||||
if not m:
|
||||
raise errors.bad_request.ExpectedUniqueData(message, err=str(e))
|
||||
values = cls.__not_unique_value_regex.findall(m.group("values"))
|
||||
index = m.group("index")
|
||||
if index == cls.__id_index:
|
||||
fields = "id"
|
||||
else:
|
||||
fields = cls.__index_sep_regex.split(index)[:-1]
|
||||
raise errors.bad_request.ExpectedUniqueData(
|
||||
message, **dict(zip(fields, values))
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def field_does_not_exist(cls, e, message, **kwargs):
|
||||
# Strict mode. Unknown fields encountered in loaded document(s)
|
||||
field_does_not_exist_cls = kwargs.get(
|
||||
"field_does_not_exist_cls", errors.server_error.InconsistentData
|
||||
)
|
||||
m = cls.__not_exist_fields_regex.search(str(e))
|
||||
params = {}
|
||||
if m:
|
||||
params["document"] = m.group("document")
|
||||
fields = cls.__not_exist_field_regex.findall(m.group("fields"))
|
||||
if fields:
|
||||
if len(fields) > 1:
|
||||
params["fields"] = "(%s)" % ", ".join(fields)
|
||||
else:
|
||||
params["field"] = fields[0]
|
||||
raise field_does_not_exist_cls(message, **params)
|
||||
|
||||
@classmethod
|
||||
@throws_default_error(errors.server_error.DataError)
|
||||
def invalid_document_error(cls, e, message, **_):
|
||||
# Reverse_delete_rule used in reference field
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def lookup_error(cls, e, message, **_):
|
||||
raise errors.bad_request.InvalidFields(
|
||||
"probably an invalid field name or unsupported nested field",
|
||||
replacement_msg="Lookup error",
|
||||
err=str(e),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@throws_default_error(errors.bad_request.InvalidRegexError)
|
||||
def invalid_regex_error(cls, e, _, **__):
|
||||
if e.args and e.args[0] == "unexpected end of regular expression":
|
||||
raise errors.bad_request.InvalidRegexError(e.args[0])
|
||||
|
||||
@classmethod
|
||||
@throws_default_error(errors.server_error.InternalError)
|
||||
def invalid_query_error(cls, e, message, **_):
|
||||
pass
|
||||
|
||||
|
||||
@contextmanager
|
||||
def translate_errors_context(message=None, **kwargs):
|
||||
"""
|
||||
A context manager that translates MongoEngine's and Elastic thrown errors into our apierrors classes,
|
||||
with an appropriate message.
|
||||
"""
|
||||
try:
|
||||
if message:
|
||||
message = "while " + message
|
||||
yield True
|
||||
except ValidationError as e:
|
||||
MongoEngineErrorsHandler.validation_error(e, message, **kwargs)
|
||||
except NotUniqueError as e:
|
||||
MongoEngineErrorsHandler.not_unique_error(e, message, **kwargs)
|
||||
except FieldDoesNotExist as e:
|
||||
MongoEngineErrorsHandler.field_does_not_exist(e, message, **kwargs)
|
||||
except InvalidDocumentError as e:
|
||||
MongoEngineErrorsHandler.invalid_document_error(e, message, **kwargs)
|
||||
except LookUpError as e:
|
||||
MongoEngineErrorsHandler.lookup_error(e, message, **kwargs)
|
||||
except re.error as e:
|
||||
MongoEngineErrorsHandler.invalid_regex_error(e, message, **kwargs)
|
||||
except InvalidQueryError as e:
|
||||
MongoEngineErrorsHandler.invalid_query_error(e, message, **kwargs)
|
||||
except PyMongoError as e:
|
||||
raise errors.server_error.InternalError(message, err=str(e))
|
||||
except NotMasterError as e:
|
||||
raise errors.server_error.InternalError(message, err=str(e))
|
||||
except MakeGetAllQueryError as e:
|
||||
raise errors.bad_request.ValidationError(e.error, field=e.field)
|
||||
except ParseCallError as e:
|
||||
raise errors.bad_request.FieldsValueError(e.args[0], **e.params)
|
||||
except JsonschemaValidationError as e:
|
||||
if len(e.args) >= 2:
|
||||
raise errors.bad_request.ValidationError(e.args[0], reason=e.args[1])
|
||||
raise errors.bad_request.ValidationError(e.args[0])
|
||||
except BulkIndexError as e:
|
||||
ElasticErrorsHandler.bulk_error(e, message, **kwargs)
|
||||
except ElasticsearchException as e:
|
||||
raise errors.server_error.DataError(e, message, **kwargs)
|
||||
except InvalidKeyName:
|
||||
raise errors.server_error.DataError("invalid empty key encountered in data")
|
||||
except Exception as ex:
|
||||
raise
|
||||
209
apiserver/database/fields.py
Normal file
209
apiserver/database/fields.py
Normal file
@@ -0,0 +1,209 @@
|
||||
from operator import itemgetter
|
||||
from sys import maxsize
|
||||
from typing import Type, Tuple
|
||||
|
||||
import six
|
||||
from mongoengine import (
|
||||
EmbeddedDocumentListField,
|
||||
ListField,
|
||||
FloatField,
|
||||
StringField,
|
||||
EmbeddedDocumentField,
|
||||
SortedListField,
|
||||
MapField,
|
||||
DictField,
|
||||
DynamicField,
|
||||
)
|
||||
from mongoengine.fields import key_not_string, key_starts_with_dollar
|
||||
|
||||
NoneType = type(None)
|
||||
|
||||
|
||||
class LengthRangeListField(ListField):
|
||||
def __init__(self, field=None, max_length=maxsize, min_length=0, **kwargs):
|
||||
self.__min_length = min_length
|
||||
self.__max_length = max_length
|
||||
super(LengthRangeListField, self).__init__(field, **kwargs)
|
||||
|
||||
def validate(self, value):
|
||||
min, val, max = self.__min_length, len(value), self.__max_length
|
||||
if not min <= val <= max:
|
||||
self.error("Item count %d exceeds range [%d, %d]" % (val, min, max))
|
||||
super(LengthRangeListField, self).validate(value)
|
||||
|
||||
|
||||
class LengthRangeEmbeddedDocumentListField(LengthRangeListField):
|
||||
def __init__(self, field=None, *args, **kwargs):
|
||||
super(LengthRangeEmbeddedDocumentListField, self).__init__(
|
||||
EmbeddedDocumentField(field), *args, **kwargs
|
||||
)
|
||||
|
||||
|
||||
class UniqueEmbeddedDocumentListField(EmbeddedDocumentListField):
|
||||
def __init__(self, document_type, key, **kwargs):
|
||||
"""
|
||||
Create a unique embedded document list field for a document type with a unique comparison key func/property
|
||||
:param document_type: The type of :class:`~mongoengine.EmbeddedDocument` the list will hold.
|
||||
:param key: A callable to extract a key from each item
|
||||
"""
|
||||
if not callable(key):
|
||||
raise KeyError("key must be callable")
|
||||
self.__key = key
|
||||
super(UniqueEmbeddedDocumentListField, self).__init__(document_type)
|
||||
|
||||
def validate(self, value):
|
||||
if len({self.__key(i) for i in value}) != len(value):
|
||||
self.error("Items with duplicate key exist in the list")
|
||||
super(UniqueEmbeddedDocumentListField, self).validate(value)
|
||||
|
||||
|
||||
def object_to_key_value_pairs(obj):
|
||||
if isinstance(obj, dict):
|
||||
return [(key, object_to_key_value_pairs(value)) for key, value in obj.items()]
|
||||
if isinstance(obj, list):
|
||||
return list(map(object_to_key_value_pairs, obj))
|
||||
return obj
|
||||
|
||||
|
||||
class EmbeddedDocumentSortedListField(EmbeddedDocumentListField):
|
||||
"""
|
||||
A sorted list of embedded documents
|
||||
"""
|
||||
|
||||
def to_mongo(self, value, use_db_field=True, fields=None):
|
||||
value = super(EmbeddedDocumentSortedListField, self).to_mongo(
|
||||
value, use_db_field, fields
|
||||
)
|
||||
return sorted(value, key=object_to_key_value_pairs)
|
||||
|
||||
|
||||
class LengthRangeSortedListField(LengthRangeListField, SortedListField):
|
||||
pass
|
||||
|
||||
|
||||
class CustomFloatField(FloatField):
|
||||
def __init__(self, greater_than=None, **kwargs):
|
||||
self.greater_than = greater_than
|
||||
super(CustomFloatField, self).__init__(**kwargs)
|
||||
|
||||
def validate(self, value):
|
||||
super(CustomFloatField, self).validate(value)
|
||||
|
||||
if self.greater_than is not None and value <= self.greater_than:
|
||||
self.error("Float value must be greater than %s" % str(self.greater_than))
|
||||
|
||||
|
||||
class StrippedStringField(StringField):
|
||||
def __init__(
|
||||
self, regex=None, max_length=None, min_length=None, strip_chars=None, **kwargs
|
||||
):
|
||||
super(StrippedStringField, self).__init__(
|
||||
regex, max_length, min_length, **kwargs
|
||||
)
|
||||
self._strip_chars = strip_chars
|
||||
|
||||
def __set__(self, instance, value):
|
||||
if value is not None:
|
||||
try:
|
||||
value = value.strip(self._strip_chars)
|
||||
except AttributeError:
|
||||
pass
|
||||
super(StrippedStringField, self).__set__(instance, value)
|
||||
|
||||
def prepare_query_value(self, op, value):
|
||||
if not isinstance(op, six.string_types):
|
||||
return value
|
||||
if value is not None:
|
||||
value = value.strip(self._strip_chars)
|
||||
return super(StrippedStringField, self).prepare_query_value(op, value)
|
||||
|
||||
|
||||
def contains_empty_key(d):
|
||||
"""
|
||||
Helper function to recursively determine if any key in a
|
||||
dictionary is empty (based on mongoengine.fields.key_not_string)
|
||||
"""
|
||||
for k, v in list(d.items()):
|
||||
if not k or (isinstance(v, dict) and contains_empty_key(v)):
|
||||
return True
|
||||
|
||||
|
||||
class DictValidationMixin:
|
||||
"""
|
||||
DictField validation in MongoEngine requires default alias and permissions to access DB version:
|
||||
https://github.com/MongoEngine/mongoengine/issues/2239
|
||||
This is a stripped down implementation that does not require any of the above and implies Mongo ver 3.6+
|
||||
"""
|
||||
|
||||
def _safe_validate(self: DictField, value):
|
||||
if not isinstance(value, dict):
|
||||
self.error("Only dictionaries may be used in a DictField")
|
||||
|
||||
if key_not_string(value):
|
||||
msg = "Invalid dictionary key - documents must have only string keys"
|
||||
self.error(msg)
|
||||
|
||||
if key_starts_with_dollar(value):
|
||||
self.error(
|
||||
'Invalid dictionary key name - keys may not startswith "$" characters'
|
||||
)
|
||||
super(DictField, self).validate(value)
|
||||
|
||||
|
||||
class SafeMapField(MapField, DictValidationMixin):
|
||||
def validate(self, value):
|
||||
self._safe_validate(value)
|
||||
|
||||
if contains_empty_key(value):
|
||||
self.error("Empty keys are not allowed in a MapField")
|
||||
|
||||
|
||||
class SafeDictField(DictField, DictValidationMixin):
|
||||
def validate(self, value):
|
||||
self._safe_validate(value)
|
||||
|
||||
if contains_empty_key(value):
|
||||
self.error("Empty keys are not allowed in a DictField")
|
||||
|
||||
|
||||
class SafeSortedListField(SortedListField):
|
||||
"""
|
||||
SortedListField that does not raise an error in case items are not comparable
|
||||
(in which case they will be sorted by their string representation)
|
||||
"""
|
||||
|
||||
def to_mongo(self, *args, **kwargs):
|
||||
try:
|
||||
return super(SafeSortedListField, self).to_mongo(*args, **kwargs)
|
||||
except TypeError:
|
||||
return self._safe_to_mongo(*args, **kwargs)
|
||||
|
||||
def _safe_to_mongo(self, value, use_db_field=True, fields=None):
|
||||
value = super(SortedListField, self).to_mongo(value, use_db_field, fields)
|
||||
if self._ordering is not None:
|
||||
|
||||
def key(v):
|
||||
return str(itemgetter(self._ordering)(v))
|
||||
|
||||
else:
|
||||
key = str
|
||||
return sorted(value, key=key, reverse=self._order_reverse)
|
||||
|
||||
|
||||
class UnionField(DynamicField):
|
||||
def __init__(self, types, *args, **kwargs):
|
||||
super(UnionField, self).__init__(*args, **kwargs)
|
||||
self.types: Tuple[Type] = tuple(types)
|
||||
|
||||
def validate(self, value, clean=True):
|
||||
if not isinstance(value, self.types):
|
||||
type_names = [t.__name__ for t in self.types]
|
||||
expected = " or ".join(
|
||||
filter(
|
||||
None,
|
||||
(", ".join(type_names[:-1]), type_names[-1]))
|
||||
)
|
||||
self.error(
|
||||
f"Expected {expected}, got {type(value).__name__}: {value}"
|
||||
)
|
||||
super(UnionField, self).validate(value, clean)
|
||||
62
apiserver/database/model/__init__.py
Normal file
62
apiserver/database/model/__init__.py
Normal file
@@ -0,0 +1,62 @@
|
||||
from enum import Enum
|
||||
|
||||
from mongoengine import Document, StringField
|
||||
|
||||
from apierrors import errors
|
||||
from database.model.base import DbModelMixin, ABSTRACT_FLAG
|
||||
from database.model.company import Company
|
||||
from database.model.user import User
|
||||
|
||||
|
||||
class AttributedDocument(DbModelMixin, Document):
|
||||
"""
|
||||
Represents objects which are attributed to a company and a user or to "no one".
|
||||
Company must be required since it can be used as unique field.
|
||||
"""
|
||||
meta = ABSTRACT_FLAG
|
||||
company = StringField(required=True, reference_field=Company)
|
||||
user = StringField(reference_field=User)
|
||||
|
||||
def is_public(self) -> bool:
|
||||
return bool(self.company)
|
||||
|
||||
|
||||
class PrivateDocument(AttributedDocument):
|
||||
"""
|
||||
Represents documents which always belong to a single company
|
||||
"""
|
||||
meta = ABSTRACT_FLAG
|
||||
# can not have an empty string as this is the "public" marker
|
||||
company = StringField(required=True, reference_field=Company, min_length=1)
|
||||
user = StringField(reference_field=User, required=True)
|
||||
|
||||
def is_public(self) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def validate_id(cls, company, **kwargs):
|
||||
"""
|
||||
Validate existence of objects with certain IDs. within company.
|
||||
:param cls: Model class to search in
|
||||
:param company: Company to search in
|
||||
:param kwargs: Mapping of field name to object ID. If any ID does not have a corresponding object,
|
||||
it will be reported along with the name it was assigned to.
|
||||
:return:
|
||||
"""
|
||||
ids = set(kwargs.values())
|
||||
objs = list(cls.objects(company=company, id__in=ids).only('id'))
|
||||
missing = ids - set(x.id for x in objs)
|
||||
if not missing:
|
||||
return
|
||||
id_to_name = {}
|
||||
for name, obj_id in kwargs.items():
|
||||
id_to_name.setdefault(obj_id, []).append(name)
|
||||
raise errors.bad_request.ValidationError(
|
||||
'Invalid {} ids'.format(cls.__name__.lower()),
|
||||
**{name: obj_id for obj_id in missing for name in id_to_name[obj_id]}
|
||||
)
|
||||
|
||||
|
||||
class EntityVisibility(Enum):
|
||||
active = "active"
|
||||
archived = "archived"
|
||||
76
apiserver/database/model/auth.py
Normal file
76
apiserver/database/model/auth.py
Normal file
@@ -0,0 +1,76 @@
|
||||
from mongoengine import (
|
||||
StringField,
|
||||
EmbeddedDocument,
|
||||
EmbeddedDocumentListField,
|
||||
EmailField,
|
||||
DateTimeField,
|
||||
)
|
||||
|
||||
from database import Database, strict
|
||||
from database.model import DbModelMixin
|
||||
from database.model.base import AuthDocument
|
||||
from database.utils import get_options
|
||||
|
||||
|
||||
class Entities(object):
|
||||
company = "company"
|
||||
task = "task"
|
||||
user = "user"
|
||||
model = "model"
|
||||
|
||||
|
||||
class Role(object):
|
||||
system = "system"
|
||||
""" Internal system component """
|
||||
root = "root"
|
||||
""" Root admin (person) """
|
||||
admin = "admin"
|
||||
""" Company administrator """
|
||||
superuser = "superuser"
|
||||
""" Company super user """
|
||||
user = "user"
|
||||
""" Company user """
|
||||
annotator = "annotator"
|
||||
""" Annotator with limited access"""
|
||||
guest = "guest"
|
||||
""" Guest user. Read Only."""
|
||||
|
||||
@classmethod
|
||||
def get_system_roles(cls) -> set:
|
||||
return {cls.system, cls.root}
|
||||
|
||||
@classmethod
|
||||
def get_company_roles(cls) -> set:
|
||||
return set(get_options(cls)) - cls.get_system_roles()
|
||||
|
||||
|
||||
class Credentials(EmbeddedDocument):
|
||||
meta = {"strict": False}
|
||||
key = StringField(required=True)
|
||||
secret = StringField(required=True)
|
||||
last_used = DateTimeField()
|
||||
|
||||
|
||||
class User(DbModelMixin, AuthDocument):
|
||||
meta = {"db_alias": Database.auth, "strict": strict}
|
||||
|
||||
id = StringField(primary_key=True)
|
||||
name = StringField()
|
||||
|
||||
created = DateTimeField()
|
||||
""" User auth entry creation time """
|
||||
|
||||
validated = DateTimeField()
|
||||
""" Last validation (login) time """
|
||||
|
||||
role = StringField(required=True, choices=get_options(Role), default=Role.user)
|
||||
""" User role """
|
||||
|
||||
company = StringField(required=True)
|
||||
""" Company this user belongs to """
|
||||
|
||||
credentials = EmbeddedDocumentListField(Credentials, default=list)
|
||||
""" Credentials generated for this user """
|
||||
|
||||
email = EmailField(unique=True, required=True)
|
||||
""" Email uniquely identifying the user """
|
||||
812
apiserver/database/model/base.py
Normal file
812
apiserver/database/model/base.py
Normal file
@@ -0,0 +1,812 @@
|
||||
import re
|
||||
from collections import namedtuple
|
||||
from functools import reduce
|
||||
from typing import Collection, Sequence, Union, Optional, Type, Tuple
|
||||
|
||||
from boltons.iterutils import first, bucketize, partition
|
||||
from dateutil.parser import parse as parse_datetime
|
||||
from mongoengine import Q, Document, ListField, StringField
|
||||
from pymongo.command_cursor import CommandCursor
|
||||
|
||||
from apierrors import errors
|
||||
from apierrors.base import BaseError
|
||||
from config import config
|
||||
from database.errors import MakeGetAllQueryError
|
||||
from database.projection import project_dict, ProjectionHelper
|
||||
from database.props import PropsMixin
|
||||
from database.query import RegexQ, RegexWrapper
|
||||
from database.utils import (
|
||||
get_company_or_none_constraint,
|
||||
get_fields_choices,
|
||||
field_does_not_exist,
|
||||
field_exists,
|
||||
)
|
||||
|
||||
log = config.logger("dbmodel")
|
||||
|
||||
ACCESS_REGEX = re.compile(r"^(?P<prefix>>=|>|<=|<)?(?P<value>.*)$")
|
||||
ACCESS_MODIFIER = {">=": "gte", ">": "gt", "<=": "lte", "<": "lt"}
|
||||
|
||||
ABSTRACT_FLAG = {"abstract": True}
|
||||
|
||||
|
||||
class AuthDocument(Document):
|
||||
meta = ABSTRACT_FLAG
|
||||
|
||||
|
||||
class ProperDictMixin(object):
|
||||
def to_proper_dict(
|
||||
self: Union["ProperDictMixin", Document],
|
||||
strip_private=True,
|
||||
only=None,
|
||||
extra_dict=None,
|
||||
) -> dict:
|
||||
return self.properize_dict(
|
||||
self.to_mongo(use_db_field=False).to_dict(),
|
||||
strip_private=strip_private,
|
||||
only=only,
|
||||
extra_dict=extra_dict,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def properize_dict(
|
||||
cls, d, strip_private=True, only=None, extra_dict=None, normalize_id=True
|
||||
):
|
||||
res = d
|
||||
if normalize_id and "_id" in res:
|
||||
res["id"] = res.pop("_id")
|
||||
if strip_private:
|
||||
res = {k: v for k, v in res.items() if k[0] != "_"}
|
||||
if only:
|
||||
res = project_dict(res, only)
|
||||
if extra_dict:
|
||||
res.update(extra_dict)
|
||||
return res
|
||||
|
||||
|
||||
class GetMixin(PropsMixin):
|
||||
_text_score = "$text_score"
|
||||
_projection_key = "projection"
|
||||
_ordering_key = "order_by"
|
||||
_search_text_key = "search_text"
|
||||
|
||||
_multi_field_param_sep = "__"
|
||||
_multi_field_param_prefix = {
|
||||
("_any_", "_or_"): lambda a, b: a | b,
|
||||
("_all_", "_and_"): lambda a, b: a & b,
|
||||
}
|
||||
MultiFieldParameters = namedtuple("MultiFieldParameters", "pattern fields")
|
||||
|
||||
_field_collation_overrides = {}
|
||||
|
||||
class QueryParameterOptions(object):
|
||||
def __init__(
|
||||
self,
|
||||
pattern_fields=("name",),
|
||||
list_fields=("tags", "system_tags", "id"),
|
||||
datetime_fields=None,
|
||||
fields=None,
|
||||
):
|
||||
"""
|
||||
:param pattern_fields: Fields for which a "string contains" condition should be generated
|
||||
:param list_fields: Fields for which a "list contains" condition should be generated
|
||||
:param datetime_fields: Fields for which datetime condition should be generated (see ACCESS_MODIFIER)
|
||||
:param fields: Fields which which a simple equality condition should be generated (basically filters out all
|
||||
other unsupported query fields)
|
||||
"""
|
||||
self.fields = fields
|
||||
self.datetime_fields = datetime_fields
|
||||
self.list_fields = list_fields
|
||||
self.pattern_fields = pattern_fields
|
||||
|
||||
class ListFieldBucketHelper:
|
||||
op_prefix = "__$"
|
||||
legacy_exclude_prefix = "-"
|
||||
|
||||
_default = "in"
|
||||
_ops = {"not": "nin"}
|
||||
_next = _default
|
||||
|
||||
def __init__(self, legacy=False):
|
||||
self._legacy = legacy
|
||||
|
||||
def key(self, v):
|
||||
if v is None:
|
||||
self._next = self._default
|
||||
return self._default
|
||||
elif self._legacy and v.startswith(self.legacy_exclude_prefix):
|
||||
self._next = self._default
|
||||
return self._ops["not"]
|
||||
elif v.startswith(self.op_prefix):
|
||||
self._next = self._ops.get(v[len(self.op_prefix) :], self._default)
|
||||
return None
|
||||
|
||||
next_ = self._next
|
||||
self._next = self._default
|
||||
return next_
|
||||
|
||||
def value_transform(self, v):
|
||||
if self._legacy and v and v.startswith(self.legacy_exclude_prefix):
|
||||
return v[len(self.legacy_exclude_prefix) :]
|
||||
return v
|
||||
|
||||
get_all_query_options = QueryParameterOptions()
|
||||
|
||||
@classmethod
|
||||
def get(
|
||||
cls: Union["GetMixin", Document],
|
||||
company,
|
||||
id,
|
||||
*,
|
||||
_only=None,
|
||||
include_public=False,
|
||||
**kwargs,
|
||||
) -> "GetMixin":
|
||||
q = cls.objects(
|
||||
cls._prepare_perm_query(company, allow_public=include_public)
|
||||
& Q(id=id, **kwargs)
|
||||
)
|
||||
if _only:
|
||||
q = q.only(*_only)
|
||||
return q.first()
|
||||
|
||||
@classmethod
|
||||
def prepare_query(
|
||||
cls,
|
||||
company: str,
|
||||
parameters: dict = None,
|
||||
parameters_options: QueryParameterOptions = None,
|
||||
allow_public=False,
|
||||
):
|
||||
"""
|
||||
Prepare a query object based on the provided query dictionary and various fields.
|
||||
:param parameters_options: Specifies options for parsing the parameters (see ParametersOptions)
|
||||
:param company: Company ID (required)
|
||||
:param allow_public: Allow results from public objects
|
||||
:param parameters: Query dictionary (relevant keys are these specified by the various field names parameters).
|
||||
Supported parameters:
|
||||
- <field_name>: <value> Will query for items with this value in the field (see QueryParameterOptions for
|
||||
specific rules on handling values). Only items matching ALL of these conditions will be retrieved.
|
||||
- <any|all>: {fields: [<field1>, <field2>, ...], pattern: <pattern>} Will query for items where any or all
|
||||
provided fields match the provided pattern.
|
||||
:return: mongoengine.Q query object
|
||||
"""
|
||||
return cls._prepare_query_no_company(
|
||||
parameters, parameters_options
|
||||
) & cls._prepare_perm_query(company, allow_public=allow_public)
|
||||
|
||||
@classmethod
|
||||
def _prepare_query_no_company(
|
||||
cls, parameters=None, parameters_options=QueryParameterOptions()
|
||||
):
|
||||
"""
|
||||
Prepare a query object based on the provided query dictionary and various fields.
|
||||
|
||||
NOTE: BE VERY CAREFUL WITH THIS CALL, as it allows creating queries that span across companies.
|
||||
|
||||
:param parameters_options: Specifies options for parsing the parameters (see ParametersOptions)
|
||||
:param parameters: Query dictionary (relevant keys are these specified by the various field names parameters).
|
||||
Supported parameters:
|
||||
- <field_name>: <value> Will query for items with this value in the field (see QueryParameterOptions for
|
||||
specific rules on handling values). Only items matching ALL of these conditions will be retrieved.
|
||||
- <any|all>: {fields: [<field1>, <field2>, ...], pattern: <pattern>} Will query for items where any or all
|
||||
provided fields match the provided pattern.
|
||||
:return: mongoengine.Q query object
|
||||
"""
|
||||
parameters_options = parameters_options or cls.get_all_query_options
|
||||
dict_query = {}
|
||||
query = RegexQ()
|
||||
if parameters:
|
||||
parameters = parameters.copy()
|
||||
opts = parameters_options
|
||||
for field in opts.pattern_fields:
|
||||
pattern = parameters.pop(field, None)
|
||||
if pattern:
|
||||
dict_query[field] = RegexWrapper(pattern)
|
||||
|
||||
for field in tuple(opts.list_fields or ()):
|
||||
data = parameters.pop(field, None)
|
||||
if data:
|
||||
query &= cls.get_list_field_query(field, data)
|
||||
|
||||
for field in opts.fields or []:
|
||||
data = parameters.pop(field, None)
|
||||
if data is not None:
|
||||
dict_query[field] = data
|
||||
|
||||
for field in opts.datetime_fields or []:
|
||||
data = parameters.pop(field, None)
|
||||
if data is not None:
|
||||
if not isinstance(data, list):
|
||||
data = [data]
|
||||
for d in data: # type: str
|
||||
m = ACCESS_REGEX.match(d)
|
||||
if not m:
|
||||
continue
|
||||
try:
|
||||
value = parse_datetime(m.group("value"))
|
||||
prefix = m.group("prefix")
|
||||
modifier = ACCESS_MODIFIER.get(prefix)
|
||||
f = field if not modifier else "__".join((field, modifier))
|
||||
dict_query[f] = value
|
||||
except (ValueError, OverflowError):
|
||||
pass
|
||||
|
||||
for field, value in parameters.items():
|
||||
for keys, func in cls._multi_field_param_prefix.items():
|
||||
if field not in keys:
|
||||
continue
|
||||
try:
|
||||
data = cls.MultiFieldParameters(**value)
|
||||
except Exception:
|
||||
raise MakeGetAllQueryError("incorrect field format", field)
|
||||
if not data.fields:
|
||||
break
|
||||
regex = RegexWrapper(data.pattern, flags=re.IGNORECASE)
|
||||
sep_fields = [f.replace(".", "__") for f in data.fields]
|
||||
q = reduce(
|
||||
lambda a, x: func(a, RegexQ(**{x: regex})), sep_fields, RegexQ()
|
||||
)
|
||||
query = query & q
|
||||
|
||||
return query & RegexQ(**dict_query)
|
||||
|
||||
@classmethod
|
||||
def get_list_field_query(cls, field: str, data: Sequence[Optional[str]]) -> Q:
|
||||
"""
|
||||
Get a proper mongoengine Q object that represents an "or" query for the provided values
|
||||
with respect to the given list field, with support for "none of empty" in case a None value
|
||||
is included.
|
||||
|
||||
- Exclusion can be specified by a leading "-" for each value (API versions <2.8)
|
||||
or by a preceding "__$not" value (operator)
|
||||
"""
|
||||
if not isinstance(data, (list, tuple)):
|
||||
raise MakeGetAllQueryError("expected list", field)
|
||||
|
||||
# TODO: backwards compatibility only for older API versions
|
||||
helper = cls.ListFieldBucketHelper(legacy=True)
|
||||
actions = bucketize(
|
||||
data, key=helper.key, value_transform=helper.value_transform
|
||||
)
|
||||
|
||||
allow_empty = None in actions.get("in", {})
|
||||
mongoengine_field = field.replace(".", "__")
|
||||
|
||||
q = RegexQ()
|
||||
for action in filter(None, actions):
|
||||
q &= RegexQ(
|
||||
**{
|
||||
f"{mongoengine_field}__{action}": list(
|
||||
set(filter(None, actions[action]))
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
if not allow_empty:
|
||||
return q
|
||||
|
||||
return (
|
||||
q
|
||||
| Q(**{f"{mongoengine_field}__exists": False})
|
||||
| Q(**{mongoengine_field: []})
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _prepare_perm_query(cls, company, allow_public=False):
|
||||
if allow_public:
|
||||
return get_company_or_none_constraint(company)
|
||||
return Q(company=company)
|
||||
|
||||
@classmethod
|
||||
def validate_order_by(cls, parameters, search_text) -> Sequence:
|
||||
"""
|
||||
Validate and extract order_by params as a list
|
||||
"""
|
||||
order_by = parameters.get(cls._ordering_key)
|
||||
if not order_by:
|
||||
return []
|
||||
|
||||
order_by = order_by if isinstance(order_by, list) else [order_by]
|
||||
order_by = [cls._text_score if x == "@text_score" else x for x in order_by]
|
||||
if not search_text and cls._text_score in order_by:
|
||||
raise errors.bad_request.FieldsValueError(
|
||||
"text score cannot be used in order_by when search text is not used"
|
||||
)
|
||||
|
||||
return order_by
|
||||
|
||||
@classmethod
|
||||
def validate_paging(
|
||||
cls, parameters=None, default_page=None, default_page_size=None
|
||||
):
|
||||
""" Validate and extract paging info from from the provided dictionary. Supports default values. """
|
||||
if parameters is None:
|
||||
parameters = {}
|
||||
default_page = parameters.get("page", default_page)
|
||||
if default_page is None:
|
||||
return None, None
|
||||
default_page_size = parameters.get("page_size", default_page_size)
|
||||
if not default_page_size:
|
||||
raise errors.bad_request.MissingRequiredFields(
|
||||
"page_size is required when page is requested", field="page_size"
|
||||
)
|
||||
elif default_page < 0:
|
||||
raise errors.bad_request.ValidationError("page must be >=0", field="page")
|
||||
elif default_page_size < 1:
|
||||
raise errors.bad_request.ValidationError(
|
||||
"page_size must be >0", field="page_size"
|
||||
)
|
||||
return default_page, default_page_size
|
||||
|
||||
@classmethod
|
||||
def get_projection(cls, parameters, override_projection=None, **__):
|
||||
""" Extract a projection list from the provided dictionary. Supports an override projection. """
|
||||
if override_projection is not None:
|
||||
return override_projection
|
||||
if not parameters:
|
||||
return []
|
||||
return parameters.get(cls._projection_key) or parameters.get("only_fields", [])
|
||||
|
||||
@classmethod
|
||||
def split_projection(
|
||||
cls, projection: Sequence[str]
|
||||
) -> Tuple[Collection[str], Collection[str]]:
|
||||
"""Return include and exclude lists based on passed projection and class definition"""
|
||||
if projection:
|
||||
include, exclude = partition(
|
||||
projection, key=lambda x: x[0] != ProjectionHelper.exclusion_prefix,
|
||||
)
|
||||
else:
|
||||
include, exclude = [], []
|
||||
exclude = {x.lstrip(ProjectionHelper.exclusion_prefix) for x in exclude}
|
||||
return include, set(cls.get_exclude_fields()).union(exclude).difference(include)
|
||||
|
||||
@classmethod
|
||||
def set_projection(cls, parameters: dict, value: Sequence[str]) -> Sequence[str]:
|
||||
parameters.pop("only_fields", None)
|
||||
parameters[cls._projection_key] = value
|
||||
return value
|
||||
|
||||
@classmethod
|
||||
def get_ordering(cls, parameters: dict) -> Optional[Sequence[str]]:
|
||||
return parameters.get(cls._ordering_key)
|
||||
|
||||
@classmethod
|
||||
def set_ordering(cls, parameters: dict, value: Sequence[str]) -> Sequence[str]:
|
||||
parameters[cls._ordering_key] = value
|
||||
return value
|
||||
|
||||
@classmethod
|
||||
def set_default_ordering(cls, parameters: dict, value: Sequence[str]) -> None:
|
||||
cls.set_ordering(parameters, cls.get_ordering(parameters) or value)
|
||||
|
||||
@classmethod
|
||||
def get_many_with_join(
|
||||
cls,
|
||||
company,
|
||||
query_dict=None,
|
||||
query_options=None,
|
||||
query=None,
|
||||
allow_public=False,
|
||||
override_projection=None,
|
||||
expand_reference_ids=True,
|
||||
):
|
||||
"""
|
||||
Fetch all documents matching a provided query with support for joining referenced documents according to the
|
||||
requested projection. See get_many() for more info.
|
||||
:param expand_reference_ids: If True, reference fields that contain just an ID string are expanded into
|
||||
a sub-document in the format {_id: <ID>}. Otherwise, field values are left as a string.
|
||||
"""
|
||||
if issubclass(cls, AuthDocument):
|
||||
# Refuse projection (join) for auth documents (auth.User etc.) to avoid inadvertently disclosing
|
||||
# auth-related secrets and prevent security leaks
|
||||
log.error(
|
||||
f"Attempted projection of {cls.__name__} auth document (ignored)",
|
||||
stack_info=True,
|
||||
)
|
||||
return []
|
||||
|
||||
override_projection = cls.get_projection(
|
||||
parameters=query_dict, override_projection=override_projection
|
||||
)
|
||||
|
||||
helper = ProjectionHelper(
|
||||
doc_cls=cls,
|
||||
projection=override_projection,
|
||||
expand_reference_ids=expand_reference_ids,
|
||||
)
|
||||
|
||||
# Make the main query
|
||||
results = cls.get_many(
|
||||
override_projection=helper.doc_projection,
|
||||
company=company,
|
||||
parameters=query_dict,
|
||||
query_dict=query_dict,
|
||||
query=query,
|
||||
query_options=query_options,
|
||||
allow_public=allow_public,
|
||||
)
|
||||
|
||||
def projection_func(doc_type, projection, ids):
|
||||
return doc_type.get_many_with_join(
|
||||
company=company,
|
||||
override_projection=projection,
|
||||
query=Q(id__in=ids),
|
||||
expand_reference_ids=expand_reference_ids,
|
||||
allow_public=allow_public,
|
||||
)
|
||||
|
||||
return helper.project(results, projection_func)
|
||||
|
||||
@classmethod
|
||||
def get_many(
|
||||
cls,
|
||||
company,
|
||||
parameters: dict = None,
|
||||
query_dict: dict = None,
|
||||
query_options: QueryParameterOptions = None,
|
||||
query: Q = None,
|
||||
allow_public=False,
|
||||
override_projection: Collection[str] = None,
|
||||
return_dicts=True,
|
||||
):
|
||||
"""
|
||||
Fetch all documents matching a provided query. Supported several built-in options
|
||||
(aside from those provided by the parameters):
|
||||
- Ordering: using query field `order_by` which can contain a string or a list of strings corresponding to
|
||||
field names. Using field names not defined in the document will cause an error.
|
||||
- Paging: using query fields page and page_size. page must be larger than or equal to 0, page_size must be
|
||||
larger than 0 and is required when specifying a page.
|
||||
- Text search: using query field `search_text`. If used, text score can be used in the ordering, using the
|
||||
`@text_score` keyword. A text index must be defined on the document type, otherwise an error will
|
||||
be raised.
|
||||
:param return_dicts: Return a list of dictionaries. If True, a list of dicts is returned (if projection was
|
||||
requested, each contains only the requested projection). If False, a QuerySet object is returned
|
||||
(lazy evaluated). If return_dicts is requested then the entities with the None value in order_by field
|
||||
are returned last in the ordering.
|
||||
:param company: Company ID (required)
|
||||
:param parameters: Parameters dict from which paging ordering and searching parameters are extracted.
|
||||
:param query_dict: If provided, passed to prepare_query() along with all of the relevant arguments to produce
|
||||
a query. The resulting query is AND'ed with the `query` parameter (if provided).
|
||||
:param query_options: query parameters options (see ParametersOptions)
|
||||
:param query: Optional query object (mongoengine.Q)
|
||||
:param override_projection: A list of projection fields overriding any projection specified in the `param_dict`
|
||||
argument
|
||||
:param allow_public: If True, objects marked as public (no associated company) are also queried.
|
||||
:return: A list of objects matching the query.
|
||||
"""
|
||||
if query_dict is not None:
|
||||
q = cls.prepare_query(
|
||||
parameters=query_dict,
|
||||
company=company,
|
||||
parameters_options=query_options,
|
||||
allow_public=allow_public,
|
||||
)
|
||||
else:
|
||||
q = cls._prepare_perm_query(company, allow_public=allow_public)
|
||||
_query = (q & query) if query else q
|
||||
|
||||
if return_dicts:
|
||||
return cls._get_many_override_none_ordering(
|
||||
query=_query,
|
||||
parameters=parameters,
|
||||
override_projection=override_projection,
|
||||
)
|
||||
|
||||
return cls._get_many_no_company(
|
||||
query=_query, parameters=parameters, override_projection=override_projection
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_many_public(
|
||||
cls, query: Q = None, projection: Collection[str] = None,
|
||||
):
|
||||
"""
|
||||
Fetch all public documents matching a provided query.
|
||||
:param query: Optional query object (mongoengine.Q).
|
||||
:param projection: A list of projection fields.
|
||||
:return: A list of documents matching the query.
|
||||
"""
|
||||
q = get_company_or_none_constraint()
|
||||
_query = (q & query) if query else q
|
||||
|
||||
return cls._get_many_no_company(query=_query, override_projection=projection)
|
||||
|
||||
@classmethod
|
||||
def _get_many_no_company(
|
||||
cls: Union["GetMixin", Document],
|
||||
query: Q,
|
||||
parameters=None,
|
||||
override_projection=None,
|
||||
):
|
||||
"""
|
||||
Fetch all documents matching a provided query.
|
||||
This is a company-less version for internal uses. We assume the caller has either added any necessary
|
||||
constraints to the query or that no constraints are required.
|
||||
|
||||
NOTE: BE VERY CAREFUL WITH THIS CALL, as it allows returning data across companies.
|
||||
|
||||
:param query: Query object (mongoengine.Q)
|
||||
:param parameters: Parameters dict from which paging ordering and searching parameters are extracted.
|
||||
:param override_projection: A list of projection fields overriding any projection specified in the `param_dict`
|
||||
argument
|
||||
"""
|
||||
if not query:
|
||||
raise ValueError("query or call_data must be provided")
|
||||
|
||||
parameters = parameters or {}
|
||||
search_text = parameters.get(cls._search_text_key)
|
||||
order_by = cls.validate_order_by(parameters=parameters, search_text=search_text)
|
||||
page, page_size = cls.validate_paging(parameters=parameters)
|
||||
include, exclude = cls.split_projection(
|
||||
cls.get_projection(parameters, override_projection)
|
||||
)
|
||||
|
||||
qs = cls.objects(query)
|
||||
if search_text:
|
||||
qs = qs.search_text(search_text)
|
||||
if order_by:
|
||||
# add ordering
|
||||
qs = qs.order_by(*order_by)
|
||||
|
||||
if include:
|
||||
# add projection
|
||||
qs = qs.only(*include)
|
||||
|
||||
if exclude:
|
||||
qs = qs.exclude(*exclude)
|
||||
|
||||
if page is not None and page_size:
|
||||
# add paging
|
||||
qs = qs.skip(page * page_size).limit(page_size)
|
||||
|
||||
return qs
|
||||
|
||||
@classmethod
|
||||
def _get_many_override_none_ordering(
|
||||
cls: Union[Document, "GetMixin"],
|
||||
query: Q = None,
|
||||
parameters: dict = None,
|
||||
override_projection: Collection[str] = None,
|
||||
) -> Sequence[dict]:
|
||||
"""
|
||||
Fetch all documents matching a provided query. For the first order by field
|
||||
the None values are sorted in the end regardless of the sorting order.
|
||||
If the first order field is a user defined parameter (either from execution.parameters,
|
||||
or from last_metrics) then the collation is set that sorts strings in numeric order where possible.
|
||||
This is a company-less version for internal uses. We assume the caller has either added any necessary
|
||||
constraints to the query or that no constraints are required.
|
||||
|
||||
NOTE: BE VERY CAREFUL WITH THIS CALL, as it allows returning data across companies.
|
||||
|
||||
:param query: Query object (mongoengine.Q)
|
||||
:param parameters: Parameters dict from which paging ordering and searching parameters are extracted.
|
||||
:param override_projection: A list of projection fields overriding any projection specified in the `param_dict`
|
||||
argument
|
||||
"""
|
||||
if not query:
|
||||
raise ValueError("query or call_data must be provided")
|
||||
|
||||
parameters = parameters or {}
|
||||
search_text = parameters.get(cls._search_text_key)
|
||||
order_by = cls.validate_order_by(parameters=parameters, search_text=search_text)
|
||||
page, page_size = cls.validate_paging(parameters=parameters)
|
||||
include, exclude = cls.split_projection(
|
||||
cls.get_projection(parameters, override_projection)
|
||||
)
|
||||
|
||||
query_sets = [cls.objects(query)]
|
||||
if order_by:
|
||||
order_field = first(
|
||||
field for field in order_by if not field.startswith("$")
|
||||
)
|
||||
if (
|
||||
order_field
|
||||
and not order_field.startswith("-")
|
||||
and "[" not in order_field
|
||||
):
|
||||
params = {}
|
||||
mongo_field = order_field.replace(".", "__")
|
||||
if mongo_field in cls.get_field_names_for_type(of_type=ListField):
|
||||
params["is_list"] = True
|
||||
elif mongo_field in cls.get_field_names_for_type(of_type=StringField):
|
||||
params["empty_value"] = ""
|
||||
non_empty = query & field_exists(mongo_field, **params)
|
||||
empty = query & field_does_not_exist(mongo_field, **params)
|
||||
query_sets = [cls.objects(non_empty), cls.objects(empty)]
|
||||
|
||||
query_sets = [qs.order_by(*order_by) for qs in query_sets]
|
||||
if order_field:
|
||||
collation_override = first(
|
||||
v
|
||||
for k, v in cls._field_collation_overrides.items()
|
||||
if order_field.startswith(k)
|
||||
)
|
||||
if collation_override:
|
||||
query_sets = [
|
||||
qs.collation(collation=collation_override) for qs in query_sets
|
||||
]
|
||||
|
||||
if search_text:
|
||||
query_sets = [qs.search_text(search_text) for qs in query_sets]
|
||||
|
||||
if include:
|
||||
# add projection
|
||||
query_sets = [qs.only(*include) for qs in query_sets]
|
||||
|
||||
if exclude:
|
||||
query_sets = [qs.exclude(*exclude) for qs in query_sets]
|
||||
|
||||
if page is None or not page_size:
|
||||
return [obj.to_proper_dict(only=include) for qs in query_sets for obj in qs]
|
||||
|
||||
# add paging
|
||||
ret = []
|
||||
start = page * page_size
|
||||
for qs in query_sets:
|
||||
qs_size = qs.count()
|
||||
if qs_size < start:
|
||||
start -= qs_size
|
||||
continue
|
||||
ret.extend(
|
||||
obj.to_proper_dict(only=include)
|
||||
for obj in qs.skip(start).limit(page_size)
|
||||
)
|
||||
if len(ret) >= page_size:
|
||||
break
|
||||
start = 0
|
||||
page_size -= len(ret)
|
||||
|
||||
return ret
|
||||
|
||||
@classmethod
|
||||
def get_for_writing(
|
||||
cls, *args, _only: Collection[str] = None, **kwargs
|
||||
) -> "GetMixin":
|
||||
if _only and "company" not in _only:
|
||||
_only = list(set(_only) | {"company"})
|
||||
result = cls.get(*args, _only=_only, include_public=True, **kwargs)
|
||||
if result and not result.company:
|
||||
object_name = cls.__name__.lower()
|
||||
raise errors.forbidden.NoWritePermission(
|
||||
f"cannot modify public {object_name}(s), ids={(result.id,)}"
|
||||
)
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def get_many_for_writing(cls, company, *args, **kwargs):
|
||||
result = cls.get_many(
|
||||
company=company,
|
||||
*args,
|
||||
**dict(return_dicts=False, **kwargs),
|
||||
allow_public=True,
|
||||
)
|
||||
forbidden_objects = {obj.id for obj in result if not obj.company}
|
||||
if forbidden_objects:
|
||||
object_name = cls.__name__.lower()
|
||||
raise errors.forbidden.NoWritePermission(
|
||||
f"cannot modify public {object_name}(s), ids={tuple(forbidden_objects)}"
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
class UpdateMixin(object):
|
||||
@classmethod
|
||||
def user_set_allowed(cls):
|
||||
res = getattr(cls, "__user_set_allowed_fields", None)
|
||||
if res is None:
|
||||
res = cls.__user_set_allowed_fields = get_fields_choices(
|
||||
cls, "user_set_allowed"
|
||||
)
|
||||
return res
|
||||
|
||||
@classmethod
|
||||
def get_safe_update_dict(cls, fields):
|
||||
if not fields:
|
||||
return {}
|
||||
valid_fields = cls.user_set_allowed()
|
||||
fields = [(k, v, fields[k]) for k, v in valid_fields.items() if k in fields]
|
||||
update_dict = {
|
||||
field: value
|
||||
for field, allowed, value in fields
|
||||
if allowed is None
|
||||
or (
|
||||
(value in allowed)
|
||||
if not isinstance(value, list)
|
||||
else all(v in allowed for v in value)
|
||||
)
|
||||
}
|
||||
return update_dict
|
||||
|
||||
@classmethod
|
||||
def safe_update(
|
||||
cls: Union["UpdateMixin", Document],
|
||||
company_id,
|
||||
id,
|
||||
partial_update_dict,
|
||||
injected_update=None,
|
||||
):
|
||||
update_dict = cls.get_safe_update_dict(partial_update_dict)
|
||||
if not update_dict:
|
||||
return 0, {}
|
||||
if injected_update:
|
||||
update_dict.update(injected_update)
|
||||
update_count = cls.objects(id=id, company=company_id).update(
|
||||
upsert=False, **update_dict
|
||||
)
|
||||
return update_count, update_dict
|
||||
|
||||
|
||||
class DbModelMixin(GetMixin, ProperDictMixin, UpdateMixin):
|
||||
""" Provide convenience methods for a subclass of mongoengine.Document """
|
||||
|
||||
@classmethod
|
||||
def aggregate(
|
||||
cls: Union["DbModelMixin", Document],
|
||||
pipeline: Sequence[dict],
|
||||
allow_disk_use=None,
|
||||
**kwargs,
|
||||
) -> CommandCursor:
|
||||
"""
|
||||
Aggregate objects of this document class according to the provided pipeline.
|
||||
:param pipeline: a list of dictionaries describing the pipeline stages
|
||||
:param allow_disk_use: if True, allow the server to use disk space if aggregation query cannot fit in memory.
|
||||
If None, default behavior will be used (see apiserver.conf/mongo/aggregate/allow_disk_use)
|
||||
:param kwargs: additional keyword arguments passed to mongoengine
|
||||
:return:
|
||||
"""
|
||||
kwargs.update(
|
||||
allowDiskUse=allow_disk_use
|
||||
if allow_disk_use is not None
|
||||
else config.get("apiserver.mongo.aggregate.allow_disk_use", True)
|
||||
)
|
||||
return cls.objects.aggregate(pipeline, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def set_public(
|
||||
cls: Type[Document],
|
||||
company_id: str,
|
||||
ids: Sequence[str],
|
||||
invalid_cls: Type[BaseError],
|
||||
enabled: bool = True,
|
||||
):
|
||||
if enabled:
|
||||
items = list(cls.objects(id__in=ids, company=company_id).only("id"))
|
||||
update = dict(set__company_origin=company_id, unset__company=1)
|
||||
else:
|
||||
items = list(
|
||||
cls.objects(
|
||||
id__in=ids, company__in=(None, ""), company_origin=company_id
|
||||
).only("id")
|
||||
)
|
||||
update = dict(set__company=company_id, unset__company_origin=1)
|
||||
|
||||
if len(items) < len(ids):
|
||||
missing = tuple(set(ids).difference(i.id for i in items))
|
||||
raise invalid_cls(ids=missing)
|
||||
|
||||
return {"updated": cls.objects(id__in=ids).update(**update)}
|
||||
|
||||
|
||||
def validate_id(cls, company, **kwargs):
|
||||
"""
|
||||
Validate existence of objects with certain IDs. within company.
|
||||
:param cls: Model class to search in
|
||||
:param company: Company to search in
|
||||
:param kwargs: Mapping of field name to object ID. If any ID does not have a corresponding object,
|
||||
it will be reported along with the name it was assigned to.
|
||||
:return:
|
||||
"""
|
||||
ids = set(kwargs.values())
|
||||
objs = list(cls.objects(company=company, id__in=ids).only("id"))
|
||||
missing = ids - set(x.id for x in objs)
|
||||
if not missing:
|
||||
return
|
||||
id_to_name = {}
|
||||
for name, obj_id in kwargs.items():
|
||||
id_to_name.setdefault(obj_id, []).append(name)
|
||||
raise errors.bad_request.ValidationError(
|
||||
"Invalid {} ids".format(cls.__name__.lower()),
|
||||
**{name: obj_id for obj_id in missing for name in id_to_name[obj_id]},
|
||||
)
|
||||
38
apiserver/database/model/company.py
Normal file
38
apiserver/database/model/company.py
Normal file
@@ -0,0 +1,38 @@
|
||||
from mongoengine import (
|
||||
Document,
|
||||
EmbeddedDocument,
|
||||
EmbeddedDocumentField,
|
||||
StringField,
|
||||
Q,
|
||||
BooleanField,
|
||||
DateTimeField,
|
||||
)
|
||||
|
||||
from database import Database, strict
|
||||
from database.fields import StrippedStringField
|
||||
from database.model import DbModelMixin
|
||||
|
||||
|
||||
class ReportStatsOption(EmbeddedDocument):
|
||||
enabled = BooleanField(default=False) # opt-in for statistics reporting
|
||||
enabled_version = StringField() # server version when enabled
|
||||
enabled_time = DateTimeField() # time when enabled
|
||||
enabled_user = StringField() # ID of user who enabled
|
||||
|
||||
|
||||
class CompanyDefaults(EmbeddedDocument):
|
||||
cluster = StringField()
|
||||
stats_option = EmbeddedDocumentField(ReportStatsOption, default=ReportStatsOption)
|
||||
|
||||
|
||||
class Company(DbModelMixin, Document):
|
||||
meta = {"db_alias": Database.backend, "strict": strict}
|
||||
|
||||
id = StringField(primary_key=True)
|
||||
name = StrippedStringField(unique=True, min_length=3)
|
||||
defaults = EmbeddedDocumentField(CompanyDefaults, default=CompanyDefaults)
|
||||
|
||||
@classmethod
|
||||
def _prepare_perm_query(cls, company, allow_public=False):
|
||||
""" Override default behavior since a 'company' constraint is not supported for this document... """
|
||||
return Q()
|
||||
75
apiserver/database/model/model.py
Normal file
75
apiserver/database/model/model.py
Normal file
@@ -0,0 +1,75 @@
|
||||
from mongoengine import Document, StringField, DateTimeField, BooleanField
|
||||
|
||||
from database import Database, strict
|
||||
from database.fields import StrippedStringField, SafeDictField, SafeSortedListField
|
||||
from database.model import DbModelMixin
|
||||
from database.model.base import GetMixin
|
||||
from database.model.model_labels import ModelLabels
|
||||
from database.model.company import Company
|
||||
from database.model.project import Project
|
||||
from database.model.task.task import Task
|
||||
from database.model.user import User
|
||||
|
||||
|
||||
class Model(DbModelMixin, Document):
|
||||
meta = {
|
||||
"db_alias": Database.backend,
|
||||
"strict": strict,
|
||||
"indexes": [
|
||||
"parent",
|
||||
"project",
|
||||
"task",
|
||||
("company", "framework"),
|
||||
("company", "name"),
|
||||
("company", "user"),
|
||||
{
|
||||
"name": "%s.model.main_text_index" % Database.backend,
|
||||
"fields": ["$name", "$id", "$comment", "$parent", "$task", "$project"],
|
||||
"default_language": "english",
|
||||
"weights": {
|
||||
"name": 10,
|
||||
"id": 10,
|
||||
"comment": 10,
|
||||
"parent": 5,
|
||||
"task": 3,
|
||||
"project": 3,
|
||||
},
|
||||
},
|
||||
],
|
||||
}
|
||||
get_all_query_options = GetMixin.QueryParameterOptions(
|
||||
pattern_fields=("name", "comment"),
|
||||
fields=("ready",),
|
||||
list_fields=(
|
||||
"tags",
|
||||
"system_tags",
|
||||
"framework",
|
||||
"uri",
|
||||
"id",
|
||||
"user",
|
||||
"project",
|
||||
"task",
|
||||
"parent",
|
||||
),
|
||||
)
|
||||
|
||||
id = StringField(primary_key=True)
|
||||
name = StrippedStringField(user_set_allowed=True, min_length=3)
|
||||
parent = StringField(reference_field="Model", required=False)
|
||||
user = StringField(required=True, reference_field=User)
|
||||
company = StringField(required=True, reference_field=Company)
|
||||
project = StringField(reference_field=Project, user_set_allowed=True)
|
||||
created = DateTimeField(required=True, user_set_allowed=True)
|
||||
task = StringField(reference_field=Task)
|
||||
comment = StringField(user_set_allowed=True)
|
||||
tags = SafeSortedListField(StringField(required=True), user_set_allowed=True)
|
||||
system_tags = SafeSortedListField(StringField(required=True), user_set_allowed=True)
|
||||
uri = StrippedStringField(default="", user_set_allowed=True)
|
||||
framework = StringField()
|
||||
design = SafeDictField()
|
||||
labels = ModelLabels()
|
||||
ready = BooleanField(required=True)
|
||||
ui_cache = SafeDictField(
|
||||
default=dict, user_set_allowed=True, exclude_by_default=True
|
||||
)
|
||||
company_origin = StringField(exclude_by_default=True)
|
||||
14
apiserver/database/model/model_labels.py
Normal file
14
apiserver/database/model/model_labels.py
Normal file
@@ -0,0 +1,14 @@
|
||||
from database.fields import NoneType, UnionField, SafeMapField
|
||||
|
||||
|
||||
class ModelLabels(SafeMapField):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(ModelLabels, self).__init__(
|
||||
field=UnionField(types=(int, NoneType)), *args, **kwargs
|
||||
)
|
||||
|
||||
def validate(self, value):
|
||||
super(ModelLabels, self).validate(value)
|
||||
non_empty_values = list(filter(None, value.values()))
|
||||
if non_empty_values and len(set(non_empty_values)) < len(non_empty_values):
|
||||
self.error("Same label id appears more than once in model labels")
|
||||
46
apiserver/database/model/project.py
Normal file
46
apiserver/database/model/project.py
Normal file
@@ -0,0 +1,46 @@
|
||||
from mongoengine import StringField, DateTimeField, IntField
|
||||
|
||||
from database import Database, strict
|
||||
from database.fields import StrippedStringField, SafeSortedListField
|
||||
from database.model import AttributedDocument
|
||||
from database.model.base import GetMixin
|
||||
|
||||
|
||||
class Project(AttributedDocument):
|
||||
|
||||
get_all_query_options = GetMixin.QueryParameterOptions(
|
||||
pattern_fields=("name", "description"),
|
||||
list_fields=("tags", "system_tags", "id"),
|
||||
)
|
||||
|
||||
meta = {
|
||||
"db_alias": Database.backend,
|
||||
"strict": strict,
|
||||
"indexes": [
|
||||
("company", "name"),
|
||||
{
|
||||
"name": "%s.project.main_text_index" % Database.backend,
|
||||
"fields": ["$name", "$id", "$description"],
|
||||
"default_language": "english",
|
||||
"weights": {"name": 10, "id": 10, "description": 10},
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
id = StringField(primary_key=True)
|
||||
name = StrippedStringField(
|
||||
required=True,
|
||||
unique_with=AttributedDocument.company.name,
|
||||
min_length=3,
|
||||
sparse=True,
|
||||
)
|
||||
description = StringField(required=True)
|
||||
created = DateTimeField(required=True)
|
||||
tags = SafeSortedListField(StringField(required=True))
|
||||
system_tags = SafeSortedListField(StringField(required=True))
|
||||
default_output_destination = StrippedStringField()
|
||||
last_update = DateTimeField()
|
||||
featured = IntField(default=9999)
|
||||
logo_url = StringField()
|
||||
logo_blob = StringField(exclude_by_default=True)
|
||||
company_origin = StringField(exclude_by_default=True)
|
||||
46
apiserver/database/model/queue.py
Normal file
46
apiserver/database/model/queue.py
Normal file
@@ -0,0 +1,46 @@
|
||||
from mongoengine import (
|
||||
Document,
|
||||
EmbeddedDocument,
|
||||
StringField,
|
||||
DateTimeField,
|
||||
EmbeddedDocumentListField,
|
||||
)
|
||||
|
||||
from database import Database, strict
|
||||
from database.fields import StrippedStringField, SafeSortedListField
|
||||
from database.model import DbModelMixin
|
||||
from database.model.base import ProperDictMixin, GetMixin
|
||||
from database.model.company import Company
|
||||
from database.model.task.task import Task
|
||||
|
||||
|
||||
class Entry(EmbeddedDocument, ProperDictMixin):
|
||||
""" Entry representing a task waiting in the queue """
|
||||
task = StringField(required=True, reference_field=Task)
|
||||
''' Task ID '''
|
||||
added = DateTimeField(required=True)
|
||||
''' Added to the queue '''
|
||||
|
||||
|
||||
class Queue(DbModelMixin, Document):
|
||||
|
||||
get_all_query_options = GetMixin.QueryParameterOptions(
|
||||
pattern_fields=("name",),
|
||||
list_fields=("tags", "system_tags", "id"),
|
||||
)
|
||||
|
||||
meta = {
|
||||
'db_alias': Database.backend,
|
||||
'strict': strict,
|
||||
}
|
||||
|
||||
id = StringField(primary_key=True)
|
||||
name = StrippedStringField(
|
||||
required=True, unique_with="company", min_length=3, user_set_allowed=True
|
||||
)
|
||||
company = StringField(required=True, reference_field=Company)
|
||||
created = DateTimeField(required=True)
|
||||
tags = SafeSortedListField(StringField(required=True), default=list, user_set_allowed=True)
|
||||
system_tags = SafeSortedListField(StringField(required=True), user_set_allowed=True)
|
||||
entries = EmbeddedDocumentListField(Entry, default=list)
|
||||
last_update = DateTimeField()
|
||||
57
apiserver/database/model/settings.py
Normal file
57
apiserver/database/model/settings.py
Normal file
@@ -0,0 +1,57 @@
|
||||
from typing import Any, Optional, Sequence, Tuple
|
||||
|
||||
from mongoengine import Document, StringField, DynamicField, Q
|
||||
from mongoengine.errors import NotUniqueError
|
||||
|
||||
from database import Database, strict
|
||||
from database.model import DbModelMixin
|
||||
|
||||
|
||||
class SettingKeys:
|
||||
server__uuid = "server.uuid"
|
||||
|
||||
|
||||
class Settings(DbModelMixin, Document):
|
||||
meta = {
|
||||
"db_alias": Database.backend,
|
||||
"strict": strict,
|
||||
}
|
||||
|
||||
key = StringField(primary_key=True)
|
||||
value = DynamicField()
|
||||
|
||||
@classmethod
|
||||
def get_by_key(cls, key: str, default: Optional[Any] = None, sep: str = ".") -> Any:
|
||||
key = key.strip(sep)
|
||||
res = Settings.objects(key=key).first()
|
||||
if not res:
|
||||
return default
|
||||
return res.value
|
||||
|
||||
@classmethod
|
||||
def get_by_prefix(
|
||||
cls, key_prefix: str, default: Optional[Any] = None, sep: str = "."
|
||||
) -> Sequence[Tuple[str, Any]]:
|
||||
key_prefix = key_prefix.strip(sep)
|
||||
query = Q(key=key_prefix) | Q(key__startswith=key_prefix + sep)
|
||||
res = Settings.objects(query)
|
||||
if not res:
|
||||
return default
|
||||
return [(x.key, x.value) for x in res]
|
||||
|
||||
@classmethod
|
||||
def set_or_add_value(cls, key: str, value: Any, sep: str = ".") -> bool:
|
||||
""" Sets a new value or adds a new key/value setting (if key does not exist) """
|
||||
key = key.strip(sep)
|
||||
res = Settings.objects(key=key).update(key=key, value=value, upsert=True)
|
||||
return bool(res)
|
||||
|
||||
@classmethod
|
||||
def add_value(cls, key: str, value: Any, sep: str = ".") -> bool:
|
||||
""" Adds a new key/value settings. Fails if key already exists. """
|
||||
key = key.strip(sep)
|
||||
try:
|
||||
res = cls(key=key, value=value).save(force_insert=True)
|
||||
return bool(res)
|
||||
except NotUniqueError:
|
||||
return False
|
||||
39
apiserver/database/model/task/metrics.py
Normal file
39
apiserver/database/model/task/metrics.py
Normal file
@@ -0,0 +1,39 @@
|
||||
from mongoengine import (
|
||||
EmbeddedDocument,
|
||||
StringField,
|
||||
DynamicField,
|
||||
LongField,
|
||||
EmbeddedDocumentField,
|
||||
)
|
||||
|
||||
from database.fields import SafeMapField
|
||||
|
||||
|
||||
class MetricEvent(EmbeddedDocument):
|
||||
meta = {
|
||||
# For backwards compatibility reasons
|
||||
"strict": False,
|
||||
}
|
||||
|
||||
metric = StringField(required=True)
|
||||
variant = StringField(required=True)
|
||||
value = DynamicField(required=True)
|
||||
min_value = DynamicField() # for backwards compatibility reasons
|
||||
max_value = DynamicField() # for backwards compatibility reasons
|
||||
|
||||
|
||||
class EventStats(EmbeddedDocument):
|
||||
meta = {
|
||||
# For backwards compatibility reasons
|
||||
"strict": False,
|
||||
}
|
||||
last_update = LongField()
|
||||
|
||||
|
||||
class MetricEventStats(EmbeddedDocument):
|
||||
meta = {
|
||||
# For backwards compatibility reasons
|
||||
"strict": False,
|
||||
}
|
||||
metric = StringField(required=True)
|
||||
event_stats_by_type = SafeMapField(field=EmbeddedDocumentField(EventStats))
|
||||
16
apiserver/database/model/task/output.py
Normal file
16
apiserver/database/model/task/output.py
Normal file
@@ -0,0 +1,16 @@
|
||||
from mongoengine import EmbeddedDocument, StringField
|
||||
|
||||
from database.fields import StrippedStringField
|
||||
from database.utils import get_options
|
||||
|
||||
|
||||
class Result(object):
|
||||
success = 'success'
|
||||
failure = 'failure'
|
||||
|
||||
|
||||
class Output(EmbeddedDocument):
|
||||
destination = StrippedStringField()
|
||||
model = StringField(reference_field='Model')
|
||||
error = StringField(user_set_allowed=True)
|
||||
result = StringField(choices=get_options(Result))
|
||||
219
apiserver/database/model/task/task.py
Normal file
219
apiserver/database/model/task/task.py
Normal file
@@ -0,0 +1,219 @@
|
||||
from mongoengine import (
|
||||
StringField,
|
||||
EmbeddedDocumentField,
|
||||
EmbeddedDocument,
|
||||
DateTimeField,
|
||||
IntField,
|
||||
ListField,
|
||||
LongField,
|
||||
)
|
||||
|
||||
from database import Database, strict
|
||||
from database.fields import (
|
||||
StrippedStringField,
|
||||
SafeMapField,
|
||||
SafeDictField,
|
||||
UnionField,
|
||||
EmbeddedDocumentSortedListField,
|
||||
SafeSortedListField,
|
||||
)
|
||||
from database.model import AttributedDocument
|
||||
from database.model.base import ProperDictMixin, GetMixin
|
||||
from database.model.model_labels import ModelLabels
|
||||
from database.model.project import Project
|
||||
from database.utils import get_options
|
||||
from .metrics import MetricEvent, MetricEventStats
|
||||
from .output import Output
|
||||
|
||||
DEFAULT_LAST_ITERATION = 0
|
||||
|
||||
|
||||
class TaskStatus(object):
|
||||
created = "created"
|
||||
queued = "queued"
|
||||
in_progress = "in_progress"
|
||||
stopped = "stopped"
|
||||
publishing = "publishing"
|
||||
published = "published"
|
||||
closed = "closed"
|
||||
failed = "failed"
|
||||
completed = "completed"
|
||||
unknown = "unknown"
|
||||
|
||||
|
||||
class TaskStatusMessage(object):
|
||||
stopping = "stopping"
|
||||
|
||||
|
||||
class TaskSystemTags(object):
|
||||
development = "development"
|
||||
|
||||
|
||||
class Script(EmbeddedDocument, ProperDictMixin):
|
||||
binary = StringField(default="python")
|
||||
repository = StringField(default="")
|
||||
tag = StringField()
|
||||
branch = StringField()
|
||||
version_num = StringField()
|
||||
entry_point = StringField(default="")
|
||||
working_dir = StringField()
|
||||
requirements = SafeDictField()
|
||||
diff = StringField()
|
||||
|
||||
|
||||
class ArtifactTypeData(EmbeddedDocument):
|
||||
preview = StringField()
|
||||
content_type = StringField()
|
||||
data_hash = StringField()
|
||||
|
||||
|
||||
class ArtifactModes:
|
||||
input = "input"
|
||||
output = "output"
|
||||
|
||||
|
||||
class Artifact(EmbeddedDocument):
|
||||
key = StringField(required=True)
|
||||
type = StringField(required=True)
|
||||
mode = StringField(choices=get_options(ArtifactModes), default=ArtifactModes.output)
|
||||
uri = StringField()
|
||||
hash = StringField()
|
||||
content_size = LongField()
|
||||
timestamp = LongField()
|
||||
type_data = EmbeddedDocumentField(ArtifactTypeData)
|
||||
display_data = SafeSortedListField(ListField(UnionField((int, float, str))))
|
||||
|
||||
|
||||
class ParamsItem(EmbeddedDocument, ProperDictMixin):
|
||||
section = StringField(required=True)
|
||||
name = StringField(required=True)
|
||||
value = StringField(required=True)
|
||||
type = StringField()
|
||||
description = StringField()
|
||||
|
||||
|
||||
class ConfigurationItem(EmbeddedDocument, ProperDictMixin):
|
||||
name = StringField(required=True)
|
||||
value = StringField(required=True)
|
||||
type = StringField()
|
||||
description = StringField()
|
||||
|
||||
|
||||
class Execution(EmbeddedDocument, ProperDictMixin):
|
||||
meta = {"strict": strict}
|
||||
test_split = IntField(default=0)
|
||||
parameters = SafeDictField(default=dict)
|
||||
model = StringField(reference_field="Model")
|
||||
model_desc = SafeMapField(StringField(default=""))
|
||||
model_labels = ModelLabels()
|
||||
framework = StringField()
|
||||
artifacts = EmbeddedDocumentSortedListField(Artifact)
|
||||
docker_cmd = StringField()
|
||||
queue = StringField()
|
||||
""" Queue ID where task was queued """
|
||||
|
||||
|
||||
class TaskType(object):
|
||||
training = "training"
|
||||
testing = "testing"
|
||||
inference = "inference"
|
||||
data_processing = "data_processing"
|
||||
application = "application"
|
||||
monitor = "monitor"
|
||||
controller = "controller"
|
||||
optimizer = "optimizer"
|
||||
service = "service"
|
||||
qc = "qc"
|
||||
custom = "custom"
|
||||
|
||||
|
||||
external_task_types = set(get_options(TaskType))
|
||||
|
||||
|
||||
class Task(AttributedDocument):
|
||||
_numeric_locale = {"locale": "en_US", "numericOrdering": True}
|
||||
_field_collation_overrides = {
|
||||
"execution.parameters.": _numeric_locale,
|
||||
"last_metrics.": _numeric_locale,
|
||||
"hyperparams.": _numeric_locale,
|
||||
"configuration.": _numeric_locale,
|
||||
}
|
||||
|
||||
meta = {
|
||||
"db_alias": Database.backend,
|
||||
"strict": strict,
|
||||
"indexes": [
|
||||
"created",
|
||||
"started",
|
||||
"completed",
|
||||
"parent",
|
||||
"project",
|
||||
("company", "name"),
|
||||
("company", "user"),
|
||||
("company", "type", "system_tags", "status"),
|
||||
("company", "project", "type", "system_tags", "status"),
|
||||
("status", "last_update"), # for maintenance tasks
|
||||
{
|
||||
"name": "%s.task.main_text_index" % Database.backend,
|
||||
"fields": [
|
||||
"$name",
|
||||
"$id",
|
||||
"$comment",
|
||||
"$execution.model",
|
||||
"$output.model",
|
||||
"$script.repository",
|
||||
"$script.entry_point",
|
||||
],
|
||||
"default_language": "english",
|
||||
"weights": {
|
||||
"name": 10,
|
||||
"id": 10,
|
||||
"comment": 10,
|
||||
"execution.model": 2,
|
||||
"output.model": 2,
|
||||
"script.repository": 1,
|
||||
"script.entry_point": 1,
|
||||
},
|
||||
},
|
||||
],
|
||||
}
|
||||
get_all_query_options = GetMixin.QueryParameterOptions(
|
||||
list_fields=("id", "user", "tags", "system_tags", "type", "status", "project"),
|
||||
datetime_fields=("status_changed",),
|
||||
pattern_fields=("name", "comment"),
|
||||
fields=("parent",),
|
||||
)
|
||||
|
||||
id = StringField(primary_key=True)
|
||||
name = StrippedStringField(
|
||||
required=True, user_set_allowed=True, sparse=False, min_length=3
|
||||
)
|
||||
|
||||
type = StringField(required=True, choices=get_options(TaskType))
|
||||
status = StringField(default=TaskStatus.created, choices=get_options(TaskStatus))
|
||||
status_reason = StringField()
|
||||
status_message = StringField()
|
||||
status_changed = DateTimeField()
|
||||
comment = StringField(user_set_allowed=True)
|
||||
created = DateTimeField(required=True, user_set_allowed=True)
|
||||
started = DateTimeField()
|
||||
completed = DateTimeField()
|
||||
published = DateTimeField()
|
||||
parent = StringField()
|
||||
project = StringField(reference_field=Project, user_set_allowed=True)
|
||||
output: Output = EmbeddedDocumentField(Output, default=Output)
|
||||
execution: Execution = EmbeddedDocumentField(Execution, default=Execution)
|
||||
tags = SafeSortedListField(StringField(required=True), user_set_allowed=True)
|
||||
system_tags = SafeSortedListField(StringField(required=True), user_set_allowed=True)
|
||||
script: Script = EmbeddedDocumentField(Script, default=Script)
|
||||
last_worker = StringField()
|
||||
last_worker_report = DateTimeField()
|
||||
last_update = DateTimeField()
|
||||
last_iteration = IntField(default=DEFAULT_LAST_ITERATION)
|
||||
last_metrics = SafeMapField(field=SafeMapField(EmbeddedDocumentField(MetricEvent)))
|
||||
metric_stats = SafeMapField(field=EmbeddedDocumentField(MetricEventStats))
|
||||
company_origin = StringField(exclude_by_default=True)
|
||||
duration = IntField() # task duration in seconds
|
||||
hyperparams = SafeMapField(field=SafeMapField(EmbeddedDocumentField(ParamsItem)))
|
||||
configuration = SafeMapField(field=EmbeddedDocumentField(ConfigurationItem))
|
||||
runtime = SafeDictField(default=dict)
|
||||
22
apiserver/database/model/user.py
Normal file
22
apiserver/database/model/user.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from mongoengine import Document, StringField, DynamicField
|
||||
|
||||
from database import Database, strict
|
||||
from database.model import DbModelMixin
|
||||
from database.model.base import GetMixin
|
||||
from database.model.company import Company
|
||||
|
||||
|
||||
class User(DbModelMixin, Document):
|
||||
meta = {
|
||||
"db_alias": Database.backend,
|
||||
"strict": strict,
|
||||
}
|
||||
get_all_query_options = GetMixin.QueryParameterOptions(list_fields=("id",))
|
||||
|
||||
id = StringField(primary_key=True)
|
||||
company = StringField(required=True, reference_field=Company)
|
||||
name = StringField(required=True, user_set_allowed=True)
|
||||
family_name = StringField(user_set_allowed=True)
|
||||
given_name = StringField(user_set_allowed=True)
|
||||
avatar = StringField()
|
||||
preferences = DynamicField(default="", exclude_by_default=True)
|
||||
18
apiserver/database/model/version.py
Normal file
18
apiserver/database/model/version.py
Normal file
@@ -0,0 +1,18 @@
|
||||
from mongoengine import Document, DateTimeField, StringField
|
||||
|
||||
from database import Database, strict
|
||||
from database.model import DbModelMixin
|
||||
|
||||
|
||||
class Version(DbModelMixin, Document):
|
||||
meta = {
|
||||
"collection": "versions", # custom collection name ('version' is not a proper collection name...)
|
||||
"db_alias": Database.backend, # although we'll use this model for all databases, a default must be defined
|
||||
"strict": strict,
|
||||
"indexes": [("-created", "-num")],
|
||||
}
|
||||
|
||||
id = StringField(primary_key=True)
|
||||
num = StringField(required=True)
|
||||
created = DateTimeField(required=True)
|
||||
desc = StringField()
|
||||
377
apiserver/database/projection.py
Normal file
377
apiserver/database/projection.py
Normal file
@@ -0,0 +1,377 @@
|
||||
import threading
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from itertools import groupby, chain
|
||||
from typing import Sequence, Dict, Callable, Tuple, Any, Type
|
||||
|
||||
import dpath.path
|
||||
|
||||
from apierrors import errors
|
||||
from database.props import PropsMixin
|
||||
|
||||
SEP = "."
|
||||
|
||||
|
||||
def project_dict(data, projection, separator=SEP):
|
||||
"""
|
||||
Project partial data from a dictionary into a new dictionary
|
||||
:param data: Input dictionary
|
||||
:param projection: List of dictionary paths (each a string with field names separated using a separator)
|
||||
:param separator: Separator (default is '.')
|
||||
:return: A new dictionary containing only the projected parts from the original dictionary
|
||||
"""
|
||||
assert isinstance(data, dict)
|
||||
result = {}
|
||||
|
||||
def copy_path(path_parts, source, destination):
|
||||
src, dst = source, destination
|
||||
try:
|
||||
for depth, path_part in enumerate(path_parts[:-1]):
|
||||
src_part = src[path_part]
|
||||
if isinstance(src_part, dict):
|
||||
src = src_part
|
||||
dst = dst.setdefault(path_part, {})
|
||||
elif isinstance(src_part, (list, tuple)):
|
||||
if path_part not in dst:
|
||||
dst[path_part] = [{} for _ in range(len(src_part))]
|
||||
elif not isinstance(dst[path_part], (list, tuple)):
|
||||
raise TypeError(
|
||||
"Incompatible destination type %s for %s (list expected)"
|
||||
% (type(dst), separator.join(path_parts[: depth + 1]))
|
||||
)
|
||||
elif not len(dst[path_part]) == len(src_part):
|
||||
raise ValueError(
|
||||
"Destination list length differs from source length for %s"
|
||||
% separator.join(path_parts[: depth + 1])
|
||||
)
|
||||
|
||||
dst[path_part] = [
|
||||
copy_path(path_parts[depth + 1 :], s, d)
|
||||
for s, d in zip(src_part, dst[path_part])
|
||||
]
|
||||
|
||||
return destination
|
||||
else:
|
||||
raise TypeError(
|
||||
"Unsupported projection type %s for %s"
|
||||
% (type(src), separator.join(path_parts[: depth + 1]))
|
||||
)
|
||||
|
||||
last_part = path_parts[-1]
|
||||
dst[last_part] = src[last_part]
|
||||
except KeyError:
|
||||
# Projection field not in source, no biggie.
|
||||
pass
|
||||
return destination
|
||||
|
||||
for projection_path in sorted(projection):
|
||||
copy_path(
|
||||
path_parts=projection_path.split(separator), source=data, destination=result
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
class _ReferenceProxy(dict):
|
||||
def __init__(self, id):
|
||||
super(_ReferenceProxy, self).__init__(**({"id": id} if id else {}))
|
||||
|
||||
|
||||
class _ProxyManager:
|
||||
lock = threading.Lock()
|
||||
|
||||
def __init__(self):
|
||||
self._proxies: Dict[str, _ReferenceProxy] = {}
|
||||
|
||||
def add(self, id):
|
||||
with self.lock:
|
||||
proxy = self._proxies.get(id)
|
||||
if proxy is None:
|
||||
proxy = self._proxies[id] = _ReferenceProxy(id)
|
||||
return proxy
|
||||
|
||||
def update(self, result):
|
||||
proxy = self._proxies.get(result.get("id"))
|
||||
if proxy is not None:
|
||||
proxy.update(result)
|
||||
|
||||
|
||||
class ProjectionHelper(object):
|
||||
pool = ThreadPoolExecutor()
|
||||
exclusion_prefix = "-"
|
||||
|
||||
@property
|
||||
def doc_projection(self):
|
||||
return self._doc_projection
|
||||
|
||||
def __init__(self, doc_cls, projection, expand_reference_ids=False):
|
||||
super(ProjectionHelper, self).__init__()
|
||||
self._should_expand_reference_ids = expand_reference_ids
|
||||
self._doc_cls = doc_cls
|
||||
self._doc_projection = None
|
||||
self._ref_projection = None
|
||||
self._proxy_manager = _ProxyManager()
|
||||
|
||||
# Cached dpath paths for each of the result documents
|
||||
self._cached_results_paths: Dict[int, Sequence[Tuple[Any, Type]]] = {}
|
||||
|
||||
self._parse_projection(projection)
|
||||
|
||||
def _collect_projection_fields(self, doc_cls, projection):
|
||||
"""
|
||||
Collect projection for the given document into immediate document projection and reference documents projection
|
||||
:param doc_cls: Document class
|
||||
:param projection: List of projection fields
|
||||
:return: A tuple of document projection and reference fields information
|
||||
"""
|
||||
doc_projection = (
|
||||
set()
|
||||
) # Projection fields for this class (used in the main query)
|
||||
ref_projection_info = (
|
||||
[]
|
||||
) # Projection information for reference fields (used in join queries)
|
||||
for field in projection:
|
||||
field_ = field.lstrip(self.exclusion_prefix)
|
||||
for ref_field, ref_field_cls in doc_cls.get_reference_fields().items():
|
||||
if not field_.startswith(ref_field):
|
||||
# Doesn't start with a reference field
|
||||
continue
|
||||
if field_ == ref_field:
|
||||
# Field is exactly a reference field. In this case we won't perform any inner projection (for that,
|
||||
# use '<reference field name>.*')
|
||||
continue
|
||||
subfield = field_[len(ref_field) :]
|
||||
if not subfield.startswith(SEP):
|
||||
# Starts with something that looks like a reference field, but isn't
|
||||
continue
|
||||
|
||||
ref_projection_info.append(
|
||||
(
|
||||
ref_field,
|
||||
ref_field_cls,
|
||||
("" if field_[0] == field[0] else self.exclusion_prefix)
|
||||
+ subfield[1:],
|
||||
)
|
||||
)
|
||||
break
|
||||
else:
|
||||
# Not a reference field, just add to the top-level projection
|
||||
# We strip any trailing '*' since it means nothing for simple fields and for embedded documents
|
||||
orig_field = field
|
||||
if field.endswith(".*"):
|
||||
field = field[:-2]
|
||||
if not field.lstrip(self.exclusion_prefix):
|
||||
raise errors.bad_request.InvalidFields(
|
||||
field=orig_field, object=doc_cls.__name__
|
||||
)
|
||||
doc_projection.add(field)
|
||||
return doc_projection, ref_projection_info
|
||||
|
||||
def _parse_projection(self, projection):
|
||||
"""
|
||||
Prepare the projection data structures for get_many_with_join().
|
||||
:param projection: A list of field names that should be returned by the query. Sub-fields can be specified
|
||||
using '.' (i.e. "parent.name"). A field terminated by '.*' indicated that all of the field's sub-fields
|
||||
should be returned (only relevant for fields that represent sub-documents or referenced documents)
|
||||
:type projection: list of strings
|
||||
:returns A tuple of (class fields projection, reference fields projection)
|
||||
"""
|
||||
doc_cls = self._doc_cls
|
||||
assert issubclass(doc_cls, PropsMixin)
|
||||
if not projection:
|
||||
return [], {}
|
||||
|
||||
doc_projection, ref_projection_info = self._collect_projection_fields(
|
||||
doc_cls, projection
|
||||
)
|
||||
|
||||
def normalize_cls_projection(cls_, fields):
|
||||
""" Normalize projection for this class and group (expand *, for once) """
|
||||
if "*" in fields:
|
||||
return list(fields.difference("*").union(cls_.get_fields()))
|
||||
return list(fields)
|
||||
|
||||
def compute_ref_cls_projection(cls_, group):
|
||||
""" Compute inner projection for this class and group """
|
||||
subfields = set([x[2] for x in group if x[2]])
|
||||
return normalize_cls_projection(cls_, subfields)
|
||||
|
||||
def sort_key(proj_info):
|
||||
return proj_info[:2]
|
||||
|
||||
# Aggregate by reference field. We'll leave out '*' from the projected items since
|
||||
ref_projection = {
|
||||
ref_field: dict(cls=ref_cls, only=compute_ref_cls_projection(ref_cls, g))
|
||||
for (ref_field, ref_cls), g in groupby(
|
||||
sorted(ref_projection_info, key=sort_key), sort_key
|
||||
)
|
||||
}
|
||||
|
||||
# Make sure this doesn't contain any reference field we'll join anyway
|
||||
# (i.e. in case only_fields=[project, project.name])
|
||||
doc_projection = normalize_cls_projection(
|
||||
doc_cls, doc_projection.difference(ref_projection)
|
||||
)
|
||||
|
||||
# Make sure that in case one or more field is a subfield of another field, we only use the the top-level field.
|
||||
# This is done since in such a case, MongoDB will only use the most restrictive field (most nested field) and
|
||||
# won't return some of the data we need.
|
||||
# This way, we make sure to use the most inclusive field that contains all requested subfields.
|
||||
projection_set = set(doc_projection)
|
||||
doc_projection = [
|
||||
field
|
||||
for field in doc_projection
|
||||
if not any(
|
||||
field.startswith(f"{other_field}.")
|
||||
for other_field in projection_set - {field}
|
||||
)
|
||||
]
|
||||
|
||||
# Make sure we didn't get any invalid projection fields for this class
|
||||
invalid_fields = [
|
||||
f
|
||||
for f in doc_projection
|
||||
if f.partition(SEP)[0].lstrip(self.exclusion_prefix)
|
||||
not in doc_cls.get_fields()
|
||||
]
|
||||
if invalid_fields:
|
||||
raise errors.bad_request.InvalidFields(
|
||||
fields=invalid_fields, object=doc_cls.__name__
|
||||
)
|
||||
|
||||
if ref_projection:
|
||||
# Join mode - use both normal projection fields and top-level reference fields
|
||||
doc_projection = set(doc_projection)
|
||||
for field in set(ref_projection).difference(doc_projection):
|
||||
if any(f for f in doc_projection if field.startswith(f)):
|
||||
continue
|
||||
doc_projection.add(field)
|
||||
doc_projection = list(doc_projection)
|
||||
|
||||
# If there are include fields (not only exclude) then add an id field
|
||||
if (
|
||||
not all(p.startswith(self.exclusion_prefix) for p in doc_projection)
|
||||
and "id" not in doc_projection
|
||||
):
|
||||
doc_projection.append("id")
|
||||
|
||||
self._doc_projection = doc_projection
|
||||
self._ref_projection = ref_projection
|
||||
|
||||
def _search(
|
||||
self,
|
||||
doc_cls: PropsMixin,
|
||||
obj: dict,
|
||||
path: str,
|
||||
factory: Callable[[str], dict] = None,
|
||||
) -> Sequence[str]:
|
||||
"""
|
||||
Search for a path in the given object, return the list of values found for the
|
||||
given path (multiple values may exist if the path is a glob expression)
|
||||
:param doc_cls: The document class represented by the object
|
||||
:param obj: Data object
|
||||
:param path: Path to a leaf in the data object ("." separated, may contain "*")
|
||||
(in case the path contains "*", there may be multiple values)
|
||||
:param factory: If provided, replace each value found with an instance provided by the factory.
|
||||
"""
|
||||
norm_path = doc_cls.get_dpath_translated_path(path)
|
||||
globlist = norm_path.strip(SEP).split(SEP)
|
||||
|
||||
obj_paths = self._cached_results_paths.get(id(obj))
|
||||
if obj_paths is None:
|
||||
obj_paths = self._cached_results_paths[id(obj)] = list(
|
||||
dpath.path.paths(obj, dirs=True, skip=True)
|
||||
)
|
||||
|
||||
paths = [p for p in obj_paths if dpath.path.match(p, globlist)]
|
||||
|
||||
def search_and_replace(p: Sequence[Tuple[str, Type]]) -> Any:
|
||||
parent = None
|
||||
target = obj
|
||||
for part in p:
|
||||
parent = target
|
||||
target = target[part[0]]
|
||||
if parent and factory:
|
||||
parent[p[-1][0]] = factory(target)
|
||||
return target
|
||||
|
||||
return [search_and_replace(p) for p in paths]
|
||||
|
||||
def project(self, results, projection_func):
|
||||
"""
|
||||
Perform projection on query results, using the provided projection func.
|
||||
:param results: A list of results dictionaries on which projection should be performed
|
||||
:param projection_func: A callable that receives a document type, list of ids and projection and returns query
|
||||
results. This callable is used in order to perform sub-queries during projection
|
||||
:return: Modified results (in-place)
|
||||
"""
|
||||
cls = self._doc_cls
|
||||
ref_projection = self._ref_projection
|
||||
|
||||
if ref_projection:
|
||||
# Join mode - get results for each reference fields projection required (this is the join step)
|
||||
# Note: this is a recursive step, so nested reference fields are supported
|
||||
|
||||
def collect_ids(ref_field_name):
|
||||
"""
|
||||
Collect unique IDs for the given reference path from all result documents.
|
||||
All collected IDs are replaced in the result dictionaries with a reference proxy generated by the
|
||||
proxies manager to allow rapid update later on when projection results are obtained.
|
||||
"""
|
||||
all_ids = (
|
||||
self._search(
|
||||
cls, res, ref_field_name, factory=self._proxy_manager.add
|
||||
)
|
||||
for res in results
|
||||
)
|
||||
return list(filter(None, set(chain.from_iterable(all_ids))))
|
||||
|
||||
items = [
|
||||
tup
|
||||
for tup in (
|
||||
(*item, collect_ids(item[0])) for item in ref_projection.items()
|
||||
)
|
||||
if tup[2]
|
||||
]
|
||||
|
||||
if items:
|
||||
|
||||
def do_projection(item):
|
||||
ref_field_name, data, ids = item
|
||||
|
||||
doc_type = data["cls"]
|
||||
doc_only = list(filter(None, data["only"]))
|
||||
doc_only = list({"id"} | set(doc_only)) if doc_only else None
|
||||
|
||||
for res in projection_func(
|
||||
doc_type=doc_type, projection=doc_only, ids=ids
|
||||
):
|
||||
self._proxy_manager.update(res)
|
||||
|
||||
if len(ref_projection) == 1:
|
||||
do_projection(items[0])
|
||||
else:
|
||||
for _ in self.pool.map(do_projection, items):
|
||||
# From ThreadPoolExecutor.map() documentation: If a call raises an exception then that exception
|
||||
# will be raised when its value is retrieved from the map() iterator
|
||||
pass
|
||||
|
||||
def do_expand_reference_ids(result, skip_fields=None):
|
||||
ref_fields = cls.get_reference_fields()
|
||||
if skip_fields:
|
||||
ref_fields = set(ref_fields) - set(skip_fields)
|
||||
self._expand_reference_fields(cls, result, ref_fields)
|
||||
|
||||
# any reference field not projected should be expanded
|
||||
if self._should_expand_reference_ids:
|
||||
for result in results:
|
||||
do_expand_reference_ids(
|
||||
result, skip_fields=list(ref_projection) if ref_projection else None
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
def _expand_reference_fields(self, doc_cls, result, fields):
|
||||
for ref_field_name in fields:
|
||||
self._search(doc_cls, result, ref_field_name, factory=_ReferenceProxy)
|
||||
|
||||
def expand_reference_ids(self, doc_cls, result):
|
||||
self._expand_reference_fields(doc_cls, result, doc_cls.get_reference_fields())
|
||||
201
apiserver/database/props.py
Normal file
201
apiserver/database/props.py
Normal file
@@ -0,0 +1,201 @@
|
||||
from collections import OrderedDict, defaultdict
|
||||
from itertools import chain
|
||||
from operator import attrgetter
|
||||
from threading import Lock
|
||||
from typing import Sequence
|
||||
|
||||
import six
|
||||
from mongoengine import EmbeddedDocumentField, EmbeddedDocumentListField
|
||||
from mongoengine.base import get_document, BaseField
|
||||
|
||||
from database.fields import (
|
||||
LengthRangeEmbeddedDocumentListField,
|
||||
UniqueEmbeddedDocumentListField,
|
||||
EmbeddedDocumentSortedListField,
|
||||
)
|
||||
from database.utils import get_fields, get_fields_attr
|
||||
|
||||
|
||||
class PropsMixin(object):
|
||||
__cached_fields = None
|
||||
__cached_reference_fields = None
|
||||
__cached_exclude_fields = None
|
||||
__cached_fields_with_instance = None
|
||||
__cached_field_names_per_type = None
|
||||
|
||||
__cached_dpath_computed_fields_lock = Lock()
|
||||
__cached_dpath_computed_fields = None
|
||||
|
||||
@classmethod
|
||||
def get_fields(cls):
|
||||
if cls.__cached_fields is None:
|
||||
cls.__cached_fields = get_fields(cls)
|
||||
return cls.__cached_fields
|
||||
|
||||
@classmethod
|
||||
def get_field_names_for_type(cls, of_type=BaseField):
|
||||
"""
|
||||
Return field names per type including subfields
|
||||
The fields of derived types are also returned
|
||||
"""
|
||||
assert issubclass(of_type, BaseField)
|
||||
if cls.__cached_field_names_per_type is None:
|
||||
fields = defaultdict(list)
|
||||
for name, field in get_fields(cls, return_instance=True, subfields=True):
|
||||
fields[type(field)].append(name)
|
||||
for type_ in fields:
|
||||
fields[type_].extend(
|
||||
chain.from_iterable(
|
||||
fields[other_type]
|
||||
for other_type in fields
|
||||
if other_type != type_ and issubclass(other_type, type_)
|
||||
)
|
||||
)
|
||||
cls.__cached_field_names_per_type = fields
|
||||
|
||||
if of_type not in cls.__cached_field_names_per_type:
|
||||
names = list(
|
||||
chain.from_iterable(
|
||||
field_names
|
||||
for type_, field_names in cls.__cached_field_names_per_type.items()
|
||||
if issubclass(type_, of_type)
|
||||
)
|
||||
)
|
||||
cls.__cached_field_names_per_type[of_type] = names
|
||||
|
||||
return cls.__cached_field_names_per_type[of_type]
|
||||
|
||||
@classmethod
|
||||
def get_fields_with_instance(cls, doc_cls):
|
||||
if cls.__cached_fields_with_instance is None:
|
||||
cls.__cached_fields_with_instance = {}
|
||||
if doc_cls not in cls.__cached_fields_with_instance:
|
||||
cls.__cached_fields_with_instance[doc_cls] = get_fields(
|
||||
doc_cls, return_instance=True
|
||||
)
|
||||
return cls.__cached_fields_with_instance[doc_cls]
|
||||
|
||||
@staticmethod
|
||||
def _get_fields_with_attr(cls_, attr):
|
||||
""" Get all fields with the specified attribute (supports nested fields) """
|
||||
res = get_fields_attr(cls_, attr=attr)
|
||||
|
||||
def resolve_doc(v):
|
||||
if not isinstance(v, six.string_types):
|
||||
return v
|
||||
if v == 'self':
|
||||
return cls_.owner_document
|
||||
return get_document(v)
|
||||
|
||||
fields = {k: resolve_doc(v) for k, v in res.items()}
|
||||
|
||||
def collect_embedded_docs(doc_cls, embedded_doc_field_getter):
|
||||
for field, embedded_doc_field in get_fields(
|
||||
cls_, of_type=doc_cls, return_instance=True
|
||||
):
|
||||
embedded_doc_cls = embedded_doc_field_getter(
|
||||
embedded_doc_field
|
||||
).document_type
|
||||
fields.update(
|
||||
{
|
||||
'.'.join((field, subfield)): doc
|
||||
for subfield, doc in PropsMixin._get_fields_with_attr(
|
||||
embedded_doc_cls, attr
|
||||
).items()
|
||||
}
|
||||
)
|
||||
|
||||
collect_embedded_docs(EmbeddedDocumentField, lambda x: x)
|
||||
collect_embedded_docs(EmbeddedDocumentListField, attrgetter('field'))
|
||||
collect_embedded_docs(LengthRangeEmbeddedDocumentListField, attrgetter('field'))
|
||||
collect_embedded_docs(UniqueEmbeddedDocumentListField, attrgetter('field'))
|
||||
collect_embedded_docs(EmbeddedDocumentSortedListField, attrgetter('field'))
|
||||
|
||||
return fields
|
||||
|
||||
@classmethod
|
||||
def _translate_fields_path(cls, parts):
|
||||
current_cls = cls
|
||||
translated_parts = []
|
||||
for depth, part in enumerate(parts):
|
||||
if current_cls is None:
|
||||
raise ValueError(
|
||||
'Invalid path (non-document encountered at %s)' % parts[: depth - 1]
|
||||
)
|
||||
try:
|
||||
field_name, field = next(
|
||||
(k, v)
|
||||
for k, v in cls.get_fields_with_instance(current_cls)
|
||||
if k == part
|
||||
)
|
||||
except StopIteration:
|
||||
raise ValueError('Invalid field path %s' % parts[:depth])
|
||||
|
||||
translated_parts.append(part)
|
||||
|
||||
if isinstance(field, EmbeddedDocumentField):
|
||||
current_cls = field.document_type
|
||||
elif isinstance(
|
||||
field,
|
||||
(
|
||||
EmbeddedDocumentListField,
|
||||
LengthRangeEmbeddedDocumentListField,
|
||||
UniqueEmbeddedDocumentListField,
|
||||
EmbeddedDocumentSortedListField,
|
||||
),
|
||||
):
|
||||
current_cls = field.field.document_type
|
||||
translated_parts.append('*')
|
||||
else:
|
||||
current_cls = None
|
||||
|
||||
return translated_parts
|
||||
|
||||
@classmethod
|
||||
def get_reference_fields(cls):
|
||||
if cls.__cached_reference_fields is None:
|
||||
fields = cls._get_fields_with_attr(cls, 'reference_field')
|
||||
cls.__cached_reference_fields = OrderedDict(sorted(fields.items()))
|
||||
return cls.__cached_reference_fields
|
||||
|
||||
@classmethod
|
||||
def get_extra_projection(cls, fields: Sequence) -> tuple:
|
||||
if isinstance(fields, str):
|
||||
fields = [fields]
|
||||
return tuple(
|
||||
set(fields).union(cls.get_fields()).difference(cls.get_exclude_fields())
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_exclude_fields(cls):
|
||||
if cls.__cached_exclude_fields is None:
|
||||
fields = cls._get_fields_with_attr(cls, 'exclude_by_default')
|
||||
cls.__cached_exclude_fields = OrderedDict(sorted(fields.items()))
|
||||
return cls.__cached_exclude_fields
|
||||
|
||||
@classmethod
|
||||
def get_dpath_translated_path(cls, path, separator='.'):
|
||||
if cls.__cached_dpath_computed_fields is None:
|
||||
cls.__cached_dpath_computed_fields = {}
|
||||
if path not in cls.__cached_dpath_computed_fields:
|
||||
with cls.__cached_dpath_computed_fields_lock:
|
||||
parts = path.split(separator)
|
||||
translated = cls._translate_fields_path(parts)
|
||||
result = separator.join(translated)
|
||||
cls.__cached_dpath_computed_fields[path] = result
|
||||
return cls.__cached_dpath_computed_fields[path]
|
||||
|
||||
def get_field_value(self, field_path: str, default=None):
|
||||
"""
|
||||
Return the document field_path value by the field_path name.
|
||||
The path may contain '.'. If on any level the path is
|
||||
not found then the default value is returned
|
||||
"""
|
||||
path_elements = field_path.split(".")
|
||||
current = self
|
||||
for name in path_elements:
|
||||
current = getattr(current, name, default)
|
||||
if current == default:
|
||||
break
|
||||
|
||||
return current
|
||||
68
apiserver/database/query.py
Normal file
68
apiserver/database/query.py
Normal file
@@ -0,0 +1,68 @@
|
||||
import copy
|
||||
import re
|
||||
from typing import Union
|
||||
|
||||
from mongoengine import Q
|
||||
from mongoengine.queryset.visitor import (
|
||||
QueryCompilerVisitor,
|
||||
SimplificationVisitor,
|
||||
QCombination,
|
||||
QNode,
|
||||
)
|
||||
|
||||
|
||||
class RegexWrapper(object):
|
||||
def __init__(self, pattern, flags=None):
|
||||
super(RegexWrapper, self).__init__()
|
||||
self.pattern = pattern
|
||||
self.flags = flags
|
||||
|
||||
@property
|
||||
def regex(self):
|
||||
return re.compile(self.pattern, self.flags if self.flags is not None else 0)
|
||||
|
||||
|
||||
class RegexMixin(object):
|
||||
def to_query(self: Union["RegexMixin", QNode], document):
|
||||
query = self.accept(SimplificationVisitor())
|
||||
query = query.accept(RegexQueryCompilerVisitor(document))
|
||||
return query
|
||||
|
||||
def _combine(self: Union["RegexMixin", QNode], other, operation):
|
||||
"""Combine this node with another node into a QCombination
|
||||
object.
|
||||
"""
|
||||
if getattr(other, "empty", True):
|
||||
return self
|
||||
|
||||
if self.empty:
|
||||
return other
|
||||
|
||||
return RegexQCombination(operation, [self, other])
|
||||
|
||||
|
||||
class RegexQCombination(RegexMixin, QCombination):
|
||||
pass
|
||||
|
||||
|
||||
class RegexQ(RegexMixin, Q):
|
||||
pass
|
||||
|
||||
|
||||
class RegexQueryCompilerVisitor(QueryCompilerVisitor):
|
||||
"""
|
||||
Improved mongoengine complied queries visitor class that supports compiled regex expressions as part of the query.
|
||||
|
||||
We need this class since mongoengine's Q (QNode) class uses copy.deepcopy() as part of the tree simplification
|
||||
stage, which does not support re.compiled objects (since Python 2.5).
|
||||
This class allows users to provide regex strings wrapped in QueryRegex instances, which are lazily evaluated to
|
||||
to re.compile instances just before being visited for compilation (this is done after the simplification stage)
|
||||
"""
|
||||
|
||||
def visit_query(self, query):
|
||||
query = copy.deepcopy(query)
|
||||
query.query = self._transform_query(query.query)
|
||||
return super(RegexQueryCompilerVisitor, self).visit_query(query)
|
||||
|
||||
def _transform_query(self, query):
|
||||
return {k: v.regex if isinstance(v, RegexWrapper) else v for k, v in query.items()}
|
||||
235
apiserver/database/utils.py
Normal file
235
apiserver/database/utils.py
Normal file
@@ -0,0 +1,235 @@
|
||||
import hashlib
|
||||
from inspect import ismethod, getmembers
|
||||
from typing import Sequence, Tuple, Set, Optional, Callable, Any
|
||||
from uuid import uuid4
|
||||
|
||||
from mongoengine import EmbeddedDocumentField, ListField, Document, Q
|
||||
from mongoengine.base import BaseField
|
||||
|
||||
from .errors import translate_errors_context, ParseCallError
|
||||
|
||||
|
||||
def get_fields(cls, of_type=BaseField, return_instance=False, subfields=False):
|
||||
return _get_fields(
|
||||
cls,
|
||||
of_type=of_type,
|
||||
subfields=subfields,
|
||||
selector=lambda k, v: (k, v) if return_instance else k,
|
||||
)
|
||||
|
||||
|
||||
def get_fields_attr(cls, attr):
|
||||
""" get field names from a class containing mongoengine fields """
|
||||
return dict(
|
||||
_get_fields(cls, with_attr=attr, selector=lambda k, v: (k, getattr(v, attr)))
|
||||
)
|
||||
|
||||
|
||||
def get_fields_choices(cls, attr):
|
||||
def get_choices(field_name: str, field: BaseField) -> Tuple:
|
||||
if isinstance(field, ListField):
|
||||
return field_name, field.field.choices
|
||||
return field_name, field.choices
|
||||
|
||||
return dict(_get_fields(cls, with_attr=attr, subfields=True, selector=get_choices))
|
||||
|
||||
|
||||
def _get_fields(
|
||||
cls,
|
||||
with_attr=None,
|
||||
of_type=BaseField,
|
||||
subfields=False,
|
||||
selector: Optional[Callable[[str, BaseField], Any]] = None,
|
||||
path: Tuple[str, ...] = (),
|
||||
):
|
||||
fields = []
|
||||
for field_name, field in cls._fields.items():
|
||||
field_path = path + (field_name,)
|
||||
if isinstance(field, of_type) and (not with_attr or hasattr(field, with_attr)):
|
||||
full_name = "__".join(field_path)
|
||||
fields.append(selector(full_name, field) if selector else full_name)
|
||||
|
||||
if subfields and isinstance(field, EmbeddedDocumentField):
|
||||
fields.extend(
|
||||
_get_fields(
|
||||
field.document_type,
|
||||
with_attr=with_attr,
|
||||
of_type=of_type,
|
||||
subfields=subfields,
|
||||
selector=selector,
|
||||
path=field_path,
|
||||
)
|
||||
)
|
||||
|
||||
return fields
|
||||
|
||||
|
||||
def get_items(cls):
|
||||
""" get key/value items from an enum-like class (members represent enumeration key/value) """
|
||||
|
||||
res = {k: v for k, v in getmembers(cls) if not (k.startswith("_") or ismethod(v))}
|
||||
return res
|
||||
|
||||
|
||||
def get_options(cls):
|
||||
""" get options from an enum-like class (members represent enumeration key/value) """
|
||||
return list(get_items(cls).values())
|
||||
|
||||
|
||||
# return a dictionary of items which:
|
||||
# 1. are in the call_data
|
||||
# 2. are in the fields dictionary, and their value in the call_data matches the type in fields
|
||||
# 3. are in the cls_fields
|
||||
def parse_from_call(call_data, fields, cls_fields, discard_none_values=True):
|
||||
if not isinstance(fields, dict):
|
||||
# fields should be key=>type dict
|
||||
fields = {k: None for k in fields}
|
||||
fields = {k: v for k, v in fields.items() if k in cls_fields}
|
||||
res = {}
|
||||
with translate_errors_context("parsing call data"):
|
||||
for field, desc in fields.items():
|
||||
value = call_data.get(field)
|
||||
if value is None:
|
||||
if not discard_none_values and field in call_data:
|
||||
# we'll keep the None value in case the field actually exists in the call data
|
||||
res[field] = None
|
||||
continue
|
||||
if desc:
|
||||
if issubclass(desc, Document):
|
||||
if not desc.objects(id=value).only("id"):
|
||||
raise ParseCallError(
|
||||
"expecting %s id" % desc.__name__, id=value, field=field
|
||||
)
|
||||
elif callable(desc):
|
||||
try:
|
||||
desc(value)
|
||||
except TypeError:
|
||||
raise ParseCallError(f"expecting {desc.__name__}", field=field)
|
||||
except Exception as ex:
|
||||
raise ParseCallError(str(ex), field=field)
|
||||
res[field] = value
|
||||
return res
|
||||
|
||||
|
||||
def init_cls_from_base(cls, instance):
|
||||
return cls(
|
||||
**{
|
||||
k: v
|
||||
for k, v in instance.to_mongo(use_db_field=False).to_dict().items()
|
||||
if k[0] != "_"
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def get_company_or_none_constraint(company=None):
|
||||
return Q(company__in=(company, None, "")) | Q(company__exists=False)
|
||||
|
||||
|
||||
def field_does_not_exist(field: str, empty_value=None, is_list=False) -> Q:
|
||||
"""
|
||||
Creates a query object used for finding a field that doesn't exist, or has None or an empty value.
|
||||
:param field: Field name
|
||||
:param empty_value: The empty value to test for (None means no specific empty value will be used)
|
||||
:param is_list: Is this a list (array) field. In this case, instead of testing for an empty value,
|
||||
the length of the array will be used (len==0 means empty)
|
||||
:return:
|
||||
"""
|
||||
query = Q(**{f"{field}__exists": False}) | Q(
|
||||
**{f"{field}__in": {empty_value, None}}
|
||||
)
|
||||
if is_list:
|
||||
query |= Q(**{f"{field}__size": 0})
|
||||
return query
|
||||
|
||||
|
||||
def field_exists(field: str, empty_value=None, is_list=False) -> Q:
|
||||
"""
|
||||
Creates a query object used for finding a field that exists and is not None or empty.
|
||||
:param field: Field name
|
||||
:param empty_value: The empty value to test for (None means no specific empty value will be used)
|
||||
:param is_list: Is this a list (array) field. In this case, instead of testing for an empty value,
|
||||
the length of the array will be used (len==0 means empty)
|
||||
:return:
|
||||
"""
|
||||
query = Q(**{f"{field}__exists": True}) & Q(
|
||||
**{f"{field}__nin": {empty_value, None}}
|
||||
)
|
||||
if is_list:
|
||||
query &= Q(**{f"{field}__not__size": 0})
|
||||
return query
|
||||
|
||||
|
||||
def get_subkey(d, key_path, default=None):
|
||||
""" Get a key from a nested dictionary. kay_path is a '.' separated string of keys used to traverse
|
||||
the nested dictionary.
|
||||
"""
|
||||
keys = key_path.split(".")
|
||||
for i, key in enumerate(keys):
|
||||
if not isinstance(d, dict):
|
||||
raise KeyError(
|
||||
"Expecting a dict (%s)" % (".".join(keys[:i]) if i else "bad input")
|
||||
)
|
||||
d = d.get(key)
|
||||
if d is None:
|
||||
return default
|
||||
return d
|
||||
|
||||
|
||||
def id():
|
||||
return str(uuid4()).replace("-", "")
|
||||
|
||||
|
||||
def hash_field_name(s):
|
||||
""" Hash field name into a unique safe string """
|
||||
return hashlib.md5(s.encode()).hexdigest()
|
||||
|
||||
|
||||
def merge_dicts(*dicts):
|
||||
base = {}
|
||||
for dct in dicts:
|
||||
base.update(dct)
|
||||
return base
|
||||
|
||||
|
||||
def filter_fields(cls, fields):
|
||||
"""From the fields dictionary return only the fields that match cls fields"""
|
||||
return {key: fields[key] for key in fields if key in get_fields(cls)}
|
||||
|
||||
|
||||
def _names_set(*names: str) -> Set[str]:
|
||||
"""
|
||||
Given a list of names return set with names and '-names'
|
||||
"""
|
||||
return set(names) | set(f"-{name}" for name in names)
|
||||
|
||||
|
||||
system_tag_names = {
|
||||
"model": _names_set("active", "archived"),
|
||||
"project": _names_set("archived", "public", "default"),
|
||||
"task": _names_set("active", "archived", "development"),
|
||||
"queue": _names_set("default"),
|
||||
}
|
||||
|
||||
system_tag_prefixes = {"task": _names_set("annotat")}
|
||||
|
||||
|
||||
def partition_tags(
|
||||
entity: str, tags: Sequence[str], system_tags: Optional[Sequence[str]] = ()
|
||||
) -> Tuple[Sequence[str], Sequence[str]]:
|
||||
"""
|
||||
Partition the given tags sequence into system and user-defined tags
|
||||
:param entity: The name of the entity that defines the list of the system tags
|
||||
:param tags: The tags to partition
|
||||
:param system_tags: Optional. If passed then these tags are considered system together
|
||||
with those defined for the entity.
|
||||
:return: a tuple where the first element is the sequence of user-defined tags and
|
||||
the second element is the sequence of system tags
|
||||
"""
|
||||
tags = set(tags)
|
||||
system_tags = set(system_tags)
|
||||
system_tags |= tags & system_tag_names[entity]
|
||||
|
||||
prefixes = system_tag_prefixes.get(entity, [])
|
||||
system_tags |= {t for t in tags for p in prefixes if t.lower().startswith(p)}
|
||||
|
||||
return list(tags - system_tags), list(system_tags)
|
||||
Reference in New Issue
Block a user