Skip to content

Commit f577001

Browse files
committed
Add tests and coverage pragmas for client_secret_basic auth support
- Add tests for Basic auth error cases (invalid base64, no colon, client_id mismatch) - Add test for 'none' auth method with public clients - Add test for explicit token_endpoint_auth_method setting - Add pragma: no cover for defensive error handling paths
1 parent 32a5e6d commit f577001

File tree

6 files changed

+224
-6
lines changed

6 files changed

+224
-6
lines changed

src/mcp/client/auth/oauth2.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -186,10 +186,10 @@ def prepare_token_auth(
186186
Tuple of (updated_data, updated_headers)
187187
"""
188188
if headers is None:
189-
headers = {}
189+
headers = {} # pragma: no cover
190190

191191
if not self.client_info:
192-
return data, headers
192+
return data, headers # pragma: no cover
193193

194194
auth_method = self.client_info.token_endpoint_auth_method
195195

src/mcp/server/auth/handlers/revoke.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ async def handle(self, request: Request) -> Response:
4141
"""
4242
try:
4343
client = await self.client_authenticator.authenticate_request(request)
44-
except AuthenticationError as e:
44+
except AuthenticationError as e: # pragma: no cover
4545
return PydanticJSONResponse(
4646
status_code=401,
4747
content=RevocationErrorResponse(

src/mcp/server/auth/handlers/token.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ async def handle(self, request: Request):
110110
try:
111111
form_data = await request.form()
112112
token_request = TokenRequest.model_validate(dict(form_data)).root
113-
except ValidationError as validation_error:
113+
except ValidationError as validation_error: # pragma: no cover
114114
return self.response(
115115
TokenErrorResponse(
116116
error="invalid_request",

src/mcp/server/auth/middleware/client_auth.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,9 @@ async def authenticate_request(self, request: Request) -> OAuthClientInformation
9393
elif client.token_endpoint_auth_method == "none":
9494
request_client_secret = None
9595
else:
96-
raise AuthenticationError(f"Unsupported auth method: {client.token_endpoint_auth_method}")
96+
raise AuthenticationError( # pragma: no cover
97+
f"Unsupported auth method: {client.token_endpoint_auth_method}"
98+
)
9799

98100
# If client from the store expects a secret, validate that the request provides
99101
# that secret

tests/client/test_auth.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -601,6 +601,39 @@ async def test_register_client_skip_if_registered(self, oauth_provider: OAuthCli
601601
request = await oauth_provider._register_client()
602602
assert request is None
603603

604+
@pytest.mark.anyio
605+
async def test_register_client_explicit_auth_method(self, mock_storage: MockTokenStorage):
606+
"""Test that explicitly set token_endpoint_auth_method is used without auto-selection."""
607+
608+
async def redirect_handler(url: str) -> None:
609+
pass # pragma: no cover
610+
611+
async def callback_handler() -> tuple[str, str | None]:
612+
return "test_auth_code", "test_state" # pragma: no cover
613+
614+
# Create client metadata with explicit auth method
615+
explicit_metadata = OAuthClientMetadata(
616+
client_name="Test Client",
617+
client_uri=AnyHttpUrl("https://example.com"),
618+
redirect_uris=[AnyUrl("http://localhost:3030/callback")],
619+
scope="read write",
620+
token_endpoint_auth_method="client_secret_basic",
621+
)
622+
provider = OAuthClientProvider(
623+
server_url="https://api.example.com/v1/mcp",
624+
client_metadata=explicit_metadata,
625+
storage=mock_storage,
626+
redirect_handler=redirect_handler,
627+
callback_handler=callback_handler,
628+
)
629+
630+
request = await provider._register_client()
631+
assert request is not None
632+
633+
body = json.loads(request.content)
634+
# Should use the explicitly set method, not auto-select
635+
assert body["token_endpoint_auth_method"] == "client_secret_basic"
636+
604637
@pytest.mark.anyio
605638
async def test_register_client_none_auth_method_with_server_metadata(self, oauth_provider: OAuthClientProvider):
606639
"""Test that token_endpoint_auth_method=None selects from server's supported methods."""
@@ -611,7 +644,7 @@ async def test_register_client_none_auth_method_with_server_metadata(self, oauth
611644
token_endpoint=AnyHttpUrl("https://auth.example.com/token"),
612645
token_endpoint_auth_methods_supported=["client_secret_post"],
613646
)
614-
# Ensure client_metadata has None for token_endpoint_auth_method is None
647+
# Ensure client_metadata has None for token_endpoint_auth_method
615648

616649
request = await oauth_provider._register_client()
617650
assert request is not None

tests/server/fastmcp/auth/test_auth_integration.py

Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1117,6 +1117,189 @@ async def test_basic_auth_without_header_fails(
11171117
assert error_response["error"] == "unauthorized_client"
11181118
assert "Missing or invalid Basic authentication" in error_response["error_description"]
11191119

1120+
@pytest.mark.anyio
1121+
async def test_basic_auth_invalid_base64_fails(
1122+
self, test_client: httpx.AsyncClient, mock_oauth_provider: MockOAuthProvider, pkce_challenge: dict[str, str]
1123+
):
1124+
"""Test that invalid base64 in Basic auth header fails."""
1125+
client_metadata = {
1126+
"redirect_uris": ["https://client.example.com/callback"],
1127+
"client_name": "Basic Auth Client",
1128+
"token_endpoint_auth_method": "client_secret_basic",
1129+
"grant_types": ["authorization_code", "refresh_token"],
1130+
}
1131+
1132+
response = await test_client.post("/register", json=client_metadata)
1133+
assert response.status_code == 201
1134+
client_info = response.json()
1135+
1136+
auth_code = f"code_{int(time.time())}"
1137+
mock_oauth_provider.auth_codes[auth_code] = AuthorizationCode(
1138+
code=auth_code,
1139+
client_id=client_info["client_id"],
1140+
code_challenge=pkce_challenge["code_challenge"],
1141+
redirect_uri=AnyUrl("https://client.example.com/callback"),
1142+
redirect_uri_provided_explicitly=True,
1143+
scopes=["read", "write"],
1144+
expires_at=time.time() + 600,
1145+
)
1146+
1147+
# Send invalid base64
1148+
response = await test_client.post(
1149+
"/token",
1150+
headers={"Authorization": "Basic !!!invalid-base64!!!"},
1151+
data={
1152+
"grant_type": "authorization_code",
1153+
"client_id": client_info["client_id"],
1154+
"code": auth_code,
1155+
"code_verifier": pkce_challenge["code_verifier"],
1156+
"redirect_uri": "https://client.example.com/callback",
1157+
},
1158+
)
1159+
assert response.status_code == 401
1160+
error_response = response.json()
1161+
assert error_response["error"] == "unauthorized_client"
1162+
assert "Invalid Basic authentication header" in error_response["error_description"]
1163+
1164+
@pytest.mark.anyio
1165+
async def test_basic_auth_no_colon_fails(
1166+
self, test_client: httpx.AsyncClient, mock_oauth_provider: MockOAuthProvider, pkce_challenge: dict[str, str]
1167+
):
1168+
"""Test that Basic auth without colon separator fails."""
1169+
client_metadata = {
1170+
"redirect_uris": ["https://client.example.com/callback"],
1171+
"client_name": "Basic Auth Client",
1172+
"token_endpoint_auth_method": "client_secret_basic",
1173+
"grant_types": ["authorization_code", "refresh_token"],
1174+
}
1175+
1176+
response = await test_client.post("/register", json=client_metadata)
1177+
assert response.status_code == 201
1178+
client_info = response.json()
1179+
1180+
auth_code = f"code_{int(time.time())}"
1181+
mock_oauth_provider.auth_codes[auth_code] = AuthorizationCode(
1182+
code=auth_code,
1183+
client_id=client_info["client_id"],
1184+
code_challenge=pkce_challenge["code_challenge"],
1185+
redirect_uri=AnyUrl("https://client.example.com/callback"),
1186+
redirect_uri_provided_explicitly=True,
1187+
scopes=["read", "write"],
1188+
expires_at=time.time() + 600,
1189+
)
1190+
1191+
# Send base64 without colon (invalid format)
1192+
import base64
1193+
1194+
invalid_creds = base64.b64encode(b"no-colon-here").decode()
1195+
response = await test_client.post(
1196+
"/token",
1197+
headers={"Authorization": f"Basic {invalid_creds}"},
1198+
data={
1199+
"grant_type": "authorization_code",
1200+
"client_id": client_info["client_id"],
1201+
"code": auth_code,
1202+
"code_verifier": pkce_challenge["code_verifier"],
1203+
"redirect_uri": "https://client.example.com/callback",
1204+
},
1205+
)
1206+
assert response.status_code == 401
1207+
error_response = response.json()
1208+
assert error_response["error"] == "unauthorized_client"
1209+
assert "Invalid Basic authentication header" in error_response["error_description"]
1210+
1211+
@pytest.mark.anyio
1212+
async def test_basic_auth_client_id_mismatch_fails(
1213+
self, test_client: httpx.AsyncClient, mock_oauth_provider: MockOAuthProvider, pkce_challenge: dict[str, str]
1214+
):
1215+
"""Test that client_id mismatch between body and Basic auth fails."""
1216+
client_metadata = {
1217+
"redirect_uris": ["https://client.example.com/callback"],
1218+
"client_name": "Basic Auth Client",
1219+
"token_endpoint_auth_method": "client_secret_basic",
1220+
"grant_types": ["authorization_code", "refresh_token"],
1221+
}
1222+
1223+
response = await test_client.post("/register", json=client_metadata)
1224+
assert response.status_code == 201
1225+
client_info = response.json()
1226+
1227+
auth_code = f"code_{int(time.time())}"
1228+
mock_oauth_provider.auth_codes[auth_code] = AuthorizationCode(
1229+
code=auth_code,
1230+
client_id=client_info["client_id"],
1231+
code_challenge=pkce_challenge["code_challenge"],
1232+
redirect_uri=AnyUrl("https://client.example.com/callback"),
1233+
redirect_uri_provided_explicitly=True,
1234+
scopes=["read", "write"],
1235+
expires_at=time.time() + 600,
1236+
)
1237+
1238+
# Send different client_id in Basic auth header
1239+
import base64
1240+
1241+
wrong_creds = base64.b64encode(f"wrong-client-id:{client_info['client_secret']}".encode()).decode()
1242+
response = await test_client.post(
1243+
"/token",
1244+
headers={"Authorization": f"Basic {wrong_creds}"},
1245+
data={
1246+
"grant_type": "authorization_code",
1247+
"client_id": client_info["client_id"], # Correct client_id in body
1248+
"code": auth_code,
1249+
"code_verifier": pkce_challenge["code_verifier"],
1250+
"redirect_uri": "https://client.example.com/callback",
1251+
},
1252+
)
1253+
assert response.status_code == 401
1254+
error_response = response.json()
1255+
assert error_response["error"] == "unauthorized_client"
1256+
assert "Client ID mismatch" in error_response["error_description"]
1257+
1258+
@pytest.mark.anyio
1259+
async def test_none_auth_method_public_client(
1260+
self, test_client: httpx.AsyncClient, mock_oauth_provider: MockOAuthProvider, pkce_challenge: dict[str, str]
1261+
):
1262+
"""Test that 'none' authentication method works for public clients."""
1263+
client_metadata = {
1264+
"redirect_uris": ["https://client.example.com/callback"],
1265+
"client_name": "Public Client",
1266+
"token_endpoint_auth_method": "none",
1267+
"grant_types": ["authorization_code", "refresh_token"],
1268+
}
1269+
1270+
response = await test_client.post("/register", json=client_metadata)
1271+
assert response.status_code == 201
1272+
client_info = response.json()
1273+
assert client_info["token_endpoint_auth_method"] == "none"
1274+
# Public clients should not have a client_secret
1275+
assert "client_secret" not in client_info or client_info.get("client_secret") is None
1276+
1277+
auth_code = f"code_{int(time.time())}"
1278+
mock_oauth_provider.auth_codes[auth_code] = AuthorizationCode(
1279+
code=auth_code,
1280+
client_id=client_info["client_id"],
1281+
code_challenge=pkce_challenge["code_challenge"],
1282+
redirect_uri=AnyUrl("https://client.example.com/callback"),
1283+
redirect_uri_provided_explicitly=True,
1284+
scopes=["read", "write"],
1285+
expires_at=time.time() + 600,
1286+
)
1287+
1288+
# Token request without any client secret
1289+
response = await test_client.post(
1290+
"/token",
1291+
data={
1292+
"grant_type": "authorization_code",
1293+
"client_id": client_info["client_id"],
1294+
"code": auth_code,
1295+
"code_verifier": pkce_challenge["code_verifier"],
1296+
"redirect_uri": "https://client.example.com/callback",
1297+
},
1298+
)
1299+
assert response.status_code == 200
1300+
token_response = response.json()
1301+
assert "access_token" in token_response
1302+
11201303

11211304
class TestAuthorizeEndpointErrors:
11221305
"""Test error handling in the OAuth authorization endpoint."""

0 commit comments

Comments
 (0)