diff --git a/main.py b/main.py index 3b9fff2..798f961 100644 --- a/main.py +++ b/main.py @@ -170,7 +170,7 @@ async def valve(form_data: ValveForm): ) pipeline = PIPELINE_MODULES[form_data.model] - return await pipeline.control_valve(form_data.body) + return await pipeline.control_valve(form_data.body, form_data.user) @app.post("/chat/completions") diff --git a/pipelines/valve_pipeline.py b/pipelines/valve_pipeline.py index af8a6db..d05a586 100644 --- a/pipelines/valve_pipeline.py +++ b/pipelines/valve_pipeline.py @@ -1,4 +1,4 @@ -from typing import List, Union, Generator, Iterator +from typing import List, Optional from schemas import OpenAIChatMessage @@ -31,7 +31,10 @@ class Pipeline: print(f"on_shutdown:{__name__}") pass - async def control_valve(self, body: dict) -> dict: + async def control_valve(self, body: dict, user: Optional[dict] = None) -> dict: print(f"get_response:{__name__}") + print(body) + print(user) + return body diff --git a/schemas.py b/schemas.py index 852102b..84cedd5 100644 --- a/schemas.py +++ b/schemas.py @@ -1,4 +1,4 @@ -from typing import List +from typing import List, Optional from pydantic import BaseModel, ConfigDict @@ -20,4 +20,5 @@ class OpenAIChatCompletionForm(BaseModel): class ValveForm(BaseModel): model: str body: dict + user: Optional[dict] = None model_config = ConfigDict(extra="allow")