mirror of
https://github.com/clearml/clearml-server
synced 2025-01-31 10:56:48 +00:00
202 lines
7.4 KiB
Python
202 lines
7.4 KiB
Python
from collections import OrderedDict, defaultdict
|
|
from itertools import chain
|
|
from operator import attrgetter
|
|
from threading import Lock
|
|
from typing import Sequence
|
|
|
|
import six
|
|
from mongoengine import EmbeddedDocumentField, EmbeddedDocumentListField
|
|
from mongoengine.base import get_document, BaseField
|
|
|
|
from apiserver.database.fields import (
|
|
LengthRangeEmbeddedDocumentListField,
|
|
UniqueEmbeddedDocumentListField,
|
|
EmbeddedDocumentSortedListField,
|
|
)
|
|
from apiserver.database.utils import get_fields, get_fields_attr
|
|
|
|
|
|
class PropsMixin(object):
|
|
__cached_fields = None
|
|
__cached_reference_fields = None
|
|
__cached_exclude_fields = None
|
|
__cached_fields_with_instance = None
|
|
__cached_field_names_per_type = None
|
|
|
|
__cached_dpath_computed_fields_lock = Lock()
|
|
__cached_dpath_computed_fields = None
|
|
|
|
@classmethod
|
|
def get_fields(cls):
|
|
if cls.__cached_fields is None:
|
|
cls.__cached_fields = get_fields(cls)
|
|
return cls.__cached_fields
|
|
|
|
@classmethod
|
|
def get_field_names_for_type(cls, of_type=BaseField):
|
|
"""
|
|
Return field names per type including subfields
|
|
The fields of derived types are also returned
|
|
"""
|
|
assert issubclass(of_type, BaseField)
|
|
if cls.__cached_field_names_per_type is None:
|
|
fields = defaultdict(list)
|
|
for name, field in get_fields(cls, return_instance=True, subfields=True):
|
|
fields[type(field)].append(name)
|
|
for type_ in fields:
|
|
fields[type_].extend(
|
|
chain.from_iterable(
|
|
fields[other_type]
|
|
for other_type in fields
|
|
if other_type != type_ and issubclass(other_type, type_)
|
|
)
|
|
)
|
|
cls.__cached_field_names_per_type = fields
|
|
|
|
if of_type not in cls.__cached_field_names_per_type:
|
|
names = list(
|
|
chain.from_iterable(
|
|
field_names
|
|
for type_, field_names in cls.__cached_field_names_per_type.items()
|
|
if issubclass(type_, of_type)
|
|
)
|
|
)
|
|
cls.__cached_field_names_per_type[of_type] = names
|
|
|
|
return cls.__cached_field_names_per_type[of_type]
|
|
|
|
@classmethod
|
|
def get_fields_with_instance(cls, doc_cls):
|
|
if cls.__cached_fields_with_instance is None:
|
|
cls.__cached_fields_with_instance = {}
|
|
if doc_cls not in cls.__cached_fields_with_instance:
|
|
cls.__cached_fields_with_instance[doc_cls] = get_fields(
|
|
doc_cls, return_instance=True
|
|
)
|
|
return cls.__cached_fields_with_instance[doc_cls]
|
|
|
|
@staticmethod
|
|
def _get_fields_with_attr(cls_, attr):
|
|
""" Get all fields with the specified attribute (supports nested fields) """
|
|
res = get_fields_attr(cls_, attr=attr)
|
|
|
|
def resolve_doc(v):
|
|
if not isinstance(v, six.string_types):
|
|
return v
|
|
if v == 'self':
|
|
return cls_.owner_document
|
|
return get_document(v)
|
|
|
|
fields = {k: resolve_doc(v) for k, v in res.items()}
|
|
|
|
def collect_embedded_docs(doc_cls, embedded_doc_field_getter):
|
|
for field, embedded_doc_field in get_fields(
|
|
cls_, of_type=doc_cls, return_instance=True
|
|
):
|
|
embedded_doc_cls = embedded_doc_field_getter(
|
|
embedded_doc_field
|
|
).document_type
|
|
fields.update(
|
|
{
|
|
'.'.join((field, subfield)): doc
|
|
for subfield, doc in PropsMixin._get_fields_with_attr(
|
|
embedded_doc_cls, attr
|
|
).items()
|
|
}
|
|
)
|
|
|
|
collect_embedded_docs(EmbeddedDocumentField, lambda x: x)
|
|
collect_embedded_docs(EmbeddedDocumentListField, attrgetter('field'))
|
|
collect_embedded_docs(LengthRangeEmbeddedDocumentListField, attrgetter('field'))
|
|
collect_embedded_docs(UniqueEmbeddedDocumentListField, attrgetter('field'))
|
|
collect_embedded_docs(EmbeddedDocumentSortedListField, attrgetter('field'))
|
|
|
|
return fields
|
|
|
|
@classmethod
|
|
def _translate_fields_path(cls, parts):
|
|
current_cls = cls
|
|
translated_parts = []
|
|
for depth, part in enumerate(parts):
|
|
if current_cls is None:
|
|
raise ValueError(
|
|
'Invalid path (non-document encountered at %s)' % parts[: depth - 1]
|
|
)
|
|
try:
|
|
field_name, field = next(
|
|
(k, v)
|
|
for k, v in cls.get_fields_with_instance(current_cls)
|
|
if k == part
|
|
)
|
|
except StopIteration:
|
|
raise ValueError('Invalid field path %s' % parts[:depth])
|
|
|
|
translated_parts.append(part)
|
|
|
|
if isinstance(field, EmbeddedDocumentField):
|
|
current_cls = field.document_type
|
|
elif isinstance(
|
|
field,
|
|
(
|
|
EmbeddedDocumentListField,
|
|
LengthRangeEmbeddedDocumentListField,
|
|
UniqueEmbeddedDocumentListField,
|
|
EmbeddedDocumentSortedListField,
|
|
),
|
|
):
|
|
current_cls = field.field.document_type
|
|
translated_parts.append('*')
|
|
else:
|
|
current_cls = None
|
|
|
|
return translated_parts
|
|
|
|
@classmethod
|
|
def get_reference_fields(cls):
|
|
if cls.__cached_reference_fields is None:
|
|
fields = cls._get_fields_with_attr(cls, 'reference_field')
|
|
cls.__cached_reference_fields = OrderedDict(sorted(fields.items()))
|
|
return cls.__cached_reference_fields
|
|
|
|
@classmethod
|
|
def get_extra_projection(cls, fields: Sequence) -> tuple:
|
|
if isinstance(fields, str):
|
|
fields = [fields]
|
|
return tuple(
|
|
set(fields).union(cls.get_fields()).difference(cls.get_exclude_fields())
|
|
)
|
|
|
|
@classmethod
|
|
def get_exclude_fields(cls):
|
|
if cls.__cached_exclude_fields is None:
|
|
fields = cls._get_fields_with_attr(cls, 'exclude_by_default')
|
|
cls.__cached_exclude_fields = OrderedDict(sorted(fields.items()))
|
|
return cls.__cached_exclude_fields
|
|
|
|
@classmethod
|
|
def get_dpath_translated_path(cls, path, separator='.'):
|
|
if cls.__cached_dpath_computed_fields is None:
|
|
cls.__cached_dpath_computed_fields = {}
|
|
if path not in cls.__cached_dpath_computed_fields:
|
|
with cls.__cached_dpath_computed_fields_lock:
|
|
parts = path.split(separator)
|
|
translated = cls._translate_fields_path(parts)
|
|
result = separator.join(translated)
|
|
cls.__cached_dpath_computed_fields[path] = result
|
|
return cls.__cached_dpath_computed_fields[path]
|
|
|
|
def get_field_value(self, field_path: str, default=None):
|
|
"""
|
|
Return the document field_path value by the field_path name.
|
|
The path may contain '.'. If on any level the path is
|
|
not found then the default value is returned
|
|
"""
|
|
path_elements = field_path.split(".")
|
|
current = self
|
|
for name in path_elements:
|
|
current = getattr(current, name, default)
|
|
if current == default:
|
|
break
|
|
|
|
return current
|