refac/fix: oauth discovery urls

Co-Authored-By: jamie-dit <80016430+jamie-dit@users.noreply.github.com>
This commit is contained in:
Timothy Jaeryang Baek
2026-01-01 14:01:18 +04:00
parent f981843852
commit 89565c58c6

View File

@@ -246,10 +246,66 @@ def get_parsed_and_base_url(server_url) -> tuple[urllib.parse.ParseResult, str]:
return parsed, base_url
def get_discovery_urls(server_url) -> list[str]:
parsed, base_url = get_parsed_and_base_url(server_url)
async def get_authorization_server_discovery_urls(server_url: str) -> list[str]:
"""
https://modelcontextprotocol.io/specification/2025-03-26/basic/authorization
"""
urls = []
authorization_servers = []
try:
async with aiohttp.ClientSession(trust_env=True) as session:
async with session.post(
server_url,
json={"jsonrpc": "2.0", "method": "initialize", "params": {}, "id": 1},
headers={"Content-Type": "application/json"},
ssl=AIOHTTP_CLIENT_SESSION_SSL,
) as response:
if response.status == 401:
match = re.search(
r'resource_metadata="([^"]+)"',
response.headers.get("WWW-Authenticate", ""),
)
if match:
resource_metadata_url = match.group(1)
log.debug(
f"Found resource_metadata URL: {resource_metadata_url}"
)
# Step 2: Fetch Protected Resource metadata
async with session.get(
resource_metadata_url, ssl=AIOHTTP_CLIENT_SESSION_SSL
) as resource_response:
if resource_response.status == 200:
resource_metadata = await resource_response.json()
# Step 3: Extract authorization_servers
servers = resource_metadata.get(
"authorization_servers", []
)
if servers:
authorization_servers = servers
log.debug(
f"Discovered authorization servers: {servers}"
)
except Exception as e:
log.debug(f"MCP Protected Resource discovery failed: {e}")
discovery_urls = []
for auth_server in authorization_servers:
auth_server = auth_server.rstrip("/")
discovery_urls.extend(
[
f"{auth_server}/.well-known/oauth-authorization-server",
f"{auth_server}/.well-known/openid-configuration",
]
)
return discovery_urls
async def get_discovery_urls(server_url) -> list[str]:
urls = await get_authorization_server_discovery_urls(server_url)
parsed, base_url = get_parsed_and_base_url(server_url)
if parsed.path and parsed.path != "/":
# Generate discovery URLs based on https://modelcontextprotocol.io/specification/draft/basic/authorization#authorization-server-metadata-discovery
@@ -303,7 +359,7 @@ async def get_oauth_client_info_with_dynamic_client_registration(
)
# Attempt to fetch OAuth server metadata to get registration endpoint & scopes
discovery_urls = get_discovery_urls(oauth_server_url)
discovery_urls = await get_discovery_urls(oauth_server_url)
for url in discovery_urls:
async with aiohttp.ClientSession(trust_env=True) as session:
async with session.get(