diff --git a/main.py b/main.py index 002e1ee..595b31d 100644 --- a/main.py +++ b/main.py @@ -1,4 +1,4 @@ -from fastapi import FastAPI, Request, Depends, status +from fastapi import FastAPI, Request, Depends, status, HTTPException from fastapi.middleware.cors import CORSMiddleware from starlette.responses import StreamingResponse, Response @@ -14,20 +14,50 @@ from utils import get_last_user_message, stream_message_template from schemas import OpenAIChatCompletionForm from config import MODEL_ID, MODEL_NAME -from pipelines.pipeline import ( - get_response, - on_startup, - on_shutdown, -) +import os +import importlib.util + + +PIPELINES = {} + + +def load_modules_from_directory(directory): + for filename in os.listdir(directory): + if filename.endswith(".py"): + module_name = filename[:-3] # Remove the .py extension + module_path = os.path.join(directory, filename) + spec = importlib.util.spec_from_file_location(module_name, module_path) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + yield module + + +for loaded_module in load_modules_from_directory("./pipelines"): + # Do something with the loaded module + print("Loaded:", loaded_module.__name__) + PIPELINES[loaded_module.__name__] = { + "module": loaded_module, + "id": loaded_module.__name__, + "name": loaded_module.__name__, + } + from contextlib import asynccontextmanager @asynccontextmanager async def lifespan(app: FastAPI): - await on_startup() + for pipeline in PIPELINES.values(): + if hasattr(pipeline["module"], "on_startup"): + info = await pipeline["module"].on_startup() + if info: + pipeline["id"] = info["id"] + pipeline["name"] = info["name"] yield - await on_shutdown() + + for pipeline in PIPELINES.values(): + if hasattr(pipeline["module"], "on_shutdown"): + await pipeline["module"].on_shutdown() app = FastAPI(docs_url="/docs", redoc_url=None, lifespan=lifespan) @@ -59,17 +89,18 @@ async def check_url(request: Request, call_next): @app.get("/v1/models") async def get_models(): """ - Returns the model that is available inside Dialog in the OpenAI format. + Returns the available pipelines """ return { "data": [ { - "id": MODEL_ID, - "name": MODEL_NAME, + "id": pipeline["id"], + "name": pipeline["name"], "object": "model", "created": int(time.time()), "owned_by": "openai", } + for pipeline in PIPELINES.values() ] } @@ -79,9 +110,18 @@ async def get_models(): async def generate_openai_chat_completion(form_data: OpenAIChatCompletionForm): user_message = get_last_user_message(form_data.messages) + if form_data.model not in PIPELINES: + return HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Model {form_data.model} not found", + ) + + get_response = PIPELINES[form_data.model]["module"].get_response + if form_data.stream: def stream_content(): + res = get_response(user_message, messages=form_data.messages) if isinstance(res, str): @@ -108,9 +148,7 @@ async def generate_openai_chat_completion(form_data: OpenAIChatCompletionForm): return StreamingResponse(stream_content(), media_type="text/event-stream") else: - res = get_response(user_message, messages=form_data.messages) - message = "" if isinstance(res, str): diff --git a/pipelines/examples/pipeline.py b/pipelines/examples/pipeline.py new file mode 100644 index 0000000..f867969 --- /dev/null +++ b/pipelines/examples/pipeline.py @@ -0,0 +1,29 @@ +from typing import List, Union, Generator +from schemas import OpenAIChatMessage + + +def get_response( + user_message: str, messages: List[OpenAIChatMessage] +) -> Union[str, Generator]: + # This is where you can add your custom pipelines like RAG. + + print(messages) + print(user_message) + + return f"pipeline response to: {user_message}" + + +async def on_startup(): + # This function is called when the server is started. + print(f"on_startup:{__name__}") + + # Optional: return pipeline metadata + # return { + # "id": "pipeline_id", + # "name": "pipeline_name", + # } + + +async def on_shutdown(): + # This function is called when the server is stopped. + pass diff --git a/pipelines/pipeline.py b/pipelines/pipeline.py index daa701f..d23a8eb 100644 --- a/pipelines/pipeline.py +++ b/pipelines/pipeline.py @@ -10,11 +10,14 @@ def get_response( print(messages) print(user_message) - return f"rag response to: {user_message}" + return f"pipeline response to: {user_message}" async def on_startup(): # This function is called when the server is started. + print("onstartup") + print(__name__) + pass