import abc
import operator
from operator import attrgetter
from typing import Sequence, Tuple, Optional, Callable, Mapping

import attr
from boltons.iterutils import first, bucketize
from elasticsearch import Elasticsearch
from jsonmodels.fields import StringField, ListField, IntField, BoolField
from jsonmodels.models import Base
from redis import StrictRedis

from apiserver.apierrors import errors
from apiserver.apimodels import JsonSerializableMixin
from apiserver.bll.event.event_common import (
    EventSettings,
    EventType,
    check_empty_data,
    search_company_events,
    get_max_metric_and_variant_counts,
)
from apiserver.bll.redis_cache_manager import RedisCacheManager
from apiserver.database.errors import translate_errors_context
from apiserver.timing_context import TimingContext
from apiserver.utilities.dicts import nested_get


class VariantState(Base):
    name: str = StringField(required=True)
    metric: str = StringField(default=None)
    min_iteration: int = IntField()
    max_iteration: int = IntField()


class HistorySampleState(Base, JsonSerializableMixin):
    id: str = StringField(required=True)
    iteration: int = IntField()
    variant: str = StringField()
    task: str = StringField()
    metric: str = StringField()
    reached_first: bool = BoolField()
    reached_last: bool = BoolField()
    variant_states: Sequence[VariantState] = ListField([VariantState])
    warning: str = StringField()
    navigate_current_metric = BoolField(default=True)


@attr.s(auto_attribs=True)
class HistorySampleResult(object):
    scroll_id: str = None
    event: dict = None
    min_iteration: int = None
    max_iteration: int = None


