@@ -819,135 +819,134 @@ async def test_resource_param_included_with_protected_resource_metadata(self, oa
819819 assert "resource=" in content
820820
821821
822- class TestResourceValidation :
823- """Test PRM resource validation in OAuthClientProvider."""
822+ @pytest .mark .anyio
823+ async def test_validate_resource_rejects_mismatched_resource (
824+ client_metadata : OAuthClientMetadata , mock_storage : MockTokenStorage
825+ ) -> None :
826+ """Client must reject PRM resource that doesn't match server URL."""
827+ provider = OAuthClientProvider (
828+ server_url = "https://api.example.com/v1/mcp" ,
829+ client_metadata = client_metadata ,
830+ storage = mock_storage ,
831+ )
832+ provider ._initialized = True
824833
825- @pytest .mark .anyio
826- async def test_rejects_mismatched_resource (
827- self , client_metadata : OAuthClientMetadata , mock_storage : MockTokenStorage
828- ) -> None :
829- """Client must reject PRM resource that doesn't match server URL."""
830- provider = OAuthClientProvider (
831- server_url = "https://api.example.com/v1/mcp" ,
832- client_metadata = client_metadata ,
833- storage = mock_storage ,
834- )
835- provider ._initialized = True
834+ prm = ProtectedResourceMetadata (
835+ resource = AnyHttpUrl ("https://evil.example.com/mcp" ),
836+ authorization_servers = [AnyHttpUrl ("https://auth.example.com" )],
837+ )
838+ with pytest .raises (OAuthFlowError , match = "does not match expected" ):
839+ await provider ._validate_resource_match (prm )
836840
837- prm = ProtectedResourceMetadata (
838- resource = AnyHttpUrl ("https://evil.example.com/mcp" ),
839- authorization_servers = [AnyHttpUrl ("https://auth.example.com" )],
840- )
841- with pytest .raises (OAuthFlowError , match = "does not match expected" ):
842- await provider ._validate_resource_match (prm )
843841
844- @pytest .mark .anyio
845- async def test_accepts_matching_resource (
846- self , client_metadata : OAuthClientMetadata , mock_storage : MockTokenStorage
847- ) -> None :
848- """Client must accept PRM resource that matches server URL."""
849- provider = OAuthClientProvider (
850- server_url = "https://api.example.com/v1/mcp" ,
851- client_metadata = client_metadata ,
852- storage = mock_storage ,
853- )
854- provider ._initialized = True
842+ @pytest .mark .anyio
843+ async def test_validate_resource_accepts_matching_resource (
844+ client_metadata : OAuthClientMetadata , mock_storage : MockTokenStorage
845+ ) -> None :
846+ """Client must accept PRM resource that matches server URL."""
847+ provider = OAuthClientProvider (
848+ server_url = "https://api.example.com/v1/mcp" ,
849+ client_metadata = client_metadata ,
850+ storage = mock_storage ,
851+ )
852+ provider ._initialized = True
855853
856- prm = ProtectedResourceMetadata (
857- resource = AnyHttpUrl ("https://api.example.com/v1/mcp" ),
858- authorization_servers = [AnyHttpUrl ("https://auth.example.com" )],
859- )
860- # Should not raise
861- await provider ._validate_resource_match (prm )
854+ prm = ProtectedResourceMetadata (
855+ resource = AnyHttpUrl ("https://api.example.com/v1/mcp" ),
856+ authorization_servers = [AnyHttpUrl ("https://auth.example.com" )],
857+ )
858+ # Should not raise
859+ await provider ._validate_resource_match (prm )
862860
863- @pytest .mark .anyio
864- async def test_custom_validate_resource_url_callback (
865- self , client_metadata : OAuthClientMetadata , mock_storage : MockTokenStorage
866- ) -> None :
867- """Custom callback overrides default validation."""
868- callback_called_with : list [tuple [str , str | None ]] = []
869861
870- async def custom_validate (server_url : str , prm_resource : str | None ) -> None :
871- callback_called_with .append ((server_url , prm_resource ))
862+ @pytest .mark .anyio
863+ async def test_validate_resource_custom_callback (
864+ client_metadata : OAuthClientMetadata , mock_storage : MockTokenStorage
865+ ) -> None :
866+ """Custom callback overrides default validation."""
867+ callback_called_with : list [tuple [str , str | None ]] = []
872868
873- provider = OAuthClientProvider (
874- server_url = "https://api.example.com/v1/mcp" ,
875- client_metadata = client_metadata ,
876- storage = mock_storage ,
877- validate_resource_url = custom_validate ,
878- )
879- provider ._initialized = True
869+ async def custom_validate (server_url : str , prm_resource : str | None ) -> None :
870+ callback_called_with .append ((server_url , prm_resource ))
880871
881- # This would normally fail default validation (different origin),
882- # but custom callback accepts it
883- prm = ProtectedResourceMetadata (
884- resource = AnyHttpUrl ("https://evil.example.com/mcp" ),
885- authorization_servers = [AnyHttpUrl ("https://auth.example.com" )],
886- )
887- await provider ._validate_resource_match (prm )
888- assert len (callback_called_with ) == 1
889- assert callback_called_with [0 ][0 ] == "https://api.example.com/v1/mcp"
890- assert callback_called_with [0 ][1 ] == "https://evil.example.com/mcp"
872+ provider = OAuthClientProvider (
873+ server_url = "https://api.example.com/v1/mcp" ,
874+ client_metadata = client_metadata ,
875+ storage = mock_storage ,
876+ validate_resource_url = custom_validate ,
877+ )
878+ provider ._initialized = True
891879
892- @pytest .mark .anyio
893- async def test_accepts_root_url_with_trailing_slash (
894- self , client_metadata : OAuthClientMetadata , mock_storage : MockTokenStorage
895- ) -> None :
896- """Root URLs with trailing slash normalization should match."""
897- provider = OAuthClientProvider (
898- server_url = "https://api.example.com" ,
899- client_metadata = client_metadata ,
900- storage = mock_storage ,
901- )
902- provider ._initialized = True
880+ # This would normally fail default validation (different origin),
881+ # but custom callback accepts it
882+ prm = ProtectedResourceMetadata (
883+ resource = AnyHttpUrl ("https://evil.example.com/mcp" ),
884+ authorization_servers = [AnyHttpUrl ("https://auth.example.com" )],
885+ )
886+ await provider ._validate_resource_match (prm )
887+ assert callback_called_with == snapshot ([("https://api.example.com/v1/mcp" , "https://evil.example.com/mcp" )])
903888
904- prm = ProtectedResourceMetadata (
905- resource = AnyHttpUrl ("https://api.example.com/" ),
906- authorization_servers = [AnyHttpUrl ("https://auth.example.com" )],
907- )
908- # Should not raise despite trailing slash difference
909- await provider ._validate_resource_match (prm )
910889
911- @pytest .mark .anyio
912- async def test_accepts_server_url_with_trailing_slash (
913- self , client_metadata : OAuthClientMetadata , mock_storage : MockTokenStorage
914- ) -> None :
915- """Server URL with trailing slash should match PRM resource ."""
916- provider = OAuthClientProvider (
917- server_url = "https://api.example.com/v1/mcp/ " ,
918- client_metadata = client_metadata ,
919- storage = mock_storage ,
920- )
921- provider ._initialized = True
890+ @pytest .mark .anyio
891+ async def test_validate_resource_accepts_root_url_with_trailing_slash (
892+ client_metadata : OAuthClientMetadata , mock_storage : MockTokenStorage
893+ ) -> None :
894+ """Root URLs with trailing slash normalization should match."""
895+ provider = OAuthClientProvider (
896+ server_url = "https://api.example.com" ,
897+ client_metadata = client_metadata ,
898+ storage = mock_storage ,
899+ )
900+ provider ._initialized = True
922901
923- prm = ProtectedResourceMetadata (
924- resource = AnyHttpUrl ("https://api.example.com/v1/mcp " ),
925- authorization_servers = [AnyHttpUrl ("https://auth.example.com" )],
926- )
927- # Should not raise - both normalize to the same URL with trailing slash
928- await provider ._validate_resource_match (prm )
902+ prm = ProtectedResourceMetadata (
903+ resource = AnyHttpUrl ("https://api.example.com/" ),
904+ authorization_servers = [AnyHttpUrl ("https://auth.example.com" )],
905+ )
906+ # Should not raise despite trailing slash difference
907+ await provider ._validate_resource_match (prm )
929908
930- @pytest .mark .anyio
931- async def test_get_resource_url_uses_canonical_when_prm_mismatches (
932- self , client_metadata : OAuthClientMetadata , mock_storage : MockTokenStorage
933- ) -> None :
934- """get_resource_url falls back to canonical URL when PRM resource doesn't match."""
935- provider = OAuthClientProvider (
936- server_url = "https://api.example.com/v1/mcp" ,
937- client_metadata = client_metadata ,
938- storage = mock_storage ,
939- )
940- provider ._initialized = True
941909
942- # Set PRM with a resource that is NOT a parent of the server URL
943- provider .context .protected_resource_metadata = ProtectedResourceMetadata (
944- resource = AnyHttpUrl ("https://other.example.com/mcp" ),
945- authorization_servers = [AnyHttpUrl ("https://auth.example.com" )],
946- )
910+ @pytest .mark .anyio
911+ async def test_validate_resource_accepts_server_url_with_trailing_slash (
912+ client_metadata : OAuthClientMetadata , mock_storage : MockTokenStorage
913+ ) -> None :
914+ """Server URL with trailing slash should match PRM resource."""
915+ provider = OAuthClientProvider (
916+ server_url = "https://api.example.com/v1/mcp/" ,
917+ client_metadata = client_metadata ,
918+ storage = mock_storage ,
919+ )
920+ provider ._initialized = True
921+
922+ prm = ProtectedResourceMetadata (
923+ resource = AnyHttpUrl ("https://api.example.com/v1/mcp" ),
924+ authorization_servers = [AnyHttpUrl ("https://auth.example.com" )],
925+ )
926+ # Should not raise - both normalize to the same URL with trailing slash
927+ await provider ._validate_resource_match (prm )
928+
929+
930+ @pytest .mark .anyio
931+ async def test_get_resource_url_uses_canonical_when_prm_mismatches (
932+ client_metadata : OAuthClientMetadata , mock_storage : MockTokenStorage
933+ ) -> None :
934+ """get_resource_url falls back to canonical URL when PRM resource doesn't match."""
935+ provider = OAuthClientProvider (
936+ server_url = "https://api.example.com/v1/mcp" ,
937+ client_metadata = client_metadata ,
938+ storage = mock_storage ,
939+ )
940+ provider ._initialized = True
941+
942+ # Set PRM with a resource that is NOT a parent of the server URL
943+ provider .context .protected_resource_metadata = ProtectedResourceMetadata (
944+ resource = AnyHttpUrl ("https://other.example.com/mcp" ),
945+ authorization_servers = [AnyHttpUrl ("https://auth.example.com" )],
946+ )
947947
948- # get_resource_url should return the canonical server URL, not the PRM resource
949- resource = provider .context .get_resource_url ()
950- assert resource == "https://api.example.com/v1/mcp"
948+ # get_resource_url should return the canonical server URL, not the PRM resource
949+ assert provider .context .get_resource_url () == snapshot ("https://api.example.com/v1/mcp" )
951950
952951
953952class TestRegistrationResponse :
0 commit comments