feat: multi-pipeline support

This commit is contained in:
Timothy J. Baek 2024-05-21 17:22:13 -07:00
parent 92890701f0
commit eaa4112f46
3 changed files with 84 additions and 14 deletions

64
main.py
View File

@ -1,4 +1,4 @@
from fastapi import FastAPI, Request, Depends, status from fastapi import FastAPI, Request, Depends, status, HTTPException
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from starlette.responses import StreamingResponse, Response from starlette.responses import StreamingResponse, Response
@ -14,20 +14,50 @@ from utils import get_last_user_message, stream_message_template
from schemas import OpenAIChatCompletionForm from schemas import OpenAIChatCompletionForm
from config import MODEL_ID, MODEL_NAME from config import MODEL_ID, MODEL_NAME
from pipelines.pipeline import ( import os
get_response, import importlib.util
on_startup,
on_shutdown,
) PIPELINES = {}
def load_modules_from_directory(directory):
for filename in os.listdir(directory):
if filename.endswith(".py"):
module_name = filename[:-3] # Remove the .py extension
module_path = os.path.join(directory, filename)
spec = importlib.util.spec_from_file_location(module_name, module_path)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
yield module
for loaded_module in load_modules_from_directory("./pipelines"):
# Do something with the loaded module
print("Loaded:", loaded_module.__name__)
PIPELINES[loaded_module.__name__] = {
"module": loaded_module,
"id": loaded_module.__name__,
"name": loaded_module.__name__,
}
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
@asynccontextmanager @asynccontextmanager
async def lifespan(app: FastAPI): async def lifespan(app: FastAPI):
await on_startup() for pipeline in PIPELINES.values():
if hasattr(pipeline["module"], "on_startup"):
info = await pipeline["module"].on_startup()
if info:
pipeline["id"] = info["id"]
pipeline["name"] = info["name"]
yield yield
await on_shutdown()
for pipeline in PIPELINES.values():
if hasattr(pipeline["module"], "on_shutdown"):
await pipeline["module"].on_shutdown()
app = FastAPI(docs_url="/docs", redoc_url=None, lifespan=lifespan) app = FastAPI(docs_url="/docs", redoc_url=None, lifespan=lifespan)
@ -59,17 +89,18 @@ async def check_url(request: Request, call_next):
@app.get("/v1/models") @app.get("/v1/models")
async def get_models(): async def get_models():
""" """
Returns the model that is available inside Dialog in the OpenAI format. Returns the available pipelines
""" """
return { return {
"data": [ "data": [
{ {
"id": MODEL_ID, "id": pipeline["id"],
"name": MODEL_NAME, "name": pipeline["name"],
"object": "model", "object": "model",
"created": int(time.time()), "created": int(time.time()),
"owned_by": "openai", "owned_by": "openai",
} }
for pipeline in PIPELINES.values()
] ]
} }
@ -79,9 +110,18 @@ async def get_models():
async def generate_openai_chat_completion(form_data: OpenAIChatCompletionForm): async def generate_openai_chat_completion(form_data: OpenAIChatCompletionForm):
user_message = get_last_user_message(form_data.messages) user_message = get_last_user_message(form_data.messages)
if form_data.model not in PIPELINES:
return HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Model {form_data.model} not found",
)
get_response = PIPELINES[form_data.model]["module"].get_response
if form_data.stream: if form_data.stream:
def stream_content(): def stream_content():
res = get_response(user_message, messages=form_data.messages) res = get_response(user_message, messages=form_data.messages)
if isinstance(res, str): if isinstance(res, str):
@ -108,9 +148,7 @@ async def generate_openai_chat_completion(form_data: OpenAIChatCompletionForm):
return StreamingResponse(stream_content(), media_type="text/event-stream") return StreamingResponse(stream_content(), media_type="text/event-stream")
else: else:
res = get_response(user_message, messages=form_data.messages) res = get_response(user_message, messages=form_data.messages)
message = "" message = ""
if isinstance(res, str): if isinstance(res, str):

View File

@ -0,0 +1,29 @@
from typing import List, Union, Generator
from schemas import OpenAIChatMessage
def get_response(
user_message: str, messages: List[OpenAIChatMessage]
) -> Union[str, Generator]:
# This is where you can add your custom pipelines like RAG.
print(messages)
print(user_message)
return f"pipeline response to: {user_message}"
async def on_startup():
# This function is called when the server is started.
print(f"on_startup:{__name__}")
# Optional: return pipeline metadata
# return {
# "id": "pipeline_id",
# "name": "pipeline_name",
# }
async def on_shutdown():
# This function is called when the server is stopped.
pass

View File

@ -10,11 +10,14 @@ def get_response(
print(messages) print(messages)
print(user_message) print(user_message)
return f"rag response to: {user_message}" return f"pipeline response to: {user_message}"
async def on_startup(): async def on_startup():
# This function is called when the server is started. # This function is called when the server is started.
print("onstartup")
print(__name__)
pass pass