diff --git a/backend/open_webui/apps/webui/main.py b/backend/open_webui/apps/webui/main.py index 6c6f197dd..8682cd767 100644 --- a/backend/open_webui/apps/webui/main.py +++ b/backend/open_webui/apps/webui/main.py @@ -32,6 +32,8 @@ from open_webui.config import ( ENABLE_MESSAGE_RATING, ENABLE_SIGNUP, JWT_EXPIRES_IN, + ENABLE_OAUTH_ROLE_MAPPING, + OAUTH_ROLES_CLAIM, OAUTH_EMAIL_CLAIM, OAUTH_PICTURE_CLAIM, OAUTH_USERNAME_CLAIM, @@ -93,6 +95,9 @@ app.state.config.OAUTH_USERNAME_CLAIM = OAUTH_USERNAME_CLAIM app.state.config.OAUTH_PICTURE_CLAIM = OAUTH_PICTURE_CLAIM app.state.config.OAUTH_EMAIL_CLAIM = OAUTH_EMAIL_CLAIM +app.state.config.ENABLE_OAUTH_ROLE_MAPPING = ENABLE_OAUTH_ROLE_MAPPING +app.state.config.OAUTH_ROLES_CLAIM = OAUTH_ROLES_CLAIM + app.state.MODELS = {} app.state.TOOLS = {} app.state.FUNCTIONS = {} diff --git a/backend/open_webui/config.py b/backend/open_webui/config.py index f9921d9cb..18e0c398e 100644 --- a/backend/open_webui/config.py +++ b/backend/open_webui/config.py @@ -278,18 +278,6 @@ ENABLE_OAUTH_SIGNUP = PersistentConfig( os.environ.get("ENABLE_OAUTH_SIGNUP", "False").lower() == "true", ) -ENABLE_OAUTH_ROLE_MAPPING = PersistentConfig( - "ENABLE_OAUTH_ROLE_MAPPING", - "oauth.enable_role_mapping", - os.environ.get("ENABLE_OAUTH_ROLE_MAPPING", "False").lower() == "true", -) - -OAUTH_ROLES_CLAIM = PersistentConfig( - "OAUTH_ROLES_CLAIM", - "oauth.roles_claim", - os.environ.get("OAUTH_ROLES_CLAIM", "roles"), -) - OAUTH_MERGE_ACCOUNTS_BY_EMAIL = PersistentConfig( "OAUTH_MERGE_ACCOUNTS_BY_EMAIL", "oauth.merge_accounts_by_email", @@ -395,7 +383,7 @@ OAUTH_USERNAME_CLAIM = PersistentConfig( ) OAUTH_PICTURE_CLAIM = PersistentConfig( - "OAUTH_USERNAME_CLAIM", + "OAUTH_PICTURE_CLAIM", "oauth.oidc.avatar_claim", os.environ.get("OAUTH_PICTURE_CLAIM", "picture"), ) @@ -406,6 +394,18 @@ OAUTH_EMAIL_CLAIM = PersistentConfig( os.environ.get("OAUTH_EMAIL_CLAIM", "email"), ) +ENABLE_OAUTH_ROLE_MAPPING = PersistentConfig( + "ENABLE_OAUTH_ROLE_MAPPING", + "oauth.enable_role_mapping", + os.environ.get("ENABLE_OAUTH_ROLE_MAPPING", "False").lower() == "true", +) + +OAUTH_ROLES_CLAIM = PersistentConfig( + "OAUTH_ROLES_CLAIM", + "oauth.roles_claim", + os.environ.get("OAUTH_ROLES_CLAIM", "roles"), +) + def load_oauth_providers(): OAUTH_PROVIDERS.clear() diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index 77d486fb7..e24a5a969 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -2249,7 +2249,7 @@ async def oauth_callback(provider: str, request: Request, response: Response): role = user.role if Users.get_num_users() == 1: role = "admin" - elif webui_app.state.config.ENABLE_OAUTH_ROLE_MAPPING: + elif webui_app.state.config.ENABLE_OAUTH_ROLE_MAPPING.value: oauth_roles = user_data.get(webui_app.state.config.OAUTH_ROLE_CLAIM) if oauth_roles: for allowed_role in ["pending", "user", "admin"]: