Merge pull request #3569 from Semihal/custom-openid-claims

feat: Custom claims for OAuth
This commit is contained in:
Timothy Jaeryang Baek 2024-07-03 15:56:37 -07:00 committed by GitHub
commit 08c024d752
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 21 additions and 2 deletions

View File

@ -39,6 +39,8 @@ from config import (
WEBUI_BANNERS,
ENABLE_COMMUNITY_SHARING,
AppConfig,
OAUTH_USERNAME_CLAIM,
OAUTH_PICTURE_CLAIM,
)
import inspect
@ -74,6 +76,9 @@ app.state.config.BANNERS = WEBUI_BANNERS
app.state.config.ENABLE_COMMUNITY_SHARING = ENABLE_COMMUNITY_SHARING
app.state.config.OAUTH_USERNAME_CLAIM = OAUTH_USERNAME_CLAIM
app.state.config.OAUTH_PICTURE_CLAIM = OAUTH_PICTURE_CLAIM
app.state.MODELS = {}
app.state.TOOLS = {}
app.state.FUNCTIONS = {}

View File

@ -393,6 +393,18 @@ OAUTH_PROVIDER_NAME = PersistentConfig(
os.environ.get("OAUTH_PROVIDER_NAME", "SSO"),
)
OAUTH_USERNAME_CLAIM = PersistentConfig(
"OAUTH_USERNAME_CLAIM",
"oauth.oidc.username_claim",
os.environ.get("OAUTH_USERNAME_CLAIM", "name"),
)
OAUTH_PICTURE_CLAIM = PersistentConfig(
"OAUTH_USERNAME_CLAIM",
"oauth.oidc.avatar_claim",
os.environ.get("OAUTH_PICTURE_CLAIM", "picture"),
)
def load_oauth_providers():
OAUTH_PROVIDERS.clear()

View File

@ -2064,7 +2064,8 @@ async def oauth_callback(provider: str, request: Request, response: Response):
if existing_user:
raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN)
picture_url = user_data.get("picture", "")
picture_claim = webui_app.state.config.OAUTH_PICTURE_CLAIM
picture_url = user_data.get(picture_claim, "")
if picture_url:
# Download the profile image into a base64 string
try:
@ -2084,6 +2085,7 @@ async def oauth_callback(provider: str, request: Request, response: Response):
picture_url = ""
if not picture_url:
picture_url = "/user.png"
username_claim = webui_app.state.config.OAUTH_USERNAME_CLAIM
role = (
"admin"
if Users.get_num_users() == 0
@ -2094,7 +2096,7 @@ async def oauth_callback(provider: str, request: Request, response: Response):
password=get_password_hash(
str(uuid.uuid4())
), # Random password, not used
name=user_data.get("name", "User"),
name=user_data.get(username_claim, "User"),
profile_image_url=picture_url,
role=role,
oauth_sub=provider_sub,