refac/fix: oauth discovery urls
Co-Authored-By: jamie-dit <80016430+jamie-dit@users.noreply.github.com>
This commit is contained in:
@@ -246,10 +246,66 @@ def get_parsed_and_base_url(server_url) -> tuple[urllib.parse.ParseResult, str]:
|
||||
return parsed, base_url
|
||||
|
||||
|
||||
def get_discovery_urls(server_url) -> list[str]:
|
||||
parsed, base_url = get_parsed_and_base_url(server_url)
|
||||
async def get_authorization_server_discovery_urls(server_url: str) -> list[str]:
|
||||
"""
|
||||
https://modelcontextprotocol.io/specification/2025-03-26/basic/authorization
|
||||
"""
|
||||
|
||||
urls = []
|
||||
authorization_servers = []
|
||||
try:
|
||||
async with aiohttp.ClientSession(trust_env=True) as session:
|
||||
async with session.post(
|
||||
server_url,
|
||||
json={"jsonrpc": "2.0", "method": "initialize", "params": {}, "id": 1},
|
||||
headers={"Content-Type": "application/json"},
|
||||
ssl=AIOHTTP_CLIENT_SESSION_SSL,
|
||||
) as response:
|
||||
if response.status == 401:
|
||||
match = re.search(
|
||||
r'resource_metadata="([^"]+)"',
|
||||
response.headers.get("WWW-Authenticate", ""),
|
||||
)
|
||||
if match:
|
||||
resource_metadata_url = match.group(1)
|
||||
log.debug(
|
||||
f"Found resource_metadata URL: {resource_metadata_url}"
|
||||
)
|
||||
|
||||
# Step 2: Fetch Protected Resource metadata
|
||||
async with session.get(
|
||||
resource_metadata_url, ssl=AIOHTTP_CLIENT_SESSION_SSL
|
||||
) as resource_response:
|
||||
if resource_response.status == 200:
|
||||
resource_metadata = await resource_response.json()
|
||||
|
||||
# Step 3: Extract authorization_servers
|
||||
servers = resource_metadata.get(
|
||||
"authorization_servers", []
|
||||
)
|
||||
if servers:
|
||||
authorization_servers = servers
|
||||
log.debug(
|
||||
f"Discovered authorization servers: {servers}"
|
||||
)
|
||||
except Exception as e:
|
||||
log.debug(f"MCP Protected Resource discovery failed: {e}")
|
||||
|
||||
discovery_urls = []
|
||||
for auth_server in authorization_servers:
|
||||
auth_server = auth_server.rstrip("/")
|
||||
discovery_urls.extend(
|
||||
[
|
||||
f"{auth_server}/.well-known/oauth-authorization-server",
|
||||
f"{auth_server}/.well-known/openid-configuration",
|
||||
]
|
||||
)
|
||||
|
||||
return discovery_urls
|
||||
|
||||
|
||||
async def get_discovery_urls(server_url) -> list[str]:
|
||||
urls = await get_authorization_server_discovery_urls(server_url)
|
||||
parsed, base_url = get_parsed_and_base_url(server_url)
|
||||
|
||||
if parsed.path and parsed.path != "/":
|
||||
# Generate discovery URLs based on https://modelcontextprotocol.io/specification/draft/basic/authorization#authorization-server-metadata-discovery
|
||||
@@ -303,7 +359,7 @@ async def get_oauth_client_info_with_dynamic_client_registration(
|
||||
)
|
||||
|
||||
# Attempt to fetch OAuth server metadata to get registration endpoint & scopes
|
||||
discovery_urls = get_discovery_urls(oauth_server_url)
|
||||
discovery_urls = await get_discovery_urls(oauth_server_url)
|
||||
for url in discovery_urls:
|
||||
async with aiohttp.ClientSession(trust_env=True) as session:
|
||||
async with session.get(
|
||||
|
||||
Reference in New Issue
Block a user