mirror of
https://github.com/open-webui/open-webui
synced 2025-06-22 18:07:17 +00:00
267 lines
11 KiB
Python
267 lines
11 KiB
Python
import pytest
|
|
from unittest.mock import AsyncMock, patch, MagicMock
|
|
import aiohttp
|
|
from open_webui.utils.oauth import OAuthManager
|
|
from open_webui.config import AppConfig
|
|
|
|
|
|
class TestOAuthGoogleGroups:
|
|
"""Basic tests for Google OAuth Groups functionality"""
|
|
|
|
def setup_method(self):
|
|
"""Setup test fixtures"""
|
|
self.oauth_manager = OAuthManager(app=MagicMock())
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_fetch_google_groups_success(self):
|
|
"""Test successful Google groups fetching with proper aiohttp mocking"""
|
|
# Mock response data from Google Cloud Identity API
|
|
mock_response_data = {
|
|
"memberships": [
|
|
{
|
|
"groupKey": {"id": "admin@company.com"},
|
|
"group": "groups/123",
|
|
"displayName": "Admin Group"
|
|
},
|
|
{
|
|
"groupKey": {"id": "users@company.com"},
|
|
"group": "groups/456",
|
|
"displayName": "Users Group"
|
|
}
|
|
]
|
|
}
|
|
|
|
# Create properly structured async mocks
|
|
mock_response = MagicMock()
|
|
mock_response.status = 200
|
|
mock_response.json = AsyncMock(return_value=mock_response_data)
|
|
|
|
# Mock the async context manager for session.get()
|
|
mock_get_context = MagicMock()
|
|
mock_get_context.__aenter__ = AsyncMock(return_value=mock_response)
|
|
mock_get_context.__aexit__ = AsyncMock(return_value=None)
|
|
|
|
# Mock the session
|
|
mock_session = MagicMock()
|
|
mock_session.get = MagicMock(return_value=mock_get_context)
|
|
|
|
# Mock the async context manager for ClientSession
|
|
mock_session_context = MagicMock()
|
|
mock_session_context.__aenter__ = AsyncMock(return_value=mock_session)
|
|
mock_session_context.__aexit__ = AsyncMock(return_value=None)
|
|
|
|
with patch("aiohttp.ClientSession", return_value=mock_session_context):
|
|
groups = await self.oauth_manager._fetch_google_groups_via_cloud_identity(
|
|
access_token="test_token",
|
|
user_email="user@company.com"
|
|
)
|
|
|
|
# Verify the results
|
|
assert groups == ["admin@company.com", "users@company.com"]
|
|
|
|
# Verify the HTTP call was made correctly
|
|
mock_session.get.assert_called_once()
|
|
call_args = mock_session.get.call_args
|
|
|
|
# Check the URL contains the user email (URL encoded)
|
|
url_arg = call_args[0][0] # First positional argument
|
|
assert "user%40company.com" in url_arg # @ is encoded as %40
|
|
assert "searchTransitiveGroups" in url_arg
|
|
|
|
# Check headers contain the bearer token
|
|
headers_arg = call_args[1]["headers"] # headers keyword argument
|
|
assert headers_arg["Authorization"] == "Bearer test_token"
|
|
assert headers_arg["Content-Type"] == "application/json"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_fetch_google_groups_api_error(self):
|
|
"""Test handling of API errors when fetching groups"""
|
|
# Mock failed response
|
|
mock_response = MagicMock()
|
|
mock_response.status = 403
|
|
mock_response.text = AsyncMock(return_value="Permission denied")
|
|
|
|
# Mock the async context manager for session.get()
|
|
mock_get_context = MagicMock()
|
|
mock_get_context.__aenter__ = AsyncMock(return_value=mock_response)
|
|
mock_get_context.__aexit__ = AsyncMock(return_value=None)
|
|
|
|
# Mock the session
|
|
mock_session = MagicMock()
|
|
mock_session.get = MagicMock(return_value=mock_get_context)
|
|
|
|
# Mock the async context manager for ClientSession
|
|
mock_session_context = MagicMock()
|
|
mock_session_context.__aenter__ = AsyncMock(return_value=mock_session)
|
|
mock_session_context.__aexit__ = AsyncMock(return_value=None)
|
|
|
|
with patch("aiohttp.ClientSession", return_value=mock_session_context):
|
|
groups = await self.oauth_manager._fetch_google_groups_via_cloud_identity(
|
|
access_token="test_token",
|
|
user_email="user@company.com"
|
|
)
|
|
|
|
# Should return empty list on error
|
|
assert groups == []
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_fetch_google_groups_network_error(self):
|
|
"""Test handling of network errors when fetching groups"""
|
|
# Mock the session that raises an exception when get() is called
|
|
mock_session = MagicMock()
|
|
mock_session.get.side_effect = aiohttp.ClientError("Network error")
|
|
|
|
# Mock the async context manager for ClientSession
|
|
mock_session_context = MagicMock()
|
|
mock_session_context.__aenter__ = AsyncMock(return_value=mock_session)
|
|
mock_session_context.__aexit__ = AsyncMock(return_value=None)
|
|
|
|
with patch("aiohttp.ClientSession", return_value=mock_session_context):
|
|
groups = await self.oauth_manager._fetch_google_groups_via_cloud_identity(
|
|
access_token="test_token",
|
|
user_email="user@company.com"
|
|
)
|
|
|
|
# Should return empty list on network error
|
|
assert groups == []
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_user_role_with_google_groups(self):
|
|
"""Test role assignment using Google groups"""
|
|
# Mock configuration
|
|
mock_config = MagicMock()
|
|
mock_config.ENABLE_OAUTH_ROLE_MANAGEMENT = True
|
|
mock_config.OAUTH_ROLES_CLAIM = "groups"
|
|
mock_config.OAUTH_ALLOWED_ROLES = ["users@company.com"]
|
|
mock_config.OAUTH_ADMIN_ROLES = ["admin@company.com"]
|
|
mock_config.DEFAULT_USER_ROLE = "pending"
|
|
mock_config.OAUTH_EMAIL_CLAIM = "email"
|
|
|
|
user_data = {"email": "user@company.com"}
|
|
|
|
# Mock Google OAuth scope check and Users class
|
|
with patch("open_webui.utils.oauth.auth_manager_config", mock_config), \
|
|
patch("open_webui.utils.oauth.GOOGLE_OAUTH_SCOPE") as mock_scope, \
|
|
patch("open_webui.utils.oauth.Users") as mock_users, \
|
|
patch.object(self.oauth_manager, "_fetch_google_groups_via_cloud_identity") as mock_fetch:
|
|
|
|
mock_scope.value = "openid email profile https://www.googleapis.com/auth/cloud-identity.groups.readonly"
|
|
mock_fetch.return_value = ["admin@company.com", "users@company.com"]
|
|
mock_users.get_num_users.return_value = 5 # Not first user
|
|
|
|
role = await self.oauth_manager.get_user_role(
|
|
user=None,
|
|
user_data=user_data,
|
|
provider="google",
|
|
access_token="test_token"
|
|
)
|
|
|
|
# Should assign admin role since user is in admin group
|
|
assert role == "admin"
|
|
mock_fetch.assert_called_once_with("test_token", "user@company.com")
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_user_role_fallback_to_claims(self):
|
|
"""Test fallback to traditional claims when Google groups fail"""
|
|
mock_config = MagicMock()
|
|
mock_config.ENABLE_OAUTH_ROLE_MANAGEMENT = True
|
|
mock_config.OAUTH_ROLES_CLAIM = "groups"
|
|
mock_config.OAUTH_ALLOWED_ROLES = ["users"]
|
|
mock_config.OAUTH_ADMIN_ROLES = ["admin"]
|
|
mock_config.DEFAULT_USER_ROLE = "pending"
|
|
mock_config.OAUTH_EMAIL_CLAIM = "email"
|
|
|
|
user_data = {
|
|
"email": "user@company.com",
|
|
"groups": ["users"]
|
|
}
|
|
|
|
with patch("open_webui.utils.oauth.auth_manager_config", mock_config), \
|
|
patch("open_webui.utils.oauth.GOOGLE_OAUTH_SCOPE") as mock_scope, \
|
|
patch("open_webui.utils.oauth.Users") as mock_users, \
|
|
patch.object(self.oauth_manager, "_fetch_google_groups_via_cloud_identity") as mock_fetch:
|
|
|
|
# Mock scope without Cloud Identity
|
|
mock_scope.value = "openid email profile"
|
|
mock_users.get_num_users.return_value = 5 # Not first user
|
|
|
|
role = await self.oauth_manager.get_user_role(
|
|
user=None,
|
|
user_data=user_data,
|
|
provider="google",
|
|
access_token="test_token"
|
|
)
|
|
|
|
# Should use traditional claims since Cloud Identity scope not present
|
|
assert role == "user"
|
|
mock_fetch.assert_not_called()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_user_role_non_google_provider(self):
|
|
"""Test that non-Google providers use traditional claims"""
|
|
mock_config = MagicMock()
|
|
mock_config.ENABLE_OAUTH_ROLE_MANAGEMENT = True
|
|
mock_config.OAUTH_ROLES_CLAIM = "roles"
|
|
mock_config.OAUTH_ALLOWED_ROLES = ["user"]
|
|
mock_config.OAUTH_ADMIN_ROLES = ["admin"]
|
|
mock_config.DEFAULT_USER_ROLE = "pending"
|
|
|
|
user_data = {"roles": ["user"]}
|
|
|
|
with patch("open_webui.utils.oauth.auth_manager_config", mock_config), \
|
|
patch("open_webui.utils.oauth.Users") as mock_users, \
|
|
patch.object(self.oauth_manager, "_fetch_google_groups_via_cloud_identity") as mock_fetch:
|
|
|
|
mock_users.get_num_users.return_value = 5 # Not first user
|
|
|
|
role = await self.oauth_manager.get_user_role(
|
|
user=None,
|
|
user_data=user_data,
|
|
provider="microsoft",
|
|
access_token="test_token"
|
|
)
|
|
|
|
# Should use traditional claims for non-Google providers
|
|
assert role == "user"
|
|
mock_fetch.assert_not_called()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_update_user_groups_with_google_groups(self):
|
|
"""Test group management using Google groups from user_data"""
|
|
mock_config = MagicMock()
|
|
mock_config.OAUTH_GROUPS_CLAIM = "groups"
|
|
mock_config.OAUTH_BLOCKED_GROUPS = "[]"
|
|
mock_config.ENABLE_OAUTH_GROUP_CREATION = False
|
|
|
|
# Mock user with Google groups data
|
|
mock_user = MagicMock()
|
|
mock_user.id = "user123"
|
|
|
|
user_data = {
|
|
"google_groups": ["developers@company.com", "employees@company.com"]
|
|
}
|
|
|
|
# Mock existing groups and user groups
|
|
mock_existing_group = MagicMock()
|
|
mock_existing_group.name = "developers@company.com"
|
|
mock_existing_group.id = "group1"
|
|
mock_existing_group.user_ids = []
|
|
mock_existing_group.permissions = {"read": True}
|
|
mock_existing_group.description = "Developers group"
|
|
|
|
with patch("open_webui.utils.oauth.auth_manager_config", mock_config), \
|
|
patch("open_webui.utils.oauth.Groups") as mock_groups:
|
|
|
|
mock_groups.get_groups_by_member_id.return_value = []
|
|
mock_groups.get_groups.return_value = [mock_existing_group]
|
|
|
|
await self.oauth_manager.update_user_groups(
|
|
user=mock_user,
|
|
user_data=user_data,
|
|
default_permissions={"read": True}
|
|
)
|
|
|
|
# Should use Google groups instead of traditional claims
|
|
mock_groups.get_groups_by_member_id.assert_called_once_with("user123")
|
|
mock_groups.update_group_by_id.assert_called()
|