feat: multi-pipeline support

This commit is contained in:
Timothy J. Baek 2024-05-21 17:22:13 -07:00
parent 92890701f0
commit eaa4112f46
3 changed files with 84 additions and 14 deletions

64
main.py
View File

@ -1,4 +1,4 @@
from fastapi import FastAPI, Request, Depends, status
from fastapi import FastAPI, Request, Depends, status, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from starlette.responses import StreamingResponse, Response
@ -14,20 +14,50 @@ from utils import get_last_user_message, stream_message_template
from schemas import OpenAIChatCompletionForm
from config import MODEL_ID, MODEL_NAME
from pipelines.pipeline import (
get_response,
on_startup,
on_shutdown,
)
import os
import importlib.util
PIPELINES = {}
def load_modules_from_directory(directory):
for filename in os.listdir(directory):
if filename.endswith(".py"):
module_name = filename[:-3] # Remove the .py extension
module_path = os.path.join(directory, filename)
spec = importlib.util.spec_from_file_location(module_name, module_path)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
yield module
for loaded_module in load_modules_from_directory("./pipelines"):
# Do something with the loaded module
print("Loaded:", loaded_module.__name__)
PIPELINES[loaded_module.__name__] = {
"module": loaded_module,
"id": loaded_module.__name__,
"name": loaded_module.__name__,
}
from contextlib import asynccontextmanager
@asynccontextmanager
async def lifespan(app: FastAPI):
await on_startup()
for pipeline in PIPELINES.values():
if hasattr(pipeline["module"], "on_startup"):
info = await pipeline["module"].on_startup()
if info:
pipeline["id"] = info["id"]
pipeline["name"] = info["name"]
yield
await on_shutdown()
for pipeline in PIPELINES.values():
if hasattr(pipeline["module"], "on_shutdown"):
await pipeline["module"].on_shutdown()
app = FastAPI(docs_url="/docs", redoc_url=None, lifespan=lifespan)
@ -59,17 +89,18 @@ async def check_url(request: Request, call_next):
@app.get("/v1/models")
async def get_models():
"""
Returns the model that is available inside Dialog in the OpenAI format.
Returns the available pipelines
"""
return {
"data": [
{
"id": MODEL_ID,
"name": MODEL_NAME,
"id": pipeline["id"],
"name": pipeline["name"],
"object": "model",
"created": int(time.time()),
"owned_by": "openai",
}
for pipeline in PIPELINES.values()
]
}
@ -79,9 +110,18 @@ async def get_models():
async def generate_openai_chat_completion(form_data: OpenAIChatCompletionForm):
user_message = get_last_user_message(form_data.messages)
if form_data.model not in PIPELINES:
return HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Model {form_data.model} not found",
)
get_response = PIPELINES[form_data.model]["module"].get_response
if form_data.stream:
def stream_content():
res = get_response(user_message, messages=form_data.messages)
if isinstance(res, str):
@ -108,9 +148,7 @@ async def generate_openai_chat_completion(form_data: OpenAIChatCompletionForm):
return StreamingResponse(stream_content(), media_type="text/event-stream")
else:
res = get_response(user_message, messages=form_data.messages)
message = ""
if isinstance(res, str):

View File

@ -0,0 +1,29 @@
from typing import List, Union, Generator
from schemas import OpenAIChatMessage
def get_response(
user_message: str, messages: List[OpenAIChatMessage]
) -> Union[str, Generator]:
# This is where you can add your custom pipelines like RAG.
print(messages)
print(user_message)
return f"pipeline response to: {user_message}"
async def on_startup():
# This function is called when the server is started.
print(f"on_startup:{__name__}")
# Optional: return pipeline metadata
# return {
# "id": "pipeline_id",
# "name": "pipeline_name",
# }
async def on_shutdown():
# This function is called when the server is stopped.
pass

View File

@ -10,11 +10,14 @@ def get_response(
print(messages)
print(user_message)
return f"rag response to: {user_message}"
return f"pipeline response to: {user_message}"
async def on_startup():
# This function is called when the server is started.
print("onstartup")
print(__name__)
pass