from concurrent.futures import ThreadPoolExecutor
from itertools import groupby, chain

import dpath

from apierrors import errors
from database.props import PropsMixin


def project_dict(data, projection, separator='.'):
    """
    Project partial data from a dictionary into a new dictionary
    :param data: Input dictionary
    :param projection: List of dictionary paths (each a string with field names separated using a separator)
    :param separator: Separator (default is '.')
    :return: A new dictionary containing only the projected parts from the original dictionary
    """
    assert isinstance(data, dict)
    result = {}

    def copy_path(path_parts, source, destination):
        src, dst = source, destination
        try:
            for depth, path_part in enumerate(path_parts[:-1]):
                src_part = src[path_part]
                if isinstance(src_part, dict):
                    src = src_part
                    dst = dst.setdefault(path_part, {})
                elif isinstance(src_part, (list, tuple)):
                    if path_part not in dst:
                        dst[path_part] = [{} for _ in range(len(src_part))]
                    elif not isinstance(dst[path_part], (list, tuple)):
                        raise TypeError('Incompatible destination type %s for %s (list expected)'
                                        % (type(dst), separator.join(path_parts[:depth + 1])))
                    elif not len(dst[path_part]) == len(src_part):
                        raise ValueError('Destination list length differs from source length for %s'
                                         % separator.join(path_parts[:depth + 1]))

                    dst[path_part] = [copy_path(path_parts[depth + 1:], s, d)
                                      for s, d in zip(src_part, dst[path_part])]

                    return destination
                else:
                    raise TypeError('Unsupported projection type %s for %s'
                                    % (type(src), separator.join(path_parts[:depth + 1])))

            last_part = path_parts[-1]
            dst[last_part] = src[last_part]
        except KeyError:
            # Projection field not in source, no biggie.
            pass
        return destination

    for projection_path in sorted(projection):
        copy_path(
            path_parts=projection_path.split(separator),
            source=data,
            destination=result)
    return result


