Update pymongo and mongoengine versions

This commit is contained in:
allegroai
2020-06-01 11:29:50 +03:00
parent 35a11db58e
commit a1dcdffa53
12 changed files with 110 additions and 50 deletions

View File

@@ -34,7 +34,12 @@ class AuthDocument(Document):
class ProperDictMixin(object):
def to_proper_dict(self, strip_private=True, only=None, extra_dict=None) -> dict:
def to_proper_dict(
self: Union["ProperDictMixin", Document],
strip_private=True,
only=None,
extra_dict=None,
) -> dict:
return self.properize_dict(
self.to_mongo(use_db_field=False).to_dict(),
strip_private=strip_private,
@@ -95,7 +100,13 @@ class GetMixin(PropsMixin):
@classmethod
def get(
cls, company, id, *, _only=None, include_public=False, **kwargs
cls: Union["GetMixin", Document],
company,
id,
*,
_only=None,
include_public=False,
**kwargs,
) -> "GetMixin":
q = cls.objects(
cls._prepare_perm_query(company, allow_public=include_public)
@@ -409,7 +420,12 @@ class GetMixin(PropsMixin):
)
@classmethod
def _get_many_no_company(cls, query, parameters=None, override_projection=None):
def _get_many_no_company(
cls: Union["GetMixin", Document],
query,
parameters=None,
override_projection=None,
):
"""
Fetch all documents matching a provided query.
This is a company-less version for internal uses. We assume the caller has either added any necessary
@@ -593,7 +609,13 @@ class UpdateMixin(object):
return update_dict
@classmethod
def safe_update(cls, company_id, id, partial_update_dict, injected_update=None):
def safe_update(
cls: Union["UpdateMixin", Document],
company_id,
id,
partial_update_dict,
injected_update=None,
):
update_dict = cls.get_safe_update_dict(partial_update_dict)
if not update_dict:
return 0, {}
@@ -610,7 +632,10 @@ class DbModelMixin(GetMixin, ProperDictMixin, UpdateMixin):
@classmethod
def aggregate(
cls: Document, *pipeline: dict, allow_disk_use=None, **kwargs
cls: Union["DbModelMixin", Document],
pipeline: Sequence[dict],
allow_disk_use=None,
**kwargs,
) -> CommandCursor:
"""
Aggregate objects of this document class according to the provided pipeline.
@@ -625,7 +650,7 @@ class DbModelMixin(GetMixin, ProperDictMixin, UpdateMixin):
if allow_disk_use is not None
else config.get("apiserver.mongo.aggregate.allow_disk_use", True)
)
return cls.objects.aggregate(*pipeline, **kwargs)
return cls.objects.aggregate(pipeline, **kwargs)
def validate_id(cls, company, **kwargs):

View File

@@ -1,11 +1,14 @@
from mongoengine import MapField, IntField
from database.fields import NoneType, UnionField, SafeMapField
class ModelLabels(MapField):
class ModelLabels(SafeMapField):
def __init__(self, *args, **kwargs):
super(ModelLabels, self).__init__(field=IntField(), *args, **kwargs)
super(ModelLabels, self).__init__(
field=UnionField(types=(int, NoneType)), *args, **kwargs
)
def validate(self, value):
super(ModelLabels, self).validate(value)
if value and (len(set(value.values())) < len(value)):
non_empty_values = list(filter(None, value.values()))
if non_empty_values and len(set(non_empty_values)) < len(non_empty_values):
self.error("Same label id appears more than once in model labels")