|
11 | 11 | from pydantic import AnyHttpUrl, AnyUrl |
12 | 12 |
|
13 | 13 | from mcp.client.auth import OAuthClientProvider, PKCEParameters |
| 14 | +from mcp.client.auth.exceptions import OAuthFlowError |
14 | 15 | from mcp.client.auth.utils import ( |
15 | 16 | build_oauth_authorization_server_metadata_discovery_urls, |
16 | 17 | build_protected_resource_metadata_discovery_urls, |
@@ -818,6 +819,88 @@ async def test_resource_param_included_with_protected_resource_metadata(self, oa |
818 | 819 | assert "resource=" in content |
819 | 820 |
|
820 | 821 |
|
| 822 | +class TestResourceValidation: |
| 823 | + """Test PRM resource validation in OAuthClientProvider.""" |
| 824 | + |
| 825 | + @pytest.mark.anyio |
| 826 | + async def test_rejects_mismatched_resource(self, client_metadata, mock_storage): |
| 827 | + """Client must reject PRM resource that doesn't match server URL.""" |
| 828 | + provider = OAuthClientProvider( |
| 829 | + server_url="https://api.example.com/v1/mcp", |
| 830 | + client_metadata=client_metadata, |
| 831 | + storage=mock_storage, |
| 832 | + ) |
| 833 | + provider._initialized = True |
| 834 | + |
| 835 | + prm = ProtectedResourceMetadata( |
| 836 | + resource=AnyHttpUrl("https://evil.example.com/mcp"), |
| 837 | + authorization_servers=[AnyHttpUrl("https://auth.example.com")], |
| 838 | + ) |
| 839 | + with pytest.raises(OAuthFlowError, match="does not match expected"): |
| 840 | + await provider._validate_resource_match(prm) |
| 841 | + |
| 842 | + @pytest.mark.anyio |
| 843 | + async def test_accepts_matching_resource(self, client_metadata, mock_storage): |
| 844 | + """Client must accept PRM resource that matches server URL.""" |
| 845 | + provider = OAuthClientProvider( |
| 846 | + server_url="https://api.example.com/v1/mcp", |
| 847 | + client_metadata=client_metadata, |
| 848 | + storage=mock_storage, |
| 849 | + ) |
| 850 | + provider._initialized = True |
| 851 | + |
| 852 | + prm = ProtectedResourceMetadata( |
| 853 | + resource=AnyHttpUrl("https://api.example.com/v1/mcp"), |
| 854 | + authorization_servers=[AnyHttpUrl("https://auth.example.com")], |
| 855 | + ) |
| 856 | + # Should not raise |
| 857 | + await provider._validate_resource_match(prm) |
| 858 | + |
| 859 | + @pytest.mark.anyio |
| 860 | + async def test_custom_validate_resource_url_callback(self, client_metadata, mock_storage): |
| 861 | + """Custom callback overrides default validation.""" |
| 862 | + callback_called_with: list[tuple[str, str | None]] = [] |
| 863 | + |
| 864 | + async def custom_validate(server_url: str, prm_resource: str | None) -> None: |
| 865 | + callback_called_with.append((server_url, prm_resource)) |
| 866 | + |
| 867 | + provider = OAuthClientProvider( |
| 868 | + server_url="https://api.example.com/v1/mcp", |
| 869 | + client_metadata=client_metadata, |
| 870 | + storage=mock_storage, |
| 871 | + validate_resource_url=custom_validate, |
| 872 | + ) |
| 873 | + provider._initialized = True |
| 874 | + |
| 875 | + # This would normally fail default validation (different origin), |
| 876 | + # but custom callback accepts it |
| 877 | + prm = ProtectedResourceMetadata( |
| 878 | + resource=AnyHttpUrl("https://evil.example.com/mcp"), |
| 879 | + authorization_servers=[AnyHttpUrl("https://auth.example.com")], |
| 880 | + ) |
| 881 | + await provider._validate_resource_match(prm) |
| 882 | + assert len(callback_called_with) == 1 |
| 883 | + assert callback_called_with[0][0] == "https://api.example.com/v1/mcp" |
| 884 | + assert callback_called_with[0][1] == "https://evil.example.com/mcp" |
| 885 | + |
| 886 | + @pytest.mark.anyio |
| 887 | + async def test_accepts_root_url_with_trailing_slash(self, client_metadata, mock_storage): |
| 888 | + """Root URLs with trailing slash normalization should match.""" |
| 889 | + provider = OAuthClientProvider( |
| 890 | + server_url="https://api.example.com", |
| 891 | + client_metadata=client_metadata, |
| 892 | + storage=mock_storage, |
| 893 | + ) |
| 894 | + provider._initialized = True |
| 895 | + |
| 896 | + prm = ProtectedResourceMetadata( |
| 897 | + resource=AnyHttpUrl("https://api.example.com/"), |
| 898 | + authorization_servers=[AnyHttpUrl("https://auth.example.com")], |
| 899 | + ) |
| 900 | + # Should not raise despite trailing slash difference |
| 901 | + await provider._validate_resource_match(prm) |
| 902 | + |
| 903 | + |
821 | 904 | class TestRegistrationResponse: |
822 | 905 | """Test client registration response handling.""" |
823 | 906 |
|
|
0 commit comments