Skip to content

Commit eac35d4

Browse files
jonsheapcarleton
authored andcommitted
Add client_secret_basic authentication support
Add support for HTTP Basic Authentication (client_secret_basic) as a client authentication method for the token and revoke endpoints, alongside the existing client_secret_post method. This improves compatibility with OAuth servers like Keycloak that use Basic auth. Key changes: - Update OAuthClientMetadata to accept "client_secret_basic" as valid token_endpoint_auth_method - Return 401 status for authentication failures (was 400) - Update metadata endpoints to advertise both auth methods - Add tests for both auth methods and edge cases
1 parent fcffa14 commit eac35d4

File tree

7 files changed

+237
-41
lines changed

7 files changed

+237
-41
lines changed

src/mcp/server/auth/handlers/revoke.py

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -40,28 +40,25 @@ async def handle(self, request: Request) -> Response:
4040
Handler for the OAuth 2.0 Token Revocation endpoint.
4141
"""
4242
try:
43-
form_data = await request.form()
44-
revocation_request = RevocationRequest.model_validate(dict(form_data))
45-
except ValidationError as e:
43+
client = await self.client_authenticator.authenticate_request(request)
44+
except AuthenticationError as e:
4645
return PydanticJSONResponse(
47-
status_code=400,
46+
status_code=401,
4847
content=RevocationErrorResponse(
49-
error="invalid_request",
50-
error_description=stringify_pydantic_error(e),
48+
error="unauthorized_client",
49+
error_description=e.message,
5150
),
5251
)
5352

54-
# Authenticate client
5553
try:
56-
client = await self.client_authenticator.authenticate(
57-
revocation_request.client_id, revocation_request.client_secret
58-
)
59-
except AuthenticationError as e: # pragma: no cover
54+
form_data = await request.form()
55+
revocation_request = RevocationRequest.model_validate(dict(form_data))
56+
except ValidationError as e:
6057
return PydanticJSONResponse(
61-
status_code=401,
58+
status_code=400,
6259
content=RevocationErrorResponse(
63-
error="unauthorized_client",
64-
error_description=e.message,
60+
error="invalid_request",
61+
error_description=stringify_pydantic_error(e),
6562
),
6663
)
6764

src/mcp/server/auth/handlers/token.py

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,22 @@ def response(self, obj: TokenSuccessResponse | TokenErrorResponse):
9191
)
9292

9393
async def handle(self, request: Request):
94+
try:
95+
client_info = await self.client_authenticator.authenticate_request(request)
96+
except AuthenticationError as e:
97+
# Authentication failures should return 401
98+
return PydanticJSONResponse(
99+
content=TokenErrorResponse(
100+
error="unauthorized_client",
101+
error_description=e.message,
102+
),
103+
status_code=401,
104+
headers={
105+
"Cache-Control": "no-store",
106+
"Pragma": "no-cache",
107+
},
108+
)
109+
94110
try:
95111
form_data = await request.form()
96112
token_request = TokenRequest.model_validate(dict(form_data)).root
@@ -102,19 +118,6 @@ async def handle(self, request: Request):
102118
)
103119
)
104120

105-
try:
106-
client_info = await self.client_authenticator.authenticate(
107-
client_id=token_request.client_id,
108-
client_secret=token_request.client_secret,
109-
)
110-
except AuthenticationError as e: # pragma: no cover
111-
return self.response(
112-
TokenErrorResponse(
113-
error="unauthorized_client",
114-
error_description=e.message,
115-
)
116-
)
117-
118121
if token_request.grant_type not in client_info.grant_types: # pragma: no cover
119122
return self.response(
120123
TokenErrorResponse(

src/mcp/server/auth/middleware/client_auth.py

Lines changed: 58 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
1+
import base64
12
import time
23
from typing import Any
34

5+
from starlette.requests import Request
6+
47
from mcp.server.auth.provider import OAuthAuthorizationServerProvider
58
from mcp.shared.auth import OAuthClientInformationFull
69

@@ -30,19 +33,69 @@ def __init__(self, provider: OAuthAuthorizationServerProvider[Any, Any, Any]):
3033
"""
3134
self.provider = provider
3235

