from fastapi import FastAPI, Request, Depends, status, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.concurrency import run_in_threadpool


from starlette.responses import StreamingResponse, Response
from pydantic import BaseModel, ConfigDict
from typing import List, Union, Generator, Iterator


import time
import json
import uuid

from utils import get_last_user_message, stream_message_template
from schemas import FilterForm, OpenAIChatCompletionForm

import os
import importlib.util

import logging

from concurrent.futures import ThreadPoolExecutor


import os

####################################
# Load .env file
####################################

try:
    from dotenv import load_dotenv, find_dotenv

    load_dotenv(find_dotenv(".env"))
except ImportError:
    print("dotenv not installed, skipping...")


PIPELINES = {}
PIPELINE_MODULES = {}


def on_startup():
    def load_modules_from_directory(directory):
        for filename in os.listdir(directory):
            if filename.endswith(".py"):
                module_name = filename[:-3]  # Remove the .py extension
                module_path = os.path.join(directory, filename)
                spec = importlib.util.spec_from_file_location(module_name, module_path)
                module = importlib.util.module_from_spec(spec)
                spec.loader.exec_module(module)
                yield module

    for loaded_module in load_modules_from_directory("./pipelines"):
        # Do something with the loaded module
        logging.info("Loaded:", loaded_module.__name__)

        pipeline = loaded_module.Pipeline()
        pipeline_id = pipeline.id if hasattr(pipeline, "id") else loaded_module.__name__

        PIPELINE_MODULES[pipeline_id] = pipeline

        if hasattr(pipeline, "type"):
            if pipeline.type == "manifold":
                for p in pipeline.pipelines:
                    manifold_pipeline_id = f'{pipeline_id}.{p["id"]}'

                    manifold_pipeline_name = p["name"]
                    if hasattr(pipeline, "name"):
                        manifold_pipeline_name = (
                            f"{pipeline.name}{manifold_pipeline_name}"
                        )

                    PIPELINES[manifold_pipeline_id] = {
                        "module": pipeline_id,
                        "id": manifold_pipeline_id,
                        "name": manifold_pipeline_name,
                        "manifold": True,
                    }
            if pipeline.type == "filter":
                PIPELINES[pipeline_id] = {
                    "module": pipeline_id,
                    "id": pipeline_id,
                    "name": (
                        pipeline.name if hasattr(pipeline, "name") else pipeline_id
                    ),
                    "filter": True,
                    "pipelines": (
                        pipeline.pipelines if hasattr(pipeline, "pipelines") else []
                    ),
                    "priority": (
                        pipeline.priority if hasattr(pipeline, "priority") else 0
                    ),
                }
        else:
            PIPELINES[pipeline_id] = {
                "module": pipeline_id,
                "id": pipeline_id,
                "name": (pipeline.name if hasattr(pipeline, "name") else pipeline_id),
            }


on_startup()


from contextlib import asynccontextmanager


@asynccontextmanager
async def lifespan(app: FastAPI):
    for module in PIPELINE_MODULES.values():
        if hasattr(module, "on_startup"):
            await module.on_startup()
    yield

    for module in PIPELINE_MODULES.values():
        if hasattr(module, "on_shutdown"):
            await module.on_shutdown()


app = FastAPI(docs_url="/docs", redoc_url=None, lifespan=lifespan)

app.state.PIPELINES = PIPELINES


origins = ["*"]


app.add_middleware(
    CORSMiddleware,
    allow_origins=origins,
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)


@app.middleware("http")
async def check_url(request: Request, call_next):
    start_time = int(time.time())
    response = await call_next(request)
    process_time = int(time.time()) - start_time
    response.headers["X-Process-Time"] = str(process_time)

    return response


@app.get("/models")
@app.get("/v1/models")
async def get_models():
    """
    Returns the available pipelines
    """
    return {
        "data": [
            {
                "id": pipeline["id"],
                "name": pipeline["name"],
                "object": "model",
                "created": int(time.time()),
                "owned_by": "openai",
                **(
                    {
                        "pipeline": {
                            "type": (
                                "pipeline" if not pipeline.get("filter") else "filter"
                            ),
                            "pipelines": pipeline.get("pipelines", []),
                            "priority": pipeline.get("priority", 0),
                        }
                    }
                    if pipeline.get("filter", False)
                    else {}
                ),
            }
            for pipeline in PIPELINES.values()
        ]
    }


