diff --git a/backend/apps/webui/main.py b/backend/apps/webui/main.py index b23ac782b..552edf7fa 100644 --- a/backend/apps/webui/main.py +++ b/backend/apps/webui/main.py @@ -39,6 +39,8 @@ from config import ( WEBUI_BANNERS, ENABLE_COMMUNITY_SHARING, AppConfig, + OAUTH_USERNAME_CLAIM, + OAUTH_PICTURE_CLAIM, ) import inspect @@ -74,6 +76,9 @@ app.state.config.BANNERS = WEBUI_BANNERS app.state.config.ENABLE_COMMUNITY_SHARING = ENABLE_COMMUNITY_SHARING +app.state.config.OAUTH_USERNAME_CLAIM = OAUTH_USERNAME_CLAIM +app.state.config.OAUTH_PICTURE_CLAIM = OAUTH_PICTURE_CLAIM + app.state.MODELS = {} app.state.TOOLS = {} app.state.FUNCTIONS = {} diff --git a/backend/config.py b/backend/config.py index 2fcc0ba64..03374ce1f 100644 --- a/backend/config.py +++ b/backend/config.py @@ -393,6 +393,18 @@ OAUTH_PROVIDER_NAME = PersistentConfig( os.environ.get("OAUTH_PROVIDER_NAME", "SSO"), ) +OAUTH_USERNAME_CLAIM = PersistentConfig( + "OAUTH_USERNAME_CLAIM", + "oauth.oidc.username_claim", + os.environ.get("OAUTH_USERNAME_CLAIM", "name"), +) + +OAUTH_PICTURE_CLAIM = PersistentConfig( + "OAUTH_USERNAME_CLAIM", + "oauth.oidc.avatar_claim", + os.environ.get("OAUTH_PICTURE_CLAIM", "picture"), +) + def load_oauth_providers(): OAUTH_PROVIDERS.clear() diff --git a/backend/main.py b/backend/main.py index 49e068a75..8f818c85b 100644 --- a/backend/main.py +++ b/backend/main.py @@ -2064,7 +2064,8 @@ async def oauth_callback(provider: str, request: Request, response: Response): if existing_user: raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN) - picture_url = user_data.get("picture", "") + picture_claim = webui_app.state.config.OAUTH_PICTURE_CLAIM + picture_url = user_data.get(picture_claim, "") if picture_url: # Download the profile image into a base64 string try: @@ -2084,6 +2085,7 @@ async def oauth_callback(provider: str, request: Request, response: Response): picture_url = "" if not picture_url: picture_url = "/user.png" + username_claim = webui_app.state.config.OAUTH_USERNAME_CLAIM role = ( "admin" if Users.get_num_users() == 0 @@ -2094,7 +2096,7 @@ async def oauth_callback(provider: str, request: Request, response: Response): password=get_password_hash( str(uuid.uuid4()) ), # Random password, not used - name=user_data.get("name", "User"), + name=user_data.get(username_claim, "User"), profile_image_url=picture_url, role=role, oauth_sub=provider_sub,