feat: pipeline valve support

This commit is contained in:
Timothy J. Baek 2024-05-27 19:03:26 -07:00
parent abce172b9d
commit cc6d9bb8c0

View File

@ -229,6 +229,83 @@ class RAGMiddleware(BaseHTTPMiddleware):
app.add_middleware(RAGMiddleware)
class PipelineMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
if request.method == "POST" and (
"/api/chat" in request.url.path or "/chat/completions" in request.url.path
):
log.debug(f"request.url.path: {request.url.path}")
# Read the original request body
body = await request.body()
# Decode body to string
body_str = body.decode("utf-8")
# Parse string to JSON
data = json.loads(body_str) if body_str else {}
model_id = data["model"]
valves = [
model
for model in app.state.MODELS.values()
if "pipeline" in model
and model["pipeline"]["type"] == "valve"
and model_id
in [
target_model["id"]
for target_model in model["pipeline"]["pipelines"]
]
]
sorted_valves = sorted(valves, key=lambda x: x["pipeline"]["priority"])
for valve in sorted_valves:
try:
urlIdx = valve["urlIdx"]
url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
key = openai_app.state.config.OPENAI_API_KEYS[urlIdx]
if key != "":
headers = {"Authorization": f"Bearer {key}"}
r = requests.post(
f"{url}/valve",
headers=headers,
json={
"model": valve["id"],
"body": data,
},
)
r.raise_for_status()
data = r.json()
except Exception as e:
# Handle connection error here
log.error(f"Connection error: {e}")
pass
modified_body_bytes = json.dumps(data).encode("utf-8")
# Replace the request body with the modified one
request._body = modified_body_bytes
# Set custom header to ensure content-length matches new body length
request.headers.__dict__["_list"] = [
(b"content-length", str(len(modified_body_bytes)).encode("utf-8")),
*[
(k, v)
for k, v in request.headers.raw
if k.lower() != b"content-length"
],
]
response = await call_next(request)
return response
async def _receive(self, body: bytes):
return {"type": "http.request", "body": body, "more_body": False}
app.add_middleware(PipelineMiddleware)
@app.middleware("http")
async def check_url(request: Request, call_next):
if len(app.state.MODELS) == 0:
@ -332,6 +409,14 @@ async def get_all_models():
@app.get("/api/models")
async def get_models(user=Depends(get_verified_user)):
models = await get_all_models()
# Filter out valve models
models = [
model
for model in models
if "pipeline" not in model or model["pipeline"]["type"] != "valve"
]
if app.state.config.ENABLE_MODEL_FILTER:
if user.role == "user":
models = list(