@app.post("/filter")
@app.post("/v1/filter")
async def filter(form_data: FilterForm):
    if form_data.model not in app.state.PIPELINES or not app.state.PIPELINES[
        form_data.model
    ].get("filter", False):
        raise HTTPException(
            status_code=status.HTTP_404_NOT_FOUND,
            detail=f"filter {form_data.model} not found",
        )

    pipeline = PIPELINE_MODULES[form_data.model]

    try:
        body = await pipeline.filter(form_data.body, form_data.user)
        return body
    except Exception as e:
        print(e)
        raise HTTPException(
            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
            detail=f"Error in filter {form_data.model}: {str(e)}",
        )


@app.post("/chat/completions")
@app.post("/v1/chat/completions")
async def generate_openai_chat_completion(form_data: OpenAIChatCompletionForm):
    user_message = get_last_user_message(form_data.messages)
    messages = [message.model_dump() for message in form_data.messages]

    if form_data.model not in app.state.PIPELINES or app.state.PIPELINES[
        form_data.model
    ].get("filter", False):
        raise HTTPException(
            status_code=status.HTTP_404_NOT_FOUND,
            detail=f"Pipeline {form_data.model} not found",
        )

    def job():
        print(form_data.model)

        pipeline = app.state.PIPELINES[form_data.model]
        pipeline_id = form_data.model

        print(pipeline_id)

        if pipeline.get("manifold", False):
            manifold_id, pipeline_id = pipeline_id.split(".", 1)
            pipe = PIPELINE_MODULES[manifold_id].pipe
        else:
            pipe = PIPELINE_MODULES[pipeline_id].pipe

        if form_data.stream:

            def stream_content():
                res = pipe(
                    user_message=user_message,
                    model_id=pipeline_id,
                    messages=messages,
                    body=form_data.model_dump(),
                )

                logging.info(f"stream:true:{res}")

                if isinstance(res, str):
                    message = stream_message_template(form_data.model, res)
                    logging.info(f"stream_content:str:{message}")
                    yield f"data: {json.dumps(message)}\n\n"

                if isinstance(res, Iterator):
                    for line in res:
                        if isinstance(line, BaseModel):
                            line = line.model_dump_json()
                            line = f"data: {line}"

                        try:
                            line = line.decode("utf-8")
                        except:
                            pass

                        logging.info(f"stream_content:Generator:{line}")

                        if line.startswith("data:"):
                            yield f"{line}\n\n"
                        else:
                            line = stream_message_template(form_data.model, line)
                            yield f"data: {json.dumps(line)}\n\n"

                if isinstance(res, str) or isinstance(res, Generator):
                    finish_message = {
                        "id": f"{form_data.model}-{str(uuid.uuid4())}",
                        "object": "chat.completion.chunk",
                        "created": int(time.time()),
                        "model": form_data.model,
                        "choices": [
                            {
                                "index": 0,
                                "delta": {},
                                "logprobs": None,
                                "finish_reason": "stop",
                            }
                        ],
                    }

                    yield f"data: {json.dumps(finish_message)}\n\n"
                    yield f"data: [DONE]"

            return StreamingResponse(stream_content(), media_type="text/event-stream")
        else:
            res = pipe(
                user_message=user_message,
                model_id=pipeline_id,
                messages=messages,
                body=form_data.model_dump(),
            )
            logging.info(f"stream:false:{res}")

            if isinstance(res, dict):
                return res
            elif isinstance(res, BaseModel):
                return res.model_dump()
            else:

                message = ""

                if isinstance(res, str):
                    message = res

                if isinstance(res, Generator):
                    for stream in res:
                        message = f"{message}{stream}"

                logging.info(f"stream:false:{message}")
                return {
                    "id": f"{form_data.model}-{str(uuid.uuid4())}",
                    "object": "chat.completion",
                    "created": int(time.time()),
                    "model": form_data.model,
                    "choices": [
                        {
                            "index": 0,
                            "message": {
                                "role": "assistant",
                                "content": message,
                            },
                            "logprobs": None,
                            "finish_reason": "stop",
                        }
                    ],
                }

    return await run_in_threadpool(job)


@app.get("/")
async def get_status():
    return {"status": True}