from operator import itemgetter
from sys import maxsize
from typing import Type, Tuple

import six
from mongoengine import (
    EmbeddedDocumentListField,
    ListField,
    FloatField,
    StringField,
    EmbeddedDocumentField,
    SortedListField,
    MapField,
    DictField,
    DynamicField,
)
from mongoengine.fields import key_not_string, key_starts_with_dollar, EmailField

NoneType = type(None)


class LengthRangeListField(ListField):
    def __init__(self, field=None, max_length=maxsize, min_length=0, **kwargs):
        self.__min_length = min_length
        self.__max_length = max_length
        super(LengthRangeListField, self).__init__(field, **kwargs)

    def validate(self, value):
        min, val, max = self.__min_length, len(value), self.__max_length
        if not min <= val <= max:
            self.error("Item count %d exceeds range [%d, %d]" % (val, min, max))
        super(LengthRangeListField, self).validate(value)


class LengthRangeEmbeddedDocumentListField(LengthRangeListField):
    def __init__(self, field=None, *args, **kwargs):
        super(LengthRangeEmbeddedDocumentListField, self).__init__(
            EmbeddedDocumentField(field), *args, **kwargs
        )


class UniqueEmbeddedDocumentListField(EmbeddedDocumentListField):
    def __init__(self, document_type, key, **kwargs):
        """
        Create a unique embedded document list field for a document type with a unique comparison key func/property
        :param document_type: The type of :class:`~mongoengine.EmbeddedDocument` the list will hold.
        :param key: A callable to extract a key from each item
        """
        if not callable(key):
            raise KeyError("key must be callable")
        self.__key = key
        super(UniqueEmbeddedDocumentListField, self).__init__(document_type)

    def validate(self, value):
        if len({self.__key(i) for i in value}) != len(value):
            self.error("Items with duplicate key exist in the list")
        super(UniqueEmbeddedDocumentListField, self).validate(value)


def object_to_key_value_pairs(obj):
    if isinstance(obj, dict):
        return [(key, object_to_key_value_pairs(value)) for key, value in obj.items()]
    if isinstance(obj, list):
        return list(map(object_to_key_value_pairs, obj))
    return obj


class EmbeddedDocumentSortedListField(EmbeddedDocumentListField):
    """
    A sorted list of embedded documents
    """

    def to_mongo(self, value, use_db_field=True, fields=None):
        value = super(EmbeddedDocumentSortedListField, self).to_mongo(
            value, use_db_field, fields
        )
        return sorted(value, key=object_to_key_value_pairs)


class LengthRangeSortedListField(LengthRangeListField, SortedListField):
    pass


class CustomFloatField(FloatField):
    def __init__(self, greater_than=None, **kwargs):
        self.greater_than = greater_than
        super(CustomFloatField, self).__init__(**kwargs)

    def validate(self, value):
        super(CustomFloatField, self).validate(value)

        if self.greater_than is not None and value <= self.greater_than:
            self.error("Float value must be greater than %s" % str(self.greater_than))


class CanonicEmailField(EmailField):
    """email field that is always lower cased"""
    def __set__(self, instance, value: str):
        if value is not None:
            try:
                value = value.lower()
            except AttributeError:
                pass
        super().__set__(instance, value)

    def prepare_query_value(self, op, value):
        if not isinstance(op, six.string_types):
            return value
        if value is not None:
            value = value.lower()
        return super().prepare_query_value(op, value)


class StrippedStringField(StringField):
    def __init__(
        self, regex=None, max_length=None, min_length=None, strip_chars=None, **kwargs
    ):
        super(StrippedStringField, self).__init__(
            regex, max_length, min_length, **kwargs
        )
        self._strip_chars = strip_chars

    def __set__(self, instance, value):
        if value is not None:
            try:
                value = value.strip(self._strip_chars)
            except AttributeError:
                pass
        super(StrippedStringField, self).__set__(instance, value)

    def prepare_query_value(self, op, value):
        if not isinstance(op, six.string_types):
            return value
        if value is not None:
            value = value.strip(self._strip_chars)
        return super(StrippedStringField, self).prepare_query_value(op, value)


def contains_empty_key(d):
    """
    Helper function to recursively determine if any key in a
    dictionary is empty (based on mongoengine.fields.key_not_string)
    """
    for k, v in list(d.items()):
        if not k or (isinstance(v, dict) and contains_empty_key(v)):
            return True


class DictValidationMixin:
    """
    DictField validation in MongoEngine requires default alias and permissions to access DB version:
    https://github.com/MongoEngine/mongoengine/issues/2239
    This is a stripped down implementation that does not require any of the above and implies Mongo ver 3.6+
    """

    def _safe_validate(self: DictField, value):
        if not isinstance(value, dict):
            self.error("Only dictionaries may be used in a DictField")

        if key_not_string(value):
            msg = "Invalid dictionary key - documents must have only string keys"
            self.error(msg)

        if key_starts_with_dollar(value):
            self.error(
                'Invalid dictionary key name - keys may not startswith "$" characters'
            )
        super(DictField, self).validate(value)


class SafeMapField(MapField, DictValidationMixin):
    def validate(self, value):
        self._safe_validate(value)

        if contains_empty_key(value):
            self.error("Empty keys are not allowed in a MapField")


class SafeDictField(DictField, DictValidationMixin):
    def validate(self, value):
        self._safe_validate(value)

        if contains_empty_key(value):
            self.error("Empty keys are not allowed in a DictField")


class SafeSortedListField(SortedListField):
    """
    SortedListField that does not raise an error in case items are not comparable
    (in which case they will be sorted by their string representation)
    """

    def to_mongo(self, *args, **kwargs):
        try:
            return super(SafeSortedListField, self).to_mongo(*args, **kwargs)
        except TypeError:
            return self._safe_to_mongo(*args, **kwargs)

    def _safe_to_mongo(self, value, use_db_field=True, fields=None):
        value = super(SortedListField, self).to_mongo(value, use_db_field, fields)
        if self._ordering is not None:

            def key(v):
                return str(itemgetter(self._ordering)(v))

        else:
            key = str
        return sorted(value, key=key, reverse=self._order_reverse)


class UnionField(DynamicField):
    def __init__(self, types, *args, **kwargs):
        super(UnionField, self).__init__(*args, **kwargs)
        self.types: Tuple[Type] = tuple(types)

    def validate(self, value, clean=True):
        if not isinstance(value, self.types):
            type_names = [t.__name__ for t in self.types]
            expected = " or ".join(
                filter(
                    None,
                    (", ".join(type_names[:-1]), type_names[-1]))
            )
            self.error(
                f"Expected {expected}, got {type(value).__name__}: {value}"
            )
        super(UnionField, self).validate(value, clean)