Skip to content

Commit 2a784e4

Browse files
committed
check prm resource before storing metadata and add tests
1 parent 9eae96a commit 2a784e4

File tree

2 files changed

+64
-4
lines changed

2 files changed

+64
-4
lines changed

src/mcp/client/auth/oauth2.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -146,9 +146,7 @@ def get_resource_url(self) -> str:
146146

147147
# If PRM provides a resource that's a valid parent, use it
148148
if self.protected_resource_metadata and self.protected_resource_metadata.resource:
149-
prm_resource = str(self.protected_resource_metadata.resource)
150-
if check_resource_allowed(requested_resource=resource, configured_resource=prm_resource):
151-
resource = prm_resource
149+
resource = str(self.protected_resource_metadata.resource)
152150

153151
return resource
154152

@@ -292,6 +290,13 @@ async def _handle_protected_resource_response(self, response: httpx.Response) ->
292290
try:
293291
content = await response.aread()
294292
metadata = ProtectedResourceMetadata.model_validate_json(content)
293+
# Validate resource field BEFORE storing metadata per RFC 9728 Section 3.3.
294+
if not check_resource_allowed(
295+
requested_resource=self.context.server_url,
296+
configured_resource=str(metadata.resource),
297+
):
298+
return False
299+
295300
self.context.protected_resource_metadata = metadata
296301
if metadata.authorization_servers:
297302
self.context.auth_server_url = str(metadata.authorization_servers[0])

tests/client/test_auth.py

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -642,6 +642,61 @@ async def test_resource_param_included_with_protected_resource_metadata(self, oa
642642
content = request.content.decode()
643643
assert "resource=" in content
644644

645+
@pytest.mark.anyio
646+
async def test_reject_metadata_with_mismatched_origin(self, oauth_provider: OAuthClientProvider):
647+
"""Test RFC 9728 Section 3.3: reject metadata with different scheme, host, or port."""
648+
# Test different scheme
649+
response_wrong_scheme = httpx.Response(
650+
200,
651+
content=b'{"resource": "http://api.example.com/v1/mcp", "authorization_servers": ["https://auth.example.com"]}',
652+
)
653+
result = await oauth_provider._handle_protected_resource_response(response_wrong_scheme)
654+
assert result is False
655+
assert oauth_provider.context.protected_resource_metadata is None
656+
657+
# Test different host
658+
response_wrong_host = httpx.Response(
659+
200,
660+
content=b'{"resource": "https://evil.example.com/v1/mcp", "authorization_servers": ["https://auth.example.com"]}',
661+
)
662+
result = await oauth_provider._handle_protected_resource_response(response_wrong_host)
663+
assert result is False
664+
assert oauth_provider.context.protected_resource_metadata is None
665+
666+
# Test different port
667+
response_wrong_port = httpx.Response(
668+
200,
669+
content=b'{"resource": "https://api.example.com:8080/v1/mcp", "authorization_servers": ["https://auth.example.com"]}',
670+
)
671+
result = await oauth_provider._handle_protected_resource_response(response_wrong_port)
672+
assert result is False
673+
674+
# Ensure no metadata was set
675+
assert oauth_provider.context.protected_resource_metadata is None
676+
677+
@pytest.mark.anyio
678+
async def test_reject_metadata_with_invalid_path_hierarchy(self, oauth_provider: OAuthClientProvider):
679+
"""Test RFC 9728 Section 3.3: reject metadata where resource is child of server URL."""
680+
681+
# Invalid: resource is child path
682+
response_child_path = httpx.Response(
683+
200,
684+
content=b'{"resource": "https://api.example.com/v1/mcp/subpath", "authorization_servers": ["https://auth.example.com"]}',
685+
)
686+
result = await oauth_provider._handle_protected_resource_response(response_child_path)
687+
assert result is False
688+
assert oauth_provider.context.protected_resource_metadata is None
689+
690+
# Valid: resource is parent path
691+
response_parent_path = httpx.Response(
692+
200,
693+
content=b'{"resource": "https://api.example.com/v1", "authorization_servers": ["https://auth.example.com"]}',
694+
)
695+
result = await oauth_provider._handle_protected_resource_response(response_parent_path)
696+
assert result is True
697+
assert oauth_provider.context.protected_resource_metadata is not None
698+
assert str(oauth_provider.context.protected_resource_metadata.resource) == "https://api.example.com/v1"
699+
645700

646701
class TestRegistrationResponse:
647702
"""Test client registration response handling."""
@@ -745,7 +800,7 @@ async def test_auth_flow_with_no_tokens(self, oauth_provider: OAuthClientProvide
745800
# Send a successful discovery response with minimal protected resource metadata
746801
discovery_response = httpx.Response(
747802
200,
748-
content=b'{"resource": "https://api.example.com/mcp", "authorization_servers": ["https://auth.example.com"]}',
803+
content=b'{"resource": "https://api.example.com/v1/mcp", "authorization_servers": ["https://auth.example.com"]}',
749804
request=discovery_request,
750805
)
751806

0 commit comments

Comments
 (0)