from fastapi import FastAPI, Request, Depends, status, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.concurrency import run_in_threadpool


from starlette.responses import StreamingResponse, Response
from pydantic import BaseModel, ConfigDict
from typing import List, Union, Generator, Iterator


import time
import json
import uuid

from utils import get_last_user_message, stream_message_template
from schemas import OpenAIChatCompletionForm

import os
import importlib.util

import logging

from concurrent.futures import ThreadPoolExecutor

PIPELINES = {}
PIPELINE_MODULES = {}


def on_startup():
    def load_modules_from_directory(directory):
        for filename in os.listdir(directory):
            if filename.endswith(".py"):
                module_name = filename[:-3]  # Remove the .py extension
                module_path = os.path.join(directory, filename)
                spec = importlib.util.spec_from_file_location(module_name, module_path)
                module = importlib.util.module_from_spec(spec)
                spec.loader.exec_module(module)
                yield module

    for loaded_module in load_modules_from_directory("./pipelines"):
        # Do something with the loaded module
        logging.info("Loaded:", loaded_module.__name__)

        pipeline = loaded_module.Pipeline()
        pipeline_id = pipeline.id if hasattr(pipeline, "id") else loaded_module.__name__

        PIPELINE_MODULES[pipeline_id] = pipeline

        if hasattr(pipeline, "manifold") and pipeline.manifold:
            for p in pipeline.pipelines:
                manifold_pipeline_id = f'{pipeline_id}.{p["id"]}'

                manifold_pipeline_name = p["name"]
                if hasattr(pipeline, "name"):
                    manifold_pipeline_name = f"{pipeline.name}{manifold_pipeline_name}"

                PIPELINES[manifold_pipeline_id] = {
                    "module": pipeline_id,
                    "id": manifold_pipeline_id,
                    "name": manifold_pipeline_name,
                    "manifold": True,
                }
        else:
            PIPELINES[loaded_module.__name__] = {
                "module": pipeline_id,
                "id": pipeline_id,
                "name": (
                    pipeline.name
                    if hasattr(pipeline, "name")
                    else loaded_module.__name__
                ),
            }


on_startup()


from contextlib import asynccontextmanager


@asynccontextmanager
async def lifespan(app: FastAPI):
    for module in PIPELINE_MODULES.values():
        if hasattr(module, "on_startup"):
            await module.on_startup()
    yield

    for module in PIPELINE_MODULES.values():
        if hasattr(module, "on_shutdown"):
            await module.on_shutdown()


app = FastAPI(docs_url="/docs", redoc_url=None, lifespan=lifespan)

app.state.PIPELINES = PIPELINES


origins = ["*"]


app.add_middleware(
    CORSMiddleware,
    allow_origins=origins,
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)


@app.middleware("http")
async def check_url(request: Request, call_next):
    start_time = int(time.time())
    response = await call_next(request)
    process_time = int(time.time()) - start_time
    response.headers["X-Process-Time"] = str(process_time)

    return response


@app.get("/models")
@app.get("/v1/models")
async def get_models():
    """
    Returns the available pipelines
    """
    return {
        "data": [
            {
                "id": pipeline["id"],
                "name": pipeline["name"],
                "object": "model",
                "created": int(time.time()),
                "owned_by": "openai",
                "pipeline": True,
            }
            for pipeline in PIPELINES.values()
        ]
    }


@app.post("/chat/completions")
@app.post("/v1/chat/completions")
async def generate_openai_chat_completion(form_data: OpenAIChatCompletionForm):
    user_message = get_last_user_message(form_data.messages)
    messages = [message.model_dump() for message in form_data.messages]

    if form_data.model not in app.state.PIPELINES:
        return HTTPException(
            status_code=status.HTTP_404_NOT_FOUND,
            detail=f"Model {form_data.model} not found",
        )

    def job():
        print(form_data.model)

        pipeline = app.state.PIPELINES[form_data.model]
        pipeline_id = form_data.model

        print(pipeline_id)

        if pipeline.get("manifold", False):
            manifold_id, pipeline_id = pipeline_id.split(".", 1)
            get_response = PIPELINE_MODULES[manifold_id].get_response
        else:
            get_response = PIPELINE_MODULES[pipeline_id].get_response

        if form_data.stream:

            def stream_content():
                res = get_response(
                    user_message=user_message,
                    model_id=pipeline_id,
                    messages=messages,
                    body=form_data.model_dump(),
                )

                logging.info(f"stream:true:{res}")

                if isinstance(res, str):
                    message = stream_message_template(form_data.model, res)
                    logging.info(f"stream_content:str:{message}")
                    yield f"data: {json.dumps(message)}\n\n"

                if isinstance(res, Iterator):
                    for line in res:
                        if isinstance(line, BaseModel):
                            line = line.model_dump_json()
                            line = f"data: {line}"

                        try:
                            line = line.decode("utf-8")
                        except:
                            pass

                        logging.info(f"stream_content:Generator:{line}")

                        if line.startswith("data:"):
                            yield f"{line}\n\n"
                        else:
                            line = stream_message_template(form_data.model, line)
                            yield f"data: {json.dumps(line)}\n\n"

                if isinstance(res, str) or isinstance(res, Generator):
                    finish_message = {
                        "id": f"{form_data.model}-{str(uuid.uuid4())}",
                        "object": "chat.completion.chunk",
                        "created": int(time.time()),
                        "model": form_data.model,
                        "choices": [
                            {
                                "index": 0,
                                "delta": {},
                                "logprobs": None,
                                "finish_reason": "stop",
                            }
                        ],
                    }

                    yield f"data: {json.dumps(finish_message)}\n\n"
                    yield f"data: [DONE]"

            return StreamingResponse(stream_content(), media_type="text/event-stream")
        else:
            res = get_response(
                user_message=user_message,
                model_id=pipeline_id,
                messages=messages,
                body=form_data.model_dump(),
            )
            logging.info(f"stream:false:{res}")

            if isinstance(res, dict):
                return res
            elif isinstance(res, BaseModel):
                return res.model_dump()
            else:

                message = ""

                if isinstance(res, str):
                    message = res

                if isinstance(res, Generator):
                    for stream in res:
                        message = f"{message}{stream}"

                logging.info(f"stream:false:{message}")
                return {
                    "id": f"{form_data.model}-{str(uuid.uuid4())}",
                    "object": "chat.completion",
                    "created": int(time.time()),
                    "model": form_data.model,
                    "choices": [
                        {
                            "index": 0,
                            "message": {
                                "role": "assistant",
                                "content": message,
                            },
                            "logprobs": None,
                            "finish_reason": "stop",
                        }
                    ],
                }

    return await run_in_threadpool(job)


@app.get("/")
async def get_status():
    return {"status": True}