from datetime import datetime

from mongoengine import Q, EmbeddedDocument

import database
from apierrors import errors
from apimodels.base import UpdateResponse
from apimodels.models import (
    CreateModelRequest,
    CreateModelResponse,
    PublishModelRequest,
    PublishModelResponse,
    ModelTaskPublishResponse,
)
from bll.task import TaskBLL
from config import config
from database.errors import translate_errors_context
from database.model import validate_id
from database.model.model import Model
from database.model.project import Project
from database.model.task.task import Task, TaskStatus
from database.utils import (
    parse_from_call,
    get_company_or_none_constraint,
    filter_fields,
)
from service_repo import APICall, endpoint
from services.utils import conform_tag_fields, conform_output_tags
from timing_context import TimingContext

log = config.logger(__file__)


@endpoint("models.get_by_id", required_fields=["model"])
def get_by_id(call):
    assert isinstance(call, APICall)
    model_id = call.data["model"]

    with translate_errors_context():
        models = Model.get_many(
            company=call.identity.company,
            query_dict=call.data,
            query=Q(id=model_id),
            allow_public=True,
        )
        if not models:
            raise errors.bad_request.InvalidModelId(
                "no such public or company model",
                id=model_id,
                company=call.identity.company,
            )
        conform_output_tags(call, models[0])
        call.result.data = {"model": models[0]}


@endpoint("models.get_by_task_id", required_fields=["task"])
def get_by_task_id(call):
    assert isinstance(call, APICall)
    task_id = call.data["task"]

    with translate_errors_context():
        query = dict(id=task_id, company=call.identity.company)
        task = Task.get(_only=["output"], **query)
        if not task:
            raise errors.bad_request.InvalidTaskId(**query)
        if not task.output:
            raise errors.bad_request.MissingTaskFields(field="output")
        if not task.output.model:
            raise errors.bad_request.MissingTaskFields(field="output.model")

        model_id = task.output.model
        model = Model.objects(
            Q(id=model_id) & get_company_or_none_constraint(call.identity.company)
        ).first()
        if not model:
            raise errors.bad_request.InvalidModelId(
                "no such public or company model",
                id=model_id,
                company=call.identity.company,
            )
        model_dict = model.to_proper_dict()
        conform_output_tags(call, model_dict)
        call.result.data = {"model": model_dict}


@endpoint("models.get_all_ex", required_fields=[])
def get_all_ex(call: APICall):
    conform_tag_fields(call, call.data)
    with translate_errors_context():
        with TimingContext("mongo", "models_get_all_ex"):
            models = Model.get_many_with_join(
                company=call.identity.company, query_dict=call.data, allow_public=True
            )
        conform_output_tags(call, models)
        call.result.data = {"models": models}


@endpoint("models.get_all", required_fields=[])
def get_all(call: APICall):
    conform_tag_fields(call, call.data)
    with translate_errors_context():
        with TimingContext("mongo", "models_get_all"):
            models = Model.get_many(
                company=call.identity.company,
                parameters=call.data,
                query_dict=call.data,
                allow_public=True,
            )
        conform_output_tags(call, models)
        call.result.data = {"models": models}


create_fields = {
    "name": None,
    "tags": list,
    "system_tags": list,
    "task": Task,
    "comment": None,
    "uri": None,
    "project": Project,
    "parent": Model,
    "framework": None,
    "design": None,
    "labels": dict,
    "ready": None,
}


def parse_model_fields(call, valid_fields):
    fields = parse_from_call(call.data, valid_fields, Model.get_fields())
    conform_tag_fields(call, fields)
    return fields


