From edc15d0d7ce0a56f1b8fc601cd23cbabb9ad7e34 Mon Sep 17 00:00:00 2001 From: "Willnow, Patrick" Date: Thu, 10 Oct 2024 23:00:05 +0200 Subject: [PATCH] rewrite oauth role management logic to allow any custom roles to be used for oauth role to open webui role mapping --- backend/open_webui/apps/webui/main.py | 4 +- backend/open_webui/config.py | 17 ++- backend/open_webui/main.py | 169 ++++++++++++++------------ 3 files changed, 105 insertions(+), 85 deletions(-) diff --git a/backend/open_webui/apps/webui/main.py b/backend/open_webui/apps/webui/main.py index 8682cd767..0208c0ea9 100644 --- a/backend/open_webui/apps/webui/main.py +++ b/backend/open_webui/apps/webui/main.py @@ -32,7 +32,7 @@ from open_webui.config import ( ENABLE_MESSAGE_RATING, ENABLE_SIGNUP, JWT_EXPIRES_IN, - ENABLE_OAUTH_ROLE_MAPPING, + ENABLE_OAUTH_ROLE_MANAGEMENT, OAUTH_ROLES_CLAIM, OAUTH_EMAIL_CLAIM, OAUTH_PICTURE_CLAIM, @@ -95,7 +95,7 @@ app.state.config.OAUTH_USERNAME_CLAIM = OAUTH_USERNAME_CLAIM app.state.config.OAUTH_PICTURE_CLAIM = OAUTH_PICTURE_CLAIM app.state.config.OAUTH_EMAIL_CLAIM = OAUTH_EMAIL_CLAIM -app.state.config.ENABLE_OAUTH_ROLE_MAPPING = ENABLE_OAUTH_ROLE_MAPPING +app.state.config.ENABLE_OAUTH_ROLE_MANAGEMENT = ENABLE_OAUTH_ROLE_MANAGEMENT app.state.config.OAUTH_ROLES_CLAIM = OAUTH_ROLES_CLAIM app.state.MODELS = {} diff --git a/backend/open_webui/config.py b/backend/open_webui/config.py index 18e0c398e..c171eaadb 100644 --- a/backend/open_webui/config.py +++ b/backend/open_webui/config.py @@ -394,10 +394,10 @@ OAUTH_EMAIL_CLAIM = PersistentConfig( os.environ.get("OAUTH_EMAIL_CLAIM", "email"), ) -ENABLE_OAUTH_ROLE_MAPPING = PersistentConfig( - "ENABLE_OAUTH_ROLE_MAPPING", +ENABLE_OAUTH_ROLE_MANAGEMENT = PersistentConfig( + "ENABLE_OAUTH_ROLE_MANAGEMENT", "oauth.enable_role_mapping", - os.environ.get("ENABLE_OAUTH_ROLE_MAPPING", "False").lower() == "true", + os.environ.get("ENABLE_OAUTH_ROLE_MANAGEMENT", "False").lower() == "true", ) OAUTH_ROLES_CLAIM = PersistentConfig( @@ -406,6 +406,17 @@ OAUTH_ROLES_CLAIM = PersistentConfig( os.environ.get("OAUTH_ROLES_CLAIM", "roles"), ) +OAUTH_ALLOWED_ROLES = PersistentConfig( + "OAUTH_ALLOWED_ROLES", + "oauth.allowed_roles", + [role.strip() for role in os.environ.get("OAUTH_ALLOWED_ROLES", "pending,user,admin").split(",")], +) + +OAUTH_ADMIN_ROLES = PersistentConfig( + "OAUTH_ADMIN_ROLES", + "oauth.admin_roles", + [role.strip() for role in os.environ.get("OAUTH_ADMIN_ROLES", "admin").split(",")], +) def load_oauth_providers(): OAUTH_PROVIDERS.clear() diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index 7374b7f62..8095a66ca 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -16,7 +16,6 @@ from typing import Optional import aiohttp import requests - from open_webui.apps.audio.main import app as audio_app from open_webui.apps.images.main import app as images_app from open_webui.apps.ollama.main import app as ollama_app @@ -47,11 +46,9 @@ from open_webui.apps.webui.models.models import Models from open_webui.apps.webui.models.users import UserModel, Users from open_webui.apps.webui.utils import load_function_module_by_id - from authlib.integrations.starlette_client import OAuth from authlib.oidc.core import UserInfo - from open_webui.config import ( CACHE_DIR, CORS_ALLOW_ORIGIN, @@ -151,7 +148,6 @@ if SAFE_MODE: print("SAFE MODE ENABLED") Functions.deactivate_all_functions() - logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL) log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["MAIN"]) @@ -210,7 +206,6 @@ app.state.config.MODEL_FILTER_LIST = MODEL_FILTER_LIST app.state.config.WEBHOOK_URL = WEBHOOK_URL - app.state.config.TASK_MODEL = TASK_MODEL app.state.config.TASK_MODEL_EXTERNAL = TASK_MODEL_EXTERNAL app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE = TITLE_GENERATION_PROMPT_TEMPLATE @@ -238,14 +233,14 @@ def get_task_model_id(default_model_id): # Check if the user has a custom task model and use that model if app.state.MODELS[task_model_id]["owned_by"] == "ollama": if ( - app.state.config.TASK_MODEL - and app.state.config.TASK_MODEL in app.state.MODELS + app.state.config.TASK_MODEL + and app.state.config.TASK_MODEL in app.state.MODELS ): task_model_id = app.state.config.TASK_MODEL else: if ( - app.state.config.TASK_MODEL_EXTERNAL - and app.state.config.TASK_MODEL_EXTERNAL in app.state.MODELS + app.state.config.TASK_MODEL_EXTERNAL + and app.state.config.TASK_MODEL_EXTERNAL in app.state.MODELS ): task_model_id = app.state.config.TASK_MODEL_EXTERNAL @@ -382,7 +377,7 @@ async def get_content_from_response(response) -> Optional[str]: async def chat_completion_tools_handler( - body: dict, user: UserModel, extra_params: dict + body: dict, user: UserModel, extra_params: dict ) -> tuple[dict, dict]: # If tool_ids field is present, call the functions metadata = body.get("metadata", {}) @@ -608,8 +603,8 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): if prompt is None: raise Exception("No user message found") if ( - rag_app.state.config.RELEVANCE_THRESHOLD == 0 - and context_string.strip() == "" + rag_app.state.config.RELEVANCE_THRESHOLD == 0 + and context_string.strip() == "" ): log.debug( f"With a 0 relevancy threshold for RAG, the context cannot be empty" @@ -676,6 +671,7 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): app.add_middleware(ChatCompletionMiddleware) + ################################## # # Pipeline Middleware @@ -688,15 +684,15 @@ def get_sorted_filters(model_id): model for model in app.state.MODELS.values() if "pipeline" in model - and "type" in model["pipeline"] - and model["pipeline"]["type"] == "filter" - and ( - model["pipeline"]["pipelines"] == ["*"] - or any( - model_id == target_model_id - for target_model_id in model["pipeline"]["pipelines"] - ) - ) + and "type" in model["pipeline"] + and model["pipeline"]["type"] == "filter" + and ( + model["pipeline"]["pipelines"] == ["*"] + or any( + model_id == target_model_id + for target_model_id in model["pipeline"]["pipelines"] + ) + ) ] sorted_filters = sorted(filters, key=lambda x: x["pipeline"]["priority"]) return sorted_filters @@ -798,7 +794,6 @@ class PipelineMiddleware(BaseHTTPMiddleware): app.add_middleware(PipelineMiddleware) - app.add_middleware( CORSMiddleware, allow_origins=CORS_ALLOW_ORIGIN, @@ -844,8 +839,8 @@ async def update_embedding_function(request: Request, call_next): @app.middleware("http") async def inspect_websocket(request: Request, call_next): if ( - "/ws/socket.io" in request.url.path - and request.query_params.get("transport") == "websocket" + "/ws/socket.io" in request.url.path + and request.query_params.get("transport") == "websocket" ): upgrade = (request.headers.get("Upgrade") or "").lower() connection = (request.headers.get("Connection") or "").lower().split(",") @@ -913,8 +908,8 @@ async def get_all_models(): if custom_model.base_model_id is None: for model in models: if ( - custom_model.id == model["id"] - or custom_model.id == model["id"].split(":")[0] + custom_model.id == model["id"] + or custom_model.id == model["id"].split(":")[0] ): model["name"] = custom_model.name model["info"] = custom_model.model_dump() @@ -931,8 +926,8 @@ async def get_all_models(): for model in models: if ( - custom_model.base_model_id == model["id"] - or custom_model.base_model_id == model["id"].split(":")[0] + custom_model.base_model_id == model["id"] + or custom_model.base_model_id == model["id"].split(":")[0] ): owned_by = model["owned_by"] if "pipe" in model: @@ -1727,7 +1722,7 @@ async def get_pipelines_list(user=Depends(get_admin_user)): @app.post("/api/pipelines/upload") async def upload_pipeline( - urlIdx: int = Form(...), file: UploadFile = File(...), user=Depends(get_admin_user) + urlIdx: int = Form(...), file: UploadFile = File(...), user=Depends(get_admin_user) ): print("upload_pipeline", urlIdx, file.filename) # Check if the uploaded file is a python file @@ -1904,9 +1899,9 @@ async def get_pipelines(urlIdx: Optional[int] = None, user=Depends(get_admin_use @app.get("/api/pipelines/{pipeline_id}/valves") async def get_pipeline_valves( - urlIdx: Optional[int], - pipeline_id: str, - user=Depends(get_admin_user), + urlIdx: Optional[int], + pipeline_id: str, + user=Depends(get_admin_user), ): r = None try: @@ -1942,9 +1937,9 @@ async def get_pipeline_valves( @app.get("/api/pipelines/{pipeline_id}/valves/spec") async def get_pipeline_valves_spec( - urlIdx: Optional[int], - pipeline_id: str, - user=Depends(get_admin_user), + urlIdx: Optional[int], + pipeline_id: str, + user=Depends(get_admin_user), ): r = None try: @@ -1979,10 +1974,10 @@ async def get_pipeline_valves_spec( @app.post("/api/pipelines/{pipeline_id}/valves/update") async def update_pipeline_valves( - urlIdx: Optional[int], - pipeline_id: str, - form_data: dict, - user=Depends(get_admin_user), + urlIdx: Optional[int], + pipeline_id: str, + form_data: dict, + user=Depends(get_admin_user), ): r = None try: @@ -2106,7 +2101,7 @@ class ModelFilterConfigForm(BaseModel): @app.post("/api/config/model/filter") async def update_model_filter_config( - form_data: ModelFilterConfigForm, user=Depends(get_admin_user) + form_data: ModelFilterConfigForm, user=Depends(get_admin_user) ): app.state.config.ENABLE_MODEL_FILTER = form_data.enabled app.state.config.MODEL_FILTER_LIST = form_data.models @@ -2155,7 +2150,7 @@ async def get_app_latest_release_version(): try: async with aiohttp.ClientSession(trust_env=True) as session: async with session.get( - "https://api.github.com/repos/open-webui/open-webui/releases/latest" + "https://api.github.com/repos/open-webui/open-webui/releases/latest" ) as response: response.raise_for_status() data = await response.json() @@ -2198,6 +2193,53 @@ if len(OAUTH_PROVIDERS) > 0: ) +def get_user_role(user: UserModel, user_data: UserInfo) -> str: + if user and Users.get_num_users() == 1: + # If the user is the only user, assign the role "admin" - actually repairs role for single user on login + return "admin" + if not user and Users.get_num_users() == 0: + # If there are no users, assign the role "admin", as the first user will be an admin + return "admin" + + if webui_app.state.config.ENABLE_OAUTH_ROLE_MANAGEMENT: + oauth_claim = webui_app.state.config.OAUTH_ROLES_CLAIM + oauth_allowed_roles = webui_app.state.config.OAUTH_ALLOWED_ROLES + oauth_admin_roles = webui_app.state.config.OAUTH_ADMIN_ROLES + oauth_roles = None + role = "pending" # Default/fallback role if no matching roles are found + + # Next block extracts the roles from the user data, accepting nested claims of any depth + if oauth_claim and oauth_allowed_roles and oauth_admin_roles: + claim_data = user_data + nested_claims = oauth_claim.split(".") + for nested_claim in nested_claims: + claim_data = claim_data.get(nested_claim, {}) + oauth_roles = claim_data if isinstance(claim_data, list) else None + + # If any roles are found, check if they match the allowed or admin roles + if oauth_roles: + # If role management is enabled, and matching roles are provided, use the roles + for allowed_role in oauth_allowed_roles: + # If the user has any of the allowed roles, assign the role "user" + if allowed_role in oauth_roles: + role = "user" + break + for admin_role in oauth_admin_roles: + # If the user has any of the admin roles, assign the role "admin" + if admin_role in oauth_roles: + role = "admin" + break + else: + if not user: + # If role management is disabled, use the default role for new users + role = webui_app.state.config.DEFAULT_USER_ROLE + else: + # If role management is disabled, use the existing role for existing users + role = user.role + + return role + + @app.get("/oauth/{provider}/login") async def oauth_login(provider: str, request: Request): if provider not in OAUTH_PROVIDERS: @@ -2244,34 +2286,6 @@ async def oauth_callback(provider: str, request: Request, response: Response): # Check if the user exists user = Users.get_user_by_oauth_sub(provider_sub) - # print all user data content for debugging - log.info(f"User data: {user_data}") - if user: - role = user.role - if Users.get_num_users() == 1: - role = "admin" - elif webui_app.state.config.ENABLE_OAUTH_ROLE_MAPPING: - oauth_claim = webui_app.state.config.OAUTH_ROLES_CLAIM - oauth_roles = None - - if oauth_claim: - claim_data = user_data - nested_claims = oauth_claim.split(".") - for nested_claim in nested_claims: - claim_data = claim_data.get(nested_claim, {}) - oauth_roles = claim_data if isinstance(claim_data, list) else None - - log.info(f"User {user.name} has OAuth roles: {oauth_roles}") - if oauth_roles: - for allowed_role in ["pending", "user", "admin"]: - role = allowed_role if allowed_role in oauth_roles else role - else: - # If role mapping is enabled, but no roles are provided, fall back to pending - role = "pending" - log.info(f"Applied role: {role} to user {user.name}") - - if role != user.role: - Users.update_user_role_by_id(user.id, role) if not user: # If the user does not exist, check if merging is enabled @@ -2282,6 +2296,11 @@ async def oauth_callback(provider: str, request: Request, response: Response): # Update the user with the new oauth sub Users.update_user_oauth_sub_by_id(user.id, provider_sub) + if user: + determined_role = get_user_role(user, user_data) + if user.role != determined_role: + Users.update_user_role_by_id(user.id, determined_role) + if not user: # If the user does not exist, check if signups are enabled if ENABLE_OAUTH_SIGNUP.value: @@ -2313,17 +2332,7 @@ async def oauth_callback(provider: str, request: Request, response: Response): picture_url = "/user.png" username_claim = webui_app.state.config.OAUTH_USERNAME_CLAIM - role = webui_app.state.config.DEFAULT_USER_ROLE - if Users.get_num_users() == 0: - role = "admin" - elif webui_app.state.config.ENABLE_OAUTH_ROLE_MAPPING: - oauth_roles = user_data.get(webui_app.state.config.OAUTH_ROLE_CLAIM) - if oauth_roles: - for allowed_role in ["pending", "user", "admin"]: - role = allowed_role if allowed_role in oauth_roles else role - else: - # If role mapping is enabled, but no roles are provided, fall back to pending - role = "pending" + role = get_user_role(None, user_data) user = Auths.insert_new_auth( email=email,