From 9a691c038731a00aa61eca9e33f966fd63621fcd Mon Sep 17 00:00:00 2001 From: Patrick Willnow Date: Thu, 3 Oct 2024 11:12:14 +0200 Subject: [PATCH 01/14] Add oauth role mapping also add node env to allow local build to succeed --- Dockerfile | 1 + backend/open_webui/config.py | 12 ++++++++++++ backend/open_webui/main.py | 27 ++++++++++++++++++++++----- 3 files changed, 35 insertions(+), 5 deletions(-) diff --git a/Dockerfile b/Dockerfile index c944f54e6..5e7f80bc8 100644 --- a/Dockerfile +++ b/Dockerfile @@ -27,6 +27,7 @@ RUN npm ci COPY . . ENV APP_BUILD_HASH=${BUILD_HASH} +ENV NODE_OPTIONS="--max_old_space_size=8192" RUN npm run build ######## WebUI backend ######## diff --git a/backend/open_webui/config.py b/backend/open_webui/config.py index f531a8728..f9921d9cb 100644 --- a/backend/open_webui/config.py +++ b/backend/open_webui/config.py @@ -278,6 +278,18 @@ ENABLE_OAUTH_SIGNUP = PersistentConfig( os.environ.get("ENABLE_OAUTH_SIGNUP", "False").lower() == "true", ) +ENABLE_OAUTH_ROLE_MAPPING = PersistentConfig( + "ENABLE_OAUTH_ROLE_MAPPING", + "oauth.enable_role_mapping", + os.environ.get("ENABLE_OAUTH_ROLE_MAPPING", "False").lower() == "true", +) + +OAUTH_ROLES_CLAIM = PersistentConfig( + "OAUTH_ROLES_CLAIM", + "oauth.roles_claim", + os.environ.get("OAUTH_ROLES_CLAIM", "roles"), +) + OAUTH_MERGE_ACCOUNTS_BY_EMAIL = PersistentConfig( "OAUTH_MERGE_ACCOUNTS_BY_EMAIL", "oauth.merge_accounts_by_email", diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index 4af48906b..77d486fb7 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -2245,6 +2245,18 @@ async def oauth_callback(provider: str, request: Request, response: Response): # Check if the user exists user = Users.get_user_by_oauth_sub(provider_sub) + if user: + role = user.role + if Users.get_num_users() == 1: + role = "admin" + elif webui_app.state.config.ENABLE_OAUTH_ROLE_MAPPING: + oauth_roles = user_data.get(webui_app.state.config.OAUTH_ROLE_CLAIM) + if oauth_roles: + for allowed_role in ["pending", "user", "admin"]: + role = allowed_role if allowed_role in oauth_roles else role + if role != user.role: + Users.update_user_role_by_id(user.id, role) + if not user: # If the user does not exist, check if merging is enabled if OAUTH_MERGE_ACCOUNTS_BY_EMAIL.value: @@ -2284,11 +2296,16 @@ async def oauth_callback(provider: str, request: Request, response: Response): if not picture_url: picture_url = "/user.png" username_claim = webui_app.state.config.OAUTH_USERNAME_CLAIM - role = ( - "admin" - if Users.get_num_users() == 0 - else webui_app.state.config.DEFAULT_USER_ROLE - ) + + role = webui_app.state.config.DEFAULT_USER_ROLE + if Users.get_num_users() == 0: + role = "admin" + elif webui_app.state.config.ENABLE_OAUTH_ROLE_MAPPING: + oauth_roles = user_data.get(webui_app.state.config.OAUTH_ROLE_CLAIM) + if oauth_roles: + for allowed_role in ["pending", "user", "admin"]: + role = allowed_role if allowed_role in oauth_roles else role + user = Auths.insert_new_auth( email=email, password=get_password_hash( From dc921786418736d97cabe13625b1bd992063280f Mon Sep 17 00:00:00 2001 From: "Willnow, Patrick" Date: Thu, 3 Oct 2024 20:55:32 +0200 Subject: [PATCH 02/14] Fix missing key mapping --- backend/open_webui/apps/webui/main.py | 5 +++++ backend/open_webui/config.py | 26 +++++++++++++------------- backend/open_webui/main.py | 2 +- 3 files changed, 19 insertions(+), 14 deletions(-) diff --git a/backend/open_webui/apps/webui/main.py b/backend/open_webui/apps/webui/main.py index 6c6f197dd..8682cd767 100644 --- a/backend/open_webui/apps/webui/main.py +++ b/backend/open_webui/apps/webui/main.py @@ -32,6 +32,8 @@ from open_webui.config import ( ENABLE_MESSAGE_RATING, ENABLE_SIGNUP, JWT_EXPIRES_IN, + ENABLE_OAUTH_ROLE_MAPPING, + OAUTH_ROLES_CLAIM, OAUTH_EMAIL_CLAIM, OAUTH_PICTURE_CLAIM, OAUTH_USERNAME_CLAIM, @@ -93,6 +95,9 @@ app.state.config.OAUTH_USERNAME_CLAIM = OAUTH_USERNAME_CLAIM app.state.config.OAUTH_PICTURE_CLAIM = OAUTH_PICTURE_CLAIM app.state.config.OAUTH_EMAIL_CLAIM = OAUTH_EMAIL_CLAIM +app.state.config.ENABLE_OAUTH_ROLE_MAPPING = ENABLE_OAUTH_ROLE_MAPPING +app.state.config.OAUTH_ROLES_CLAIM = OAUTH_ROLES_CLAIM + app.state.MODELS = {} app.state.TOOLS = {} app.state.FUNCTIONS = {} diff --git a/backend/open_webui/config.py b/backend/open_webui/config.py index f9921d9cb..18e0c398e 100644 --- a/backend/open_webui/config.py +++ b/backend/open_webui/config.py @@ -278,18 +278,6 @@ ENABLE_OAUTH_SIGNUP = PersistentConfig( os.environ.get("ENABLE_OAUTH_SIGNUP", "False").lower() == "true", ) -ENABLE_OAUTH_ROLE_MAPPING = PersistentConfig( - "ENABLE_OAUTH_ROLE_MAPPING", - "oauth.enable_role_mapping", - os.environ.get("ENABLE_OAUTH_ROLE_MAPPING", "False").lower() == "true", -) - -OAUTH_ROLES_CLAIM = PersistentConfig( - "OAUTH_ROLES_CLAIM", - "oauth.roles_claim", - os.environ.get("OAUTH_ROLES_CLAIM", "roles"), -) - OAUTH_MERGE_ACCOUNTS_BY_EMAIL = PersistentConfig( "OAUTH_MERGE_ACCOUNTS_BY_EMAIL", "oauth.merge_accounts_by_email", @@ -395,7 +383,7 @@ OAUTH_USERNAME_CLAIM = PersistentConfig( ) OAUTH_PICTURE_CLAIM = PersistentConfig( - "OAUTH_USERNAME_CLAIM", + "OAUTH_PICTURE_CLAIM", "oauth.oidc.avatar_claim", os.environ.get("OAUTH_PICTURE_CLAIM", "picture"), ) @@ -406,6 +394,18 @@ OAUTH_EMAIL_CLAIM = PersistentConfig( os.environ.get("OAUTH_EMAIL_CLAIM", "email"), ) +ENABLE_OAUTH_ROLE_MAPPING = PersistentConfig( + "ENABLE_OAUTH_ROLE_MAPPING", + "oauth.enable_role_mapping", + os.environ.get("ENABLE_OAUTH_ROLE_MAPPING", "False").lower() == "true", +) + +OAUTH_ROLES_CLAIM = PersistentConfig( + "OAUTH_ROLES_CLAIM", + "oauth.roles_claim", + os.environ.get("OAUTH_ROLES_CLAIM", "roles"), +) + def load_oauth_providers(): OAUTH_PROVIDERS.clear() diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index 77d486fb7..e24a5a969 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -2249,7 +2249,7 @@ async def oauth_callback(provider: str, request: Request, response: Response): role = user.role if Users.get_num_users() == 1: role = "admin" - elif webui_app.state.config.ENABLE_OAUTH_ROLE_MAPPING: + elif webui_app.state.config.ENABLE_OAUTH_ROLE_MAPPING.value: oauth_roles = user_data.get(webui_app.state.config.OAUTH_ROLE_CLAIM) if oauth_roles: for allowed_role in ["pending", "user", "admin"]: From c9d948f2847820a3df460bca5d0f22f5e0cb5598 Mon Sep 17 00:00:00 2001 From: "Willnow, Patrick" Date: Thu, 3 Oct 2024 21:38:56 +0200 Subject: [PATCH 03/14] Remove copy pasta error of calling value on bool --- backend/open_webui/main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index e24a5a969..77d486fb7 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -2249,7 +2249,7 @@ async def oauth_callback(provider: str, request: Request, response: Response): role = user.role if Users.get_num_users() == 1: role = "admin" - elif webui_app.state.config.ENABLE_OAUTH_ROLE_MAPPING.value: + elif webui_app.state.config.ENABLE_OAUTH_ROLE_MAPPING: oauth_roles = user_data.get(webui_app.state.config.OAUTH_ROLE_CLAIM) if oauth_roles: for allowed_role in ["pending", "user", "admin"]: From 0a7373dae18fdc80d354bdb9ce8409cd580de214 Mon Sep 17 00:00:00 2001 From: "Willnow, Patrick" Date: Thu, 3 Oct 2024 22:56:52 +0200 Subject: [PATCH 04/14] add pending as role fallback add logging to determine correct handling of oauth roles --- backend/open_webui/main.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index 77d486fb7..9442ab9af 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -2250,10 +2250,16 @@ async def oauth_callback(provider: str, request: Request, response: Response): if Users.get_num_users() == 1: role = "admin" elif webui_app.state.config.ENABLE_OAUTH_ROLE_MAPPING: - oauth_roles = user_data.get(webui_app.state.config.OAUTH_ROLE_CLAIM) + oauth_roles = user_data.get(webui_app.state.config.OAUTH_ROLES_CLAIM) + log.info(f"User {user.name} has OAuth roles: {oauth_roles}") if oauth_roles: for allowed_role in ["pending", "user", "admin"]: role = allowed_role if allowed_role in oauth_roles else role + log.info(f"Applied role: {role} to user {user.name}") + else: + # If role mapping is enabled, but no roles are provided, fall back to pending + role = "pending" + if role != user.role: Users.update_user_role_by_id(user.id, role) @@ -2305,6 +2311,9 @@ async def oauth_callback(provider: str, request: Request, response: Response): if oauth_roles: for allowed_role in ["pending", "user", "admin"]: role = allowed_role if allowed_role in oauth_roles else role + else: + # If role mapping is enabled, but no roles are provided, fall back to pending + role = "pending" user = Auths.insert_new_auth( email=email, From 5b2e1ca7cdbeba0070182fca1155f9b07c4c91fe Mon Sep 17 00:00:00 2001 From: "Willnow, Patrick" Date: Thu, 3 Oct 2024 23:06:05 +0200 Subject: [PATCH 05/14] add more logging --- backend/open_webui/main.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index 9442ab9af..89aed7fcb 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -2244,7 +2244,8 @@ async def oauth_callback(provider: str, request: Request, response: Response): # Check if the user exists user = Users.get_user_by_oauth_sub(provider_sub) - + # print all user data content for debugging + log.info(f"User data: {user_data}") if user: role = user.role if Users.get_num_users() == 1: From 8e4776ada16bb94ecf3381fe7ef2653426a25453 Mon Sep 17 00:00:00 2001 From: "Willnow, Patrick" Date: Thu, 3 Oct 2024 23:25:00 +0200 Subject: [PATCH 06/14] add handling nested claims... --- backend/open_webui/main.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index 89aed7fcb..6b601d446 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -2251,7 +2251,16 @@ async def oauth_callback(provider: str, request: Request, response: Response): if Users.get_num_users() == 1: role = "admin" elif webui_app.state.config.ENABLE_OAUTH_ROLE_MAPPING: - oauth_roles = user_data.get(webui_app.state.config.OAUTH_ROLES_CLAIM) + oauth_claim = webui_app.state.config.OAUTH_ROLES_CLAIM + oauth_roles = user_data.get(oauth_claim) # Works for simple claims with no nesting + if "." in oauth_claim: + # Implementation to handle nested claims of arbitrary depth + nested_claims = oauth_claim.split(".") + claim_data = user_data.get(nested_claims[0]) + for nested_claim in nested_claims[1:]: + claim_data = claim_data.get(nested_claim) + oauth_roles = claim_data + log.info(f"User {user.name} has OAuth roles: {oauth_roles}") if oauth_roles: for allowed_role in ["pending", "user", "admin"]: From 79b9c8a677796c11c7f44f40b5be7500920c53b3 Mon Sep 17 00:00:00 2001 From: "Willnow, Patrick" Date: Fri, 4 Oct 2024 00:05:36 +0200 Subject: [PATCH 07/14] handling no claim received when nested expected --- backend/open_webui/main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index 6b601d446..551357f90 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -2253,7 +2253,7 @@ async def oauth_callback(provider: str, request: Request, response: Response): elif webui_app.state.config.ENABLE_OAUTH_ROLE_MAPPING: oauth_claim = webui_app.state.config.OAUTH_ROLES_CLAIM oauth_roles = user_data.get(oauth_claim) # Works for simple claims with no nesting - if "." in oauth_claim: + if oauth_roles and "." in oauth_claim: # Implementation to handle nested claims of arbitrary depth nested_claims = oauth_claim.split(".") claim_data = user_data.get(nested_claims[0]) From 6ddd8c72410baa72203f9012e238cda7cd0c5d50 Mon Sep 17 00:00:00 2001 From: Patrick Willnow Date: Fri, 4 Oct 2024 10:14:20 +0200 Subject: [PATCH 08/14] fix logic --- backend/open_webui/main.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index 551357f90..ce6a05184 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -2252,13 +2252,15 @@ async def oauth_callback(provider: str, request: Request, response: Response): role = "admin" elif webui_app.state.config.ENABLE_OAUTH_ROLE_MAPPING: oauth_claim = webui_app.state.config.OAUTH_ROLES_CLAIM - oauth_roles = user_data.get(oauth_claim) # Works for simple claims with no nesting + oauth_roles = user_data.get(oauth_claim) # Works for simple claims with no nesting if oauth_roles and "." in oauth_claim: # Implementation to handle nested claims of arbitrary depth nested_claims = oauth_claim.split(".") - claim_data = user_data.get(nested_claims[0]) - for nested_claim in nested_claims[1:]: + claim_data = user_data + for nested_claim in nested_claims: claim_data = claim_data.get(nested_claim) + if claim_data is None: + break oauth_roles = claim_data log.info(f"User {user.name} has OAuth roles: {oauth_roles}") From f751d22a208b3e2d776ecf06ac0889f62b739f68 Mon Sep 17 00:00:00 2001 From: "Willnow, Patrick" Date: Fri, 4 Oct 2024 13:26:49 +0200 Subject: [PATCH 09/14] Refinement --- backend/open_webui/main.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index ce6a05184..7374b7f62 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -2252,25 +2252,23 @@ async def oauth_callback(provider: str, request: Request, response: Response): role = "admin" elif webui_app.state.config.ENABLE_OAUTH_ROLE_MAPPING: oauth_claim = webui_app.state.config.OAUTH_ROLES_CLAIM - oauth_roles = user_data.get(oauth_claim) # Works for simple claims with no nesting - if oauth_roles and "." in oauth_claim: - # Implementation to handle nested claims of arbitrary depth - nested_claims = oauth_claim.split(".") + oauth_roles = None + + if oauth_claim: claim_data = user_data + nested_claims = oauth_claim.split(".") for nested_claim in nested_claims: - claim_data = claim_data.get(nested_claim) - if claim_data is None: - break - oauth_roles = claim_data + claim_data = claim_data.get(nested_claim, {}) + oauth_roles = claim_data if isinstance(claim_data, list) else None log.info(f"User {user.name} has OAuth roles: {oauth_roles}") if oauth_roles: for allowed_role in ["pending", "user", "admin"]: role = allowed_role if allowed_role in oauth_roles else role - log.info(f"Applied role: {role} to user {user.name}") else: # If role mapping is enabled, but no roles are provided, fall back to pending role = "pending" + log.info(f"Applied role: {role} to user {user.name}") if role != user.role: Users.update_user_role_by_id(user.id, role) From edc15d0d7ce0a56f1b8fc601cd23cbabb9ad7e34 Mon Sep 17 00:00:00 2001 From: "Willnow, Patrick" Date: Thu, 10 Oct 2024 23:00:05 +0200 Subject: [PATCH 10/14] rewrite oauth role management logic to allow any custom roles to be used for oauth role to open webui role mapping --- backend/open_webui/apps/webui/main.py | 4 +- backend/open_webui/config.py | 17 ++- backend/open_webui/main.py | 169 ++++++++++++++------------ 3 files changed, 105 insertions(+), 85 deletions(-) diff --git a/backend/open_webui/apps/webui/main.py b/backend/open_webui/apps/webui/main.py index 8682cd767..0208c0ea9 100644 --- a/backend/open_webui/apps/webui/main.py +++ b/backend/open_webui/apps/webui/main.py @@ -32,7 +32,7 @@ from open_webui.config import ( ENABLE_MESSAGE_RATING, ENABLE_SIGNUP, JWT_EXPIRES_IN, - ENABLE_OAUTH_ROLE_MAPPING, + ENABLE_OAUTH_ROLE_MANAGEMENT, OAUTH_ROLES_CLAIM, OAUTH_EMAIL_CLAIM, OAUTH_PICTURE_CLAIM, @@ -95,7 +95,7 @@ app.state.config.OAUTH_USERNAME_CLAIM = OAUTH_USERNAME_CLAIM app.state.config.OAUTH_PICTURE_CLAIM = OAUTH_PICTURE_CLAIM app.state.config.OAUTH_EMAIL_CLAIM = OAUTH_EMAIL_CLAIM -app.state.config.ENABLE_OAUTH_ROLE_MAPPING = ENABLE_OAUTH_ROLE_MAPPING +app.state.config.ENABLE_OAUTH_ROLE_MANAGEMENT = ENABLE_OAUTH_ROLE_MANAGEMENT app.state.config.OAUTH_ROLES_CLAIM = OAUTH_ROLES_CLAIM app.state.MODELS = {} diff --git a/backend/open_webui/config.py b/backend/open_webui/config.py index 18e0c398e..c171eaadb 100644 --- a/backend/open_webui/config.py +++ b/backend/open_webui/config.py @@ -394,10 +394,10 @@ OAUTH_EMAIL_CLAIM = PersistentConfig( os.environ.get("OAUTH_EMAIL_CLAIM", "email"), ) -ENABLE_OAUTH_ROLE_MAPPING = PersistentConfig( - "ENABLE_OAUTH_ROLE_MAPPING", +ENABLE_OAUTH_ROLE_MANAGEMENT = PersistentConfig( + "ENABLE_OAUTH_ROLE_MANAGEMENT", "oauth.enable_role_mapping", - os.environ.get("ENABLE_OAUTH_ROLE_MAPPING", "False").lower() == "true", + os.environ.get("ENABLE_OAUTH_ROLE_MANAGEMENT", "False").lower() == "true", ) OAUTH_ROLES_CLAIM = PersistentConfig( @@ -406,6 +406,17 @@ OAUTH_ROLES_CLAIM = PersistentConfig( os.environ.get("OAUTH_ROLES_CLAIM", "roles"), ) +OAUTH_ALLOWED_ROLES = PersistentConfig( + "OAUTH_ALLOWED_ROLES", + "oauth.allowed_roles", + [role.strip() for role in os.environ.get("OAUTH_ALLOWED_ROLES", "pending,user,admin").split(",")], +) + +OAUTH_ADMIN_ROLES = PersistentConfig( + "OAUTH_ADMIN_ROLES", + "oauth.admin_roles", + [role.strip() for role in os.environ.get("OAUTH_ADMIN_ROLES", "admin").split(",")], +) def load_oauth_providers(): OAUTH_PROVIDERS.clear() diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index 7374b7f62..8095a66ca 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -16,7 +16,6 @@ from typing import Optional import aiohttp import requests - from open_webui.apps.audio.main import app as audio_app from open_webui.apps.images.main import app as images_app from open_webui.apps.ollama.main import app as ollama_app @@ -47,11 +46,9 @@ from open_webui.apps.webui.models.models import Models from open_webui.apps.webui.models.users import UserModel, Users from open_webui.apps.webui.utils import load_function_module_by_id - from authlib.integrations.starlette_client import OAuth from authlib.oidc.core import UserInfo - from open_webui.config import ( CACHE_DIR, CORS_ALLOW_ORIGIN, @@ -151,7 +148,6 @@ if SAFE_MODE: print("SAFE MODE ENABLED") Functions.deactivate_all_functions() - logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL) log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["MAIN"]) @@ -210,7 +206,6 @@ app.state.config.MODEL_FILTER_LIST = MODEL_FILTER_LIST app.state.config.WEBHOOK_URL = WEBHOOK_URL - app.state.config.TASK_MODEL = TASK_MODEL app.state.config.TASK_MODEL_EXTERNAL = TASK_MODEL_EXTERNAL app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE = TITLE_GENERATION_PROMPT_TEMPLATE @@ -238,14 +233,14 @@ def get_task_model_id(default_model_id): # Check if the user has a custom task model and use that model if app.state.MODELS[task_model_id]["owned_by"] == "ollama": if ( - app.state.config.TASK_MODEL - and app.state.config.TASK_MODEL in app.state.MODELS + app.state.config.TASK_MODEL + and app.state.config.TASK_MODEL in app.state.MODELS ): task_model_id = app.state.config.TASK_MODEL else: if ( - app.state.config.TASK_MODEL_EXTERNAL - and app.state.config.TASK_MODEL_EXTERNAL in app.state.MODELS + app.state.config.TASK_MODEL_EXTERNAL + and app.state.config.TASK_MODEL_EXTERNAL in app.state.MODELS ): task_model_id = app.state.config.TASK_MODEL_EXTERNAL @@ -382,7 +377,7 @@ async def get_content_from_response(response) -> Optional[str]: async def chat_completion_tools_handler( - body: dict, user: UserModel, extra_params: dict + body: dict, user: UserModel, extra_params: dict ) -> tuple[dict, dict]: # If tool_ids field is present, call the functions metadata = body.get("metadata", {}) @@ -608,8 +603,8 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): if prompt is None: raise Exception("No user message found") if ( - rag_app.state.config.RELEVANCE_THRESHOLD == 0 - and context_string.strip() == "" + rag_app.state.config.RELEVANCE_THRESHOLD == 0 + and context_string.strip() == "" ): log.debug( f"With a 0 relevancy threshold for RAG, the context cannot be empty" @@ -676,6 +671,7 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): app.add_middleware(ChatCompletionMiddleware) + ################################## # # Pipeline Middleware @@ -688,15 +684,15 @@ def get_sorted_filters(model_id): model for model in app.state.MODELS.values() if "pipeline" in model - and "type" in model["pipeline"] - and model["pipeline"]["type"] == "filter" - and ( - model["pipeline"]["pipelines"] == ["*"] - or any( - model_id == target_model_id - for target_model_id in model["pipeline"]["pipelines"] - ) - ) + and "type" in model["pipeline"] + and model["pipeline"]["type"] == "filter" + and ( + model["pipeline"]["pipelines"] == ["*"] + or any( + model_id == target_model_id + for target_model_id in model["pipeline"]["pipelines"] + ) + ) ] sorted_filters = sorted(filters, key=lambda x: x["pipeline"]["priority"]) return sorted_filters @@ -798,7 +794,6 @@ class PipelineMiddleware(BaseHTTPMiddleware): app.add_middleware(PipelineMiddleware) - app.add_middleware( CORSMiddleware, allow_origins=CORS_ALLOW_ORIGIN, @@ -844,8 +839,8 @@ async def update_embedding_function(request: Request, call_next): @app.middleware("http") async def inspect_websocket(request: Request, call_next): if ( - "/ws/socket.io" in request.url.path - and request.query_params.get("transport") == "websocket" + "/ws/socket.io" in request.url.path + and request.query_params.get("transport") == "websocket" ): upgrade = (request.headers.get("Upgrade") or "").lower() connection = (request.headers.get("Connection") or "").lower().split(",") @@ -913,8 +908,8 @@ async def get_all_models(): if custom_model.base_model_id is None: for model in models: if ( - custom_model.id == model["id"] - or custom_model.id == model["id"].split(":")[0] + custom_model.id == model["id"] + or custom_model.id == model["id"].split(":")[0] ): model["name"] = custom_model.name model["info"] = custom_model.model_dump() @@ -931,8 +926,8 @@ async def get_all_models(): for model in models: if ( - custom_model.base_model_id == model["id"] - or custom_model.base_model_id == model["id"].split(":")[0] + custom_model.base_model_id == model["id"] + or custom_model.base_model_id == model["id"].split(":")[0] ): owned_by = model["owned_by"] if "pipe" in model: @@ -1727,7 +1722,7 @@ async def get_pipelines_list(user=Depends(get_admin_user)): @app.post("/api/pipelines/upload") async def upload_pipeline( - urlIdx: int = Form(...), file: UploadFile = File(...), user=Depends(get_admin_user) + urlIdx: int = Form(...), file: UploadFile = File(...), user=Depends(get_admin_user) ): print("upload_pipeline", urlIdx, file.filename) # Check if the uploaded file is a python file @@ -1904,9 +1899,9 @@ async def get_pipelines(urlIdx: Optional[int] = None, user=Depends(get_admin_use @app.get("/api/pipelines/{pipeline_id}/valves") async def get_pipeline_valves( - urlIdx: Optional[int], - pipeline_id: str, - user=Depends(get_admin_user), + urlIdx: Optional[int], + pipeline_id: str, + user=Depends(get_admin_user), ): r = None try: @@ -1942,9 +1937,9 @@ async def get_pipeline_valves( @app.get("/api/pipelines/{pipeline_id}/valves/spec") async def get_pipeline_valves_spec( - urlIdx: Optional[int], - pipeline_id: str, - user=Depends(get_admin_user), + urlIdx: Optional[int], + pipeline_id: str, + user=Depends(get_admin_user), ): r = None try: @@ -1979,10 +1974,10 @@ async def get_pipeline_valves_spec( @app.post("/api/pipelines/{pipeline_id}/valves/update") async def update_pipeline_valves( - urlIdx: Optional[int], - pipeline_id: str, - form_data: dict, - user=Depends(get_admin_user), + urlIdx: Optional[int], + pipeline_id: str, + form_data: dict, + user=Depends(get_admin_user), ): r = None try: @@ -2106,7 +2101,7 @@ class ModelFilterConfigForm(BaseModel): @app.post("/api/config/model/filter") async def update_model_filter_config( - form_data: ModelFilterConfigForm, user=Depends(get_admin_user) + form_data: ModelFilterConfigForm, user=Depends(get_admin_user) ): app.state.config.ENABLE_MODEL_FILTER = form_data.enabled app.state.config.MODEL_FILTER_LIST = form_data.models @@ -2155,7 +2150,7 @@ async def get_app_latest_release_version(): try: async with aiohttp.ClientSession(trust_env=True) as session: async with session.get( - "https://api.github.com/repos/open-webui/open-webui/releases/latest" + "https://api.github.com/repos/open-webui/open-webui/releases/latest" ) as response: response.raise_for_status() data = await response.json() @@ -2198,6 +2193,53 @@ if len(OAUTH_PROVIDERS) > 0: ) +def get_user_role(user: UserModel, user_data: UserInfo) -> str: + if user and Users.get_num_users() == 1: + # If the user is the only user, assign the role "admin" - actually repairs role for single user on login + return "admin" + if not user and Users.get_num_users() == 0: + # If there are no users, assign the role "admin", as the first user will be an admin + return "admin" + + if webui_app.state.config.ENABLE_OAUTH_ROLE_MANAGEMENT: + oauth_claim = webui_app.state.config.OAUTH_ROLES_CLAIM + oauth_allowed_roles = webui_app.state.config.OAUTH_ALLOWED_ROLES + oauth_admin_roles = webui_app.state.config.OAUTH_ADMIN_ROLES + oauth_roles = None + role = "pending" # Default/fallback role if no matching roles are found + + # Next block extracts the roles from the user data, accepting nested claims of any depth + if oauth_claim and oauth_allowed_roles and oauth_admin_roles: + claim_data = user_data + nested_claims = oauth_claim.split(".") + for nested_claim in nested_claims: + claim_data = claim_data.get(nested_claim, {}) + oauth_roles = claim_data if isinstance(claim_data, list) else None + + # If any roles are found, check if they match the allowed or admin roles + if oauth_roles: + # If role management is enabled, and matching roles are provided, use the roles + for allowed_role in oauth_allowed_roles: + # If the user has any of the allowed roles, assign the role "user" + if allowed_role in oauth_roles: + role = "user" + break + for admin_role in oauth_admin_roles: + # If the user has any of the admin roles, assign the role "admin" + if admin_role in oauth_roles: + role = "admin" + break + else: + if not user: + # If role management is disabled, use the default role for new users + role = webui_app.state.config.DEFAULT_USER_ROLE + else: + # If role management is disabled, use the existing role for existing users + role = user.role + + return role + + @app.get("/oauth/{provider}/login") async def oauth_login(provider: str, request: Request): if provider not in OAUTH_PROVIDERS: @@ -2244,34 +2286,6 @@ async def oauth_callback(provider: str, request: Request, response: Response): # Check if the user exists user = Users.get_user_by_oauth_sub(provider_sub) - # print all user data content for debugging - log.info(f"User data: {user_data}") - if user: - role = user.role - if Users.get_num_users() == 1: - role = "admin" - elif webui_app.state.config.ENABLE_OAUTH_ROLE_MAPPING: - oauth_claim = webui_app.state.config.OAUTH_ROLES_CLAIM - oauth_roles = None - - if oauth_claim: - claim_data = user_data - nested_claims = oauth_claim.split(".") - for nested_claim in nested_claims: - claim_data = claim_data.get(nested_claim, {}) - oauth_roles = claim_data if isinstance(claim_data, list) else None - - log.info(f"User {user.name} has OAuth roles: {oauth_roles}") - if oauth_roles: - for allowed_role in ["pending", "user", "admin"]: - role = allowed_role if allowed_role in oauth_roles else role - else: - # If role mapping is enabled, but no roles are provided, fall back to pending - role = "pending" - log.info(f"Applied role: {role} to user {user.name}") - - if role != user.role: - Users.update_user_role_by_id(user.id, role) if not user: # If the user does not exist, check if merging is enabled @@ -2282,6 +2296,11 @@ async def oauth_callback(provider: str, request: Request, response: Response): # Update the user with the new oauth sub Users.update_user_oauth_sub_by_id(user.id, provider_sub) + if user: + determined_role = get_user_role(user, user_data) + if user.role != determined_role: + Users.update_user_role_by_id(user.id, determined_role) + if not user: # If the user does not exist, check if signups are enabled if ENABLE_OAUTH_SIGNUP.value: @@ -2313,17 +2332,7 @@ async def oauth_callback(provider: str, request: Request, response: Response): picture_url = "/user.png" username_claim = webui_app.state.config.OAUTH_USERNAME_CLAIM - role = webui_app.state.config.DEFAULT_USER_ROLE - if Users.get_num_users() == 0: - role = "admin" - elif webui_app.state.config.ENABLE_OAUTH_ROLE_MAPPING: - oauth_roles = user_data.get(webui_app.state.config.OAUTH_ROLE_CLAIM) - if oauth_roles: - for allowed_role in ["pending", "user", "admin"]: - role = allowed_role if allowed_role in oauth_roles else role - else: - # If role mapping is enabled, but no roles are provided, fall back to pending - role = "pending" + role = get_user_role(None, user_data) user = Auths.insert_new_auth( email=email, From 1c5b6987e20ee39ec6c99bfb00a0c373e8ea2d77 Mon Sep 17 00:00:00 2001 From: "Willnow, Patrick" Date: Fri, 11 Oct 2024 14:08:11 +0200 Subject: [PATCH 11/14] add missing env mapping --- backend/open_webui/apps/webui/main.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/backend/open_webui/apps/webui/main.py b/backend/open_webui/apps/webui/main.py index 0208c0ea9..586b14f92 100644 --- a/backend/open_webui/apps/webui/main.py +++ b/backend/open_webui/apps/webui/main.py @@ -37,12 +37,14 @@ from open_webui.config import ( OAUTH_EMAIL_CLAIM, OAUTH_PICTURE_CLAIM, OAUTH_USERNAME_CLAIM, + OAUTH_ALLOWED_ROLES, + OAUTH_ADMIN_ROLES, SHOW_ADMIN_DETAILS, USER_PERMISSIONS, WEBHOOK_URL, WEBUI_AUTH, WEBUI_BANNERS, - AppConfig, + AppConfig, OAUTH_ALLOWED_ROLES, OAUTH_ADMIN_ROLES, ) from open_webui.env import ( WEBUI_AUTH_TRUSTED_EMAIL_HEADER, @@ -97,6 +99,8 @@ app.state.config.OAUTH_EMAIL_CLAIM = OAUTH_EMAIL_CLAIM app.state.config.ENABLE_OAUTH_ROLE_MANAGEMENT = ENABLE_OAUTH_ROLE_MANAGEMENT app.state.config.OAUTH_ROLES_CLAIM = OAUTH_ROLES_CLAIM +app.state.config.OAUTH_ALLOWED_ROLES = OAUTH_ALLOWED_ROLES +app.state.config.OAUTH_ADMIN_ROLES = OAUTH_ADMIN_ROLES app.state.MODELS = {} app.state.TOOLS = {} From 08ff4947549f84bd2f49cd836a2b8ad3a10db96a Mon Sep 17 00:00:00 2001 From: "Willnow, Patrick" Date: Wed, 16 Oct 2024 09:42:47 +0200 Subject: [PATCH 12/14] WIP - refactoring oauth functions to enable refresh functionality --- backend/open_webui/apps/webui/main.py | 2 +- backend/open_webui/main.py | 190 +-------------------- backend/open_webui/utils/oauth.py | 234 ++++++++++++++++++++++++++ 3 files changed, 239 insertions(+), 187 deletions(-) create mode 100644 backend/open_webui/utils/oauth.py diff --git a/backend/open_webui/apps/webui/main.py b/backend/open_webui/apps/webui/main.py index 586b14f92..5849e0217 100644 --- a/backend/open_webui/apps/webui/main.py +++ b/backend/open_webui/apps/webui/main.py @@ -44,7 +44,7 @@ from open_webui.config import ( WEBHOOK_URL, WEBUI_AUTH, WEBUI_BANNERS, - AppConfig, OAUTH_ALLOWED_ROLES, OAUTH_ADMIN_ROLES, + AppConfig, ) from open_webui.env import ( WEBUI_AUTH_TRUSTED_EMAIL_HEADER, diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index 8095a66ca..fcb2db055 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -46,8 +46,6 @@ from open_webui.apps.webui.models.models import Models from open_webui.apps.webui.models.users import UserModel, Users from open_webui.apps.webui.utils import load_function_module_by_id -from authlib.integrations.starlette_client import OAuth -from authlib.oidc.core import UserInfo from open_webui.config import ( CACHE_DIR, @@ -144,6 +142,8 @@ from open_webui.utils.response import ( convert_streaming_response_ollama_to_openai, ) +from open_webui.utils.oauth import oauth_manager + if SAFE_MODE: print("SAFE MODE ENABLED") Functions.deactivate_all_functions() @@ -2168,20 +2168,6 @@ async def get_app_latest_release_version(): # OAuth Login & Callback ############################ -oauth = OAuth() - -for provider_name, provider_config in OAUTH_PROVIDERS.items(): - oauth.register( - name=provider_name, - client_id=provider_config["client_id"], - client_secret=provider_config["client_secret"], - server_metadata_url=provider_config["server_metadata_url"], - client_kwargs={ - "scope": provider_config["scope"], - }, - redirect_uri=provider_config["redirect_uri"], - ) - # SessionMiddleware is used by authlib for oauth if len(OAUTH_PROVIDERS) > 0: app.add_middleware( @@ -2193,65 +2179,9 @@ if len(OAUTH_PROVIDERS) > 0: ) -def get_user_role(user: UserModel, user_data: UserInfo) -> str: - if user and Users.get_num_users() == 1: - # If the user is the only user, assign the role "admin" - actually repairs role for single user on login - return "admin" - if not user and Users.get_num_users() == 0: - # If there are no users, assign the role "admin", as the first user will be an admin - return "admin" - - if webui_app.state.config.ENABLE_OAUTH_ROLE_MANAGEMENT: - oauth_claim = webui_app.state.config.OAUTH_ROLES_CLAIM - oauth_allowed_roles = webui_app.state.config.OAUTH_ALLOWED_ROLES - oauth_admin_roles = webui_app.state.config.OAUTH_ADMIN_ROLES - oauth_roles = None - role = "pending" # Default/fallback role if no matching roles are found - - # Next block extracts the roles from the user data, accepting nested claims of any depth - if oauth_claim and oauth_allowed_roles and oauth_admin_roles: - claim_data = user_data - nested_claims = oauth_claim.split(".") - for nested_claim in nested_claims: - claim_data = claim_data.get(nested_claim, {}) - oauth_roles = claim_data if isinstance(claim_data, list) else None - - # If any roles are found, check if they match the allowed or admin roles - if oauth_roles: - # If role management is enabled, and matching roles are provided, use the roles - for allowed_role in oauth_allowed_roles: - # If the user has any of the allowed roles, assign the role "user" - if allowed_role in oauth_roles: - role = "user" - break - for admin_role in oauth_admin_roles: - # If the user has any of the admin roles, assign the role "admin" - if admin_role in oauth_roles: - role = "admin" - break - else: - if not user: - # If role management is disabled, use the default role for new users - role = webui_app.state.config.DEFAULT_USER_ROLE - else: - # If role management is disabled, use the existing role for existing users - role = user.role - - return role - - @app.get("/oauth/{provider}/login") async def oauth_login(provider: str, request: Request): - if provider not in OAUTH_PROVIDERS: - raise HTTPException(404) - # If the provider has a custom redirect URL, use that, otherwise automatically generate one - redirect_uri = OAUTH_PROVIDERS[provider].get("redirect_uri") or request.url_for( - "oauth_callback", provider=provider - ) - client = oauth.create_client(provider) - if client is None: - raise HTTPException(404) - return await client.authorize_redirect(request, redirect_uri) + return oauth_manager.handle_login(provider, request) # OAuth login logic is as follows: @@ -2262,119 +2192,7 @@ async def oauth_login(provider: str, request: Request): # - Email addresses are considered unique, so we fail registration if the email address is alreayd taken @app.get("/oauth/{provider}/callback") async def oauth_callback(provider: str, request: Request, response: Response): - if provider not in OAUTH_PROVIDERS: - raise HTTPException(404) - client = oauth.create_client(provider) - try: - token = await client.authorize_access_token(request) - except Exception as e: - log.warning(f"OAuth callback error: {e}") - raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED) - user_data: UserInfo = token["userinfo"] - - sub = user_data.get("sub") - if not sub: - log.warning(f"OAuth callback failed, sub is missing: {user_data}") - raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED) - provider_sub = f"{provider}@{sub}" - email_claim = webui_app.state.config.OAUTH_EMAIL_CLAIM - email = user_data.get(email_claim, "").lower() - # We currently mandate that email addresses are provided - if not email: - log.warning(f"OAuth callback failed, email is missing: {user_data}") - raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED) - - # Check if the user exists - user = Users.get_user_by_oauth_sub(provider_sub) - - if not user: - # If the user does not exist, check if merging is enabled - if OAUTH_MERGE_ACCOUNTS_BY_EMAIL.value: - # Check if the user exists by email - user = Users.get_user_by_email(email) - if user: - # Update the user with the new oauth sub - Users.update_user_oauth_sub_by_id(user.id, provider_sub) - - if user: - determined_role = get_user_role(user, user_data) - if user.role != determined_role: - Users.update_user_role_by_id(user.id, determined_role) - - if not user: - # If the user does not exist, check if signups are enabled - if ENABLE_OAUTH_SIGNUP.value: - # Check if an existing user with the same email already exists - existing_user = Users.get_user_by_email(user_data.get("email", "").lower()) - if existing_user: - raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN) - - picture_claim = webui_app.state.config.OAUTH_PICTURE_CLAIM - picture_url = user_data.get(picture_claim, "") - if picture_url: - # Download the profile image into a base64 string - try: - async with aiohttp.ClientSession() as session: - async with session.get(picture_url) as resp: - picture = await resp.read() - base64_encoded_picture = base64.b64encode(picture).decode( - "utf-8" - ) - guessed_mime_type = mimetypes.guess_type(picture_url)[0] - if guessed_mime_type is None: - # assume JPG, browsers are tolerant enough of image formats - guessed_mime_type = "image/jpeg" - picture_url = f"data:{guessed_mime_type};base64,{base64_encoded_picture}" - except Exception as e: - log.error(f"Error downloading profile image '{picture_url}': {e}") - picture_url = "" - if not picture_url: - picture_url = "/user.png" - username_claim = webui_app.state.config.OAUTH_USERNAME_CLAIM - - role = get_user_role(None, user_data) - - user = Auths.insert_new_auth( - email=email, - password=get_password_hash( - str(uuid.uuid4()) - ), # Random password, not used - name=user_data.get(username_claim, "User"), - profile_image_url=picture_url, - role=role, - oauth_sub=provider_sub, - ) - - if webui_app.state.config.WEBHOOK_URL: - post_webhook( - webui_app.state.config.WEBHOOK_URL, - WEBHOOK_MESSAGES.USER_SIGNUP(user.name), - { - "action": "signup", - "message": WEBHOOK_MESSAGES.USER_SIGNUP(user.name), - "user": user.model_dump_json(exclude_none=True), - }, - ) - else: - raise HTTPException( - status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED - ) - - jwt_token = create_token( - data={"id": user.id}, - expires_delta=parse_duration(webui_app.state.config.JWT_EXPIRES_IN), - ) - - # Set the cookie token - response.set_cookie( - key="token", - value=jwt_token, - httponly=True, # Ensures the cookie is not accessible via JavaScript - ) - - # Redirect back to the frontend with the JWT token - redirect_url = f"{request.base_url}auth#token={jwt_token}" - return RedirectResponse(url=redirect_url) + return oauth_manager.handle_callback(provider, request, response) @app.get("/manifest.json") diff --git a/backend/open_webui/utils/oauth.py b/backend/open_webui/utils/oauth.py new file mode 100644 index 000000000..3e163686e --- /dev/null +++ b/backend/open_webui/utils/oauth.py @@ -0,0 +1,234 @@ +import base64 +import mimetypes +import uuid + +import aiohttp +import logging +from fastapi import ( + HTTPException, + Request, + status, +) +from starlette.responses import RedirectResponse, Response, StreamingResponse +from authlib.oidc.core import UserInfo + +from open_webui.apps.webui.models.auths import Auths +from open_webui.apps.webui.models.users import Users, UserModel +from open_webui.config import ( + DEFAULT_USER_ROLE, + ENABLE_OAUTH_SIGNUP, + OAUTH_MERGE_ACCOUNTS_BY_EMAIL, + OAUTH_PROVIDERS, + ENABLE_OAUTH_ROLE_MANAGEMENT, + OAUTH_ROLES_CLAIM, + OAUTH_EMAIL_CLAIM, + OAUTH_PICTURE_CLAIM, + OAUTH_USERNAME_CLAIM, + OAUTH_ALLOWED_ROLES, + OAUTH_ADMIN_ROLES, WEBHOOK_URL, JWT_EXPIRES_IN, +) + +from authlib.integrations.starlette_client import OAuth + +from open_webui.constants import ERROR_MESSAGES, WEBHOOK_MESSAGES +from open_webui.utils.misc import parse_duration +from open_webui.utils.utils import get_password_hash, create_token +from open_webui.utils.webhook import post_webhook + +log = logging.getLogger(__name__) + +oauth_manager = {} +oauth_manager.oauth = OAuth() + +for provider_name, provider_config in OAUTH_PROVIDERS.items(): + oauth_manager.oauth.register( + name=provider_name, + client_id=provider_config["client_id"], + client_secret=provider_config["client_secret"], + server_metadata_url=provider_config["server_metadata_url"], + client_kwargs={ + "scope": provider_config["scope"], + }, + redirect_uri=provider_config["redirect_uri"], + ) + +oauth_manager.get_client = oauth_manager.oauth.create_client + +def get_user_role(user: UserModel, user_data: UserInfo) -> str: + if user and Users.get_num_users() == 1: + # If the user is the only user, assign the role "admin" - actually repairs role for single user on login + return "admin" + if not user and Users.get_num_users() == 0: + # If there are no users, assign the role "admin", as the first user will be an admin + return "admin" + + if ENABLE_OAUTH_ROLE_MANAGEMENT: + oauth_claim = OAUTH_ROLES_CLAIM + oauth_allowed_roles = OAUTH_ALLOWED_ROLES + oauth_admin_roles = OAUTH_ADMIN_ROLES + oauth_roles = None + role = "pending" # Default/fallback role if no matching roles are found + + # Next block extracts the roles from the user data, accepting nested claims of any depth + if oauth_claim and oauth_allowed_roles and oauth_admin_roles: + claim_data = user_data + nested_claims = oauth_claim.split(".") + for nested_claim in nested_claims: + claim_data = claim_data.get(nested_claim, {}) + oauth_roles = claim_data if isinstance(claim_data, list) else None + + # If any roles are found, check if they match the allowed or admin roles + if oauth_roles: + # If role management is enabled, and matching roles are provided, use the roles + for allowed_role in oauth_allowed_roles: + # If the user has any of the allowed roles, assign the role "user" + if allowed_role in oauth_roles: + role = "user" + break + for admin_role in oauth_admin_roles: + # If the user has any of the admin roles, assign the role "admin" + if admin_role in oauth_roles: + role = "admin" + break + else: + if not user: + # If role management is disabled, use the default role for new users + role = DEFAULT_USER_ROLE + else: + # If role management is disabled, use the existing role for existing users + role = user.role + + return role + +oauth_manager.get_user_role = get_user_role + +async def handle_login(provider: str, request: Request): + if provider not in OAUTH_PROVIDERS: + raise HTTPException(404) + # If the provider has a custom redirect URL, use that, otherwise automatically generate one + redirect_uri = OAUTH_PROVIDERS[provider].get("redirect_uri") or request.url_for( + "oauth_callback", provider=provider + ) + client = oauth_manager.get_client(provider) + if client is None: + raise HTTPException(404) + return await client.authorize_redirect(request, redirect_uri) + +oauth_manager.handle_login = handle_login + +async def handle_callback(provider: str, request: Request, response: Response): + if provider not in OAUTH_PROVIDERS: + raise HTTPException(404) + client = oauth_manager.get_client(provider) + try: + token = await client.authorize_access_token(request) + except Exception as e: + log.warning(f"OAuth callback error: {e}") + raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED) + user_data: UserInfo = token["userinfo"] + + sub = user_data.get("sub") + if not sub: + log.warning(f"OAuth callback failed, sub is missing: {user_data}") + raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED) + provider_sub = f"{provider}@{sub}" + email_claim = OAUTH_EMAIL_CLAIM + email = user_data.get(email_claim, "").lower() + # We currently mandate that email addresses are provided + if not email: + log.warning(f"OAuth callback failed, email is missing: {user_data}") + raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED) + + # Check if the user exists + user = Users.get_user_by_oauth_sub(provider_sub) + + if not user: + # If the user does not exist, check if merging is enabled + if OAUTH_MERGE_ACCOUNTS_BY_EMAIL.value: + # Check if the user exists by email + user = Users.get_user_by_email(email) + if user: + # Update the user with the new oauth sub + Users.update_user_oauth_sub_by_id(user.id, provider_sub) + + if user: + determined_role = get_user_role(user, user_data) + if user.role != determined_role: + Users.update_user_role_by_id(user.id, determined_role) + + if not user: + # If the user does not exist, check if signups are enabled + if ENABLE_OAUTH_SIGNUP.value: + # Check if an existing user with the same email already exists + existing_user = Users.get_user_by_email(user_data.get("email", "").lower()) + if existing_user: + raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN) + + picture_claim = OAUTH_PICTURE_CLAIM + picture_url = user_data.get(picture_claim, "") + if picture_url: + # Download the profile image into a base64 string + try: + async with aiohttp.ClientSession() as session: + async with session.get(picture_url) as resp: + picture = await resp.read() + base64_encoded_picture = base64.b64encode(picture).decode( + "utf-8" + ) + guessed_mime_type = mimetypes.guess_type(picture_url)[0] + if guessed_mime_type is None: + # assume JPG, browsers are tolerant enough of image formats + guessed_mime_type = "image/jpeg" + picture_url = f"data:{guessed_mime_type};base64,{base64_encoded_picture}" + except Exception as e: + log.error(f"Error downloading profile image '{picture_url}': {e}") + picture_url = "" + if not picture_url: + picture_url = "/user.png" + username_claim = OAUTH_USERNAME_CLAIM + + role = get_user_role(None, user_data) + + user = Auths.insert_new_auth( + email=email, + password=get_password_hash( + str(uuid.uuid4()) + ), # Random password, not used + name=user_data.get(username_claim, "User"), + profile_image_url=picture_url, + role=role, + oauth_sub=provider_sub, + ) + + if WEBHOOK_URL: + post_webhook( + WEBHOOK_URL, + WEBHOOK_MESSAGES.USER_SIGNUP(user.name), + { + "action": "signup", + "message": WEBHOOK_MESSAGES.USER_SIGNUP(user.name), + "user": user.model_dump_json(exclude_none=True), + }, + ) + else: + raise HTTPException( + status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED + ) + + jwt_token = create_token( + data={"id": user.id}, + expires_delta=parse_duration(JWT_EXPIRES_IN), + ) + + # Set the cookie token + response.set_cookie( + key="token", + value=jwt_token, + httponly=True, # Ensures the cookie is not accessible via JavaScript + ) + + # Redirect back to the frontend with the JWT token + redirect_url = f"{request.base_url}auth#token={jwt_token}" + return RedirectResponse(url=redirect_url) + +oauth_manager.handle_callback = handle_callback \ No newline at end of file From 8eebd6bce1ce3ce7fec173d0bbc0bc0bdbf3a119 Mon Sep 17 00:00:00 2001 From: "Willnow, Patrick" Date: Wed, 16 Oct 2024 16:32:57 +0200 Subject: [PATCH 13/14] Finish reorganizing oauth code --- backend/open_webui/main.py | 178 +++++++------- backend/open_webui/utils/oauth.py | 385 +++++++++++++++--------------- 2 files changed, 279 insertions(+), 284 deletions(-) diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index fcb2db055..b0f9178a3 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -1,4 +1,5 @@ -import base64 +import inspect +import asyncio import inspect import json import logging @@ -7,89 +8,11 @@ import os import shutil import sys import time -import uuid -import asyncio - from contextlib import asynccontextmanager from typing import Optional import aiohttp import requests - -from open_webui.apps.audio.main import app as audio_app -from open_webui.apps.images.main import app as images_app -from open_webui.apps.ollama.main import app as ollama_app -from open_webui.apps.ollama.main import ( - GenerateChatCompletionForm, - generate_chat_completion as generate_ollama_chat_completion, - generate_openai_chat_completion as generate_ollama_openai_chat_completion, -) -from open_webui.apps.ollama.main import get_all_models as get_ollama_models -from open_webui.apps.openai.main import app as openai_app -from open_webui.apps.openai.main import ( - generate_chat_completion as generate_openai_chat_completion, -) -from open_webui.apps.openai.main import get_all_models as get_openai_models -from open_webui.apps.rag.main import app as rag_app -from open_webui.apps.rag.utils import get_rag_context, rag_template -from open_webui.apps.socket.main import app as socket_app, periodic_usage_pool_cleanup -from open_webui.apps.socket.main import get_event_call, get_event_emitter -from open_webui.apps.webui.internal.db import Session -from open_webui.apps.webui.main import app as webui_app -from open_webui.apps.webui.main import ( - generate_function_chat_completion, - get_pipe_models, -) -from open_webui.apps.webui.models.auths import Auths -from open_webui.apps.webui.models.functions import Functions -from open_webui.apps.webui.models.models import Models -from open_webui.apps.webui.models.users import UserModel, Users -from open_webui.apps.webui.utils import load_function_module_by_id - - -from open_webui.config import ( - CACHE_DIR, - CORS_ALLOW_ORIGIN, - DEFAULT_LOCALE, - ENABLE_ADMIN_CHAT_ACCESS, - ENABLE_ADMIN_EXPORT, - ENABLE_MODEL_FILTER, - ENABLE_OAUTH_SIGNUP, - ENABLE_OLLAMA_API, - ENABLE_OPENAI_API, - ENV, - FRONTEND_BUILD_DIR, - MODEL_FILTER_LIST, - OAUTH_MERGE_ACCOUNTS_BY_EMAIL, - OAUTH_PROVIDERS, - ENABLE_SEARCH_QUERY, - SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE, - STATIC_DIR, - TASK_MODEL, - TASK_MODEL_EXTERNAL, - TITLE_GENERATION_PROMPT_TEMPLATE, - TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE, - WEBHOOK_URL, - WEBUI_AUTH, - WEBUI_NAME, - AppConfig, - run_migrations, - reset_config, -) -from open_webui.constants import ERROR_MESSAGES, TASKS, WEBHOOK_MESSAGES -from open_webui.env import ( - CHANGELOG, - GLOBAL_LOG_LEVEL, - SAFE_MODE, - SRC_LOG_LEVELS, - VERSION, - WEBUI_BUILD_HASH, - WEBUI_SECRET_KEY, - WEBUI_SESSION_COOKIE_SAME_SITE, - WEBUI_SESSION_COOKIE_SECURE, - WEBUI_URL, - RESET_CONFIG_ON_START, -) from fastapi import ( Depends, FastAPI, @@ -108,16 +31,88 @@ from sqlalchemy import text from starlette.exceptions import HTTPException as StarletteHTTPException from starlette.middleware.base import BaseHTTPMiddleware from starlette.middleware.sessions import SessionMiddleware -from starlette.responses import RedirectResponse, Response, StreamingResponse - -from open_webui.utils.security_headers import SecurityHeadersMiddleware +from starlette.responses import Response, StreamingResponse +from open_webui.apps.audio.main import app as audio_app +from open_webui.apps.images.main import app as images_app +from open_webui.apps.ollama.main import ( + GenerateChatCompletionForm, + generate_chat_completion as generate_ollama_chat_completion, +) +from open_webui.apps.ollama.main import app as ollama_app +from open_webui.apps.ollama.main import get_all_models as get_ollama_models +from open_webui.apps.openai.main import app as openai_app +from open_webui.apps.openai.main import ( + generate_chat_completion as generate_openai_chat_completion, +) +from open_webui.apps.openai.main import get_all_models as get_openai_models +from open_webui.apps.rag.main import app as rag_app +from open_webui.apps.rag.utils import get_rag_context, rag_template +from open_webui.apps.socket.main import app as socket_app, periodic_usage_pool_cleanup +from open_webui.apps.socket.main import get_event_call, get_event_emitter +from open_webui.apps.webui.internal.db import Session +from open_webui.apps.webui.main import app as webui_app +from open_webui.apps.webui.main import ( + generate_function_chat_completion, + get_pipe_models, +) +from open_webui.apps.webui.models.functions import Functions +from open_webui.apps.webui.models.models import Models +from open_webui.apps.webui.models.users import UserModel, Users +from open_webui.apps.webui.utils import load_function_module_by_id +from open_webui.config import ( + CACHE_DIR, + CORS_ALLOW_ORIGIN, + DEFAULT_LOCALE, + ENABLE_ADMIN_CHAT_ACCESS, + ENABLE_ADMIN_EXPORT, + ENABLE_MODEL_FILTER, + ENABLE_OLLAMA_API, + ENABLE_OPENAI_API, + ENV, + FRONTEND_BUILD_DIR, + MODEL_FILTER_LIST, + OAUTH_PROVIDERS, + ENABLE_SEARCH_QUERY, + SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE, + STATIC_DIR, + TASK_MODEL, + TASK_MODEL_EXTERNAL, + TITLE_GENERATION_PROMPT_TEMPLATE, + TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE, + WEBHOOK_URL, + WEBUI_AUTH, + WEBUI_NAME, + AppConfig, + run_migrations, + reset_config, +) +from open_webui.constants import ERROR_MESSAGES, TASKS +from open_webui.env import ( + CHANGELOG, + GLOBAL_LOG_LEVEL, + SAFE_MODE, + SRC_LOG_LEVELS, + VERSION, + WEBUI_BUILD_HASH, + WEBUI_SECRET_KEY, + WEBUI_SESSION_COOKIE_SAME_SITE, + WEBUI_SESSION_COOKIE_SECURE, + WEBUI_URL, + RESET_CONFIG_ON_START, +) from open_webui.utils.misc import ( add_or_update_system_message, get_last_user_message, - parse_duration, prepend_to_first_user_message_content, ) +from open_webui.utils.oauth import oauth_manager +from open_webui.utils.payload import convert_payload_openai_to_ollama +from open_webui.utils.response import ( + convert_response_ollama_to_openai, + convert_streaming_response_ollama_to_openai, +) +from open_webui.utils.security_headers import SecurityHeadersMiddleware from open_webui.utils.task import ( moa_response_generation_template, search_query_generation_template, @@ -126,23 +121,12 @@ from open_webui.utils.task import ( ) from open_webui.utils.tools import get_tools from open_webui.utils.utils import ( - create_token, decode_token, get_admin_user, get_current_user, get_http_authorization_cred, - get_password_hash, get_verified_user, ) -from open_webui.utils.webhook import post_webhook - -from open_webui.utils.payload import convert_payload_openai_to_ollama -from open_webui.utils.response import ( - convert_response_ollama_to_openai, - convert_streaming_response_ollama_to_openai, -) - -from open_webui.utils.oauth import oauth_manager if SAFE_MODE: print("SAFE MODE ENABLED") @@ -220,6 +204,8 @@ app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = ( app.state.MODELS = {} + + ################################## # # ChatCompletion Middleware @@ -2181,7 +2167,7 @@ if len(OAUTH_PROVIDERS) > 0: @app.get("/oauth/{provider}/login") async def oauth_login(provider: str, request: Request): - return oauth_manager.handle_login(provider, request) + return await oauth_manager.handle_login(provider, request) # OAuth login logic is as follows: @@ -2192,7 +2178,7 @@ async def oauth_login(provider: str, request: Request): # - Email addresses are considered unique, so we fail registration if the email address is alreayd taken @app.get("/oauth/{provider}/callback") async def oauth_callback(provider: str, request: Request, response: Response): - return oauth_manager.handle_callback(provider, request, response) + return await oauth_manager.handle_callback(provider, request, response) @app.get("/manifest.json") diff --git a/backend/open_webui/utils/oauth.py b/backend/open_webui/utils/oauth.py index 3e163686e..e15edc0a6 100644 --- a/backend/open_webui/utils/oauth.py +++ b/backend/open_webui/utils/oauth.py @@ -1,19 +1,19 @@ import base64 +import logging import mimetypes import uuid import aiohttp -import logging +from authlib.integrations.starlette_client import OAuth +from authlib.oidc.core import UserInfo from fastapi import ( HTTPException, - Request, status, ) -from starlette.responses import RedirectResponse, Response, StreamingResponse -from authlib.oidc.core import UserInfo +from starlette.responses import RedirectResponse from open_webui.apps.webui.models.auths import Auths -from open_webui.apps.webui.models.users import Users, UserModel +from open_webui.apps.webui.models.users import Users from open_webui.config import ( DEFAULT_USER_ROLE, ENABLE_OAUTH_SIGNUP, @@ -25,210 +25,219 @@ from open_webui.config import ( OAUTH_PICTURE_CLAIM, OAUTH_USERNAME_CLAIM, OAUTH_ALLOWED_ROLES, - OAUTH_ADMIN_ROLES, WEBHOOK_URL, JWT_EXPIRES_IN, + OAUTH_ADMIN_ROLES, WEBHOOK_URL, JWT_EXPIRES_IN, AppConfig, ) - -from authlib.integrations.starlette_client import OAuth - -from open_webui.constants import ERROR_MESSAGES, WEBHOOK_MESSAGES +from open_webui.constants import ERROR_MESSAGES from open_webui.utils.misc import parse_duration from open_webui.utils.utils import get_password_hash, create_token from open_webui.utils.webhook import post_webhook log = logging.getLogger(__name__) -oauth_manager = {} -oauth_manager.oauth = OAuth() +auth_manager_config = AppConfig() +auth_manager_config.DEFAULT_USER_ROLE = DEFAULT_USER_ROLE +auth_manager_config.ENABLE_OAUTH_SIGNUP = ENABLE_OAUTH_SIGNUP +auth_manager_config.OAUTH_MERGE_ACCOUNTS_BY_EMAIL = OAUTH_MERGE_ACCOUNTS_BY_EMAIL +auth_manager_config.ENABLE_OAUTH_ROLE_MANAGEMENT = ENABLE_OAUTH_ROLE_MANAGEMENT +auth_manager_config.OAUTH_ROLES_CLAIM = OAUTH_ROLES_CLAIM +auth_manager_config.OAUTH_EMAIL_CLAIM = OAUTH_EMAIL_CLAIM +auth_manager_config.OAUTH_PICTURE_CLAIM = OAUTH_PICTURE_CLAIM +auth_manager_config.OAUTH_USERNAME_CLAIM = OAUTH_USERNAME_CLAIM +auth_manager_config.OAUTH_ALLOWED_ROLES = OAUTH_ALLOWED_ROLES +auth_manager_config.OAUTH_ADMIN_ROLES = OAUTH_ADMIN_ROLES +auth_manager_config.WEBHOOK_URL = WEBHOOK_URL +auth_manager_config.JWT_EXPIRES_IN = JWT_EXPIRES_IN -for provider_name, provider_config in OAUTH_PROVIDERS.items(): - oauth_manager.oauth.register( - name=provider_name, - client_id=provider_config["client_id"], - client_secret=provider_config["client_secret"], - server_metadata_url=provider_config["server_metadata_url"], - client_kwargs={ - "scope": provider_config["scope"], - }, - redirect_uri=provider_config["redirect_uri"], - ) -oauth_manager.get_client = oauth_manager.oauth.create_client +class OAuthManager: + def __init__(self): + self.oauth = OAuth() + for provider_name, provider_config in OAUTH_PROVIDERS.items(): + self.oauth.register( + name=provider_name, + client_id=provider_config["client_id"], + client_secret=provider_config["client_secret"], + server_metadata_url=provider_config["server_metadata_url"], + client_kwargs={ + "scope": provider_config["scope"], + }, + redirect_uri=provider_config["redirect_uri"], + ) -def get_user_role(user: UserModel, user_data: UserInfo) -> str: - if user and Users.get_num_users() == 1: - # If the user is the only user, assign the role "admin" - actually repairs role for single user on login - return "admin" - if not user and Users.get_num_users() == 0: - # If there are no users, assign the role "admin", as the first user will be an admin - return "admin" + def get_client(self, provider_name): + return self.oauth.create_client(provider_name) - if ENABLE_OAUTH_ROLE_MANAGEMENT: - oauth_claim = OAUTH_ROLES_CLAIM - oauth_allowed_roles = OAUTH_ALLOWED_ROLES - oauth_admin_roles = OAUTH_ADMIN_ROLES - oauth_roles = None - role = "pending" # Default/fallback role if no matching roles are found + def get_user_role(self, user, user_data): + if user and Users.get_num_users() == 1: + # If the user is the only user, assign the role "admin" - actually repairs role for single user on login + return "admin" + if not user and Users.get_num_users() == 0: + # If there are no users, assign the role "admin", as the first user will be an admin + return "admin" - # Next block extracts the roles from the user data, accepting nested claims of any depth - if oauth_claim and oauth_allowed_roles and oauth_admin_roles: - claim_data = user_data - nested_claims = oauth_claim.split(".") - for nested_claim in nested_claims: - claim_data = claim_data.get(nested_claim, {}) - oauth_roles = claim_data if isinstance(claim_data, list) else None + if auth_manager_config.ENABLE_OAUTH_ROLE_MANAGEMENT: + oauth_claim = auth_manager_config.OAUTH_ROLES_CLAIM + oauth_allowed_roles = auth_manager_config.OAUTH_ALLOWED_ROLES + oauth_admin_roles = auth_manager_config.OAUTH_ADMIN_ROLES + oauth_roles = None + role = "pending" # Default/fallback role if no matching roles are found + + # Next block extracts the roles from the user data, accepting nested claims of any depth + if oauth_claim and oauth_allowed_roles and oauth_admin_roles: + claim_data = user_data + nested_claims = oauth_claim.split(".") + for nested_claim in nested_claims: + claim_data = claim_data.get(nested_claim, {}) + oauth_roles = claim_data if isinstance(claim_data, list) else None + + # If any roles are found, check if they match the allowed or admin roles + if oauth_roles: + # If role management is enabled, and matching roles are provided, use the roles + for allowed_role in oauth_allowed_roles: + # If the user has any of the allowed roles, assign the role "user" + if allowed_role in oauth_roles: + role = "user" + break + for admin_role in oauth_admin_roles: + # If the user has any of the admin roles, assign the role "admin" + if admin_role in oauth_roles: + role = "admin" + break + else: + if not user: + # If role management is disabled, use the default role for new users + role = auth_manager_config.DEFAULT_USER_ROLE + else: + # If role management is disabled, use the existing role for existing users + role = user.role + + return role + + async def handle_login(self, provider, request): + if provider not in OAUTH_PROVIDERS: + raise HTTPException(404) + # If the provider has a custom redirect URL, use that, otherwise automatically generate one + redirect_uri = OAUTH_PROVIDERS[provider].get("redirect_uri") or request.url_for( + "oauth_callback", provider=provider + ) + client = self.get_client(provider) + if client is None: + raise HTTPException(404) + return await client.authorize_redirect(request, redirect_uri) + + async def handle_callback(self, provider, request, response): + if provider not in OAUTH_PROVIDERS: + raise HTTPException(404) + client = self.get_client(provider) + try: + token = await client.authorize_access_token(request) + except Exception as e: + log.warning(f"OAuth callback error: {e}") + raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED) + user_data: UserInfo = token["userinfo"] + + sub = user_data.get("sub") + if not sub: + log.warning(f"OAuth callback failed, sub is missing: {user_data}") + raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED) + provider_sub = f"{provider}@{sub}" + email_claim = auth_manager_config.OAUTH_EMAIL_CLAIM + email = user_data.get(email_claim, "").lower() + # We currently mandate that email addresses are provided + if not email: + log.warning(f"OAuth callback failed, email is missing: {user_data}") + raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED) + + # Check if the user exists + user = Users.get_user_by_oauth_sub(provider_sub) - # If any roles are found, check if they match the allowed or admin roles - if oauth_roles: - # If role management is enabled, and matching roles are provided, use the roles - for allowed_role in oauth_allowed_roles: - # If the user has any of the allowed roles, assign the role "user" - if allowed_role in oauth_roles: - role = "user" - break - for admin_role in oauth_admin_roles: - # If the user has any of the admin roles, assign the role "admin" - if admin_role in oauth_roles: - role = "admin" - break - else: if not user: - # If role management is disabled, use the default role for new users - role = DEFAULT_USER_ROLE - else: - # If role management is disabled, use the existing role for existing users - role = user.role + # If the user does not exist, check if merging is enabled + if auth_manager_config.OAUTH_MERGE_ACCOUNTS_BY_EMAIL.value: + # Check if the user exists by email + user = Users.get_user_by_email(email) + if user: + # Update the user with the new oauth sub + Users.update_user_oauth_sub_by_id(user.id, provider_sub) - return role + if user: + determined_role = self.get_user_role(user, user_data) + if user.role != determined_role: + Users.update_user_role_by_id(user.id, determined_role) -oauth_manager.get_user_role = get_user_role + if not user: + # If the user does not exist, check if signups are enabled + if auth_manager_config.ENABLE_OAUTH_SIGNUP.value: + # Check if an existing user with the same email already exists + existing_user = Users.get_user_by_email(user_data.get("email", "").lower()) + if existing_user: + raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN) -async def handle_login(provider: str, request: Request): - if provider not in OAUTH_PROVIDERS: - raise HTTPException(404) - # If the provider has a custom redirect URL, use that, otherwise automatically generate one - redirect_uri = OAUTH_PROVIDERS[provider].get("redirect_uri") or request.url_for( - "oauth_callback", provider=provider - ) - client = oauth_manager.get_client(provider) - if client is None: - raise HTTPException(404) - return await client.authorize_redirect(request, redirect_uri) + picture_claim = auth_manager_config.OAUTH_PICTURE_CLAIM + picture_url = user_data.get(picture_claim, "") + if picture_url: + # Download the profile image into a base64 string + try: + async with aiohttp.ClientSession() as session: + async with session.get(picture_url) as resp: + picture = await resp.read() + base64_encoded_picture = base64.b64encode(picture).decode( + "utf-8" + ) + guessed_mime_type = mimetypes.guess_type(picture_url)[0] + if guessed_mime_type is None: + # assume JPG, browsers are tolerant enough of image formats + guessed_mime_type = "image/jpeg" + picture_url = f"data:{guessed_mime_type};base64,{base64_encoded_picture}" + except Exception as e: + log.error(f"Error downloading profile image '{picture_url}': {e}") + picture_url = "" + if not picture_url: + picture_url = "/user.png" + username_claim = auth_manager_config.OAUTH_USERNAME_CLAIM -oauth_manager.handle_login = handle_login + role = self.get_user_role(None, user_data) -async def handle_callback(provider: str, request: Request, response: Response): - if provider not in OAUTH_PROVIDERS: - raise HTTPException(404) - client = oauth_manager.get_client(provider) - try: - token = await client.authorize_access_token(request) - except Exception as e: - log.warning(f"OAuth callback error: {e}") - raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED) - user_data: UserInfo = token["userinfo"] - - sub = user_data.get("sub") - if not sub: - log.warning(f"OAuth callback failed, sub is missing: {user_data}") - raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED) - provider_sub = f"{provider}@{sub}" - email_claim = OAUTH_EMAIL_CLAIM - email = user_data.get(email_claim, "").lower() - # We currently mandate that email addresses are provided - if not email: - log.warning(f"OAuth callback failed, email is missing: {user_data}") - raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED) - - # Check if the user exists - user = Users.get_user_by_oauth_sub(provider_sub) - - if not user: - # If the user does not exist, check if merging is enabled - if OAUTH_MERGE_ACCOUNTS_BY_EMAIL.value: - # Check if the user exists by email - user = Users.get_user_by_email(email) - if user: - # Update the user with the new oauth sub - Users.update_user_oauth_sub_by_id(user.id, provider_sub) - - if user: - determined_role = get_user_role(user, user_data) - if user.role != determined_role: - Users.update_user_role_by_id(user.id, determined_role) - - if not user: - # If the user does not exist, check if signups are enabled - if ENABLE_OAUTH_SIGNUP.value: - # Check if an existing user with the same email already exists - existing_user = Users.get_user_by_email(user_data.get("email", "").lower()) - if existing_user: - raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN) - - picture_claim = OAUTH_PICTURE_CLAIM - picture_url = user_data.get(picture_claim, "") - if picture_url: - # Download the profile image into a base64 string - try: - async with aiohttp.ClientSession() as session: - async with session.get(picture_url) as resp: - picture = await resp.read() - base64_encoded_picture = base64.b64encode(picture).decode( - "utf-8" - ) - guessed_mime_type = mimetypes.guess_type(picture_url)[0] - if guessed_mime_type is None: - # assume JPG, browsers are tolerant enough of image formats - guessed_mime_type = "image/jpeg" - picture_url = f"data:{guessed_mime_type};base64,{base64_encoded_picture}" - except Exception as e: - log.error(f"Error downloading profile image '{picture_url}': {e}") - picture_url = "" - if not picture_url: - picture_url = "/user.png" - username_claim = OAUTH_USERNAME_CLAIM - - role = get_user_role(None, user_data) - - user = Auths.insert_new_auth( - email=email, - password=get_password_hash( - str(uuid.uuid4()) - ), # Random password, not used - name=user_data.get(username_claim, "User"), - profile_image_url=picture_url, - role=role, - oauth_sub=provider_sub, - ) - - if WEBHOOK_URL: - post_webhook( - WEBHOOK_URL, - WEBHOOK_MESSAGES.USER_SIGNUP(user.name), - { - "action": "signup", - "message": WEBHOOK_MESSAGES.USER_SIGNUP(user.name), - "user": user.model_dump_json(exclude_none=True), - }, + user = Auths.insert_new_auth( + email=email, + password=get_password_hash( + str(uuid.uuid4()) + ), # Random password, not used + name=user_data.get(username_claim, "User"), + profile_image_url=picture_url, + role=role, + oauth_sub=provider_sub, ) - else: - raise HTTPException( - status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED - ) - jwt_token = create_token( - data={"id": user.id}, - expires_delta=parse_duration(JWT_EXPIRES_IN), - ) + if auth_manager_config.WEBHOOK_URL: + post_webhook( + auth_manager_config.WEBHOOK_URL, + auth_manager_config.WEBHOOK_MESSAGES.USER_SIGNUP(user.name), + { + "action": "signup", + "message": auth_manager_config.WEBHOOK_MESSAGES.USER_SIGNUP(user.name), + "user": user.model_dump_json(exclude_none=True), + }, + ) + else: + raise HTTPException( + status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED + ) - # Set the cookie token - response.set_cookie( - key="token", - value=jwt_token, - httponly=True, # Ensures the cookie is not accessible via JavaScript - ) + jwt_token = create_token( + data={"id": user.id}, + expires_delta=parse_duration(auth_manager_config.JWT_EXPIRES_IN), + ) - # Redirect back to the frontend with the JWT token - redirect_url = f"{request.base_url}auth#token={jwt_token}" - return RedirectResponse(url=redirect_url) + # Set the cookie token + response.set_cookie( + key="token", + value=jwt_token, + httponly=True, # Ensures the cookie is not accessible via JavaScript + ) -oauth_manager.handle_callback = handle_callback \ No newline at end of file + # Redirect back to the frontend with the JWT token + redirect_url = f"{request.base_url}auth#token={jwt_token}" + return RedirectResponse(url=redirect_url) + +oauth_manager = OAuthManager() \ No newline at end of file From b1554be3f221226155f93a2746b6e77578f39aa4 Mon Sep 17 00:00:00 2001 From: "Willnow, Patrick" Date: Wed, 16 Oct 2024 16:58:03 +0200 Subject: [PATCH 14/14] Fix imports --- backend/open_webui/main.py | 42 +++++++++++++++++++++----------------- 1 file changed, 23 insertions(+), 19 deletions(-) diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index bc3244cdb..e06f870a7 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -1,4 +1,3 @@ -import inspect import asyncio import inspect import json @@ -13,12 +12,32 @@ from typing import Optional import aiohttp import requests +from fastapi import ( + Depends, + FastAPI, + File, + Form, + HTTPException, + Request, + UploadFile, + status, +) +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import JSONResponse +from fastapi.staticfiles import StaticFiles +from pydantic import BaseModel +from sqlalchemy import text +from starlette.exceptions import HTTPException as StarletteHTTPException +from starlette.middleware.base import BaseHTTPMiddleware +from starlette.middleware.sessions import SessionMiddleware +from starlette.responses import Response, StreamingResponse +from open_webui.apps.audio.main import app as audio_app +from open_webui.apps.images.main import app as images_app from open_webui.apps.ollama.main import ( app as ollama_app, get_all_models as get_ollama_models, generate_chat_completion as generate_ollama_chat_completion, - generate_openai_chat_completion as generate_ollama_openai_chat_completion, GenerateChatCompletionForm, ) from open_webui.apps.openai.main import ( @@ -26,38 +45,24 @@ from open_webui.apps.openai.main import ( generate_chat_completion as generate_openai_chat_completion, get_all_models as get_openai_models, ) - from open_webui.apps.retrieval.main import app as retrieval_app from open_webui.apps.retrieval.utils import get_rag_context, rag_template - from open_webui.apps.socket.main import ( app as socket_app, periodic_usage_pool_cleanup, get_event_call, get_event_emitter, ) - +from open_webui.apps.webui.internal.db import Session from open_webui.apps.webui.main import ( app as webui_app, generate_function_chat_completion, get_pipe_models, ) -from open_webui.apps.webui.internal.db import Session - -from open_webui.apps.webui.models.auths import Auths from open_webui.apps.webui.models.functions import Functions from open_webui.apps.webui.models.models import Models from open_webui.apps.webui.models.users import UserModel, Users - from open_webui.apps.webui.utils import load_function_module_by_id - -from open_webui.apps.audio.main import app as audio_app -from open_webui.apps.images.main import app as images_app - -from authlib.integrations.starlette_client import OAuth -from authlib.oidc.core import UserInfo - - from open_webui.config import ( CACHE_DIR, CORS_ALLOW_ORIGIN, @@ -82,10 +87,9 @@ from open_webui.config import ( WEBUI_AUTH, WEBUI_NAME, AppConfig, - run_migrations, reset_config, ) -from open_webui.constants import ERROR_MESSAGES, TASKS +from open_webui.constants import TASKS from open_webui.env import ( CHANGELOG, GLOBAL_LOG_LEVEL,