from collections import OrderedDict, defaultdict from itertools import chain from operator import attrgetter from threading import Lock from typing import Sequence import six from mongoengine import EmbeddedDocumentField, EmbeddedDocumentListField from mongoengine.base import get_document, BaseField from database.fields import ( LengthRangeEmbeddedDocumentListField, UniqueEmbeddedDocumentListField, EmbeddedDocumentSortedListField, ) from database.utils import get_fields, get_fields_attr class PropsMixin(object): __cached_fields = None __cached_reference_fields = None __cached_exclude_fields = None __cached_fields_with_instance = None __cached_field_names_per_type = None __cached_dpath_computed_fields_lock = Lock() __cached_dpath_computed_fields = None @classmethod def get_fields(cls): if cls.__cached_fields is None: cls.__cached_fields = get_fields(cls) return cls.__cached_fields @classmethod def get_field_names_for_type(cls, of_type=BaseField): """ Return field names per type including subfields The fields of derived types are also returned """ assert issubclass(of_type, BaseField) if cls.__cached_field_names_per_type is None: fields = defaultdict(list) for name, field in get_fields(cls, return_instance=True, subfields=True): fields[type(field)].append(name) for type_ in fields: fields[type_].extend( chain.from_iterable( fields[other_type] for other_type in fields if other_type != type_ and issubclass(other_type, type_) ) ) cls.__cached_field_names_per_type = fields if of_type not in cls.__cached_field_names_per_type: names = list( chain.from_iterable( field_names for type_, field_names in cls.__cached_field_names_per_type.items() if issubclass(type_, of_type) ) ) cls.__cached_field_names_per_type[of_type] = names return cls.__cached_field_names_per_type[of_type] @classmethod def get_fields_with_instance(cls, doc_cls): if cls.__cached_fields_with_instance is None: cls.__cached_fields_with_instance = {} if doc_cls not in cls.__cached_fields_with_instance: cls.__cached_fields_with_instance[doc_cls] = get_fields( doc_cls, return_instance=True ) return cls.__cached_fields_with_instance[doc_cls] @staticmethod def _get_fields_with_attr(cls_, attr): """ Get all fields with the specified attribute (supports nested fields) """ res = get_fields_attr(cls_, attr=attr) def resolve_doc(v): if not isinstance(v, six.string_types): return v if v == 'self': return cls_.owner_document return get_document(v) fields = {k: resolve_doc(v) for k, v in res.items()} def collect_embedded_docs(doc_cls, embedded_doc_field_getter): for field, embedded_doc_field in get_fields( cls_, of_type=doc_cls, return_instance=True ): embedded_doc_cls = embedded_doc_field_getter( embedded_doc_field ).document_type fields.update( { '.'.join((field, subfield)): doc for subfield, doc in PropsMixin._get_fields_with_attr( embedded_doc_cls, attr ).items() } ) collect_embedded_docs(EmbeddedDocumentField, lambda x: x) collect_embedded_docs(EmbeddedDocumentListField, attrgetter('field')) collect_embedded_docs(LengthRangeEmbeddedDocumentListField, attrgetter('field')) collect_embedded_docs(UniqueEmbeddedDocumentListField, attrgetter('field')) collect_embedded_docs(EmbeddedDocumentSortedListField, attrgetter('field')) return fields @classmethod def _translate_fields_path(cls, parts): current_cls = cls translated_parts = [] for depth, part in enumerate(parts): if current_cls is None: raise ValueError( 'Invalid path (non-document encountered at %s)' % parts[: depth - 1] ) try: field_name, field = next( (k, v) for k, v in cls.get_fields_with_instance(current_cls) if k == part ) except StopIteration: raise ValueError('Invalid field path %s' % parts[:depth]) translated_parts.append(part) if isinstance(field, EmbeddedDocumentField): current_cls = field.document_type elif isinstance( field, ( EmbeddedDocumentListField, LengthRangeEmbeddedDocumentListField, UniqueEmbeddedDocumentListField, EmbeddedDocumentSortedListField, ), ): current_cls = field.field.document_type translated_parts.append('*') else: current_cls = None return translated_parts @classmethod def get_reference_fields(cls): if cls.__cached_reference_fields is None: fields = cls._get_fields_with_attr(cls, 'reference_field') cls.__cached_reference_fields = OrderedDict(sorted(fields.items())) return cls.__cached_reference_fields @classmethod def get_extra_projection(cls, fields: Sequence) -> tuple: if isinstance(fields, str): fields = [fields] return tuple( set(fields).union(cls.get_fields()).difference(cls.get_exclude_fields()) ) @classmethod def get_exclude_fields(cls): if cls.__cached_exclude_fields is None: fields = cls._get_fields_with_attr(cls, 'exclude_by_default') cls.__cached_exclude_fields = OrderedDict(sorted(fields.items())) return cls.__cached_exclude_fields @classmethod def get_dpath_translated_path(cls, path, separator='.'): if cls.__cached_dpath_computed_fields is None: cls.__cached_dpath_computed_fields = {} if path not in cls.__cached_dpath_computed_fields: with cls.__cached_dpath_computed_fields_lock: parts = path.split(separator) translated = cls._translate_fields_path(parts) result = separator.join(translated) cls.__cached_dpath_computed_fields[path] = result return cls.__cached_dpath_computed_fields[path] def get_field_value(self, field_path: str, default=None): """ Return the document field_path value by the field_path name. The path may contain '.'. If on any level the path is not found then the default value is returned """ path_elements = field_path.split(".") current = self for name in path_elements: current = getattr(current, name, default) if current == default: break return current