pipelines/main.py
Timothy J. Baek 8c3f262ffd refac
2024-05-26 23:54:26 -07:00

269 lines
8.2 KiB
Python

from fastapi import FastAPI, Request, Depends, status, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.concurrency import run_in_threadpool
from starlette.responses import StreamingResponse, Response
from pydantic import BaseModel, ConfigDict
from typing import List, Union, Generator, Iterator
import time
import json
import uuid
from utils import get_last_user_message, stream_message_template
from schemas import OpenAIChatCompletionForm
import os
import importlib.util
import logging
from concurrent.futures import ThreadPoolExecutor
PIPELINES = {}
PIPELINE_MODULES = {}
def on_startup():
def load_modules_from_directory(directory):
for filename in os.listdir(directory):
if filename.endswith(".py"):
module_name = filename[:-3] # Remove the .py extension
module_path = os.path.join(directory, filename)
spec = importlib.util.spec_from_file_location(module_name, module_path)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
yield module
for loaded_module in load_modules_from_directory("./pipelines"):
# Do something with the loaded module
logging.info("Loaded:", loaded_module.__name__)
pipeline = loaded_module.Pipeline()
pipeline_id = pipeline.id if hasattr(pipeline, "id") else loaded_module.__name__
PIPELINE_MODULES[pipeline_id] = pipeline
if hasattr(pipeline, "manifold") and pipeline.manifold:
for p in pipeline.pipelines:
manifold_pipeline_id = f'{pipeline_id}.{p["id"]}'
manifold_pipeline_name = p["name"]
if hasattr(pipeline, "name"):
manifold_pipeline_name = f"{pipeline.name}{manifold_pipeline_name}"
PIPELINES[manifold_pipeline_id] = {
"module": pipeline_id,
"id": manifold_pipeline_id,
"name": manifold_pipeline_name,
"manifold": True,
}
else:
PIPELINES[loaded_module.__name__] = {
"module": pipeline_id,
"id": pipeline_id,
"name": (
pipeline.name
if hasattr(pipeline, "name")
else loaded_module.__name__
),
}
on_startup()
from contextlib import asynccontextmanager
@asynccontextmanager
async def lifespan(app: FastAPI):
for module in PIPELINE_MODULES.values():
if hasattr(module, "on_startup"):
await module.on_startup()
yield
for module in PIPELINE_MODULES.values():
if hasattr(module, "on_shutdown"):
await module.on_shutdown()
app = FastAPI(docs_url="/docs", redoc_url=None, lifespan=lifespan)
app.state.PIPELINES = PIPELINES
origins = ["*"]
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.middleware("http")
async def check_url(request: Request, call_next):
start_time = int(time.time())
response = await call_next(request)
process_time = int(time.time()) - start_time
response.headers["X-Process-Time"] = str(process_time)
return response
@app.get("/models")
@app.get("/v1/models")
async def get_models():
"""
Returns the available pipelines
"""
return {
"data": [
{
"id": pipeline["id"],
"name": pipeline["name"],
"object": "model",
"created": int(time.time()),
"owned_by": "openai",
"pipeline": True,
}
for pipeline in PIPELINES.values()
]
}
@app.post("/chat/completions")
@app.post("/v1/chat/completions")
async def generate_openai_chat_completion(form_data: OpenAIChatCompletionForm):
user_message = get_last_user_message(form_data.messages)
messages = [message.model_dump() for message in form_data.messages]
if form_data.model not in app.state.PIPELINES:
return HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Model {form_data.model} not found",
)
def job():
print(form_data.model)
pipeline = app.state.PIPELINES[form_data.model]
pipeline_id = form_data.model
if pipeline.get("manifold", False):
pipeline_id = pipeline_id.split(".")[1]
get_response = pipeline["module"].get_response
if form_data.stream:
def stream_content():
res = get_response(
user_message=user_message,
model_id=pipeline_id,
messages=messages,
body=form_data.model_dump(),
)
logging.info(f"stream:true:{res}")
if isinstance(res, str):
message = stream_message_template(form_data.model, res)
logging.info(f"stream_content:str:{message}")
yield f"data: {json.dumps(message)}\n\n"
if isinstance(res, Iterator):
for line in res:
if isinstance(line, BaseModel):
line = line.model_dump_json()
line = f"data: {line}"
try:
line = line.decode("utf-8")
except:
pass
logging.info(f"stream_content:Generator:{line}")
if line.startswith("data:"):
yield f"{line}\n\n"
else:
line = stream_message_template(form_data.model, line)
yield f"data: {json.dumps(line)}\n\n"
if isinstance(res, str) or isinstance(res, Generator):
finish_message = {
"id": f"{form_data.model}-{str(uuid.uuid4())}",
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": form_data.model,
"choices": [
{
"index": 0,
"delta": {},
"logprobs": None,
"finish_reason": "stop",
}
],
}
yield f"data: {json.dumps(finish_message)}\n\n"
yield f"data: [DONE]"
return StreamingResponse(stream_content(), media_type="text/event-stream")
else:
res = get_response(
user_message=user_message,
model_id=pipeline_id,
messages=messages,
body=form_data.model_dump(),
)
logging.info(f"stream:false:{res}")
if isinstance(res, dict):
return res
elif isinstance(res, BaseModel):
return res.model_dump()
else:
message = ""
if isinstance(res, str):
message = res
if isinstance(res, Generator):
for stream in res:
message = f"{message}{stream}"
logging.info(f"stream:false:{message}")
return {
"id": f"{form_data.model}-{str(uuid.uuid4())}",
"object": "chat.completion",
"created": int(time.time()),
"model": form_data.model,
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": message,
},
"logprobs": None,
"finish_reason": "stop",
}
],
}
return await run_in_threadpool(job)
@app.get("/")
async def get_status():
return {"status": True}