pipelines/main.py
Timothy J. Baek 92890701f0 refac
2024-05-21 17:02:12 -07:00

145 lines
3.6 KiB
Python

from fastapi import FastAPI, Request, Depends, status
from fastapi.middleware.cors import CORSMiddleware
from starlette.responses import StreamingResponse, Response
from pydantic import BaseModel, ConfigDict
from typing import List, Union, Generator
import time
import json
import uuid
from utils import get_last_user_message, stream_message_template
from schemas import OpenAIChatCompletionForm
from config import MODEL_ID, MODEL_NAME
from pipelines.pipeline import (
get_response,
on_startup,
on_shutdown,
)
from contextlib import asynccontextmanager
@asynccontextmanager
async def lifespan(app: FastAPI):
await on_startup()
yield
await on_shutdown()
app = FastAPI(docs_url="/docs", redoc_url=None, lifespan=lifespan)
origins = ["*"]
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.middleware("http")
async def check_url(request: Request, call_next):
start_time = int(time.time())
response = await call_next(request)
process_time = int(time.time()) - start_time
response.headers["X-Process-Time"] = str(process_time)
return response
@app.get("/models")
@app.get("/v1/models")
async def get_models():
"""
Returns the model that is available inside Dialog in the OpenAI format.
"""
return {
"data": [
{
"id": MODEL_ID,
"name": MODEL_NAME,
"object": "model",
"created": int(time.time()),
"owned_by": "openai",
}
]
}
@app.post("/chat/completions")
@app.post("/v1/chat/completions")
async def generate_openai_chat_completion(form_data: OpenAIChatCompletionForm):
user_message = get_last_user_message(form_data.messages)
if form_data.stream:
def stream_content():
res = get_response(user_message, messages=form_data.messages)
if isinstance(res, str):
message = stream_message_template(res)
yield f"data: {json.dumps(message)}\n\n"
elif isinstance(res, Generator):
for message in res:
message = stream_message_template(message)
yield f"data: {json.dumps(message)}\n\n"
finish_message = {
"id": f"rag-{str(uuid.uuid4())}",
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": MODEL_ID,
"choices": [
{"index": 0, "delta": {}, "logprobs": None, "finish_reason": "stop"}
],
}
yield f"data: {json.dumps(finish_message)}\n\n"
yield f"data: [DONE]"
return StreamingResponse(stream_content(), media_type="text/event-stream")
else:
res = get_response(user_message, messages=form_data.messages)
message = ""
if isinstance(res, str):
message = res
elif isinstance(res, Generator):
for stream in res:
message = f"{message}{stream}"
return {
"id": f"rag-{str(uuid.uuid4())}",
"object": "chat.completion",
"created": int(time.time()),
"model": MODEL_ID,
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": message,
},
"logprobs": None,
"finish_reason": "stop",
}
],
}
@app.get("/")
async def get_status():
return {"status": True}