Remove Depends(get_session) from the /chat/completions endpoint to prevent database connections from being held during the entire duration of LLM calls (30-60+ seconds for streaming responses). Previously, the database session was acquired at request start and held until the streaming response completed. Under concurrent load, this exhausted the connection pool, causing QueuePool timeout errors for other database operations. The fix allows Models.get_model_by_id() and has_access() to manage their own short-lived sessions internally, releasing the connection immediately after the quick authorization checks complete - before the slow external LLM API call begins.
1163 lines
38 KiB
Python
1163 lines
38 KiB
Python
import asyncio
|
|
import hashlib
|
|
import json
|
|
import logging
|
|
from typing import Optional
|
|
|
|
import aiohttp
|
|
from aiocache import cached
|
|
import requests
|
|
|
|
from azure.identity import DefaultAzureCredential, get_bearer_token_provider
|
|
|
|
from fastapi import Depends, HTTPException, Request, APIRouter
|
|
from fastapi.responses import (
|
|
FileResponse,
|
|
StreamingResponse,
|
|
JSONResponse,
|
|
PlainTextResponse,
|
|
)
|
|
from pydantic import BaseModel
|
|
from starlette.background import BackgroundTask
|
|
from sqlalchemy.orm import Session
|
|
|
|
from open_webui.internal.db import get_session
|
|
|
|
from open_webui.models.models import Models
|
|
from open_webui.config import (
|
|
CACHE_DIR,
|
|
)
|
|
from open_webui.env import (
|
|
MODELS_CACHE_TTL,
|
|
AIOHTTP_CLIENT_SESSION_SSL,
|
|
AIOHTTP_CLIENT_TIMEOUT,
|
|
AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST,
|
|
ENABLE_FORWARD_USER_INFO_HEADERS,
|
|
BYPASS_MODEL_ACCESS_CONTROL,
|
|
)
|
|
from open_webui.models.users import UserModel
|
|
|
|
from open_webui.constants import ERROR_MESSAGES
|
|
|
|
|
|
from open_webui.utils.payload import (
|
|
apply_model_params_to_body_openai,
|
|
apply_system_prompt_to_body,
|
|
)
|
|
from open_webui.utils.misc import (
|
|
convert_logit_bias_input_to_json,
|
|
stream_chunks_handler,
|
|
)
|
|
|
|
from open_webui.utils.auth import get_admin_user, get_verified_user
|
|
from open_webui.utils.access_control import has_access
|
|
from open_webui.utils.headers import include_user_info_headers
|
|
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
|
|
##########################################
|
|
#
|
|
# Utility functions
|
|
#
|
|
##########################################
|
|
|
|
|
|
async def send_get_request(url, key=None, user: UserModel = None):
|
|
timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST)
|
|
try:
|
|
async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
|
|
headers = {
|
|
**({"Authorization": f"Bearer {key}"} if key else {}),
|
|
}
|
|
|
|
if ENABLE_FORWARD_USER_INFO_HEADERS and user:
|
|
headers = include_user_info_headers(headers, user)
|
|
|
|
async with session.get(
|
|
url,
|
|
headers=headers,
|
|
ssl=AIOHTTP_CLIENT_SESSION_SSL,
|
|
) as response:
|
|
return await response.json()
|
|
except Exception as e:
|
|
# Handle connection error here
|
|
log.error(f"Connection error: {e}")
|
|
return None
|
|
|
|
|
|
async def cleanup_response(
|
|
response: Optional[aiohttp.ClientResponse],
|
|
session: Optional[aiohttp.ClientSession],
|
|
):
|
|
if response:
|
|
response.close()
|
|
if session:
|
|
await session.close()
|
|
|
|
|
|
def openai_reasoning_model_handler(payload):
|
|
"""
|
|
Handle reasoning model specific parameters
|
|
"""
|
|
if "max_tokens" in payload:
|
|
# Convert "max_tokens" to "max_completion_tokens" for all reasoning models
|
|
payload["max_completion_tokens"] = payload["max_tokens"]
|
|
del payload["max_tokens"]
|
|
|
|
# Handle system role conversion based on model type
|
|
if payload["messages"][0]["role"] == "system":
|
|
model_lower = payload["model"].lower()
|
|
# Legacy models use "user" role instead of "system"
|
|
if model_lower.startswith("o1-mini") or model_lower.startswith("o1-preview"):
|
|
payload["messages"][0]["role"] = "user"
|
|
else:
|
|
payload["messages"][0]["role"] = "developer"
|
|
|
|
return payload
|
|
|
|
|
|
async def get_headers_and_cookies(
|
|
request: Request,
|
|
url,
|
|
key=None,
|
|
config=None,
|
|
metadata: Optional[dict] = None,
|
|
user: UserModel = None,
|
|
):
|
|
cookies = {}
|
|
headers = {
|
|
"Content-Type": "application/json",
|
|
**(
|
|
{
|
|
"HTTP-Referer": "https://openwebui.com/",
|
|
"X-Title": "Open WebUI",
|
|
}
|
|
if "openrouter.ai" in url
|
|
else {}
|
|
),
|
|
}
|
|
|
|
if ENABLE_FORWARD_USER_INFO_HEADERS and user:
|
|
headers = include_user_info_headers(headers, user)
|
|
if metadata and metadata.get("chat_id"):
|
|
headers["X-OpenWebUI-Chat-Id"] = metadata.get("chat_id")
|
|
|
|
token = None
|
|
auth_type = config.get("auth_type")
|
|
|
|
if auth_type == "bearer" or auth_type is None:
|
|
# Default to bearer if not specified
|
|
token = f"{key}"
|
|
elif auth_type == "none":
|
|
token = None
|
|
elif auth_type == "session":
|
|
cookies = request.cookies
|
|
token = request.state.token.credentials
|
|
elif auth_type == "system_oauth":
|
|
cookies = request.cookies
|
|
|
|
oauth_token = None
|
|
try:
|
|
if request.cookies.get("oauth_session_id", None):
|
|
oauth_token = await request.app.state.oauth_manager.get_oauth_token(
|
|
user.id,
|
|
request.cookies.get("oauth_session_id", None),
|
|
)
|
|
except Exception as e:
|
|
log.error(f"Error getting OAuth token: {e}")
|
|
|
|
if oauth_token:
|
|
token = f"{oauth_token.get('access_token', '')}"
|
|
|
|
elif auth_type in ("azure_ad", "microsoft_entra_id"):
|
|
token = get_microsoft_entra_id_access_token()
|
|
|
|
if token:
|
|
headers["Authorization"] = f"Bearer {token}"
|
|
|
|
if config.get("headers") and isinstance(config.get("headers"), dict):
|
|
headers = {**headers, **config.get("headers")}
|
|
|
|
return headers, cookies
|
|
|
|
|
|
def get_microsoft_entra_id_access_token():
|
|
"""
|
|
Get Microsoft Entra ID access token using DefaultAzureCredential for Azure OpenAI.
|
|
Returns the token string or None if authentication fails.
|
|
"""
|
|
try:
|
|
token_provider = get_bearer_token_provider(
|
|
DefaultAzureCredential(), "https://cognitiveservices.azure.com/.default"
|
|
)
|
|
return token_provider()
|
|
except Exception as e:
|
|
log.error(f"Error getting Microsoft Entra ID access token: {e}")
|
|
return None
|
|
|
|
|
|
##########################################
|
|
#
|
|
# API routes
|
|
#
|
|
##########################################
|
|
|
|
router = APIRouter()
|
|
|
|
|
|
@router.get("/config")
|
|
async def get_config(request: Request, user=Depends(get_admin_user)):
|
|
return {
|
|
"ENABLE_OPENAI_API": request.app.state.config.ENABLE_OPENAI_API,
|
|
"OPENAI_API_BASE_URLS": request.app.state.config.OPENAI_API_BASE_URLS,
|
|
"OPENAI_API_KEYS": request.app.state.config.OPENAI_API_KEYS,
|
|
"OPENAI_API_CONFIGS": request.app.state.config.OPENAI_API_CONFIGS,
|
|
}
|
|
|
|
|
|
class OpenAIConfigForm(BaseModel):
|
|
ENABLE_OPENAI_API: Optional[bool] = None
|
|
OPENAI_API_BASE_URLS: list[str]
|
|
OPENAI_API_KEYS: list[str]
|
|
OPENAI_API_CONFIGS: dict
|
|
|
|
|
|
@router.post("/config/update")
|
|
async def update_config(
|
|
request: Request, form_data: OpenAIConfigForm, user=Depends(get_admin_user)
|
|
):
|
|
request.app.state.config.ENABLE_OPENAI_API = form_data.ENABLE_OPENAI_API
|
|
request.app.state.config.OPENAI_API_BASE_URLS = form_data.OPENAI_API_BASE_URLS
|
|
request.app.state.config.OPENAI_API_KEYS = form_data.OPENAI_API_KEYS
|
|
|
|
# Check if API KEYS length is same than API URLS length
|
|
if len(request.app.state.config.OPENAI_API_KEYS) != len(
|
|
request.app.state.config.OPENAI_API_BASE_URLS
|
|
):
|
|
if len(request.app.state.config.OPENAI_API_KEYS) > len(
|
|
request.app.state.config.OPENAI_API_BASE_URLS
|
|
):
|
|
request.app.state.config.OPENAI_API_KEYS = (
|
|
request.app.state.config.OPENAI_API_KEYS[
|
|
: len(request.app.state.config.OPENAI_API_BASE_URLS)
|
|
]
|
|
)
|
|
else:
|
|
request.app.state.config.OPENAI_API_KEYS += [""] * (
|
|
len(request.app.state.config.OPENAI_API_BASE_URLS)
|
|
- len(request.app.state.config.OPENAI_API_KEYS)
|
|
)
|
|
|
|
request.app.state.config.OPENAI_API_CONFIGS = form_data.OPENAI_API_CONFIGS
|
|
|
|
# Remove the API configs that are not in the API URLS
|
|
keys = list(map(str, range(len(request.app.state.config.OPENAI_API_BASE_URLS))))
|
|
request.app.state.config.OPENAI_API_CONFIGS = {
|
|
key: value
|
|
for key, value in request.app.state.config.OPENAI_API_CONFIGS.items()
|
|
if key in keys
|
|
}
|
|
|
|
return {
|
|
"ENABLE_OPENAI_API": request.app.state.config.ENABLE_OPENAI_API,
|
|
"OPENAI_API_BASE_URLS": request.app.state.config.OPENAI_API_BASE_URLS,
|
|
"OPENAI_API_KEYS": request.app.state.config.OPENAI_API_KEYS,
|
|
"OPENAI_API_CONFIGS": request.app.state.config.OPENAI_API_CONFIGS,
|
|
}
|
|
|
|
|
|
@router.post("/audio/speech")
|
|
async def speech(request: Request, user=Depends(get_verified_user)):
|
|
idx = None
|
|
try:
|
|
idx = request.app.state.config.OPENAI_API_BASE_URLS.index(
|
|
"https://api.openai.com/v1"
|
|
)
|
|
|
|
body = await request.body()
|
|
name = hashlib.sha256(body).hexdigest()
|
|
|
|
SPEECH_CACHE_DIR = CACHE_DIR / "audio" / "speech"
|
|
SPEECH_CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
|
file_path = SPEECH_CACHE_DIR.joinpath(f"{name}.mp3")
|
|
file_body_path = SPEECH_CACHE_DIR.joinpath(f"{name}.json")
|
|
|
|
# Check if the file already exists in the cache
|
|
if file_path.is_file():
|
|
return FileResponse(file_path)
|
|
|
|
url = request.app.state.config.OPENAI_API_BASE_URLS[idx]
|
|
key = request.app.state.config.OPENAI_API_KEYS[idx]
|
|
api_config = request.app.state.config.OPENAI_API_CONFIGS.get(
|
|
str(idx),
|
|
request.app.state.config.OPENAI_API_CONFIGS.get(url, {}), # Legacy support
|
|
)
|
|
|
|
headers, cookies = await get_headers_and_cookies(
|
|
request, url, key, api_config, user=user
|
|
)
|
|
|
|
r = None
|
|
try:
|
|
r = requests.post(
|
|
url=f"{url}/audio/speech",
|
|
data=body,
|
|
headers=headers,
|
|
cookies=cookies,
|
|
stream=True,
|
|
)
|
|
|
|
r.raise_for_status()
|
|
|
|
# Save the streaming content to a file
|
|
with open(file_path, "wb") as f:
|
|
for chunk in r.iter_content(chunk_size=8192):
|
|
f.write(chunk)
|
|
|
|
with open(file_body_path, "w") as f:
|
|
json.dump(json.loads(body.decode("utf-8")), f)
|
|
|
|
# Return the saved file
|
|
return FileResponse(file_path)
|
|
|
|
except Exception as e:
|
|
log.exception(e)
|
|
|
|
detail = None
|
|
if r is not None:
|
|
try:
|
|
res = r.json()
|
|
if "error" in res:
|
|
detail = f"External: {res['error']}"
|
|
except Exception:
|
|
detail = f"External: {e}"
|
|
|
|
raise HTTPException(
|
|
status_code=r.status_code if r else 500,
|
|
detail=detail if detail else "Open WebUI: Server Connection Error",
|
|
)
|
|
|
|
except ValueError:
|
|
raise HTTPException(status_code=401, detail=ERROR_MESSAGES.OPENAI_NOT_FOUND)
|
|
|
|
|
|
async def get_all_models_responses(request: Request, user: UserModel) -> list:
|
|
if not request.app.state.config.ENABLE_OPENAI_API:
|
|
return []
|
|
|
|
# Check if API KEYS length is same than API URLS length
|
|
num_urls = len(request.app.state.config.OPENAI_API_BASE_URLS)
|
|
num_keys = len(request.app.state.config.OPENAI_API_KEYS)
|
|
|
|
if num_keys != num_urls:
|
|
# if there are more keys than urls, remove the extra keys
|
|
if num_keys > num_urls:
|
|
new_keys = request.app.state.config.OPENAI_API_KEYS[:num_urls]
|
|
request.app.state.config.OPENAI_API_KEYS = new_keys
|
|
# if there are more urls than keys, add empty keys
|
|
else:
|
|
request.app.state.config.OPENAI_API_KEYS += [""] * (num_urls - num_keys)
|
|
|
|
request_tasks = []
|
|
for idx, url in enumerate(request.app.state.config.OPENAI_API_BASE_URLS):
|
|
if (str(idx) not in request.app.state.config.OPENAI_API_CONFIGS) and (
|
|
url not in request.app.state.config.OPENAI_API_CONFIGS # Legacy support
|
|
):
|
|
request_tasks.append(
|
|
send_get_request(
|
|
f"{url}/models",
|
|
request.app.state.config.OPENAI_API_KEYS[idx],
|
|
user=user,
|
|
)
|
|
)
|
|
else:
|
|
api_config = request.app.state.config.OPENAI_API_CONFIGS.get(
|
|
str(idx),
|
|
request.app.state.config.OPENAI_API_CONFIGS.get(
|
|
url, {}
|
|
), # Legacy support
|
|
)
|
|
|
|
enable = api_config.get("enable", True)
|
|
model_ids = api_config.get("model_ids", [])
|
|
|
|
if enable:
|
|
if len(model_ids) == 0:
|
|
request_tasks.append(
|
|
send_get_request(
|
|
f"{url}/models",
|
|
request.app.state.config.OPENAI_API_KEYS[idx],
|
|
user=user,
|
|
)
|
|
)
|
|
else:
|
|
model_list = {
|
|
"object": "list",
|
|
"data": [
|
|
{
|
|
"id": model_id,
|
|
"name": model_id,
|
|
"owned_by": "openai",
|
|
"openai": {"id": model_id},
|
|
"urlIdx": idx,
|
|
}
|
|
for model_id in model_ids
|
|
],
|
|
}
|
|
|
|
request_tasks.append(
|
|
asyncio.ensure_future(asyncio.sleep(0, model_list))
|
|
)
|
|
else:
|
|
request_tasks.append(asyncio.ensure_future(asyncio.sleep(0, None)))
|
|
|
|
responses = await asyncio.gather(*request_tasks)
|
|
|
|
for idx, response in enumerate(responses):
|
|
if response:
|
|
url = request.app.state.config.OPENAI_API_BASE_URLS[idx]
|
|
api_config = request.app.state.config.OPENAI_API_CONFIGS.get(
|
|
str(idx),
|
|
request.app.state.config.OPENAI_API_CONFIGS.get(
|
|
url, {}
|
|
), # Legacy support
|
|
)
|
|
|
|
connection_type = api_config.get("connection_type", "external")
|
|
prefix_id = api_config.get("prefix_id", None)
|
|
tags = api_config.get("tags", [])
|
|
|
|
model_list = (
|
|
response if isinstance(response, list) else response.get("data", [])
|
|
)
|
|
if not isinstance(model_list, list):
|
|
# Catch non-list responses
|
|
model_list = []
|
|
|
|
for model in model_list:
|
|
# Remove name key if its value is None #16689
|
|
if "name" in model and model["name"] is None:
|
|
del model["name"]
|
|
|
|
if prefix_id:
|
|
model["id"] = (
|
|
f"{prefix_id}.{model.get('id', model.get('name', ''))}"
|
|
)
|
|
|
|
if tags:
|
|
model["tags"] = tags
|
|
|
|
if connection_type:
|
|
model["connection_type"] = connection_type
|
|
|
|
log.debug(f"get_all_models:responses() {responses}")
|
|
return responses
|
|
|
|
|
|
async def get_filtered_models(models, user, db=None):
|
|
# Filter models based on user access control
|
|
filtered_models = []
|
|
for model in models.get("data", []):
|
|
model_info = Models.get_model_by_id(model["id"], db=db)
|
|
if model_info:
|
|
if user.id == model_info.user_id or has_access(
|
|
user.id, type="read", access_control=model_info.access_control, db=db
|
|
):
|
|
filtered_models.append(model)
|
|
return filtered_models
|
|
|
|
|
|
@cached(
|
|
ttl=MODELS_CACHE_TTL,
|
|
key=lambda _, user: f"openai_all_models_{user.id}" if user else "openai_all_models",
|
|
)
|
|
async def get_all_models(request: Request, user: UserModel) -> dict[str, list]:
|
|
log.info("get_all_models()")
|
|
|
|
if not request.app.state.config.ENABLE_OPENAI_API:
|
|
return {"data": []}
|
|
|
|
responses = await get_all_models_responses(request, user=user)
|
|
|
|
def extract_data(response):
|
|
if response and "data" in response:
|
|
return response["data"]
|
|
if isinstance(response, list):
|
|
return response
|
|
return None
|
|
|
|
def is_supported_openai_models(model_id):
|
|
if any(
|
|
name in model_id
|
|
for name in [
|
|
"babbage",
|
|
"dall-e",
|
|
"davinci",
|
|
"embedding",
|
|
"tts",
|
|
"whisper",
|
|
]
|
|
):
|
|
return False
|
|
return True
|
|
|
|
def get_merged_models(model_lists):
|
|
log.debug(f"merge_models_lists {model_lists}")
|
|
models = {}
|
|
|
|
for idx, model_list in enumerate(model_lists):
|
|
if model_list is not None and "error" not in model_list:
|
|
for model in model_list:
|
|
model_id = model.get("id") or model.get("name")
|
|
|
|
if (
|
|
"api.openai.com"
|
|
in request.app.state.config.OPENAI_API_BASE_URLS[idx]
|
|
and not is_supported_openai_models(model_id)
|
|
):
|
|
# Skip unwanted OpenAI models
|
|
continue
|
|
|
|
if model_id and model_id not in models:
|
|
models[model_id] = {
|
|
**model,
|
|
"name": model.get("name", model_id),
|
|
"owned_by": "openai",
|
|
"openai": model,
|
|
"connection_type": model.get("connection_type", "external"),
|
|
"urlIdx": idx,
|
|
}
|
|
|
|
return models
|
|
|
|
models = get_merged_models(map(extract_data, responses))
|
|
log.debug(f"models: {models}")
|
|
|
|
request.app.state.OPENAI_MODELS = models
|
|
return {"data": list(models.values())}
|
|
|
|
|
|
@router.get("/models")
|
|
@router.get("/models/{url_idx}")
|
|
async def get_models(
|
|
request: Request, url_idx: Optional[int] = None, user=Depends(get_verified_user)
|
|
):
|
|
models = {
|
|
"data": [],
|
|
}
|
|
|
|
if url_idx is None:
|
|
models = await get_all_models(request, user=user)
|
|
else:
|
|
url = request.app.state.config.OPENAI_API_BASE_URLS[url_idx]
|
|
key = request.app.state.config.OPENAI_API_KEYS[url_idx]
|
|
|
|
api_config = request.app.state.config.OPENAI_API_CONFIGS.get(
|
|
str(url_idx),
|
|
request.app.state.config.OPENAI_API_CONFIGS.get(url, {}), # Legacy support
|
|
)
|
|
|
|
r = None
|
|
async with aiohttp.ClientSession(
|
|
trust_env=True,
|
|
timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST),
|
|
) as session:
|
|
try:
|
|
headers, cookies = await get_headers_and_cookies(
|
|
request, url, key, api_config, user=user
|
|
)
|
|
|
|
if api_config.get("azure", False):
|
|
models = {
|
|
"data": api_config.get("model_ids", []) or [],
|
|
"object": "list",
|
|
}
|
|
else:
|
|
async with session.get(
|
|
f"{url}/models",
|
|
headers=headers,
|
|
cookies=cookies,
|
|
ssl=AIOHTTP_CLIENT_SESSION_SSL,
|
|
) as r:
|
|
if r.status != 200:
|
|
# Extract response error details if available
|
|
error_detail = f"HTTP Error: {r.status}"
|
|
res = await r.json()
|
|
if "error" in res:
|
|
error_detail = f"External Error: {res['error']}"
|
|
raise Exception(error_detail)
|
|
|
|
response_data = await r.json()
|
|
|
|
# Check if we're calling OpenAI API based on the URL
|
|
if "api.openai.com" in url:
|
|
# Filter models according to the specified conditions
|
|
response_data["data"] = [
|
|
model
|
|
for model in response_data.get("data", [])
|
|
if not any(
|
|
name in model["id"]
|
|
for name in [
|
|
"babbage",
|
|
"dall-e",
|
|
"davinci",
|
|
"embedding",
|
|
"tts",
|
|
"whisper",
|
|
]
|
|
)
|
|
]
|
|
|
|
models = response_data
|
|
except aiohttp.ClientError as e:
|
|
# ClientError covers all aiohttp requests issues
|
|
log.exception(f"Client error: {str(e)}")
|
|
raise HTTPException(
|
|
status_code=500, detail="Open WebUI: Server Connection Error"
|
|
)
|
|
except Exception as e:
|
|
log.exception(f"Unexpected error: {e}")
|
|
error_detail = f"Unexpected error: {str(e)}"
|
|
raise HTTPException(status_code=500, detail=error_detail)
|
|
|
|
if user.role == "user" and not BYPASS_MODEL_ACCESS_CONTROL:
|
|
models["data"] = await get_filtered_models(models, user)
|
|
|
|
return models
|
|
|
|
|
|
class ConnectionVerificationForm(BaseModel):
|
|
url: str
|
|
key: str
|
|
|
|
config: Optional[dict] = None
|
|
|
|
|
|
@router.post("/verify")
|
|
async def verify_connection(
|
|
request: Request,
|
|
form_data: ConnectionVerificationForm,
|
|
user=Depends(get_admin_user),
|
|
):
|
|
url = form_data.url
|
|
key = form_data.key
|
|
|
|
api_config = form_data.config or {}
|
|
|
|
async with aiohttp.ClientSession(
|
|
trust_env=True,
|
|
timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST),
|
|
) as session:
|
|
try:
|
|
headers, cookies = await get_headers_and_cookies(
|
|
request, url, key, api_config, user=user
|
|
)
|
|
|
|
if api_config.get("azure", False):
|
|
# Only set api-key header if not using Azure Entra ID authentication
|
|
auth_type = api_config.get("auth_type", "bearer")
|
|
if auth_type not in ("azure_ad", "microsoft_entra_id"):
|
|
headers["api-key"] = key
|
|
|
|
api_version = api_config.get("api_version", "") or "2023-03-15-preview"
|
|
async with session.get(
|
|
url=f"{url}/openai/models?api-version={api_version}",
|
|
headers=headers,
|
|
cookies=cookies,
|
|
ssl=AIOHTTP_CLIENT_SESSION_SSL,
|
|
) as r:
|
|
try:
|
|
response_data = await r.json()
|
|
except Exception:
|
|
response_data = await r.text()
|
|
|
|
if r.status != 200:
|
|
if isinstance(response_data, (dict, list)):
|
|
return JSONResponse(
|
|
status_code=r.status, content=response_data
|
|
)
|
|
else:
|
|
return PlainTextResponse(
|
|
status_code=r.status, content=response_data
|
|
)
|
|
|
|
return response_data
|
|
else:
|
|
async with session.get(
|
|
f"{url}/models",
|
|
headers=headers,
|
|
cookies=cookies,
|
|
ssl=AIOHTTP_CLIENT_SESSION_SSL,
|
|
) as r:
|
|
try:
|
|
response_data = await r.json()
|
|
except Exception:
|
|
response_data = await r.text()
|
|
|
|
if r.status != 200:
|
|
if isinstance(response_data, (dict, list)):
|
|
return JSONResponse(
|
|
status_code=r.status, content=response_data
|
|
)
|
|
else:
|
|
return PlainTextResponse(
|
|
status_code=r.status, content=response_data
|
|
)
|
|
|
|
return response_data
|
|
|
|
except aiohttp.ClientError as e:
|
|
# ClientError covers all aiohttp requests issues
|
|
log.exception(f"Client error: {str(e)}")
|
|
raise HTTPException(
|
|
status_code=500, detail="Open WebUI: Server Connection Error"
|
|
)
|
|
except Exception as e:
|
|
log.exception(f"Unexpected error: {e}")
|
|
raise HTTPException(
|
|
status_code=500, detail="Open WebUI: Server Connection Error"
|
|
)
|
|
|
|
|
|
def get_azure_allowed_params(api_version: str) -> set[str]:
|
|
allowed_params = {
|
|
"messages",
|
|
"temperature",
|
|
"role",
|
|
"content",
|
|
"contentPart",
|
|
"contentPartImage",
|
|
"enhancements",
|
|
"dataSources",
|
|
"n",
|
|
"stream",
|
|
"stop",
|
|
"max_tokens",
|
|
"presence_penalty",
|
|
"frequency_penalty",
|
|
"logit_bias",
|
|
"user",
|
|
"function_call",
|
|
"functions",
|
|
"tools",
|
|
"tool_choice",
|
|
"top_p",
|
|
"log_probs",
|
|
"top_logprobs",
|
|
"response_format",
|
|
"seed",
|
|
"max_completion_tokens",
|
|
"reasoning_effort",
|
|
}
|
|
|
|
try:
|
|
if api_version >= "2024-09-01-preview":
|
|
allowed_params.add("stream_options")
|
|
except ValueError:
|
|
log.debug(
|
|
f"Invalid API version {api_version} for Azure OpenAI. Defaulting to allowed parameters."
|
|
)
|
|
|
|
return allowed_params
|
|
|
|
|
|
def is_openai_reasoning_model(model: str) -> bool:
|
|
return model.lower().startswith(("o1", "o3", "o4", "gpt-5"))
|
|
|
|
|
|
def convert_to_azure_payload(url, payload: dict, api_version: str):
|
|
model = payload.get("model", "")
|
|
|
|
# Filter allowed parameters based on Azure OpenAI API
|
|
allowed_params = get_azure_allowed_params(api_version)
|
|
|
|
# Special handling for o-series models
|
|
if is_openai_reasoning_model(model):
|
|
# Convert max_tokens to max_completion_tokens for o-series models
|
|
if "max_tokens" in payload:
|
|
payload["max_completion_tokens"] = payload["max_tokens"]
|
|
del payload["max_tokens"]
|
|
|
|
# Remove temperature if not 1 for o-series models
|
|
if "temperature" in payload and payload["temperature"] != 1:
|
|
log.debug(
|
|
f"Removing temperature parameter for o-series model {model} as only default value (1) is supported"
|
|
)
|
|
del payload["temperature"]
|
|
|
|
# Filter out unsupported parameters
|
|
payload = {k: v for k, v in payload.items() if k in allowed_params}
|
|
|
|
url = f"{url}/openai/deployments/{model}"
|
|
return url, payload
|
|
|
|
|
|
@router.post("/chat/completions")
|
|
async def generate_chat_completion(
|
|
request: Request,
|
|
form_data: dict,
|
|
user=Depends(get_verified_user),
|
|
bypass_filter: Optional[bool] = False,
|
|
bypass_system_prompt: bool = False,
|
|
):
|
|
# NOTE: We intentionally do NOT use Depends(get_session) here.
|
|
# Database operations (get_model_by_id, has_access) manage their own short-lived sessions.
|
|
# This prevents holding a connection during the entire LLM call (30-60+ seconds),
|
|
# which would exhaust the connection pool under concurrent load.
|
|
if BYPASS_MODEL_ACCESS_CONTROL:
|
|
bypass_filter = True
|
|
|
|
idx = 0
|
|
|
|
payload = {**form_data}
|
|
metadata = payload.pop("metadata", None)
|
|
|
|
model_id = form_data.get("model")
|
|
model_info = Models.get_model_by_id(model_id)
|
|
|
|
# Check model info and override the payload
|
|
if model_info:
|
|
if model_info.base_model_id:
|
|
base_model_id = (
|
|
request.base_model_id
|
|
if hasattr(request, "base_model_id")
|
|
else model_info.base_model_id
|
|
) # Use request's base_model_id if available
|
|
payload["model"] = base_model_id
|
|
model_id = base_model_id
|
|
|
|
params = model_info.params.model_dump()
|
|
|
|
if params:
|
|
system = params.pop("system", None)
|
|
|
|
payload = apply_model_params_to_body_openai(params, payload)
|
|
if not bypass_system_prompt:
|
|
payload = apply_system_prompt_to_body(system, payload, metadata, user)
|
|
|
|
# Check if user has access to the model
|
|
if not bypass_filter and user.role == "user":
|
|
if not (
|
|
user.id == model_info.user_id
|
|
or has_access(
|
|
user.id,
|
|
type="read",
|
|
access_control=model_info.access_control,
|
|
)
|
|
):
|
|
raise HTTPException(
|
|
status_code=403,
|
|
detail="Model not found",
|
|
)
|
|
elif not bypass_filter:
|
|
if user.role != "admin":
|
|
raise HTTPException(
|
|
status_code=403,
|
|
detail="Model not found",
|
|
)
|
|
|
|
await get_all_models(request, user=user)
|
|
model = request.app.state.OPENAI_MODELS.get(model_id)
|
|
if model:
|
|
idx = model["urlIdx"]
|
|
else:
|
|
raise HTTPException(
|
|
status_code=404,
|
|
detail="Model not found",
|
|
)
|
|
|
|
# Get the API config for the model
|
|
api_config = request.app.state.config.OPENAI_API_CONFIGS.get(
|
|
str(idx),
|
|
request.app.state.config.OPENAI_API_CONFIGS.get(
|
|
request.app.state.config.OPENAI_API_BASE_URLS[idx], {}
|
|
), # Legacy support
|
|
)
|
|
|
|
prefix_id = api_config.get("prefix_id", None)
|
|
if prefix_id:
|
|
payload["model"] = payload["model"].replace(f"{prefix_id}.", "")
|
|
|
|
# Add user info to the payload if the model is a pipeline
|
|
if "pipeline" in model and model.get("pipeline"):
|
|
payload["user"] = {
|
|
"name": user.name,
|
|
"id": user.id,
|
|
"email": user.email,
|
|
"role": user.role,
|
|
}
|
|
|
|
url = request.app.state.config.OPENAI_API_BASE_URLS[idx]
|
|
key = request.app.state.config.OPENAI_API_KEYS[idx]
|
|
|
|
# Check if model is a reasoning model that needs special handling
|
|
if is_openai_reasoning_model(payload["model"]):
|
|
payload = openai_reasoning_model_handler(payload)
|
|
elif "api.openai.com" not in url:
|
|
# Remove "max_completion_tokens" from the payload for backward compatibility
|
|
if "max_completion_tokens" in payload:
|
|
payload["max_tokens"] = payload["max_completion_tokens"]
|
|
del payload["max_completion_tokens"]
|
|
|
|
if "max_tokens" in payload and "max_completion_tokens" in payload:
|
|
del payload["max_tokens"]
|
|
|
|
# Convert the modified body back to JSON
|
|
if "logit_bias" in payload and payload["logit_bias"]:
|
|
logit_bias = convert_logit_bias_input_to_json(payload["logit_bias"])
|
|
|
|
if logit_bias:
|
|
payload["logit_bias"] = json.loads(logit_bias)
|
|
|
|
headers, cookies = await get_headers_and_cookies(
|
|
request, url, key, api_config, metadata, user=user
|
|
)
|
|
|
|
if api_config.get("azure", False):
|
|
api_version = api_config.get("api_version", "2023-03-15-preview")
|
|
request_url, payload = convert_to_azure_payload(url, payload, api_version)
|
|
|
|
# Only set api-key header if not using Azure Entra ID authentication
|
|
auth_type = api_config.get("auth_type", "bearer")
|
|
if auth_type not in ("azure_ad", "microsoft_entra_id"):
|
|
headers["api-key"] = key
|
|
|
|
headers["api-version"] = api_version
|
|
request_url = f"{request_url}/chat/completions?api-version={api_version}"
|
|
else:
|
|
request_url = f"{url}/chat/completions"
|
|
|
|
payload = json.dumps(payload)
|
|
|
|
r = None
|
|
session = None
|
|
streaming = False
|
|
response = None
|
|
|
|
try:
|
|
session = aiohttp.ClientSession(
|
|
trust_env=True, timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT)
|
|
)
|
|
|
|
r = await session.request(
|
|
method="POST",
|
|
url=request_url,
|
|
data=payload,
|
|
headers=headers,
|
|
cookies=cookies,
|
|
ssl=AIOHTTP_CLIENT_SESSION_SSL,
|
|
)
|
|
|
|
# Check if response is SSE
|
|
if "text/event-stream" in r.headers.get("Content-Type", ""):
|
|
streaming = True
|
|
return StreamingResponse(
|
|
stream_chunks_handler(r.content),
|
|
status_code=r.status,
|
|
headers=dict(r.headers),
|
|
background=BackgroundTask(
|
|
cleanup_response, response=r, session=session
|
|
),
|
|
)
|
|
else:
|
|
try:
|
|
response = await r.json()
|
|
except Exception as e:
|
|
log.error(e)
|
|
response = await r.text()
|
|
|
|
if r.status >= 400:
|
|
if isinstance(response, (dict, list)):
|
|
return JSONResponse(status_code=r.status, content=response)
|
|
else:
|
|
return PlainTextResponse(status_code=r.status, content=response)
|
|
|
|
return response
|
|
except Exception as e:
|
|
log.exception(e)
|
|
|
|
raise HTTPException(
|
|
status_code=r.status if r else 500,
|
|
detail="Open WebUI: Server Connection Error",
|
|
)
|
|
finally:
|
|
if not streaming:
|
|
await cleanup_response(r, session)
|
|
|
|
|
|
async def embeddings(request: Request, form_data: dict, user):
|
|
"""
|
|
Calls the embeddings endpoint for OpenAI-compatible providers.
|
|
|
|
Args:
|
|
request (Request): The FastAPI request context.
|
|
form_data (dict): OpenAI-compatible embeddings payload.
|
|
user (UserModel): The authenticated user.
|
|
|
|
Returns:
|
|
dict: OpenAI-compatible embeddings response.
|
|
"""
|
|
idx = 0
|
|
# Prepare payload/body
|
|
body = json.dumps(form_data)
|
|
# Find correct backend url/key based on model
|
|
await get_all_models(request, user=user)
|
|
model_id = form_data.get("model")
|
|
models = request.app.state.OPENAI_MODELS
|
|
if model_id in models:
|
|
idx = models[model_id]["urlIdx"]
|
|
|
|
url = request.app.state.config.OPENAI_API_BASE_URLS[idx]
|
|
key = request.app.state.config.OPENAI_API_KEYS[idx]
|
|
api_config = request.app.state.config.OPENAI_API_CONFIGS.get(
|
|
str(idx),
|
|
request.app.state.config.OPENAI_API_CONFIGS.get(url, {}), # Legacy support
|
|
)
|
|
|
|
r = None
|
|
session = None
|
|
streaming = False
|
|
|
|
headers, cookies = await get_headers_and_cookies(
|
|
request, url, key, api_config, user=user
|
|
)
|
|
try:
|
|
session = aiohttp.ClientSession(trust_env=True)
|
|
r = await session.request(
|
|
method="POST",
|
|
url=f"{url}/embeddings",
|
|
data=body,
|
|
headers=headers,
|
|
cookies=cookies,
|
|
)
|
|
|
|
if "text/event-stream" in r.headers.get("Content-Type", ""):
|
|
streaming = True
|
|
return StreamingResponse(
|
|
r.content,
|
|
status_code=r.status,
|
|
headers=dict(r.headers),
|
|
background=BackgroundTask(
|
|
cleanup_response, response=r, session=session
|
|
),
|
|
)
|
|
else:
|
|
try:
|
|
response_data = await r.json()
|
|
except Exception:
|
|
response_data = await r.text()
|
|
|
|
if r.status >= 400:
|
|
if isinstance(response_data, (dict, list)):
|
|
return JSONResponse(status_code=r.status, content=response_data)
|
|
else:
|
|
return PlainTextResponse(
|
|
status_code=r.status, content=response_data
|
|
)
|
|
|
|
return response_data
|
|
except Exception as e:
|
|
log.exception(e)
|
|
raise HTTPException(
|
|
status_code=r.status if r else 500,
|
|
detail="Open WebUI: Server Connection Error",
|
|
)
|
|
finally:
|
|
if not streaming:
|
|
await cleanup_response(r, session)
|
|
|
|
|
|
@router.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])
|
|
async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
|
|
"""
|
|
Deprecated: proxy all requests to OpenAI API
|
|
"""
|
|
|
|
body = await request.body()
|
|
|
|
idx = 0
|
|
url = request.app.state.config.OPENAI_API_BASE_URLS[idx]
|
|
key = request.app.state.config.OPENAI_API_KEYS[idx]
|
|
api_config = request.app.state.config.OPENAI_API_CONFIGS.get(
|
|
str(idx),
|
|
request.app.state.config.OPENAI_API_CONFIGS.get(
|
|
request.app.state.config.OPENAI_API_BASE_URLS[idx], {}
|
|
), # Legacy support
|
|
)
|
|
|
|
r = None
|
|
session = None
|
|
streaming = False
|
|
|
|
try:
|
|
headers, cookies = await get_headers_and_cookies(
|
|
request, url, key, api_config, user=user
|
|
)
|
|
|
|
if api_config.get("azure", False):
|
|
api_version = api_config.get("api_version", "2023-03-15-preview")
|
|
|
|
# Only set api-key header if not using Azure Entra ID authentication
|
|
auth_type = api_config.get("auth_type", "bearer")
|
|
if auth_type not in ("azure_ad", "microsoft_entra_id"):
|
|
headers["api-key"] = key
|
|
|
|
headers["api-version"] = api_version
|
|
|
|
payload = json.loads(body)
|
|
url, payload = convert_to_azure_payload(url, payload, api_version)
|
|
body = json.dumps(payload).encode()
|
|
|
|
request_url = f"{url}/{path}?api-version={api_version}"
|
|
else:
|
|
request_url = f"{url}/{path}"
|
|
|
|
session = aiohttp.ClientSession(trust_env=True)
|
|
r = await session.request(
|
|
method=request.method,
|
|
url=request_url,
|
|
data=body,
|
|
headers=headers,
|
|
cookies=cookies,
|
|
ssl=AIOHTTP_CLIENT_SESSION_SSL,
|
|
)
|
|
|
|
# Check if response is SSE
|
|
if "text/event-stream" in r.headers.get("Content-Type", ""):
|
|
streaming = True
|
|
return StreamingResponse(
|
|
r.content,
|
|
status_code=r.status,
|
|
headers=dict(r.headers),
|
|
background=BackgroundTask(
|
|
cleanup_response, response=r, session=session
|
|
),
|
|
)
|
|
else:
|
|
try:
|
|
response_data = await r.json()
|
|
except Exception:
|
|
response_data = await r.text()
|
|
|
|
if r.status >= 400:
|
|
if isinstance(response_data, (dict, list)):
|
|
return JSONResponse(status_code=r.status, content=response_data)
|
|
else:
|
|
return PlainTextResponse(
|
|
status_code=r.status, content=response_data
|
|
)
|
|
|
|
return response_data
|
|
|
|
except Exception as e:
|
|
log.exception(e)
|
|
raise HTTPException(
|
|
status_code=r.status if r else 500,
|
|
detail="Open WebUI: Server Connection Error",
|
|
)
|
|
finally:
|
|
if not streaming:
|
|
await cleanup_response(r, session)
|