pipelines/main.py
2024-05-21 17:22:13 -07:00

183 lines
5.0 KiB
Python

from fastapi import FastAPI, Request, Depends, status, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from starlette.responses import StreamingResponse, Response
from pydantic import BaseModel, ConfigDict
from typing import List, Union, Generator
import time
import json
import uuid
from utils import get_last_user_message, stream_message_template
from schemas import OpenAIChatCompletionForm
from config import MODEL_ID, MODEL_NAME
import os
import importlib.util
PIPELINES = {}
def load_modules_from_directory(directory):
for filename in os.listdir(directory):
if filename.endswith(".py"):
module_name = filename[:-3] # Remove the .py extension
module_path = os.path.join(directory, filename)
spec = importlib.util.spec_from_file_location(module_name, module_path)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
yield module
for loaded_module in load_modules_from_directory("./pipelines"):
# Do something with the loaded module
print("Loaded:", loaded_module.__name__)
PIPELINES[loaded_module.__name__] = {
"module": loaded_module,
"id": loaded_module.__name__,
"name": loaded_module.__name__,
}
from contextlib import asynccontextmanager
@asynccontextmanager
async def lifespan(app: FastAPI):
for pipeline in PIPELINES.values():
if hasattr(pipeline["module"], "on_startup"):
info = await pipeline["module"].on_startup()
if info:
pipeline["id"] = info["id"]
pipeline["name"] = info["name"]
yield
for pipeline in PIPELINES.values():
if hasattr(pipeline["module"], "on_shutdown"):
await pipeline["module"].on_shutdown()
app = FastAPI(docs_url="/docs", redoc_url=None, lifespan=lifespan)
origins = ["*"]
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.middleware("http")
async def check_url(request: Request, call_next):
start_time = int(time.time())
response = await call_next(request)
process_time = int(time.time()) - start_time
response.headers["X-Process-Time"] = str(process_time)
return response
@app.get("/models")
@app.get("/v1/models")
async def get_models():
"""
Returns the available pipelines
"""
return {
"data": [
{
"id": pipeline["id"],
"name": pipeline["name"],
"object": "model",
"created": int(time.time()),
"owned_by": "openai",
}
for pipeline in PIPELINES.values()
]
}
@app.post("/chat/completions")
@app.post("/v1/chat/completions")
async def generate_openai_chat_completion(form_data: OpenAIChatCompletionForm):
user_message = get_last_user_message(form_data.messages)
if form_data.model not in PIPELINES:
return HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Model {form_data.model} not found",
)
get_response = PIPELINES[form_data.model]["module"].get_response
if form_data.stream:
def stream_content():
res = get_response(user_message, messages=form_data.messages)
if isinstance(res, str):
message = stream_message_template(res)
yield f"data: {json.dumps(message)}\n\n"
elif isinstance(res, Generator):
for message in res:
message = stream_message_template(message)
yield f"data: {json.dumps(message)}\n\n"
finish_message = {
"id": f"rag-{str(uuid.uuid4())}",
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": MODEL_ID,
"choices": [
{"index": 0, "delta": {}, "logprobs": None, "finish_reason": "stop"}
],
}
yield f"data: {json.dumps(finish_message)}\n\n"
yield f"data: [DONE]"
return StreamingResponse(stream_content(), media_type="text/event-stream")
else:
res = get_response(user_message, messages=form_data.messages)
message = ""
if isinstance(res, str):
message = res
elif isinstance(res, Generator):
for stream in res:
message = f"{message}{stream}"
return {
"id": f"rag-{str(uuid.uuid4())}",
"object": "chat.completion",
"created": int(time.time()),
"model": MODEL_ID,
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": message,
},
"logprobs": None,
"finish_reason": "stop",
}
],
}
@app.get("/")
async def get_status():
return {"status": True}