diff --git a/main.py b/main.py index 7e9d66f..3b9fff2 100644 --- a/main.py +++ b/main.py @@ -13,7 +13,7 @@ import json import uuid from utils import get_last_user_message, stream_message_template -from schemas import OpenAIChatCompletionForm +from schemas import ValveForm, OpenAIChatCompletionForm import os import importlib.util @@ -79,11 +79,12 @@ def on_startup(): PIPELINES[loaded_module.__name__] = { "module": pipeline_id, "id": pipeline_id, - "name": ( - pipeline.name - if hasattr(pipeline, "name") - else loaded_module.__name__ + "name": (pipeline.name if hasattr(pipeline, "name") else pipeline_id), + "valve": hasattr(pipeline, "valve"), + "pipelines": ( + pipeline.pipelines if hasattr(pipeline, "pipelines") else [] ), + "priority": pipeline.priority if hasattr(pipeline, "priority") else 0, } @@ -146,23 +147,44 @@ async def get_models(): "object": "model", "created": int(time.time()), "owned_by": "openai", - "pipeline": True, + "pipeline": { + "type": "pipeline" if not pipeline.get("valve") else "valve", + "pipelines": pipeline.get("pipelines", []), + "priority": pipeline.get("priority", 0), + }, } for pipeline in PIPELINES.values() ] } +@app.post("/valve") +@app.post("/v1/valve") +async def valve(form_data: ValveForm): + if form_data.model not in app.state.PIPELINES or not app.state.PIPELINES[ + form_data.model + ].get("valve", False): + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Valve {form_data.model} not found", + ) + + pipeline = PIPELINE_MODULES[form_data.model] + return await pipeline.control_valve(form_data.body) + + @app.post("/chat/completions") @app.post("/v1/chat/completions") async def generate_openai_chat_completion(form_data: OpenAIChatCompletionForm): user_message = get_last_user_message(form_data.messages) messages = [message.model_dump() for message in form_data.messages] - if form_data.model not in app.state.PIPELINES: - return HTTPException( + if form_data.model not in app.state.PIPELINES or app.state.PIPELINES[ + form_data.model + ].get("valve", False): + raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, - detail=f"Model {form_data.model} not found", + detail=f"Pipeline {form_data.model} not found", ) def job(): diff --git a/pipelines/examples/valve_pipeline.py b/pipelines/examples/valve_pipeline.py new file mode 100644 index 0000000..af8a6db --- /dev/null +++ b/pipelines/examples/valve_pipeline.py @@ -0,0 +1,37 @@ +from typing import List, Union, Generator, Iterator +from schemas import OpenAIChatMessage + + +class Pipeline: + def __init__(self): + # Pipeline valves are only compatible with Open WebUI + # You can think of valve pipeline as a middleware that can be used to edit the form data before it is sent to the OpenAI API. + self.valve = True + self.id = "valve_pipeline" + self.name = "Valve" + + # Assign a priority level to the valve pipeline. + # The priority level determines the order in which the valve pipelines are executed. + # The lower the number, the higher the priority. + self.priority = 0 + + # List target pipelines (models) that this valve will be connected to. + self.pipelines = [ + {"id": "llama3:latest"}, + ] + pass + + async def on_startup(self): + # This function is called when the server is started. + print(f"on_startup:{__name__}") + pass + + async def on_shutdown(self): + # This function is called when the server is stopped. + print(f"on_shutdown:{__name__}") + pass + + async def control_valve(self, body: dict) -> dict: + print(f"get_response:{__name__}") + print(body) + return body diff --git a/pipelines/valve_pipeline.py b/pipelines/valve_pipeline.py new file mode 100644 index 0000000..af8a6db --- /dev/null +++ b/pipelines/valve_pipeline.py @@ -0,0 +1,37 @@ +from typing import List, Union, Generator, Iterator +from schemas import OpenAIChatMessage + + +class Pipeline: + def __init__(self): + # Pipeline valves are only compatible with Open WebUI + # You can think of valve pipeline as a middleware that can be used to edit the form data before it is sent to the OpenAI API. + self.valve = True + self.id = "valve_pipeline" + self.name = "Valve" + + # Assign a priority level to the valve pipeline. + # The priority level determines the order in which the valve pipelines are executed. + # The lower the number, the higher the priority. + self.priority = 0 + + # List target pipelines (models) that this valve will be connected to. + self.pipelines = [ + {"id": "llama3:latest"}, + ] + pass + + async def on_startup(self): + # This function is called when the server is started. + print(f"on_startup:{__name__}") + pass + + async def on_shutdown(self): + # This function is called when the server is stopped. + print(f"on_shutdown:{__name__}") + pass + + async def control_valve(self, body: dict) -> dict: + print(f"get_response:{__name__}") + print(body) + return body diff --git a/schemas.py b/schemas.py index 6f90d68..852102b 100644 --- a/schemas.py +++ b/schemas.py @@ -15,3 +15,9 @@ class OpenAIChatCompletionForm(BaseModel): messages: List[OpenAIChatMessage] model_config = ConfigDict(extra="allow") + + +class ValveForm(BaseModel): + model: str + body: dict + model_config = ConfigDict(extra="allow")