diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index 62e53e34c..6864306a2 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -919,7 +919,7 @@ async def get_models(request: Request, user=Depends(get_verified_user)): return filtered_models - models = await get_all_models(request) + models = await get_all_models(request, user=user) # Filter out filter pipelines models = [ @@ -959,7 +959,7 @@ async def chat_completion( user=Depends(get_verified_user), ): if not request.app.state.MODELS: - await get_all_models(request) + await get_all_models(request, user=user) model_item = form_data.pop("model_item", {}) tasks = form_data.pop("background_tasks", None) diff --git a/backend/open_webui/routers/ollama.py b/backend/open_webui/routers/ollama.py index 732dd36f9..d99416c83 100644 --- a/backend/open_webui/routers/ollama.py +++ b/backend/open_webui/routers/ollama.py @@ -14,6 +14,11 @@ from urllib.parse import urlparse import aiohttp from aiocache import cached import requests +from open_webui.models.users import UserModel + +from open_webui.env import ( + ENABLE_FORWARD_USER_INFO_HEADERS, +) from fastapi import ( Depends, @@ -66,12 +71,26 @@ log.setLevel(SRC_LOG_LEVELS["OLLAMA"]) ########################################## -async def send_get_request(url, key=None): +async def send_get_request(url, key=None, user: UserModel = None): timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST) try: async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session: async with session.get( - url, headers={**({"Authorization": f"Bearer {key}"} if key else {})} + url, + headers={ + "Content-Type": "application/json", + **({"Authorization": f"Bearer {key}"} if key else {}), + **( + { + "X-OpenWebUI-User-Name": user.name, + "X-OpenWebUI-User-Id": user.id, + "X-OpenWebUI-User-Email": user.email, + "X-OpenWebUI-User-Role": user.role, + } + if ENABLE_FORWARD_USER_INFO_HEADERS and user + else {} + ), + }, ) as response: return await response.json() except Exception as e: @@ -96,6 +115,7 @@ async def send_post_request( stream: bool = True, key: Optional[str] = None, content_type: Optional[str] = None, + user: UserModel = None, ): r = None @@ -110,6 +130,16 @@ async def send_post_request( headers={ "Content-Type": "application/json", **({"Authorization": f"Bearer {key}"} if key else {}), + **( + { + "X-OpenWebUI-User-Name": user.name, + "X-OpenWebUI-User-Id": user.id, + "X-OpenWebUI-User-Email": user.email, + "X-OpenWebUI-User-Role": user.role, + } + if ENABLE_FORWARD_USER_INFO_HEADERS and user + else {} + ), }, ) r.raise_for_status() @@ -191,7 +221,19 @@ async def verify_connection( try: async with session.get( f"{url}/api/version", - headers={**({"Authorization": f"Bearer {key}"} if key else {})}, + headers={ + **({"Authorization": f"Bearer {key}"} if key else {}), + **( + { + "X-OpenWebUI-User-Name": user.name, + "X-OpenWebUI-User-Id": user.id, + "X-OpenWebUI-User-Email": user.email, + "X-OpenWebUI-User-Role": user.role, + } + if ENABLE_FORWARD_USER_INFO_HEADERS and user + else {} + ), + }, ) as r: if r.status != 200: detail = f"HTTP Error: {r.status}" @@ -254,7 +296,7 @@ async def update_config( @cached(ttl=3) -async def get_all_models(request: Request): +async def get_all_models(request: Request, user: UserModel = None): log.info("get_all_models()") if request.app.state.config.ENABLE_OLLAMA_API: request_tasks = [] @@ -262,7 +304,7 @@ async def get_all_models(request: Request): if (str(idx) not in request.app.state.config.OLLAMA_API_CONFIGS) and ( url not in request.app.state.config.OLLAMA_API_CONFIGS # Legacy support ): - request_tasks.append(send_get_request(f"{url}/api/tags")) + request_tasks.append(send_get_request(f"{url}/api/tags", user=user)) else: api_config = request.app.state.config.OLLAMA_API_CONFIGS.get( str(idx), @@ -275,7 +317,9 @@ async def get_all_models(request: Request): key = api_config.get("key", None) if enable: - request_tasks.append(send_get_request(f"{url}/api/tags", key)) + request_tasks.append( + send_get_request(f"{url}/api/tags", key, user=user) + ) else: request_tasks.append(asyncio.ensure_future(asyncio.sleep(0, None))) @@ -360,7 +404,7 @@ async def get_ollama_tags( models = [] if url_idx is None: - models = await get_all_models(request) + models = await get_all_models(request, user=user) else: url = request.app.state.config.OLLAMA_BASE_URLS[url_idx] key = get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS) @@ -370,7 +414,19 @@ async def get_ollama_tags( r = requests.request( method="GET", url=f"{url}/api/tags", - headers={**({"Authorization": f"Bearer {key}"} if key else {})}, + headers={ + **({"Authorization": f"Bearer {key}"} if key else {}), + **( + { + "X-OpenWebUI-User-Name": user.name, + "X-OpenWebUI-User-Id": user.id, + "X-OpenWebUI-User-Email": user.email, + "X-OpenWebUI-User-Role": user.role, + } + if ENABLE_FORWARD_USER_INFO_HEADERS and user + else {} + ), + }, ) r.raise_for_status() @@ -477,6 +533,7 @@ async def get_ollama_loaded_models(request: Request, user=Depends(get_verified_u url, {} ), # Legacy support ).get("key", None), + user=user, ) for idx, url in enumerate(request.app.state.config.OLLAMA_BASE_URLS) ] @@ -509,6 +566,7 @@ async def pull_model( url=f"{url}/api/pull", payload=json.dumps(payload), key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS), + user=user, ) @@ -527,7 +585,7 @@ async def push_model( user=Depends(get_admin_user), ): if url_idx is None: - await get_all_models(request) + await get_all_models(request, user=user) models = request.app.state.OLLAMA_MODELS if form_data.name in models: @@ -545,6 +603,7 @@ async def push_model( url=f"{url}/api/push", payload=form_data.model_dump_json(exclude_none=True).encode(), key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS), + user=user, ) @@ -571,6 +630,7 @@ async def create_model( url=f"{url}/api/create", payload=form_data.model_dump_json(exclude_none=True).encode(), key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS), + user=user, ) @@ -588,7 +648,7 @@ async def copy_model( user=Depends(get_admin_user), ): if url_idx is None: - await get_all_models(request) + await get_all_models(request, user=user) models = request.app.state.OLLAMA_MODELS if form_data.source in models: @@ -609,6 +669,16 @@ async def copy_model( headers={ "Content-Type": "application/json", **({"Authorization": f"Bearer {key}"} if key else {}), + **( + { + "X-OpenWebUI-User-Name": user.name, + "X-OpenWebUI-User-Id": user.id, + "X-OpenWebUI-User-Email": user.email, + "X-OpenWebUI-User-Role": user.role, + } + if ENABLE_FORWARD_USER_INFO_HEADERS and user + else {} + ), }, data=form_data.model_dump_json(exclude_none=True).encode(), ) @@ -643,7 +713,7 @@ async def delete_model( user=Depends(get_admin_user), ): if url_idx is None: - await get_all_models(request) + await get_all_models(request, user=user) models = request.app.state.OLLAMA_MODELS if form_data.name in models: @@ -665,6 +735,16 @@ async def delete_model( headers={ "Content-Type": "application/json", **({"Authorization": f"Bearer {key}"} if key else {}), + **( + { + "X-OpenWebUI-User-Name": user.name, + "X-OpenWebUI-User-Id": user.id, + "X-OpenWebUI-User-Email": user.email, + "X-OpenWebUI-User-Role": user.role, + } + if ENABLE_FORWARD_USER_INFO_HEADERS and user + else {} + ), }, ) r.raise_for_status() @@ -693,7 +773,7 @@ async def delete_model( async def show_model_info( request: Request, form_data: ModelNameForm, user=Depends(get_verified_user) ): - await get_all_models(request) + await get_all_models(request, user=user) models = request.app.state.OLLAMA_MODELS if form_data.name not in models: @@ -714,6 +794,16 @@ async def show_model_info( headers={ "Content-Type": "application/json", **({"Authorization": f"Bearer {key}"} if key else {}), + **( + { + "X-OpenWebUI-User-Name": user.name, + "X-OpenWebUI-User-Id": user.id, + "X-OpenWebUI-User-Email": user.email, + "X-OpenWebUI-User-Role": user.role, + } + if ENABLE_FORWARD_USER_INFO_HEADERS and user + else {} + ), }, data=form_data.model_dump_json(exclude_none=True).encode(), ) @@ -757,7 +847,7 @@ async def embed( log.info(f"generate_ollama_batch_embeddings {form_data}") if url_idx is None: - await get_all_models(request) + await get_all_models(request, user=user) models = request.app.state.OLLAMA_MODELS model = form_data.model @@ -783,6 +873,16 @@ async def embed( headers={ "Content-Type": "application/json", **({"Authorization": f"Bearer {key}"} if key else {}), + **( + { + "X-OpenWebUI-User-Name": user.name, + "X-OpenWebUI-User-Id": user.id, + "X-OpenWebUI-User-Email": user.email, + "X-OpenWebUI-User-Role": user.role, + } + if ENABLE_FORWARD_USER_INFO_HEADERS and user + else {} + ), }, data=form_data.model_dump_json(exclude_none=True).encode(), ) @@ -826,7 +926,7 @@ async def embeddings( log.info(f"generate_ollama_embeddings {form_data}") if url_idx is None: - await get_all_models(request) + await get_all_models(request, user=user) models = request.app.state.OLLAMA_MODELS model = form_data.model @@ -852,6 +952,16 @@ async def embeddings( headers={ "Content-Type": "application/json", **({"Authorization": f"Bearer {key}"} if key else {}), + **( + { + "X-OpenWebUI-User-Name": user.name, + "X-OpenWebUI-User-Id": user.id, + "X-OpenWebUI-User-Email": user.email, + "X-OpenWebUI-User-Role": user.role, + } + if ENABLE_FORWARD_USER_INFO_HEADERS and user + else {} + ), }, data=form_data.model_dump_json(exclude_none=True).encode(), ) @@ -901,7 +1011,7 @@ async def generate_completion( user=Depends(get_verified_user), ): if url_idx is None: - await get_all_models(request) + await get_all_models(request, user=user) models = request.app.state.OLLAMA_MODELS model = form_data.model @@ -931,6 +1041,7 @@ async def generate_completion( url=f"{url}/api/generate", payload=form_data.model_dump_json(exclude_none=True).encode(), key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS), + user=user, ) @@ -1060,6 +1171,7 @@ async def generate_chat_completion( stream=form_data.stream, key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS), content_type="application/x-ndjson", + user=user, ) @@ -1162,6 +1274,7 @@ async def generate_openai_completion( payload=json.dumps(payload), stream=payload.get("stream", False), key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS), + user=user, ) @@ -1240,6 +1353,7 @@ async def generate_openai_chat_completion( payload=json.dumps(payload), stream=payload.get("stream", False), key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS), + user=user, ) @@ -1253,7 +1367,7 @@ async def get_openai_models( models = [] if url_idx is None: - model_list = await get_all_models(request) + model_list = await get_all_models(request, user=user) models = [ { "id": model["model"], diff --git a/backend/open_webui/routers/openai.py b/backend/open_webui/routers/openai.py index 78aae9980..1adbfdfcd 100644 --- a/backend/open_webui/routers/openai.py +++ b/backend/open_webui/routers/openai.py @@ -26,6 +26,7 @@ from open_webui.env import ( ENABLE_FORWARD_USER_INFO_HEADERS, BYPASS_MODEL_ACCESS_CONTROL, ) +from open_webui.models.users import UserModel from open_webui.constants import ERROR_MESSAGES from open_webui.env import ENV, SRC_LOG_LEVELS @@ -51,12 +52,25 @@ log.setLevel(SRC_LOG_LEVELS["OPENAI"]) ########################################## -async def send_get_request(url, key=None): +async def send_get_request(url, key=None, user: UserModel = None): timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST) try: async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session: async with session.get( - url, headers={**({"Authorization": f"Bearer {key}"} if key else {})} + url, + headers={ + **({"Authorization": f"Bearer {key}"} if key else {}), + **( + { + "X-OpenWebUI-User-Name": user.name, + "X-OpenWebUI-User-Id": user.id, + "X-OpenWebUI-User-Email": user.email, + "X-OpenWebUI-User-Role": user.role, + } + if ENABLE_FORWARD_USER_INFO_HEADERS + else {} + ), + }, ) as response: return await response.json() except Exception as e: @@ -247,7 +261,7 @@ async def speech(request: Request, user=Depends(get_verified_user)): raise HTTPException(status_code=401, detail=ERROR_MESSAGES.OPENAI_NOT_FOUND) -async def get_all_models_responses(request: Request) -> list: +async def get_all_models_responses(request: Request, user: UserModel) -> list: if not request.app.state.config.ENABLE_OPENAI_API: return [] @@ -271,7 +285,9 @@ async def get_all_models_responses(request: Request) -> list: ): request_tasks.append( send_get_request( - f"{url}/models", request.app.state.config.OPENAI_API_KEYS[idx] + f"{url}/models", + request.app.state.config.OPENAI_API_KEYS[idx], + user=user, ) ) else: @@ -291,6 +307,7 @@ async def get_all_models_responses(request: Request) -> list: send_get_request( f"{url}/models", request.app.state.config.OPENAI_API_KEYS[idx], + user=user, ) ) else: @@ -352,13 +369,13 @@ async def get_filtered_models(models, user): @cached(ttl=3) -async def get_all_models(request: Request) -> dict[str, list]: +async def get_all_models(request: Request, user: UserModel) -> dict[str, list]: log.info("get_all_models()") if not request.app.state.config.ENABLE_OPENAI_API: return {"data": []} - responses = await get_all_models_responses(request) + responses = await get_all_models_responses(request, user=user) def extract_data(response): if response and "data" in response: @@ -418,7 +435,7 @@ async def get_models( } if url_idx is None: - models = await get_all_models(request) + models = await get_all_models(request, user=user) else: url = request.app.state.config.OPENAI_API_BASE_URLS[url_idx] key = request.app.state.config.OPENAI_API_KEYS[url_idx] @@ -515,6 +532,16 @@ async def verify_connection( headers={ "Authorization": f"Bearer {key}", "Content-Type": "application/json", + **( + { + "X-OpenWebUI-User-Name": user.name, + "X-OpenWebUI-User-Id": user.id, + "X-OpenWebUI-User-Email": user.email, + "X-OpenWebUI-User-Role": user.role, + } + if ENABLE_FORWARD_USER_INFO_HEADERS + else {} + ), }, ) as r: if r.status != 200: @@ -587,7 +614,7 @@ async def generate_chat_completion( detail="Model not found", ) - await get_all_models(request) + await get_all_models(request, user=user) model = request.app.state.OPENAI_MODELS.get(model_id) if model: idx = model["urlIdx"] diff --git a/backend/open_webui/utils/chat.py b/backend/open_webui/utils/chat.py index 209fb02dc..739d6515e 100644 --- a/backend/open_webui/utils/chat.py +++ b/backend/open_webui/utils/chat.py @@ -285,7 +285,7 @@ chat_completion = generate_chat_completion async def chat_completed(request: Request, form_data: dict, user: Any): if not request.app.state.MODELS: - await get_all_models(request) + await get_all_models(request, user=user) if getattr(request.state, "direct", False) and hasattr(request.state, "model"): models = { @@ -351,7 +351,7 @@ async def chat_action(request: Request, action_id: str, form_data: dict, user: A raise Exception(f"Action not found: {action_id}") if not request.app.state.MODELS: - await get_all_models(request) + await get_all_models(request, user=user) if getattr(request.state, "direct", False) and hasattr(request.state, "model"): models = { diff --git a/backend/open_webui/utils/models.py b/backend/open_webui/utils/models.py index a2c0eadca..149e41a41 100644 --- a/backend/open_webui/utils/models.py +++ b/backend/open_webui/utils/models.py @@ -22,6 +22,7 @@ from open_webui.config import ( ) from open_webui.env import SRC_LOG_LEVELS, GLOBAL_LOG_LEVEL +from open_webui.models.users import UserModel logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL) @@ -29,17 +30,17 @@ log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["MAIN"]) -async def get_all_base_models(request: Request): +async def get_all_base_models(request: Request, user: UserModel = None): function_models = [] openai_models = [] ollama_models = [] if request.app.state.config.ENABLE_OPENAI_API: - openai_models = await openai.get_all_models(request) + openai_models = await openai.get_all_models(request, user=user) openai_models = openai_models["data"] if request.app.state.config.ENABLE_OLLAMA_API: - ollama_models = await ollama.get_all_models(request) + ollama_models = await ollama.get_all_models(request, user=user) ollama_models = [ { "id": model["model"], @@ -58,8 +59,8 @@ async def get_all_base_models(request: Request): return models -async def get_all_models(request): - models = await get_all_base_models(request) +async def get_all_models(request, user: UserModel = None): + models = await get_all_base_models(request, user=user) # If there are no models, return an empty list if len(models) == 0: