rewrite oauth role management logic to allow any custom roles to be used for oauth role to open webui role mapping

This commit is contained in:
Willnow, Patrick 2024-10-10 23:00:05 +02:00
parent f751d22a20
commit edc15d0d7c
3 changed files with 105 additions and 85 deletions

View File

@ -32,7 +32,7 @@ from open_webui.config import (
ENABLE_MESSAGE_RATING,
ENABLE_SIGNUP,
JWT_EXPIRES_IN,
ENABLE_OAUTH_ROLE_MAPPING,
ENABLE_OAUTH_ROLE_MANAGEMENT,
OAUTH_ROLES_CLAIM,
OAUTH_EMAIL_CLAIM,
OAUTH_PICTURE_CLAIM,
@ -95,7 +95,7 @@ app.state.config.OAUTH_USERNAME_CLAIM = OAUTH_USERNAME_CLAIM
app.state.config.OAUTH_PICTURE_CLAIM = OAUTH_PICTURE_CLAIM
app.state.config.OAUTH_EMAIL_CLAIM = OAUTH_EMAIL_CLAIM
app.state.config.ENABLE_OAUTH_ROLE_MAPPING = ENABLE_OAUTH_ROLE_MAPPING
app.state.config.ENABLE_OAUTH_ROLE_MANAGEMENT = ENABLE_OAUTH_ROLE_MANAGEMENT
app.state.config.OAUTH_ROLES_CLAIM = OAUTH_ROLES_CLAIM
app.state.MODELS = {}

View File

@ -394,10 +394,10 @@ OAUTH_EMAIL_CLAIM = PersistentConfig(
os.environ.get("OAUTH_EMAIL_CLAIM", "email"),
)
ENABLE_OAUTH_ROLE_MAPPING = PersistentConfig(
"ENABLE_OAUTH_ROLE_MAPPING",
ENABLE_OAUTH_ROLE_MANAGEMENT = PersistentConfig(
"ENABLE_OAUTH_ROLE_MANAGEMENT",
"oauth.enable_role_mapping",
os.environ.get("ENABLE_OAUTH_ROLE_MAPPING", "False").lower() == "true",
os.environ.get("ENABLE_OAUTH_ROLE_MANAGEMENT", "False").lower() == "true",
)
OAUTH_ROLES_CLAIM = PersistentConfig(
@ -406,6 +406,17 @@ OAUTH_ROLES_CLAIM = PersistentConfig(
os.environ.get("OAUTH_ROLES_CLAIM", "roles"),
)
OAUTH_ALLOWED_ROLES = PersistentConfig(
"OAUTH_ALLOWED_ROLES",
"oauth.allowed_roles",
[role.strip() for role in os.environ.get("OAUTH_ALLOWED_ROLES", "pending,user,admin").split(",")],
)
OAUTH_ADMIN_ROLES = PersistentConfig(
"OAUTH_ADMIN_ROLES",
"oauth.admin_roles",
[role.strip() for role in os.environ.get("OAUTH_ADMIN_ROLES", "admin").split(",")],
)
def load_oauth_providers():
OAUTH_PROVIDERS.clear()

View File

@ -16,7 +16,6 @@ from typing import Optional
import aiohttp
import requests
from open_webui.apps.audio.main import app as audio_app
from open_webui.apps.images.main import app as images_app
from open_webui.apps.ollama.main import app as ollama_app
@ -47,11 +46,9 @@ from open_webui.apps.webui.models.models import Models
from open_webui.apps.webui.models.users import UserModel, Users
from open_webui.apps.webui.utils import load_function_module_by_id
from authlib.integrations.starlette_client import OAuth
from authlib.oidc.core import UserInfo
from open_webui.config import (
CACHE_DIR,
CORS_ALLOW_ORIGIN,
@ -151,7 +148,6 @@ if SAFE_MODE:
print("SAFE MODE ENABLED")
Functions.deactivate_all_functions()
logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MAIN"])
@ -210,7 +206,6 @@ app.state.config.MODEL_FILTER_LIST = MODEL_FILTER_LIST
app.state.config.WEBHOOK_URL = WEBHOOK_URL
app.state.config.TASK_MODEL = TASK_MODEL
app.state.config.TASK_MODEL_EXTERNAL = TASK_MODEL_EXTERNAL
app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE = TITLE_GENERATION_PROMPT_TEMPLATE
@ -676,6 +671,7 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
app.add_middleware(ChatCompletionMiddleware)
##################################
#
# Pipeline Middleware
@ -798,7 +794,6 @@ class PipelineMiddleware(BaseHTTPMiddleware):
app.add_middleware(PipelineMiddleware)
app.add_middleware(
CORSMiddleware,
allow_origins=CORS_ALLOW_ORIGIN,
@ -2198,6 +2193,53 @@ if len(OAUTH_PROVIDERS) > 0:
)
def get_user_role(user: UserModel, user_data: UserInfo) -> str:
if user and Users.get_num_users() == 1:
# If the user is the only user, assign the role "admin" - actually repairs role for single user on login
return "admin"
if not user and Users.get_num_users() == 0:
# If there are no users, assign the role "admin", as the first user will be an admin
return "admin"
if webui_app.state.config.ENABLE_OAUTH_ROLE_MANAGEMENT:
oauth_claim = webui_app.state.config.OAUTH_ROLES_CLAIM
oauth_allowed_roles = webui_app.state.config.OAUTH_ALLOWED_ROLES
oauth_admin_roles = webui_app.state.config.OAUTH_ADMIN_ROLES
oauth_roles = None
role = "pending" # Default/fallback role if no matching roles are found
# Next block extracts the roles from the user data, accepting nested claims of any depth
if oauth_claim and oauth_allowed_roles and oauth_admin_roles:
claim_data = user_data
nested_claims = oauth_claim.split(".")
for nested_claim in nested_claims:
claim_data = claim_data.get(nested_claim, {})
oauth_roles = claim_data if isinstance(claim_data, list) else None
# If any roles are found, check if they match the allowed or admin roles
if oauth_roles:
# If role management is enabled, and matching roles are provided, use the roles
for allowed_role in oauth_allowed_roles:
# If the user has any of the allowed roles, assign the role "user"
if allowed_role in oauth_roles:
role = "user"
break
for admin_role in oauth_admin_roles:
# If the user has any of the admin roles, assign the role "admin"
if admin_role in oauth_roles:
role = "admin"
break
else:
if not user:
# If role management is disabled, use the default role for new users
role = webui_app.state.config.DEFAULT_USER_ROLE
else:
# If role management is disabled, use the existing role for existing users
role = user.role
return role
@app.get("/oauth/{provider}/login")
async def oauth_login(provider: str, request: Request):
if provider not in OAUTH_PROVIDERS:
@ -2244,34 +2286,6 @@ async def oauth_callback(provider: str, request: Request, response: Response):
# Check if the user exists
user = Users.get_user_by_oauth_sub(provider_sub)
# print all user data content for debugging
log.info(f"User data: {user_data}")
if user:
role = user.role
if Users.get_num_users() == 1:
role = "admin"
elif webui_app.state.config.ENABLE_OAUTH_ROLE_MAPPING:
oauth_claim = webui_app.state.config.OAUTH_ROLES_CLAIM
oauth_roles = None
if oauth_claim:
claim_data = user_data
nested_claims = oauth_claim.split(".")
for nested_claim in nested_claims:
claim_data = claim_data.get(nested_claim, {})
oauth_roles = claim_data if isinstance(claim_data, list) else None
log.info(f"User {user.name} has OAuth roles: {oauth_roles}")
if oauth_roles:
for allowed_role in ["pending", "user", "admin"]:
role = allowed_role if allowed_role in oauth_roles else role
else:
# If role mapping is enabled, but no roles are provided, fall back to pending
role = "pending"
log.info(f"Applied role: {role} to user {user.name}")
if role != user.role:
Users.update_user_role_by_id(user.id, role)
if not user:
# If the user does not exist, check if merging is enabled
@ -2282,6 +2296,11 @@ async def oauth_callback(provider: str, request: Request, response: Response):
# Update the user with the new oauth sub
Users.update_user_oauth_sub_by_id(user.id, provider_sub)
if user:
determined_role = get_user_role(user, user_data)
if user.role != determined_role:
Users.update_user_role_by_id(user.id, determined_role)
if not user:
# If the user does not exist, check if signups are enabled
if ENABLE_OAUTH_SIGNUP.value:
@ -2313,17 +2332,7 @@ async def oauth_callback(provider: str, request: Request, response: Response):
picture_url = "/user.png"
username_claim = webui_app.state.config.OAUTH_USERNAME_CLAIM
role = webui_app.state.config.DEFAULT_USER_ROLE
if Users.get_num_users() == 0:
role = "admin"
elif webui_app.state.config.ENABLE_OAUTH_ROLE_MAPPING:
oauth_roles = user_data.get(webui_app.state.config.OAUTH_ROLE_CLAIM)
if oauth_roles:
for allowed_role in ["pending", "user", "admin"]:
role = allowed_role if allowed_role in oauth_roles else role
else:
# If role mapping is enabled, but no roles are provided, fall back to pending
role = "pending"
role = get_user_role(None, user_data)
user = Auths.insert_new_auth(
email=email,