Skip to content

Commit d6acc58

Browse files
authored
Merge pull request #54 from sacha-development-stuff/codex/fix-nameerror-in-legacy-server-tests
Fix OAuth discovery fallbacks for legacy servers
2 parents 738f2c5 + 219b71f commit d6acc58

File tree

2 files changed

+60
-43
lines changed

2 files changed

+60
-43
lines changed

src/mcp/client/auth/oauth2.py

Lines changed: 58 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -251,14 +251,20 @@ def _create_registration_request(self, metadata: OAuthMetadata | None = None) ->
251251
headers={"Content-Type": "application/json"},
252252
)
253253

254-
async def _handle_registration_response(self, response: httpx.Response) -> None:
254+
async def _handle_registration_response(
255+
self, response: httpx.Response
256+
) -> OAuthClientInformationFull:
255257
if response.status_code not in (200, 201):
256258
await response.aread()
257259
raise OAuthRegistrationError(f"Registration failed: {response.status_code} {response.text}")
258260
content = await response.aread()
259261
client_info = OAuthClientInformationFull.model_validate_json(content)
260262
self._client_info = client_info
261263
await self.storage.set_client_info(client_info)
264+
context = getattr(self, "context", None)
265+
if context is not None:
266+
context.client_info = client_info
267+
return client_info
262268

263269
def _apply_client_auth(
264270
self,
@@ -315,6 +321,18 @@ def __init__(
315321
)
316322
self._initialized = False
317323

324+
def _build_protected_resource_discovery_urls(self, resource_metadata_url: str | None) -> list[str]:
325+
"""Build the list of PRM discovery URLs with legacy fallbacks."""
326+
return build_protected_resource_metadata_discovery_urls(
327+
resource_metadata_url, self.context.server_url
328+
)
329+
330+
def _get_discovery_urls(self, server_url: str | None = None) -> list[str]:
331+
"""Build OAuth authorization server discovery URLs with legacy fallbacks."""
332+
return build_oauth_authorization_server_metadata_discovery_urls(
333+
server_url, self.context.server_url
334+
)
335+
318336
async def _handle_protected_resource_response(self, response: httpx.Response) -> bool:
319337
"""
320338
Handle protected resource metadata discovery response.
@@ -324,28 +342,30 @@ async def _handle_protected_resource_response(self, response: httpx.Response) ->
324342
Returns:
325343
True if metadata was successfully discovered, False if we should try next URL
326344
"""
327-
if response.status_code == 200:
328-
try:
329-
content = await response.aread()
330-
metadata = ProtectedResourceMetadata.model_validate_json(content)
331-
self.context.protected_resource_metadata = metadata
332-
if metadata.authorization_servers: # pragma: no branch
333-
self.context.auth_server_url = str(metadata.authorization_servers[0])
334-
return True
335-
336-
except ValidationError: # pragma: no cover
337-
# Invalid metadata - try next URL
338-
logger.warning(f"Invalid protected resource metadata at {response.request.url}")
339-
return False
340-
elif response.status_code == 404: # pragma: no cover
341-
# Not found - try next URL in fallback chain
342-
logger.debug(f"Protected resource metadata not found at {response.request.url}, trying next URL")
343-
return False
344-
else:
345-
# Other error - fail immediately
346-
raise OAuthFlowError(
347-
f"Protected Resource Metadata request failed: {response.status_code}"
348-
) # pragma: no cover
345+
metadata = await handle_protected_resource_response(response)
346+
if metadata:
347+
self.context.protected_resource_metadata = metadata
348+
if metadata.authorization_servers: # pragma: no branch
349+
self.context.auth_server_url = str(metadata.authorization_servers[0])
350+
return True
351+
352+
logger.debug(
353+
"Protected resource metadata discovery failed with status %s at %s",
354+
response.status_code,
355+
response.request.url,
356+
)
357+
return False
358+
359+
async def _handle_oauth_metadata_response(
360+
self, response: httpx.Response
361+
) -> tuple[bool, OAuthMetadata | None]:
362+
ok, asm = await handle_auth_metadata_response(response)
363+
if asm:
364+
self.context.oauth_metadata = asm
365+
self._metadata = asm
366+
if self.context.client_metadata.scope is None and asm.scopes_supported is not None:
367+
self.context.client_metadata.scope = " ".join(asm.scopes_supported)
368+
return ok, asm
349369

350370
async def _perform_authorization(self) -> httpx.Request:
351371
"""Perform the authorization flow."""
@@ -560,34 +580,33 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
560580
self._metadata = None
561581

562582
# Step 1: Discover protected resource metadata (SEP-985 with fallback support)
563-
prm_discovery_urls = build_protected_resource_metadata_discovery_urls(
564-
www_auth_resource_metadata_url, self.context.server_url
583+
prm_discovery_urls = self._build_protected_resource_discovery_urls(
584+
www_auth_resource_metadata_url
565585
)
566586

567587
for url in prm_discovery_urls: # pragma: no branch
568-
discovery_request = create_oauth_metadata_request(url)
588+
discovery_request = self._create_oauth_metadata_request(url)
569589
discovery_response = yield discovery_request
570590

571-
prm = await handle_protected_resource_response(discovery_response)
572-
if prm:
573-
self.context.protected_resource_metadata = prm
574-
if prm.authorization_servers: # pragma: no branch
575-
self.context.auth_server_url = str(prm.authorization_servers[0])
591+
handled = await self._handle_protected_resource_response(discovery_response)
592+
if handled:
576593
break
577594

578-
logger.debug(f"Protected resource metadata discovery failed: {url}")
579-
580595
# Step 2: Discover OAuth Authorization Server Metadata (OASM) (with fallback for legacy servers)
581-
asm_discovery_urls = build_oauth_authorization_server_metadata_discovery_urls(
582-
self.context.auth_server_url, self.context.server_url
583-
)
596+
asm_discovery_urls = self._get_discovery_urls(self.context.auth_server_url)
584597

585598
authorization_metadata: OAuthMetadata | None = None
586599
for url in asm_discovery_urls: # pragma: no branch
587-
oauth_metadata_request = create_oauth_metadata_request(url)
600+
oauth_metadata_request = self._create_oauth_metadata_request(url)
588601
oauth_metadata_response = yield oauth_metadata_request
589602

590-
ok, asm = await handle_auth_metadata_response(oauth_metadata_response)
603+
result = await self._handle_oauth_metadata_response(oauth_metadata_response)
604+
if isinstance(result, tuple):
605+
ok, asm = result
606+
else:
607+
ok = bool(result) if result is not None else True
608+
asm = self.context.oauth_metadata or self._metadata
609+
591610
if not ok:
592611
break
593612
if asm:
@@ -615,9 +634,7 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
615634
self.context.get_authorization_base_url(self.context.server_url),
616635
)
617636
registration_response = yield registration_request
618-
client_information = await handle_registration_response(registration_response)
619-
self.context.client_info = client_information
620-
await self.context.storage.set_client_info(client_information)
637+
await self._handle_registration_response(registration_response)
621638

622639
# Step 5: Perform authorization and complete token exchange
623640
token_response = yield await self._perform_authorization()

tests/client/test_auth.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1364,7 +1364,7 @@ async def callback_handler() -> tuple[str, str | None]:
13641364
)
13651365

13661366
# Mock authorization
1367-
provider._perform_authorization_code_grant = mock.AsyncMock(
1367+
provider._perform_authorization_code_grant = AsyncMock(
13681368
return_value=("test_auth_code", "test_code_verifier")
13691369
)
13701370

@@ -1470,7 +1470,7 @@ async def callback_handler() -> tuple[str, str | None]:
14701470
request=oauth_metadata_request,
14711471
)
14721472

1473-
provider._perform_authorization_code_grant = mock.AsyncMock(
1473+
provider._perform_authorization_code_grant = AsyncMock(
14741474
return_value=("test_auth_code", "test_code_verifier")
14751475
)
14761476

0 commit comments

Comments
 (0)