Skip to content

Commit 3ff6257

Browse files
committed
refactor: pull out oauth handle protected resource response
1 parent a2e96d9 commit 3ff6257

File tree

2 files changed

+64
-1
lines changed

2 files changed

+64
-1
lines changed

src/mcp/client/auth/oauth2.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
extract_scope_from_www_auth,
2727
get_client_metadata_scopes,
2828
get_discovery_urls,
29+
handle_protected_resource_response,
2930
)
3031
from mcp.client.streamable_http import MCP_PROTOCOL_VERSION
3132
from mcp.shared.auth import (
@@ -502,8 +503,15 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
502503
"GET", url, headers={MCP_PROTOCOL_VERSION: LATEST_PROTOCOL_VERSION}
503504
)
504505
discovery_response = yield discovery_request
505-
discovery_success = await self._handle_protected_resource_response(discovery_response)
506+
discovery_success, prm, auth_server_url = await handle_protected_resource_response(
507+
discovery_response
508+
)
506509
if discovery_success:
510+
assert prm is not None
511+
assert auth_server_url is not None
512+
self.context.protected_resource_metadata = prm
513+
if prm.authorization_servers:
514+
self.context.auth_server_url = auth_server_url
507515
break
508516

509517
if not discovery_success:

src/mcp/client/auth/utils.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,15 @@
1+
import logging
12
import re
23
from urllib.parse import urljoin, urlparse
34

45
from httpx import Response
6+
from pydantic import ValidationError
57

8+
from mcp.client.auth import OAuthFlowError
69
from mcp.shared.auth import ProtectedResourceMetadata
710

11+
logger = logging.getLogger(__name__)
12+
813

914
def extract_field_from_www_auth(response: Response, field_name: str) -> str | None:
1015
"""
@@ -134,3 +139,53 @@ def get_discovery_urls(auth_server_url: str) -> list[str]:
134139
urls.append(oidc_fallback)
135140

136141
return urls
142+
143+
144+
async def handle_protected_resource_response(
145+
response: Response,
146+
) -> tuple[bool, ProtectedResourceMetadata | None, str | None]:
147+
"""
148+
Handle protected resource metadata discovery response.
149+
150+
Per SEP-985, supports fallback when discovery fails at one URL.
151+
152+
Returns:
153+
True if metadata was successfully discovered, False if we should try next URL
154+
"""
155+
if response.status_code == 200:
156+
try:
157+
content = await response.aread()
158+
metadata = ProtectedResourceMetadata.model_validate_json(content)
159+
auth_server_url: str | None = None
160+
if metadata.authorization_servers:
161+
auth_server_url = str(metadata.authorization_servers[0])
162+
return True, metadata, auth_server_url
163+
164+
except ValidationError:
165+
# Invalid metadata - try next URL
166+
logger.warning(f"Invalid protected resource metadata at {response.request.url}")
167+
return False, None, None
168+
elif response.status_code == 404:
169+
# Not found - try next URL in fallback chain
170+
logger.debug(f"Protected resource metadata not found at {response.request.url}, trying next URL")
171+
return False, None, None
172+
else:
173+
# Other error - fail immediately
174+
raise OAuthFlowError(f"Protected Resource Metadata request failed: {response.status_code}")
175+
176+
177+
# async def discovery_process(discovery_urls: list[str]) -> AsyncGenerator[Request, Response]:
178+
# discovery_success = False
179+
# prm, auth_url = None, None
180+
# for url in discovery_urls:
181+
# discovery_request = Request("GET", url, headers={MCP_PROTOCOL_VERSION: LATEST_PROTOCOL_VERSION})
182+
# discovery_response, prm, auth_url = yield discovery_request
183+
#
184+
# discovery_success = await handle_protected_resource_response(discovery_response)
185+
# if discovery_success:
186+
# break
187+
#
188+
# if discovery_success:
189+
# return
190+
# else:
191+
# raise OAuthFlowError("Protected resource metadata discovery failed: no valid metadata found")

0 commit comments

Comments
 (0)