@endpoint("models.update_for_task", required_fields=["task"])
def update_for_task(call, company_id, _):
    assert isinstance(call, APICall)
    task_id = call.data["task"]
    uri = call.data.get("uri")
    iteration = call.data.get("iteration")
    override_model_id = call.data.get("override_model_id")
    if not (uri or override_model_id) or (uri and override_model_id):
        raise errors.bad_request.MissingRequiredFields(
            "exactly one field is required", fields=("uri", "override_model_id")
        )

    with translate_errors_context():

        query = dict(id=task_id, company=company_id)
        task = Task.get_for_writing(
            id=task_id,
            company=company_id,
            _only=["output", "execution", "name", "status", "project"],
        )
        if not task:
            raise errors.bad_request.InvalidTaskId(**query)

        allowed_states = [TaskStatus.created, TaskStatus.in_progress]
        if task.status not in allowed_states:
            raise errors.bad_request.InvalidTaskStatus(
                f"model can only be updated for tasks in the {allowed_states} states",
                **query,
            )

        if override_model_id:
            query = dict(company=company_id, id=override_model_id)
            model = Model.objects(**query).first()
            if not model:
                raise errors.bad_request.InvalidModelId(**query)
        else:
            if "name" not in call.data:
                # use task name if name not provided
                call.data["name"] = task.name

            if "comment" not in call.data:
                call.data["comment"] = f"Created by task `{task.name}` ({task.id})"

            if task.output and task.output.model:
                # model exists, update
                res = _update_model(call, model_id=task.output.model).to_struct()
                res.update({"id": task.output.model, "created": False})
                call.result.data = res
                return

            # new model, create
            fields = parse_model_fields(call, create_fields)

            # create and save model
            model = Model(
                id=database.utils.id(),
                created=datetime.utcnow(),
                user=call.identity.user,
                company=company_id,
                project=task.project,
                framework=task.execution.framework,
                parent=task.execution.model,
                design=task.execution.model_desc,
                labels=task.execution.model_labels,
                ready=(task.status == TaskStatus.published),
                **fields,
            )
            model.save()

        TaskBLL.update_statistics(
            task_id=task_id,
            company_id=company_id,
            last_iteration_max=iteration,
            output__model=model.id,
        )

        call.result.data = {"id": model.id, "created": True}


@endpoint(
    "models.create",
    request_data_model=CreateModelRequest,
    response_data_model=CreateModelResponse,
)
def create(call, company, req_model):
    assert isinstance(call, APICall)
    assert isinstance(req_model, CreateModelRequest)
    identity = call.identity

    if req_model.public:
        company = ""

    with translate_errors_context():

        project = req_model.project
        if project:
            validate_id(Project, company=company, project=project)

        task = req_model.task
        req_data = req_model.to_struct()
        if task:
            validate_task(call, req_data)

        fields = filter_fields(Model, req_data)
        conform_tag_fields(call, fields)

        # create and save model
        model = Model(
            id=database.utils.id(),
            user=identity.user,
            company=company,
            created=datetime.utcnow(),
            **fields,
        )
        model.save()

        call.result.data_model = CreateModelResponse(id=model.id, created=True)


def prepare_update_fields(call, fields):
    fields = fields.copy()
    if "uri" in fields:
        # clear UI cache if URI is provided (model updated)
        fields["ui_cache"] = fields.pop("ui_cache", {})
    if "task" in fields:
        validate_task(call, fields)

    if "labels" in fields:
        labels = fields["labels"]

        def find_other_types(iterable, type_):
            res = [x for x in iterable if not isinstance(x, type_)]
            try:
                return set(res)
            except TypeError:
                # Un-hashable, probably
                return res

        invalid_keys = find_other_types(labels.keys(), str)
        if invalid_keys:
            raise errors.bad_request.ValidationError(
                "labels keys must be strings", keys=invalid_keys
            )

        invalid_values = find_other_types(labels.values(), int)
        if invalid_values:
            raise errors.bad_request.ValidationError(
                "labels values must be integers", values=invalid_values
            )

    conform_tag_fields(call, fields)
    return fields


def validate_task(call, fields):
    Task.get_for_writing(company=call.identity.company, id=fields["task"], _only=["id"])


