mirror of
https://github.com/clearml/clearml-server
synced 2025-05-08 14:04:44 +00:00
Update pymongo and mongoengine versions
This commit is contained in:
parent
35a11db58e
commit
a1dcdffa53
@ -280,7 +280,7 @@ class StatisticsReporter:
|
|||||||
]
|
]
|
||||||
return {
|
return {
|
||||||
group["_id"]: {k: v for k, v in group.items() if k != "_id"}
|
group["_id"]: {k: v for k, v in group.items() if k != "_id"}
|
||||||
for group in Task.aggregate(*pipeline)
|
for group in Task.aggregate(pipeline)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -263,7 +263,7 @@ class TaskBLL(object):
|
|||||||
]
|
]
|
||||||
|
|
||||||
with translate_errors_context():
|
with translate_errors_context():
|
||||||
result = Task.aggregate(*pipeline)
|
result = Task.aggregate(pipeline)
|
||||||
return [r["metrics"][0] for r in result]
|
return [r["metrics"][0] for r in result]
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -666,7 +666,7 @@ class TaskBLL(object):
|
|||||||
]
|
]
|
||||||
|
|
||||||
with translate_errors_context():
|
with translate_errors_context():
|
||||||
result = next(Task.aggregate(*pipeline), None)
|
result = next(Task.aggregate(pipeline), None)
|
||||||
|
|
||||||
total = 0
|
total = 0
|
||||||
remaining = 0
|
remaining = 0
|
||||||
|
@ -223,7 +223,7 @@ class WorkerBLL:
|
|||||||
},
|
},
|
||||||
]
|
]
|
||||||
queues_info = {
|
queues_info = {
|
||||||
res["_id"]: res for res in Queue.objects.aggregate(*projection)
|
res["_id"]: res for res in Queue.objects.aggregate(projection)
|
||||||
}
|
}
|
||||||
task_ids = task_ids.union(
|
task_ids = task_ids.union(
|
||||||
filter(
|
filter(
|
||||||
|
@ -14,6 +14,9 @@ from mongoengine import (
|
|||||||
DictField,
|
DictField,
|
||||||
DynamicField,
|
DynamicField,
|
||||||
)
|
)
|
||||||
|
from mongoengine.fields import key_not_string, key_starts_with_dollar
|
||||||
|
|
||||||
|
NoneType = type(None)
|
||||||
|
|
||||||
|
|
||||||
class LengthRangeListField(ListField):
|
class LengthRangeListField(ListField):
|
||||||
@ -125,17 +128,39 @@ def contains_empty_key(d):
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
class SafeMapField(MapField):
|
class DictValidationMixin:
|
||||||
|
"""
|
||||||
|
DictField validation in MongoEngine requires default alias and permissions to access DB version:
|
||||||
|
https://github.com/MongoEngine/mongoengine/issues/2239
|
||||||
|
This is a stripped down implementation that does not require any of the above and implies Mongo ver 3.6+
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _safe_validate(self: DictField, value):
|
||||||
|
if not isinstance(value, dict):
|
||||||
|
self.error("Only dictionaries may be used in a DictField")
|
||||||
|
|
||||||
|
if key_not_string(value):
|
||||||
|
msg = "Invalid dictionary key - documents must have only string keys"
|
||||||
|
self.error(msg)
|
||||||
|
|
||||||
|
if key_starts_with_dollar(value):
|
||||||
|
self.error(
|
||||||
|
'Invalid dictionary key name - keys may not startswith "$" characters'
|
||||||
|
)
|
||||||
|
super(DictField, self).validate(value)
|
||||||
|
|
||||||
|
|
||||||
|
class SafeMapField(MapField, DictValidationMixin):
|
||||||
def validate(self, value):
|
def validate(self, value):
|
||||||
super(SafeMapField, self).validate(value)
|
self._safe_validate(value)
|
||||||
|
|
||||||
if contains_empty_key(value):
|
if contains_empty_key(value):
|
||||||
self.error("Empty keys are not allowed in a MapField")
|
self.error("Empty keys are not allowed in a MapField")
|
||||||
|
|
||||||
|
|
||||||
class SafeDictField(DictField):
|
class SafeDictField(DictField, DictValidationMixin):
|
||||||
def validate(self, value):
|
def validate(self, value):
|
||||||
super(SafeDictField, self).validate(value)
|
self._safe_validate(value)
|
||||||
|
|
||||||
if contains_empty_key(value):
|
if contains_empty_key(value):
|
||||||
self.error("Empty keys are not allowed in a DictField")
|
self.error("Empty keys are not allowed in a DictField")
|
||||||
@ -146,6 +171,7 @@ class SafeSortedListField(SortedListField):
|
|||||||
SortedListField that does not raise an error in case items are not comparable
|
SortedListField that does not raise an error in case items are not comparable
|
||||||
(in which case they will be sorted by their string representation)
|
(in which case they will be sorted by their string representation)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def to_mongo(self, *args, **kwargs):
|
def to_mongo(self, *args, **kwargs):
|
||||||
try:
|
try:
|
||||||
return super(SafeSortedListField, self).to_mongo(*args, **kwargs)
|
return super(SafeSortedListField, self).to_mongo(*args, **kwargs)
|
||||||
@ -155,7 +181,10 @@ class SafeSortedListField(SortedListField):
|
|||||||
def _safe_to_mongo(self, value, use_db_field=True, fields=None):
|
def _safe_to_mongo(self, value, use_db_field=True, fields=None):
|
||||||
value = super(SortedListField, self).to_mongo(value, use_db_field, fields)
|
value = super(SortedListField, self).to_mongo(value, use_db_field, fields)
|
||||||
if self._ordering is not None:
|
if self._ordering is not None:
|
||||||
def key(v): return str(itemgetter(self._ordering)(v))
|
|
||||||
|
def key(v):
|
||||||
|
return str(itemgetter(self._ordering)(v))
|
||||||
|
|
||||||
else:
|
else:
|
||||||
key = str
|
key = str
|
||||||
return sorted(value, key=key, reverse=self._order_reverse)
|
return sorted(value, key=key, reverse=self._order_reverse)
|
||||||
|
@ -34,7 +34,12 @@ class AuthDocument(Document):
|
|||||||
|
|
||||||
|
|
||||||
class ProperDictMixin(object):
|
class ProperDictMixin(object):
|
||||||
def to_proper_dict(self, strip_private=True, only=None, extra_dict=None) -> dict:
|
def to_proper_dict(
|
||||||
|
self: Union["ProperDictMixin", Document],
|
||||||
|
strip_private=True,
|
||||||
|
only=None,
|
||||||
|
extra_dict=None,
|
||||||
|
) -> dict:
|
||||||
return self.properize_dict(
|
return self.properize_dict(
|
||||||
self.to_mongo(use_db_field=False).to_dict(),
|
self.to_mongo(use_db_field=False).to_dict(),
|
||||||
strip_private=strip_private,
|
strip_private=strip_private,
|
||||||
@ -95,7 +100,13 @@ class GetMixin(PropsMixin):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get(
|
def get(
|
||||||
cls, company, id, *, _only=None, include_public=False, **kwargs
|
cls: Union["GetMixin", Document],
|
||||||
|
company,
|
||||||
|
id,
|
||||||
|
*,
|
||||||
|
_only=None,
|
||||||
|
include_public=False,
|
||||||
|
**kwargs,
|
||||||
) -> "GetMixin":
|
) -> "GetMixin":
|
||||||
q = cls.objects(
|
q = cls.objects(
|
||||||
cls._prepare_perm_query(company, allow_public=include_public)
|
cls._prepare_perm_query(company, allow_public=include_public)
|
||||||
@ -409,7 +420,12 @@ class GetMixin(PropsMixin):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _get_many_no_company(cls, query, parameters=None, override_projection=None):
|
def _get_many_no_company(
|
||||||
|
cls: Union["GetMixin", Document],
|
||||||
|
query,
|
||||||
|
parameters=None,
|
||||||
|
override_projection=None,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Fetch all documents matching a provided query.
|
Fetch all documents matching a provided query.
|
||||||
This is a company-less version for internal uses. We assume the caller has either added any necessary
|
This is a company-less version for internal uses. We assume the caller has either added any necessary
|
||||||
@ -593,7 +609,13 @@ class UpdateMixin(object):
|
|||||||
return update_dict
|
return update_dict
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def safe_update(cls, company_id, id, partial_update_dict, injected_update=None):
|
def safe_update(
|
||||||
|
cls: Union["UpdateMixin", Document],
|
||||||
|
company_id,
|
||||||
|
id,
|
||||||
|
partial_update_dict,
|
||||||
|
injected_update=None,
|
||||||
|
):
|
||||||
update_dict = cls.get_safe_update_dict(partial_update_dict)
|
update_dict = cls.get_safe_update_dict(partial_update_dict)
|
||||||
if not update_dict:
|
if not update_dict:
|
||||||
return 0, {}
|
return 0, {}
|
||||||
@ -610,7 +632,10 @@ class DbModelMixin(GetMixin, ProperDictMixin, UpdateMixin):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def aggregate(
|
def aggregate(
|
||||||
cls: Document, *pipeline: dict, allow_disk_use=None, **kwargs
|
cls: Union["DbModelMixin", Document],
|
||||||
|
pipeline: Sequence[dict],
|
||||||
|
allow_disk_use=None,
|
||||||
|
**kwargs,
|
||||||
) -> CommandCursor:
|
) -> CommandCursor:
|
||||||
"""
|
"""
|
||||||
Aggregate objects of this document class according to the provided pipeline.
|
Aggregate objects of this document class according to the provided pipeline.
|
||||||
@ -625,7 +650,7 @@ class DbModelMixin(GetMixin, ProperDictMixin, UpdateMixin):
|
|||||||
if allow_disk_use is not None
|
if allow_disk_use is not None
|
||||||
else config.get("apiserver.mongo.aggregate.allow_disk_use", True)
|
else config.get("apiserver.mongo.aggregate.allow_disk_use", True)
|
||||||
)
|
)
|
||||||
return cls.objects.aggregate(*pipeline, **kwargs)
|
return cls.objects.aggregate(pipeline, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def validate_id(cls, company, **kwargs):
|
def validate_id(cls, company, **kwargs):
|
||||||
|
@ -1,11 +1,14 @@
|
|||||||
from mongoengine import MapField, IntField
|
from database.fields import NoneType, UnionField, SafeMapField
|
||||||
|
|
||||||
|
|
||||||
class ModelLabels(MapField):
|
class ModelLabels(SafeMapField):
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super(ModelLabels, self).__init__(field=IntField(), *args, **kwargs)
|
super(ModelLabels, self).__init__(
|
||||||
|
field=UnionField(types=(int, NoneType)), *args, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
def validate(self, value):
|
def validate(self, value):
|
||||||
super(ModelLabels, self).validate(value)
|
super(ModelLabels, self).validate(value)
|
||||||
if value and (len(set(value.values())) < len(value)):
|
non_empty_values = list(filter(None, value.values()))
|
||||||
|
if non_empty_values and len(set(non_empty_values)) < len(non_empty_values):
|
||||||
self.error("Same label id appears more than once in model labels")
|
self.error("Same label id appears more than once in model labels")
|
||||||
|
@ -1,8 +1,14 @@
|
|||||||
import copy
|
import copy
|
||||||
import re
|
import re
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
from mongoengine import Q
|
from mongoengine import Q
|
||||||
from mongoengine.queryset.visitor import QueryCompilerVisitor, SimplificationVisitor, QCombination
|
from mongoengine.queryset.visitor import (
|
||||||
|
QueryCompilerVisitor,
|
||||||
|
SimplificationVisitor,
|
||||||
|
QCombination,
|
||||||
|
QNode,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class RegexWrapper(object):
|
class RegexWrapper(object):
|
||||||
@ -17,17 +23,16 @@ class RegexWrapper(object):
|
|||||||
|
|
||||||
|
|
||||||
class RegexMixin(object):
|
class RegexMixin(object):
|
||||||
|
def to_query(self: Union["RegexMixin", QNode], document):
|
||||||
def to_query(self, document):
|
|
||||||
query = self.accept(SimplificationVisitor())
|
query = self.accept(SimplificationVisitor())
|
||||||
query = query.accept(RegexQueryCompilerVisitor(document))
|
query = query.accept(RegexQueryCompilerVisitor(document))
|
||||||
return query
|
return query
|
||||||
|
|
||||||
def _combine(self, other, operation):
|
def _combine(self: Union["RegexMixin", QNode], other, operation):
|
||||||
"""Combine this node with another node into a QCombination
|
"""Combine this node with another node into a QCombination
|
||||||
object.
|
object.
|
||||||
"""
|
"""
|
||||||
if getattr(other, 'empty', True):
|
if getattr(other, "empty", True):
|
||||||
return self
|
return self
|
||||||
|
|
||||||
if self.empty:
|
if self.empty:
|
||||||
|
@ -95,26 +95,18 @@ def parse_from_call(call_data, fields, cls_fields, discard_none_values=True):
|
|||||||
res[field] = None
|
res[field] = None
|
||||||
continue
|
continue
|
||||||
if desc:
|
if desc:
|
||||||
if callable(desc):
|
if issubclass(desc, Document):
|
||||||
|
if not desc.objects(id=value).only("id"):
|
||||||
|
raise ParseCallError(
|
||||||
|
"expecting %s id" % desc.__name__, id=value, field=field
|
||||||
|
)
|
||||||
|
elif callable(desc):
|
||||||
try:
|
try:
|
||||||
desc(value)
|
desc(value)
|
||||||
except TypeError:
|
except TypeError:
|
||||||
raise ParseCallError(f"expecting {desc.__name__}", field=field)
|
raise ParseCallError(f"expecting {desc.__name__}", field=field)
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
raise ParseCallError(str(ex), field=field)
|
raise ParseCallError(str(ex), field=field)
|
||||||
else:
|
|
||||||
if issubclass(desc, (list, tuple, dict)) and not isinstance(
|
|
||||||
value, desc
|
|
||||||
):
|
|
||||||
raise ParseCallError(
|
|
||||||
"expecting %s" % desc.__name__, field=field
|
|
||||||
)
|
|
||||||
if issubclass(desc, Document) and not desc.objects(id=value).only(
|
|
||||||
"id"
|
|
||||||
):
|
|
||||||
raise ParseCallError(
|
|
||||||
"expecting %s id" % desc.__name__, id=value, field=field
|
|
||||||
)
|
|
||||||
res[field] = value
|
res[field] = value
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
@ -14,12 +14,12 @@ Jinja2==2.10
|
|||||||
jsonmodels>=2.3
|
jsonmodels>=2.3
|
||||||
jsonschema>=2.6.0
|
jsonschema>=2.6.0
|
||||||
luqum>=0.7.2
|
luqum>=0.7.2
|
||||||
mongoengine==0.16.2
|
mongoengine==0.19.1
|
||||||
nested_dict>=1.61
|
nested_dict>=1.61
|
||||||
psutil>=5.6.5
|
psutil>=5.6.5
|
||||||
pyhocon>=0.3.35
|
pyhocon>=0.3.35
|
||||||
pyjwt>=1.3.0
|
pyjwt>=1.3.0
|
||||||
pymongo==3.6.1 # 3.7 has a bug multiple users logged in
|
pymongo==3.10.1
|
||||||
python-rapidjson>=0.6.3
|
python-rapidjson>=0.6.3
|
||||||
redis>=2.10.5
|
redis>=2.10.5
|
||||||
related>=0.7.2
|
related>=0.7.2
|
||||||
|
@ -290,11 +290,15 @@ def prepare_update_fields(call, fields):
|
|||||||
|
|
||||||
invalid_keys = find_other_types(labels.keys(), str)
|
invalid_keys = find_other_types(labels.keys(), str)
|
||||||
if invalid_keys:
|
if invalid_keys:
|
||||||
raise errors.bad_request.ValidationError("labels keys must be strings", keys=invalid_keys)
|
raise errors.bad_request.ValidationError(
|
||||||
|
"labels keys must be strings", keys=invalid_keys
|
||||||
|
)
|
||||||
|
|
||||||
invalid_values = find_other_types(labels.values(), int)
|
invalid_values = find_other_types(labels.values(), int)
|
||||||
if invalid_values:
|
if invalid_values:
|
||||||
raise errors.bad_request.ValidationError("labels values must be integers", values=invalid_values)
|
raise errors.bad_request.ValidationError(
|
||||||
|
"labels values must be integers", values=invalid_values
|
||||||
|
)
|
||||||
|
|
||||||
conform_tag_fields(call, fields)
|
conform_tag_fields(call, fields)
|
||||||
return fields
|
return fields
|
||||||
@ -331,7 +335,7 @@ def edit(call: APICall):
|
|||||||
fields[key] = d
|
fields[key] = d
|
||||||
|
|
||||||
iteration = call.data.get("iteration")
|
iteration = call.data.get("iteration")
|
||||||
task_id = model.task or fields.get('task')
|
task_id = model.task or fields.get("task")
|
||||||
if task_id and iteration is not None:
|
if task_id and iteration is not None:
|
||||||
TaskBLL.update_statistics(
|
TaskBLL.update_statistics(
|
||||||
task_id=task_id,
|
task_id=task_id,
|
||||||
@ -393,14 +397,14 @@ def set_ready(call: APICall, company, req_model: PublishModelRequest):
|
|||||||
model_id=req_model.model,
|
model_id=req_model.model,
|
||||||
company_id=company,
|
company_id=company,
|
||||||
publish_task=req_model.publish_task,
|
publish_task=req_model.publish_task,
|
||||||
force_publish_task=req_model.force_publish_task
|
force_publish_task=req_model.force_publish_task,
|
||||||
)
|
)
|
||||||
|
|
||||||
call.result.data_model = PublishModelResponse(
|
call.result.data_model = PublishModelResponse(
|
||||||
updated=updated,
|
updated=updated,
|
||||||
published_task=ModelTaskPublishResponse(
|
published_task=ModelTaskPublishResponse(**published_task_data)
|
||||||
**published_task_data
|
if published_task_data
|
||||||
) if published_task_data else None
|
else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -210,7 +210,7 @@ def get_all_ex(call: APICall):
|
|||||||
|
|
||||||
status_count = defaultdict(lambda: {})
|
status_count = defaultdict(lambda: {})
|
||||||
key = itemgetter(EntityVisibility.archived.value)
|
key = itemgetter(EntityVisibility.archived.value)
|
||||||
for result in Task.aggregate(*status_count_pipeline):
|
for result in Task.aggregate(status_count_pipeline):
|
||||||
for k, group in groupby(sorted(result["counts"], key=key), key):
|
for k, group in groupby(sorted(result["counts"], key=key), key):
|
||||||
section = (
|
section = (
|
||||||
EntityVisibility.archived if k else EntityVisibility.active
|
EntityVisibility.archived if k else EntityVisibility.active
|
||||||
@ -224,7 +224,7 @@ def get_all_ex(call: APICall):
|
|||||||
|
|
||||||
runtime = {
|
runtime = {
|
||||||
result["_id"]: {k: v for k, v in result.items() if k != "_id"}
|
result["_id"]: {k: v for k, v in result.items() if k != "_id"}
|
||||||
for result in Task.aggregate(*runtime_pipeline)
|
for result in Task.aggregate(runtime_pipeline)
|
||||||
}
|
}
|
||||||
|
|
||||||
def safe_get(obj, path, default=None):
|
def safe_get(obj, path, default=None):
|
||||||
|
@ -212,7 +212,9 @@ def get_queue_metrics(
|
|||||||
dates=data["date"],
|
dates=data["date"],
|
||||||
avg_waiting_times=data["avg_waiting_time"],
|
avg_waiting_times=data["avg_waiting_time"],
|
||||||
queue_lengths=data["queue_length"],
|
queue_lengths=data["queue_length"],
|
||||||
) if data else QueueMetrics(queue=queue)
|
)
|
||||||
|
if data
|
||||||
|
else QueueMetrics(queue=queue)
|
||||||
for queue, data in queue_dicts.items()
|
for queue, data in queue_dicts.items()
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user