from itertools import chain
from operator import attrgetter
from typing import Sequence, Dict

from boltons import iterutils

from apiserver.apierrors import errors
from apiserver.apimodels.tasks import (
    HyperParamKey,
    HyperParamItem,
    ReplaceHyperparams,
    Configuration,
)
from apiserver.bll.task import TaskBLL
from apiserver.bll.task.utils import get_task_for_update, update_task
from apiserver.config_repo import config
from apiserver.database.model.task.task import ParamsItem, Task, ConfigurationItem
from apiserver.timing_context import TimingContext
from apiserver.utilities.parameter_key_escaper import (
    ParameterKeyEscaper,
    mongoengine_safe,
)

log = config.logger(__file__)
task_bll = TaskBLL()


class HyperParams:
    _properties_section = "properties"

    @classmethod
    def get_params(cls, company_id: str, task_ids: Sequence[str]) -> Dict[str, dict]:
        only = ("id", "hyperparams")
        tasks = task_bll.assert_exists(
            company_id=company_id, task_ids=task_ids, only=only, allow_public=True,
        )

        return {
            task.id: {"hyperparams": cls._get_params_list(items=task.hyperparams)}
            for task in tasks
        }

    @classmethod
    def _get_params_list(
        cls, items: Dict[str, Dict[str, ParamsItem]]
    ) -> Sequence[dict]:
        ret = list(chain.from_iterable(v.values() for v in items.values()))
        return [
            p.to_proper_dict() for p in sorted(ret, key=attrgetter("section", "name"))
        ]

    @classmethod
    def _normalize_params(cls, params: Sequence) -> bool:
        """
        Lower case properties section and return True if it is the only section
        """
        for p in params:
            if p.section.lower() == cls._properties_section:
                p.section = cls._properties_section

        return all(p.section == cls._properties_section for p in params)

    @classmethod
    def delete_params(
        cls,
        company_id: str,
        task_id: str,
        hyperparams: Sequence[HyperParamKey],
        force: bool,
    ) -> int:
        with TimingContext("mongo", "delete_hyperparams"):
            properties_only = cls._normalize_params(hyperparams)
            task = get_task_for_update(
                company_id=company_id,
                task_id=task_id,
                allow_all_statuses=properties_only,
                force=force,
            )

            with_param, without_param = iterutils.partition(
                hyperparams, key=lambda p: bool(p.name)
            )
            sections_to_delete = {p.section for p in without_param}
            delete_cmds = {
                f"unset__hyperparams__{ParameterKeyEscaper.escape(section)}": 1
                for section in sections_to_delete
            }

            for item in with_param:
                section = ParameterKeyEscaper.escape(item.section)
                if item.section in sections_to_delete:
                    raise errors.bad_request.FieldsConflict(
                        "Cannot delete section field if the whole section was scheduled for deletion"
                    )
                name = ParameterKeyEscaper.escape(item.name)
                delete_cmds[f"unset__hyperparams__{section}__{name}"] = 1

            return update_task(
                task, update_cmds=delete_cmds, set_last_update=not properties_only
            )

    @classmethod
    def edit_params(
        cls,
        company_id: str,
        task_id: str,
        hyperparams: Sequence[HyperParamItem],
        replace_hyperparams: str,
        force: bool,
    ) -> int:
        with TimingContext("mongo", "edit_hyperparams"):
            properties_only = cls._normalize_params(hyperparams)
            task = get_task_for_update(
                company_id=company_id,
                task_id=task_id,
                allow_all_statuses=properties_only,
                force=force,
            )

            update_cmds = dict()
            hyperparams = cls._db_dicts_from_list(hyperparams)
            if replace_hyperparams == ReplaceHyperparams.all:
                update_cmds["set__hyperparams"] = hyperparams
            elif replace_hyperparams == ReplaceHyperparams.section:
                for section, value in hyperparams.items():
                    update_cmds[
                        f"set__hyperparams__{mongoengine_safe(section)}"
                    ] = value
            else:
                for section, section_params in hyperparams.items():
                    for name, value in section_params.items():
                        update_cmds[
                            f"set__hyperparams__{section}__{mongoengine_safe(name)}"
                        ] = value

            return update_task(
                task, update_cmds=update_cmds, set_last_update=not properties_only
            )

    @classmethod
    def _db_dicts_from_list(cls, items: Sequence[HyperParamItem]) -> Dict[str, dict]:
        sections = iterutils.bucketize(items, key=attrgetter("section"))
        return {
            ParameterKeyEscaper.escape(section): {
                ParameterKeyEscaper.escape(param.name): ParamsItem(**param.to_struct())
                for param in params
            }
            for section, params in sections.items()
        }

    @classmethod
    def get_configurations(
        cls, company_id: str, task_ids: Sequence[str], names: Sequence[str]
    ) -> Dict[str, dict]:
        only = ["id"]
        if names:
            only.extend(
                f"configuration.{ParameterKeyEscaper.escape(name)}" for name in names
            )
        else:
            only.append("configuration")
        tasks = task_bll.assert_exists(
            company_id=company_id, task_ids=task_ids, only=only, allow_public=True,
        )

        return {
            task.id: {
                "configuration": [
                    c.to_proper_dict()
                    for c in sorted(task.configuration.values(), key=attrgetter("name"))
                ]
            }
            for task in tasks
        }

    @classmethod
    def get_configuration_names(
        cls, company_id: str, task_ids: Sequence[str]
    ) -> Dict[str, list]:
        with TimingContext("mongo", "get_configuration_names"):
            pipeline = [
                {
                    "$match": {
                        "company": {"$in": [None, "", company_id]},
                        "_id": {"$in": task_ids},
                    }
                },
                {"$project": {"items": {"$objectToArray": "$configuration"}}},
                {"$unwind": "$items"},
                {"$group": {"_id": "$_id", "names": {"$addToSet": "$items.k"}}},
            ]

            tasks = Task.aggregate(pipeline)

            return {
                task["_id"]: {
                    "names": sorted(
                        ParameterKeyEscaper.unescape(name) for name in task["names"]
                    )
                }
                for task in tasks
            }

    @classmethod
    def edit_configuration(
        cls,
        company_id: str,
        task_id: str,
        configuration: Sequence[Configuration],
        replace_configuration: bool,
        force: bool,
    ) -> int:
        with TimingContext("mongo", "edit_configuration"):
            task = get_task_for_update(
                company_id=company_id, task_id=task_id, force=force
            )

            update_cmds = dict()
            configuration = {
                ParameterKeyEscaper.escape(c.name): ConfigurationItem(**c.to_struct())
                for c in configuration
            }
            if replace_configuration:
                update_cmds["set__configuration"] = configuration
            else:
                for name, value in configuration.items():
                    update_cmds[f"set__configuration__{mongoengine_safe(name)}"] = value

            return update_task(task, update_cmds=update_cmds)

    @classmethod
    def delete_configuration(
        cls, company_id: str, task_id: str, configuration: Sequence[str], force: bool
    ) -> int:
        with TimingContext("mongo", "delete_configuration"):
            task = get_task_for_update(
                company_id=company_id, task_id=task_id, force=force
            )

            delete_cmds = {
                f"unset__configuration__{ParameterKeyEscaper.escape(name)}": 1
                for name in set(configuration)
            }

            return update_task(task, update_cmds=delete_cmds)