diff --git a/backend/open_webui/utils/oauth.py b/backend/open_webui/utils/oauth.py index d30eec3a7..c849eb25a 100644 --- a/backend/open_webui/utils/oauth.py +++ b/backend/open_webui/utils/oauth.py @@ -246,10 +246,66 @@ def get_parsed_and_base_url(server_url) -> tuple[urllib.parse.ParseResult, str]: return parsed, base_url -def get_discovery_urls(server_url) -> list[str]: - parsed, base_url = get_parsed_and_base_url(server_url) +async def get_authorization_server_discovery_urls(server_url: str) -> list[str]: + """ + https://modelcontextprotocol.io/specification/2025-03-26/basic/authorization + """ - urls = [] + authorization_servers = [] + try: + async with aiohttp.ClientSession(trust_env=True) as session: + async with session.post( + server_url, + json={"jsonrpc": "2.0", "method": "initialize", "params": {}, "id": 1}, + headers={"Content-Type": "application/json"}, + ssl=AIOHTTP_CLIENT_SESSION_SSL, + ) as response: + if response.status == 401: + match = re.search( + r'resource_metadata="([^"]+)"', + response.headers.get("WWW-Authenticate", ""), + ) + if match: + resource_metadata_url = match.group(1) + log.debug( + f"Found resource_metadata URL: {resource_metadata_url}" + ) + + # Step 2: Fetch Protected Resource metadata + async with session.get( + resource_metadata_url, ssl=AIOHTTP_CLIENT_SESSION_SSL + ) as resource_response: + if resource_response.status == 200: + resource_metadata = await resource_response.json() + + # Step 3: Extract authorization_servers + servers = resource_metadata.get( + "authorization_servers", [] + ) + if servers: + authorization_servers = servers + log.debug( + f"Discovered authorization servers: {servers}" + ) + except Exception as e: + log.debug(f"MCP Protected Resource discovery failed: {e}") + + discovery_urls = [] + for auth_server in authorization_servers: + auth_server = auth_server.rstrip("/") + discovery_urls.extend( + [ + f"{auth_server}/.well-known/oauth-authorization-server", + f"{auth_server}/.well-known/openid-configuration", + ] + ) + + return discovery_urls + + +async def get_discovery_urls(server_url) -> list[str]: + urls = await get_authorization_server_discovery_urls(server_url) + parsed, base_url = get_parsed_and_base_url(server_url) if parsed.path and parsed.path != "/": # Generate discovery URLs based on https://modelcontextprotocol.io/specification/draft/basic/authorization#authorization-server-metadata-discovery @@ -303,7 +359,7 @@ async def get_oauth_client_info_with_dynamic_client_registration( ) # Attempt to fetch OAuth server metadata to get registration endpoint & scopes - discovery_urls = get_discovery_urls(oauth_server_url) + discovery_urls = await get_discovery_urls(oauth_server_url) for url in discovery_urls: async with aiohttp.ClientSession(trust_env=True) as session: async with session.get(