mirror of
https://github.com/open-webui/open-webui
synced 2024-11-24 21:13:59 +00:00
Merge pull request #3569 from Semihal/custom-openid-claims
feat: Custom claims for OAuth
This commit is contained in:
commit
08c024d752
@ -39,6 +39,8 @@ from config import (
|
|||||||
WEBUI_BANNERS,
|
WEBUI_BANNERS,
|
||||||
ENABLE_COMMUNITY_SHARING,
|
ENABLE_COMMUNITY_SHARING,
|
||||||
AppConfig,
|
AppConfig,
|
||||||
|
OAUTH_USERNAME_CLAIM,
|
||||||
|
OAUTH_PICTURE_CLAIM,
|
||||||
)
|
)
|
||||||
|
|
||||||
import inspect
|
import inspect
|
||||||
@ -74,6 +76,9 @@ app.state.config.BANNERS = WEBUI_BANNERS
|
|||||||
|
|
||||||
app.state.config.ENABLE_COMMUNITY_SHARING = ENABLE_COMMUNITY_SHARING
|
app.state.config.ENABLE_COMMUNITY_SHARING = ENABLE_COMMUNITY_SHARING
|
||||||
|
|
||||||
|
app.state.config.OAUTH_USERNAME_CLAIM = OAUTH_USERNAME_CLAIM
|
||||||
|
app.state.config.OAUTH_PICTURE_CLAIM = OAUTH_PICTURE_CLAIM
|
||||||
|
|
||||||
app.state.MODELS = {}
|
app.state.MODELS = {}
|
||||||
app.state.TOOLS = {}
|
app.state.TOOLS = {}
|
||||||
app.state.FUNCTIONS = {}
|
app.state.FUNCTIONS = {}
|
||||||
|
@ -393,6 +393,18 @@ OAUTH_PROVIDER_NAME = PersistentConfig(
|
|||||||
os.environ.get("OAUTH_PROVIDER_NAME", "SSO"),
|
os.environ.get("OAUTH_PROVIDER_NAME", "SSO"),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
OAUTH_USERNAME_CLAIM = PersistentConfig(
|
||||||
|
"OAUTH_USERNAME_CLAIM",
|
||||||
|
"oauth.oidc.username_claim",
|
||||||
|
os.environ.get("OAUTH_USERNAME_CLAIM", "name"),
|
||||||
|
)
|
||||||
|
|
||||||
|
OAUTH_PICTURE_CLAIM = PersistentConfig(
|
||||||
|
"OAUTH_USERNAME_CLAIM",
|
||||||
|
"oauth.oidc.avatar_claim",
|
||||||
|
os.environ.get("OAUTH_PICTURE_CLAIM", "picture"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def load_oauth_providers():
|
def load_oauth_providers():
|
||||||
OAUTH_PROVIDERS.clear()
|
OAUTH_PROVIDERS.clear()
|
||||||
|
@ -2064,7 +2064,8 @@ async def oauth_callback(provider: str, request: Request, response: Response):
|
|||||||
if existing_user:
|
if existing_user:
|
||||||
raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN)
|
raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN)
|
||||||
|
|
||||||
picture_url = user_data.get("picture", "")
|
picture_claim = webui_app.state.config.OAUTH_PICTURE_CLAIM
|
||||||
|
picture_url = user_data.get(picture_claim, "")
|
||||||
if picture_url:
|
if picture_url:
|
||||||
# Download the profile image into a base64 string
|
# Download the profile image into a base64 string
|
||||||
try:
|
try:
|
||||||
@ -2084,6 +2085,7 @@ async def oauth_callback(provider: str, request: Request, response: Response):
|
|||||||
picture_url = ""
|
picture_url = ""
|
||||||
if not picture_url:
|
if not picture_url:
|
||||||
picture_url = "/user.png"
|
picture_url = "/user.png"
|
||||||
|
username_claim = webui_app.state.config.OAUTH_USERNAME_CLAIM
|
||||||
role = (
|
role = (
|
||||||
"admin"
|
"admin"
|
||||||
if Users.get_num_users() == 0
|
if Users.get_num_users() == 0
|
||||||
@ -2094,7 +2096,7 @@ async def oauth_callback(provider: str, request: Request, response: Response):
|
|||||||
password=get_password_hash(
|
password=get_password_hash(
|
||||||
str(uuid.uuid4())
|
str(uuid.uuid4())
|
||||||
), # Random password, not used
|
), # Random password, not used
|
||||||
name=user_data.get("name", "User"),
|
name=user_data.get(username_claim, "User"),
|
||||||
profile_image_url=picture_url,
|
profile_image_url=picture_url,
|
||||||
role=role,
|
role=role,
|
||||||
oauth_sub=provider_sub,
|
oauth_sub=provider_sub,
|
||||||
|
Loading…
Reference in New Issue
Block a user