class HistorySampleIterator(abc.ABC):
    def __init__(self, redis: StrictRedis, es: Elasticsearch, event_type: EventType):
        self.es = es
        self.event_type = event_type
        self.cache_manager = RedisCacheManager(
            state_class=HistorySampleState,
            redis=redis,
            expiration_interval=EventSettings.state_expiration_sec,
        )

    def get_next_sample(
        self, company_id: str, task: str, state_id: str, navigate_earlier: bool
    ) -> HistorySampleResult:
        """
        Get the sample for next/prev variant on the current iteration
        If does not exist then try getting sample for the first/last variant from next/prev iteration
        """
        res = HistorySampleResult(scroll_id=state_id)
        state = self.cache_manager.get_state(state_id)
        if not state or state.task != task:
            raise errors.bad_request.InvalidScrollId(scroll_id=state_id)

        if check_empty_data(self.es, company_id=company_id, event_type=self.event_type):
            return res

        event = self._get_next_for_current_iteration(
            company_id=company_id, navigate_earlier=navigate_earlier, state=state
        ) or self._get_next_for_another_iteration(
            company_id=company_id, navigate_earlier=navigate_earlier, state=state
        )
        if not event:
            return res

        self._fill_res_and_update_state(event=event, res=res, state=state)
        self.cache_manager.set_state(state=state)
        return res

    def _fill_res_and_update_state(
        self, event: dict, res: HistorySampleResult, state: HistorySampleState
    ):
        self._process_event(event)
        state.variant = event["variant"]
        state.metric = event["metric"]
        state.iteration = event["iter"]
        res.event = event
        var_state = first(
            vs
            for vs in state.variant_states
            if vs.name == state.variant and vs.metric == state.metric
        )
        if var_state:
            res.min_iteration = var_state.min_iteration
            res.max_iteration = var_state.max_iteration

    @abc.abstractmethod
    def _get_extra_conditions(self) -> Sequence[dict]:
        pass

    @abc.abstractmethod
    def _process_event(self, event: dict) -> dict:
        pass

    @abc.abstractmethod
    def _get_variants_conditions(self, variants: Sequence[VariantState]) -> dict:
        pass

    def _get_metric_variants_condition(self, variants: Sequence[VariantState]) -> dict:
        metrics = bucketize(variants, key=attrgetter("metric"))
        metrics_conditions = [
            {
                "bool": {
                    "must": [
                        {"term": {"metric": metric}},
                        self._get_variants_conditions(vs),
                    ]
                }
            }
            for metric, vs in metrics.items()
        ]
        return {"bool": {"should": metrics_conditions}}

    def _get_next_for_current_iteration(
        self, company_id: str, navigate_earlier: bool, state: HistorySampleState
    ) -> Optional[dict]:
        """
        Get the sample for next (if navigate_earlier is False) or previous variant sorted by name for the same iteration
        Only variants for which the iteration falls into their valid range are considered
        Return None if no such variant or sample is found
        """
        if state.navigate_current_metric:
            variants = [
                var_state
                for var_state in state.variant_states
                if var_state.metric == state.metric
            ]
        else:
            variants = state.variant_states

        cmp = operator.lt if navigate_earlier else operator.gt
        variants = [
            var_state
            for var_state in variants
            if cmp((var_state.metric, var_state.name), (state.metric, state.variant))
            and var_state.min_iteration <= state.iteration
        ]
        if not variants:
            return

        must_conditions = [
            {"term": {"task": state.task}},
            {"term": {"iter": state.iteration}},
            self._get_metric_variants_condition(variants),
            *self._get_extra_conditions(),
        ]
        order = "desc" if navigate_earlier else "asc"
        es_req = {
            "size": 1,
            "sort": [{"metric": order}, {"variant": order}],
            "query": {"bool": {"must": must_conditions}},
        }

        with translate_errors_context(), TimingContext(
            "es", "get_next_for_current_iteration"
        ):
            es_res = search_company_events(
                self.es,
                company_id=company_id,
                event_type=self.event_type,
                body=es_req,
            )

        hits = nested_get(es_res, ("hits", "hits"))
        if not hits:
            return

        return hits[0]["_source"]

    def _get_next_for_another_iteration(
        self, company_id: str, navigate_earlier: bool, state: HistorySampleState
    ) -> Optional[dict]:
        """
        Get the sample for the first variant for the next iteration (if navigate_earlier is set to False)
        or from the last variant for the previous iteration (otherwise)
        The variants for which the sample falls in invalid range are discarded
        If no suitable sample is found then None is returned
        """
        if state.navigate_current_metric:
            variants = [
                var_state
                for var_state in state.variant_states
                if var_state.metric == state.metric
            ]
        else:
            variants = state.variant_states

        if navigate_earlier:
            range_operator = "lt"
            order = "desc"
            variants = [
                var_state
                for var_state in variants
                if var_state.min_iteration < state.iteration
            ]
        else:
            range_operator = "gt"
            order = "asc"
            variants = variants

        if not variants:
            return

        must_conditions = [
            {"term": {"task": state.task}},
            self._get_metric_variants_condition(variants),
            {"range": {"iter": {range_operator: state.iteration}}},
            *self._get_extra_conditions(),
        ]
        es_req = {
            "size": 1,
            "sort": [{"iter": order}, {"metric": order}, {"variant": order}],
            "query": {"bool": {"must": must_conditions}},
        }
        with translate_errors_context(), TimingContext(
            "es", "get_next_for_another_iteration"
        ):
            es_res = search_company_events(
                self.es,
                company_id=company_id,
                event_type=self.event_type,
                body=es_req,
            )

        hits = nested_get(es_res, ("hits", "hits"))
        if not hits:
            return

        return hits[0]["_source"]

    def get_sample_for_variant(
        self,
        company_id: str,
        task: str,
        metric: str,
        variant: str,
        iteration: Optional[int] = None,
        refresh: bool = False,
        state_id: str = None,
        navigate_current_metric: bool = True,
    ) -> HistorySampleResult:
        """
        Get the sample for the requested iteration or the latest before it
        If the iteration is not passed then get the latest event
        """
        res = HistorySampleResult()
        if check_empty_data(self.es, company_id=company_id, event_type=self.event_type):
            return res

        def init_state(state_: HistorySampleState):
            state_.task = task
            state_.metric = metric
            state_.navigate_current_metric = navigate_current_metric
            self._reset_variant_states(company_id=company_id, state=state_)

        def validate_state(state_: HistorySampleState):
            if (
                state_.task != task
                or state_.navigate_current_metric != navigate_current_metric
                or (state_.navigate_current_metric and state_.metric != metric)
            ):
                raise errors.bad_request.InvalidScrollId(
                    "Task and metric stored in the state do not match the passed ones",
                    scroll_id=state_.id,
                )
            # fix old variant states:
            for vs in state_.variant_states:
                if vs.metric is None:
                    vs.metric = metric
            if refresh:
                self._reset_variant_states(company_id=company_id, state=state_)

        state: HistorySampleState
        with self.cache_manager.get_or_create_state(
            state_id=state_id, init_state=init_state, validate_state=validate_state,
        ) as state:
            res.scroll_id = state.id

            var_state = first(
                vs
                for vs in state.variant_states
                if vs.name == variant and vs.metric == metric
            )
            if not var_state:
                return res

            res.min_iteration = var_state.min_iteration
            res.max_iteration = var_state.max_iteration

            must_conditions = [
                {"term": {"task": task}},
                {"term": {"metric": metric}},
                {"term": {"variant": variant}},
                *self._get_extra_conditions(),
            ]
            if iteration is not None:
                must_conditions.append(
                    {
                        "range": {
                            "iter": {"lte": iteration, "gte": var_state.min_iteration}
                        }
                    }
                )
            else:
                must_conditions.append(
                    {"range": {"iter": {"gte": var_state.min_iteration}}}
                )

            es_req = {
                "size": 1,
                "sort": {"iter": "desc"},
                "query": {"bool": {"must": must_conditions}},
            }

            with translate_errors_context(), TimingContext(
                "es", "get_history_sample_for_variant"
            ):
                es_res = search_company_events(
                    self.es,
                    company_id=company_id,
                    event_type=self.event_type,
                    body=es_req,
                )

            hits = nested_get(es_res, ("hits", "hits"))
            if not hits:
                return res

            self._fill_res_and_update_state(
                event=hits[0]["_source"], res=res, state=state
            )
            return res

    def _reset_variant_states(self, company_id: str, state: HistorySampleState):
        metrics = self._get_metric_variant_iterations(
            company_id=company_id,
            task=state.task,
            metric=state.metric if state.navigate_current_metric else None,
        )
        state.variant_states = [
            VariantState(
                metric=metric,
                name=var_name,
                min_iteration=min_iter,
                max_iteration=max_iter,
            )
            for metric, variants in metrics.items()
            for var_name, min_iter, max_iter in variants
        ]

    @abc.abstractmethod
    def _get_min_max_aggs(self) -> Tuple[dict, Callable[[dict], Tuple[int, int]]]:
        pass

    def _get_metric_variant_iterations(
        self, company_id: str, task: str, metric: str,
    ) -> Mapping[str, Tuple[str, str, int, int]]:
        """
        Return valid min and max iterations that the task reported events of the required type
        """
        must = [
            {"term": {"task": task}},
            *self._get_extra_conditions(),
        ]
        if metric is not None:
            must.append({"term": {"metric": metric}})
        query = {"bool": {"must": must}}

        search_args = dict(
            es=self.es, company_id=company_id, event_type=self.event_type
        )
        max_metrics, max_variants = get_max_metric_and_variant_counts(
            query=query, **search_args
        )
        max_variants = int(max_variants // 2)
        min_max_aggs, get_min_max_data = self._get_min_max_aggs()
        es_req: dict = {
            "size": 0,
            "query": query,
            "aggs": {
                "metrics": {
                    "terms": {
                        "field": "metric",
                        "size": max_metrics,
                        "order": {"_key": "asc"},
                    },
                    "aggs": {
                        "variants": {
                            "terms": {
                                "field": "variant",
                                "size": max_variants,
                                "order": {"_key": "asc"},
                            },
                            "aggs": min_max_aggs,
                        }
                    },
                }
            },
        }

        with translate_errors_context(), TimingContext(
            "es", "get_history_sample_iterations"
        ):
            es_res = search_company_events(body=es_req, **search_args,)

        def get_variant_data(variant_bucket: dict) -> Tuple[str, int, int]:
            variant = variant_bucket["key"]
            min_iter, max_iter = get_min_max_data(variant_bucket)
            return variant, min_iter, max_iter

        return {
            metric_bucket["key"]: [
                get_variant_data(variant_bucket)
                for variant_bucket in nested_get(metric_bucket, ("variants", "buckets"))
            ]
            for metric_bucket in nested_get(
                es_res, ("aggregations", "metrics", "buckets")
            )
        }