diff --git a/backend/open_webui/config.py b/backend/open_webui/config.py index 12812ae0a..cf6a594ed 100644 --- a/backend/open_webui/config.py +++ b/backend/open_webui/config.py @@ -468,12 +468,20 @@ OAUTH_ALLOWED_DOMAINS = PersistentConfig( def load_oauth_providers(): OAUTH_PROVIDERS.clear() if GOOGLE_CLIENT_ID.value and GOOGLE_CLIENT_SECRET.value: + def google_oauth_register(client): + client.register( + name="google", + client_id=GOOGLE_CLIENT_ID.value, + client_secret=GOOGLE_CLIENT_SECRET.value, + server_metadata_url="https://accounts.google.com/.well-known/openid-configuration", + client_kwargs={ + "scope": GOOGLE_OAUTH_SCOPE.value + }, + redirect_uri=GOOGLE_REDIRECT_URI.value, + ) OAUTH_PROVIDERS["google"] = { - "client_id": GOOGLE_CLIENT_ID.value, - "client_secret": GOOGLE_CLIENT_SECRET.value, - "server_metadata_url": "https://accounts.google.com/.well-known/openid-configuration", - "scope": GOOGLE_OAUTH_SCOPE.value, "redirect_uri": GOOGLE_REDIRECT_URI.value, + "register": google_oauth_register, } if ( @@ -481,13 +489,21 @@ def load_oauth_providers(): and MICROSOFT_CLIENT_SECRET.value and MICROSOFT_CLIENT_TENANT_ID.value ): + def microsoft_oauth_register(client): + client.register( + name="microsoft", + client_id=MICROSOFT_CLIENT_ID.value, + client_secret=MICROSOFT_CLIENT_SECRET.value, + server_metadata_url=f"https://login.microsoftonline.com/{MICROSOFT_CLIENT_TENANT_ID.value}/v2.0/.well-known/openid-configuration", + client_kwargs={ + "scope": MICROSOFT_OAUTH_SCOPE.value, + }, + redirect_uri=MICROSOFT_REDIRECT_URI.value, + ) OAUTH_PROVIDERS["microsoft"] = { - "client_id": MICROSOFT_CLIENT_ID.value, - "client_secret": MICROSOFT_CLIENT_SECRET.value, - "server_metadata_url": f"https://login.microsoftonline.com/{MICROSOFT_CLIENT_TENANT_ID.value}/v2.0/.well-known/openid-configuration", - "scope": MICROSOFT_OAUTH_SCOPE.value, "redirect_uri": MICROSOFT_REDIRECT_URI.value, "picture_url": "https://graph.microsoft.com/v1.0/me/photo/$value", + "register": microsoft_oauth_register, } if ( @@ -495,13 +511,20 @@ def load_oauth_providers(): and OAUTH_CLIENT_SECRET.value and OPENID_PROVIDER_URL.value ): + def oidc_oauth_register(client): + client.register( + name="oidc", + client_id=OAUTH_CLIENT_ID.value, + client_secret=OAUTH_CLIENT_SECRET.value, + server_metadata_url=OPENID_PROVIDER_URL.value, + client_kwargs={ + "scope": OAUTH_SCOPES.value, + }, + redirect_uri=OPENID_REDIRECT_URI.value, + ) OAUTH_PROVIDERS["oidc"] = { - "client_id": OAUTH_CLIENT_ID.value, - "client_secret": OAUTH_CLIENT_SECRET.value, - "server_metadata_url": OPENID_PROVIDER_URL.value, - "scope": OAUTH_SCOPES.value, "name": OAUTH_PROVIDER_NAME.value, - "redirect_uri": OPENID_REDIRECT_URI.value, + "register": oidc_oauth_register, } diff --git a/backend/open_webui/utils/oauth.py b/backend/open_webui/utils/oauth.py index 386ea05ca..1ae6d4aa7 100644 --- a/backend/open_webui/utils/oauth.py +++ b/backend/open_webui/utils/oauth.py @@ -63,17 +63,8 @@ auth_manager_config.JWT_EXPIRES_IN = JWT_EXPIRES_IN class OAuthManager: def __init__(self): self.oauth = OAuth() - for provider_name, provider_config in OAUTH_PROVIDERS.items(): - self.oauth.register( - name=provider_name, - client_id=provider_config["client_id"], - client_secret=provider_config["client_secret"], - server_metadata_url=provider_config["server_metadata_url"], - client_kwargs={ - "scope": provider_config["scope"], - }, - redirect_uri=provider_config["redirect_uri"], - ) + for _, provider_config in OAUTH_PROVIDERS.items(): + provider_config["register"](self.oauth) def get_client(self, provider_name): return self.oauth.create_client(provider_name) @@ -207,7 +198,7 @@ class OAuthManager: log.warning(f"OAuth callback failed, user data is missing: {token}") raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED) - sub = user_data.get("sub") + sub = user_data.get(OAUTH_PROVIDERS[provider].get("sub_claim", "sub")) if not sub: log.warning(f"OAuth callback failed, sub is missing: {user_data}") raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)