from typing import Sequence

import attr
from mongoengine import Q

from apiserver.apierrors import errors
from apiserver.apierrors.errors.bad_request import InvalidProjectId
from apiserver.apimodels.base import UpdateResponse, MakePublicRequest, IdResponse
from apiserver.apimodels.projects import (
    GetParamsRequest,
    ProjectTagsRequest,
    ProjectTaskParentsRequest,
    ProjectHyperparamValuesRequest,
    ProjectsGetRequest,
    DeleteRequest,
    MoveRequest,
    MergeRequest,
    ProjectOrNoneRequest,
    ProjectRequest,
    ProjectModelMetadataValuesRequest,
)
from apiserver.bll.organization import OrgBLL, Tags
from apiserver.bll.project import ProjectBLL, ProjectQueries
from apiserver.bll.project.project_cleanup import (
    delete_project,
    validate_project_delete,
)
from apiserver.database.errors import translate_errors_context
from apiserver.database.model import EntityVisibility
from apiserver.database.model.project import Project
from apiserver.database.utils import (
    parse_from_call,
    get_company_or_none_constraint,
)
from apiserver.service_repo import APICall, endpoint
from apiserver.services.utils import (
    conform_tag_fields,
    conform_output_tags,
    get_tags_filter_dictionary,
    sort_tags_response,
)

org_bll = OrgBLL()
project_bll = ProjectBLL()
project_queries = ProjectQueries()

create_fields = {
    "name": None,
    "description": None,
    "tags": list,
    "system_tags": list,
    "default_output_destination": None,
}


@endpoint("projects.get_by_id", required_fields=["project"])
def get_by_id(call):
    assert isinstance(call, APICall)
    project_id = call.data["project"]

    with translate_errors_context():
        query = Q(id=project_id) & get_company_or_none_constraint(call.identity.company)
        project = Project.objects(query).first()
        if not project:
            raise errors.bad_request.InvalidProjectId(id=project_id)

        project_dict = project.to_proper_dict()
        conform_output_tags(call, project_dict)

        call.result.data = {"project": project_dict}


def _hidden_query(search_hidden: bool, ids: Sequence) -> Q:
    """
    1. Add only non-hidden tasks search condition (unless specifically specified differently)
    """
    if search_hidden or ids:
        return Q()

    return Q(system_tags__ne=EntityVisibility.hidden.value)


def _adjust_search_parameters(data: dict, shallow_search: bool):
    """
    1. Make sure that there is no external query on path
    2. If not shallow_search and parent is provided then parent can be at any place in path
    3. If shallow_search and no parent provided then use a top level parent
    """
    data.pop("path", None)
    if not shallow_search:
        if "parent" in data:
            data["path"] = data.pop("parent")
        return

    if "parent" not in data:
        data["parent"] = [None]


@endpoint("projects.get_all_ex", request_data_model=ProjectsGetRequest)
def get_all_ex(call: APICall, company_id: str, request: ProjectsGetRequest):
    data = call.data
    conform_tag_fields(call, data)
    allow_public = not request.non_public
    requested_ids = data.get("id")
    if isinstance(requested_ids, str):
        requested_ids = [requested_ids]
    _adjust_search_parameters(
        data, shallow_search=request.shallow_search,
    )
    user_active_project_ids = None
    if request.active_users:
        ids, user_active_project_ids = project_bll.get_projects_with_active_user(
            company=company_id,
            users=request.active_users,
            project_ids=requested_ids,
            allow_public=allow_public,
        )
        if not ids:
            return {"projects": []}
        data["id"] = ids

    ret_params = {}
    projects: Sequence[dict] = Project.get_many_with_join(
        company=company_id,
        query_dict=data,
        query=_hidden_query(search_hidden=request.search_hidden, ids=requested_ids),
        allow_public=allow_public,
        ret_params=ret_params,
    )
    if not projects:
        return {"projects": projects, **ret_params}

    project_ids = list({project["id"] for project in projects})
    if request.check_own_contents:
        contents = project_bll.calc_own_contents(
            company=company_id,
            project_ids=project_ids,
            filter_=request.include_stats_filter,
            users=request.active_users,
        )
        for project in projects:
            project.update(**contents.get(project["id"], {}))

    conform_output_tags(call, projects)
    if request.include_stats:
        stats, children = project_bll.get_project_stats(
            company=company_id,
            project_ids=project_ids,
            specific_state=request.stats_for_state,
            include_children=request.stats_with_children,
            search_hidden=request.search_hidden,
            filter_=request.include_stats_filter,
            users=request.active_users,
            user_active_project_ids=user_active_project_ids,
        )

        for project in projects:
            project["stats"] = stats[project["id"]]
            project["sub_projects"] = children[project["id"]]

    if request.include_dataset_stats:
        dataset_stats = project_bll.get_dataset_stats(
            company=company_id, project_ids=project_ids, users=request.active_users,
        )
        for project in projects:
            project["dataset_stats"] = dataset_stats.get(project["id"])

    call.result.data = {"projects": projects, **ret_params}


