Skip to content

Commit 2561b33

Browse files
committed
add unit test
1 parent f11104b commit 2561b33

File tree

1 file changed

+97
-0
lines changed

1 file changed

+97
-0
lines changed

tests/client/test_auth.py

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -423,6 +423,103 @@ async def test_oauth_discovery_fallback_conditions(self, oauth_provider: OAuthCl
423423
except StopAsyncIteration:
424424
pass # Expected - generator should complete
425425

426+
@pytest.mark.anyio
427+
async def test_prm_endpoint_not_implemented_fallthrough(self, oauth_provider: OAuthClientProvider):
428+
"""Test that PRM endpoint failures fall through without raising errors (backward compatibility)."""
429+
# Ensure no tokens are stored
430+
oauth_provider.context.current_tokens = None
431+
oauth_provider.context.token_expiry_time = None
432+
oauth_provider._initialized = True
433+
434+
# Mock client info to skip DCR
435+
oauth_provider.context.client_info = OAuthClientInformationFull(
436+
client_id="existing_client",
437+
redirect_uris=[AnyUrl("http://localhost:3030/callback")],
438+
)
439+
440+
# Create a test request
441+
test_request = httpx.Request("GET", "https://api.example.com/v1/mcp")
442+
443+
# Mock the auth flow
444+
auth_flow = oauth_provider.async_auth_flow(test_request)
445+
446+
# First request should be the original request without auth header
447+
request = await auth_flow.__anext__()
448+
assert "Authorization" not in request.headers
449+
450+
# Send a 401 response to trigger the OAuth flow
451+
response = httpx.Response(
452+
401,
453+
headers={
454+
"WWW-Authenticate": 'Bearer resource_metadata="https://api.example.com/.well-known/oauth-protected-resource"'
455+
},
456+
request=test_request,
457+
)
458+
459+
# Next request should be to discover protected resource metadata
460+
discovery_request = await auth_flow.asend(response)
461+
assert str(discovery_request.url) == "https://api.example.com/.well-known/oauth-protected-resource"
462+
assert discovery_request.method == "GET"
463+
464+
# Send a 404 response - PRM endpoint not implemented (legacy server)
465+
# This should NOT raise an error, but fall through to legacy OAuth discovery
466+
prm_404_response = httpx.Response(
467+
404,
468+
content=b"Not Found",
469+
request=discovery_request,
470+
)
471+
472+
# Next request should fall through to legacy OAuth discovery fallback
473+
# Since PRM failed, it should try OAuth metadata discovery
474+
oauth_metadata_request = await auth_flow.asend(prm_404_response)
475+
assert oauth_metadata_request.method == "GET"
476+
# Should try one of the fallback URLs
477+
assert ".well-known/oauth-authorization-server" in str(oauth_metadata_request.url)
478+
479+
# Send a successful OAuth metadata response to continue the flow
480+
oauth_metadata_response = httpx.Response(
481+
200,
482+
content=(
483+
b'{"issuer": "https://api.example.com", '
484+
b'"authorization_endpoint": "https://api.example.com/authorize", '
485+
b'"token_endpoint": "https://api.example.com/token"}'
486+
),
487+
request=oauth_metadata_request,
488+
)
489+
490+
# Mock the authorization process
491+
oauth_provider._perform_authorization_code_grant = mock.AsyncMock(
492+
return_value=("test_auth_code", "test_code_verifier")
493+
)
494+
495+
# Next request should be token exchange (mocked authorization, so goes straight to token)
496+
token_request = await auth_flow.asend(oauth_metadata_response)
497+
assert str(token_request.url) == "https://api.example.com/token"
498+
assert token_request.method == "POST"
499+
500+
# Send a successful token response
501+
token_response = httpx.Response(
502+
200,
503+
content=(
504+
b'{"access_token": "new_access_token", "token_type": "Bearer", "expires_in": 3600, '
505+
b'"refresh_token": "new_refresh_token"}'
506+
),
507+
request=token_request,
508+
)
509+
510+
# After OAuth flow completes, the original request is retried with auth header
511+
final_request = await auth_flow.asend(token_response)
512+
assert final_request.headers["Authorization"] == "Bearer new_access_token"
513+
assert final_request.method == "GET"
514+
assert str(final_request.url) == "https://api.example.com/v1/mcp"
515+
516+
# Send final success response to properly close the generator
517+
final_response = httpx.Response(200, request=final_request)
518+
try:
519+
await auth_flow.asend(final_response)
520+
except StopAsyncIteration:
521+
pass # Expected - generator should complete
522+
426523
@pytest.mark.anyio
427524
async def test_handle_metadata_response_success(self, oauth_provider: OAuthClientProvider):
428525
"""Test successful metadata response handling."""

0 commit comments

Comments
 (0)