This commit is contained in:
Timothy Jaeryang Baek
2024-12-11 17:50:48 -08:00
parent 87d695caad
commit 3ec0a58cd7
5 changed files with 300 additions and 723 deletions

View File

@@ -10,15 +10,15 @@ from aiocache import cached
import requests
from fastapi import Depends, FastAPI, HTTPException, Request, APIRouter
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse, StreamingResponse
from pydantic import BaseModel
from starlette.background import BackgroundTask
from open_webui.models.models import Models
from open_webui.config import (
CACHE_DIR,
CORS_ALLOW_ORIGIN,
ENABLE_OPENAI_API,
OPENAI_API_BASE_URLS,
OPENAI_API_KEYS,
OPENAI_API_CONFIGS,
AppConfig,
)
from open_webui.env import (
AIOHTTP_CLIENT_TIMEOUT,
@@ -29,11 +29,7 @@ from open_webui.env import (
from open_webui.constants import ERROR_MESSAGES
from open_webui.env import ENV, SRC_LOG_LEVELS
from fastapi import Depends, FastAPI, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse, StreamingResponse
from pydantic import BaseModel
from starlette.background import BackgroundTask
from open_webui.utils.payload import (
apply_model_params_to_body_openai,
@@ -48,13 +44,69 @@ log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["OPENAI"])
@app.get("/config")
async def get_config(user=Depends(get_admin_user)):
##########################################
#
# Utility functions
#
##########################################
async def send_get_request(url, key=None):
timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST)
try:
async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
async with session.get(
url, headers={**({"Authorization": f"Bearer {key}"} if key else {})}
) as response:
return await response.json()
except Exception as e:
# Handle connection error here
log.error(f"Connection error: {e}")
return None
async def cleanup_response(
response: Optional[aiohttp.ClientResponse],
session: Optional[aiohttp.ClientSession],
):
if response:
response.close()
if session:
await session.close()
def openai_o1_handler(payload):
"""
Handle O1 specific parameters
"""
if "max_tokens" in payload:
# Remove "max_tokens" from the payload
payload["max_completion_tokens"] = payload["max_tokens"]
del payload["max_tokens"]
# Fix: O1 does not support the "system" parameter, Modify "system" to "user"
if payload["messages"][0]["role"] == "system":
payload["messages"][0]["role"] = "user"
return payload
##########################################
#
# API routes
#
##########################################
router = APIRouter()
@router.get("/config")
async def get_config(request: Request, user=Depends(get_admin_user)):
return {
"ENABLE_OPENAI_API": app.state.config.ENABLE_OPENAI_API,
"OPENAI_API_BASE_URLS": app.state.config.OPENAI_API_BASE_URLS,
"OPENAI_API_KEYS": app.state.config.OPENAI_API_KEYS,
"OPENAI_API_CONFIGS": app.state.config.OPENAI_API_CONFIGS,
"ENABLE_OPENAI_API": request.app.state.config.ENABLE_OPENAI_API,
"OPENAI_API_BASE_URLS": request.app.state.config.OPENAI_API_BASE_URLS,
"OPENAI_API_KEYS": request.app.state.config.OPENAI_API_KEYS,
"OPENAI_API_CONFIGS": request.app.state.config.OPENAI_API_CONFIGS,
}
@@ -65,49 +117,56 @@ class OpenAIConfigForm(BaseModel):
OPENAI_API_CONFIGS: dict
@app.post("/config/update")
async def update_config(form_data: OpenAIConfigForm, user=Depends(get_admin_user)):
app.state.config.ENABLE_OPENAI_API = form_data.ENABLE_OPENAI_API
app.state.config.OPENAI_API_BASE_URLS = form_data.OPENAI_API_BASE_URLS
app.state.config.OPENAI_API_KEYS = form_data.OPENAI_API_KEYS
@router.post("/config/update")
async def update_config(
request: Request, form_data: OpenAIConfigForm, user=Depends(get_admin_user)
):
request.app.state.config.ENABLE_OPENAI_API = form_data.ENABLE_OPENAI_API
request.app.state.config.OPENAI_API_BASE_URLS = form_data.OPENAI_API_BASE_URLS
request.app.state.config.OPENAI_API_KEYS = form_data.OPENAI_API_KEYS
# Check if API KEYS length is same than API URLS length
if len(app.state.config.OPENAI_API_KEYS) != len(
app.state.config.OPENAI_API_BASE_URLS
if len(request.app.state.config.OPENAI_API_KEYS) != len(
request.app.state.config.OPENAI_API_BASE_URLS
):
if len(app.state.config.OPENAI_API_KEYS) > len(
app.state.config.OPENAI_API_BASE_URLS
if len(request.app.state.config.OPENAI_API_KEYS) > len(
request.app.state.config.OPENAI_API_BASE_URLS
):
app.state.config.OPENAI_API_KEYS = app.state.config.OPENAI_API_KEYS[
: len(app.state.config.OPENAI_API_BASE_URLS)
]
request.app.state.config.OPENAI_API_KEYS = (
request.app.state.config.OPENAI_API_KEYS[
: len(request.app.state.config.OPENAI_API_BASE_URLS)
]
)
else:
app.state.config.OPENAI_API_KEYS += [""] * (
len(app.state.config.OPENAI_API_BASE_URLS)
- len(app.state.config.OPENAI_API_KEYS)
request.app.state.config.OPENAI_API_KEYS += [""] * (
len(request.app.state.config.OPENAI_API_BASE_URLS)
- len(request.app.state.config.OPENAI_API_KEYS)
)
app.state.config.OPENAI_API_CONFIGS = form_data.OPENAI_API_CONFIGS
request.app.state.config.OPENAI_API_CONFIGS = form_data.OPENAI_API_CONFIGS
# Remove any extra configs
config_urls = app.state.config.OPENAI_API_CONFIGS.keys()
for idx, url in enumerate(app.state.config.OPENAI_API_BASE_URLS):
config_urls = request.app.state.config.OPENAI_API_CONFIGS.keys()
for idx, url in enumerate(request.app.state.config.OPENAI_API_BASE_URLS):
if url not in config_urls:
app.state.config.OPENAI_API_CONFIGS.pop(url, None)
request.app.state.config.OPENAI_API_CONFIGS.pop(url, None)
return {
"ENABLE_OPENAI_API": app.state.config.ENABLE_OPENAI_API,
"OPENAI_API_BASE_URLS": app.state.config.OPENAI_API_BASE_URLS,
"OPENAI_API_KEYS": app.state.config.OPENAI_API_KEYS,
"OPENAI_API_CONFIGS": app.state.config.OPENAI_API_CONFIGS,
"ENABLE_OPENAI_API": request.app.state.config.ENABLE_OPENAI_API,
"OPENAI_API_BASE_URLS": request.app.state.config.OPENAI_API_BASE_URLS,
"OPENAI_API_KEYS": request.app.state.config.OPENAI_API_KEYS,
"OPENAI_API_CONFIGS": request.app.state.config.OPENAI_API_CONFIGS,
}
@app.post("/audio/speech")
@router.post("/audio/speech")
async def speech(request: Request, user=Depends(get_verified_user)):
idx = None
try:
idx = app.state.config.OPENAI_API_BASE_URLS.index("https://api.openai.com/v1")
idx = request.app.state.config.OPENAI_API_BASE_URLS.index(
"https://api.openai.com/v1"
)
body = await request.body()
name = hashlib.sha256(body).hexdigest()
@@ -120,23 +179,35 @@ async def speech(request: Request, user=Depends(get_verified_user)):
if file_path.is_file():
return FileResponse(file_path)
headers = {}
headers["Authorization"] = f"Bearer {app.state.config.OPENAI_API_KEYS[idx]}"
headers["Content-Type"] = "application/json"
if "openrouter.ai" in app.state.config.OPENAI_API_BASE_URLS[idx]:
headers["HTTP-Referer"] = "https://openwebui.com/"
headers["X-Title"] = "Open WebUI"
if ENABLE_FORWARD_USER_INFO_HEADERS:
headers["X-OpenWebUI-User-Name"] = user.name
headers["X-OpenWebUI-User-Id"] = user.id
headers["X-OpenWebUI-User-Email"] = user.email
headers["X-OpenWebUI-User-Role"] = user.role
url = request.app.state.config.OPENAI_API_BASE_URLS[idx]
r = None
try:
r = requests.post(
url=f"{app.state.config.OPENAI_API_BASE_URLS[idx]}/audio/speech",
url=f"{url}/audio/speech",
data=body,
headers=headers,
headers={
"Content-Type": "application/json",
"Authorization": f"Bearer {request.app.state.config.OPENAI_API_KEYS[idx]}",
**(
{
"HTTP-Referer": "https://openwebui.com/",
"X-Title": "Open WebUI",
}
if "openrouter.ai" in url
else {}
),
**(
{
"X-OpenWebUI-User-Name": user.name,
"X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role,
}
if ENABLE_FORWARD_USER_INFO_HEADERS
else {}
),
},
stream=True,
)
@@ -155,46 +226,25 @@ async def speech(request: Request, user=Depends(get_verified_user)):
except Exception as e:
log.exception(e)
error_detail = "Open WebUI: Server Connection Error"
detail = None
if r is not None:
try:
res = r.json()
if "error" in res:
error_detail = f"External: {res['error']}"
detail = f"External: {res['error']}"
except Exception:
error_detail = f"External: {e}"
detail = f"External: {e}"
raise HTTPException(
status_code=r.status_code if r else 500, detail=error_detail
status_code=r.status_code if r else 500,
detail=detail if detail else "Open WebUI: Server Connection Error",
)
except ValueError:
raise HTTPException(status_code=401, detail=ERROR_MESSAGES.OPENAI_NOT_FOUND)
async def aiohttp_get(url, key=None):
timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST)
try:
headers = {"Authorization": f"Bearer {key}"} if key else {}
async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
async with session.get(url, headers=headers) as response:
return await response.json()
except Exception as e:
# Handle connection error here
log.error(f"Connection error: {e}")
return None
async def cleanup_response(
response: Optional[aiohttp.ClientResponse],
session: Optional[aiohttp.ClientSession],
):
if response:
response.close()
if session:
await session.close()
def merge_models_lists(model_lists):
log.debug(f"merge_models_lists {model_lists}")
merged_list = []
@@ -212,7 +262,7 @@ def merge_models_lists(model_lists):
}
for model in models
if "api.openai.com"
not in app.state.config.OPENAI_API_BASE_URLS[idx]
not in request.app.state.config.OPENAI_API_BASE_URLS[idx]
or not any(
name in model["id"]
for name in [
@@ -230,40 +280,43 @@ def merge_models_lists(model_lists):
return merged_list
async def get_all_models_responses() -> list:
if not app.state.config.ENABLE_OPENAI_API:
async def get_all_models_responses(request: Request) -> list:
if not request.app.state.config.ENABLE_OPENAI_API:
return []
# Check if API KEYS length is same than API URLS length
num_urls = len(app.state.config.OPENAI_API_BASE_URLS)
num_keys = len(app.state.config.OPENAI_API_KEYS)
num_urls = len(request.app.state.config.OPENAI_API_BASE_URLS)
num_keys = len(request.app.state.config.OPENAI_API_KEYS)
if num_keys != num_urls:
# if there are more keys than urls, remove the extra keys
if num_keys > num_urls:
new_keys = app.state.config.OPENAI_API_KEYS[:num_urls]
app.state.config.OPENAI_API_KEYS = new_keys
new_keys = request.app.state.config.OPENAI_API_KEYS[:num_urls]
request.app.state.config.OPENAI_API_KEYS = new_keys
# if there are more urls than keys, add empty keys
else:
app.state.config.OPENAI_API_KEYS += [""] * (num_urls - num_keys)
request.app.state.config.OPENAI_API_KEYS += [""] * (num_urls - num_keys)
tasks = []
for idx, url in enumerate(app.state.config.OPENAI_API_BASE_URLS):
if url not in app.state.config.OPENAI_API_CONFIGS:
tasks.append(
aiohttp_get(f"{url}/models", app.state.config.OPENAI_API_KEYS[idx])
request_tasks = []
for idx, url in enumerate(request.app.state.config.OPENAI_API_BASE_URLS):
if url not in request.app.state.config.OPENAI_API_CONFIGS:
request_tasks.append(
send_get_request(
f"{url}/models", request.app.state.config.OPENAI_API_KEYS[idx]
)
)
else:
api_config = app.state.config.OPENAI_API_CONFIGS.get(url, {})
api_config = request.app.state.config.OPENAI_API_CONFIGS.get(url, {})
enable = api_config.get("enable", True)
model_ids = api_config.get("model_ids", [])
if enable:
if len(model_ids) == 0:
tasks.append(
aiohttp_get(
f"{url}/models", app.state.config.OPENAI_API_KEYS[idx]
request_tasks.append(
send_get_request(
f"{url}/models",
request.app.state.config.OPENAI_API_KEYS[idx],
)
)
else:
@@ -281,16 +334,18 @@ async def get_all_models_responses() -> list:
],
}
tasks.append(asyncio.ensure_future(asyncio.sleep(0, model_list)))
request_tasks.append(
asyncio.ensure_future(asyncio.sleep(0, model_list))
)
else:
tasks.append(asyncio.ensure_future(asyncio.sleep(0, None)))
request_tasks.append(asyncio.ensure_future(asyncio.sleep(0, None)))
responses = await asyncio.gather(*tasks)
responses = await asyncio.gather(*request_tasks)
for idx, response in enumerate(responses):
if response:
url = app.state.config.OPENAI_API_BASE_URLS[idx]
api_config = app.state.config.OPENAI_API_CONFIGS.get(url, {})
url = request.app.state.config.OPENAI_API_BASE_URLS[idx]
api_config = request.app.state.config.OPENAI_API_CONFIGS.get(url, {})
prefix_id = api_config.get("prefix_id", None)
@@ -301,15 +356,27 @@ async def get_all_models_responses() -> list:
model["id"] = f"{prefix_id}.{model['id']}"
log.debug(f"get_all_models:responses() {responses}")
return responses
async def get_filtered_models(models, user):
# Filter models based on user access control
filtered_models = []
for model in models.get("data", []):
model_info = Models.get_model_by_id(model["id"])
if model_info:
if user.id == model_info.user_id or has_access(
user.id, type="read", access_control=model_info.access_control
):
filtered_models.append(model)
return filtered_models
@cached(ttl=3)
async def get_all_models() -> dict[str, list]:
async def get_all_models(request: Request) -> dict[str, list]:
log.info("get_all_models()")
if not app.state.config.ENABLE_OPENAI_API:
if not request.app.state.config.ENABLE_OPENAI_API:
return {"data": []}
responses = await get_all_models_responses()
@@ -324,12 +391,15 @@ async def get_all_models() -> dict[str, list]:
models = {"data": merge_models_lists(map(extract_data, responses))}
log.debug(f"models: {models}")
request.app.state.OPENAI_MODELS = {model["id"]: model for model in models["data"]}
return models
@app.get("/models")
@app.get("/models/{url_idx}")
async def get_models(url_idx: Optional[int] = None, user=Depends(get_verified_user)):
@router.get("/models")
@router.get("/models/{url_idx}")
async def get_models(
request: Request, url_idx: Optional[int] = None, user=Depends(get_verified_user)
):
models = {
"data": [],
}
@@ -337,25 +407,33 @@ async def get_models(url_idx: Optional[int] = None, user=Depends(get_verified_us
if url_idx is None:
models = await get_all_models()
else:
url = app.state.config.OPENAI_API_BASE_URLS[url_idx]
key = app.state.config.OPENAI_API_KEYS[url_idx]
headers = {}
headers["Authorization"] = f"Bearer {key}"
headers["Content-Type"] = "application/json"
if ENABLE_FORWARD_USER_INFO_HEADERS:
headers["X-OpenWebUI-User-Name"] = user.name
headers["X-OpenWebUI-User-Id"] = user.id
headers["X-OpenWebUI-User-Email"] = user.email
headers["X-OpenWebUI-User-Role"] = user.role
url = request.app.state.config.OPENAI_API_BASE_URLS[url_idx]
key = request.app.state.config.OPENAI_API_KEYS[url_idx]
r = None
timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST)
async with aiohttp.ClientSession(timeout=timeout) as session:
async with aiohttp.ClientSession(
timeout=aiohttp.ClientTimeout(
total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST
)
) as session:
try:
async with session.get(f"{url}/models", headers=headers) as r:
async with session.get(
f"{url}/models",
headers={
"Authorization": f"Bearer {key}",
"Content-Type": "application/json",
**(
{
"X-OpenWebUI-User-Name": user.name,
"X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role,
}
if ENABLE_FORWARD_USER_INFO_HEADERS
else {}
),
},
) as r:
if r.status != 200:
# Extract response error details if available
error_detail = f"HTTP Error: {r.status}"
@@ -389,27 +467,16 @@ async def get_models(url_idx: Optional[int] = None, user=Depends(get_verified_us
except aiohttp.ClientError as e:
# ClientError covers all aiohttp requests issues
log.exception(f"Client error: {str(e)}")
# Handle aiohttp-specific connection issues, timeout etc.
raise HTTPException(
status_code=500, detail="Open WebUI: Server Connection Error"
)
except Exception as e:
log.exception(f"Unexpected error: {e}")
# Generic error handler in case parsing JSON or other steps fail
error_detail = f"Unexpected error: {str(e)}"
raise HTTPException(status_code=500, detail=error_detail)
if user.role == "user" and not BYPASS_MODEL_ACCESS_CONTROL:
# Filter models based on user access control
filtered_models = []
for model in models.get("data", []):
model_info = Models.get_model_by_id(model["id"])
if model_info:
if user.id == model_info.user_id or has_access(
user.id, type="read", access_control=model_info.access_control
):
filtered_models.append(model)
models["data"] = filtered_models
models["data"] = get_filtered_models(models, user)
return models
@@ -419,21 +486,24 @@ class ConnectionVerificationForm(BaseModel):
key: str
@app.post("/verify")
@router.post("/verify")
async def verify_connection(
form_data: ConnectionVerificationForm, user=Depends(get_admin_user)
):
url = form_data.url
key = form_data.key
headers = {}
headers["Authorization"] = f"Bearer {key}"
headers["Content-Type"] = "application/json"
timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST)
async with aiohttp.ClientSession(timeout=timeout) as session:
async with aiohttp.ClientSession(
timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST)
) as session:
try:
async with session.get(f"{url}/models", headers=headers) as r:
async with session.get(
f"{url}/models",
headers={
"Authorization": f"Bearer {key}",
"Content-Type": "application/json",
},
) as r:
if r.status != 200:
# Extract response error details if available
error_detail = f"HTTP Error: {r.status}"
@@ -448,26 +518,24 @@ async def verify_connection(
except aiohttp.ClientError as e:
# ClientError covers all aiohttp requests issues
log.exception(f"Client error: {str(e)}")
# Handle aiohttp-specific connection issues, timeout etc.
raise HTTPException(
status_code=500, detail="Open WebUI: Server Connection Error"
)
except Exception as e:
log.exception(f"Unexpected error: {e}")
# Generic error handler in case parsing JSON or other steps fail
error_detail = f"Unexpected error: {str(e)}"
raise HTTPException(status_code=500, detail=error_detail)
@app.post("/chat/completions")
@router.post("/chat/completions")
async def generate_chat_completion(
request: Request,
form_data: dict,
user=Depends(get_verified_user),
bypass_filter: Optional[bool] = False,
):
idx = 0
payload = {**form_data}
if "metadata" in payload:
del payload["metadata"]
@@ -502,15 +570,7 @@ async def generate_chat_completion(
detail="Model not found",
)
# Attemp to get urlIdx from the model
models = await get_all_models()
# Find the model from the list
model = next(
(model for model in models["data"] if model["id"] == payload.get("model")),
None,
)
model = request.app.state.OPENAI_MODELS.get(model_id)
if model:
idx = model["urlIdx"]
else:
@@ -520,11 +580,11 @@ async def generate_chat_completion(
)
# Get the API config for the model
api_config = app.state.config.OPENAI_API_CONFIGS.get(
app.state.config.OPENAI_API_BASE_URLS[idx], {}
api_config = request.app.state.config.OPENAI_API_CONFIGS.get(
request.app.state.config.OPENAI_API_BASE_URLS[idx], {}
)
prefix_id = api_config.get("prefix_id", None)
prefix_id = api_config.get("prefix_id", None)
if prefix_id:
payload["model"] = payload["model"].replace(f"{prefix_id}.", "")
@@ -537,43 +597,26 @@ async def generate_chat_completion(
"role": user.role,
}
url = app.state.config.OPENAI_API_BASE_URLS[idx]
key = app.state.config.OPENAI_API_KEYS[idx]
url = request.app.state.config.OPENAI_API_BASE_URLS[idx]
key = request.app.state.config.OPENAI_API_KEYS[idx]
# Fix: O1 does not support the "max_tokens" parameter, Modify "max_tokens" to "max_completion_tokens"
is_o1 = payload["model"].lower().startswith("o1-")
# Change max_completion_tokens to max_tokens (Backward compatible)
if "api.openai.com" not in url and not is_o1:
if "max_completion_tokens" in payload:
# Remove "max_completion_tokens" from the payload
payload["max_tokens"] = payload["max_completion_tokens"]
del payload["max_completion_tokens"]
else:
if is_o1 and "max_tokens" in payload:
if is_o1:
payload = openai_o1_handler(payload)
elif "api.openai.com" not in url:
# Remove "max_tokens" from the payload for backward compatibility
if "max_tokens" in payload:
payload["max_completion_tokens"] = payload["max_tokens"]
del payload["max_tokens"]
if "max_tokens" in payload and "max_completion_tokens" in payload:
del payload["max_tokens"]
# Fix: O1 does not support the "system" parameter, Modify "system" to "user"
if is_o1 and payload["messages"][0]["role"] == "system":
payload["messages"][0]["role"] = "user"
# TODO: check if below is needed
# if "max_tokens" in payload and "max_completion_tokens" in payload:
# del payload["max_tokens"]
# Convert the modified body back to JSON
payload = json.dumps(payload)
headers = {}
headers["Authorization"] = f"Bearer {key}"
headers["Content-Type"] = "application/json"
if "openrouter.ai" in app.state.config.OPENAI_API_BASE_URLS[idx]:
headers["HTTP-Referer"] = "https://openwebui.com/"
headers["X-Title"] = "Open WebUI"
if ENABLE_FORWARD_USER_INFO_HEADERS:
headers["X-OpenWebUI-User-Name"] = user.name
headers["X-OpenWebUI-User-Id"] = user.id
headers["X-OpenWebUI-User-Email"] = user.email
headers["X-OpenWebUI-User-Role"] = user.role
r = None
session = None
streaming = False
@@ -583,11 +626,33 @@ async def generate_chat_completion(
session = aiohttp.ClientSession(
trust_env=True, timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT)
)
r = await session.request(
method="POST",
url=f"{url}/chat/completions",
data=payload,
headers=headers,
headers={
"Authorization": f"Bearer {key}",
"Content-Type": "application/json",
**(
{
"HTTP-Referer": "https://openwebui.com/",
"X-Title": "Open WebUI",
}
if "openrouter.ai" in url
else {}
),
**(
{
"X-OpenWebUI-User-Name": user.name,
"X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role,
}
if ENABLE_FORWARD_USER_INFO_HEADERS
else {}
),
},
)
# Check if response is SSE
@@ -612,14 +677,18 @@ async def generate_chat_completion(
return response
except Exception as e:
log.exception(e)
error_detail = "Open WebUI: Server Connection Error"
detail = None
if isinstance(response, dict):
if "error" in response:
error_detail = f"{response['error']['message'] if 'message' in response['error'] else response['error']}"
detail = f"{response['error']['message'] if 'message' in response['error'] else response['error']}"
elif isinstance(response, str):
error_detail = response
detail = response
raise HTTPException(status_code=r.status if r else 500, detail=error_detail)
raise HTTPException(
status_code=r.status if r else 500,
detail=detail if detail else "Open WebUI: Server Connection Error",
)
finally:
if not streaming and session:
if r:
@@ -627,25 +696,17 @@ async def generate_chat_completion(
await session.close()
@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])
@router.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])
async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
idx = 0
"""
Deprecated: proxy all requests to OpenAI API
"""
body = await request.body()
url = app.state.config.OPENAI_API_BASE_URLS[idx]
key = app.state.config.OPENAI_API_KEYS[idx]
target_url = f"{url}/{path}"
headers = {}
headers["Authorization"] = f"Bearer {key}"
headers["Content-Type"] = "application/json"
if ENABLE_FORWARD_USER_INFO_HEADERS:
headers["X-OpenWebUI-User-Name"] = user.name
headers["X-OpenWebUI-User-Id"] = user.id
headers["X-OpenWebUI-User-Email"] = user.email
headers["X-OpenWebUI-User-Role"] = user.role
idx = 0
url = request.app.state.config.OPENAI_API_BASE_URLS[idx]
key = request.app.state.config.OPENAI_API_KEYS[idx]
r = None
session = None
@@ -655,11 +716,23 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
session = aiohttp.ClientSession(trust_env=True)
r = await session.request(
method=request.method,
url=target_url,
url=f"{url}/{path}",
data=body,
headers=headers,
headers={
"Authorization": f"Bearer {key}",
"Content-Type": "application/json",
**(
{
"X-OpenWebUI-User-Name": user.name,
"X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role,
}
if ENABLE_FORWARD_USER_INFO_HEADERS
else {}
),
},
)
r.raise_for_status()
# Check if response is SSE
@@ -676,18 +749,23 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
else:
response_data = await r.json()
return response_data
except Exception as e:
log.exception(e)
error_detail = "Open WebUI: Server Connection Error"
detail = None
if r is not None:
try:
res = await r.json()
print(res)
if "error" in res:
error_detail = f"External: {res['error']['message'] if 'message' in res['error'] else res['error']}"
detail = f"External: {res['error']['message'] if 'message' in res['error'] else res['error']}"
except Exception:
error_detail = f"External: {e}"
raise HTTPException(status_code=r.status if r else 500, detail=error_detail)
detail = f"External: {e}"
raise HTTPException(
status_code=r.status if r else 500,
detail=detail if detail else "Open WebUI: Server Connection Error",
)
finally:
if not streaming and session:
if r: