Skip to content

Commit 320f58c

Browse files
committed
fix: attempt to query resource PRM before root PRM
1 parent b8cb367 commit 320f58c

File tree

2 files changed

+40
-5
lines changed

2 files changed

+40
-5
lines changed

src/mcp/client/auth.py

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -203,13 +203,24 @@ def __init__(
203203
)
204204
self._initialized = False
205205

206-
async def _discover_protected_resource(self) -> httpx.Request:
206+
def _build_well_known_path_protected_resource(self, pathname: str) -> str:
207+
"""Construct well-known path for OAuth protected resource metadata discovery."""
208+
well_known_path = f"/.well-known/oauth-protected-resource{pathname}"
209+
if pathname.endswith("/"):
210+
# Strip trailing slash from pathname to avoid double slashes
211+
well_known_path = well_known_path[:-1]
212+
return well_known_path
213+
214+
async def _discover_protected_resource(self, is_fallback: bool = False) -> httpx.Request:
207215
"""Build discovery request for protected resource metadata."""
208216
auth_base_url = self.context.get_authorization_base_url(self.context.server_url)
209-
url = urljoin(auth_base_url, "/.well-known/oauth-protected-resource")
217+
auth_url_parsed = urlparse(self.context.server_url)
218+
pathname = auth_url_parsed.path if not is_fallback else "/"
219+
well_known_path = self._build_well_known_path_protected_resource(pathname)
220+
url = urljoin(auth_base_url, well_known_path)
210221
return httpx.Request("GET", url, headers={MCP_PROTOCOL_VERSION: LATEST_PROTOCOL_VERSION})
211222

212-
async def _handle_protected_resource_response(self, response: httpx.Response) -> None:
223+
async def _handle_protected_resource_response(self, response: httpx.Response) -> bool:
213224
"""Handle discovery response."""
214225
if response.status_code == 200:
215226
try:
@@ -218,8 +229,10 @@ async def _handle_protected_resource_response(self, response: httpx.Response) ->
218229
self.context.protected_resource_metadata = metadata
219230
if metadata.authorization_servers:
220231
self.context.auth_server_url = str(metadata.authorization_servers[0])
232+
return True
221233
except ValidationError:
222234
pass
235+
return False
223236

224237
def _build_well_known_path(self, pathname: str) -> str:
225238
"""Construct well-known path for OAuth metadata discovery."""
@@ -497,7 +510,13 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
497510
# Step 1: Discover protected resource metadata (spec revision 2025-06-18)
498511
discovery_request = await self._discover_protected_resource()
499512
discovery_response = yield discovery_request
500-
await self._handle_protected_resource_response(discovery_response)
513+
discovery_handled = await self._handle_protected_resource_response(discovery_response)
514+
515+
# If path-aware discovery failed, try fallback to root
516+
if not discovery_handled:
517+
discovery_request = await self._discover_protected_resource(is_fallback=True)
518+
discovery_response = yield discovery_request
519+
await self._handle_protected_resource_response(discovery_response)
501520

502521
# Step 2: Discover OAuth metadata (with fallback for legacy servers)
503522
oauth_request = await self._discover_oauth_metadata()
@@ -549,7 +568,13 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
549568
# Step 1: Discover protected resource metadata (spec revision 2025-06-18)
550569
discovery_request = await self._discover_protected_resource()
551570
discovery_response = yield discovery_request
552-
await self._handle_protected_resource_response(discovery_response)
571+
discovery_handled = await self._handle_protected_resource_response(discovery_response)
572+
573+
# If path-aware discovery failed, try fallback to root
574+
if not discovery_handled:
575+
discovery_request = await self._discover_protected_resource(is_fallback=True)
576+
discovery_response = yield discovery_request
577+
await self._handle_protected_resource_response(discovery_response)
553578

554579
# Step 2: Discover OAuth metadata (with fallback for legacy servers)
555580
oauth_request = await self._discover_oauth_metadata()

tests/client/test_auth.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,16 @@ async def test_discover_protected_resource_request(self, oauth_provider):
201201
request = await oauth_provider._discover_protected_resource()
202202

203203
assert request.method == "GET"
204+
assert str(request.url) == "https://api.example.com/.well-known/oauth-protected-resource/v1/mcp"
205+
assert "mcp-protocol-version" in request.headers
206+
207+
@pytest.mark.anyio
208+
async def test_discover_protected_resource_request_fallback(self, oauth_provider):
209+
"""Test protected resource discovery request building after a failure to discover metadata at the standard endpoint."""
210+
request = await oauth_provider._discover_protected_resource(is_fallback=True)
211+
212+
assert request.method == "GET"
213+
# Falls back to the root
204214
assert str(request.url) == "https://api.example.com/.well-known/oauth-protected-resource"
205215
assert "mcp-protocol-version" in request.headers
206216

0 commit comments

Comments
 (0)