mirror of
https://github.com/open-webui/open-webui
synced 2024-11-16 13:40:55 +00:00
feat: pipeline valve support
This commit is contained in:
parent
abce172b9d
commit
cc6d9bb8c0
@ -229,6 +229,83 @@ class RAGMiddleware(BaseHTTPMiddleware):
|
||||
app.add_middleware(RAGMiddleware)
|
||||
|
||||
|
||||
class PipelineMiddleware(BaseHTTPMiddleware):
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
if request.method == "POST" and (
|
||||
"/api/chat" in request.url.path or "/chat/completions" in request.url.path
|
||||
):
|
||||
log.debug(f"request.url.path: {request.url.path}")
|
||||
|
||||
# Read the original request body
|
||||
body = await request.body()
|
||||
# Decode body to string
|
||||
body_str = body.decode("utf-8")
|
||||
# Parse string to JSON
|
||||
data = json.loads(body_str) if body_str else {}
|
||||
|
||||
model_id = data["model"]
|
||||
|
||||
valves = [
|
||||
model
|
||||
for model in app.state.MODELS.values()
|
||||
if "pipeline" in model
|
||||
and model["pipeline"]["type"] == "valve"
|
||||
and model_id
|
||||
in [
|
||||
target_model["id"]
|
||||
for target_model in model["pipeline"]["pipelines"]
|
||||
]
|
||||
]
|
||||
sorted_valves = sorted(valves, key=lambda x: x["pipeline"]["priority"])
|
||||
|
||||
for valve in sorted_valves:
|
||||
try:
|
||||
urlIdx = valve["urlIdx"]
|
||||
|
||||
url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
|
||||
key = openai_app.state.config.OPENAI_API_KEYS[urlIdx]
|
||||
|
||||
if key != "":
|
||||
headers = {"Authorization": f"Bearer {key}"}
|
||||
r = requests.post(
|
||||
f"{url}/valve",
|
||||
headers=headers,
|
||||
json={
|
||||
"model": valve["id"],
|
||||
"body": data,
|
||||
},
|
||||
)
|
||||
|
||||
r.raise_for_status()
|
||||
data = r.json()
|
||||
except Exception as e:
|
||||
# Handle connection error here
|
||||
log.error(f"Connection error: {e}")
|
||||
pass
|
||||
|
||||
modified_body_bytes = json.dumps(data).encode("utf-8")
|
||||
# Replace the request body with the modified one
|
||||
request._body = modified_body_bytes
|
||||
# Set custom header to ensure content-length matches new body length
|
||||
request.headers.__dict__["_list"] = [
|
||||
(b"content-length", str(len(modified_body_bytes)).encode("utf-8")),
|
||||
*[
|
||||
(k, v)
|
||||
for k, v in request.headers.raw
|
||||
if k.lower() != b"content-length"
|
||||
],
|
||||
]
|
||||
|
||||
response = await call_next(request)
|
||||
return response
|
||||
|
||||
async def _receive(self, body: bytes):
|
||||
return {"type": "http.request", "body": body, "more_body": False}
|
||||
|
||||
|
||||
app.add_middleware(PipelineMiddleware)
|
||||
|
||||
|
||||
@app.middleware("http")
|
||||
async def check_url(request: Request, call_next):
|
||||
if len(app.state.MODELS) == 0:
|
||||
@ -332,6 +409,14 @@ async def get_all_models():
|
||||
@app.get("/api/models")
|
||||
async def get_models(user=Depends(get_verified_user)):
|
||||
models = await get_all_models()
|
||||
|
||||
# Filter out valve models
|
||||
models = [
|
||||
model
|
||||
for model in models
|
||||
if "pipeline" not in model or model["pipeline"]["type"] != "valve"
|
||||
]
|
||||
|
||||
if app.state.config.ENABLE_MODEL_FILTER:
|
||||
if user.role == "user":
|
||||
models = list(
|
||||
|
Loading…
Reference in New Issue
Block a user