mirror of
https://github.com/open-webui/pipelines
synced 2025-05-15 01:45:43 +00:00
feat: multi-pipeline support
This commit is contained in:
parent
92890701f0
commit
eaa4112f46
64
main.py
64
main.py
@ -1,4 +1,4 @@
|
|||||||
from fastapi import FastAPI, Request, Depends, status
|
from fastapi import FastAPI, Request, Depends, status, HTTPException
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
|
||||||
from starlette.responses import StreamingResponse, Response
|
from starlette.responses import StreamingResponse, Response
|
||||||
@ -14,20 +14,50 @@ from utils import get_last_user_message, stream_message_template
|
|||||||
from schemas import OpenAIChatCompletionForm
|
from schemas import OpenAIChatCompletionForm
|
||||||
from config import MODEL_ID, MODEL_NAME
|
from config import MODEL_ID, MODEL_NAME
|
||||||
|
|
||||||
from pipelines.pipeline import (
|
import os
|
||||||
get_response,
|
import importlib.util
|
||||||
on_startup,
|
|
||||||
on_shutdown,
|
|
||||||
)
|
PIPELINES = {}
|
||||||
|
|
||||||
|
|
||||||
|
def load_modules_from_directory(directory):
|
||||||
|
for filename in os.listdir(directory):
|
||||||
|
if filename.endswith(".py"):
|
||||||
|
module_name = filename[:-3] # Remove the .py extension
|
||||||
|
module_path = os.path.join(directory, filename)
|
||||||
|
spec = importlib.util.spec_from_file_location(module_name, module_path)
|
||||||
|
module = importlib.util.module_from_spec(spec)
|
||||||
|
spec.loader.exec_module(module)
|
||||||
|
yield module
|
||||||
|
|
||||||
|
|
||||||
|
for loaded_module in load_modules_from_directory("./pipelines"):
|
||||||
|
# Do something with the loaded module
|
||||||
|
print("Loaded:", loaded_module.__name__)
|
||||||
|
PIPELINES[loaded_module.__name__] = {
|
||||||
|
"module": loaded_module,
|
||||||
|
"id": loaded_module.__name__,
|
||||||
|
"name": loaded_module.__name__,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def lifespan(app: FastAPI):
|
async def lifespan(app: FastAPI):
|
||||||
await on_startup()
|
for pipeline in PIPELINES.values():
|
||||||
|
if hasattr(pipeline["module"], "on_startup"):
|
||||||
|
info = await pipeline["module"].on_startup()
|
||||||
|
if info:
|
||||||
|
pipeline["id"] = info["id"]
|
||||||
|
pipeline["name"] = info["name"]
|
||||||
yield
|
yield
|
||||||
await on_shutdown()
|
|
||||||
|
for pipeline in PIPELINES.values():
|
||||||
|
if hasattr(pipeline["module"], "on_shutdown"):
|
||||||
|
await pipeline["module"].on_shutdown()
|
||||||
|
|
||||||
|
|
||||||
app = FastAPI(docs_url="/docs", redoc_url=None, lifespan=lifespan)
|
app = FastAPI(docs_url="/docs", redoc_url=None, lifespan=lifespan)
|
||||||
@ -59,17 +89,18 @@ async def check_url(request: Request, call_next):
|
|||||||
@app.get("/v1/models")
|
@app.get("/v1/models")
|
||||||
async def get_models():
|
async def get_models():
|
||||||
"""
|
"""
|
||||||
Returns the model that is available inside Dialog in the OpenAI format.
|
Returns the available pipelines
|
||||||
"""
|
"""
|
||||||
return {
|
return {
|
||||||
"data": [
|
"data": [
|
||||||
{
|
{
|
||||||
"id": MODEL_ID,
|
"id": pipeline["id"],
|
||||||
"name": MODEL_NAME,
|
"name": pipeline["name"],
|
||||||
"object": "model",
|
"object": "model",
|
||||||
"created": int(time.time()),
|
"created": int(time.time()),
|
||||||
"owned_by": "openai",
|
"owned_by": "openai",
|
||||||
}
|
}
|
||||||
|
for pipeline in PIPELINES.values()
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -79,9 +110,18 @@ async def get_models():
|
|||||||
async def generate_openai_chat_completion(form_data: OpenAIChatCompletionForm):
|
async def generate_openai_chat_completion(form_data: OpenAIChatCompletionForm):
|
||||||
user_message = get_last_user_message(form_data.messages)
|
user_message = get_last_user_message(form_data.messages)
|
||||||
|
|
||||||
|
if form_data.model not in PIPELINES:
|
||||||
|
return HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail=f"Model {form_data.model} not found",
|
||||||
|
)
|
||||||
|
|
||||||
|
get_response = PIPELINES[form_data.model]["module"].get_response
|
||||||
|
|
||||||
if form_data.stream:
|
if form_data.stream:
|
||||||
|
|
||||||
def stream_content():
|
def stream_content():
|
||||||
|
|
||||||
res = get_response(user_message, messages=form_data.messages)
|
res = get_response(user_message, messages=form_data.messages)
|
||||||
|
|
||||||
if isinstance(res, str):
|
if isinstance(res, str):
|
||||||
@ -108,9 +148,7 @@ async def generate_openai_chat_completion(form_data: OpenAIChatCompletionForm):
|
|||||||
|
|
||||||
return StreamingResponse(stream_content(), media_type="text/event-stream")
|
return StreamingResponse(stream_content(), media_type="text/event-stream")
|
||||||
else:
|
else:
|
||||||
|
|
||||||
res = get_response(user_message, messages=form_data.messages)
|
res = get_response(user_message, messages=form_data.messages)
|
||||||
|
|
||||||
message = ""
|
message = ""
|
||||||
|
|
||||||
if isinstance(res, str):
|
if isinstance(res, str):
|
||||||
|
29
pipelines/examples/pipeline.py
Normal file
29
pipelines/examples/pipeline.py
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
from typing import List, Union, Generator
|
||||||
|
from schemas import OpenAIChatMessage
|
||||||
|
|
||||||
|
|
||||||
|
def get_response(
|
||||||
|
user_message: str, messages: List[OpenAIChatMessage]
|
||||||
|
) -> Union[str, Generator]:
|
||||||
|
# This is where you can add your custom pipelines like RAG.
|
||||||
|
|
||||||
|
print(messages)
|
||||||
|
print(user_message)
|
||||||
|
|
||||||
|
return f"pipeline response to: {user_message}"
|
||||||
|
|
||||||
|
|
||||||
|
async def on_startup():
|
||||||
|
# This function is called when the server is started.
|
||||||
|
print(f"on_startup:{__name__}")
|
||||||
|
|
||||||
|
# Optional: return pipeline metadata
|
||||||
|
# return {
|
||||||
|
# "id": "pipeline_id",
|
||||||
|
# "name": "pipeline_name",
|
||||||
|
# }
|
||||||
|
|
||||||
|
|
||||||
|
async def on_shutdown():
|
||||||
|
# This function is called when the server is stopped.
|
||||||
|
pass
|
@ -10,11 +10,14 @@ def get_response(
|
|||||||
print(messages)
|
print(messages)
|
||||||
print(user_message)
|
print(user_message)
|
||||||
|
|
||||||
return f"rag response to: {user_message}"
|
return f"pipeline response to: {user_message}"
|
||||||
|
|
||||||
|
|
||||||
async def on_startup():
|
async def on_startup():
|
||||||
# This function is called when the server is started.
|
# This function is called when the server is started.
|
||||||
|
print("onstartup")
|
||||||
|
print(__name__)
|
||||||
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user