open-webui/backend/open_webui/test/test_oauth_google_groups.py

267 lines
11 KiB
Python

import pytest
from unittest.mock import AsyncMock, patch, MagicMock
import aiohttp
from open_webui.utils.oauth import OAuthManager
from open_webui.config import AppConfig
class TestOAuthGoogleGroups:
"""Basic tests for Google OAuth Groups functionality"""
def setup_method(self):
"""Setup test fixtures"""
self.oauth_manager = OAuthManager(app=MagicMock())
@pytest.mark.asyncio
async def test_fetch_google_groups_success(self):
"""Test successful Google groups fetching with proper aiohttp mocking"""
# Mock response data from Google Cloud Identity API
mock_response_data = {
"memberships": [
{
"groupKey": {"id": "admin@company.com"},
"group": "groups/123",
"displayName": "Admin Group"
},
{
"groupKey": {"id": "users@company.com"},
"group": "groups/456",
"displayName": "Users Group"
}
]
}
# Create properly structured async mocks
mock_response = MagicMock()
mock_response.status = 200
mock_response.json = AsyncMock(return_value=mock_response_data)
# Mock the async context manager for session.get()
mock_get_context = MagicMock()
mock_get_context.__aenter__ = AsyncMock(return_value=mock_response)
mock_get_context.__aexit__ = AsyncMock(return_value=None)
# Mock the session
mock_session = MagicMock()
mock_session.get = MagicMock(return_value=mock_get_context)
# Mock the async context manager for ClientSession
mock_session_context = MagicMock()
mock_session_context.__aenter__ = AsyncMock(return_value=mock_session)
mock_session_context.__aexit__ = AsyncMock(return_value=None)
with patch("aiohttp.ClientSession", return_value=mock_session_context):
groups = await self.oauth_manager._fetch_google_groups_via_cloud_identity(
access_token="test_token",
user_email="user@company.com"
)
# Verify the results
assert groups == ["admin@company.com", "users@company.com"]
# Verify the HTTP call was made correctly
mock_session.get.assert_called_once()
call_args = mock_session.get.call_args
# Check the URL contains the user email (URL encoded)
url_arg = call_args[0][0] # First positional argument
assert "user%40company.com" in url_arg # @ is encoded as %40
assert "searchTransitiveGroups" in url_arg
# Check headers contain the bearer token
headers_arg = call_args[1]["headers"] # headers keyword argument
assert headers_arg["Authorization"] == "Bearer test_token"
assert headers_arg["Content-Type"] == "application/json"
@pytest.mark.asyncio
async def test_fetch_google_groups_api_error(self):
"""Test handling of API errors when fetching groups"""
# Mock failed response
mock_response = MagicMock()
mock_response.status = 403
mock_response.text = AsyncMock(return_value="Permission denied")
# Mock the async context manager for session.get()
mock_get_context = MagicMock()
mock_get_context.__aenter__ = AsyncMock(return_value=mock_response)
mock_get_context.__aexit__ = AsyncMock(return_value=None)
# Mock the session
mock_session = MagicMock()
mock_session.get = MagicMock(return_value=mock_get_context)
# Mock the async context manager for ClientSession
mock_session_context = MagicMock()
mock_session_context.__aenter__ = AsyncMock(return_value=mock_session)
mock_session_context.__aexit__ = AsyncMock(return_value=None)
with patch("aiohttp.ClientSession", return_value=mock_session_context):
groups = await self.oauth_manager._fetch_google_groups_via_cloud_identity(
access_token="test_token",
user_email="user@company.com"
)
# Should return empty list on error
assert groups == []
@pytest.mark.asyncio
async def test_fetch_google_groups_network_error(self):
"""Test handling of network errors when fetching groups"""
# Mock the session that raises an exception when get() is called
mock_session = MagicMock()
mock_session.get.side_effect = aiohttp.ClientError("Network error")
# Mock the async context manager for ClientSession
mock_session_context = MagicMock()
mock_session_context.__aenter__ = AsyncMock(return_value=mock_session)
mock_session_context.__aexit__ = AsyncMock(return_value=None)
with patch("aiohttp.ClientSession", return_value=mock_session_context):
groups = await self.oauth_manager._fetch_google_groups_via_cloud_identity(
access_token="test_token",
user_email="user@company.com"
)
# Should return empty list on network error
assert groups == []
@pytest.mark.asyncio
async def test_get_user_role_with_google_groups(self):
"""Test role assignment using Google groups"""
# Mock configuration
mock_config = MagicMock()
mock_config.ENABLE_OAUTH_ROLE_MANAGEMENT = True
mock_config.OAUTH_ROLES_CLAIM = "groups"
mock_config.OAUTH_ALLOWED_ROLES = ["users@company.com"]
mock_config.OAUTH_ADMIN_ROLES = ["admin@company.com"]
mock_config.DEFAULT_USER_ROLE = "pending"
mock_config.OAUTH_EMAIL_CLAIM = "email"
user_data = {"email": "user@company.com"}
# Mock Google OAuth scope check and Users class
with patch("open_webui.utils.oauth.auth_manager_config", mock_config), \
patch("open_webui.utils.oauth.GOOGLE_OAUTH_SCOPE") as mock_scope, \
patch("open_webui.utils.oauth.Users") as mock_users, \
patch.object(self.oauth_manager, "_fetch_google_groups_via_cloud_identity") as mock_fetch:
mock_scope.value = "openid email profile https://www.googleapis.com/auth/cloud-identity.groups.readonly"
mock_fetch.return_value = ["admin@company.com", "users@company.com"]
mock_users.get_num_users.return_value = 5 # Not first user
role = await self.oauth_manager.get_user_role(
user=None,
user_data=user_data,
provider="google",
access_token="test_token"
)
# Should assign admin role since user is in admin group
assert role == "admin"
mock_fetch.assert_called_once_with("test_token", "user@company.com")
@pytest.mark.asyncio
async def test_get_user_role_fallback_to_claims(self):
"""Test fallback to traditional claims when Google groups fail"""
mock_config = MagicMock()
mock_config.ENABLE_OAUTH_ROLE_MANAGEMENT = True
mock_config.OAUTH_ROLES_CLAIM = "groups"
mock_config.OAUTH_ALLOWED_ROLES = ["users"]
mock_config.OAUTH_ADMIN_ROLES = ["admin"]
mock_config.DEFAULT_USER_ROLE = "pending"
mock_config.OAUTH_EMAIL_CLAIM = "email"
user_data = {
"email": "user@company.com",
"groups": ["users"]
}
with patch("open_webui.utils.oauth.auth_manager_config", mock_config), \
patch("open_webui.utils.oauth.GOOGLE_OAUTH_SCOPE") as mock_scope, \
patch("open_webui.utils.oauth.Users") as mock_users, \
patch.object(self.oauth_manager, "_fetch_google_groups_via_cloud_identity") as mock_fetch:
# Mock scope without Cloud Identity
mock_scope.value = "openid email profile"
mock_users.get_num_users.return_value = 5 # Not first user
role = await self.oauth_manager.get_user_role(
user=None,
user_data=user_data,
provider="google",
access_token="test_token"
)
# Should use traditional claims since Cloud Identity scope not present
assert role == "user"
mock_fetch.assert_not_called()
@pytest.mark.asyncio
async def test_get_user_role_non_google_provider(self):
"""Test that non-Google providers use traditional claims"""
mock_config = MagicMock()
mock_config.ENABLE_OAUTH_ROLE_MANAGEMENT = True
mock_config.OAUTH_ROLES_CLAIM = "roles"
mock_config.OAUTH_ALLOWED_ROLES = ["user"]
mock_config.OAUTH_ADMIN_ROLES = ["admin"]
mock_config.DEFAULT_USER_ROLE = "pending"
user_data = {"roles": ["user"]}
with patch("open_webui.utils.oauth.auth_manager_config", mock_config), \
patch("open_webui.utils.oauth.Users") as mock_users, \
patch.object(self.oauth_manager, "_fetch_google_groups_via_cloud_identity") as mock_fetch:
mock_users.get_num_users.return_value = 5 # Not first user
role = await self.oauth_manager.get_user_role(
user=None,
user_data=user_data,
provider="microsoft",
access_token="test_token"
)
# Should use traditional claims for non-Google providers
assert role == "user"
mock_fetch.assert_not_called()
@pytest.mark.asyncio
async def test_update_user_groups_with_google_groups(self):
"""Test group management using Google groups from user_data"""
mock_config = MagicMock()
mock_config.OAUTH_GROUPS_CLAIM = "groups"
mock_config.OAUTH_BLOCKED_GROUPS = "[]"
mock_config.ENABLE_OAUTH_GROUP_CREATION = False
# Mock user with Google groups data
mock_user = MagicMock()
mock_user.id = "user123"
user_data = {
"google_groups": ["developers@company.com", "employees@company.com"]
}
# Mock existing groups and user groups
mock_existing_group = MagicMock()
mock_existing_group.name = "developers@company.com"
mock_existing_group.id = "group1"
mock_existing_group.user_ids = []
mock_existing_group.permissions = {"read": True}
mock_existing_group.description = "Developers group"
with patch("open_webui.utils.oauth.auth_manager_config", mock_config), \
patch("open_webui.utils.oauth.Groups") as mock_groups:
mock_groups.get_groups_by_member_id.return_value = []
mock_groups.get_groups.return_value = [mock_existing_group]
await self.oauth_manager.update_user_groups(
user=mock_user,
user_data=user_data,
default_permissions={"read": True}
)
# Should use Google groups instead of traditional claims
mock_groups.get_groups_by_member_id.assert_called_once_with("user123")
mock_groups.update_group_by_id.assert_called()