from datetime import datetime
from typing import Callable, Tuple

from apiserver.apierrors import errors
from apiserver.apimodels.models import ModelTaskPublishResponse
from apiserver.bll.task.utils import deleted_prefix
from apiserver.database.model import EntityVisibility
from apiserver.database.model.model import Model
from apiserver.database.model.task.task import Task, TaskStatus
from .metadata import Metadata


class ModelBLL:
    @classmethod
    def get_company_model_by_id(
        cls, company_id: str, model_id: str, only_fields=None
    ) -> Model:
        query = dict(company=company_id, id=model_id)
        qs = Model.objects(**query)
        if only_fields:
            qs = qs.only(*only_fields)
        model = qs.first()
        if not model:
            raise errors.bad_request.InvalidModelId(**query)
        return model

    @classmethod
    def publish_model(
        cls,
        model_id: str,
        company_id: str,
        force_publish_task: bool = False,
        publish_task_func: Callable[[str, str, bool], dict] = None,
    ) -> Tuple[int, ModelTaskPublishResponse]:
        model = cls.get_company_model_by_id(company_id=company_id, model_id=model_id)
        if model.ready:
            raise errors.bad_request.ModelIsReady(company=company_id, model=model_id)

        published_task = None
        if model.task and publish_task_func:
            task = (
                Task.objects(id=model.task, company=company_id)
                .only("id", "status")
                .first()
            )
            if task and task.status != TaskStatus.published:
                task_publish_res = publish_task_func(
                    model.task, company_id, force_publish_task
                )
                published_task = ModelTaskPublishResponse(
                    id=model.task, data=task_publish_res
                )

        updated = model.update(upsert=False, ready=True, last_update=datetime.utcnow())
        return updated, published_task

    @classmethod
    def delete_model(
        cls, model_id: str, company_id: str, force: bool
    ) -> Tuple[int, Model]:
        model = cls.get_company_model_by_id(
            company_id=company_id,
            model_id=model_id,
            only_fields=("id", "task", "project", "uri"),
        )
        deleted_model_id = f"{deleted_prefix}{model_id}"

        using_tasks = Task.objects(models__input__model=model_id).only("id")
        if using_tasks:
            if not force:
                raise errors.bad_request.ModelInUse(
                    "as execution model, use force=True to delete",
                    num_tasks=len(using_tasks),
                )
            # update deleted model id in using tasks
            Task._get_collection().update_many(
                filter={"_id": {"$in": [t.id for t in using_tasks]}},
                update={"$set": {"models.input.$[elem].model": deleted_model_id}},
                array_filters=[{"elem.model": model_id}],
                upsert=False,
            )

        if model.task:
            task = Task.objects(id=model.task).first()
            if task and task.status == TaskStatus.published:
                if not force:
                    raise errors.bad_request.ModelCreatingTaskExists(
                        "and published, use force=True to delete", task=model.task
                    )
                if task.models.output and model_id in task.models.output:
                    now = datetime.utcnow()
                    Task._get_collection().update_one(
                        filter={"_id": model.task, "models.output.model": model_id},
                        update={
                            "$set": {
                                "models.output.$[elem].model": deleted_model_id,
                                "output.error": f"model deleted on {now.isoformat()}",
                            },
                            "last_change": now,
                        },
                        array_filters=[{"elem.model": model_id}],
                        upsert=False,
                    )

        del_count = Model.objects(id=model_id, company=company_id).delete()
        return del_count, model

    @classmethod
    def archive_model(cls, model_id: str, company_id: str):
        cls.get_company_model_by_id(
            company_id=company_id, model_id=model_id, only_fields=("id",)
        )
        archived = Model.objects(company=company_id, id=model_id).update(
            add_to_set__system_tags=EntityVisibility.archived.value,
            last_update=datetime.utcnow(),
        )

        return archived

    @classmethod
    def unarchive_model(cls, model_id: str, company_id: str):
        cls.get_company_model_by_id(
            company_id=company_id, model_id=model_id, only_fields=("id",)
        )
        unarchived = Model.objects(company=company_id, id=model_id).update(
            pull__system_tags=EntityVisibility.archived.value,
            last_update=datetime.utcnow(),
        )

        return unarchived