diff --git a/backend/open_webui/config.py b/backend/open_webui/config.py index e49c251a1..937edb37b 100644 --- a/backend/open_webui/config.py +++ b/backend/open_webui/config.py @@ -403,12 +403,24 @@ OAUTH_EMAIL_CLAIM = PersistentConfig( os.environ.get("OAUTH_EMAIL_CLAIM", "email"), ) +OAUTH_GROUPS_CLAIM = PersistentConfig( + "OAUTH_GROUPS_CLAIM", + "oauth.oidc.group_claim", + os.environ.get("OAUTH_GROUP_CLAIM", "groups"), +) + ENABLE_OAUTH_ROLE_MANAGEMENT = PersistentConfig( "ENABLE_OAUTH_ROLE_MANAGEMENT", "oauth.enable_role_mapping", os.environ.get("ENABLE_OAUTH_ROLE_MANAGEMENT", "False").lower() == "true", ) +ENABLE_OAUTH_GROUP_MANAGEMENT = PersistentConfig( + "ENABLE_OAUTH_GROUP_MANAGEMENT", + "oauth.enable_group_mapping", + os.environ.get("ENABLE_OAUTH_GROUP_MANAGEMENT", "False").lower() == "true", +) + OAUTH_ROLES_CLAIM = PersistentConfig( "OAUTH_ROLES_CLAIM", "oauth.roles_claim", diff --git a/backend/open_webui/utils/oauth.py b/backend/open_webui/utils/oauth.py index f0ab7a345..b84fd8248 100644 --- a/backend/open_webui/utils/oauth.py +++ b/backend/open_webui/utils/oauth.py @@ -14,13 +14,16 @@ from starlette.responses import RedirectResponse from open_webui.models.auths import Auths from open_webui.models.users import Users +from open_webui.models.groups import Groups, GroupModel, GroupUpdateForm from open_webui.config import ( DEFAULT_USER_ROLE, ENABLE_OAUTH_SIGNUP, OAUTH_MERGE_ACCOUNTS_BY_EMAIL, OAUTH_PROVIDERS, ENABLE_OAUTH_ROLE_MANAGEMENT, + ENABLE_OAUTH_GROUP_MANAGEMENT, OAUTH_ROLES_CLAIM, + OAUTH_GROUPS_CLAIM, OAUTH_EMAIL_CLAIM, OAUTH_PICTURE_CLAIM, OAUTH_USERNAME_CLAIM, @@ -44,7 +47,9 @@ auth_manager_config.DEFAULT_USER_ROLE = DEFAULT_USER_ROLE auth_manager_config.ENABLE_OAUTH_SIGNUP = ENABLE_OAUTH_SIGNUP auth_manager_config.OAUTH_MERGE_ACCOUNTS_BY_EMAIL = OAUTH_MERGE_ACCOUNTS_BY_EMAIL auth_manager_config.ENABLE_OAUTH_ROLE_MANAGEMENT = ENABLE_OAUTH_ROLE_MANAGEMENT +auth_manager_config.ENABLE_OAUTH_GROUP_MANAGEMENT = ENABLE_OAUTH_GROUP_MANAGEMENT auth_manager_config.OAUTH_ROLES_CLAIM = OAUTH_ROLES_CLAIM +auth_manager_config.OAUTH_GROUPS_CLAIM = OAUTH_GROUPS_CLAIM auth_manager_config.OAUTH_EMAIL_CLAIM = OAUTH_EMAIL_CLAIM auth_manager_config.OAUTH_PICTURE_CLAIM = OAUTH_PICTURE_CLAIM auth_manager_config.OAUTH_USERNAME_CLAIM = OAUTH_USERNAME_CLAIM @@ -118,6 +123,59 @@ class OAuthManager: role = user.role return role + + def update_user_groups(self, user, user_data, default_permissions): + oauth_claim = auth_manager_config.OAUTH_GROUPS_CLAIM + + user_oauth_groups: list[str] = user_data.get(oauth_claim, list()) + user_current_groups: list[GroupModel] = Groups.get_groups_by_member_id(user.id) + all_available_groups: list[GroupModel] = Groups.get_groups() + + print(user_oauth_groups) + print() + print(user_current_groups) + print() + print(all_available_groups) + + + # Remove groups that user is no longer a part of + for group_model in user_current_groups: + if group_model.name not in user_oauth_groups: + # Remove group from user + print(f"Found group to remove from user: {group_model.name}") + + user_ids = group_model.user_ids + user_ids = [i for i in user_ids if i != user.id] + + # In case a group is created, but perms are never assigned to the group by hitting "save" + group_permissions = group_model.permissions + if not group_permissions: + group_permissions = default_permissions + + update_form = GroupUpdateForm(name=group_model.name, description=group_model.description, + permissions=group_permissions, + user_ids=user_ids) + Groups.update_group_by_id(id=group_model.id, form_data=update_form, overwrite=False) + + + # Add user to new groups + for group_model in all_available_groups: + if group_model.name in user_oauth_groups and not any(gm.name == group_model.name for gm in user_current_groups): + # Add user to group + print(f"Found group to add to user: {group_model.name}") + + user_ids = group_model.user_ids + user_ids.append(user.id) + + # In case a group is created, but perms are never assigned to the group by hitting "save" + group_permissions = group_model.permissions + if not group_permissions: + group_permissions = default_permissions + + update_form = GroupUpdateForm(name=group_model.name, description=group_model.description, + permissions=group_permissions, + user_ids=user_ids) + Groups.update_group_by_id(id=group_model.id, form_data=update_form, overwrite=False) async def handle_login(self, provider, request): if provider not in OAUTH_PROVIDERS: @@ -254,6 +312,11 @@ class OAuthManager: expires_delta=parse_duration(auth_manager_config.JWT_EXPIRES_IN), ) + if auth_manager_config.ENABLE_OAUTH_GROUP_MANAGEMENT: + print("OAUTH GROUP MANAGEMENT ACTIVE\n\n\n\n\n\n\n") + self.update_user_groups(user=user, user_data=user_data, + default_permissions=request.app.state.config.USER_PERMISSIONS) + # Set the cookie token response.set_cookie( key="token",