From 0c3f9a16e3c3ecd882b69bea2363902889a3c4c8 Mon Sep 17 00:00:00 2001 From: Sergey Mihaylin Date: Fri, 28 Jun 2024 16:31:40 +0300 Subject: [PATCH] custom env for set custom claims for openid --- backend/apps/webui/main.py | 5 +++++ backend/config.py | 12 ++++++++++++ backend/main.py | 6 ++++-- 3 files changed, 21 insertions(+), 2 deletions(-) diff --git a/backend/apps/webui/main.py b/backend/apps/webui/main.py index 28b1b4aac..e7f0683c6 100644 --- a/backend/apps/webui/main.py +++ b/backend/apps/webui/main.py @@ -39,6 +39,8 @@ from config import ( WEBUI_BANNERS, ENABLE_COMMUNITY_SHARING, AppConfig, + OAUTH_USERNAME_CLAIM, + OAUTH_PICTURE_CLAIM ) import inspect @@ -74,6 +76,9 @@ app.state.config.BANNERS = WEBUI_BANNERS app.state.config.ENABLE_COMMUNITY_SHARING = ENABLE_COMMUNITY_SHARING +app.state.config.OAUTH_USERNAME_CLAIM = OAUTH_USERNAME_CLAIM +app.state.config.OAUTH_PICTURE_CLAIM = OAUTH_PICTURE_CLAIM + app.state.MODELS = {} app.state.TOOLS = {} app.state.FUNCTIONS = {} diff --git a/backend/config.py b/backend/config.py index 3a825f53a..cd184aab8 100644 --- a/backend/config.py +++ b/backend/config.py @@ -395,6 +395,18 @@ OAUTH_PROVIDER_NAME = PersistentConfig( os.environ.get("OAUTH_PROVIDER_NAME", "SSO"), ) +OAUTH_USERNAME_CLAIM = PersistentConfig( + "OAUTH_USERNAME_CLAIM", + "oauth.oidc.username_claim", + os.environ.get("OAUTH_USERNAME_CLAIM", "name"), +) + +OAUTH_PICTURE_CLAIM = PersistentConfig( + "OAUTH_USERNAME_CLAIM", + "oauth.oidc.avatar_claim", + os.environ.get("OAUTH_PICTURE_CLAIM", "picture"), +) + def load_oauth_providers(): OAUTH_PROVIDERS.clear() diff --git a/backend/main.py b/backend/main.py index aae305c5e..b4fd10c21 100644 --- a/backend/main.py +++ b/backend/main.py @@ -1920,11 +1920,13 @@ async def oauth_callback(provider: str, request: Request, response: Response): # If the user does not exist, check if signups are enabled if ENABLE_OAUTH_SIGNUP.value: # Check if an existing user with the same email already exists - existing_user = Users.get_user_by_email(user_data.get("email", "").lower()) + email_claim = webui_app.state.config.OAUTH_USERNAME_CLAIM + existing_user = Users.get_user_by_email(user_data.get(email_claim, "").lower()) if existing_user: raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN) - picture_url = user_data.get("picture", "") + picture_claim = webui_app.state.config.OAUTH_PICTURE_CLAIM + picture_url = user_data.get(picture_claim, "") if picture_url: # Download the profile image into a base64 string try: