@@ -198,16 +198,16 @@ class TestOAuthFlow:
198198 """Test OAuth flow methods."""
199199
200200 @pytest .mark .anyio
201- async def test_protected_resource_discovery_urls (self , client_metadata , mock_storage ):
202- """Test protected resource discovery URL generation with fallback ."""
201+ async def test_protected_resource_discovery_urls_generation (self , client_metadata , mock_storage ):
202+ """Test that discovery URL generation works correctly for different server URLs ."""
203203
204204 async def redirect_handler (url : str ) -> None :
205205 pass
206206
207207 async def callback_handler () -> tuple [str , str | None ]:
208208 return "test_auth_code" , "test_state"
209209
210- # Test with path component
210+ # Test with path component - should have both path-specific and base endpoints
211211 provider = OAuthClientProvider (
212212 server_url = "https://api.example.com/api/2.0/mcp" ,
213213 client_metadata = client_metadata ,
@@ -222,7 +222,7 @@ async def callback_handler() -> tuple[str, str | None]:
222222 "https://api.example.com/.well-known/oauth-protected-resource" ,
223223 ]
224224
225- # Test without path component
225+ # Test without path component - should only have base endpoint
226226 provider = OAuthClientProvider (
227227 server_url = "https://api.example.com" ,
228228 client_metadata = client_metadata ,
@@ -594,6 +594,177 @@ async def test_auth_flow_with_no_tokens(self, oauth_provider, mock_storage):
594594 assert oauth_provider .context .current_tokens .access_token == "new_access_token"
595595 assert oauth_provider .context .token_expiry_time is not None
596596
597+ @pytest .mark .anyio
598+ async def test_auth_flow_protected_resource_fallback (self , client_metadata , mock_storage ):
599+ """Test that the OAuth flow correctly implements fallback from path-specific to base endpoint."""
600+
601+ async def redirect_handler (url : str ) -> None :
602+ pass
603+
604+ async def callback_handler () -> tuple [str , str | None ]:
605+ return "test_auth_code" , "test_state"
606+
607+ provider = OAuthClientProvider (
608+ server_url = "https://api.example.com/api/2.0/mcp" ,
609+ client_metadata = client_metadata ,
610+ storage = mock_storage ,
611+ redirect_handler = redirect_handler ,
612+ callback_handler = callback_handler ,
613+ )
614+
615+ provider .context .current_tokens = None
616+ provider .context .token_expiry_time = None
617+ provider ._initialized = True
618+
619+ test_request = httpx .Request ("GET" , "https://api.example.com/api/2.0/mcp" )
620+ auth_flow = provider .async_auth_flow (test_request )
621+
622+ # Step 1: Original request without auth
623+ request = await auth_flow .__anext__ ()
624+ assert "Authorization" not in request .headers
625+
626+ # Step 2: 401 triggers protected resource discovery - should try path-specific first
627+ response = httpx .Response (401 , request = test_request )
628+ path_discovery_request = await auth_flow .asend (response )
629+ assert (
630+ str (path_discovery_request .url )
631+ == "https://api.example.com/.well-known/oauth-protected-resource/api/2.0/mcp"
632+ )
633+
634+ # Step 3: Path-specific fails with 404 - should trigger fallback
635+ path_404_response = httpx .Response (404 , request = path_discovery_request )
636+ base_discovery_request = await auth_flow .asend (path_404_response )
637+ assert str (base_discovery_request .url ) == "https://api.example.com/.well-known/oauth-protected-resource"
638+
639+ # Step 4: Base endpoint succeeds - should store metadata and continue to OAuth discovery
640+ successful_response = httpx .Response (
641+ 200 ,
642+ content = b'{"resource": "https://api.example.com", "authorization_servers": ["https://api.example.com"]}' ,
643+ request = base_discovery_request ,
644+ )
645+
646+ # Verify the fallback worked and metadata was stored
647+ await auth_flow .asend (successful_response )
648+ assert provider .context .protected_resource_metadata is not None
649+ assert str (provider .context .protected_resource_metadata .resource ) == "https://api.example.com/"
650+
651+ # Clean up the generator
652+ try :
653+ await auth_flow .aclose ()
654+ except Exception :
655+ pass
656+
657+ @pytest .mark .anyio
658+ async def test_auth_flow_www_authenticate_no_fallback (self , client_metadata , mock_storage ):
659+ """Test that WWW-Authenticate header skips fallback logic entirely."""
660+
661+ async def redirect_handler (url : str ) -> None :
662+ pass
663+
664+ async def callback_handler () -> tuple [str , str | None ]:
665+ return "test_auth_code" , "test_state"
666+
667+ provider = OAuthClientProvider (
668+ server_url = "https://api.example.com/api/2.0/mcp" ,
669+ client_metadata = client_metadata ,
670+ storage = mock_storage ,
671+ redirect_handler = redirect_handler ,
672+ callback_handler = callback_handler ,
673+ )
674+
675+ provider .context .current_tokens = None
676+ provider .context .token_expiry_time = None
677+ provider ._initialized = True
678+
679+ test_request = httpx .Request ("GET" , "https://api.example.com/api/2.0/mcp" )
680+ auth_flow = provider .async_auth_flow (test_request )
681+
682+ # Step 1: Original request without auth
683+ request = await auth_flow .__anext__ ()
684+ assert "Authorization" not in request .headers
685+
686+ # Step 2: 401 with WWW-Authenticate should use that URL directly
687+ response = httpx .Response (
688+ 401 ,
689+ headers = {
690+ "WWW-Authenticate" : 'Bearer resource_metadata="https://custom.example.com/.well-known/oauth-protected-resource"'
691+ },
692+ request = test_request ,
693+ )
694+
695+ www_auth_request = await auth_flow .asend (response )
696+ assert str (www_auth_request .url ) == "https://custom.example.com/.well-known/oauth-protected-resource"
697+
698+ # Step 3: Should proceed directly to OAuth metadata discovery (no fallback attempted)
699+ successful_response = httpx .Response (
700+ 200 ,
701+ content = b'{"resource": "https://api.example.com/api/2.0/mcp", "authorization_servers": ["https://api.example.com"]}' ,
702+ request = www_auth_request ,
703+ )
704+
705+ await auth_flow .asend (successful_response )
706+ assert provider .context .protected_resource_metadata is not None
707+
708+ # Clean up the generator
709+ try :
710+ await auth_flow .aclose ()
711+ except Exception :
712+ pass
713+
714+ @pytest .mark .anyio
715+ async def test_auth_flow_no_fallback_on_success (self , client_metadata , mock_storage ):
716+ """Test that first successful discovery response stops the fallback process."""
717+
718+ async def redirect_handler (url : str ) -> None :
719+ pass
720+
721+ async def callback_handler () -> tuple [str , str | None ]:
722+ return "test_auth_code" , "test_state"
723+
724+ provider = OAuthClientProvider (
725+ server_url = "https://api.example.com/api/2.0/mcp" ,
726+ client_metadata = client_metadata ,
727+ storage = mock_storage ,
728+ redirect_handler = redirect_handler ,
729+ callback_handler = callback_handler ,
730+ )
731+
732+ provider .context .current_tokens = None
733+ provider .context .token_expiry_time = None
734+ provider ._initialized = True
735+
736+ test_request = httpx .Request ("GET" , "https://api.example.com/api/2.0/mcp" )
737+ auth_flow = provider .async_auth_flow (test_request )
738+
739+ # Step 1: Original request without auth
740+ request = await auth_flow .__anext__ ()
741+ assert "Authorization" not in request .headers
742+
743+ # Step 2: 401 triggers path-specific discovery
744+ response = httpx .Response (401 , request = test_request )
745+ path_discovery_request = await auth_flow .asend (response )
746+ assert (
747+ str (path_discovery_request .url )
748+ == "https://api.example.com/.well-known/oauth-protected-resource/api/2.0/mcp"
749+ )
750+
751+ # Step 3: Path-specific succeeds - should skip fallback and go to OAuth discovery
752+ successful_response = httpx .Response (
753+ 200 ,
754+ content = b'{"resource": "https://api.example.com/api/2.0/mcp", "authorization_servers": ["https://api.example.com"]}' ,
755+ request = path_discovery_request ,
756+ )
757+
758+ await auth_flow .asend (successful_response )
759+ assert provider .context .protected_resource_metadata is not None
760+ assert str (provider .context .protected_resource_metadata .resource ) == "https://api.example.com/api/2.0/mcp"
761+
762+ # Clean up the generator
763+ try :
764+ await auth_flow .aclose ()
765+ except Exception :
766+ pass
767+
597768
598769@pytest .mark .parametrize (
599770 (
0 commit comments