diff --git a/server/bll/statistics/stats_reporter.py b/server/bll/statistics/stats_reporter.py index 0eeb641..5d9f17c 100644 --- a/server/bll/statistics/stats_reporter.py +++ b/server/bll/statistics/stats_reporter.py @@ -280,7 +280,7 @@ class StatisticsReporter: ] return { group["_id"]: {k: v for k, v in group.items() if k != "_id"} - for group in Task.aggregate(*pipeline) + for group in Task.aggregate(pipeline) } diff --git a/server/bll/task/task_bll.py b/server/bll/task/task_bll.py index 38e53e4..8d83e8d 100644 --- a/server/bll/task/task_bll.py +++ b/server/bll/task/task_bll.py @@ -263,7 +263,7 @@ class TaskBLL(object): ] with translate_errors_context(): - result = Task.aggregate(*pipeline) + result = Task.aggregate(pipeline) return [r["metrics"][0] for r in result] @staticmethod @@ -666,7 +666,7 @@ class TaskBLL(object): ] with translate_errors_context(): - result = next(Task.aggregate(*pipeline), None) + result = next(Task.aggregate(pipeline), None) total = 0 remaining = 0 diff --git a/server/bll/workers/__init__.py b/server/bll/workers/__init__.py index 914814d..e8b3128 100644 --- a/server/bll/workers/__init__.py +++ b/server/bll/workers/__init__.py @@ -223,7 +223,7 @@ class WorkerBLL: }, ] queues_info = { - res["_id"]: res for res in Queue.objects.aggregate(*projection) + res["_id"]: res for res in Queue.objects.aggregate(projection) } task_ids = task_ids.union( filter( diff --git a/server/database/fields.py b/server/database/fields.py index 94a012f..ca3ee9b 100644 --- a/server/database/fields.py +++ b/server/database/fields.py @@ -14,6 +14,9 @@ from mongoengine import ( DictField, DynamicField, ) +from mongoengine.fields import key_not_string, key_starts_with_dollar + +NoneType = type(None) class LengthRangeListField(ListField): @@ -125,17 +128,39 @@ def contains_empty_key(d): return True -class SafeMapField(MapField): +class DictValidationMixin: + """ + DictField validation in MongoEngine requires default alias and permissions to access DB version: + https://github.com/MongoEngine/mongoengine/issues/2239 + This is a stripped down implementation that does not require any of the above and implies Mongo ver 3.6+ + """ + + def _safe_validate(self: DictField, value): + if not isinstance(value, dict): + self.error("Only dictionaries may be used in a DictField") + + if key_not_string(value): + msg = "Invalid dictionary key - documents must have only string keys" + self.error(msg) + + if key_starts_with_dollar(value): + self.error( + 'Invalid dictionary key name - keys may not startswith "$" characters' + ) + super(DictField, self).validate(value) + + +class SafeMapField(MapField, DictValidationMixin): def validate(self, value): - super(SafeMapField, self).validate(value) + self._safe_validate(value) if contains_empty_key(value): self.error("Empty keys are not allowed in a MapField") -class SafeDictField(DictField): +class SafeDictField(DictField, DictValidationMixin): def validate(self, value): - super(SafeDictField, self).validate(value) + self._safe_validate(value) if contains_empty_key(value): self.error("Empty keys are not allowed in a DictField") @@ -146,6 +171,7 @@ class SafeSortedListField(SortedListField): SortedListField that does not raise an error in case items are not comparable (in which case they will be sorted by their string representation) """ + def to_mongo(self, *args, **kwargs): try: return super(SafeSortedListField, self).to_mongo(*args, **kwargs) @@ -155,7 +181,10 @@ class SafeSortedListField(SortedListField): def _safe_to_mongo(self, value, use_db_field=True, fields=None): value = super(SortedListField, self).to_mongo(value, use_db_field, fields) if self._ordering is not None: - def key(v): return str(itemgetter(self._ordering)(v)) + + def key(v): + return str(itemgetter(self._ordering)(v)) + else: key = str return sorted(value, key=key, reverse=self._order_reverse) diff --git a/server/database/model/base.py b/server/database/model/base.py index 272f8ef..2ba6a27 100644 --- a/server/database/model/base.py +++ b/server/database/model/base.py @@ -34,7 +34,12 @@ class AuthDocument(Document): class ProperDictMixin(object): - def to_proper_dict(self, strip_private=True, only=None, extra_dict=None) -> dict: + def to_proper_dict( + self: Union["ProperDictMixin", Document], + strip_private=True, + only=None, + extra_dict=None, + ) -> dict: return self.properize_dict( self.to_mongo(use_db_field=False).to_dict(), strip_private=strip_private, @@ -95,7 +100,13 @@ class GetMixin(PropsMixin): @classmethod def get( - cls, company, id, *, _only=None, include_public=False, **kwargs + cls: Union["GetMixin", Document], + company, + id, + *, + _only=None, + include_public=False, + **kwargs, ) -> "GetMixin": q = cls.objects( cls._prepare_perm_query(company, allow_public=include_public) @@ -409,7 +420,12 @@ class GetMixin(PropsMixin): ) @classmethod - def _get_many_no_company(cls, query, parameters=None, override_projection=None): + def _get_many_no_company( + cls: Union["GetMixin", Document], + query, + parameters=None, + override_projection=None, + ): """ Fetch all documents matching a provided query. This is a company-less version for internal uses. We assume the caller has either added any necessary @@ -593,7 +609,13 @@ class UpdateMixin(object): return update_dict @classmethod - def safe_update(cls, company_id, id, partial_update_dict, injected_update=None): + def safe_update( + cls: Union["UpdateMixin", Document], + company_id, + id, + partial_update_dict, + injected_update=None, + ): update_dict = cls.get_safe_update_dict(partial_update_dict) if not update_dict: return 0, {} @@ -610,7 +632,10 @@ class DbModelMixin(GetMixin, ProperDictMixin, UpdateMixin): @classmethod def aggregate( - cls: Document, *pipeline: dict, allow_disk_use=None, **kwargs + cls: Union["DbModelMixin", Document], + pipeline: Sequence[dict], + allow_disk_use=None, + **kwargs, ) -> CommandCursor: """ Aggregate objects of this document class according to the provided pipeline. @@ -625,7 +650,7 @@ class DbModelMixin(GetMixin, ProperDictMixin, UpdateMixin): if allow_disk_use is not None else config.get("apiserver.mongo.aggregate.allow_disk_use", True) ) - return cls.objects.aggregate(*pipeline, **kwargs) + return cls.objects.aggregate(pipeline, **kwargs) def validate_id(cls, company, **kwargs): diff --git a/server/database/model/model_labels.py b/server/database/model/model_labels.py index 1f18c99..7b6f4e6 100644 --- a/server/database/model/model_labels.py +++ b/server/database/model/model_labels.py @@ -1,11 +1,14 @@ -from mongoengine import MapField, IntField +from database.fields import NoneType, UnionField, SafeMapField -class ModelLabels(MapField): +class ModelLabels(SafeMapField): def __init__(self, *args, **kwargs): - super(ModelLabels, self).__init__(field=IntField(), *args, **kwargs) + super(ModelLabels, self).__init__( + field=UnionField(types=(int, NoneType)), *args, **kwargs + ) def validate(self, value): super(ModelLabels, self).validate(value) - if value and (len(set(value.values())) < len(value)): + non_empty_values = list(filter(None, value.values())) + if non_empty_values and len(set(non_empty_values)) < len(non_empty_values): self.error("Same label id appears more than once in model labels") diff --git a/server/database/query.py b/server/database/query.py index a6885c7..62f409b 100644 --- a/server/database/query.py +++ b/server/database/query.py @@ -1,8 +1,14 @@ import copy import re +from typing import Union from mongoengine import Q -from mongoengine.queryset.visitor import QueryCompilerVisitor, SimplificationVisitor, QCombination +from mongoengine.queryset.visitor import ( + QueryCompilerVisitor, + SimplificationVisitor, + QCombination, + QNode, +) class RegexWrapper(object): @@ -17,17 +23,16 @@ class RegexWrapper(object): class RegexMixin(object): - - def to_query(self, document): + def to_query(self: Union["RegexMixin", QNode], document): query = self.accept(SimplificationVisitor()) query = query.accept(RegexQueryCompilerVisitor(document)) return query - def _combine(self, other, operation): + def _combine(self: Union["RegexMixin", QNode], other, operation): """Combine this node with another node into a QCombination object. """ - if getattr(other, 'empty', True): + if getattr(other, "empty", True): return self if self.empty: diff --git a/server/database/utils.py b/server/database/utils.py index 1af37b8..7b4f92d 100644 --- a/server/database/utils.py +++ b/server/database/utils.py @@ -95,26 +95,18 @@ def parse_from_call(call_data, fields, cls_fields, discard_none_values=True): res[field] = None continue if desc: - if callable(desc): + if issubclass(desc, Document): + if not desc.objects(id=value).only("id"): + raise ParseCallError( + "expecting %s id" % desc.__name__, id=value, field=field + ) + elif callable(desc): try: desc(value) except TypeError: raise ParseCallError(f"expecting {desc.__name__}", field=field) except Exception as ex: raise ParseCallError(str(ex), field=field) - else: - if issubclass(desc, (list, tuple, dict)) and not isinstance( - value, desc - ): - raise ParseCallError( - "expecting %s" % desc.__name__, field=field - ) - if issubclass(desc, Document) and not desc.objects(id=value).only( - "id" - ): - raise ParseCallError( - "expecting %s id" % desc.__name__, id=value, field=field - ) res[field] = value return res diff --git a/server/requirements.txt b/server/requirements.txt index 7cb394c..ef479c9 100644 --- a/server/requirements.txt +++ b/server/requirements.txt @@ -14,12 +14,12 @@ Jinja2==2.10 jsonmodels>=2.3 jsonschema>=2.6.0 luqum>=0.7.2 -mongoengine==0.16.2 +mongoengine==0.19.1 nested_dict>=1.61 psutil>=5.6.5 pyhocon>=0.3.35 pyjwt>=1.3.0 -pymongo==3.6.1 # 3.7 has a bug multiple users logged in +pymongo==3.10.1 python-rapidjson>=0.6.3 redis>=2.10.5 related>=0.7.2 diff --git a/server/services/models.py b/server/services/models.py index 25fa202..b9adc9d 100644 --- a/server/services/models.py +++ b/server/services/models.py @@ -290,11 +290,15 @@ def prepare_update_fields(call, fields): invalid_keys = find_other_types(labels.keys(), str) if invalid_keys: - raise errors.bad_request.ValidationError("labels keys must be strings", keys=invalid_keys) + raise errors.bad_request.ValidationError( + "labels keys must be strings", keys=invalid_keys + ) invalid_values = find_other_types(labels.values(), int) if invalid_values: - raise errors.bad_request.ValidationError("labels values must be integers", values=invalid_values) + raise errors.bad_request.ValidationError( + "labels values must be integers", values=invalid_values + ) conform_tag_fields(call, fields) return fields @@ -331,7 +335,7 @@ def edit(call: APICall): fields[key] = d iteration = call.data.get("iteration") - task_id = model.task or fields.get('task') + task_id = model.task or fields.get("task") if task_id and iteration is not None: TaskBLL.update_statistics( task_id=task_id, @@ -393,14 +397,14 @@ def set_ready(call: APICall, company, req_model: PublishModelRequest): model_id=req_model.model, company_id=company, publish_task=req_model.publish_task, - force_publish_task=req_model.force_publish_task + force_publish_task=req_model.force_publish_task, ) call.result.data_model = PublishModelResponse( updated=updated, - published_task=ModelTaskPublishResponse( - **published_task_data - ) if published_task_data else None + published_task=ModelTaskPublishResponse(**published_task_data) + if published_task_data + else None, ) diff --git a/server/services/projects.py b/server/services/projects.py index cdaecb4..ff98102 100644 --- a/server/services/projects.py +++ b/server/services/projects.py @@ -210,7 +210,7 @@ def get_all_ex(call: APICall): status_count = defaultdict(lambda: {}) key = itemgetter(EntityVisibility.archived.value) - for result in Task.aggregate(*status_count_pipeline): + for result in Task.aggregate(status_count_pipeline): for k, group in groupby(sorted(result["counts"], key=key), key): section = ( EntityVisibility.archived if k else EntityVisibility.active @@ -224,7 +224,7 @@ def get_all_ex(call: APICall): runtime = { result["_id"]: {k: v for k, v in result.items() if k != "_id"} - for result in Task.aggregate(*runtime_pipeline) + for result in Task.aggregate(runtime_pipeline) } def safe_get(obj, path, default=None): diff --git a/server/services/queues.py b/server/services/queues.py index 4c73122..bbf60b2 100644 --- a/server/services/queues.py +++ b/server/services/queues.py @@ -212,7 +212,9 @@ def get_queue_metrics( dates=data["date"], avg_waiting_times=data["avg_waiting_time"], queue_lengths=data["queue_length"], - ) if data else QueueMetrics(queue=queue) + ) + if data + else QueueMetrics(queue=queue) for queue, data in queue_dicts.items() ] )