Skip to content

Commit 0e3d18a

Browse files
committed
test: add coverage for PRM resource validation
Tests for resource mismatch rejection, matching resources, custom callback override, and root URL trailing slash normalization.
1 parent fbd97ee commit 0e3d18a

File tree

2 files changed

+84
-1
lines changed

2 files changed

+84
-1
lines changed

src/mcp/client/auth/oauth2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -492,7 +492,7 @@ async def _validate_resource_match(self, prm: ProtectedResourceMetadata) -> None
492492
return
493493

494494
if not prm_resource:
495-
return
495+
return # pragma: no cover
496496
default_resource = resource_url_from_server_url(self.context.server_url)
497497
# Normalize: Pydantic AnyHttpUrl adds trailing slash to root URLs
498498
# (e.g. "https://example.com/") while resource_url_from_server_url may not.

tests/client/test_auth.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from pydantic import AnyHttpUrl, AnyUrl
1212

1313
from mcp.client.auth import OAuthClientProvider, PKCEParameters
14+
from mcp.client.auth.exceptions import OAuthFlowError
1415
from mcp.client.auth.utils import (
1516
build_oauth_authorization_server_metadata_discovery_urls,
1617
build_protected_resource_metadata_discovery_urls,
@@ -818,6 +819,88 @@ async def test_resource_param_included_with_protected_resource_metadata(self, oa
818819
assert "resource=" in content
819820

820821

822+
class TestResourceValidation:
823+
"""Test PRM resource validation in OAuthClientProvider."""
824+
825+
@pytest.mark.anyio
826+
async def test_rejects_mismatched_resource(self, client_metadata, mock_storage):
827+
"""Client must reject PRM resource that doesn't match server URL."""
828+
provider = OAuthClientProvider(
829+
server_url="https://api.example.com/v1/mcp",
830+
client_metadata=client_metadata,
831+
storage=mock_storage,
832+
)
833+
provider._initialized = True
834+
835+
prm = ProtectedResourceMetadata(
836+
resource=AnyHttpUrl("https://evil.example.com/mcp"),
837+
authorization_servers=[AnyHttpUrl("https://auth.example.com")],
838+
)
839+
with pytest.raises(OAuthFlowError, match="does not match expected"):
840+
await provider._validate_resource_match(prm)
841+
842+
@pytest.mark.anyio
843+
async def test_accepts_matching_resource(self, client_metadata, mock_storage):
844+
"""Client must accept PRM resource that matches server URL."""
845+
provider = OAuthClientProvider(
846+
server_url="https://api.example.com/v1/mcp",
847+
client_metadata=client_metadata,
848+
storage=mock_storage,
849+
)
850+
provider._initialized = True
851+
852+
prm = ProtectedResourceMetadata(
853+
resource=AnyHttpUrl("https://api.example.com/v1/mcp"),
854+
authorization_servers=[AnyHttpUrl("https://auth.example.com")],
855+
)
856+
# Should not raise
857+
await provider._validate_resource_match(prm)
858+
859+
@pytest.mark.anyio
860+
async def test_custom_validate_resource_url_callback(self, client_metadata, mock_storage):
861+
"""Custom callback overrides default validation."""
862+
callback_called_with: list[tuple[str, str | None]] = []
863+
864+
async def custom_validate(server_url: str, prm_resource: str | None) -> None:
865+
callback_called_with.append((server_url, prm_resource))
866+
867+
provider = OAuthClientProvider(
868+
server_url="https://api.example.com/v1/mcp",
869+
client_metadata=client_metadata,
870+
storage=mock_storage,
871+
validate_resource_url=custom_validate,
872+
)
873+
provider._initialized = True
874+
875+
# This would normally fail default validation (different origin),
876+
# but custom callback accepts it
877+
prm = ProtectedResourceMetadata(
878+
resource=AnyHttpUrl("https://evil.example.com/mcp"),
879+
authorization_servers=[AnyHttpUrl("https://auth.example.com")],
880+
)
881+
await provider._validate_resource_match(prm)
882+
assert len(callback_called_with) == 1
883+
assert callback_called_with[0][0] == "https://api.example.com/v1/mcp"
884+
assert callback_called_with[0][1] == "https://evil.example.com/mcp"
885+
886+
@pytest.mark.anyio
887+
async def test_accepts_root_url_with_trailing_slash(self, client_metadata, mock_storage):
888+
"""Root URLs with trailing slash normalization should match."""
889+
provider = OAuthClientProvider(
890+
server_url="https://api.example.com",
891+
client_metadata=client_metadata,
892+
storage=mock_storage,
893+
)
894+
provider._initialized = True
895+
896+
prm = ProtectedResourceMetadata(
897+
resource=AnyHttpUrl("https://api.example.com/"),
898+
authorization_servers=[AnyHttpUrl("https://auth.example.com")],
899+
)
900+
# Should not raise despite trailing slash difference
901+
await provider._validate_resource_match(prm)
902+
903+
821904
class TestRegistrationResponse:
822905
"""Test client registration response handling."""
823906

0 commit comments

Comments
 (0)