@endpoint("projects.get_all")
def get_all(call: APICall):
    data = call.data
    conform_tag_fields(call, data)
    _adjust_search_parameters(
        data, shallow_search=data.get("shallow_search", False),
    )
    ret_params = {}
    projects = Project.get_many(
        company=call.identity.company,
        query_dict=data,
        query=_hidden_query(
            search_hidden=data.get("search_hidden"), ids=data.get("id")
        ),
        parameters=data,
        allow_public=True,
        ret_params=ret_params,
    )
    conform_output_tags(call, projects)
    call.result.data = {"projects": projects, **ret_params}


@endpoint(
    "projects.create", required_fields=["name"], response_data_model=IdResponse,
)
def create(call: APICall):
    identity = call.identity

    with translate_errors_context():
        fields = parse_from_call(call.data, create_fields, Project.get_fields())
        conform_tag_fields(call, fields, validate=True)

        return IdResponse(
            id=ProjectBLL.create(
                user=identity.user, company=identity.company, **fields,
            )
        )


@endpoint(
    "projects.update", required_fields=["project"], response_data_model=UpdateResponse
)
def update(call: APICall):
    """
    update

    :summary: Update project information.
              See `project.create` for parameters.
    :return: updated - `int` - number of projects updated
             fields - `[string]` - updated fields
    """
    fields = parse_from_call(
        call.data, create_fields, Project.get_fields(), discard_none_values=False
    )
    conform_tag_fields(call, fields, validate=True)
    updated = ProjectBLL.update(
        company=call.identity.company, project_id=call.data["project"], **fields
    )
    conform_output_tags(call, fields)
    call.result.data_model = UpdateResponse(updated=updated, fields=fields)


def _reset_cached_tags(company: str, projects: Sequence[str]):
    org_bll.reset_tags(company, Tags.Task, projects=projects)
    org_bll.reset_tags(company, Tags.Model, projects=projects)


@endpoint("projects.move", request_data_model=MoveRequest)
def move(call: APICall, company: str, request: MoveRequest):
    moved, affected_projects = ProjectBLL.move_project(
        company=company,
        user=call.identity.user,
        project_id=request.project,
        new_location=request.new_location,
    )
    _reset_cached_tags(company, projects=list(affected_projects))

    call.result.data = {"moved": moved}


@endpoint("projects.merge", request_data_model=MergeRequest)
def merge(call: APICall, company: str, request: MergeRequest):
    moved_entitites, moved_projects, affected_projects = ProjectBLL.merge_project(
        company, source_id=request.project, destination_id=request.destination_project
    )

    _reset_cached_tags(company, projects=list(affected_projects))

    call.result.data = {
        "moved_entities": moved_entitites,
        "moved_projects": moved_projects,
    }


@endpoint("projects.validate_delete")
def validate_delete(call: APICall, company_id: str, request: ProjectRequest):
    call.result.data = validate_project_delete(
        company=company_id, project_id=request.project
    )


@endpoint("projects.delete", request_data_model=DeleteRequest)
def delete(call: APICall, company_id: str, request: DeleteRequest):
    res, affected_projects = delete_project(
        company=company_id,
        user=call.identity.user,
        project_id=request.project,
        force=request.force,
        delete_contents=request.delete_contents,
    )
    _reset_cached_tags(company_id, projects=list(affected_projects))
    call.result.data = {**attr.asdict(res)}


@endpoint(
    "projects.get_unique_metric_variants", request_data_model=ProjectOrNoneRequest
)
def get_unique_metric_variants(
    call: APICall, company_id: str, request: ProjectOrNoneRequest
):

    metrics = project_queries.get_unique_metric_variants(
        company_id,
        [request.project] if request.project else None,
        include_subprojects=request.include_subprojects,
    )

    call.result.data = {"metrics": metrics}


