from fastapi import FastAPI, Request, Depends, status
from fastapi.middleware.cors import CORSMiddleware

from starlette.responses import StreamingResponse, Response
from pydantic import BaseModel, ConfigDict
from typing import List, Union, Generator


import time
import json
import uuid

from utils import get_last_user_message, stream_message_template
from schemas import OpenAIChatCompletionForm
from config import MODEL_ID, MODEL_NAME

from pipelines.pipeline import get_response

app = FastAPI(docs_url="/docs", redoc_url=None)


origins = ["*"]


app.add_middleware(
    CORSMiddleware,
    allow_origins=origins,
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)


@app.middleware("http")
async def check_url(request: Request, call_next):
    start_time = int(time.time())
    response = await call_next(request)
    process_time = int(time.time()) - start_time
    response.headers["X-Process-Time"] = str(process_time)

    return response


@app.get("/models")
@app.get("/v1/models")
async def get_models():
    """
    Returns the model that is available inside Dialog in the OpenAI format.
    """
    return {
        "data": [
            {
                "id": MODEL_ID,
                "name": MODEL_NAME,
                "object": "model",
                "created": int(time.time()),
                "owned_by": "openai",
            }
        ]
    }


@app.post("/chat/completions")
@app.post("/v1/chat/completions")
async def generate_openai_chat_completion(form_data: OpenAIChatCompletionForm):
    user_message = get_last_user_message(form_data.messages)

    if form_data.stream:

        def stream_content():
            res = get_response(user_message, messages=form_data.messages)

            if isinstance(res, str):
                message = stream_message_template(res)
                yield f"data: {json.dumps(message)}\n\n"

            elif isinstance(res, Generator):
                for message in res:
                    message = stream_message_template(message)
                    yield f"data: {json.dumps(message)}\n\n"

            finish_message = {
                "id": f"rag-{str(uuid.uuid4())}",
                "object": "chat.completion.chunk",
                "created": int(time.time()),
                "model": MODEL_ID,
                "choices": [
                    {"index": 0, "delta": {}, "logprobs": None, "finish_reason": "stop"}
                ],
            }

            yield f"data: {json.dumps(finish_message)}\n\n"
            yield f"data: [DONE]"

        return StreamingResponse(stream_content(), media_type="text/event-stream")
    else:

        res = get_response(user_message, messages=form_data.messages)

        message = ""

        if isinstance(res, str):
            message = res

        elif isinstance(res, Generator):
            for stream in res:
                message = f"{message}{stream}"

        return {
            "id": f"rag-{str(uuid.uuid4())}",
            "object": "chat.completion",
            "created": int(time.time()),
            "model": MODEL_ID,
            "choices": [
                {
                    "index": 0,
                    "message": {
                        "role": "assistant",
                        "content": message,
                    },
                    "logprobs": None,
                    "finish_reason": "stop",
                }
            ],
        }


@app.get("/")
async def get_status():
    return {"status": True}