33-
async def authenticate(self, client_id: str, client_secret: str | None) -> OAuthClientInformationFull:
34-
# Look up client information
35-
client = await self.provider.get_client(client_id)
36+
async def authenticate_request(self, request: Request) -> OAuthClientInformationFull:
37+
"""
38+
Authenticate a client from an HTTP request.
39+
40+
Extracts client credentials from the appropriate location based on the
41+
client's registered authentication method and validates them.
42+
43+
Args:
44+
request: The HTTP request containing client credentials
45+
46+
Returns:
47+
The authenticated client information
48+
49+
Raises:
50+
AuthenticationError: If authentication fails
51+
"""
52+
form_data = await request.form()
53+
client_id = form_data.get("client_id")
54+
if not client_id:
55+
raise AuthenticationError("Missing client_id")
56+
57+
client = await self.provider.get_client(str(client_id))
3658
if not client:
3759
raise AuthenticationError("Invalid client_id") # pragma: no cover
3860

61+
request_client_secret = None
62+
auth_header = request.headers.get("Authorization", "")
63+
64+
if client.token_endpoint_auth_method == "client_secret_basic":
65+
if not auth_header.startswith("Basic "):
66+
raise AuthenticationError("Missing or invalid Basic authentication in Authorization header")
67+
68+
try:
69+
encoded_credentials = auth_header[6:] # Remove "Basic " prefix
70+
decoded = base64.b64decode(encoded_credentials).decode("utf-8")
71+
if ":" not in decoded:
72+
raise ValueError("Invalid Basic auth format")
73+
basic_client_id, request_client_secret = decoded.split(":", 1)
74+
75+
if basic_client_id != client_id:
76+
raise AuthenticationError("Client ID mismatch in Basic auth")
77+
except AuthenticationError:
78+
raise
79+
except Exception:
80+
raise AuthenticationError("Invalid Basic authentication header")
81+
82+
elif client.token_endpoint_auth_method == "client_secret_post":
83+
request_client_secret = form_data.get("client_secret")
84+
if request_client_secret:
85+
request_client_secret = str(request_client_secret)
86+
87+
elif client.token_endpoint_auth_method == "none":
88+
request_client_secret = None
89+
else:
90+
raise AuthenticationError(f"Unsupported auth method: {client.token_endpoint_auth_method}")
91+
3992
# If client from the store expects a secret, validate that the request provides
4093
# that secret
4194
if client.client_secret: # pragma: no branch
42-
if not client_secret:
95+
if not request_client_secret:
4396
raise AuthenticationError("Client secret is required") # pragma: no cover
4497

45-
if client.client_secret != client_secret:
98+
if client.client_secret != request_client_secret:
4699
raise AuthenticationError("Invalid client_secret") # pragma: no cover
47100

48101
if client.client_secret_expires_at and client.client_secret_expires_at < int(time.time()):