@endpoint("projects.get_model_metadata_keys",)
def get_model_metadata_keys(call: APICall, company_id: str, request: GetParamsRequest):
    total, remaining, keys = project_queries.get_model_metadata_keys(
        company_id,
        project_ids=[request.project] if request.project else None,
        include_subprojects=request.include_subprojects,
        page=request.page,
        page_size=request.page_size,
    )

    call.result.data = {
        "total": total,
        "remaining": remaining,
        "keys": keys,
    }


@endpoint("projects.get_model_metadata_values")
def get_model_metadata_values(
    call: APICall, company_id: str, request: ProjectModelMetadataValuesRequest
):
    total, values = project_queries.get_model_metadata_distinct_values(
        company_id,
        project_ids=request.projects,
        key=request.key,
        include_subprojects=request.include_subprojects,
        allow_public=request.allow_public,
    )
    call.result.data = {
        "total": total,
        "values": values,
    }


@endpoint(
    "projects.get_hyper_parameters",
    min_version="2.9",
    request_data_model=GetParamsRequest,
)
def get_hyper_parameters(call: APICall, company_id: str, request: GetParamsRequest):

    total, remaining, parameters = project_queries.get_aggregated_project_parameters(
        company_id,
        project_ids=[request.project] if request.project else None,
        include_subprojects=request.include_subprojects,
        page=request.page,
        page_size=request.page_size,
    )

    call.result.data = {
        "total": total,
        "remaining": remaining,
        "parameters": parameters,
    }


@endpoint(
    "projects.get_hyperparam_values",
    min_version="2.13",
    request_data_model=ProjectHyperparamValuesRequest,
)
def get_hyperparam_values(
    call: APICall, company_id: str, request: ProjectHyperparamValuesRequest
):
    total, values = project_queries.get_task_hyperparam_distinct_values(
        company_id,
        project_ids=request.projects,
        section=request.section,
        name=request.name,
        include_subprojects=request.include_subprojects,
        allow_public=request.allow_public,
    )
    call.result.data = {
        "total": total,
        "values": values,
    }


@endpoint("projects.get_project_tags")
def get_tags(call: APICall, company, request: ProjectTagsRequest):
    tags, system_tags = project_bll.get_project_tags(
        company,
        include_system=request.include_system,
        filter_=get_tags_filter_dictionary(request.filter),
        projects=request.projects,
    )
    call.result.data = sort_tags_response({"tags": tags, "system_tags": system_tags})


@endpoint(
    "projects.get_task_tags", min_version="2.8", request_data_model=ProjectTagsRequest
)
def get_tags(call: APICall, company, request: ProjectTagsRequest):
    ret = org_bll.get_tags(
        company,
        Tags.Task,
        include_system=request.include_system,
        filter_=get_tags_filter_dictionary(request.filter),
        projects=request.projects,
    )
    call.result.data = sort_tags_response(ret)


@endpoint(
    "projects.get_model_tags", min_version="2.8", request_data_model=ProjectTagsRequest
)
def get_tags(call: APICall, company, request: ProjectTagsRequest):
    ret = org_bll.get_tags(
        company,
        Tags.Model,
        include_system=request.include_system,
        filter_=get_tags_filter_dictionary(request.filter),
        projects=request.projects,
    )
    call.result.data = sort_tags_response(ret)


@endpoint(
    "projects.make_public", min_version="2.9", request_data_model=MakePublicRequest
)
def make_public(call: APICall, company_id, request: MakePublicRequest):
    call.result.data = Project.set_public(
        company_id, ids=request.ids, invalid_cls=InvalidProjectId, enabled=True
    )


@endpoint(
    "projects.make_private", min_version="2.9", request_data_model=MakePublicRequest
)
def make_public(call: APICall, company_id, request: MakePublicRequest):
    call.result.data = Project.set_public(
        company_id, ids=request.ids, invalid_cls=InvalidProjectId, enabled=False
    )


@endpoint(
    "projects.get_task_parents",
    min_version="2.12",
    request_data_model=ProjectTaskParentsRequest,
)
def get_task_parents(
    call: APICall, company_id: str, request: ProjectTaskParentsRequest
):
    call.result.data = {
        "parents": project_bll.get_task_parents(
            company_id,
            projects=request.projects,
            include_subprojects=request.include_subprojects,
            state=request.tasks_state,
        )
    }