From cc6d9bb8c06b033e6ed426437b785abb3379d87c Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Mon, 27 May 2024 19:03:26 -0700 Subject: [PATCH] feat: pipeline valve support --- backend/main.py | 85 +++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 85 insertions(+) diff --git a/backend/main.py b/backend/main.py index d9a8c07b7..be6a5d3ba 100644 --- a/backend/main.py +++ b/backend/main.py @@ -229,6 +229,83 @@ class RAGMiddleware(BaseHTTPMiddleware): app.add_middleware(RAGMiddleware) +class PipelineMiddleware(BaseHTTPMiddleware): + async def dispatch(self, request: Request, call_next): + if request.method == "POST" and ( + "/api/chat" in request.url.path or "/chat/completions" in request.url.path + ): + log.debug(f"request.url.path: {request.url.path}") + + # Read the original request body + body = await request.body() + # Decode body to string + body_str = body.decode("utf-8") + # Parse string to JSON + data = json.loads(body_str) if body_str else {} + + model_id = data["model"] + + valves = [ + model + for model in app.state.MODELS.values() + if "pipeline" in model + and model["pipeline"]["type"] == "valve" + and model_id + in [ + target_model["id"] + for target_model in model["pipeline"]["pipelines"] + ] + ] + sorted_valves = sorted(valves, key=lambda x: x["pipeline"]["priority"]) + + for valve in sorted_valves: + try: + urlIdx = valve["urlIdx"] + + url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx] + key = openai_app.state.config.OPENAI_API_KEYS[urlIdx] + + if key != "": + headers = {"Authorization": f"Bearer {key}"} + r = requests.post( + f"{url}/valve", + headers=headers, + json={ + "model": valve["id"], + "body": data, + }, + ) + + r.raise_for_status() + data = r.json() + except Exception as e: + # Handle connection error here + log.error(f"Connection error: {e}") + pass + + modified_body_bytes = json.dumps(data).encode("utf-8") + # Replace the request body with the modified one + request._body = modified_body_bytes + # Set custom header to ensure content-length matches new body length + request.headers.__dict__["_list"] = [ + (b"content-length", str(len(modified_body_bytes)).encode("utf-8")), + *[ + (k, v) + for k, v in request.headers.raw + if k.lower() != b"content-length" + ], + ] + + response = await call_next(request) + return response + + async def _receive(self, body: bytes): + return {"type": "http.request", "body": body, "more_body": False} + + +app.add_middleware(PipelineMiddleware) + + @app.middleware("http") async def check_url(request: Request, call_next): if len(app.state.MODELS) == 0: @@ -332,6 +409,14 @@ async def get_all_models(): @app.get("/api/models") async def get_models(user=Depends(get_verified_user)): models = await get_all_models() + + # Filter out valve models + models = [ + model + for model in models + if "pipeline" not in model or model["pipeline"]["type"] != "valve" + ] + if app.state.config.ENABLE_MODEL_FILTER: if user.role == "user": models = list(