src/mcp/server/auth/routes.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ def build_metadata(
165165
response_types_supported=["code"],
166166
response_modes_supported=None,
167167
grant_types_supported=["authorization_code", "refresh_token"],
168-
token_endpoint_auth_methods_supported=["client_secret_post"],
168+
token_endpoint_auth_methods_supported=["client_secret_post", "client_secret_basic"],
169169
token_endpoint_auth_signing_alg_values_supported=None,
170170
service_documentation=service_documentation_url,
171171
ui_locales_supported=None,
@@ -182,7 +182,7 @@ def build_metadata(
182182
# Add revocation endpoint if supported
183183
if revocation_options.enabled: # pragma: no branch
184184
metadata.revocation_endpoint = AnyHttpUrl(str(issuer_url).rstrip("/") + REVOCATION_PATH)
185-
metadata.revocation_endpoint_auth_methods_supported = ["client_secret_post"]
185+
metadata.revocation_endpoint_auth_methods_supported = ["client_secret_post", "client_secret_basic"]
186186

187187
return metadata
188188

src/mcp/shared/auth.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,9 @@ class OAuthClientMetadata(BaseModel):
4343

4444
redirect_uris: list[AnyUrl] | None = Field(..., min_length=1)
4545
# supported auth methods for the token endpoint
46-
token_endpoint_auth_method: Literal["none", "client_secret_post", "private_key_jwt"] = "client_secret_post"
46+
token_endpoint_auth_method: Literal[
47+
"none", "client_secret_post", "client_secret_basic", "private_key_jwt"
48+
] = "client_secret_post"
4749
# supported grant_types of this implementation
4850
grant_types: list[
4951
Literal["authorization_code", "refresh_token", "urn:ietf:params:oauth:grant-type:jwt-bearer"] | str

tests/client/test_auth.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1103,10 +1103,10 @@ def test_build_metadata(
11031103
"registration_endpoint": Is(registration_endpoint),
11041104
"scopes_supported": ["read", "write", "admin"],
11051105
"grant_types_supported": ["authorization_code", "refresh_token"],
1106-
"token_endpoint_auth_methods_supported": ["client_secret_post"],
1106+
"token_endpoint_auth_methods_supported": ["client_secret_post", "client_secret_basic"],
11071107
"service_documentation": Is(service_documentation_url),
11081108
"revocation_endpoint": Is(revocation_endpoint),
1109-
"revocation_endpoint_auth_methods_supported": ["client_secret_post"],
1109+
"revocation_endpoint_auth_methods_supported": ["client_secret_post", "client_secret_basic"],
11101110
"code_challenge_methods_supported": ["S256"],
11111111
}
11121112
)

tests/server/fastmcp/auth/test_auth_integration.py

Lines changed: 145 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
import httpx
1414
import pytest
15-
from pydantic import AnyHttpUrl
15+
from pydantic import AnyHttpUrl, AnyUrl
1616
from starlette.applications import Starlette
1717

1818
from mcp.server.auth.provider import (
@@ -320,7 +320,7 @@ async def test_metadata_endpoint(self, test_client: httpx.AsyncClient):
320320
assert metadata["revocation_endpoint"] == "https://auth.example.com/revoke"
321321
assert metadata["response_types_supported"] == ["code"]
322322
assert metadata["code_challenge_methods_supported"] == ["S256"]
323-
assert metadata["token_endpoint_auth_methods_supported"] == ["client_secret_post"]
323+
assert metadata["token_endpoint_auth_methods_supported"] == ["client_secret_post", "client_secret_basic"]
324324
assert metadata["grant_types_supported"] == [
325325
"authorization_code",
326326
"refresh_token",
@@ -339,8 +339,8 @@ async def test_token_validation_error(self, test_client: httpx.AsyncClient):
339339
},
340340
)
341341
error_response = response.json()
342-
assert error_response["error"] == "invalid_request"
343-
assert "error_description" in error_response # Contains validation error messages
342+
assert error_response["error"] == "unauthorized_client"
343+
assert "error_description" in error_response # Contains error message
344344

345345
@pytest.mark.anyio
346346
async def test_token_invalid_auth_code(
@@ -976,6 +976,147 @@ async def test_client_registration_default_response_types(
976976
assert "response_types" in data
977977
assert data["response_types"] == ["code"]
978978

979+
@pytest.mark.anyio
980+
async def test_client_secret_basic_authentication(
981+
self, test_client: httpx.AsyncClient, mock_oauth_provider: MockOAuthProvider, pkce_challenge: dict[str, str]
982+
):
983+
"""Test that client_secret_basic authentication works correctly."""
984+
client_metadata = {
985+
"redirect_uris": ["https://client.example.com/callback"],
986+
"client_name": "Basic Auth Client",
987+
"token_endpoint_auth_method": "client_secret_basic",
988+
"grant_types": ["authorization_code", "refresh_token"],
989+
}
990+
991+
response = await test_client.post("/register", json=client_metadata)
992+
assert response.status_code == 201
993+
client_info = response.json()
994+
assert client_info["token_endpoint_auth_method"] == "client_secret_basic"
995+
996+
auth_code = f"code_{int(time.time())}"
997+
mock_oauth_provider.auth_codes[auth_code] = AuthorizationCode(
998+
code=auth_code,
999+
client_id=client_info["client_id"],
1000+
code_challenge=pkce_challenge["code_challenge"],
1001+
redirect_uri=AnyUrl("https://client.example.com/callback"),
1002+
redirect_uri_provided_explicitly=True,
1003+
scopes=["read", "write"],
1004+
expires_at=time.time() + 600,
1005+
)
1006+
1007+
credentials = f"{client_info['client_id']}:{client_info['client_secret']}"
1008+
encoded_credentials = base64.b64encode(credentials.encode()).decode()
1009+
1010+
response = await test_client.post(
1011+
"/token",
1012+
headers={"Authorization": f"Basic {encoded_credentials}"},
1013+
data={
1014+
"grant_type": "authorization_code",
1015+
"client_id": client_info["client_id"],
1016+
"code": auth_code,
1017+
"code_verifier": pkce_challenge["code_verifier"],
1018+
"redirect_uri": "https://client.example.com/callback",
1019+
},
1020+
)
1021+
assert response.status_code == 200
1022+
token_response = response.json()
1023+
assert "access_token" in token_response
1024+
1025+
@pytest.mark.anyio
1026+
async def test_wrong_auth_method_without_valid_credentials_fails(
1027+
self, test_client: httpx.AsyncClient, mock_oauth_provider: MockOAuthProvider, pkce_challenge: dict[str, str]
1028+
):
1029+
"""Test that using the wrong authentication method fails when credentials are missing."""
1030+
client_metadata = {
1031+
"redirect_uris": ["https://client.example.com/callback"],
1032+
"client_name": "Post Auth Client",
1033+
"token_endpoint_auth_method": "client_secret_post",
1034+
"grant_types": ["authorization_code", "refresh_token"],
1035+
}
1036+
1037+
response = await test_client.post("/register", json=client_metadata)
1038+
assert response.status_code == 201
1039+
client_info = response.json()
1040+
assert client_info["token_endpoint_auth_method"] == "client_secret_post"
1041+
1042+
auth_code = f"code_{int(time.time())}"
1043+
mock_oauth_provider.auth_codes[auth_code] = AuthorizationCode(
1044+
code=auth_code,
1045+
client_id=client_info["client_id"],
1046+
code_challenge=pkce_challenge["code_challenge"],
1047+
redirect_uri=AnyUrl("https://client.example.com/callback"),
1048+
redirect_uri_provided_explicitly=True,
1049+
scopes=["read", "write"],
1050+
expires_at=time.time() + 600,
1051+
)
1052+
1053+
# Try to use Basic auth when client_secret_post is registered (without secret in body)
1054+
# This should fail because the secret is missing from the expected location
1055+
1056+
credentials = f"{client_info['client_id']}:{client_info['client_secret']}"
1057+
encoded_credentials = base64.b64encode(credentials.encode()).decode()
1058+
1059+
response = await test_client.post(
1060+
"/token",
1061+
headers={"Authorization": f"Basic {encoded_credentials}"},
1062+
data={
1063+
"grant_type": "authorization_code",
1064+
"client_id": client_info["client_id"],
1065+
# client_secret NOT in body where it should be
1066+
"code": auth_code,
1067+
"code_verifier": pkce_challenge["code_verifier"],
1068+
"redirect_uri": "https://client.example.com/callback",
1069+
},
1070+
)
1071+
assert response.status_code == 401
1072+
error_response = response.json()
1073+
assert error_response["error"] == "unauthorized_client"
1074+
assert "Client secret is required" in error_response["error_description"]
1075+
1076+
@pytest.mark.anyio
1077+
async def test_basic_auth_without_header_fails(
1078+
self, test_client: httpx.AsyncClient, mock_oauth_provider: MockOAuthProvider, pkce_challenge: dict[str, str]
1079+
):
1080+
"""Test that omitting Basic auth when client_secret_basic is registered fails."""
1081+
client_metadata = {
1082+
"redirect_uris": ["https://client.example.com/callback"],
1083+
"client_name": "Basic Auth Client",
1084+
"token_endpoint_auth_method": "client_secret_basic",
1085+
"grant_types": ["authorization_code", "refresh_token"],
1086+
}
1087+
1088+
response = await test_client.post("/register", json=client_metadata)
1089+
assert response.status_code == 201
1090+
client_info = response.json()
1091+
assert client_info["token_endpoint_auth_method"] == "client_secret_basic"
1092+
1093+
auth_code = f"code_{int(time.time())}"
1094+
mock_oauth_provider.auth_codes[auth_code] = AuthorizationCode(
1095+
code=auth_code,
1096+
client_id=client_info["client_id"],
1097+
code_challenge=pkce_challenge["code_challenge"],
1098+
redirect_uri=AnyUrl("https://client.example.com/callback"),
1099+
redirect_uri_provided_explicitly=True,
1100+
scopes=["read", "write"],
1101+
expires_at=time.time() + 600,
1102+
)
1103+
1104+
response = await test_client.post(
1105+
"/token",
1106+
data={
1107+
"grant_type": "authorization_code",
1108+
"client_id": client_info["client_id"],
1109+
"client_secret": client_info["client_secret"], # Secret in body (ignored)
1110+
"code": auth_code,
1111+
"code_verifier": pkce_challenge["code_verifier"],
1112+
"redirect_uri": "https://client.example.com/callback",
1113+
},
1114+
)
1115+
assert response.status_code == 401
1116+
error_response = response.json()
1117+
assert error_response["error"] == "unauthorized_client"
1118+
assert "Missing or invalid Basic authentication" in error_response["error_description"]
1119+
9791120

9801121
class TestAuthorizeEndpointErrors:
9811122
"""Test error handling in the OAuth authorization endpoint."""

0 commit comments

Comments
 (0)