diff --git a/backend/open_webui/models/groups.py b/backend/open_webui/models/groups.py index df79284cf..096041e40 100644 --- a/backend/open_webui/models/groups.py +++ b/backend/open_webui/models/groups.py @@ -207,9 +207,39 @@ class GroupTable: except Exception: return False - def sync_user_groups_by_group_names( + def create_groups_by_group_names( self, user_id: str, group_names: list[str] - ) -> bool: + ) -> list[GroupModel]: + + # check for existing groups + existing_groups = self.get_groups() + existing_group_names = {group.name for group in existing_groups} + + new_groups = [] + + with get_db() as db: + for group_name in group_names: + if group_name not in existing_group_names: + new_group = GroupModel( + id=str(uuid.uuid4()), + user_id=user_id, + name=group_name, + description="", + created_at=int(time.time()), + updated_at=int(time.time()), + ) + try: + result = Group(**new_group.model_dump()) + db.add(result) + db.commit() + db.refresh(result) + new_groups.append(GroupModel.model_validate(result)) + except Exception as e: + log.exception(e) + continue + return new_groups + + def sync_groups_by_group_names(self, user_id: str, group_names: list[str]) -> bool: with get_db() as db: try: groups = db.query(Group).filter(Group.name.in_(group_names)).all() diff --git a/backend/open_webui/routers/auths.py b/backend/open_webui/routers/auths.py index 7414e3a86..60a12db4b 100644 --- a/backend/open_webui/routers/auths.py +++ b/backend/open_webui/routers/auths.py @@ -228,18 +228,23 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm): if not connection_app.bind(): raise HTTPException(400, detail="Application account bind failed") - ENABLE_LDAP_GROUP_MANAGEMENT = request.app.state.config.ENABLE_LDAP_GROUP_MANAGEMENT + ENABLE_LDAP_GROUP_MANAGEMENT = ( + request.app.state.config.ENABLE_LDAP_GROUP_MANAGEMENT + ) + ENABLE_LDAP_GROUP_CREATION = request.app.state.config.ENABLE_LDAP_GROUP_CREATION LDAP_ATTRIBUTE_FOR_GROUPS = request.app.state.config.LDAP_ATTRIBUTE_FOR_GROUPS - + search_attributes = [ f"{LDAP_ATTRIBUTE_FOR_USERNAME}", f"{LDAP_ATTRIBUTE_FOR_MAIL}", "cn", ] - + if ENABLE_LDAP_GROUP_MANAGEMENT: search_attributes.append(f"{LDAP_ATTRIBUTE_FOR_GROUPS}") - log.info(f"LDAP Group Management enabled. Adding {LDAP_ATTRIBUTE_FOR_GROUPS} to search attributes") + log.info( + f"LDAP Group Management enabled. Adding {LDAP_ATTRIBUTE_FOR_GROUPS} to search attributes" + ) log.info(f"LDAP search attributes: {search_attributes}") @@ -273,55 +278,64 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm): if ENABLE_LDAP_GROUP_MANAGEMENT and LDAP_ATTRIBUTE_FOR_GROUPS in entry: group_dns = entry[LDAP_ATTRIBUTE_FOR_GROUPS] log.info(f"LDAP raw group DNs for user {username}: {group_dns}") - + if group_dns: log.info(f"LDAP group_dns original: {group_dns}") log.info(f"LDAP group_dns type: {type(group_dns)}") log.info(f"LDAP group_dns length: {len(group_dns)}") - - if hasattr(group_dns, 'value'): + + if hasattr(group_dns, "value"): group_dns = group_dns.value log.info(f"Extracted .value property: {group_dns}") - elif hasattr(group_dns, '__iter__') and not isinstance(group_dns, (str, bytes)): + elif hasattr(group_dns, "__iter__") and not isinstance( + group_dns, (str, bytes) + ): group_dns = list(group_dns) log.info(f"Converted to list: {group_dns}") - elif not isinstance(group_dns, list): - group_dns = [group_dns] - + if isinstance(group_dns, list): group_dns = [str(item) for item in group_dns] else: group_dns = [str(group_dns)] - - log.info(f"LDAP group_dns after processing - type: {type(group_dns)}, length: {len(group_dns)}") - - for i, group_dn in enumerate(group_dns): - group_dn_str = str(group_dn) - log.info(f"Processing group DN #{i+1}: {group_dn_str}") - + + log.info( + f"LDAP group_dns after processing - type: {type(group_dns)}, length: {len(group_dns)}" + ) + + for group_idx, group_dn in enumerate(group_dns): + group_dn = str(group_dn) + log.info(f"Processing group DN #{group_idx + 1}: {group_dn}") + try: - cn_part = None - dn_parts = group_dn_str.split(',') - log.debug(f"DN parts: {dn_parts}") - - for part in dn_parts: - part = part.strip() - if part.upper().startswith('CN='): - cn_part = part[3:] + group_cn = None + + for item in group_dn.split(","): + item = item.strip() + if item.upper().startswith("CN="): + group_cn = item[3:] break - - if cn_part: - user_groups.append(cn_part) + + if group_cn: + user_groups.append(group_cn) + else: - log.warning(f"Could not extract CN from group DN: {group_dn_str}") + log.warning( + f"Could not extract CN from group DN: {group_dn}" + ) except Exception as e: - log.warning(f"Failed to extract group name from DN {group_dn_str}: {e}") - - log.info(f"LDAP groups for user {username}: {user_groups} (total: {len(user_groups)})") + log.warning( + f"Failed to extract group name from DN {group_dn}: {e}" + ) + + log.info( + f"LDAP groups for user {username}: {user_groups} (total: {len(user_groups)})" + ) else: log.info(f"No groups found for user {username}") elif ENABLE_LDAP_GROUP_MANAGEMENT: - log.warning(f"LDAP Group Management enabled but {LDAP_ATTRIBUTE_FOR_GROUPS} attribute not found in user entry") + log.warning( + f"LDAP Group Management enabled but {LDAP_ATTRIBUTE_FOR_GROUPS} attribute not found in user entry" + ) if username == form_data.user.lower(): connection_user = Connection( @@ -398,26 +412,19 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm): user.id, request.app.state.config.USER_PERMISSIONS ) - if ENABLE_LDAP_GROUP_MANAGEMENT and user_groups and request.app.state.config.ENABLE_LDAP_GROUP_CREATION: - from open_webui.models.groups import GroupForm - existing_groups = Groups.get_groups() - existing_group_names = [grp.name for grp in existing_groups] - log.info(f"Existing groups: {existing_group_names}") - - for i, g in enumerate(user_groups): - if not any(grp.name == g for grp in existing_groups): - try: - Groups.insert_new_group(user.id, GroupForm(name=g, description=f"{LDAP_SERVER_LABEL}")) - log.info(f"Successfully created group '{g}'") - except Exception as e: - log.error(f"Failed to create group '{g}': {e}") - else: - log.info(f"Group {g} already exists") + if ( + user.role != "admin" + and ENABLE_LDAP_GROUP_MANAGEMENT + and user_groups + ): + if ENABLE_LDAP_GROUP_CREATION: + Groups.create_groups_by_group_names(user.id, user_groups) - if ENABLE_LDAP_GROUP_MANAGEMENT and user_groups and user.role != "admin": try: - Groups.sync_user_groups_by_group_names(user.id, user_groups) - log.info(f"Successfully synced groups for user {user.id}: {user_groups}") + Groups.sync_groups_by_group_names(user.id, user_groups) + log.info( + f"Successfully synced groups for user {user.id}: {user_groups}" + ) except Exception as e: log.error(f"Failed to sync groups for user {user.id}: {e}") @@ -473,7 +480,7 @@ async def signin(request: Request, response: Response, form_data: SigninForm): group_names = [name.strip() for name in group_names if name.strip()] if group_names: - Groups.sync_user_groups_by_group_names(user.id, group_names) + Groups.sync_groups_by_group_names(user.id, group_names) elif WEBUI_AUTH == False: admin_email = "admin@localhost"