This commit is contained in:
Timothy Jaeryang Baek 2024-12-11 17:50:48 -08:00
parent 87d695caad
commit 3ec0a58cd7
5 changed files with 300 additions and 723 deletions

View File

@ -1409,9 +1409,9 @@ app.include_router(ollama.router, prefix="/ollama")
app.include_router(openai.router, prefix="/openai") app.include_router(openai.router, prefix="/openai")
app.include_router(images.router, prefix="/api/v1/images") app.include_router(images.router, prefix="/api/v1/images", tags=["images"])
app.include_router(audio.router, prefix="/api/v1/audio") app.include_router(audio.router, prefix="/api/v1/audio", tags=["audio"])
app.include_router(retrieval.router, prefix="/api/v1/retrieval") app.include_router(retrieval.router, prefix="/api/v1/retrieval", tags=["retrieval"])
app.include_router(configs.router, prefix="/api/v1/configs", tags=["configs"]) app.include_router(configs.router, prefix="/api/v1/configs", tags=["configs"])

View File

@ -1,411 +0,0 @@
from fastapi import APIRouter, Depends, HTTPException, Response, status
from pydantic import BaseModel
router = APIRouter()
@app.post("/api/chat/completions")
async def generate_chat_completions(
request: Request,
form_data: dict,
user=Depends(get_verified_user),
bypass_filter: bool = False,
):
if BYPASS_MODEL_ACCESS_CONTROL:
bypass_filter = True
model_list = request.state.models
models = {model["id"]: model for model in model_list}
model_id = form_data["model"]
if model_id not in models:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Model not found",
)
model = models[model_id]
# Check if user has access to the model
if not bypass_filter and user.role == "user":
if model.get("arena"):
if not has_access(
user.id,
type="read",
access_control=model.get("info", {})
.get("meta", {})
.get("access_control", {}),
):
raise HTTPException(
status_code=403,
detail="Model not found",
)
else:
model_info = Models.get_model_by_id(model_id)
if not model_info:
raise HTTPException(
status_code=404,
detail="Model not found",
)
elif not (
user.id == model_info.user_id
or has_access(
user.id, type="read", access_control=model_info.access_control
)
):
raise HTTPException(
status_code=403,
detail="Model not found",
)
if model["owned_by"] == "arena":
model_ids = model.get("info", {}).get("meta", {}).get("model_ids")
filter_mode = model.get("info", {}).get("meta", {}).get("filter_mode")
if model_ids and filter_mode == "exclude":
model_ids = [
model["id"]
for model in await get_all_models()
if model.get("owned_by") != "arena" and model["id"] not in model_ids
]
selected_model_id = None
if isinstance(model_ids, list) and model_ids:
selected_model_id = random.choice(model_ids)
else:
model_ids = [
model["id"]
for model in await get_all_models()
if model.get("owned_by") != "arena"
]
selected_model_id = random.choice(model_ids)
form_data["model"] = selected_model_id
if form_data.get("stream") == True:
async def stream_wrapper(stream):
yield f"data: {json.dumps({'selected_model_id': selected_model_id})}\n\n"
async for chunk in stream:
yield chunk
response = await generate_chat_completions(
form_data, user, bypass_filter=True
)
return StreamingResponse(
stream_wrapper(response.body_iterator), media_type="text/event-stream"
)
else:
return {
**(
await generate_chat_completions(form_data, user, bypass_filter=True)
),
"selected_model_id": selected_model_id,
}
if model.get("pipe"):
# Below does not require bypass_filter because this is the only route the uses this function and it is already bypassing the filter
return await generate_function_chat_completion(
form_data, user=user, models=models
)
if model["owned_by"] == "ollama":
# Using /ollama/api/chat endpoint
form_data = convert_payload_openai_to_ollama(form_data)
form_data = GenerateChatCompletionForm(**form_data)
response = await generate_ollama_chat_completion(
form_data=form_data, user=user, bypass_filter=bypass_filter
)
if form_data.stream:
response.headers["content-type"] = "text/event-stream"
return StreamingResponse(
convert_streaming_response_ollama_to_openai(response),
headers=dict(response.headers),
)
else:
return convert_response_ollama_to_openai(response)
else:
return await generate_openai_chat_completion(
form_data, user=user, bypass_filter=bypass_filter
)
@app.post("/api/chat/completed")
async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
model_list = await get_all_models()
models = {model["id"]: model for model in model_list}
data = form_data
model_id = data["model"]
if model_id not in models:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Model not found",
)
model = models[model_id]
sorted_filters = get_sorted_filters(model_id, models)
if "pipeline" in model:
sorted_filters = [model] + sorted_filters
for filter in sorted_filters:
r = None
try:
urlIdx = filter["urlIdx"]
url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
key = openai_app.state.config.OPENAI_API_KEYS[urlIdx]
if key != "":
headers = {"Authorization": f"Bearer {key}"}
r = requests.post(
f"{url}/{filter['id']}/filter/outlet",
headers=headers,
json={
"user": {
"id": user.id,
"name": user.name,
"email": user.email,
"role": user.role,
},
"body": data,
},
)
r.raise_for_status()
data = r.json()
except Exception as e:
# Handle connection error here
print(f"Connection error: {e}")
if r is not None:
try:
res = r.json()
if "detail" in res:
return JSONResponse(
status_code=r.status_code,
content=res,
)
except Exception:
pass
else:
pass
__event_emitter__ = get_event_emitter(
{
"chat_id": data["chat_id"],
"message_id": data["id"],
"session_id": data["session_id"],
}
)
__event_call__ = get_event_call(
{
"chat_id": data["chat_id"],
"message_id": data["id"],
"session_id": data["session_id"],
}
)
def get_priority(function_id):
function = Functions.get_function_by_id(function_id)
if function is not None and hasattr(function, "valves"):
# TODO: Fix FunctionModel to include vavles
return (function.valves if function.valves else {}).get("priority", 0)
return 0
filter_ids = [function.id for function in Functions.get_global_filter_functions()]
if "info" in model and "meta" in model["info"]:
filter_ids.extend(model["info"]["meta"].get("filterIds", []))
filter_ids = list(set(filter_ids))
enabled_filter_ids = [
function.id
for function in Functions.get_functions_by_type("filter", active_only=True)
]
filter_ids = [
filter_id for filter_id in filter_ids if filter_id in enabled_filter_ids
]
# Sort filter_ids by priority, using the get_priority function
filter_ids.sort(key=get_priority)
for filter_id in filter_ids:
filter = Functions.get_function_by_id(filter_id)
if not filter:
continue
if filter_id in webui_app.state.FUNCTIONS:
function_module = webui_app.state.FUNCTIONS[filter_id]
else:
function_module, _, _ = load_function_module_by_id(filter_id)
webui_app.state.FUNCTIONS[filter_id] = function_module
if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
valves = Functions.get_function_valves_by_id(filter_id)
function_module.valves = function_module.Valves(
**(valves if valves else {})
)
if not hasattr(function_module, "outlet"):
continue
try:
outlet = function_module.outlet
# Get the signature of the function
sig = inspect.signature(outlet)
params = {"body": data}
# Extra parameters to be passed to the function
extra_params = {
"__model__": model,
"__id__": filter_id,
"__event_emitter__": __event_emitter__,
"__event_call__": __event_call__,
}
# Add extra params in contained in function signature
for key, value in extra_params.items():
if key in sig.parameters:
params[key] = value
if "__user__" in sig.parameters:
__user__ = {
"id": user.id,
"email": user.email,
"name": user.name,
"role": user.role,
}
try:
if hasattr(function_module, "UserValves"):
__user__["valves"] = function_module.UserValves(
**Functions.get_user_valves_by_id_and_user_id(
filter_id, user.id
)
)
except Exception as e:
print(e)
params = {**params, "__user__": __user__}
if inspect.iscoroutinefunction(outlet):
data = await outlet(**params)
else:
data = outlet(**params)
except Exception as e:
print(f"Error: {e}")
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content={"detail": str(e)},
)
return data
@app.post("/api/chat/actions/{action_id}")
async def chat_action(action_id: str, form_data: dict, user=Depends(get_verified_user)):
if "." in action_id:
action_id, sub_action_id = action_id.split(".")
else:
sub_action_id = None
action = Functions.get_function_by_id(action_id)
if not action:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Action not found",
)
model_list = await get_all_models()
models = {model["id"]: model for model in model_list}
data = form_data
model_id = data["model"]
if model_id not in models:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Model not found",
)
model = models[model_id]
__event_emitter__ = get_event_emitter(
{
"chat_id": data["chat_id"],
"message_id": data["id"],
"session_id": data["session_id"],
}
)
__event_call__ = get_event_call(
{
"chat_id": data["chat_id"],
"message_id": data["id"],
"session_id": data["session_id"],
}
)
if action_id in webui_app.state.FUNCTIONS:
function_module = webui_app.state.FUNCTIONS[action_id]
else:
function_module, _, _ = load_function_module_by_id(action_id)
webui_app.state.FUNCTIONS[action_id] = function_module
if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
valves = Functions.get_function_valves_by_id(action_id)
function_module.valves = function_module.Valves(**(valves if valves else {}))
if hasattr(function_module, "action"):
try:
action = function_module.action
# Get the signature of the function
sig = inspect.signature(action)
params = {"body": data}
# Extra parameters to be passed to the function
extra_params = {
"__model__": model,
"__id__": sub_action_id if sub_action_id is not None else action_id,
"__event_emitter__": __event_emitter__,
"__event_call__": __event_call__,
}
# Add extra params in contained in function signature
for key, value in extra_params.items():
if key in sig.parameters:
params[key] = value
if "__user__" in sig.parameters:
__user__ = {
"id": user.id,
"email": user.email,
"name": user.name,
"role": user.role,
}
try:
if hasattr(function_module, "UserValves"):
__user__["valves"] = function_module.UserValves(
**Functions.get_user_valves_by_id_and_user_id(
action_id, user.id
)
)
except Exception as e:
print(e)
params = {**params, "__user__": __user__}
if inspect.iscoroutinefunction(action):
data = await action(**params)
else:
data = action(**params)
except Exception as e:
print(f"Error: {e}")
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content={"detail": str(e)},
)
return data

View File

@ -317,6 +317,9 @@ async def get_all_models(request: Request):
else: else:
models = {"models": []} models = {"models": []}
request.app.state.OLLAMA_MODELS = {
model["model"]: model for model in models["models"]
}
return models return models

View File

@ -10,15 +10,15 @@ from aiocache import cached
import requests import requests
from fastapi import Depends, FastAPI, HTTPException, Request, APIRouter
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse, StreamingResponse
from pydantic import BaseModel
from starlette.background import BackgroundTask
from open_webui.models.models import Models from open_webui.models.models import Models
from open_webui.config import ( from open_webui.config import (
CACHE_DIR, CACHE_DIR,
CORS_ALLOW_ORIGIN,
ENABLE_OPENAI_API,
OPENAI_API_BASE_URLS,
OPENAI_API_KEYS,
OPENAI_API_CONFIGS,
AppConfig,
) )
from open_webui.env import ( from open_webui.env import (
AIOHTTP_CLIENT_TIMEOUT, AIOHTTP_CLIENT_TIMEOUT,
@ -29,11 +29,7 @@ from open_webui.env import (
from open_webui.constants import ERROR_MESSAGES from open_webui.constants import ERROR_MESSAGES
from open_webui.env import ENV, SRC_LOG_LEVELS from open_webui.env import ENV, SRC_LOG_LEVELS
from fastapi import Depends, FastAPI, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse, StreamingResponse
from pydantic import BaseModel
from starlette.background import BackgroundTask
from open_webui.utils.payload import ( from open_webui.utils.payload import (
apply_model_params_to_body_openai, apply_model_params_to_body_openai,
@ -48,13 +44,69 @@ log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["OPENAI"]) log.setLevel(SRC_LOG_LEVELS["OPENAI"])
@app.get("/config") ##########################################
async def get_config(user=Depends(get_admin_user)): #
# Utility functions
#
##########################################
async def send_get_request(url, key=None):
timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST)
try:
async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
async with session.get(
url, headers={**({"Authorization": f"Bearer {key}"} if key else {})}
) as response:
return await response.json()
except Exception as e:
# Handle connection error here
log.error(f"Connection error: {e}")
return None
async def cleanup_response(
response: Optional[aiohttp.ClientResponse],
session: Optional[aiohttp.ClientSession],
):
if response:
response.close()
if session:
await session.close()
def openai_o1_handler(payload):
"""
Handle O1 specific parameters
"""
if "max_tokens" in payload:
# Remove "max_tokens" from the payload
payload["max_completion_tokens"] = payload["max_tokens"]
del payload["max_tokens"]
# Fix: O1 does not support the "system" parameter, Modify "system" to "user"
if payload["messages"][0]["role"] == "system":
payload["messages"][0]["role"] = "user"
return payload
##########################################
#
# API routes
#
##########################################
router = APIRouter()
@router.get("/config")
async def get_config(request: Request, user=Depends(get_admin_user)):
return { return {
"ENABLE_OPENAI_API": app.state.config.ENABLE_OPENAI_API, "ENABLE_OPENAI_API": request.app.state.config.ENABLE_OPENAI_API,
"OPENAI_API_BASE_URLS": app.state.config.OPENAI_API_BASE_URLS, "OPENAI_API_BASE_URLS": request.app.state.config.OPENAI_API_BASE_URLS,
"OPENAI_API_KEYS": app.state.config.OPENAI_API_KEYS, "OPENAI_API_KEYS": request.app.state.config.OPENAI_API_KEYS,
"OPENAI_API_CONFIGS": app.state.config.OPENAI_API_CONFIGS, "OPENAI_API_CONFIGS": request.app.state.config.OPENAI_API_CONFIGS,
} }
@ -65,49 +117,56 @@ class OpenAIConfigForm(BaseModel):
OPENAI_API_CONFIGS: dict OPENAI_API_CONFIGS: dict
@app.post("/config/update") @router.post("/config/update")
async def update_config(form_data: OpenAIConfigForm, user=Depends(get_admin_user)): async def update_config(
app.state.config.ENABLE_OPENAI_API = form_data.ENABLE_OPENAI_API request: Request, form_data: OpenAIConfigForm, user=Depends(get_admin_user)
app.state.config.OPENAI_API_BASE_URLS = form_data.OPENAI_API_BASE_URLS ):
app.state.config.OPENAI_API_KEYS = form_data.OPENAI_API_KEYS request.app.state.config.ENABLE_OPENAI_API = form_data.ENABLE_OPENAI_API
request.app.state.config.OPENAI_API_BASE_URLS = form_data.OPENAI_API_BASE_URLS
request.app.state.config.OPENAI_API_KEYS = form_data.OPENAI_API_KEYS
# Check if API KEYS length is same than API URLS length # Check if API KEYS length is same than API URLS length
if len(app.state.config.OPENAI_API_KEYS) != len( if len(request.app.state.config.OPENAI_API_KEYS) != len(
app.state.config.OPENAI_API_BASE_URLS request.app.state.config.OPENAI_API_BASE_URLS
): ):
if len(app.state.config.OPENAI_API_KEYS) > len( if len(request.app.state.config.OPENAI_API_KEYS) > len(
app.state.config.OPENAI_API_BASE_URLS request.app.state.config.OPENAI_API_BASE_URLS
): ):
app.state.config.OPENAI_API_KEYS = app.state.config.OPENAI_API_KEYS[ request.app.state.config.OPENAI_API_KEYS = (
: len(app.state.config.OPENAI_API_BASE_URLS) request.app.state.config.OPENAI_API_KEYS[
: len(request.app.state.config.OPENAI_API_BASE_URLS)
] ]
)
else: else:
app.state.config.OPENAI_API_KEYS += [""] * ( request.app.state.config.OPENAI_API_KEYS += [""] * (
len(app.state.config.OPENAI_API_BASE_URLS) len(request.app.state.config.OPENAI_API_BASE_URLS)
- len(app.state.config.OPENAI_API_KEYS) - len(request.app.state.config.OPENAI_API_KEYS)
) )
app.state.config.OPENAI_API_CONFIGS = form_data.OPENAI_API_CONFIGS request.app.state.config.OPENAI_API_CONFIGS = form_data.OPENAI_API_CONFIGS
# Remove any extra configs # Remove any extra configs
config_urls = app.state.config.OPENAI_API_CONFIGS.keys() config_urls = request.app.state.config.OPENAI_API_CONFIGS.keys()
for idx, url in enumerate(app.state.config.OPENAI_API_BASE_URLS): for idx, url in enumerate(request.app.state.config.OPENAI_API_BASE_URLS):
if url not in config_urls: if url not in config_urls:
app.state.config.OPENAI_API_CONFIGS.pop(url, None) request.app.state.config.OPENAI_API_CONFIGS.pop(url, None)
return { return {
"ENABLE_OPENAI_API": app.state.config.ENABLE_OPENAI_API, "ENABLE_OPENAI_API": request.app.state.config.ENABLE_OPENAI_API,
"OPENAI_API_BASE_URLS": app.state.config.OPENAI_API_BASE_URLS, "OPENAI_API_BASE_URLS": request.app.state.config.OPENAI_API_BASE_URLS,
"OPENAI_API_KEYS": app.state.config.OPENAI_API_KEYS, "OPENAI_API_KEYS": request.app.state.config.OPENAI_API_KEYS,
"OPENAI_API_CONFIGS": app.state.config.OPENAI_API_CONFIGS, "OPENAI_API_CONFIGS": request.app.state.config.OPENAI_API_CONFIGS,
} }
@app.post("/audio/speech") @router.post("/audio/speech")
async def speech(request: Request, user=Depends(get_verified_user)): async def speech(request: Request, user=Depends(get_verified_user)):
idx = None idx = None
try: try:
idx = app.state.config.OPENAI_API_BASE_URLS.index("https://api.openai.com/v1") idx = request.app.state.config.OPENAI_API_BASE_URLS.index(
"https://api.openai.com/v1"
)
body = await request.body() body = await request.body()
name = hashlib.sha256(body).hexdigest() name = hashlib.sha256(body).hexdigest()
@ -120,23 +179,35 @@ async def speech(request: Request, user=Depends(get_verified_user)):
if file_path.is_file(): if file_path.is_file():
return FileResponse(file_path) return FileResponse(file_path)
headers = {} url = request.app.state.config.OPENAI_API_BASE_URLS[idx]
headers["Authorization"] = f"Bearer {app.state.config.OPENAI_API_KEYS[idx]}"
headers["Content-Type"] = "application/json"
if "openrouter.ai" in app.state.config.OPENAI_API_BASE_URLS[idx]:
headers["HTTP-Referer"] = "https://openwebui.com/"
headers["X-Title"] = "Open WebUI"
if ENABLE_FORWARD_USER_INFO_HEADERS:
headers["X-OpenWebUI-User-Name"] = user.name
headers["X-OpenWebUI-User-Id"] = user.id
headers["X-OpenWebUI-User-Email"] = user.email
headers["X-OpenWebUI-User-Role"] = user.role
r = None r = None
try: try:
r = requests.post( r = requests.post(
url=f"{app.state.config.OPENAI_API_BASE_URLS[idx]}/audio/speech", url=f"{url}/audio/speech",
data=body, data=body,
headers=headers, headers={
"Content-Type": "application/json",
"Authorization": f"Bearer {request.app.state.config.OPENAI_API_KEYS[idx]}",
**(
{
"HTTP-Referer": "https://openwebui.com/",
"X-Title": "Open WebUI",
}
if "openrouter.ai" in url
else {}
),
**(
{
"X-OpenWebUI-User-Name": user.name,
"X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role,
}
if ENABLE_FORWARD_USER_INFO_HEADERS
else {}
),
},
stream=True, stream=True,
) )
@ -155,46 +226,25 @@ async def speech(request: Request, user=Depends(get_verified_user)):
except Exception as e: except Exception as e:
log.exception(e) log.exception(e)
error_detail = "Open WebUI: Server Connection Error"
detail = None
if r is not None: if r is not None:
try: try:
res = r.json() res = r.json()
if "error" in res: if "error" in res:
error_detail = f"External: {res['error']}" detail = f"External: {res['error']}"
except Exception: except Exception:
error_detail = f"External: {e}" detail = f"External: {e}"
raise HTTPException( raise HTTPException(
status_code=r.status_code if r else 500, detail=error_detail status_code=r.status_code if r else 500,
detail=detail if detail else "Open WebUI: Server Connection Error",
) )
except ValueError: except ValueError:
raise HTTPException(status_code=401, detail=ERROR_MESSAGES.OPENAI_NOT_FOUND) raise HTTPException(status_code=401, detail=ERROR_MESSAGES.OPENAI_NOT_FOUND)
async def aiohttp_get(url, key=None):
timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST)
try:
headers = {"Authorization": f"Bearer {key}"} if key else {}
async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
async with session.get(url, headers=headers) as response:
return await response.json()
except Exception as e:
# Handle connection error here
log.error(f"Connection error: {e}")
return None
async def cleanup_response(
response: Optional[aiohttp.ClientResponse],
session: Optional[aiohttp.ClientSession],
):
if response:
response.close()
if session:
await session.close()
def merge_models_lists(model_lists): def merge_models_lists(model_lists):
log.debug(f"merge_models_lists {model_lists}") log.debug(f"merge_models_lists {model_lists}")
merged_list = [] merged_list = []
@ -212,7 +262,7 @@ def merge_models_lists(model_lists):
} }
for model in models for model in models
if "api.openai.com" if "api.openai.com"
not in app.state.config.OPENAI_API_BASE_URLS[idx] not in request.app.state.config.OPENAI_API_BASE_URLS[idx]
or not any( or not any(
name in model["id"] name in model["id"]
for name in [ for name in [
@ -230,40 +280,43 @@ def merge_models_lists(model_lists):
return merged_list return merged_list
async def get_all_models_responses() -> list: async def get_all_models_responses(request: Request) -> list:
if not app.state.config.ENABLE_OPENAI_API: if not request.app.state.config.ENABLE_OPENAI_API:
return [] return []
# Check if API KEYS length is same than API URLS length # Check if API KEYS length is same than API URLS length
num_urls = len(app.state.config.OPENAI_API_BASE_URLS) num_urls = len(request.app.state.config.OPENAI_API_BASE_URLS)
num_keys = len(app.state.config.OPENAI_API_KEYS) num_keys = len(request.app.state.config.OPENAI_API_KEYS)
if num_keys != num_urls: if num_keys != num_urls:
# if there are more keys than urls, remove the extra keys # if there are more keys than urls, remove the extra keys
if num_keys > num_urls: if num_keys > num_urls:
new_keys = app.state.config.OPENAI_API_KEYS[:num_urls] new_keys = request.app.state.config.OPENAI_API_KEYS[:num_urls]
app.state.config.OPENAI_API_KEYS = new_keys request.app.state.config.OPENAI_API_KEYS = new_keys
# if there are more urls than keys, add empty keys # if there are more urls than keys, add empty keys
else: else:
app.state.config.OPENAI_API_KEYS += [""] * (num_urls - num_keys) request.app.state.config.OPENAI_API_KEYS += [""] * (num_urls - num_keys)
tasks = [] request_tasks = []
for idx, url in enumerate(app.state.config.OPENAI_API_BASE_URLS): for idx, url in enumerate(request.app.state.config.OPENAI_API_BASE_URLS):
if url not in app.state.config.OPENAI_API_CONFIGS: if url not in request.app.state.config.OPENAI_API_CONFIGS:
tasks.append( request_tasks.append(
aiohttp_get(f"{url}/models", app.state.config.OPENAI_API_KEYS[idx]) send_get_request(
f"{url}/models", request.app.state.config.OPENAI_API_KEYS[idx]
)
) )
else: else:
api_config = app.state.config.OPENAI_API_CONFIGS.get(url, {}) api_config = request.app.state.config.OPENAI_API_CONFIGS.get(url, {})
enable = api_config.get("enable", True) enable = api_config.get("enable", True)
model_ids = api_config.get("model_ids", []) model_ids = api_config.get("model_ids", [])
if enable: if enable:
if len(model_ids) == 0: if len(model_ids) == 0:
tasks.append( request_tasks.append(
aiohttp_get( send_get_request(
f"{url}/models", app.state.config.OPENAI_API_KEYS[idx] f"{url}/models",
request.app.state.config.OPENAI_API_KEYS[idx],
) )
) )
else: else:
@ -281,16 +334,18 @@ async def get_all_models_responses() -> list:
], ],
} }
tasks.append(asyncio.ensure_future(asyncio.sleep(0, model_list))) request_tasks.append(
asyncio.ensure_future(asyncio.sleep(0, model_list))
)
else: else:
tasks.append(asyncio.ensure_future(asyncio.sleep(0, None))) request_tasks.append(asyncio.ensure_future(asyncio.sleep(0, None)))
responses = await asyncio.gather(*tasks) responses = await asyncio.gather(*request_tasks)
for idx, response in enumerate(responses): for idx, response in enumerate(responses):
if response: if response:
url = app.state.config.OPENAI_API_BASE_URLS[idx] url = request.app.state.config.OPENAI_API_BASE_URLS[idx]
api_config = app.state.config.OPENAI_API_CONFIGS.get(url, {}) api_config = request.app.state.config.OPENAI_API_CONFIGS.get(url, {})
prefix_id = api_config.get("prefix_id", None) prefix_id = api_config.get("prefix_id", None)
@ -301,15 +356,27 @@ async def get_all_models_responses() -> list:
model["id"] = f"{prefix_id}.{model['id']}" model["id"] = f"{prefix_id}.{model['id']}"
log.debug(f"get_all_models:responses() {responses}") log.debug(f"get_all_models:responses() {responses}")
return responses return responses
async def get_filtered_models(models, user):
# Filter models based on user access control
filtered_models = []
for model in models.get("data", []):
model_info = Models.get_model_by_id(model["id"])
if model_info:
if user.id == model_info.user_id or has_access(
user.id, type="read", access_control=model_info.access_control
):
filtered_models.append(model)
return filtered_models
@cached(ttl=3) @cached(ttl=3)
async def get_all_models() -> dict[str, list]: async def get_all_models(request: Request) -> dict[str, list]:
log.info("get_all_models()") log.info("get_all_models()")
if not app.state.config.ENABLE_OPENAI_API: if not request.app.state.config.ENABLE_OPENAI_API:
return {"data": []} return {"data": []}
responses = await get_all_models_responses() responses = await get_all_models_responses()
@ -324,12 +391,15 @@ async def get_all_models() -> dict[str, list]:
models = {"data": merge_models_lists(map(extract_data, responses))} models = {"data": merge_models_lists(map(extract_data, responses))}
log.debug(f"models: {models}") log.debug(f"models: {models}")
request.app.state.OPENAI_MODELS = {model["id"]: model for model in models["data"]}
return models return models
@app.get("/models") @router.get("/models")
@app.get("/models/{url_idx}") @router.get("/models/{url_idx}")
async def get_models(url_idx: Optional[int] = None, user=Depends(get_verified_user)): async def get_models(
request: Request, url_idx: Optional[int] = None, user=Depends(get_verified_user)
):
models = { models = {
"data": [], "data": [],
} }
@ -337,25 +407,33 @@ async def get_models(url_idx: Optional[int] = None, user=Depends(get_verified_us
if url_idx is None: if url_idx is None:
models = await get_all_models() models = await get_all_models()
else: else:
url = app.state.config.OPENAI_API_BASE_URLS[url_idx] url = request.app.state.config.OPENAI_API_BASE_URLS[url_idx]
key = app.state.config.OPENAI_API_KEYS[url_idx] key = request.app.state.config.OPENAI_API_KEYS[url_idx]
headers = {}
headers["Authorization"] = f"Bearer {key}"
headers["Content-Type"] = "application/json"
if ENABLE_FORWARD_USER_INFO_HEADERS:
headers["X-OpenWebUI-User-Name"] = user.name
headers["X-OpenWebUI-User-Id"] = user.id
headers["X-OpenWebUI-User-Email"] = user.email
headers["X-OpenWebUI-User-Role"] = user.role
r = None r = None
async with aiohttp.ClientSession(
timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST) timeout=aiohttp.ClientTimeout(
async with aiohttp.ClientSession(timeout=timeout) as session: total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST
)
) as session:
try: try:
async with session.get(f"{url}/models", headers=headers) as r: async with session.get(
f"{url}/models",
headers={
"Authorization": f"Bearer {key}",
"Content-Type": "application/json",
**(
{
"X-OpenWebUI-User-Name": user.name,
"X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role,
}
if ENABLE_FORWARD_USER_INFO_HEADERS
else {}
),
},
) as r:
if r.status != 200: if r.status != 200:
# Extract response error details if available # Extract response error details if available
error_detail = f"HTTP Error: {r.status}" error_detail = f"HTTP Error: {r.status}"
@ -389,27 +467,16 @@ async def get_models(url_idx: Optional[int] = None, user=Depends(get_verified_us
except aiohttp.ClientError as e: except aiohttp.ClientError as e:
# ClientError covers all aiohttp requests issues # ClientError covers all aiohttp requests issues
log.exception(f"Client error: {str(e)}") log.exception(f"Client error: {str(e)}")
# Handle aiohttp-specific connection issues, timeout etc.
raise HTTPException( raise HTTPException(
status_code=500, detail="Open WebUI: Server Connection Error" status_code=500, detail="Open WebUI: Server Connection Error"
) )
except Exception as e: except Exception as e:
log.exception(f"Unexpected error: {e}") log.exception(f"Unexpected error: {e}")
# Generic error handler in case parsing JSON or other steps fail
error_detail = f"Unexpected error: {str(e)}" error_detail = f"Unexpected error: {str(e)}"
raise HTTPException(status_code=500, detail=error_detail) raise HTTPException(status_code=500, detail=error_detail)
if user.role == "user" and not BYPASS_MODEL_ACCESS_CONTROL: if user.role == "user" and not BYPASS_MODEL_ACCESS_CONTROL:
# Filter models based on user access control models["data"] = get_filtered_models(models, user)
filtered_models = []
for model in models.get("data", []):
model_info = Models.get_model_by_id(model["id"])
if model_info:
if user.id == model_info.user_id or has_access(
user.id, type="read", access_control=model_info.access_control
):
filtered_models.append(model)
models["data"] = filtered_models
return models return models
@ -419,21 +486,24 @@ class ConnectionVerificationForm(BaseModel):
key: str key: str
@app.post("/verify") @router.post("/verify")
async def verify_connection( async def verify_connection(
form_data: ConnectionVerificationForm, user=Depends(get_admin_user) form_data: ConnectionVerificationForm, user=Depends(get_admin_user)
): ):
url = form_data.url url = form_data.url
key = form_data.key key = form_data.key
headers = {} async with aiohttp.ClientSession(
headers["Authorization"] = f"Bearer {key}"
headers["Content-Type"] = "application/json"
timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST) timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST)
async with aiohttp.ClientSession(timeout=timeout) as session: ) as session:
try: try:
async with session.get(f"{url}/models", headers=headers) as r: async with session.get(
f"{url}/models",
headers={
"Authorization": f"Bearer {key}",
"Content-Type": "application/json",
},
) as r:
if r.status != 200: if r.status != 200:
# Extract response error details if available # Extract response error details if available
error_detail = f"HTTP Error: {r.status}" error_detail = f"HTTP Error: {r.status}"
@ -448,26 +518,24 @@ async def verify_connection(
except aiohttp.ClientError as e: except aiohttp.ClientError as e:
# ClientError covers all aiohttp requests issues # ClientError covers all aiohttp requests issues
log.exception(f"Client error: {str(e)}") log.exception(f"Client error: {str(e)}")
# Handle aiohttp-specific connection issues, timeout etc.
raise HTTPException( raise HTTPException(
status_code=500, detail="Open WebUI: Server Connection Error" status_code=500, detail="Open WebUI: Server Connection Error"
) )
except Exception as e: except Exception as e:
log.exception(f"Unexpected error: {e}") log.exception(f"Unexpected error: {e}")
# Generic error handler in case parsing JSON or other steps fail
error_detail = f"Unexpected error: {str(e)}" error_detail = f"Unexpected error: {str(e)}"
raise HTTPException(status_code=500, detail=error_detail) raise HTTPException(status_code=500, detail=error_detail)
@app.post("/chat/completions") @router.post("/chat/completions")
async def generate_chat_completion( async def generate_chat_completion(
request: Request,
form_data: dict, form_data: dict,
user=Depends(get_verified_user), user=Depends(get_verified_user),
bypass_filter: Optional[bool] = False, bypass_filter: Optional[bool] = False,
): ):
idx = 0 idx = 0
payload = {**form_data} payload = {**form_data}
if "metadata" in payload: if "metadata" in payload:
del payload["metadata"] del payload["metadata"]
@ -502,15 +570,7 @@ async def generate_chat_completion(
detail="Model not found", detail="Model not found",
) )
# Attemp to get urlIdx from the model model = request.app.state.OPENAI_MODELS.get(model_id)
models = await get_all_models()
# Find the model from the list
model = next(
(model for model in models["data"] if model["id"] == payload.get("model")),
None,
)
if model: if model:
idx = model["urlIdx"] idx = model["urlIdx"]
else: else:
@ -520,11 +580,11 @@ async def generate_chat_completion(
) )
# Get the API config for the model # Get the API config for the model
api_config = app.state.config.OPENAI_API_CONFIGS.get( api_config = request.app.state.config.OPENAI_API_CONFIGS.get(
app.state.config.OPENAI_API_BASE_URLS[idx], {} request.app.state.config.OPENAI_API_BASE_URLS[idx], {}
) )
prefix_id = api_config.get("prefix_id", None)
prefix_id = api_config.get("prefix_id", None)
if prefix_id: if prefix_id:
payload["model"] = payload["model"].replace(f"{prefix_id}.", "") payload["model"] = payload["model"].replace(f"{prefix_id}.", "")
@ -537,43 +597,26 @@ async def generate_chat_completion(
"role": user.role, "role": user.role,
} }
url = app.state.config.OPENAI_API_BASE_URLS[idx] url = request.app.state.config.OPENAI_API_BASE_URLS[idx]
key = app.state.config.OPENAI_API_KEYS[idx] key = request.app.state.config.OPENAI_API_KEYS[idx]
# Fix: O1 does not support the "max_tokens" parameter, Modify "max_tokens" to "max_completion_tokens" # Fix: O1 does not support the "max_tokens" parameter, Modify "max_tokens" to "max_completion_tokens"
is_o1 = payload["model"].lower().startswith("o1-") is_o1 = payload["model"].lower().startswith("o1-")
# Change max_completion_tokens to max_tokens (Backward compatible) if is_o1:
if "api.openai.com" not in url and not is_o1: payload = openai_o1_handler(payload)
if "max_completion_tokens" in payload: elif "api.openai.com" not in url:
# Remove "max_completion_tokens" from the payload # Remove "max_tokens" from the payload for backward compatibility
payload["max_tokens"] = payload["max_completion_tokens"] if "max_tokens" in payload:
del payload["max_completion_tokens"]
else:
if is_o1 and "max_tokens" in payload:
payload["max_completion_tokens"] = payload["max_tokens"] payload["max_completion_tokens"] = payload["max_tokens"]
del payload["max_tokens"] del payload["max_tokens"]
if "max_tokens" in payload and "max_completion_tokens" in payload:
del payload["max_tokens"]
# Fix: O1 does not support the "system" parameter, Modify "system" to "user" # TODO: check if below is needed
if is_o1 and payload["messages"][0]["role"] == "system": # if "max_tokens" in payload and "max_completion_tokens" in payload:
payload["messages"][0]["role"] = "user" # del payload["max_tokens"]
# Convert the modified body back to JSON # Convert the modified body back to JSON
payload = json.dumps(payload) payload = json.dumps(payload)
headers = {}
headers["Authorization"] = f"Bearer {key}"
headers["Content-Type"] = "application/json"
if "openrouter.ai" in app.state.config.OPENAI_API_BASE_URLS[idx]:
headers["HTTP-Referer"] = "https://openwebui.com/"
headers["X-Title"] = "Open WebUI"
if ENABLE_FORWARD_USER_INFO_HEADERS:
headers["X-OpenWebUI-User-Name"] = user.name
headers["X-OpenWebUI-User-Id"] = user.id
headers["X-OpenWebUI-User-Email"] = user.email
headers["X-OpenWebUI-User-Role"] = user.role
r = None r = None
session = None session = None
streaming = False streaming = False
@ -583,11 +626,33 @@ async def generate_chat_completion(
session = aiohttp.ClientSession( session = aiohttp.ClientSession(
trust_env=True, timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT) trust_env=True, timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT)
) )
r = await session.request( r = await session.request(
method="POST", method="POST",
url=f"{url}/chat/completions", url=f"{url}/chat/completions",
data=payload, data=payload,
headers=headers, headers={
"Authorization": f"Bearer {key}",
"Content-Type": "application/json",
**(
{
"HTTP-Referer": "https://openwebui.com/",
"X-Title": "Open WebUI",
}
if "openrouter.ai" in url
else {}
),
**(
{
"X-OpenWebUI-User-Name": user.name,
"X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role,
}
if ENABLE_FORWARD_USER_INFO_HEADERS
else {}
),
},
) )
# Check if response is SSE # Check if response is SSE
@ -612,14 +677,18 @@ async def generate_chat_completion(
return response return response
except Exception as e: except Exception as e:
log.exception(e) log.exception(e)
error_detail = "Open WebUI: Server Connection Error"
detail = None
if isinstance(response, dict): if isinstance(response, dict):
if "error" in response: if "error" in response:
error_detail = f"{response['error']['message'] if 'message' in response['error'] else response['error']}" detail = f"{response['error']['message'] if 'message' in response['error'] else response['error']}"
elif isinstance(response, str): elif isinstance(response, str):
error_detail = response detail = response
raise HTTPException(status_code=r.status if r else 500, detail=error_detail) raise HTTPException(
status_code=r.status if r else 500,
detail=detail if detail else "Open WebUI: Server Connection Error",
)
finally: finally:
if not streaming and session: if not streaming and session:
if r: if r:
@ -627,25 +696,17 @@ async def generate_chat_completion(
await session.close() await session.close()
@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"]) @router.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])
async def proxy(path: str, request: Request, user=Depends(get_verified_user)): async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
idx = 0 """
Deprecated: proxy all requests to OpenAI API
"""
body = await request.body() body = await request.body()
url = app.state.config.OPENAI_API_BASE_URLS[idx] idx = 0
key = app.state.config.OPENAI_API_KEYS[idx] url = request.app.state.config.OPENAI_API_BASE_URLS[idx]
key = request.app.state.config.OPENAI_API_KEYS[idx]
target_url = f"{url}/{path}"
headers = {}
headers["Authorization"] = f"Bearer {key}"
headers["Content-Type"] = "application/json"
if ENABLE_FORWARD_USER_INFO_HEADERS:
headers["X-OpenWebUI-User-Name"] = user.name
headers["X-OpenWebUI-User-Id"] = user.id
headers["X-OpenWebUI-User-Email"] = user.email
headers["X-OpenWebUI-User-Role"] = user.role
r = None r = None
session = None session = None
@ -655,11 +716,23 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
session = aiohttp.ClientSession(trust_env=True) session = aiohttp.ClientSession(trust_env=True)
r = await session.request( r = await session.request(
method=request.method, method=request.method,
url=target_url, url=f"{url}/{path}",
data=body, data=body,
headers=headers, headers={
"Authorization": f"Bearer {key}",
"Content-Type": "application/json",
**(
{
"X-OpenWebUI-User-Name": user.name,
"X-OpenWebUI-User-Id": user.id,
"X-OpenWebUI-User-Email": user.email,
"X-OpenWebUI-User-Role": user.role,
}
if ENABLE_FORWARD_USER_INFO_HEADERS
else {}
),
},
) )
r.raise_for_status() r.raise_for_status()
# Check if response is SSE # Check if response is SSE
@ -676,18 +749,23 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
else: else:
response_data = await r.json() response_data = await r.json()
return response_data return response_data
except Exception as e: except Exception as e:
log.exception(e) log.exception(e)
error_detail = "Open WebUI: Server Connection Error"
detail = None
if r is not None: if r is not None:
try: try:
res = await r.json() res = await r.json()
print(res) print(res)
if "error" in res: if "error" in res:
error_detail = f"External: {res['error']['message'] if 'message' in res['error'] else res['error']}" detail = f"External: {res['error']['message'] if 'message' in res['error'] else res['error']}"
except Exception: except Exception:
error_detail = f"External: {e}" detail = f"External: {e}"
raise HTTPException(status_code=r.status if r else 500, detail=error_detail) raise HTTPException(
status_code=r.status if r else 500,
detail=detail if detail else "Open WebUI: Server Connection Error",
)
finally: finally:
if not streaming and session: if not streaming and session:
if r: if r:

View File

@ -89,103 +89,10 @@ from open_webui.utils.payload import (
from open_webui.utils.tools import get_tools from open_webui.utils.tools import get_tools
app = FastAPI(
docs_url="/docs" if ENV == "dev" else None,
openapi_url="/openapi.json" if ENV == "dev" else None,
redoc_url=None,
)
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MAIN"]) log.setLevel(SRC_LOG_LEVELS["MAIN"])
app.state.config = AppConfig()
app.state.config.ENABLE_SIGNUP = ENABLE_SIGNUP
app.state.config.ENABLE_LOGIN_FORM = ENABLE_LOGIN_FORM
app.state.config.ENABLE_API_KEY = ENABLE_API_KEY
app.state.config.JWT_EXPIRES_IN = JWT_EXPIRES_IN
app.state.AUTH_TRUSTED_EMAIL_HEADER = WEBUI_AUTH_TRUSTED_EMAIL_HEADER
app.state.AUTH_TRUSTED_NAME_HEADER = WEBUI_AUTH_TRUSTED_NAME_HEADER
app.state.config.SHOW_ADMIN_DETAILS = SHOW_ADMIN_DETAILS
app.state.config.ADMIN_EMAIL = ADMIN_EMAIL
app.state.config.DEFAULT_MODELS = DEFAULT_MODELS
app.state.config.DEFAULT_PROMPT_SUGGESTIONS = DEFAULT_PROMPT_SUGGESTIONS
app.state.config.DEFAULT_USER_ROLE = DEFAULT_USER_ROLE
app.state.config.USER_PERMISSIONS = USER_PERMISSIONS
app.state.config.WEBHOOK_URL = WEBHOOK_URL
app.state.config.BANNERS = WEBUI_BANNERS
app.state.config.MODEL_ORDER_LIST = MODEL_ORDER_LIST
app.state.config.ENABLE_COMMUNITY_SHARING = ENABLE_COMMUNITY_SHARING
app.state.config.ENABLE_MESSAGE_RATING = ENABLE_MESSAGE_RATING
app.state.config.ENABLE_EVALUATION_ARENA_MODELS = ENABLE_EVALUATION_ARENA_MODELS
app.state.config.EVALUATION_ARENA_MODELS = EVALUATION_ARENA_MODELS
app.state.config.OAUTH_USERNAME_CLAIM = OAUTH_USERNAME_CLAIM
app.state.config.OAUTH_PICTURE_CLAIM = OAUTH_PICTURE_CLAIM
app.state.config.OAUTH_EMAIL_CLAIM = OAUTH_EMAIL_CLAIM
app.state.config.ENABLE_OAUTH_ROLE_MANAGEMENT = ENABLE_OAUTH_ROLE_MANAGEMENT
app.state.config.OAUTH_ROLES_CLAIM = OAUTH_ROLES_CLAIM
app.state.config.OAUTH_ALLOWED_ROLES = OAUTH_ALLOWED_ROLES
app.state.config.OAUTH_ADMIN_ROLES = OAUTH_ADMIN_ROLES
app.state.config.ENABLE_LDAP = ENABLE_LDAP
app.state.config.LDAP_SERVER_LABEL = LDAP_SERVER_LABEL
app.state.config.LDAP_SERVER_HOST = LDAP_SERVER_HOST
app.state.config.LDAP_SERVER_PORT = LDAP_SERVER_PORT
app.state.config.LDAP_ATTRIBUTE_FOR_USERNAME = LDAP_ATTRIBUTE_FOR_USERNAME
app.state.config.LDAP_APP_DN = LDAP_APP_DN
app.state.config.LDAP_APP_PASSWORD = LDAP_APP_PASSWORD
app.state.config.LDAP_SEARCH_BASE = LDAP_SEARCH_BASE
app.state.config.LDAP_SEARCH_FILTERS = LDAP_SEARCH_FILTERS
app.state.config.LDAP_USE_TLS = LDAP_USE_TLS
app.state.config.LDAP_CA_CERT_FILE = LDAP_CA_CERT_FILE
app.state.config.LDAP_CIPHERS = LDAP_CIPHERS
app.state.TOOLS = {}
app.state.FUNCTIONS = {}
app.add_middleware(
CORSMiddleware,
allow_origins=CORS_ALLOW_ORIGIN,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
app.include_router(configs.router, prefix="/configs", tags=["configs"])
app.include_router(auths.router, prefix="/auths", tags=["auths"])
app.include_router(users.router, prefix="/users", tags=["users"])
app.include_router(chats.router, prefix="/chats", tags=["chats"])
app.include_router(models.router, prefix="/models", tags=["models"])
app.include_router(knowledge.router, prefix="/knowledge", tags=["knowledge"])
app.include_router(prompts.router, prefix="/prompts", tags=["prompts"])
app.include_router(tools.router, prefix="/tools", tags=["tools"])
app.include_router(memories.router, prefix="/memories", tags=["memories"])
app.include_router(folders.router, prefix="/folders", tags=["folders"])
app.include_router(groups.router, prefix="/groups", tags=["groups"])
app.include_router(files.router, prefix="/files", tags=["files"])
app.include_router(functions.router, prefix="/functions", tags=["functions"])
app.include_router(evaluations.router, prefix="/evaluations", tags=["evaluations"])
app.include_router(utils.router, prefix="/utils", tags=["utils"])
@app.get("/") @app.get("/")
async def get_status(): async def get_status():