from collections import defaultdict
from datetime import datetime
from typing import Callable, Sequence, Optional, Tuple

from elasticsearch import Elasticsearch

import database
import es_factory
from apierrors import errors
from bll.queue.queue_metrics import QueueMetrics
from bll.workers import WorkerBLL
from config import config
from database.errors import translate_errors_context
from database.model.queue import Queue, Entry

log = config.logger(__file__)


class QueueBLL(object):
    def __init__(self, worker_bll: WorkerBLL = None, es: Elasticsearch = None):
        self.worker_bll = worker_bll or WorkerBLL()
        self.es = es or es_factory.connect("workers")
        self._metrics = QueueMetrics(self.es)

    @property
    def metrics(self) -> QueueMetrics:
        return self._metrics

    @staticmethod
    def create(
        company_id: str,
        name: str,
        tags: Optional[Sequence[str]] = None,
        system_tags: Optional[Sequence[str]] = None,
    ) -> Queue:
        """Creates a queue"""
        with translate_errors_context():
            now = datetime.utcnow()
            queue = Queue(
                id=database.utils.id(),
                company=company_id,
                created=now,
                name=name,
                tags=tags or [],
                system_tags=system_tags or [],
                last_update=now,
            )
            queue.save()
            return queue

    def get_by_id(
        self, company_id: str, queue_id: str, only: Optional[Sequence[str]] = None
    ) -> Queue:
        """
        Get queue by id
        :raise errors.bad_request.InvalidQueueId: if the queue is not found
        """
        with translate_errors_context():
            query = dict(id=queue_id, company=company_id)
            qs = Queue.objects(**query)
            if only:
                qs = qs.only(*only)
            queue = qs.first()
            if not queue:
                raise errors.bad_request.InvalidQueueId(**query)

            return queue

    @classmethod
    def get_queue_with_task(cls, company_id: str, queue_id: str, task_id: str) -> Queue:
        with translate_errors_context():
            query = dict(id=queue_id, company=company_id)
            queue = Queue.objects(entries__task=task_id, **query).first()
            if not queue:
                raise errors.bad_request.InvalidQueueOrTaskNotQueued(
                    task=task_id, **query
                )

            return queue

    def get_default(self, company_id: str) -> Queue:
        """
        Get the default queue
        :raise errors.bad_request.NoDefaultQueue: if the default queue not found
        :raise errors.bad_request.MultipleDefaultQueues: if more than one default queue is found
        """
        with translate_errors_context():
            res = Queue.objects(company=company_id, system_tags="default").only(
                "id", "name"
            )
            if not res:
                raise errors.bad_request.NoDefaultQueue()
            if len(res) > 1:
                raise errors.bad_request.MultipleDefaultQueues(
                    queues=tuple(r.id for r in res)
                )

            return res.first()

    def update(
        self, company_id: str, queue_id: str, **update_fields
    ) -> Tuple[int, dict]:
        """
        Partial update of the queue from update_fields
        :raise errors.bad_request.InvalidQueueId: if the queue is not found
        :return: number of updated objects and updated fields dictionary
        """
        with translate_errors_context():
            # validate the queue exists
            self.get_by_id(company_id=company_id, queue_id=queue_id, only=("id",))
            return Queue.safe_update(company_id, queue_id, update_fields)

    def delete(self, company_id: str, queue_id: str, force: bool) -> None:
        """
        Delete the queue
        :raise errors.bad_request.InvalidQueueId: if the queue is not found
        :raise errors.bad_request.QueueNotEmpty: if the queue is not empty and 'force' not set
        """
        with translate_errors_context():
            queue = self.get_by_id(company_id=company_id, queue_id=queue_id)
            if queue.entries and not force:
                raise errors.bad_request.QueueNotEmpty(
                    "use force=true to delete", id=queue_id
                )
            queue.delete()

    def get_all(self, company_id: str, query_dict: dict) -> Sequence[dict]:
        """Get all the queues according to the query"""
        with translate_errors_context():
            return Queue.get_many(
                company=company_id, parameters=query_dict, query_dict=query_dict
            )

    def get_queue_infos(self, company_id: str, query_dict: dict) -> Sequence[dict]:
        """
        Get infos on all the company queues, including queue tasks and workers
        """
        projection = Queue.get_extra_projection("entries.task.name")
        with translate_errors_context():
            res = Queue.get_many_with_join(
                company=company_id,
                query_dict=query_dict,
                override_projection=projection,
            )

            queue_workers = defaultdict(list)
            for worker in self.worker_bll.get_all(company_id):
                for queue in worker.queues:
                    queue_workers[queue].append(worker)

            for item in res:
                item["workers"] = [
                    {
                        "name": w.id,
                        "ip": w.ip,
                        "task": w.task.to_struct() if w.task else None,
                    }
                    for w in queue_workers.get(item["id"], [])
                ]

        return res

    def add_task(self, company_id: str, queue_id: str, task_id: str) -> dict:
        """
        Add the task to the queue and return the queue update results
        :raise errors.bad_request.TaskAlreadyQueued: if the task is already in the queue
        :raise errors.bad_request.InvalidQueueOrTaskNotQueued: if the queue update operation failed
        """
        with translate_errors_context():
            queue = self.get_by_id(company_id=company_id, queue_id=queue_id)
            if any(e.task == task_id for e in queue.entries):
                raise errors.bad_request.TaskAlreadyQueued(task=task_id)

            self.metrics.log_queue_metrics_to_es(company_id=company_id, queues=[queue])

            entry = Entry(added=datetime.utcnow(), task=task_id)
            query = dict(id=queue_id, company=company_id)
            res = Queue.objects(entries__task__ne=task_id, **query).update_one(
                push__entries=entry, last_update=datetime.utcnow(), upsert=False
            )
            if not res:
                raise errors.bad_request.InvalidQueueOrTaskNotQueued(
                    task=task_id, **query
                )

            return res

    def get_next_task(self, company_id: str, queue_id: str) -> Optional[Entry]:
        """
        Atomically pop and return the first task from the queue (or None)
        :raise errors.bad_request.InvalidQueueId: if the queue does not exist
        """
        with translate_errors_context():
            query = dict(id=queue_id, company=company_id)
            queue = Queue.objects(**query).modify(pop__entries=-1, upsert=False)
            if not queue:
                raise errors.bad_request.InvalidQueueId(**query)

            self.metrics.log_queue_metrics_to_es(company_id, queues=[queue])

            if not queue.entries:
                return

            try:
                Queue.objects(**query).update(last_update=datetime.utcnow())
            except Exception:
                log.exception("Error while updating Queue.last_update")

            return queue.entries[0]

    def remove_task(self, company_id: str, queue_id: str, task_id: str) -> int:
        """
        Removes the task from the queue and returns the number of removed items
        :raise errors.bad_request.InvalidQueueOrTaskNotQueued: if the task is not found in the queue
        """
        with translate_errors_context():
            queue = self.get_queue_with_task(
                company_id=company_id, queue_id=queue_id, task_id=task_id
            )
            self.metrics.log_queue_metrics_to_es(company_id, queues=[queue])

            entries_to_remove = [e for e in queue.entries if e.task == task_id]
            query = dict(id=queue_id, company=company_id)
            res = Queue.objects(entries__task=task_id, **query).update_one(
                pull_all__entries=entries_to_remove, last_update=datetime.utcnow()
            )

            return len(entries_to_remove) if res else 0

    def reposition_task(
        self,
        company_id: str,
        queue_id: str,
        task_id: str,
        pos_func: Callable[[int], int],
    ) -> int:
        """
        Moves the task in the queue to the position calculated by pos_func
        Returns the updated task position in the queue
        """
        with translate_errors_context():
            queue = self.get_queue_with_task(
                company_id=company_id, queue_id=queue_id, task_id=task_id
            )

            position = next(i for i, e in enumerate(queue.entries) if e.task == task_id)
            new_position = pos_func(position)

            if new_position != position:
                entry = queue.entries[position]
                query = dict(id=queue_id, company=company_id)
                updated = Queue.objects(entries__task=task_id, **query).update_one(
                    pull__entries=entry, last_update=datetime.utcnow()
                )
                if not updated:
                    raise errors.bad_request.RemovedDuringReposition(
                        task=task_id, **query
                    )
                inst = {"$push": {"entries": {"$each": [entry.to_proper_dict()]}}}
                if new_position >= 0:
                    inst["$push"]["entries"]["$position"] = new_position
                res = Queue.objects(entries__task__ne=task_id, **query).update_one(
                    __raw__=inst
                )
                if not res:
                    raise errors.bad_request.FailedAddingDuringReposition(
                        task=task_id, **query
                    )

            return new_position