From abc9b63093d65f4d74342db85b7d5df1809aa0f0 Mon Sep 17 00:00:00 2001 From: Timothy Jaeryang Baek Date: Fri, 13 Feb 2026 14:55:13 -0600 Subject: [PATCH] refac Co-Authored-By: Juan Calderon-Perez <835733+gaby@users.noreply.github.com> --- backend/open_webui/routers/openai.py | 135 ++++++++++++++++++++++++++- 1 file changed, 133 insertions(+), 2 deletions(-) diff --git a/backend/open_webui/routers/openai.py b/backend/open_webui/routers/openai.py index f8688a9c9..a5c08b1d6 100644 --- a/backend/open_webui/routers/openai.py +++ b/backend/open_webui/routers/openai.py @@ -455,8 +455,13 @@ async def get_all_models_responses(request: Request, user: UserModel) -> list: async def get_filtered_models(models, user, db=None): # Filter models based on user access control model_ids = [model["id"] for model in models.get("data", [])] - model_infos = {model_info.id: model_info for model_info in Models.get_models_by_ids(model_ids, db=db)} - user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user.id, db=db)} + model_infos = { + model_info.id: model_info + for model_info in Models.get_models_by_ids(model_ids, db=db) + } + user_group_ids = { + group.id for group in Groups.get_groups_by_member_id(user.id, db=db) + } # Batch-fetch accessible resource IDs in a single query instead of N has_access calls accessible_model_ids = AccessGrants.get_accessible_resource_ids( @@ -1215,6 +1220,115 @@ async def embeddings(request: Request, form_data: dict, user): await cleanup_response(r, session) +@router.post("/responses") +async def responses(request: Request, user=Depends(get_verified_user)): + """ + Forward requests to the OpenAI Responses API endpoint. + Routes to the correct upstream backend based on the model field. + """ + body = await request.body() + + try: + payload = json.loads(body) + except (json.JSONDecodeError, ValueError): + raise HTTPException(status_code=400, detail="Invalid JSON payload") + + if not isinstance(payload, dict): + raise HTTPException( + status_code=400, + detail="Invalid payload: expected JSON object", + ) + + idx = 0 + model_id = payload.get("model") + if model_id: + models = request.app.state.OPENAI_MODELS + if not models or model_id not in models: + await get_all_models(request, user=user) + models = request.app.state.OPENAI_MODELS + if model_id in models: + idx = models[model_id]["urlIdx"] + + url = request.app.state.config.OPENAI_API_BASE_URLS[idx] + key = request.app.state.config.OPENAI_API_KEYS[idx] + api_config = request.app.state.config.OPENAI_API_CONFIGS.get( + str(idx), + request.app.state.config.OPENAI_API_CONFIGS.get(url, {}), # Legacy support + ) + + r = None + session = None + streaming = False + + try: + headers, cookies = await get_headers_and_cookies( + request, url, key, api_config, user=user + ) + + if api_config.get("azure", False): + api_version = api_config.get("api_version", "2023-03-15-preview") + + auth_type = api_config.get("auth_type", "bearer") + if auth_type not in ("azure_ad", "microsoft_entra_id"): + headers["api-key"] = key + + headers["api-version"] = api_version + + model = payload.get("model", "") + request_url = ( + f"{url}/openai/deployments/{model}/responses?api-version={api_version}" + ) + else: + request_url = f"{url}/responses" + + session = aiohttp.ClientSession( + trust_env=True, + timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT), + ) + r = await session.request( + method="POST", + url=request_url, + data=body, + headers=headers, + cookies=cookies, + ssl=AIOHTTP_CLIENT_SESSION_SSL, + ) + + # Check if response is SSE + if "text/event-stream" in r.headers.get("Content-Type", ""): + streaming = True + return StreamingResponse( + stream_wrapper(r, session), + status_code=r.status, + headers=dict(r.headers), + ) + else: + try: + response_data = await r.json() + except Exception: + response_data = await r.text() + + if r.status >= 400: + if isinstance(response_data, (dict, list)): + return JSONResponse(status_code=r.status, content=response_data) + else: + return PlainTextResponse( + status_code=r.status, content=response_data + ) + + return response_data + + except Exception as e: + log.exception(e) + raise HTTPException( + status_code=r.status if r else 500, + detail="Open WebUI: Server Connection Error", + ) + finally: + if not streaming: + await cleanup_response(r, session) + + @router.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"]) async def proxy(path: str, request: Request, user=Depends(get_verified_user)): """ @@ -1223,7 +1337,24 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)): body = await request.body() + # Parse JSON body to resolve model-based routing + payload = None + if body: + try: + payload = json.loads(body) + except (json.JSONDecodeError, ValueError): + payload = None + idx = 0 + model_id = payload.get("model") if isinstance(payload, dict) else None + if model_id: + models = request.app.state.OPENAI_MODELS + if not models or model_id not in models: + await get_all_models(request, user=user) + models = request.app.state.OPENAI_MODELS + if model_id in models: + idx = models[model_id]["urlIdx"] + url = request.app.state.config.OPENAI_API_BASE_URLS[idx] key = request.app.state.config.OPENAI_API_KEYS[idx] api_config = request.app.state.config.OPENAI_API_CONFIGS.get(