diff --git a/main.py b/main.py index 2f9241b..5888a6c 100644 --- a/main.py +++ b/main.py @@ -77,6 +77,9 @@ def on_startup(): "id": manifold_pipeline_id, "name": manifold_pipeline_name, "manifold": True, + "valves": ( + pipeline.valves if hasattr(pipeline, "valves") else None + ), } if pipeline.type == "filter": PIPELINES[pipeline_id] = { @@ -92,12 +95,14 @@ def on_startup(): "priority": ( pipeline.priority if hasattr(pipeline, "priority") else 0 ), + "valves": pipeline.valves if hasattr(pipeline, "valves") else None, } else: PIPELINES[pipeline_id] = { "module": pipeline_id, "id": pipeline_id, "name": (pipeline.name if hasattr(pipeline, "name") else pipeline_id), + "valves": pipeline.valves if hasattr(pipeline, "valves") else None, } @@ -160,37 +165,91 @@ async def get_models(): "object": "model", "created": int(time.time()), "owned_by": "openai", - **( - { - "pipeline": { + "pipeline": { + **( + { "type": ( "pipeline" if not pipeline.get("filter") else "filter" ), "pipelines": pipeline.get("pipelines", []), "priority": pipeline.get("priority", 0), } - } - if pipeline.get("filter", False) - else {} - ), + if pipeline.get("filter", False) + else {} + ), + "valves": "valves" in pipeline, + }, } for pipeline in PIPELINES.values() ] } -@app.post("/filter") -@app.post("/v1/filter") -async def filter(form_data: FilterForm): - if form_data.model not in app.state.PIPELINES or not app.state.PIPELINES[ - form_data.model +@app.get("/{pipeline_id}/valves") +async def get_valves(pipeline_id: str): + if pipeline_id not in app.state.PIPELINES or not app.state.PIPELINES[ + pipeline_id + ].get("valves", False): + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Valves {pipeline_id} not found", + ) + + pipeline = PIPELINE_MODULES[pipeline_id] + return pipeline.valves + + +@app.get("/{pipeline_id}/valves/spec") +async def get_valves_spec(pipeline_id: str): + if pipeline_id not in app.state.PIPELINES or not app.state.PIPELINES[ + pipeline_id + ].get("valves", False): + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Valves {pipeline_id} not found", + ) + + pipeline = PIPELINE_MODULES[pipeline_id] + return pipeline.valves.schema() + + +@app.post("/{pipeline_id}/valves/update") +async def update_valves(pipeline_id: str, form_data: dict): + if pipeline_id not in app.state.PIPELINES or not app.state.PIPELINES[ + pipeline_id + ].get("valves", False): + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Valves {pipeline_id} not found", + ) + + pipeline = PIPELINE_MODULES[pipeline_id] + + try: + ValvesModel = pipeline.valves.__class__ + valves = ValvesModel(**form_data.valves) + pipeline.valves = valves + except Exception as e: + print(e) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"{str(e)}", + ) + + return pipeline.valves + + +@app.post("/{pipeline_id}/filter") +async def filter(pipeline_id: str, form_data: FilterForm): + if pipeline_id not in app.state.PIPELINES or not app.state.PIPELINES[ + pipeline_id ].get("filter", False): raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, - detail=f"filter {form_data.model} not found", + detail=f"filter {pipeline_id} not found", ) - pipeline = PIPELINE_MODULES[form_data.model] + pipeline = PIPELINE_MODULES[pipeline_id] try: body = await pipeline.filter(form_data.body, form_data.user) diff --git a/pipelines/rate_limit_filter_pipeline.py b/pipelines/rate_limit_filter_pipeline.py index 9854382..d24ba68 100644 --- a/pipelines/rate_limit_filter_pipeline.py +++ b/pipelines/rate_limit_filter_pipeline.py @@ -10,6 +10,9 @@ class Pipeline: # You can think of filter pipeline as a middleware that can be used to edit the form data before it is sent to the OpenAI API. self.type = "filter" + # Assign a unique identifier to the filter pipeline. + # The identifier must be unique across all filter pipelines. + # The identifier must be an alphanumeric string that can include underscores or hyphens. It cannot contain spaces, special characters, slashes, or backslashes. self.id = "rate_limit_filter_pipeline" self.name = "Rate Limit Filter" diff --git a/schemas.py b/schemas.py index 58c0cc4..c35a75b 100644 --- a/schemas.py +++ b/schemas.py @@ -18,7 +18,6 @@ class OpenAIChatCompletionForm(BaseModel): class FilterForm(BaseModel): - model: str body: dict user: Optional[dict] = None model_config = ConfigDict(extra="allow")