class ProjectionHelper(object):
    pool = ThreadPoolExecutor()

    @property
    def doc_projection(self):
        return self._doc_projection

    def __init__(self, doc_cls, projection, expand_reference_ids=False):
        super(ProjectionHelper, self).__init__()
        self._should_expand_reference_ids = expand_reference_ids
        self._doc_cls = doc_cls
        self._doc_projection = None
        self._ref_projection = None
        self._parse_projection(projection)

    def _collect_projection_fields(self, doc_cls, projection):
        """
        Collect projection for the given document into immediate document projection and reference documents projection
        :param doc_cls: Document class
        :param projection: List of projection fields
        :return: A tuple of document projection and reference fields information
        """
        doc_projection = set()  # Projection fields for this class (used in the main query)
        ref_projection_info = []  # Projection information for reference fields (used in join queries)
        for field in projection:
            for ref_field, ref_field_cls in doc_cls.get_reference_fields().items():
                if not field.startswith(ref_field):
                    # Doesn't start with a reference field
                    continue
                if field == ref_field:
                    # Field is exactly a reference field. In this case we won't perform any inner projection (for that,
                    # use '<reference field name>.*')
                    continue
                subfield = field[len(ref_field):]
                if not subfield.startswith('.'):
                    # Starts with something that looks like a reference field, but isn't
                    continue

                ref_projection_info.append((ref_field, ref_field_cls, subfield[1:]))
                break
            else:
                # Not a reference field, just add to the top-level projection
                # We strip any trailing '*' since it means nothing for simple fields and for embedded documents
                orig_field = field
                if field.endswith('.*'):
                    field = field[:-2]
                if not field:
                    raise errors.bad_request.InvalidFields(field=orig_field, object=doc_cls.__name__)
                doc_projection.add(field)
        return doc_projection, ref_projection_info

    def _parse_projection(self, projection):
        """
        Prepare the projection data structures for get_many_with_join().
        :param projection: A list of field names that should be returned by the query. Sub-fields can be specified
            using '.' (i.e. "parent.name"). A field terminated by '.*' indicated that all of the field's sub-fields
            should be returned (only relevant for fields that represent sub-documents or referenced documents)
        :type projection: list of strings
        :returns A tuple of (class fields projection, reference fields projection)
        """
        doc_cls = self._doc_cls
        assert issubclass(doc_cls, PropsMixin)
        if not projection:
            return [], {}

        doc_projection, ref_projection_info = self._collect_projection_fields(doc_cls, projection)

        def normalize_cls_projection(cls_, fields):
            """ Normalize projection for this class and group (expand *, for once) """
            if '*' in fields:
                return list(fields.difference('*').union(cls_.get_fields()))
            return list(fields)

        def compute_ref_cls_projection(cls_, group):
            """ Compute inner projection for this class and group """
            subfields = set([x[2] for x in group if x[2]])
            return normalize_cls_projection(cls_, subfields)

        def sort_key(proj_info):
            return proj_info[:2]

        # Aggregate by reference field. We'll leave out '*' from the projected items since
        ref_projection = {
            ref_field: dict(cls=ref_cls, only=compute_ref_cls_projection(ref_cls, g))
            for (ref_field, ref_cls), g in groupby(sorted(ref_projection_info, key=sort_key), sort_key)
        }

        # Make sure this doesn't contain any reference field we'll join anyway
        # (i.e. in case only_fields=[project, project.name])
        doc_projection = normalize_cls_projection(doc_cls, doc_projection.difference(ref_projection).union({'id'}))

        # Make sure that in case one or more field is a subfield of another field, we only use the the top-level field.
        # This is done since in such a case, MongoDB will only use the most restrictive field (most nested field) and
        # won't return some of the data we need.
        # This way, we make sure to use the most inclusive field that contains all requested subfields.
        projection_set = set(doc_projection)
        doc_projection = [
            field
            for field in doc_projection
            if not any(field.startswith(f"{other_field}.") for other_field in projection_set - {field})
        ]

        # Make sure we didn't get any invalid projection fields for this class
        invalid_fields = [f for f in doc_projection if f.split('.')[0] not in doc_cls.get_fields()]
        if invalid_fields:
            raise errors.bad_request.InvalidFields(fields=invalid_fields, object=doc_cls.__name__)

        if ref_projection:
            # Join mode - use both normal projection fields and top-level reference fields
            doc_projection = set(doc_projection)
            for field in set(ref_projection).difference(doc_projection):
                if any(f for f in doc_projection if field.startswith(f)):
                    continue
                doc_projection.add(field)
            doc_projection = list(doc_projection)

        self._doc_projection = doc_projection
        self._ref_projection = ref_projection

    @staticmethod
    def _search(doc_cls, obj, path, only_values=True):
        """ Call dpath.search with yielded=True, collect result values """
        norm_path = doc_cls.get_dpath_translated_path(path)
        return [v if only_values else (k, v) for k, v in dpath.search(obj, norm_path, separator='.', yielded=True)]

    def project(self, results, projection_func):
        """
        Perform projection on query results, using the provided projection func.
        :param results: A list of results dictionaries on which projection should be performed
        :param projection_func: A callable that receives a document type, list of ids and projection and returns query
            results. This callable is used in order to perform sub-queries during projection
        :return: Modified results (in-place)
        """
        cls = self._doc_cls
        ref_projection = self._ref_projection

        if ref_projection:
            # Join mode - get results for each reference fields projection required (this is the join step)
            # Note: this is a recursive step, so we support nested reference fields

            def do_projection(item):
                ref_field_name, data = item
                res = {}
                ids = list(filter(None, set(chain.from_iterable(self._search(cls, res, ref_field_name)
                                                                for res in results))))
                if ids:
                    doc_type = data['cls']
                    doc_only = list(filter(None, data['only']))
                    doc_only = list({'id'} | set(doc_only)) if doc_only else None
                    res = {r['id']: r for r in projection_func(doc_type=doc_type, projection=doc_only, ids=ids)}
                data['res'] = res

            items = list(ref_projection.items())
            if len(ref_projection) == 1:
                do_projection(items[0])
            else:
                for _ in self.pool.map(do_projection, items):
                    # From ThreadPoolExecutor.map() documentation: If a call raises an exception then that exception
                    #  will be raised when its value is retrieved from the map() iterator
                    pass

        def do_expand_reference_ids(result, skip_fields=None):
            ref_fields = cls.get_reference_fields()
            if skip_fields:
                ref_fields = set(ref_fields) - set(skip_fields)
            self._expand_reference_fields(cls, result, ref_fields)

        def merge_projection_result(result):
            for ref_field_name, data in ref_projection.items():
                res = data.get('res')
                if not res:
                    self._expand_reference_fields(cls, result, [ref_field_name])
                    continue
                ref_ids = self._search(cls, result, ref_field_name, only_values=False)
                if not ref_ids:
                    continue
                for path, value in ref_ids:
                    obj = res.get(value) or {'id': value}
                    dpath.new(result, path, obj, separator='.')

            # any reference field not projected should be expanded
            do_expand_reference_ids(result, skip_fields=list(ref_projection))

        update_func = merge_projection_result if ref_projection else \
            do_expand_reference_ids if self._should_expand_reference_ids else None

        if update_func:
            for result in results:
                update_func(result)

        return results

    @classmethod
    def _expand_reference_fields(cls, doc_cls, result, fields):
        for ref_field_name in fields:
            ref_ids = cls._search(doc_cls, result, ref_field_name, only_values=False)
            if not ref_ids:
                continue
            for path, value in ref_ids:
                dpath.set(
                    result,
                    path,
                    {'id': value} if value else {},
                    separator='.')

    @classmethod
    def expand_reference_ids(cls, doc_cls, result):
        cls._expand_reference_fields(doc_cls, result, doc_cls.get_reference_fields())