clearml-server/apiserver/database/props.py
2021-05-03 17:36:04 +03:00

176 lines
6.3 KiB
Python

from collections import OrderedDict
from operator import attrgetter
from threading import Lock
from typing import Sequence
import six
from mongoengine import EmbeddedDocumentField, EmbeddedDocumentListField
from mongoengine.base import get_document
from apiserver.database.fields import (
LengthRangeEmbeddedDocumentListField,
UniqueEmbeddedDocumentListField,
EmbeddedDocumentSortedListField,
)
from apiserver.database.utils import get_fields, get_fields_attr
class PropsMixin(object):
__cached_fields = None
__cached_reference_fields = None
__cached_exclude_fields = None
__cached_fields_with_instance = None
__cached_all_fields_with_instance = None
__cached_dpath_computed_fields_lock = Lock()
__cached_dpath_computed_fields = None
@classmethod
def get_fields(cls):
if cls.__cached_fields is None:
cls.__cached_fields = get_fields(cls)
return cls.__cached_fields
@classmethod
def get_all_fields_with_instance(cls):
if cls.__cached_all_fields_with_instance is None:
cls.__cached_all_fields_with_instance = get_fields(
cls, return_instance=True, subfields=True
)
return cls.__cached_all_fields_with_instance
@classmethod
def get_fields_with_instance(cls, doc_cls):
if cls.__cached_fields_with_instance is None:
cls.__cached_fields_with_instance = {}
if doc_cls not in cls.__cached_fields_with_instance:
cls.__cached_fields_with_instance[doc_cls] = get_fields(
doc_cls, return_instance=True
)
return cls.__cached_fields_with_instance[doc_cls]
@staticmethod
def _get_fields_with_attr(cls_, attr):
""" Get all fields with the specified attribute (supports nested fields) """
res = get_fields_attr(cls_, attr=attr)
def resolve_doc(v):
if not isinstance(v, six.string_types):
return v
if v == 'self':
return cls_.owner_document
return get_document(v)
fields = {k: resolve_doc(v) for k, v in res.items()}
def collect_embedded_docs(doc_cls, embedded_doc_field_getter):
for field, embedded_doc_field in get_fields(
cls_, of_type=doc_cls, return_instance=True
):
embedded_doc_cls = embedded_doc_field_getter(
embedded_doc_field
).document_type
fields.update(
{
'.'.join((field, subfield)): doc
for subfield, doc in PropsMixin._get_fields_with_attr(
embedded_doc_cls, attr
).items()
}
)
collect_embedded_docs(EmbeddedDocumentField, lambda x: x)
collect_embedded_docs(EmbeddedDocumentListField, attrgetter('field'))
collect_embedded_docs(LengthRangeEmbeddedDocumentListField, attrgetter('field'))
collect_embedded_docs(UniqueEmbeddedDocumentListField, attrgetter('field'))
collect_embedded_docs(EmbeddedDocumentSortedListField, attrgetter('field'))
return fields
@classmethod
def _translate_fields_path(cls, parts):
current_cls = cls
translated_parts = []
for depth, part in enumerate(parts):
if current_cls is None:
raise ValueError(
'Invalid path (non-document encountered at %s)' % parts[: depth - 1]
)
try:
field_name, field = next(
(k, v)
for k, v in cls.get_fields_with_instance(current_cls)
if k == part
)
except StopIteration:
raise ValueError('Invalid field path %s' % parts[:depth])
translated_parts.append(part)
if isinstance(field, EmbeddedDocumentField):
current_cls = field.document_type
elif isinstance(
field,
(
EmbeddedDocumentListField,
LengthRangeEmbeddedDocumentListField,
UniqueEmbeddedDocumentListField,
EmbeddedDocumentSortedListField,
),
):
current_cls = field.field.document_type
translated_parts.append('*')
else:
current_cls = None
return translated_parts
@classmethod
def get_reference_fields(cls):
if cls.__cached_reference_fields is None:
fields = cls._get_fields_with_attr(cls, 'reference_field')
cls.__cached_reference_fields = OrderedDict(sorted(fields.items()))
return cls.__cached_reference_fields
@classmethod
def get_extra_projection(cls, fields: Sequence) -> tuple:
if isinstance(fields, str):
fields = [fields]
return tuple(
set(fields).union(cls.get_fields()).difference(cls.get_exclude_fields())
)
@classmethod
def get_exclude_fields(cls):
if cls.__cached_exclude_fields is None:
fields = cls._get_fields_with_attr(cls, 'exclude_by_default')
cls.__cached_exclude_fields = OrderedDict(sorted(fields.items()))
return cls.__cached_exclude_fields
@classmethod
def get_dpath_translated_path(cls, path, separator='.'):
if cls.__cached_dpath_computed_fields is None:
cls.__cached_dpath_computed_fields = {}
if path not in cls.__cached_dpath_computed_fields:
with cls.__cached_dpath_computed_fields_lock:
parts = path.split(separator)
translated = cls._translate_fields_path(parts)
result = separator.join(translated)
cls.__cached_dpath_computed_fields[path] = result
return cls.__cached_dpath_computed_fields[path]
def get_field_value(self, field_path: str, default=None):
"""
Return the document field_path value by the field_path name.
The path may contain '.'. If on any level the path is
not found then the default value is returned
"""
path_elements = field_path.split(".")
current = self
for name in path_elements:
current = getattr(current, name, default)
if current == default:
break
return current