|
11 | 11 | from pydantic import AnyHttpUrl, AnyUrl |
12 | 12 |
|
13 | 13 | from mcp.client.auth import OAuthClientProvider, PKCEParameters |
| 14 | +from mcp.client.auth.exceptions import OAuthFlowError |
14 | 15 | from mcp.client.auth.utils import ( |
15 | 16 | build_oauth_authorization_server_metadata_discovery_urls, |
16 | 17 | build_protected_resource_metadata_discovery_urls, |
@@ -818,6 +819,137 @@ async def test_resource_param_included_with_protected_resource_metadata(self, oa |
818 | 819 | assert "resource=" in content |
819 | 820 |
|
820 | 821 |
|
| 822 | +class TestResourceValidation: |
| 823 | + """Test PRM resource validation in OAuthClientProvider.""" |
| 824 | + |
| 825 | + @pytest.mark.anyio |
| 826 | + async def test_rejects_mismatched_resource( |
| 827 | + self, client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage |
| 828 | + ) -> None: |
| 829 | + """Client must reject PRM resource that doesn't match server URL.""" |
| 830 | + provider = OAuthClientProvider( |
| 831 | + server_url="https://api.example.com/v1/mcp", |
| 832 | + client_metadata=client_metadata, |
| 833 | + storage=mock_storage, |
| 834 | + ) |
| 835 | + provider._initialized = True |
| 836 | + |
| 837 | + prm = ProtectedResourceMetadata( |
| 838 | + resource=AnyHttpUrl("https://evil.example.com/mcp"), |
| 839 | + authorization_servers=[AnyHttpUrl("https://auth.example.com")], |
| 840 | + ) |
| 841 | + with pytest.raises(OAuthFlowError, match="does not match expected"): |
| 842 | + await provider._validate_resource_match(prm) |
| 843 | + |
| 844 | + @pytest.mark.anyio |
| 845 | + async def test_accepts_matching_resource( |
| 846 | + self, client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage |
| 847 | + ) -> None: |
| 848 | + """Client must accept PRM resource that matches server URL.""" |
| 849 | + provider = OAuthClientProvider( |
| 850 | + server_url="https://api.example.com/v1/mcp", |
| 851 | + client_metadata=client_metadata, |
| 852 | + storage=mock_storage, |
| 853 | + ) |
| 854 | + provider._initialized = True |
| 855 | + |
| 856 | + prm = ProtectedResourceMetadata( |
| 857 | + resource=AnyHttpUrl("https://api.example.com/v1/mcp"), |
| 858 | + authorization_servers=[AnyHttpUrl("https://auth.example.com")], |
| 859 | + ) |
| 860 | + # Should not raise |
| 861 | + await provider._validate_resource_match(prm) |
| 862 | + |
| 863 | + @pytest.mark.anyio |
| 864 | + async def test_custom_validate_resource_url_callback( |
| 865 | + self, client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage |
| 866 | + ) -> None: |
| 867 | + """Custom callback overrides default validation.""" |
| 868 | + callback_called_with: list[tuple[str, str | None]] = [] |
| 869 | + |
| 870 | + async def custom_validate(server_url: str, prm_resource: str | None) -> None: |
| 871 | + callback_called_with.append((server_url, prm_resource)) |
| 872 | + |
| 873 | + provider = OAuthClientProvider( |
| 874 | + server_url="https://api.example.com/v1/mcp", |
| 875 | + client_metadata=client_metadata, |
| 876 | + storage=mock_storage, |
| 877 | + validate_resource_url=custom_validate, |
| 878 | + ) |
| 879 | + provider._initialized = True |
| 880 | + |
| 881 | + # This would normally fail default validation (different origin), |
| 882 | + # but custom callback accepts it |
| 883 | + prm = ProtectedResourceMetadata( |
| 884 | + resource=AnyHttpUrl("https://evil.example.com/mcp"), |
| 885 | + authorization_servers=[AnyHttpUrl("https://auth.example.com")], |
| 886 | + ) |
| 887 | + await provider._validate_resource_match(prm) |
| 888 | + assert len(callback_called_with) == 1 |
| 889 | + assert callback_called_with[0][0] == "https://api.example.com/v1/mcp" |
| 890 | + assert callback_called_with[0][1] == "https://evil.example.com/mcp" |
| 891 | + |
| 892 | + @pytest.mark.anyio |
| 893 | + async def test_accepts_root_url_with_trailing_slash( |
| 894 | + self, client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage |
| 895 | + ) -> None: |
| 896 | + """Root URLs with trailing slash normalization should match.""" |
| 897 | + provider = OAuthClientProvider( |
| 898 | + server_url="https://api.example.com", |
| 899 | + client_metadata=client_metadata, |
| 900 | + storage=mock_storage, |
| 901 | + ) |
| 902 | + provider._initialized = True |
| 903 | + |
| 904 | + prm = ProtectedResourceMetadata( |
| 905 | + resource=AnyHttpUrl("https://api.example.com/"), |
| 906 | + authorization_servers=[AnyHttpUrl("https://auth.example.com")], |
| 907 | + ) |
| 908 | + # Should not raise despite trailing slash difference |
| 909 | + await provider._validate_resource_match(prm) |
| 910 | + |
| 911 | + @pytest.mark.anyio |
| 912 | + async def test_accepts_server_url_with_trailing_slash( |
| 913 | + self, client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage |
| 914 | + ) -> None: |
| 915 | + """Server URL with trailing slash should match PRM resource.""" |
| 916 | + provider = OAuthClientProvider( |
| 917 | + server_url="https://api.example.com/v1/mcp/", |
| 918 | + client_metadata=client_metadata, |
| 919 | + storage=mock_storage, |
| 920 | + ) |
| 921 | + provider._initialized = True |
| 922 | + |
| 923 | + prm = ProtectedResourceMetadata( |
| 924 | + resource=AnyHttpUrl("https://api.example.com/v1/mcp"), |
| 925 | + authorization_servers=[AnyHttpUrl("https://auth.example.com")], |
| 926 | + ) |
| 927 | + # Should not raise - both normalize to the same URL with trailing slash |
| 928 | + await provider._validate_resource_match(prm) |
| 929 | + |
| 930 | + @pytest.mark.anyio |
| 931 | + async def test_get_resource_url_uses_canonical_when_prm_mismatches( |
| 932 | + self, client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage |
| 933 | + ) -> None: |
| 934 | + """get_resource_url falls back to canonical URL when PRM resource doesn't match.""" |
| 935 | + provider = OAuthClientProvider( |
| 936 | + server_url="https://api.example.com/v1/mcp", |
| 937 | + client_metadata=client_metadata, |
| 938 | + storage=mock_storage, |
| 939 | + ) |
| 940 | + provider._initialized = True |
| 941 | + |
| 942 | + # Set PRM with a resource that is NOT a parent of the server URL |
| 943 | + provider.context.protected_resource_metadata = ProtectedResourceMetadata( |
| 944 | + resource=AnyHttpUrl("https://other.example.com/mcp"), |
| 945 | + authorization_servers=[AnyHttpUrl("https://auth.example.com")], |
| 946 | + ) |
| 947 | + |
| 948 | + # get_resource_url should return the canonical server URL, not the PRM resource |
| 949 | + resource = provider.context.get_resource_url() |
| 950 | + assert resource == "https://api.example.com/v1/mcp" |
| 951 | + |
| 952 | + |
821 | 953 | class TestRegistrationResponse: |
822 | 954 | """Test client registration response handling.""" |
823 | 955 |
|
@@ -963,7 +1095,7 @@ async def test_auth_flow_with_no_tokens(self, oauth_provider: OAuthClientProvide |
963 | 1095 | # Send a successful discovery response with minimal protected resource metadata |
964 | 1096 | discovery_response = httpx.Response( |
965 | 1097 | 200, |
966 | | - content=b'{"resource": "https://api.example.com/mcp", "authorization_servers": ["https://auth.example.com"]}', |
| 1098 | + content=b'{"resource": "https://api.example.com/v1/mcp", "authorization_servers": ["https://auth.example.com"]}', |
967 | 1099 | request=discovery_request, |
968 | 1100 | ) |
969 | 1101 |
|
@@ -1116,7 +1248,7 @@ async def test_token_exchange_accepts_201_status( |
1116 | 1248 | # Send a successful discovery response with minimal protected resource metadata |
1117 | 1249 | discovery_response = httpx.Response( |
1118 | 1250 | 200, |
1119 | | - content=b'{"resource": "https://api.example.com/mcp", "authorization_servers": ["https://auth.example.com"]}', |
| 1251 | + content=b'{"resource": "https://api.example.com/v1/mcp", "authorization_servers": ["https://auth.example.com"]}', |
1120 | 1252 | request=discovery_request, |
1121 | 1253 | ) |
1122 | 1254 |
|
|
0 commit comments