from collections import OrderedDict, defaultdict
from itertools import chain
from operator import attrgetter
from threading import Lock
from typing import Sequence

import six
from mongoengine import EmbeddedDocumentField, EmbeddedDocumentListField
from mongoengine.base import get_document, BaseField

from database.fields import (
    LengthRangeEmbeddedDocumentListField,
    UniqueEmbeddedDocumentListField,
    EmbeddedDocumentSortedListField,
)
from database.utils import get_fields, get_fields_attr


class PropsMixin(object):
    __cached_fields = None
    __cached_reference_fields = None
    __cached_exclude_fields = None
    __cached_fields_with_instance = None
    __cached_field_names_per_type = None

    __cached_dpath_computed_fields_lock = Lock()
    __cached_dpath_computed_fields = None

    @classmethod
    def get_fields(cls):
        if cls.__cached_fields is None:
            cls.__cached_fields = get_fields(cls)
        return cls.__cached_fields

    @classmethod
    def get_field_names_for_type(cls, of_type=BaseField):
        """
        Return field names per type including subfields
        The fields of derived types are also returned
        """
        assert issubclass(of_type, BaseField)
        if cls.__cached_field_names_per_type is None:
            fields = defaultdict(list)
            for name, field in get_fields(cls, return_instance=True, subfields=True):
                fields[type(field)].append(name)
            for type_ in fields:
                fields[type_].extend(
                    chain.from_iterable(
                        fields[other_type]
                        for other_type in fields
                        if other_type != type_ and issubclass(other_type, type_)
                    )
                )
            cls.__cached_field_names_per_type = fields

        if of_type not in cls.__cached_field_names_per_type:
            names = list(
                chain.from_iterable(
                    field_names
                    for type_, field_names in cls.__cached_field_names_per_type.items()
                    if issubclass(type_, of_type)
                )
            )
            cls.__cached_field_names_per_type[of_type] = names

        return cls.__cached_field_names_per_type[of_type]

    @classmethod
    def get_fields_with_instance(cls, doc_cls):
        if cls.__cached_fields_with_instance is None:
            cls.__cached_fields_with_instance = {}
        if doc_cls not in cls.__cached_fields_with_instance:
            cls.__cached_fields_with_instance[doc_cls] = get_fields(
                doc_cls, return_instance=True
            )
        return cls.__cached_fields_with_instance[doc_cls]

    @staticmethod
    def _get_fields_with_attr(cls_, attr):
        """ Get all fields with the specified attribute (supports nested fields) """
        res = get_fields_attr(cls_, attr=attr)

        def resolve_doc(v):
            if not isinstance(v, six.string_types):
                return v
            if v == 'self':
                return cls_.owner_document
            return get_document(v)

        fields = {k: resolve_doc(v) for k, v in res.items()}

        def collect_embedded_docs(doc_cls, embedded_doc_field_getter):
            for field, embedded_doc_field in get_fields(
                cls_, of_type=doc_cls, return_instance=True
            ):
                embedded_doc_cls = embedded_doc_field_getter(
                    embedded_doc_field
                ).document_type
                fields.update(
                    {
                        '.'.join((field, subfield)): doc
                        for subfield, doc in PropsMixin._get_fields_with_attr(
                            embedded_doc_cls, attr
                        ).items()
                    }
                )

        collect_embedded_docs(EmbeddedDocumentField, lambda x: x)
        collect_embedded_docs(EmbeddedDocumentListField, attrgetter('field'))
        collect_embedded_docs(LengthRangeEmbeddedDocumentListField, attrgetter('field'))
        collect_embedded_docs(UniqueEmbeddedDocumentListField, attrgetter('field'))
        collect_embedded_docs(EmbeddedDocumentSortedListField, attrgetter('field'))

        return fields

    @classmethod
    def _translate_fields_path(cls, parts):
        current_cls = cls
        translated_parts = []
        for depth, part in enumerate(parts):
            if current_cls is None:
                raise ValueError(
                    'Invalid path (non-document encountered at %s)' % parts[: depth - 1]
                )
            try:
                field_name, field = next(
                    (k, v)
                    for k, v in cls.get_fields_with_instance(current_cls)
                    if k == part
                )
            except StopIteration:
                raise ValueError('Invalid field path %s' % parts[:depth])

            translated_parts.append(part)

            if isinstance(field, EmbeddedDocumentField):
                current_cls = field.document_type
            elif isinstance(
                field,
                (
                    EmbeddedDocumentListField,
                    LengthRangeEmbeddedDocumentListField,
                    UniqueEmbeddedDocumentListField,
                    EmbeddedDocumentSortedListField,
                ),
            ):
                current_cls = field.field.document_type
                translated_parts.append('*')
            else:
                current_cls = None

        return translated_parts

    @classmethod
    def get_reference_fields(cls):
        if cls.__cached_reference_fields is None:
            fields = cls._get_fields_with_attr(cls, 'reference_field')
            cls.__cached_reference_fields = OrderedDict(sorted(fields.items()))
        return cls.__cached_reference_fields

    @classmethod
    def get_extra_projection(cls, fields: Sequence) -> tuple:
        if isinstance(fields, str):
            fields = [fields]
        return tuple(
            set(fields).union(cls.get_fields()).difference(cls.get_exclude_fields())
        )

    @classmethod
    def get_exclude_fields(cls):
        if cls.__cached_exclude_fields is None:
            fields = cls._get_fields_with_attr(cls, 'exclude_by_default')
            cls.__cached_exclude_fields = OrderedDict(sorted(fields.items()))
        return cls.__cached_exclude_fields

    @classmethod
    def get_dpath_translated_path(cls, path, separator='.'):
        if cls.__cached_dpath_computed_fields is None:
            cls.__cached_dpath_computed_fields = {}
        if path not in cls.__cached_dpath_computed_fields:
            with cls.__cached_dpath_computed_fields_lock:
                parts = path.split(separator)
                translated = cls._translate_fields_path(parts)
                result = separator.join(translated)
                cls.__cached_dpath_computed_fields[path] = result
        return cls.__cached_dpath_computed_fields[path]

    def get_field_value(self, field_path: str, default=None):
        """
        Return the document field_path value by the field_path name.
        The path may contain '.'. If on any level the path is
        not found then the default value is returned
        """
        path_elements = field_path.split(".")
        current = self
        for name in path_elements:
            current = getattr(current, name, default)
            if current == default:
                break

        return current