mirror of
https://github.com/open-webui/open-webui
synced 2024-11-16 21:42:58 +00:00
feat: pipeline valve support
This commit is contained in:
parent
abce172b9d
commit
cc6d9bb8c0
@ -229,6 +229,83 @@ class RAGMiddleware(BaseHTTPMiddleware):
|
|||||||
app.add_middleware(RAGMiddleware)
|
app.add_middleware(RAGMiddleware)
|
||||||
|
|
||||||
|
|
||||||
|
class PipelineMiddleware(BaseHTTPMiddleware):
|
||||||
|
async def dispatch(self, request: Request, call_next):
|
||||||
|
if request.method == "POST" and (
|
||||||
|
"/api/chat" in request.url.path or "/chat/completions" in request.url.path
|
||||||
|
):
|
||||||
|
log.debug(f"request.url.path: {request.url.path}")
|
||||||
|
|
||||||
|
# Read the original request body
|
||||||
|
body = await request.body()
|
||||||
|
# Decode body to string
|
||||||
|
body_str = body.decode("utf-8")
|
||||||
|
# Parse string to JSON
|
||||||
|
data = json.loads(body_str) if body_str else {}
|
||||||
|
|
||||||
|
model_id = data["model"]
|
||||||
|
|
||||||
|
valves = [
|
||||||
|
model
|
||||||
|
for model in app.state.MODELS.values()
|
||||||
|
if "pipeline" in model
|
||||||
|
and model["pipeline"]["type"] == "valve"
|
||||||
|
and model_id
|
||||||
|
in [
|
||||||
|
target_model["id"]
|
||||||
|
for target_model in model["pipeline"]["pipelines"]
|
||||||
|
]
|
||||||
|
]
|
||||||
|
sorted_valves = sorted(valves, key=lambda x: x["pipeline"]["priority"])
|
||||||
|
|
||||||
|
for valve in sorted_valves:
|
||||||
|
try:
|
||||||
|
urlIdx = valve["urlIdx"]
|
||||||
|
|
||||||
|
url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
|
||||||
|
key = openai_app.state.config.OPENAI_API_KEYS[urlIdx]
|
||||||
|
|
||||||
|
if key != "":
|
||||||
|
headers = {"Authorization": f"Bearer {key}"}
|
||||||
|
r = requests.post(
|
||||||
|
f"{url}/valve",
|
||||||
|
headers=headers,
|
||||||
|
json={
|
||||||
|
"model": valve["id"],
|
||||||
|
"body": data,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
r.raise_for_status()
|
||||||
|
data = r.json()
|
||||||
|
except Exception as e:
|
||||||
|
# Handle connection error here
|
||||||
|
log.error(f"Connection error: {e}")
|
||||||
|
pass
|
||||||
|
|
||||||
|
modified_body_bytes = json.dumps(data).encode("utf-8")
|
||||||
|
# Replace the request body with the modified one
|
||||||
|
request._body = modified_body_bytes
|
||||||
|
# Set custom header to ensure content-length matches new body length
|
||||||
|
request.headers.__dict__["_list"] = [
|
||||||
|
(b"content-length", str(len(modified_body_bytes)).encode("utf-8")),
|
||||||
|
*[
|
||||||
|
(k, v)
|
||||||
|
for k, v in request.headers.raw
|
||||||
|
if k.lower() != b"content-length"
|
||||||
|
],
|
||||||
|
]
|
||||||
|
|
||||||
|
response = await call_next(request)
|
||||||
|
return response
|
||||||
|
|
||||||
|
async def _receive(self, body: bytes):
|
||||||
|
return {"type": "http.request", "body": body, "more_body": False}
|
||||||
|
|
||||||
|
|
||||||
|
app.add_middleware(PipelineMiddleware)
|
||||||
|
|
||||||
|
|
||||||
@app.middleware("http")
|
@app.middleware("http")
|
||||||
async def check_url(request: Request, call_next):
|
async def check_url(request: Request, call_next):
|
||||||
if len(app.state.MODELS) == 0:
|
if len(app.state.MODELS) == 0:
|
||||||
@ -332,6 +409,14 @@ async def get_all_models():
|
|||||||
@app.get("/api/models")
|
@app.get("/api/models")
|
||||||
async def get_models(user=Depends(get_verified_user)):
|
async def get_models(user=Depends(get_verified_user)):
|
||||||
models = await get_all_models()
|
models = await get_all_models()
|
||||||
|
|
||||||
|
# Filter out valve models
|
||||||
|
models = [
|
||||||
|
model
|
||||||
|
for model in models
|
||||||
|
if "pipeline" not in model or model["pipeline"]["type"] != "valve"
|
||||||
|
]
|
||||||
|
|
||||||
if app.state.config.ENABLE_MODEL_FILTER:
|
if app.state.config.ENABLE_MODEL_FILTER:
|
||||||
if user.role == "user":
|
if user.role == "user":
|
||||||
models = list(
|
models = list(
|
||||||
|
Loading…
Reference in New Issue
Block a user