pipelines/main.py
Timothy J. Baek 88bf7ec93c fix
2024-05-21 18:36:19 -07:00

198 lines
5.6 KiB
Python

from fastapi import FastAPI, Request, Depends, status, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.concurrency import run_in_threadpool
from starlette.responses import StreamingResponse, Response
from pydantic import BaseModel, ConfigDict
from typing import List, Union, Generator
import time
import json
import uuid
from utils import get_last_user_message, stream_message_template
from schemas import OpenAIChatCompletionForm
from config import MODEL_ID, MODEL_NAME
import os
import importlib.util
PIPELINES = {}
def load_modules_from_directory(directory):
for filename in os.listdir(directory):
if filename.endswith(".py"):
module_name = filename[:-3] # Remove the .py extension
module_path = os.path.join(directory, filename)
spec = importlib.util.spec_from_file_location(module_name, module_path)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
yield module
for loaded_module in load_modules_from_directory("./pipelines"):
# Do something with the loaded module
print("Loaded:", loaded_module.__name__)
pipeline = loaded_module.Pipeline()
PIPELINES[loaded_module.__name__] = {
"module": pipeline,
"id": loaded_module.__name__,
"name": loaded_module.__name__,
}
from contextlib import asynccontextmanager
@asynccontextmanager
async def lifespan(app: FastAPI):
for pipeline in PIPELINES.values():
if hasattr(pipeline["module"], "on_startup"):
info = await pipeline["module"].on_startup()
if info:
pipeline["id"] = info["id"]
pipeline["name"] = info["name"]
yield
for pipeline in PIPELINES.values():
if hasattr(pipeline["module"], "on_shutdown"):
await pipeline["module"].on_shutdown()
app = FastAPI(docs_url="/docs", redoc_url=None, lifespan=lifespan)
app.state.PIPELINES = PIPELINES
origins = ["*"]
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.middleware("http")
async def check_url(request: Request, call_next):
start_time = int(time.time())
response = await call_next(request)
process_time = int(time.time()) - start_time
response.headers["X-Process-Time"] = str(process_time)
return response
@app.get("/models")
@app.get("/v1/models")
async def get_models():
"""
Returns the available pipelines
"""
return {
"data": [
{
"id": pipeline["id"],
"name": pipeline["name"],
"object": "model",
"created": int(time.time()),
"owned_by": "openai",
}
for pipeline in PIPELINES.values()
]
}
@app.post("/chat/completions")
@app.post("/v1/chat/completions")
async def generate_openai_chat_completion(form_data: OpenAIChatCompletionForm):
user_message = get_last_user_message(form_data.messages)
if form_data.model not in app.state.PIPELINES:
return HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Model {form_data.model} not found",
)
def job():
get_response = app.state.PIPELINES[form_data.model]["module"].get_response
if form_data.stream:
def stream_content():
res = get_response(user_message, messages=form_data.messages)
if isinstance(res, str):
message = stream_message_template(res)
yield f"data: {json.dumps(message)}\n\n"
elif isinstance(res, Generator):
for message in res:
print(message)
message = stream_message_template(message)
yield f"data: {json.dumps(message)}\n\n"
finish_message = {
"id": f"{form_data.model}-{str(uuid.uuid4())}",
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": MODEL_ID,
"choices": [
{
"index": 0,
"delta": {},
"logprobs": None,
"finish_reason": "stop",
}
],
}
yield f"data: {json.dumps(finish_message)}\n\n"
yield f"data: [DONE]"
return StreamingResponse(stream_content(), media_type="text/event-stream")
else:
res = get_response(user_message, messages=form_data.messages)
message = ""
if isinstance(res, str):
message = res
elif isinstance(res, Generator):
for stream in res:
message = f"{message}{stream}"
return {
"id": f"{form_data.model}-{str(uuid.uuid4())}",
"object": "chat.completion",
"created": int(time.time()),
"model": MODEL_ID,
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": message,
},
"logprobs": None,
"finish_reason": "stop",
}
],
}
return await run_in_threadpool(job)
@app.get("/")
async def get_status():
return {"status": True}