From 59c3a1811869f11c2d3fc2f361b65c6e81d11236 Mon Sep 17 00:00:00 2001 From: Timothy Jaeryang Baek Date: Sun, 1 Dec 2024 18:25:44 -0800 Subject: [PATCH] enh: `BYPASS_MODEL_ACCESS_CONTROL` --- backend/open_webui/apps/ollama/main.py | 7 ++++--- backend/open_webui/apps/openai/main.py | 3 ++- backend/open_webui/config.py | 2 +- backend/open_webui/env.py | 3 +++ backend/open_webui/main.py | 5 +++-- 5 files changed, 13 insertions(+), 7 deletions(-) diff --git a/backend/open_webui/apps/ollama/main.py b/backend/open_webui/apps/ollama/main.py index 71a40cb47..82a37a752 100644 --- a/backend/open_webui/apps/ollama/main.py +++ b/backend/open_webui/apps/ollama/main.py @@ -24,6 +24,7 @@ from open_webui.config import ( from open_webui.env import ( AIOHTTP_CLIENT_TIMEOUT, AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST, + BYPASS_MODEL_ACCESS_CONTROL, ) @@ -359,7 +360,7 @@ async def get_ollama_tags( detail=error_detail, ) - if user.role == "user": + if user.role == "user" and not BYPASS_MODEL_ACCESS_CONTROL: # Filter models based on user access control filtered_models = [] for model in models.get("models", []): @@ -1067,7 +1068,7 @@ async def generate_openai_chat_completion( payload = apply_model_system_prompt_to_body(params, payload, user) # Check if user has access to the model - if user.role == "user": + if user.role == "user" and not BYPASS_MODEL_ACCESS_CONTROL: if not ( user.id == model_info.user_id or has_access( @@ -1156,7 +1157,7 @@ async def get_openai_models( detail=error_detail, ) - if user.role == "user": + if user.role == "user" and not BYPASS_MODEL_ACCESS_CONTROL: # Filter models based on user access control filtered_models = [] for model in models: diff --git a/backend/open_webui/apps/openai/main.py b/backend/open_webui/apps/openai/main.py index 31c36a8a1..9193c2be6 100644 --- a/backend/open_webui/apps/openai/main.py +++ b/backend/open_webui/apps/openai/main.py @@ -24,6 +24,7 @@ from open_webui.env import ( AIOHTTP_CLIENT_TIMEOUT, AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST, ENABLE_FORWARD_USER_INFO_HEADERS, + BYPASS_MODEL_ACCESS_CONTROL, ) from open_webui.constants import ERROR_MESSAGES @@ -422,7 +423,7 @@ async def get_models(url_idx: Optional[int] = None, user=Depends(get_verified_us error_detail = f"Unexpected error: {str(e)}" raise HTTPException(status_code=500, detail=error_detail) - if user.role == "user": + if user.role == "user" and not BYPASS_MODEL_ACCESS_CONTROL: # Filter models based on user access control filtered_models = [] for model in models.get("data", []): diff --git a/backend/open_webui/config.py b/backend/open_webui/config.py index 15d209941..c0a0f63b5 100644 --- a/backend/open_webui/config.py +++ b/backend/open_webui/config.py @@ -702,6 +702,7 @@ ENABLE_LOGIN_FORM = PersistentConfig( os.environ.get("ENABLE_LOGIN_FORM", "True").lower() == "true", ) + DEFAULT_LOCALE = PersistentConfig( "DEFAULT_LOCALE", "ui.default_locale", @@ -758,7 +759,6 @@ DEFAULT_USER_ROLE = PersistentConfig( os.getenv("DEFAULT_USER_ROLE", "pending"), ) - USER_PERMISSIONS_WORKSPACE_MODELS_ACCESS = ( os.environ.get("USER_PERMISSIONS_WORKSPACE_MODELS_ACCESS", "False").lower() == "true" diff --git a/backend/open_webui/env.py b/backend/open_webui/env.py index 28b5a10a5..e1b350ead 100644 --- a/backend/open_webui/env.py +++ b/backend/open_webui/env.py @@ -329,6 +329,9 @@ WEBUI_AUTH_TRUSTED_EMAIL_HEADER = os.environ.get( ) WEBUI_AUTH_TRUSTED_NAME_HEADER = os.environ.get("WEBUI_AUTH_TRUSTED_NAME_HEADER", None) +BYPASS_MODEL_ACCESS_CONTROL = ( + os.environ.get("BYPASS_MODEL_ACCESS_CONTROL", "False").lower() == "true" +) #################################### # WEBUI_SECRET_KEY diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index 177bded66..1bf221beb 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -112,6 +112,7 @@ from open_webui.env import ( WEBUI_SESSION_COOKIE_SAME_SITE, WEBUI_SESSION_COOKIE_SECURE, WEBUI_URL, + BYPASS_MODEL_ACCESS_CONTROL, RESET_CONFIG_ON_START, OFFLINE_MODE, ) @@ -621,7 +622,7 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): ) model_info = Models.get_model_by_id(model["id"]) - if user.role == "user": + if user.role == "user" and not BYPASS_MODEL_ACCESS_CONTROL: if model.get("arena"): if not has_access( user.id, @@ -1224,7 +1225,7 @@ async def get_models(user=Depends(get_verified_user)): ) # Filter out models that the user does not have access to - if user.role == "user": + if user.role == "user" and not BYPASS_MODEL_ACCESS_CONTROL: filtered_models = [] for model in models: if model.get("arena"):