diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index 2e1929bb3..a2d114844 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -1409,9 +1409,9 @@ app.include_router(ollama.router, prefix="/ollama") app.include_router(openai.router, prefix="/openai") -app.include_router(images.router, prefix="/api/v1/images") -app.include_router(audio.router, prefix="/api/v1/audio") -app.include_router(retrieval.router, prefix="/api/v1/retrieval") +app.include_router(images.router, prefix="/api/v1/images", tags=["images"]) +app.include_router(audio.router, prefix="/api/v1/audio", tags=["audio"]) +app.include_router(retrieval.router, prefix="/api/v1/retrieval", tags=["retrieval"]) app.include_router(configs.router, prefix="/api/v1/configs", tags=["configs"]) diff --git a/backend/open_webui/routers/chat.py b/backend/open_webui/routers/chat.py deleted file mode 100644 index fba1ffa1b..000000000 --- a/backend/open_webui/routers/chat.py +++ /dev/null @@ -1,411 +0,0 @@ -from fastapi import APIRouter, Depends, HTTPException, Response, status -from pydantic import BaseModel - -router = APIRouter() - - -@app.post("/api/chat/completions") -async def generate_chat_completions( - request: Request, - form_data: dict, - user=Depends(get_verified_user), - bypass_filter: bool = False, -): - if BYPASS_MODEL_ACCESS_CONTROL: - bypass_filter = True - - model_list = request.state.models - models = {model["id"]: model for model in model_list} - - model_id = form_data["model"] - if model_id not in models: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="Model not found", - ) - - model = models[model_id] - - # Check if user has access to the model - if not bypass_filter and user.role == "user": - if model.get("arena"): - if not has_access( - user.id, - type="read", - access_control=model.get("info", {}) - .get("meta", {}) - .get("access_control", {}), - ): - raise HTTPException( - status_code=403, - detail="Model not found", - ) - else: - model_info = Models.get_model_by_id(model_id) - if not model_info: - raise HTTPException( - status_code=404, - detail="Model not found", - ) - elif not ( - user.id == model_info.user_id - or has_access( - user.id, type="read", access_control=model_info.access_control - ) - ): - raise HTTPException( - status_code=403, - detail="Model not found", - ) - - if model["owned_by"] == "arena": - model_ids = model.get("info", {}).get("meta", {}).get("model_ids") - filter_mode = model.get("info", {}).get("meta", {}).get("filter_mode") - if model_ids and filter_mode == "exclude": - model_ids = [ - model["id"] - for model in await get_all_models() - if model.get("owned_by") != "arena" and model["id"] not in model_ids - ] - - selected_model_id = None - if isinstance(model_ids, list) and model_ids: - selected_model_id = random.choice(model_ids) - else: - model_ids = [ - model["id"] - for model in await get_all_models() - if model.get("owned_by") != "arena" - ] - selected_model_id = random.choice(model_ids) - - form_data["model"] = selected_model_id - - if form_data.get("stream") == True: - - async def stream_wrapper(stream): - yield f"data: {json.dumps({'selected_model_id': selected_model_id})}\n\n" - async for chunk in stream: - yield chunk - - response = await generate_chat_completions( - form_data, user, bypass_filter=True - ) - return StreamingResponse( - stream_wrapper(response.body_iterator), media_type="text/event-stream" - ) - else: - return { - **( - await generate_chat_completions(form_data, user, bypass_filter=True) - ), - "selected_model_id": selected_model_id, - } - - if model.get("pipe"): - # Below does not require bypass_filter because this is the only route the uses this function and it is already bypassing the filter - return await generate_function_chat_completion( - form_data, user=user, models=models - ) - if model["owned_by"] == "ollama": - # Using /ollama/api/chat endpoint - form_data = convert_payload_openai_to_ollama(form_data) - form_data = GenerateChatCompletionForm(**form_data) - response = await generate_ollama_chat_completion( - form_data=form_data, user=user, bypass_filter=bypass_filter - ) - if form_data.stream: - response.headers["content-type"] = "text/event-stream" - return StreamingResponse( - convert_streaming_response_ollama_to_openai(response), - headers=dict(response.headers), - ) - else: - return convert_response_ollama_to_openai(response) - else: - return await generate_openai_chat_completion( - form_data, user=user, bypass_filter=bypass_filter - ) - - -@app.post("/api/chat/completed") -async def chat_completed(form_data: dict, user=Depends(get_verified_user)): - model_list = await get_all_models() - models = {model["id"]: model for model in model_list} - - data = form_data - model_id = data["model"] - if model_id not in models: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="Model not found", - ) - - model = models[model_id] - sorted_filters = get_sorted_filters(model_id, models) - if "pipeline" in model: - sorted_filters = [model] + sorted_filters - - for filter in sorted_filters: - r = None - try: - urlIdx = filter["urlIdx"] - - url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx] - key = openai_app.state.config.OPENAI_API_KEYS[urlIdx] - - if key != "": - headers = {"Authorization": f"Bearer {key}"} - r = requests.post( - f"{url}/{filter['id']}/filter/outlet", - headers=headers, - json={ - "user": { - "id": user.id, - "name": user.name, - "email": user.email, - "role": user.role, - }, - "body": data, - }, - ) - - r.raise_for_status() - data = r.json() - except Exception as e: - # Handle connection error here - print(f"Connection error: {e}") - - if r is not None: - try: - res = r.json() - if "detail" in res: - return JSONResponse( - status_code=r.status_code, - content=res, - ) - except Exception: - pass - - else: - pass - - __event_emitter__ = get_event_emitter( - { - "chat_id": data["chat_id"], - "message_id": data["id"], - "session_id": data["session_id"], - } - ) - - __event_call__ = get_event_call( - { - "chat_id": data["chat_id"], - "message_id": data["id"], - "session_id": data["session_id"], - } - ) - - def get_priority(function_id): - function = Functions.get_function_by_id(function_id) - if function is not None and hasattr(function, "valves"): - # TODO: Fix FunctionModel to include vavles - return (function.valves if function.valves else {}).get("priority", 0) - return 0 - - filter_ids = [function.id for function in Functions.get_global_filter_functions()] - if "info" in model and "meta" in model["info"]: - filter_ids.extend(model["info"]["meta"].get("filterIds", [])) - filter_ids = list(set(filter_ids)) - - enabled_filter_ids = [ - function.id - for function in Functions.get_functions_by_type("filter", active_only=True) - ] - filter_ids = [ - filter_id for filter_id in filter_ids if filter_id in enabled_filter_ids - ] - - # Sort filter_ids by priority, using the get_priority function - filter_ids.sort(key=get_priority) - - for filter_id in filter_ids: - filter = Functions.get_function_by_id(filter_id) - if not filter: - continue - - if filter_id in webui_app.state.FUNCTIONS: - function_module = webui_app.state.FUNCTIONS[filter_id] - else: - function_module, _, _ = load_function_module_by_id(filter_id) - webui_app.state.FUNCTIONS[filter_id] = function_module - - if hasattr(function_module, "valves") and hasattr(function_module, "Valves"): - valves = Functions.get_function_valves_by_id(filter_id) - function_module.valves = function_module.Valves( - **(valves if valves else {}) - ) - - if not hasattr(function_module, "outlet"): - continue - try: - outlet = function_module.outlet - - # Get the signature of the function - sig = inspect.signature(outlet) - params = {"body": data} - - # Extra parameters to be passed to the function - extra_params = { - "__model__": model, - "__id__": filter_id, - "__event_emitter__": __event_emitter__, - "__event_call__": __event_call__, - } - - # Add extra params in contained in function signature - for key, value in extra_params.items(): - if key in sig.parameters: - params[key] = value - - if "__user__" in sig.parameters: - __user__ = { - "id": user.id, - "email": user.email, - "name": user.name, - "role": user.role, - } - - try: - if hasattr(function_module, "UserValves"): - __user__["valves"] = function_module.UserValves( - **Functions.get_user_valves_by_id_and_user_id( - filter_id, user.id - ) - ) - except Exception as e: - print(e) - - params = {**params, "__user__": __user__} - - if inspect.iscoroutinefunction(outlet): - data = await outlet(**params) - else: - data = outlet(**params) - - except Exception as e: - print(f"Error: {e}") - return JSONResponse( - status_code=status.HTTP_400_BAD_REQUEST, - content={"detail": str(e)}, - ) - - return data - - -@app.post("/api/chat/actions/{action_id}") -async def chat_action(action_id: str, form_data: dict, user=Depends(get_verified_user)): - if "." in action_id: - action_id, sub_action_id = action_id.split(".") - else: - sub_action_id = None - - action = Functions.get_function_by_id(action_id) - if not action: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="Action not found", - ) - - model_list = await get_all_models() - models = {model["id"]: model for model in model_list} - - data = form_data - model_id = data["model"] - - if model_id not in models: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="Model not found", - ) - model = models[model_id] - - __event_emitter__ = get_event_emitter( - { - "chat_id": data["chat_id"], - "message_id": data["id"], - "session_id": data["session_id"], - } - ) - __event_call__ = get_event_call( - { - "chat_id": data["chat_id"], - "message_id": data["id"], - "session_id": data["session_id"], - } - ) - - if action_id in webui_app.state.FUNCTIONS: - function_module = webui_app.state.FUNCTIONS[action_id] - else: - function_module, _, _ = load_function_module_by_id(action_id) - webui_app.state.FUNCTIONS[action_id] = function_module - - if hasattr(function_module, "valves") and hasattr(function_module, "Valves"): - valves = Functions.get_function_valves_by_id(action_id) - function_module.valves = function_module.Valves(**(valves if valves else {})) - - if hasattr(function_module, "action"): - try: - action = function_module.action - - # Get the signature of the function - sig = inspect.signature(action) - params = {"body": data} - - # Extra parameters to be passed to the function - extra_params = { - "__model__": model, - "__id__": sub_action_id if sub_action_id is not None else action_id, - "__event_emitter__": __event_emitter__, - "__event_call__": __event_call__, - } - - # Add extra params in contained in function signature - for key, value in extra_params.items(): - if key in sig.parameters: - params[key] = value - - if "__user__" in sig.parameters: - __user__ = { - "id": user.id, - "email": user.email, - "name": user.name, - "role": user.role, - } - - try: - if hasattr(function_module, "UserValves"): - __user__["valves"] = function_module.UserValves( - **Functions.get_user_valves_by_id_and_user_id( - action_id, user.id - ) - ) - except Exception as e: - print(e) - - params = {**params, "__user__": __user__} - - if inspect.iscoroutinefunction(action): - data = await action(**params) - else: - data = action(**params) - - except Exception as e: - print(f"Error: {e}") - return JSONResponse( - status_code=status.HTTP_400_BAD_REQUEST, - content={"detail": str(e)}, - ) - - return data diff --git a/backend/open_webui/routers/ollama.py b/backend/open_webui/routers/ollama.py index 082d14ec3..b217b8f45 100644 --- a/backend/open_webui/routers/ollama.py +++ b/backend/open_webui/routers/ollama.py @@ -317,6 +317,9 @@ async def get_all_models(request: Request): else: models = {"models": []} + request.app.state.OLLAMA_MODELS = { + model["model"]: model for model in models["models"] + } return models diff --git a/backend/open_webui/routers/openai.py b/backend/open_webui/routers/openai.py index 1e9ca4af7..34c5683a8 100644 --- a/backend/open_webui/routers/openai.py +++ b/backend/open_webui/routers/openai.py @@ -10,15 +10,15 @@ from aiocache import cached import requests +from fastapi import Depends, FastAPI, HTTPException, Request, APIRouter +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import FileResponse, StreamingResponse +from pydantic import BaseModel +from starlette.background import BackgroundTask + from open_webui.models.models import Models from open_webui.config import ( CACHE_DIR, - CORS_ALLOW_ORIGIN, - ENABLE_OPENAI_API, - OPENAI_API_BASE_URLS, - OPENAI_API_KEYS, - OPENAI_API_CONFIGS, - AppConfig, ) from open_webui.env import ( AIOHTTP_CLIENT_TIMEOUT, @@ -29,11 +29,7 @@ from open_webui.env import ( from open_webui.constants import ERROR_MESSAGES from open_webui.env import ENV, SRC_LOG_LEVELS -from fastapi import Depends, FastAPI, HTTPException, Request -from fastapi.middleware.cors import CORSMiddleware -from fastapi.responses import FileResponse, StreamingResponse -from pydantic import BaseModel -from starlette.background import BackgroundTask + from open_webui.utils.payload import ( apply_model_params_to_body_openai, @@ -48,13 +44,69 @@ log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["OPENAI"]) -@app.get("/config") -async def get_config(user=Depends(get_admin_user)): +########################################## +# +# Utility functions +# +########################################## + + +async def send_get_request(url, key=None): + timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST) + try: + async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session: + async with session.get( + url, headers={**({"Authorization": f"Bearer {key}"} if key else {})} + ) as response: + return await response.json() + except Exception as e: + # Handle connection error here + log.error(f"Connection error: {e}") + return None + + +async def cleanup_response( + response: Optional[aiohttp.ClientResponse], + session: Optional[aiohttp.ClientSession], +): + if response: + response.close() + if session: + await session.close() + + +def openai_o1_handler(payload): + """ + Handle O1 specific parameters + """ + if "max_tokens" in payload: + # Remove "max_tokens" from the payload + payload["max_completion_tokens"] = payload["max_tokens"] + del payload["max_tokens"] + + # Fix: O1 does not support the "system" parameter, Modify "system" to "user" + if payload["messages"][0]["role"] == "system": + payload["messages"][0]["role"] = "user" + + return payload + + +########################################## +# +# API routes +# +########################################## + +router = APIRouter() + + +@router.get("/config") +async def get_config(request: Request, user=Depends(get_admin_user)): return { - "ENABLE_OPENAI_API": app.state.config.ENABLE_OPENAI_API, - "OPENAI_API_BASE_URLS": app.state.config.OPENAI_API_BASE_URLS, - "OPENAI_API_KEYS": app.state.config.OPENAI_API_KEYS, - "OPENAI_API_CONFIGS": app.state.config.OPENAI_API_CONFIGS, + "ENABLE_OPENAI_API": request.app.state.config.ENABLE_OPENAI_API, + "OPENAI_API_BASE_URLS": request.app.state.config.OPENAI_API_BASE_URLS, + "OPENAI_API_KEYS": request.app.state.config.OPENAI_API_KEYS, + "OPENAI_API_CONFIGS": request.app.state.config.OPENAI_API_CONFIGS, } @@ -65,49 +117,56 @@ class OpenAIConfigForm(BaseModel): OPENAI_API_CONFIGS: dict -@app.post("/config/update") -async def update_config(form_data: OpenAIConfigForm, user=Depends(get_admin_user)): - app.state.config.ENABLE_OPENAI_API = form_data.ENABLE_OPENAI_API - app.state.config.OPENAI_API_BASE_URLS = form_data.OPENAI_API_BASE_URLS - app.state.config.OPENAI_API_KEYS = form_data.OPENAI_API_KEYS +@router.post("/config/update") +async def update_config( + request: Request, form_data: OpenAIConfigForm, user=Depends(get_admin_user) +): + request.app.state.config.ENABLE_OPENAI_API = form_data.ENABLE_OPENAI_API + request.app.state.config.OPENAI_API_BASE_URLS = form_data.OPENAI_API_BASE_URLS + request.app.state.config.OPENAI_API_KEYS = form_data.OPENAI_API_KEYS # Check if API KEYS length is same than API URLS length - if len(app.state.config.OPENAI_API_KEYS) != len( - app.state.config.OPENAI_API_BASE_URLS + if len(request.app.state.config.OPENAI_API_KEYS) != len( + request.app.state.config.OPENAI_API_BASE_URLS ): - if len(app.state.config.OPENAI_API_KEYS) > len( - app.state.config.OPENAI_API_BASE_URLS + if len(request.app.state.config.OPENAI_API_KEYS) > len( + request.app.state.config.OPENAI_API_BASE_URLS ): - app.state.config.OPENAI_API_KEYS = app.state.config.OPENAI_API_KEYS[ - : len(app.state.config.OPENAI_API_BASE_URLS) - ] + request.app.state.config.OPENAI_API_KEYS = ( + request.app.state.config.OPENAI_API_KEYS[ + : len(request.app.state.config.OPENAI_API_BASE_URLS) + ] + ) else: - app.state.config.OPENAI_API_KEYS += [""] * ( - len(app.state.config.OPENAI_API_BASE_URLS) - - len(app.state.config.OPENAI_API_KEYS) + request.app.state.config.OPENAI_API_KEYS += [""] * ( + len(request.app.state.config.OPENAI_API_BASE_URLS) + - len(request.app.state.config.OPENAI_API_KEYS) ) - app.state.config.OPENAI_API_CONFIGS = form_data.OPENAI_API_CONFIGS + request.app.state.config.OPENAI_API_CONFIGS = form_data.OPENAI_API_CONFIGS # Remove any extra configs - config_urls = app.state.config.OPENAI_API_CONFIGS.keys() - for idx, url in enumerate(app.state.config.OPENAI_API_BASE_URLS): + config_urls = request.app.state.config.OPENAI_API_CONFIGS.keys() + for idx, url in enumerate(request.app.state.config.OPENAI_API_BASE_URLS): if url not in config_urls: - app.state.config.OPENAI_API_CONFIGS.pop(url, None) + request.app.state.config.OPENAI_API_CONFIGS.pop(url, None) return { - "ENABLE_OPENAI_API": app.state.config.ENABLE_OPENAI_API, - "OPENAI_API_BASE_URLS": app.state.config.OPENAI_API_BASE_URLS, - "OPENAI_API_KEYS": app.state.config.OPENAI_API_KEYS, - "OPENAI_API_CONFIGS": app.state.config.OPENAI_API_CONFIGS, + "ENABLE_OPENAI_API": request.app.state.config.ENABLE_OPENAI_API, + "OPENAI_API_BASE_URLS": request.app.state.config.OPENAI_API_BASE_URLS, + "OPENAI_API_KEYS": request.app.state.config.OPENAI_API_KEYS, + "OPENAI_API_CONFIGS": request.app.state.config.OPENAI_API_CONFIGS, } -@app.post("/audio/speech") +@router.post("/audio/speech") async def speech(request: Request, user=Depends(get_verified_user)): idx = None try: - idx = app.state.config.OPENAI_API_BASE_URLS.index("https://api.openai.com/v1") + idx = request.app.state.config.OPENAI_API_BASE_URLS.index( + "https://api.openai.com/v1" + ) + body = await request.body() name = hashlib.sha256(body).hexdigest() @@ -120,23 +179,35 @@ async def speech(request: Request, user=Depends(get_verified_user)): if file_path.is_file(): return FileResponse(file_path) - headers = {} - headers["Authorization"] = f"Bearer {app.state.config.OPENAI_API_KEYS[idx]}" - headers["Content-Type"] = "application/json" - if "openrouter.ai" in app.state.config.OPENAI_API_BASE_URLS[idx]: - headers["HTTP-Referer"] = "https://openwebui.com/" - headers["X-Title"] = "Open WebUI" - if ENABLE_FORWARD_USER_INFO_HEADERS: - headers["X-OpenWebUI-User-Name"] = user.name - headers["X-OpenWebUI-User-Id"] = user.id - headers["X-OpenWebUI-User-Email"] = user.email - headers["X-OpenWebUI-User-Role"] = user.role + url = request.app.state.config.OPENAI_API_BASE_URLS[idx] + r = None try: r = requests.post( - url=f"{app.state.config.OPENAI_API_BASE_URLS[idx]}/audio/speech", + url=f"{url}/audio/speech", data=body, - headers=headers, + headers={ + "Content-Type": "application/json", + "Authorization": f"Bearer {request.app.state.config.OPENAI_API_KEYS[idx]}", + **( + { + "HTTP-Referer": "https://openwebui.com/", + "X-Title": "Open WebUI", + } + if "openrouter.ai" in url + else {} + ), + **( + { + "X-OpenWebUI-User-Name": user.name, + "X-OpenWebUI-User-Id": user.id, + "X-OpenWebUI-User-Email": user.email, + "X-OpenWebUI-User-Role": user.role, + } + if ENABLE_FORWARD_USER_INFO_HEADERS + else {} + ), + }, stream=True, ) @@ -155,46 +226,25 @@ async def speech(request: Request, user=Depends(get_verified_user)): except Exception as e: log.exception(e) - error_detail = "Open WebUI: Server Connection Error" + + detail = None if r is not None: try: res = r.json() if "error" in res: - error_detail = f"External: {res['error']}" + detail = f"External: {res['error']}" except Exception: - error_detail = f"External: {e}" + detail = f"External: {e}" raise HTTPException( - status_code=r.status_code if r else 500, detail=error_detail + status_code=r.status_code if r else 500, + detail=detail if detail else "Open WebUI: Server Connection Error", ) except ValueError: raise HTTPException(status_code=401, detail=ERROR_MESSAGES.OPENAI_NOT_FOUND) -async def aiohttp_get(url, key=None): - timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST) - try: - headers = {"Authorization": f"Bearer {key}"} if key else {} - async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session: - async with session.get(url, headers=headers) as response: - return await response.json() - except Exception as e: - # Handle connection error here - log.error(f"Connection error: {e}") - return None - - -async def cleanup_response( - response: Optional[aiohttp.ClientResponse], - session: Optional[aiohttp.ClientSession], -): - if response: - response.close() - if session: - await session.close() - - def merge_models_lists(model_lists): log.debug(f"merge_models_lists {model_lists}") merged_list = [] @@ -212,7 +262,7 @@ def merge_models_lists(model_lists): } for model in models if "api.openai.com" - not in app.state.config.OPENAI_API_BASE_URLS[idx] + not in request.app.state.config.OPENAI_API_BASE_URLS[idx] or not any( name in model["id"] for name in [ @@ -230,40 +280,43 @@ def merge_models_lists(model_lists): return merged_list -async def get_all_models_responses() -> list: - if not app.state.config.ENABLE_OPENAI_API: +async def get_all_models_responses(request: Request) -> list: + if not request.app.state.config.ENABLE_OPENAI_API: return [] # Check if API KEYS length is same than API URLS length - num_urls = len(app.state.config.OPENAI_API_BASE_URLS) - num_keys = len(app.state.config.OPENAI_API_KEYS) + num_urls = len(request.app.state.config.OPENAI_API_BASE_URLS) + num_keys = len(request.app.state.config.OPENAI_API_KEYS) if num_keys != num_urls: # if there are more keys than urls, remove the extra keys if num_keys > num_urls: - new_keys = app.state.config.OPENAI_API_KEYS[:num_urls] - app.state.config.OPENAI_API_KEYS = new_keys + new_keys = request.app.state.config.OPENAI_API_KEYS[:num_urls] + request.app.state.config.OPENAI_API_KEYS = new_keys # if there are more urls than keys, add empty keys else: - app.state.config.OPENAI_API_KEYS += [""] * (num_urls - num_keys) + request.app.state.config.OPENAI_API_KEYS += [""] * (num_urls - num_keys) - tasks = [] - for idx, url in enumerate(app.state.config.OPENAI_API_BASE_URLS): - if url not in app.state.config.OPENAI_API_CONFIGS: - tasks.append( - aiohttp_get(f"{url}/models", app.state.config.OPENAI_API_KEYS[idx]) + request_tasks = [] + for idx, url in enumerate(request.app.state.config.OPENAI_API_BASE_URLS): + if url not in request.app.state.config.OPENAI_API_CONFIGS: + request_tasks.append( + send_get_request( + f"{url}/models", request.app.state.config.OPENAI_API_KEYS[idx] + ) ) else: - api_config = app.state.config.OPENAI_API_CONFIGS.get(url, {}) + api_config = request.app.state.config.OPENAI_API_CONFIGS.get(url, {}) enable = api_config.get("enable", True) model_ids = api_config.get("model_ids", []) if enable: if len(model_ids) == 0: - tasks.append( - aiohttp_get( - f"{url}/models", app.state.config.OPENAI_API_KEYS[idx] + request_tasks.append( + send_get_request( + f"{url}/models", + request.app.state.config.OPENAI_API_KEYS[idx], ) ) else: @@ -281,16 +334,18 @@ async def get_all_models_responses() -> list: ], } - tasks.append(asyncio.ensure_future(asyncio.sleep(0, model_list))) + request_tasks.append( + asyncio.ensure_future(asyncio.sleep(0, model_list)) + ) else: - tasks.append(asyncio.ensure_future(asyncio.sleep(0, None))) + request_tasks.append(asyncio.ensure_future(asyncio.sleep(0, None))) - responses = await asyncio.gather(*tasks) + responses = await asyncio.gather(*request_tasks) for idx, response in enumerate(responses): if response: - url = app.state.config.OPENAI_API_BASE_URLS[idx] - api_config = app.state.config.OPENAI_API_CONFIGS.get(url, {}) + url = request.app.state.config.OPENAI_API_BASE_URLS[idx] + api_config = request.app.state.config.OPENAI_API_CONFIGS.get(url, {}) prefix_id = api_config.get("prefix_id", None) @@ -301,15 +356,27 @@ async def get_all_models_responses() -> list: model["id"] = f"{prefix_id}.{model['id']}" log.debug(f"get_all_models:responses() {responses}") - return responses +async def get_filtered_models(models, user): + # Filter models based on user access control + filtered_models = [] + for model in models.get("data", []): + model_info = Models.get_model_by_id(model["id"]) + if model_info: + if user.id == model_info.user_id or has_access( + user.id, type="read", access_control=model_info.access_control + ): + filtered_models.append(model) + return filtered_models + + @cached(ttl=3) -async def get_all_models() -> dict[str, list]: +async def get_all_models(request: Request) -> dict[str, list]: log.info("get_all_models()") - if not app.state.config.ENABLE_OPENAI_API: + if not request.app.state.config.ENABLE_OPENAI_API: return {"data": []} responses = await get_all_models_responses() @@ -324,12 +391,15 @@ async def get_all_models() -> dict[str, list]: models = {"data": merge_models_lists(map(extract_data, responses))} log.debug(f"models: {models}") + request.app.state.OPENAI_MODELS = {model["id"]: model for model in models["data"]} return models -@app.get("/models") -@app.get("/models/{url_idx}") -async def get_models(url_idx: Optional[int] = None, user=Depends(get_verified_user)): +@router.get("/models") +@router.get("/models/{url_idx}") +async def get_models( + request: Request, url_idx: Optional[int] = None, user=Depends(get_verified_user) +): models = { "data": [], } @@ -337,25 +407,33 @@ async def get_models(url_idx: Optional[int] = None, user=Depends(get_verified_us if url_idx is None: models = await get_all_models() else: - url = app.state.config.OPENAI_API_BASE_URLS[url_idx] - key = app.state.config.OPENAI_API_KEYS[url_idx] - - headers = {} - headers["Authorization"] = f"Bearer {key}" - headers["Content-Type"] = "application/json" - - if ENABLE_FORWARD_USER_INFO_HEADERS: - headers["X-OpenWebUI-User-Name"] = user.name - headers["X-OpenWebUI-User-Id"] = user.id - headers["X-OpenWebUI-User-Email"] = user.email - headers["X-OpenWebUI-User-Role"] = user.role + url = request.app.state.config.OPENAI_API_BASE_URLS[url_idx] + key = request.app.state.config.OPENAI_API_KEYS[url_idx] r = None - - timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST) - async with aiohttp.ClientSession(timeout=timeout) as session: + async with aiohttp.ClientSession( + timeout=aiohttp.ClientTimeout( + total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST + ) + ) as session: try: - async with session.get(f"{url}/models", headers=headers) as r: + async with session.get( + f"{url}/models", + headers={ + "Authorization": f"Bearer {key}", + "Content-Type": "application/json", + **( + { + "X-OpenWebUI-User-Name": user.name, + "X-OpenWebUI-User-Id": user.id, + "X-OpenWebUI-User-Email": user.email, + "X-OpenWebUI-User-Role": user.role, + } + if ENABLE_FORWARD_USER_INFO_HEADERS + else {} + ), + }, + ) as r: if r.status != 200: # Extract response error details if available error_detail = f"HTTP Error: {r.status}" @@ -389,27 +467,16 @@ async def get_models(url_idx: Optional[int] = None, user=Depends(get_verified_us except aiohttp.ClientError as e: # ClientError covers all aiohttp requests issues log.exception(f"Client error: {str(e)}") - # Handle aiohttp-specific connection issues, timeout etc. raise HTTPException( status_code=500, detail="Open WebUI: Server Connection Error" ) except Exception as e: log.exception(f"Unexpected error: {e}") - # Generic error handler in case parsing JSON or other steps fail error_detail = f"Unexpected error: {str(e)}" raise HTTPException(status_code=500, detail=error_detail) if user.role == "user" and not BYPASS_MODEL_ACCESS_CONTROL: - # Filter models based on user access control - filtered_models = [] - for model in models.get("data", []): - model_info = Models.get_model_by_id(model["id"]) - if model_info: - if user.id == model_info.user_id or has_access( - user.id, type="read", access_control=model_info.access_control - ): - filtered_models.append(model) - models["data"] = filtered_models + models["data"] = get_filtered_models(models, user) return models @@ -419,21 +486,24 @@ class ConnectionVerificationForm(BaseModel): key: str -@app.post("/verify") +@router.post("/verify") async def verify_connection( form_data: ConnectionVerificationForm, user=Depends(get_admin_user) ): url = form_data.url key = form_data.key - headers = {} - headers["Authorization"] = f"Bearer {key}" - headers["Content-Type"] = "application/json" - - timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST) - async with aiohttp.ClientSession(timeout=timeout) as session: + async with aiohttp.ClientSession( + timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST) + ) as session: try: - async with session.get(f"{url}/models", headers=headers) as r: + async with session.get( + f"{url}/models", + headers={ + "Authorization": f"Bearer {key}", + "Content-Type": "application/json", + }, + ) as r: if r.status != 200: # Extract response error details if available error_detail = f"HTTP Error: {r.status}" @@ -448,26 +518,24 @@ async def verify_connection( except aiohttp.ClientError as e: # ClientError covers all aiohttp requests issues log.exception(f"Client error: {str(e)}") - # Handle aiohttp-specific connection issues, timeout etc. raise HTTPException( status_code=500, detail="Open WebUI: Server Connection Error" ) except Exception as e: log.exception(f"Unexpected error: {e}") - # Generic error handler in case parsing JSON or other steps fail error_detail = f"Unexpected error: {str(e)}" raise HTTPException(status_code=500, detail=error_detail) -@app.post("/chat/completions") +@router.post("/chat/completions") async def generate_chat_completion( + request: Request, form_data: dict, user=Depends(get_verified_user), bypass_filter: Optional[bool] = False, ): idx = 0 payload = {**form_data} - if "metadata" in payload: del payload["metadata"] @@ -502,15 +570,7 @@ async def generate_chat_completion( detail="Model not found", ) - # Attemp to get urlIdx from the model - models = await get_all_models() - - # Find the model from the list - model = next( - (model for model in models["data"] if model["id"] == payload.get("model")), - None, - ) - + model = request.app.state.OPENAI_MODELS.get(model_id) if model: idx = model["urlIdx"] else: @@ -520,11 +580,11 @@ async def generate_chat_completion( ) # Get the API config for the model - api_config = app.state.config.OPENAI_API_CONFIGS.get( - app.state.config.OPENAI_API_BASE_URLS[idx], {} + api_config = request.app.state.config.OPENAI_API_CONFIGS.get( + request.app.state.config.OPENAI_API_BASE_URLS[idx], {} ) - prefix_id = api_config.get("prefix_id", None) + prefix_id = api_config.get("prefix_id", None) if prefix_id: payload["model"] = payload["model"].replace(f"{prefix_id}.", "") @@ -537,43 +597,26 @@ async def generate_chat_completion( "role": user.role, } - url = app.state.config.OPENAI_API_BASE_URLS[idx] - key = app.state.config.OPENAI_API_KEYS[idx] + url = request.app.state.config.OPENAI_API_BASE_URLS[idx] + key = request.app.state.config.OPENAI_API_KEYS[idx] # Fix: O1 does not support the "max_tokens" parameter, Modify "max_tokens" to "max_completion_tokens" is_o1 = payload["model"].lower().startswith("o1-") - # Change max_completion_tokens to max_tokens (Backward compatible) - if "api.openai.com" not in url and not is_o1: - if "max_completion_tokens" in payload: - # Remove "max_completion_tokens" from the payload - payload["max_tokens"] = payload["max_completion_tokens"] - del payload["max_completion_tokens"] - else: - if is_o1 and "max_tokens" in payload: + if is_o1: + payload = openai_o1_handler(payload) + elif "api.openai.com" not in url: + # Remove "max_tokens" from the payload for backward compatibility + if "max_tokens" in payload: payload["max_completion_tokens"] = payload["max_tokens"] del payload["max_tokens"] - if "max_tokens" in payload and "max_completion_tokens" in payload: - del payload["max_tokens"] - # Fix: O1 does not support the "system" parameter, Modify "system" to "user" - if is_o1 and payload["messages"][0]["role"] == "system": - payload["messages"][0]["role"] = "user" + # TODO: check if below is needed + # if "max_tokens" in payload and "max_completion_tokens" in payload: + # del payload["max_tokens"] # Convert the modified body back to JSON payload = json.dumps(payload) - headers = {} - headers["Authorization"] = f"Bearer {key}" - headers["Content-Type"] = "application/json" - if "openrouter.ai" in app.state.config.OPENAI_API_BASE_URLS[idx]: - headers["HTTP-Referer"] = "https://openwebui.com/" - headers["X-Title"] = "Open WebUI" - if ENABLE_FORWARD_USER_INFO_HEADERS: - headers["X-OpenWebUI-User-Name"] = user.name - headers["X-OpenWebUI-User-Id"] = user.id - headers["X-OpenWebUI-User-Email"] = user.email - headers["X-OpenWebUI-User-Role"] = user.role - r = None session = None streaming = False @@ -583,11 +626,33 @@ async def generate_chat_completion( session = aiohttp.ClientSession( trust_env=True, timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT) ) + r = await session.request( method="POST", url=f"{url}/chat/completions", data=payload, - headers=headers, + headers={ + "Authorization": f"Bearer {key}", + "Content-Type": "application/json", + **( + { + "HTTP-Referer": "https://openwebui.com/", + "X-Title": "Open WebUI", + } + if "openrouter.ai" in url + else {} + ), + **( + { + "X-OpenWebUI-User-Name": user.name, + "X-OpenWebUI-User-Id": user.id, + "X-OpenWebUI-User-Email": user.email, + "X-OpenWebUI-User-Role": user.role, + } + if ENABLE_FORWARD_USER_INFO_HEADERS + else {} + ), + }, ) # Check if response is SSE @@ -612,14 +677,18 @@ async def generate_chat_completion( return response except Exception as e: log.exception(e) - error_detail = "Open WebUI: Server Connection Error" + + detail = None if isinstance(response, dict): if "error" in response: - error_detail = f"{response['error']['message'] if 'message' in response['error'] else response['error']}" + detail = f"{response['error']['message'] if 'message' in response['error'] else response['error']}" elif isinstance(response, str): - error_detail = response + detail = response - raise HTTPException(status_code=r.status if r else 500, detail=error_detail) + raise HTTPException( + status_code=r.status if r else 500, + detail=detail if detail else "Open WebUI: Server Connection Error", + ) finally: if not streaming and session: if r: @@ -627,25 +696,17 @@ async def generate_chat_completion( await session.close() -@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"]) +@router.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"]) async def proxy(path: str, request: Request, user=Depends(get_verified_user)): - idx = 0 + """ + Deprecated: proxy all requests to OpenAI API + """ body = await request.body() - url = app.state.config.OPENAI_API_BASE_URLS[idx] - key = app.state.config.OPENAI_API_KEYS[idx] - - target_url = f"{url}/{path}" - - headers = {} - headers["Authorization"] = f"Bearer {key}" - headers["Content-Type"] = "application/json" - if ENABLE_FORWARD_USER_INFO_HEADERS: - headers["X-OpenWebUI-User-Name"] = user.name - headers["X-OpenWebUI-User-Id"] = user.id - headers["X-OpenWebUI-User-Email"] = user.email - headers["X-OpenWebUI-User-Role"] = user.role + idx = 0 + url = request.app.state.config.OPENAI_API_BASE_URLS[idx] + key = request.app.state.config.OPENAI_API_KEYS[idx] r = None session = None @@ -655,11 +716,23 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)): session = aiohttp.ClientSession(trust_env=True) r = await session.request( method=request.method, - url=target_url, + url=f"{url}/{path}", data=body, - headers=headers, + headers={ + "Authorization": f"Bearer {key}", + "Content-Type": "application/json", + **( + { + "X-OpenWebUI-User-Name": user.name, + "X-OpenWebUI-User-Id": user.id, + "X-OpenWebUI-User-Email": user.email, + "X-OpenWebUI-User-Role": user.role, + } + if ENABLE_FORWARD_USER_INFO_HEADERS + else {} + ), + }, ) - r.raise_for_status() # Check if response is SSE @@ -676,18 +749,23 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)): else: response_data = await r.json() return response_data + except Exception as e: log.exception(e) - error_detail = "Open WebUI: Server Connection Error" + + detail = None if r is not None: try: res = await r.json() print(res) if "error" in res: - error_detail = f"External: {res['error']['message'] if 'message' in res['error'] else res['error']}" + detail = f"External: {res['error']['message'] if 'message' in res['error'] else res['error']}" except Exception: - error_detail = f"External: {e}" - raise HTTPException(status_code=r.status if r else 500, detail=error_detail) + detail = f"External: {e}" + raise HTTPException( + status_code=r.status if r else 500, + detail=detail if detail else "Open WebUI: Server Connection Error", + ) finally: if not streaming and session: if r: diff --git a/backend/open_webui/routers/webui.py b/backend/open_webui/routers/webui.py index 1ac4db152..d3942db97 100644 --- a/backend/open_webui/routers/webui.py +++ b/backend/open_webui/routers/webui.py @@ -89,103 +89,10 @@ from open_webui.utils.payload import ( from open_webui.utils.tools import get_tools -app = FastAPI( - docs_url="/docs" if ENV == "dev" else None, - openapi_url="/openapi.json" if ENV == "dev" else None, - redoc_url=None, -) log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["MAIN"]) -app.state.config = AppConfig() - -app.state.config.ENABLE_SIGNUP = ENABLE_SIGNUP -app.state.config.ENABLE_LOGIN_FORM = ENABLE_LOGIN_FORM -app.state.config.ENABLE_API_KEY = ENABLE_API_KEY - -app.state.config.JWT_EXPIRES_IN = JWT_EXPIRES_IN -app.state.AUTH_TRUSTED_EMAIL_HEADER = WEBUI_AUTH_TRUSTED_EMAIL_HEADER -app.state.AUTH_TRUSTED_NAME_HEADER = WEBUI_AUTH_TRUSTED_NAME_HEADER - - -app.state.config.SHOW_ADMIN_DETAILS = SHOW_ADMIN_DETAILS -app.state.config.ADMIN_EMAIL = ADMIN_EMAIL - - -app.state.config.DEFAULT_MODELS = DEFAULT_MODELS -app.state.config.DEFAULT_PROMPT_SUGGESTIONS = DEFAULT_PROMPT_SUGGESTIONS -app.state.config.DEFAULT_USER_ROLE = DEFAULT_USER_ROLE - - -app.state.config.USER_PERMISSIONS = USER_PERMISSIONS -app.state.config.WEBHOOK_URL = WEBHOOK_URL -app.state.config.BANNERS = WEBUI_BANNERS -app.state.config.MODEL_ORDER_LIST = MODEL_ORDER_LIST - -app.state.config.ENABLE_COMMUNITY_SHARING = ENABLE_COMMUNITY_SHARING -app.state.config.ENABLE_MESSAGE_RATING = ENABLE_MESSAGE_RATING - -app.state.config.ENABLE_EVALUATION_ARENA_MODELS = ENABLE_EVALUATION_ARENA_MODELS -app.state.config.EVALUATION_ARENA_MODELS = EVALUATION_ARENA_MODELS - -app.state.config.OAUTH_USERNAME_CLAIM = OAUTH_USERNAME_CLAIM -app.state.config.OAUTH_PICTURE_CLAIM = OAUTH_PICTURE_CLAIM -app.state.config.OAUTH_EMAIL_CLAIM = OAUTH_EMAIL_CLAIM - -app.state.config.ENABLE_OAUTH_ROLE_MANAGEMENT = ENABLE_OAUTH_ROLE_MANAGEMENT -app.state.config.OAUTH_ROLES_CLAIM = OAUTH_ROLES_CLAIM -app.state.config.OAUTH_ALLOWED_ROLES = OAUTH_ALLOWED_ROLES -app.state.config.OAUTH_ADMIN_ROLES = OAUTH_ADMIN_ROLES - -app.state.config.ENABLE_LDAP = ENABLE_LDAP -app.state.config.LDAP_SERVER_LABEL = LDAP_SERVER_LABEL -app.state.config.LDAP_SERVER_HOST = LDAP_SERVER_HOST -app.state.config.LDAP_SERVER_PORT = LDAP_SERVER_PORT -app.state.config.LDAP_ATTRIBUTE_FOR_USERNAME = LDAP_ATTRIBUTE_FOR_USERNAME -app.state.config.LDAP_APP_DN = LDAP_APP_DN -app.state.config.LDAP_APP_PASSWORD = LDAP_APP_PASSWORD -app.state.config.LDAP_SEARCH_BASE = LDAP_SEARCH_BASE -app.state.config.LDAP_SEARCH_FILTERS = LDAP_SEARCH_FILTERS -app.state.config.LDAP_USE_TLS = LDAP_USE_TLS -app.state.config.LDAP_CA_CERT_FILE = LDAP_CA_CERT_FILE -app.state.config.LDAP_CIPHERS = LDAP_CIPHERS - -app.state.TOOLS = {} -app.state.FUNCTIONS = {} - -app.add_middleware( - CORSMiddleware, - allow_origins=CORS_ALLOW_ORIGIN, - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], -) - - -app.include_router(configs.router, prefix="/configs", tags=["configs"]) - -app.include_router(auths.router, prefix="/auths", tags=["auths"]) -app.include_router(users.router, prefix="/users", tags=["users"]) - -app.include_router(chats.router, prefix="/chats", tags=["chats"]) - -app.include_router(models.router, prefix="/models", tags=["models"]) -app.include_router(knowledge.router, prefix="/knowledge", tags=["knowledge"]) -app.include_router(prompts.router, prefix="/prompts", tags=["prompts"]) -app.include_router(tools.router, prefix="/tools", tags=["tools"]) - -app.include_router(memories.router, prefix="/memories", tags=["memories"]) -app.include_router(folders.router, prefix="/folders", tags=["folders"]) - -app.include_router(groups.router, prefix="/groups", tags=["groups"]) -app.include_router(files.router, prefix="/files", tags=["files"]) -app.include_router(functions.router, prefix="/functions", tags=["functions"]) -app.include_router(evaluations.router, prefix="/evaluations", tags=["evaluations"]) - - -app.include_router(utils.router, prefix="/utils", tags=["utils"]) - @app.get("/") async def get_status():