Update pymongo and mongoengine versions

This commit is contained in:
allegroai 2020-06-01 11:29:50 +03:00
parent 35a11db58e
commit a1dcdffa53
12 changed files with 110 additions and 50 deletions

View File

@ -280,7 +280,7 @@ class StatisticsReporter:
] ]
return { return {
group["_id"]: {k: v for k, v in group.items() if k != "_id"} group["_id"]: {k: v for k, v in group.items() if k != "_id"}
for group in Task.aggregate(*pipeline) for group in Task.aggregate(pipeline)
} }

View File

@ -263,7 +263,7 @@ class TaskBLL(object):
] ]
with translate_errors_context(): with translate_errors_context():
result = Task.aggregate(*pipeline) result = Task.aggregate(pipeline)
return [r["metrics"][0] for r in result] return [r["metrics"][0] for r in result]
@staticmethod @staticmethod
@ -666,7 +666,7 @@ class TaskBLL(object):
] ]
with translate_errors_context(): with translate_errors_context():
result = next(Task.aggregate(*pipeline), None) result = next(Task.aggregate(pipeline), None)
total = 0 total = 0
remaining = 0 remaining = 0

View File

@ -223,7 +223,7 @@ class WorkerBLL:
}, },
] ]
queues_info = { queues_info = {
res["_id"]: res for res in Queue.objects.aggregate(*projection) res["_id"]: res for res in Queue.objects.aggregate(projection)
} }
task_ids = task_ids.union( task_ids = task_ids.union(
filter( filter(

View File

@ -14,6 +14,9 @@ from mongoengine import (
DictField, DictField,
DynamicField, DynamicField,
) )
from mongoengine.fields import key_not_string, key_starts_with_dollar
NoneType = type(None)
class LengthRangeListField(ListField): class LengthRangeListField(ListField):
@ -125,17 +128,39 @@ def contains_empty_key(d):
return True return True
class SafeMapField(MapField): class DictValidationMixin:
"""
DictField validation in MongoEngine requires default alias and permissions to access DB version:
https://github.com/MongoEngine/mongoengine/issues/2239
This is a stripped down implementation that does not require any of the above and implies Mongo ver 3.6+
"""
def _safe_validate(self: DictField, value):
if not isinstance(value, dict):
self.error("Only dictionaries may be used in a DictField")
if key_not_string(value):
msg = "Invalid dictionary key - documents must have only string keys"
self.error(msg)
if key_starts_with_dollar(value):
self.error(
'Invalid dictionary key name - keys may not startswith "$" characters'
)
super(DictField, self).validate(value)
class SafeMapField(MapField, DictValidationMixin):
def validate(self, value): def validate(self, value):
super(SafeMapField, self).validate(value) self._safe_validate(value)
if contains_empty_key(value): if contains_empty_key(value):
self.error("Empty keys are not allowed in a MapField") self.error("Empty keys are not allowed in a MapField")
class SafeDictField(DictField): class SafeDictField(DictField, DictValidationMixin):
def validate(self, value): def validate(self, value):
super(SafeDictField, self).validate(value) self._safe_validate(value)
if contains_empty_key(value): if contains_empty_key(value):
self.error("Empty keys are not allowed in a DictField") self.error("Empty keys are not allowed in a DictField")
@ -146,6 +171,7 @@ class SafeSortedListField(SortedListField):
SortedListField that does not raise an error in case items are not comparable SortedListField that does not raise an error in case items are not comparable
(in which case they will be sorted by their string representation) (in which case they will be sorted by their string representation)
""" """
def to_mongo(self, *args, **kwargs): def to_mongo(self, *args, **kwargs):
try: try:
return super(SafeSortedListField, self).to_mongo(*args, **kwargs) return super(SafeSortedListField, self).to_mongo(*args, **kwargs)
@ -155,7 +181,10 @@ class SafeSortedListField(SortedListField):
def _safe_to_mongo(self, value, use_db_field=True, fields=None): def _safe_to_mongo(self, value, use_db_field=True, fields=None):
value = super(SortedListField, self).to_mongo(value, use_db_field, fields) value = super(SortedListField, self).to_mongo(value, use_db_field, fields)
if self._ordering is not None: if self._ordering is not None:
def key(v): return str(itemgetter(self._ordering)(v))
def key(v):
return str(itemgetter(self._ordering)(v))
else: else:
key = str key = str
return sorted(value, key=key, reverse=self._order_reverse) return sorted(value, key=key, reverse=self._order_reverse)

View File

@ -34,7 +34,12 @@ class AuthDocument(Document):
class ProperDictMixin(object): class ProperDictMixin(object):
def to_proper_dict(self, strip_private=True, only=None, extra_dict=None) -> dict: def to_proper_dict(
self: Union["ProperDictMixin", Document],
strip_private=True,
only=None,
extra_dict=None,
) -> dict:
return self.properize_dict( return self.properize_dict(
self.to_mongo(use_db_field=False).to_dict(), self.to_mongo(use_db_field=False).to_dict(),
strip_private=strip_private, strip_private=strip_private,
@ -95,7 +100,13 @@ class GetMixin(PropsMixin):
@classmethod @classmethod
def get( def get(
cls, company, id, *, _only=None, include_public=False, **kwargs cls: Union["GetMixin", Document],
company,
id,
*,
_only=None,
include_public=False,
**kwargs,
) -> "GetMixin": ) -> "GetMixin":
q = cls.objects( q = cls.objects(
cls._prepare_perm_query(company, allow_public=include_public) cls._prepare_perm_query(company, allow_public=include_public)
@ -409,7 +420,12 @@ class GetMixin(PropsMixin):
) )
@classmethod @classmethod
def _get_many_no_company(cls, query, parameters=None, override_projection=None): def _get_many_no_company(
cls: Union["GetMixin", Document],
query,
parameters=None,
override_projection=None,
):
""" """
Fetch all documents matching a provided query. Fetch all documents matching a provided query.
This is a company-less version for internal uses. We assume the caller has either added any necessary This is a company-less version for internal uses. We assume the caller has either added any necessary
@ -593,7 +609,13 @@ class UpdateMixin(object):
return update_dict return update_dict
@classmethod @classmethod
def safe_update(cls, company_id, id, partial_update_dict, injected_update=None): def safe_update(
cls: Union["UpdateMixin", Document],
company_id,
id,
partial_update_dict,
injected_update=None,
):
update_dict = cls.get_safe_update_dict(partial_update_dict) update_dict = cls.get_safe_update_dict(partial_update_dict)
if not update_dict: if not update_dict:
return 0, {} return 0, {}
@ -610,7 +632,10 @@ class DbModelMixin(GetMixin, ProperDictMixin, UpdateMixin):
@classmethod @classmethod
def aggregate( def aggregate(
cls: Document, *pipeline: dict, allow_disk_use=None, **kwargs cls: Union["DbModelMixin", Document],
pipeline: Sequence[dict],
allow_disk_use=None,
**kwargs,
) -> CommandCursor: ) -> CommandCursor:
""" """
Aggregate objects of this document class according to the provided pipeline. Aggregate objects of this document class according to the provided pipeline.
@ -625,7 +650,7 @@ class DbModelMixin(GetMixin, ProperDictMixin, UpdateMixin):
if allow_disk_use is not None if allow_disk_use is not None
else config.get("apiserver.mongo.aggregate.allow_disk_use", True) else config.get("apiserver.mongo.aggregate.allow_disk_use", True)
) )
return cls.objects.aggregate(*pipeline, **kwargs) return cls.objects.aggregate(pipeline, **kwargs)
def validate_id(cls, company, **kwargs): def validate_id(cls, company, **kwargs):

View File

@ -1,11 +1,14 @@
from mongoengine import MapField, IntField from database.fields import NoneType, UnionField, SafeMapField
class ModelLabels(MapField): class ModelLabels(SafeMapField):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super(ModelLabels, self).__init__(field=IntField(), *args, **kwargs) super(ModelLabels, self).__init__(
field=UnionField(types=(int, NoneType)), *args, **kwargs
)
def validate(self, value): def validate(self, value):
super(ModelLabels, self).validate(value) super(ModelLabels, self).validate(value)
if value and (len(set(value.values())) < len(value)): non_empty_values = list(filter(None, value.values()))
if non_empty_values and len(set(non_empty_values)) < len(non_empty_values):
self.error("Same label id appears more than once in model labels") self.error("Same label id appears more than once in model labels")

View File

@ -1,8 +1,14 @@
import copy import copy
import re import re
from typing import Union
from mongoengine import Q from mongoengine import Q
from mongoengine.queryset.visitor import QueryCompilerVisitor, SimplificationVisitor, QCombination from mongoengine.queryset.visitor import (
QueryCompilerVisitor,
SimplificationVisitor,
QCombination,
QNode,
)
class RegexWrapper(object): class RegexWrapper(object):
@ -17,17 +23,16 @@ class RegexWrapper(object):
class RegexMixin(object): class RegexMixin(object):
def to_query(self: Union["RegexMixin", QNode], document):
def to_query(self, document):
query = self.accept(SimplificationVisitor()) query = self.accept(SimplificationVisitor())
query = query.accept(RegexQueryCompilerVisitor(document)) query = query.accept(RegexQueryCompilerVisitor(document))
return query return query
def _combine(self, other, operation): def _combine(self: Union["RegexMixin", QNode], other, operation):
"""Combine this node with another node into a QCombination """Combine this node with another node into a QCombination
object. object.
""" """
if getattr(other, 'empty', True): if getattr(other, "empty", True):
return self return self
if self.empty: if self.empty:

View File

@ -95,26 +95,18 @@ def parse_from_call(call_data, fields, cls_fields, discard_none_values=True):
res[field] = None res[field] = None
continue continue
if desc: if desc:
if callable(desc): if issubclass(desc, Document):
if not desc.objects(id=value).only("id"):
raise ParseCallError(
"expecting %s id" % desc.__name__, id=value, field=field
)
elif callable(desc):
try: try:
desc(value) desc(value)
except TypeError: except TypeError:
raise ParseCallError(f"expecting {desc.__name__}", field=field) raise ParseCallError(f"expecting {desc.__name__}", field=field)
except Exception as ex: except Exception as ex:
raise ParseCallError(str(ex), field=field) raise ParseCallError(str(ex), field=field)
else:
if issubclass(desc, (list, tuple, dict)) and not isinstance(
value, desc
):
raise ParseCallError(
"expecting %s" % desc.__name__, field=field
)
if issubclass(desc, Document) and not desc.objects(id=value).only(
"id"
):
raise ParseCallError(
"expecting %s id" % desc.__name__, id=value, field=field
)
res[field] = value res[field] = value
return res return res

View File

@ -14,12 +14,12 @@ Jinja2==2.10
jsonmodels>=2.3 jsonmodels>=2.3
jsonschema>=2.6.0 jsonschema>=2.6.0
luqum>=0.7.2 luqum>=0.7.2
mongoengine==0.16.2 mongoengine==0.19.1
nested_dict>=1.61 nested_dict>=1.61
psutil>=5.6.5 psutil>=5.6.5
pyhocon>=0.3.35 pyhocon>=0.3.35
pyjwt>=1.3.0 pyjwt>=1.3.0
pymongo==3.6.1 # 3.7 has a bug multiple users logged in pymongo==3.10.1
python-rapidjson>=0.6.3 python-rapidjson>=0.6.3
redis>=2.10.5 redis>=2.10.5
related>=0.7.2 related>=0.7.2

View File

@ -290,11 +290,15 @@ def prepare_update_fields(call, fields):
invalid_keys = find_other_types(labels.keys(), str) invalid_keys = find_other_types(labels.keys(), str)
if invalid_keys: if invalid_keys:
raise errors.bad_request.ValidationError("labels keys must be strings", keys=invalid_keys) raise errors.bad_request.ValidationError(
"labels keys must be strings", keys=invalid_keys
)
invalid_values = find_other_types(labels.values(), int) invalid_values = find_other_types(labels.values(), int)
if invalid_values: if invalid_values:
raise errors.bad_request.ValidationError("labels values must be integers", values=invalid_values) raise errors.bad_request.ValidationError(
"labels values must be integers", values=invalid_values
)
conform_tag_fields(call, fields) conform_tag_fields(call, fields)
return fields return fields
@ -331,7 +335,7 @@ def edit(call: APICall):
fields[key] = d fields[key] = d
iteration = call.data.get("iteration") iteration = call.data.get("iteration")
task_id = model.task or fields.get('task') task_id = model.task or fields.get("task")
if task_id and iteration is not None: if task_id and iteration is not None:
TaskBLL.update_statistics( TaskBLL.update_statistics(
task_id=task_id, task_id=task_id,
@ -393,14 +397,14 @@ def set_ready(call: APICall, company, req_model: PublishModelRequest):
model_id=req_model.model, model_id=req_model.model,
company_id=company, company_id=company,
publish_task=req_model.publish_task, publish_task=req_model.publish_task,
force_publish_task=req_model.force_publish_task force_publish_task=req_model.force_publish_task,
) )
call.result.data_model = PublishModelResponse( call.result.data_model = PublishModelResponse(
updated=updated, updated=updated,
published_task=ModelTaskPublishResponse( published_task=ModelTaskPublishResponse(**published_task_data)
**published_task_data if published_task_data
) if published_task_data else None else None,
) )

View File

@ -210,7 +210,7 @@ def get_all_ex(call: APICall):
status_count = defaultdict(lambda: {}) status_count = defaultdict(lambda: {})
key = itemgetter(EntityVisibility.archived.value) key = itemgetter(EntityVisibility.archived.value)
for result in Task.aggregate(*status_count_pipeline): for result in Task.aggregate(status_count_pipeline):
for k, group in groupby(sorted(result["counts"], key=key), key): for k, group in groupby(sorted(result["counts"], key=key), key):
section = ( section = (
EntityVisibility.archived if k else EntityVisibility.active EntityVisibility.archived if k else EntityVisibility.active
@ -224,7 +224,7 @@ def get_all_ex(call: APICall):
runtime = { runtime = {
result["_id"]: {k: v for k, v in result.items() if k != "_id"} result["_id"]: {k: v for k, v in result.items() if k != "_id"}
for result in Task.aggregate(*runtime_pipeline) for result in Task.aggregate(runtime_pipeline)
} }
def safe_get(obj, path, default=None): def safe_get(obj, path, default=None):

View File

@ -212,7 +212,9 @@ def get_queue_metrics(
dates=data["date"], dates=data["date"],
avg_waiting_times=data["avg_waiting_time"], avg_waiting_times=data["avg_waiting_time"],
queue_lengths=data["queue_length"], queue_lengths=data["queue_length"],
) if data else QueueMetrics(queue=queue) )
if data
else QueueMetrics(queue=queue)
for queue, data in queue_dicts.items() for queue, data in queue_dicts.items()
] ]
) )