@endpoint("models.edit", required_fields=["model"], response_data_model=UpdateResponse)
def edit(call: APICall):
    identity = call.identity
    model_id = call.data["model"]

    with translate_errors_context():
        query = dict(id=model_id, company=identity.company)
        model = Model.objects(**query).first()
        if not model:
            raise errors.bad_request.InvalidModelId(**query)

        fields = parse_model_fields(call, create_fields)
        fields = prepare_update_fields(call, fields)

        for key in fields:
            field = getattr(model, key, None)
            value = fields[key]
            if (
                field
                and isinstance(value, dict)
                and isinstance(field, EmbeddedDocument)
            ):
                d = field.to_mongo(use_db_field=False).to_dict()
                d.update(value)
                fields[key] = d

        iteration = call.data.get("iteration")
        task_id = model.task or fields.get("task")
        if task_id and iteration is not None:
            TaskBLL.update_statistics(
                task_id=task_id,
                company_id=identity.company,
                last_iteration_max=iteration,
            )

        if fields:
            updated = model.update(upsert=False, **fields)
            conform_output_tags(call, fields)
            call.result.data_model = UpdateResponse(updated=updated, fields=fields)
        else:
            call.result.data_model = UpdateResponse(updated=0)


def _update_model(call: APICall, model_id=None):
    identity = call.identity
    model_id = model_id or call.data["model"]

    with translate_errors_context():
        # get model by id
        query = dict(id=model_id, company=identity.company)
        model = Model.objects(**query).first()
        if not model:
            raise errors.bad_request.InvalidModelId(**query)

        data = prepare_update_fields(call, call.data)

        task_id = data.get("task")
        iteration = data.get("iteration")
        if task_id and iteration is not None:
            TaskBLL.update_statistics(
                task_id=task_id,
                company_id=identity.company,
                last_iteration_max=iteration,
            )

        updated_count, updated_fields = Model.safe_update(
            call.identity.company, model.id, data
        )
        conform_output_tags(call, updated_fields)
        return UpdateResponse(updated=updated_count, fields=updated_fields)


@endpoint(
    "models.update", required_fields=["model"], response_data_model=UpdateResponse
)
def update(call):
    call.result.data_model = _update_model(call)


@endpoint(
    "models.set_ready",
    request_data_model=PublishModelRequest,
    response_data_model=PublishModelResponse,
)
def set_ready(call: APICall, company, req_model: PublishModelRequest):
    updated, published_task_data = TaskBLL.model_set_ready(
        model_id=req_model.model,
        company_id=company,
        publish_task=req_model.publish_task,
        force_publish_task=req_model.force_publish_task,
    )

    call.result.data_model = PublishModelResponse(
        updated=updated,
        published_task=ModelTaskPublishResponse(**published_task_data)
        if published_task_data
        else None,
    )


@endpoint("models.delete", required_fields=["model"])
def update(call):
    assert isinstance(call, APICall)
    identity = call.identity
    model_id = call.data["model"]
    force = call.data.get("force", False)

    with translate_errors_context():
        query = dict(id=model_id, company=identity.company)
        model = Model.objects(**query).only("id", "task").first()
        if not model:
            raise errors.bad_request.InvalidModelId(**query)

        deleted_model_id = f"__DELETED__{model_id}"

        using_tasks = Task.objects(execution__model=model_id).only("id")
        if using_tasks:
            if not force:
                raise errors.bad_request.ModelInUse(
                    "as execution model, use force=True to delete",
                    num_tasks=len(using_tasks),
                )
            # update deleted model id in using tasks
            using_tasks.update(
                execution__model=deleted_model_id, upsert=False, multi=True
            )

        if model.task:
            task = Task.objects(id=model.task).first()
            if task and task.status == TaskStatus.published:
                if not force:
                    raise errors.bad_request.ModelCreatingTaskExists(
                        "and published, use force=True to delete", task=model.task
                    )
                task.update(
                    output__model=deleted_model_id,
                    output__error=f"model deleted on {datetime.utcnow().isoformat()}",
                    upsert=False,
                )

        del_count = Model.objects(**query).delete()
        call.result.data = dict(deleted=del_count > 0)