Skip to content

Commit cca61b9

Browse files
refactor: convert test class to functions, use snapshots
Address review feedback: - Convert TestResourceValidation class to standalone test functions - Use inline_snapshot for assertion values
1 parent 7c3abaa commit cca61b9

File tree

1 file changed

+113
-114
lines changed

1 file changed

+113
-114
lines changed

tests/client/test_auth.py

Lines changed: 113 additions & 114 deletions
Original file line numberDiff line numberDiff line change
@@ -819,135 +819,134 @@ async def test_resource_param_included_with_protected_resource_metadata(self, oa
819819
assert "resource=" in content
820820

821821

822-
class TestResourceValidation:
823-
"""Test PRM resource validation in OAuthClientProvider."""
822+
@pytest.mark.anyio
823+
async def test_validate_resource_rejects_mismatched_resource(
824+
client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage
825+
) -> None:
826+
"""Client must reject PRM resource that doesn't match server URL."""
827+
provider = OAuthClientProvider(
828+
server_url="https://api.example.com/v1/mcp",
829+
client_metadata=client_metadata,
830+
storage=mock_storage,
831+
)
832+
provider._initialized = True
824833

825-
@pytest.mark.anyio
826-
async def test_rejects_mismatched_resource(
827-
self, client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage
828-
) -> None:
829-
"""Client must reject PRM resource that doesn't match server URL."""
830-
provider = OAuthClientProvider(
831-
server_url="https://api.example.com/v1/mcp",
832-
client_metadata=client_metadata,
833-
storage=mock_storage,
834-
)
835-
provider._initialized = True
834+
prm = ProtectedResourceMetadata(
835+
resource=AnyHttpUrl("https://evil.example.com/mcp"),
836+
authorization_servers=[AnyHttpUrl("https://auth.example.com")],
837+
)
838+
with pytest.raises(OAuthFlowError, match="does not match expected"):
839+
await provider._validate_resource_match(prm)
836840

837-
prm = ProtectedResourceMetadata(
838-
resource=AnyHttpUrl("https://evil.example.com/mcp"),
839-
authorization_servers=[AnyHttpUrl("https://auth.example.com")],
840-
)
841-
with pytest.raises(OAuthFlowError, match="does not match expected"):
842-
await provider._validate_resource_match(prm)
843841

844-
@pytest.mark.anyio
845-
async def test_accepts_matching_resource(
846-
self, client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage
847-
) -> None:
848-
"""Client must accept PRM resource that matches server URL."""
849-
provider = OAuthClientProvider(
850-
server_url="https://api.example.com/v1/mcp",
851-
client_metadata=client_metadata,
852-
storage=mock_storage,
853-
)
854-
provider._initialized = True
842+
@pytest.mark.anyio
843+
async def test_validate_resource_accepts_matching_resource(
844+
client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage
845+
) -> None:
846+
"""Client must accept PRM resource that matches server URL."""
847+
provider = OAuthClientProvider(
848+
server_url="https://api.example.com/v1/mcp",
849+
client_metadata=client_metadata,
850+
storage=mock_storage,
851+
)
852+
provider._initialized = True
855853

856-
prm = ProtectedResourceMetadata(
857-
resource=AnyHttpUrl("https://api.example.com/v1/mcp"),
858-
authorization_servers=[AnyHttpUrl("https://auth.example.com")],
859-
)
860-
# Should not raise
861-
await provider._validate_resource_match(prm)
854+
prm = ProtectedResourceMetadata(
855+
resource=AnyHttpUrl("https://api.example.com/v1/mcp"),
856+
authorization_servers=[AnyHttpUrl("https://auth.example.com")],
857+
)
858+
# Should not raise
859+
await provider._validate_resource_match(prm)
862860

863-
@pytest.mark.anyio
864-
async def test_custom_validate_resource_url_callback(
865-
self, client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage
866-
) -> None:
867-
"""Custom callback overrides default validation."""
868-
callback_called_with: list[tuple[str, str | None]] = []
869861

870-
async def custom_validate(server_url: str, prm_resource: str | None) -> None:
871-
callback_called_with.append((server_url, prm_resource))
862+
@pytest.mark.anyio
863+
async def test_validate_resource_custom_callback(
864+
client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage
865+
) -> None:
866+
"""Custom callback overrides default validation."""
867+
callback_called_with: list[tuple[str, str | None]] = []
872868

873-
provider = OAuthClientProvider(
874-
server_url="https://api.example.com/v1/mcp",
875-
client_metadata=client_metadata,
876-
storage=mock_storage,
877-
validate_resource_url=custom_validate,
878-
)
879-
provider._initialized = True
869+
async def custom_validate(server_url: str, prm_resource: str | None) -> None:
870+
callback_called_with.append((server_url, prm_resource))
880871

881-
# This would normally fail default validation (different origin),
882-
# but custom callback accepts it
883-
prm = ProtectedResourceMetadata(
884-
resource=AnyHttpUrl("https://evil.example.com/mcp"),
885-
authorization_servers=[AnyHttpUrl("https://auth.example.com")],
886-
)
887-
await provider._validate_resource_match(prm)
888-
assert len(callback_called_with) == 1
889-
assert callback_called_with[0][0] == "https://api.example.com/v1/mcp"
890-
assert callback_called_with[0][1] == "https://evil.example.com/mcp"
872+
provider = OAuthClientProvider(
873+
server_url="https://api.example.com/v1/mcp",
874+
client_metadata=client_metadata,
875+
storage=mock_storage,
876+
validate_resource_url=custom_validate,
877+
)
878+
provider._initialized = True
891879

892-
@pytest.mark.anyio
893-
async def test_accepts_root_url_with_trailing_slash(
894-
self, client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage
895-
) -> None:
896-
"""Root URLs with trailing slash normalization should match."""
897-
provider = OAuthClientProvider(
898-
server_url="https://api.example.com",
899-
client_metadata=client_metadata,
900-
storage=mock_storage,
901-
)
902-
provider._initialized = True
880+
# This would normally fail default validation (different origin),
881+
# but custom callback accepts it
882+
prm = ProtectedResourceMetadata(
883+
resource=AnyHttpUrl("https://evil.example.com/mcp"),
884+
authorization_servers=[AnyHttpUrl("https://auth.example.com")],
885+
)
886+
await provider._validate_resource_match(prm)
887+
assert callback_called_with == snapshot([("https://api.example.com/v1/mcp", "https://evil.example.com/mcp")])
903888

904-
prm = ProtectedResourceMetadata(
905-
resource=AnyHttpUrl("https://api.example.com/"),
906-
authorization_servers=[AnyHttpUrl("https://auth.example.com")],
907-
)
908-
# Should not raise despite trailing slash difference
909-
await provider._validate_resource_match(prm)
910889

911-
@pytest.mark.anyio
912-
async def test_accepts_server_url_with_trailing_slash(
913-
self, client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage
914-
) -> None:
915-
"""Server URL with trailing slash should match PRM resource."""
916-
provider = OAuthClientProvider(
917-
server_url="https://api.example.com/v1/mcp/",
918-
client_metadata=client_metadata,
919-
storage=mock_storage,
920-
)
921-
provider._initialized = True
890+
@pytest.mark.anyio
891+
async def test_validate_resource_accepts_root_url_with_trailing_slash(
892+
client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage
893+
) -> None:
894+
"""Root URLs with trailing slash normalization should match."""
895+
provider = OAuthClientProvider(
896+
server_url="https://api.example.com",
897+
client_metadata=client_metadata,
898+
storage=mock_storage,
899+
)
900+
provider._initialized = True
922901

923-
prm = ProtectedResourceMetadata(
924-
resource=AnyHttpUrl("https://api.example.com/v1/mcp"),
925-
authorization_servers=[AnyHttpUrl("https://auth.example.com")],
926-
)
927-
# Should not raise - both normalize to the same URL with trailing slash
928-
await provider._validate_resource_match(prm)
902+
prm = ProtectedResourceMetadata(
903+
resource=AnyHttpUrl("https://api.example.com/"),
904+
authorization_servers=[AnyHttpUrl("https://auth.example.com")],
905+
)
906+
# Should not raise despite trailing slash difference
907+
await provider._validate_resource_match(prm)
929908

930-
@pytest.mark.anyio
931-
async def test_get_resource_url_uses_canonical_when_prm_mismatches(
932-
self, client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage
933-
) -> None:
934-
"""get_resource_url falls back to canonical URL when PRM resource doesn't match."""
935-
provider = OAuthClientProvider(
936-
server_url="https://api.example.com/v1/mcp",
937-
client_metadata=client_metadata,
938-
storage=mock_storage,
939-
)
940-
provider._initialized = True
941909

942-
# Set PRM with a resource that is NOT a parent of the server URL
943-
provider.context.protected_resource_metadata = ProtectedResourceMetadata(
944-
resource=AnyHttpUrl("https://other.example.com/mcp"),
945-
authorization_servers=[AnyHttpUrl("https://auth.example.com")],
946-
)
910+
@pytest.mark.anyio
911+
async def test_validate_resource_accepts_server_url_with_trailing_slash(
912+
client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage
913+
) -> None:
914+
"""Server URL with trailing slash should match PRM resource."""
915+
provider = OAuthClientProvider(
916+
server_url="https://api.example.com/v1/mcp/",
917+
client_metadata=client_metadata,
918+
storage=mock_storage,
919+
)
920+
provider._initialized = True
921+
922+
prm = ProtectedResourceMetadata(
923+
resource=AnyHttpUrl("https://api.example.com/v1/mcp"),
924+
authorization_servers=[AnyHttpUrl("https://auth.example.com")],
925+
)
926+
# Should not raise - both normalize to the same URL with trailing slash
927+
await provider._validate_resource_match(prm)
928+
929+
930+
@pytest.mark.anyio
931+
async def test_get_resource_url_uses_canonical_when_prm_mismatches(
932+
client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage
933+
) -> None:
934+
"""get_resource_url falls back to canonical URL when PRM resource doesn't match."""
935+
provider = OAuthClientProvider(
936+
server_url="https://api.example.com/v1/mcp",
937+
client_metadata=client_metadata,
938+
storage=mock_storage,
939+
)
940+
provider._initialized = True
941+
942+
# Set PRM with a resource that is NOT a parent of the server URL
943+
provider.context.protected_resource_metadata = ProtectedResourceMetadata(
944+
resource=AnyHttpUrl("https://other.example.com/mcp"),
945+
authorization_servers=[AnyHttpUrl("https://auth.example.com")],
946+
)
947947

948-
# get_resource_url should return the canonical server URL, not the PRM resource
949-
resource = provider.context.get_resource_url()
950-
assert resource == "https://api.example.com/v1/mcp"
948+
# get_resource_url should return the canonical server URL, not the PRM resource
949+
assert provider.context.get_resource_url() == snapshot("https://api.example.com/v1/mcp")
951950

952951

953952
class TestRegistrationResponse:

0 commit comments

Comments
 (0)