diff --git a/backend/open_webui/apps/webui/main.py b/backend/open_webui/apps/webui/main.py index 4535501fd..ae54ab29a 100644 --- a/backend/open_webui/apps/webui/main.py +++ b/backend/open_webui/apps/webui/main.py @@ -110,6 +110,8 @@ app.state.config.ADMIN_EMAIL = ADMIN_EMAIL app.state.config.DEFAULT_MODELS = DEFAULT_MODELS app.state.config.DEFAULT_PROMPT_SUGGESTIONS = DEFAULT_PROMPT_SUGGESTIONS app.state.config.DEFAULT_USER_ROLE = DEFAULT_USER_ROLE + + app.state.config.USER_PERMISSIONS = USER_PERMISSIONS app.state.config.WEBHOOK_URL = WEBHOOK_URL app.state.config.BANNERS = WEBUI_BANNERS diff --git a/backend/open_webui/apps/webui/routers/auths.py b/backend/open_webui/apps/webui/routers/auths.py index feea350cc..d3592f03b 100644 --- a/backend/open_webui/apps/webui/routers/auths.py +++ b/backend/open_webui/apps/webui/routers/auths.py @@ -40,10 +40,12 @@ from open_webui.utils.utils import ( get_password_hash, ) from open_webui.utils.webhook import post_webhook +from open_webui.utils.access_control import get_permissions + from typing import Optional, List -from ldap3 import Server, Connection, ALL, Tls from ssl import CERT_REQUIRED, PROTOCOL_TLS +from ldap3 import Server, Connection, ALL, Tls from ldap3.utils.conv import escape_filter_chars router = APIRouter() @@ -58,6 +60,7 @@ log.setLevel(SRC_LOG_LEVELS["MAIN"]) class SessionUserResponse(Token, UserResponse): expires_at: Optional[int] = None + permissions: Optional[dict] = None @router.get("/", response_model=SessionUserResponse) @@ -90,6 +93,10 @@ async def get_session_user( secure=WEBUI_SESSION_COOKIE_SECURE, ) + user_permissions = get_permissions( + user.id, request.app.state.config.USER_PERMISSIONS + ) + return { "token": token, "token_type": "Bearer", @@ -99,6 +106,7 @@ async def get_session_user( "name": user.name, "role": user.role, "profile_image_url": user.profile_image_url, + "permissions": user_permissions, } @@ -163,40 +171,67 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm): LDAP_APP_PASSWORD = request.app.state.config.LDAP_APP_PASSWORD LDAP_USE_TLS = request.app.state.config.LDAP_USE_TLS LDAP_CA_CERT_FILE = request.app.state.config.LDAP_CA_CERT_FILE - LDAP_CIPHERS = request.app.state.config.LDAP_CIPHERS if request.app.state.config.LDAP_CIPHERS else 'ALL' + LDAP_CIPHERS = ( + request.app.state.config.LDAP_CIPHERS + if request.app.state.config.LDAP_CIPHERS + else "ALL" + ) if not ENABLE_LDAP: raise HTTPException(400, detail="LDAP authentication is not enabled") try: - tls = Tls(validate=CERT_REQUIRED, version=PROTOCOL_TLS, ca_certs_file=LDAP_CA_CERT_FILE, ciphers=LDAP_CIPHERS) + tls = Tls( + validate=CERT_REQUIRED, + version=PROTOCOL_TLS, + ca_certs_file=LDAP_CA_CERT_FILE, + ciphers=LDAP_CIPHERS, + ) except Exception as e: log.error(f"An error occurred on TLS: {str(e)}") raise HTTPException(400, detail=str(e)) try: - server = Server(host=LDAP_SERVER_HOST, port=LDAP_SERVER_PORT, get_info=ALL, use_ssl=LDAP_USE_TLS, tls=tls) - connection_app = Connection(server, LDAP_APP_DN, LDAP_APP_PASSWORD, auto_bind='NONE', authentication='SIMPLE') + server = Server( + host=LDAP_SERVER_HOST, + port=LDAP_SERVER_PORT, + get_info=ALL, + use_ssl=LDAP_USE_TLS, + tls=tls, + ) + connection_app = Connection( + server, + LDAP_APP_DN, + LDAP_APP_PASSWORD, + auto_bind="NONE", + authentication="SIMPLE", + ) if not connection_app.bind(): raise HTTPException(400, detail="Application account bind failed") search_success = connection_app.search( search_base=LDAP_SEARCH_BASE, - search_filter=f'(&({LDAP_ATTRIBUTE_FOR_USERNAME}={escape_filter_chars(form_data.user.lower())}){LDAP_SEARCH_FILTERS})', - attributes=[f'{LDAP_ATTRIBUTE_FOR_USERNAME}', 'mail', 'cn'] + search_filter=f"(&({LDAP_ATTRIBUTE_FOR_USERNAME}={escape_filter_chars(form_data.user.lower())}){LDAP_SEARCH_FILTERS})", + attributes=[f"{LDAP_ATTRIBUTE_FOR_USERNAME}", "mail", "cn"], ) if not search_success: raise HTTPException(400, detail="User not found in the LDAP server") entry = connection_app.entries[0] - username = str(entry[f'{LDAP_ATTRIBUTE_FOR_USERNAME}']).lower() - mail = str(entry['mail']) - cn = str(entry['cn']) + username = str(entry[f"{LDAP_ATTRIBUTE_FOR_USERNAME}"]).lower() + mail = str(entry["mail"]) + cn = str(entry["cn"]) user_dn = entry.entry_dn if username == form_data.user.lower(): - connection_user = Connection(server, user_dn, form_data.password, auto_bind='NONE', authentication='SIMPLE') + connection_user = Connection( + server, + user_dn, + form_data.password, + auto_bind="NONE", + authentication="SIMPLE", + ) if not connection_user.bind(): raise HTTPException(400, f"Authentication failed for {form_data.user}") @@ -205,14 +240,12 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm): try: hashed = get_password_hash(form_data.password) - user = Auths.insert_new_auth( - mail, - hashed, - cn - ) + user = Auths.insert_new_auth(mail, hashed, cn) if not user: - raise HTTPException(500, detail=ERROR_MESSAGES.CREATE_USER_ERROR) + raise HTTPException( + 500, detail=ERROR_MESSAGES.CREATE_USER_ERROR + ) except HTTPException: raise @@ -224,7 +257,9 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm): if user: token = create_token( data={"id": user.id}, - expires_delta=parse_duration(request.app.state.config.JWT_EXPIRES_IN), + expires_delta=parse_duration( + request.app.state.config.JWT_EXPIRES_IN + ), ) # Set the cookie token @@ -246,7 +281,10 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm): else: raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED) else: - raise HTTPException(400, f"User {form_data.user} does not match the record. Search result: {str(entry[f'{LDAP_ATTRIBUTE_FOR_USERNAME}'])}") + raise HTTPException( + 400, + f"User {form_data.user} does not match the record. Search result: {str(entry[f'{LDAP_ATTRIBUTE_FOR_USERNAME}'])}", + ) except Exception as e: raise HTTPException(400, detail=str(e)) @@ -325,6 +363,10 @@ async def signin(request: Request, response: Response, form_data: SigninForm): secure=WEBUI_SESSION_COOKIE_SECURE, ) + user_permissions = get_permissions( + user.id, request.app.state.config.USER_PERMISSIONS + ) + return { "token": token, "token_type": "Bearer", @@ -334,6 +376,7 @@ async def signin(request: Request, response: Response, form_data: SigninForm): "name": user.name, "role": user.role, "profile_image_url": user.profile_image_url, + "permissions": user_permissions, } else: raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED) @@ -426,6 +469,10 @@ async def signup(request: Request, response: Response, form_data: SignupForm): }, ) + user_permissions = get_permissions( + user.id, request.app.state.config.USER_PERMISSIONS + ) + return { "token": token, "token_type": "Bearer", @@ -435,6 +482,7 @@ async def signup(request: Request, response: Response, form_data: SignupForm): "name": user.name, "role": user.role, "profile_image_url": user.profile_image_url, + "permissions": user_permissions, } else: raise HTTPException(500, detail=ERROR_MESSAGES.CREATE_USER_ERROR) @@ -583,19 +631,18 @@ class LdapServerConfig(BaseModel): label: str host: str port: Optional[int] = None - attribute_for_username: str = 'uid' + attribute_for_username: str = "uid" app_dn: str app_dn_password: str search_base: str - search_filters: str = '' + search_filters: str = "" use_tls: bool = True certificate_path: Optional[str] = None - ciphers: Optional[str] = 'ALL' + ciphers: Optional[str] = "ALL" + @router.get("/admin/config/ldap/server", response_model=LdapServerConfig) -async def get_ldap_server( - request: Request, user=Depends(get_admin_user) -): +async def get_ldap_server(request: Request, user=Depends(get_admin_user)): return { "label": request.app.state.config.LDAP_SERVER_LABEL, "host": request.app.state.config.LDAP_SERVER_HOST, @@ -607,26 +654,38 @@ async def get_ldap_server( "search_filters": request.app.state.config.LDAP_SEARCH_FILTERS, "use_tls": request.app.state.config.LDAP_USE_TLS, "certificate_path": request.app.state.config.LDAP_CA_CERT_FILE, - "ciphers": request.app.state.config.LDAP_CIPHERS + "ciphers": request.app.state.config.LDAP_CIPHERS, } + @router.post("/admin/config/ldap/server") async def update_ldap_server( request: Request, form_data: LdapServerConfig, user=Depends(get_admin_user) ): - required_fields = ['label', 'host', 'attribute_for_username', 'app_dn', 'app_dn_password', 'search_base'] + required_fields = [ + "label", + "host", + "attribute_for_username", + "app_dn", + "app_dn_password", + "search_base", + ] for key in required_fields: value = getattr(form_data, key) if not value: raise HTTPException(400, detail=f"Required field {key} is empty") if form_data.use_tls and not form_data.certificate_path: - raise HTTPException(400, detail="TLS is enabled but certificate file path is missing") + raise HTTPException( + 400, detail="TLS is enabled but certificate file path is missing" + ) request.app.state.config.LDAP_SERVER_LABEL = form_data.label request.app.state.config.LDAP_SERVER_HOST = form_data.host request.app.state.config.LDAP_SERVER_PORT = form_data.port - request.app.state.config.LDAP_ATTRIBUTE_FOR_USERNAME = form_data.attribute_for_username + request.app.state.config.LDAP_ATTRIBUTE_FOR_USERNAME = ( + form_data.attribute_for_username + ) request.app.state.config.LDAP_APP_DN = form_data.app_dn request.app.state.config.LDAP_APP_PASSWORD = form_data.app_dn_password request.app.state.config.LDAP_SEARCH_BASE = form_data.search_base @@ -646,18 +705,23 @@ async def update_ldap_server( "search_filters": request.app.state.config.LDAP_SEARCH_FILTERS, "use_tls": request.app.state.config.LDAP_USE_TLS, "certificate_path": request.app.state.config.LDAP_CA_CERT_FILE, - "ciphers": request.app.state.config.LDAP_CIPHERS + "ciphers": request.app.state.config.LDAP_CIPHERS, } + @router.get("/admin/config/ldap") async def get_ldap_config(request: Request, user=Depends(get_admin_user)): return {"ENABLE_LDAP": request.app.state.config.ENABLE_LDAP} + class LdapConfigForm(BaseModel): enable_ldap: Optional[bool] = None + @router.post("/admin/config/ldap") -async def update_ldap_config(request: Request, form_data: LdapConfigForm, user=Depends(get_admin_user)): +async def update_ldap_config( + request: Request, form_data: LdapConfigForm, user=Depends(get_admin_user) +): request.app.state.config.ENABLE_LDAP = form_data.enable_ldap return {"ENABLE_LDAP": request.app.state.config.ENABLE_LDAP} diff --git a/backend/open_webui/utils/access_control.py b/backend/open_webui/utils/access_control.py index 93774a4ae..316c09193 100644 --- a/backend/open_webui/utils/access_control.py +++ b/backend/open_webui/utils/access_control.py @@ -2,6 +2,38 @@ from typing import Optional, Union, List, Dict from open_webui.apps.webui.models.groups import Groups +def get_permissions( + user_id: str, + default_permissions: Dict[str, bool] = {}, +) -> dict: + """ + Get all permissions for a user by combining the permissions of all groups the user is a member of. + If a permission is defined in multiple groups, the most permissive value is used. + """ + + def merge_permissions( + permissions: Dict[str, bool], new_permissions: Dict[str, bool] + ) -> Dict[str, bool]: + """Merge two permission dictionaries, keeping the most permissive value.""" + for key, value in new_permissions.items(): + if key not in permissions: + permissions[key] = value + else: + permissions[key] = ( + permissions[key] or value + ) # Use the most permissive value + + return permissions + + user_groups = Groups.get_groups_by_member_id(user_id) + user_permissions = default_permissions.copy() + + for group in user_groups: + user_permissions = merge_permissions(user_permissions, group.permissions) + + return user_permissions + + def has_permission( user_id: str, permission_key: str,