refac
Co-Authored-By: Juan Calderon-Perez <835733+gaby@users.noreply.github.com>
This commit is contained in:
@@ -455,8 +455,13 @@ async def get_all_models_responses(request: Request, user: UserModel) -> list:
|
||||
async def get_filtered_models(models, user, db=None):
|
||||
# Filter models based on user access control
|
||||
model_ids = [model["id"] for model in models.get("data", [])]
|
||||
model_infos = {model_info.id: model_info for model_info in Models.get_models_by_ids(model_ids, db=db)}
|
||||
user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user.id, db=db)}
|
||||
model_infos = {
|
||||
model_info.id: model_info
|
||||
for model_info in Models.get_models_by_ids(model_ids, db=db)
|
||||
}
|
||||
user_group_ids = {
|
||||
group.id for group in Groups.get_groups_by_member_id(user.id, db=db)
|
||||
}
|
||||
|
||||
# Batch-fetch accessible resource IDs in a single query instead of N has_access calls
|
||||
accessible_model_ids = AccessGrants.get_accessible_resource_ids(
|
||||
@@ -1215,6 +1220,115 @@ async def embeddings(request: Request, form_data: dict, user):
|
||||
await cleanup_response(r, session)
|
||||
|
||||
|
||||
@router.post("/responses")
|
||||
async def responses(request: Request, user=Depends(get_verified_user)):
|
||||
"""
|
||||
Forward requests to the OpenAI Responses API endpoint.
|
||||
Routes to the correct upstream backend based on the model field.
|
||||
"""
|
||||
body = await request.body()
|
||||
|
||||
try:
|
||||
payload = json.loads(body)
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
raise HTTPException(status_code=400, detail="Invalid JSON payload")
|
||||
|
||||
if not isinstance(payload, dict):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Invalid payload: expected JSON object",
|
||||
)
|
||||
|
||||
idx = 0
|
||||
model_id = payload.get("model")
|
||||
if model_id:
|
||||
models = request.app.state.OPENAI_MODELS
|
||||
if not models or model_id not in models:
|
||||
await get_all_models(request, user=user)
|
||||
models = request.app.state.OPENAI_MODELS
|
||||
if model_id in models:
|
||||
idx = models[model_id]["urlIdx"]
|
||||
|
||||
url = request.app.state.config.OPENAI_API_BASE_URLS[idx]
|
||||
key = request.app.state.config.OPENAI_API_KEYS[idx]
|
||||
api_config = request.app.state.config.OPENAI_API_CONFIGS.get(
|
||||
str(idx),
|
||||
request.app.state.config.OPENAI_API_CONFIGS.get(url, {}), # Legacy support
|
||||
)
|
||||
|
||||
r = None
|
||||
session = None
|
||||
streaming = False
|
||||
|
||||
try:
|
||||
headers, cookies = await get_headers_and_cookies(
|
||||
request, url, key, api_config, user=user
|
||||
)
|
||||
|
||||
if api_config.get("azure", False):
|
||||
api_version = api_config.get("api_version", "2023-03-15-preview")
|
||||
|
||||
auth_type = api_config.get("auth_type", "bearer")
|
||||
if auth_type not in ("azure_ad", "microsoft_entra_id"):
|
||||
headers["api-key"] = key
|
||||
|
||||
headers["api-version"] = api_version
|
||||
|
||||
model = payload.get("model", "")
|
||||
request_url = (
|
||||
f"{url}/openai/deployments/{model}/responses?api-version={api_version}"
|
||||
)
|
||||
else:
|
||||
request_url = f"{url}/responses"
|
||||
|
||||
session = aiohttp.ClientSession(
|
||||
trust_env=True,
|
||||
timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT),
|
||||
)
|
||||
r = await session.request(
|
||||
method="POST",
|
||||
url=request_url,
|
||||
data=body,
|
||||
headers=headers,
|
||||
cookies=cookies,
|
||||
ssl=AIOHTTP_CLIENT_SESSION_SSL,
|
||||
)
|
||||
|
||||
# Check if response is SSE
|
||||
if "text/event-stream" in r.headers.get("Content-Type", ""):
|
||||
streaming = True
|
||||
return StreamingResponse(
|
||||
stream_wrapper(r, session),
|
||||
status_code=r.status,
|
||||
headers=dict(r.headers),
|
||||
)
|
||||
else:
|
||||
try:
|
||||
response_data = await r.json()
|
||||
except Exception:
|
||||
response_data = await r.text()
|
||||
|
||||
if r.status >= 400:
|
||||
if isinstance(response_data, (dict, list)):
|
||||
return JSONResponse(status_code=r.status, content=response_data)
|
||||
else:
|
||||
return PlainTextResponse(
|
||||
status_code=r.status, content=response_data
|
||||
)
|
||||
|
||||
return response_data
|
||||
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
raise HTTPException(
|
||||
status_code=r.status if r else 500,
|
||||
detail="Open WebUI: Server Connection Error",
|
||||
)
|
||||
finally:
|
||||
if not streaming:
|
||||
await cleanup_response(r, session)
|
||||
|
||||
|
||||
@router.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])
|
||||
async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
|
||||
"""
|
||||
@@ -1223,7 +1337,24 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
|
||||
|
||||
body = await request.body()
|
||||
|
||||
# Parse JSON body to resolve model-based routing
|
||||
payload = None
|
||||
if body:
|
||||
try:
|
||||
payload = json.loads(body)
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
payload = None
|
||||
|
||||
idx = 0
|
||||
model_id = payload.get("model") if isinstance(payload, dict) else None
|
||||
if model_id:
|
||||
models = request.app.state.OPENAI_MODELS
|
||||
if not models or model_id not in models:
|
||||
await get_all_models(request, user=user)
|
||||
models = request.app.state.OPENAI_MODELS
|
||||
if model_id in models:
|
||||
idx = models[model_id]["urlIdx"]
|
||||
|
||||
url = request.app.state.config.OPENAI_API_BASE_URLS[idx]
|
||||
key = request.app.state.config.OPENAI_API_KEYS[idx]
|
||||
api_config = request.app.state.config.OPENAI_API_CONFIGS.get(
|
||||
|
||||
Reference in New Issue
Block a user