Skip to content

Commit 7c57b82

Browse files
committed
coverage increase
1 parent cd397b0 commit 7c57b82

File tree

2 files changed

+296
-6
lines changed

2 files changed

+296
-6
lines changed

src/mcp/client/auth/extensions/client_credentials.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -83,14 +83,14 @@ def __init__(
8383

8484
async def _exchange_token_authorization_code(
8585
self, auth_code: str, code_verifier: str, *, token_data: dict[str, Any] | None = None
86-
) -> httpx.Request: # pragma: no cover
86+
) -> httpx.Request:
8787
"""Build token exchange request for authorization_code flow."""
8888
token_data = token_data or {}
8989
if self.context.client_metadata.token_endpoint_auth_method == "private_key_jwt":
9090
self._add_client_authentication_jwt(token_data=token_data)
9191
return await super()._exchange_token_authorization_code(auth_code, code_verifier, token_data=token_data)
9292

93-
async def _perform_authorization(self) -> httpx.Request: # pragma: no cover
93+
async def _perform_authorization(self) -> httpx.Request:
9494
"""Perform the authorization flow."""
9595
if "client_credentials" in self.context.client_metadata.grant_types:
9696
# SEP-1046: client_credentials grant with private_key_jwt authentication
@@ -103,12 +103,12 @@ async def _perform_authorization(self) -> httpx.Request: # pragma: no cover
103103
else:
104104
return await super()._perform_authorization()
105105

106-
def _add_client_authentication_jwt(self, *, token_data: dict[str, Any]): # pragma: no cover
106+
def _add_client_authentication_jwt(self, *, token_data: dict[str, Any]):
107107
"""Add JWT assertion for client authentication to token endpoint parameters."""
108108
if not self.jwt_parameters:
109-
raise OAuthTokenError("Missing JWT parameters for private_key_jwt flow")
109+
raise OAuthTokenError("Missing JWT parameters for private_key_jwt flow") # pragma: no cover
110110
if not self.context.oauth_metadata:
111-
raise OAuthTokenError("Missing OAuth metadata for private_key_jwt flow")
111+
raise OAuthTokenError("Missing OAuth metadata for private_key_jwt flow") # pragma: no cover
112112

113113
# We need to set the audience to the issuer identifier of the authorization server
114114
# https://datatracker.ietf.org/doc/html/draft-ietf-oauth-rfc7523bis-01#name-updates-to-rfc-7523
@@ -122,7 +122,7 @@ def _add_client_authentication_jwt(self, *, token_data: dict[str, Any]): # prag
122122
# it represents the resource server that will validate the token
123123
token_data["audience"] = self.context.get_resource_url()
124124

125-
async def _exchange_token_client_credentials(self) -> httpx.Request: # pragma: no cover
125+
async def _exchange_token_client_credentials(self) -> httpx.Request:
126126
"""Build token exchange request for client_credentials grant.
127127
128128
This implements SEP-1046: OAuth Client Credentials Extension for MCP.

tests/client/auth/extensions/test_client_credentials.py

Lines changed: 290 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import base64
12
import urllib.parse
23

34
import jwt
@@ -161,3 +162,292 @@ async def test_token_exchange_request_jwt(self, rfc7523_oauth_provider: RFC7523O
161162
assert claims["name"] == "John Doe"
162163
assert claims["admin"]
163164
assert claims["iat"] == 1516239022
165+
166+
@pytest.mark.anyio
167+
async def test_exchange_token_client_credentials_with_private_key_jwt(
168+
self, rfc7523_oauth_provider: RFC7523OAuthClientProvider
169+
):
170+
"""Test client_credentials token exchange with private_key_jwt authentication."""
171+
# Set up required context for client_credentials with private_key_jwt
172+
rfc7523_oauth_provider.context.client_info = OAuthClientInformationFull(
173+
client_id="test-client",
174+
grant_types=["client_credentials"],
175+
token_endpoint_auth_method="private_key_jwt",
176+
redirect_uris=[AnyUrl("http://localhost:0/unused")],
177+
scope="read write",
178+
)
179+
rfc7523_oauth_provider.context.oauth_metadata = OAuthMetadata(
180+
issuer=AnyHttpUrl("https://auth.example.com"),
181+
authorization_endpoint=AnyHttpUrl("https://auth.example.com/authorize"),
182+
token_endpoint=AnyHttpUrl("https://auth.example.com/token"),
183+
)
184+
rfc7523_oauth_provider.context.client_metadata = rfc7523_oauth_provider.context.client_info
185+
rfc7523_oauth_provider.context.protocol_version = "2025-06-18"
186+
rfc7523_oauth_provider.jwt_parameters = JWTParameters(
187+
issuer="test-client",
188+
subject="test-client",
189+
jwt_signing_algorithm="HS256",
190+
jwt_signing_key="a-string-secret-at-least-256-bits-long",
191+
jwt_lifetime_seconds=300,
192+
)
193+
194+
request = await rfc7523_oauth_provider._exchange_token_client_credentials()
195+
196+
assert request.method == "POST"
197+
assert str(request.url) == "https://auth.example.com/token"
198+
assert request.headers["Content-Type"] == "application/x-www-form-urlencoded"
199+
200+
# Check form data
201+
content = urllib.parse.unquote_plus(request.content.decode())
202+
assert "grant_type=client_credentials" in content
203+
assert "scope=read write" in content
204+
assert "resource=https://api.example.com/v1/mcp" in content
205+
assert "client_assertion=" in content
206+
assert "client_assertion_type=urn:ietf:params:oauth:client-assertion-type:jwt-bearer" in content
207+
208+
@pytest.mark.anyio
209+
async def test_exchange_token_client_credentials_with_client_secret_basic(
210+
self, rfc7523_oauth_provider: RFC7523OAuthClientProvider
211+
):
212+
"""Test client_credentials token exchange with client_secret_basic authentication."""
213+
# Set up required context for client_credentials with client_secret_basic
214+
rfc7523_oauth_provider.context.client_info = OAuthClientInformationFull(
215+
client_id="test-client",
216+
client_secret="test-secret",
217+
grant_types=["client_credentials"],
218+
token_endpoint_auth_method="client_secret_basic",
219+
redirect_uris=[AnyUrl("http://localhost:0/unused")],
220+
scope="read write",
221+
)
222+
rfc7523_oauth_provider.context.oauth_metadata = OAuthMetadata(
223+
issuer=AnyHttpUrl("https://auth.example.com"),
224+
authorization_endpoint=AnyHttpUrl("https://auth.example.com/authorize"),
225+
token_endpoint=AnyHttpUrl("https://auth.example.com/token"),
226+
)
227+
rfc7523_oauth_provider.context.client_metadata = rfc7523_oauth_provider.context.client_info
228+
rfc7523_oauth_provider.context.protocol_version = "2025-06-18"
229+
# No JWT parameters needed for client_secret_basic
230+
231+
request = await rfc7523_oauth_provider._exchange_token_client_credentials()
232+
233+
assert request.method == "POST"
234+
assert str(request.url) == "https://auth.example.com/token"
235+
assert request.headers["Content-Type"] == "application/x-www-form-urlencoded"
236+
237+
# Check Authorization header (Basic auth)
238+
assert "Authorization" in request.headers
239+
auth_header = request.headers["Authorization"]
240+
assert auth_header.startswith("Basic ")
241+
242+
# Decode and verify credentials
243+
encoded_creds = auth_header[6:] # Remove "Basic " prefix
244+
decoded = base64.b64decode(encoded_creds).decode()
245+
assert decoded == "test-client:test-secret"
246+
247+
# Check form data
248+
content = urllib.parse.unquote_plus(request.content.decode())
249+
assert "grant_type=client_credentials" in content
250+
assert "scope=read write" in content
251+
assert "resource=https://api.example.com/v1/mcp" in content
252+
# client_secret should NOT be in body for client_secret_basic
253+
assert "client_secret=" not in content
254+
255+
@pytest.mark.anyio
256+
async def test_perform_authorization_routes_to_client_credentials(
257+
self, rfc7523_oauth_provider: RFC7523OAuthClientProvider
258+
):
259+
"""Test that _perform_authorization routes to client_credentials when configured."""
260+
# Set up required context for client_credentials flow
261+
rfc7523_oauth_provider.context.client_info = OAuthClientInformationFull(
262+
client_id="test-client",
263+
client_secret="test-secret",
264+
grant_types=["client_credentials"],
265+
token_endpoint_auth_method="client_secret_basic",
266+
redirect_uris=[AnyUrl("http://localhost:0/unused")],
267+
)
268+
rfc7523_oauth_provider.context.oauth_metadata = OAuthMetadata(
269+
issuer=AnyHttpUrl("https://auth.example.com"),
270+
authorization_endpoint=AnyHttpUrl("https://auth.example.com/authorize"),
271+
token_endpoint=AnyHttpUrl("https://auth.example.com/token"),
272+
)
273+
rfc7523_oauth_provider.context.client_metadata = rfc7523_oauth_provider.context.client_info
274+
rfc7523_oauth_provider.context.protocol_version = "2025-06-18"
275+
276+
request = await rfc7523_oauth_provider._perform_authorization()
277+
278+
# Should route to client_credentials flow
279+
content = urllib.parse.unquote_plus(request.content.decode())
280+
assert "grant_type=client_credentials" in content
281+
282+
@pytest.mark.anyio
283+
async def test_perform_authorization_routes_to_jwt_bearer(
284+
self, rfc7523_oauth_provider: RFC7523OAuthClientProvider
285+
):
286+
"""Test that _perform_authorization routes to jwt-bearer when configured."""
287+
# Set up required context for jwt-bearer flow
288+
rfc7523_oauth_provider.context.client_info = OAuthClientInformationFull(
289+
client_id="test-client",
290+
grant_types=["urn:ietf:params:oauth:grant-type:jwt-bearer"],
291+
token_endpoint_auth_method="private_key_jwt",
292+
redirect_uris=[AnyUrl("http://localhost:0/unused")],
293+
)
294+
rfc7523_oauth_provider.context.oauth_metadata = OAuthMetadata(
295+
issuer=AnyHttpUrl("https://auth.example.com"),
296+
authorization_endpoint=AnyHttpUrl("https://auth.example.com/authorize"),
297+
token_endpoint=AnyHttpUrl("https://auth.example.com/token"),
298+
)
299+
rfc7523_oauth_provider.context.client_metadata = rfc7523_oauth_provider.context.client_info
300+
rfc7523_oauth_provider.context.protocol_version = "2025-06-18"
301+
rfc7523_oauth_provider.jwt_parameters = JWTParameters(
302+
issuer="test-client",
303+
subject="test-client",
304+
jwt_signing_algorithm="HS256",
305+
jwt_signing_key="a-string-secret-at-least-256-bits-long",
306+
)
307+
308+
request = await rfc7523_oauth_provider._perform_authorization()
309+
310+
# Should route to jwt-bearer flow
311+
content = urllib.parse.unquote_plus(request.content.decode())
312+
assert "grant_type=urn:ietf:params:oauth:grant-type:jwt-bearer" in content
313+
314+
@pytest.mark.anyio
315+
async def test_add_client_authentication_jwt(
316+
self, rfc7523_oauth_provider: RFC7523OAuthClientProvider
317+
):
318+
"""Test _add_client_authentication_jwt adds correct JWT assertion parameters."""
319+
# Set up required context
320+
rfc7523_oauth_provider.context.oauth_metadata = OAuthMetadata(
321+
issuer=AnyHttpUrl("https://auth.example.com"),
322+
authorization_endpoint=AnyHttpUrl("https://auth.example.com/authorize"),
323+
token_endpoint=AnyHttpUrl("https://auth.example.com/token"),
324+
)
325+
rfc7523_oauth_provider.jwt_parameters = JWTParameters(
326+
issuer="test-client",
327+
subject="test-client",
328+
jwt_signing_algorithm="HS256",
329+
jwt_signing_key="a-string-secret-at-least-256-bits-long",
330+
)
331+
332+
token_data: dict = {}
333+
rfc7523_oauth_provider._add_client_authentication_jwt(token_data=token_data)
334+
335+
# Check that JWT assertion parameters were added
336+
assert "client_assertion" in token_data
337+
assert token_data["client_assertion_type"] == "urn:ietf:params:oauth:client-assertion-type:jwt-bearer"
338+
assert token_data["audience"] == "https://api.example.com/v1/mcp"
339+
340+
# Verify the JWT assertion is valid and has correct audience (issuer identifier)
341+
claims = jwt.decode(
342+
token_data["client_assertion"],
343+
key="a-string-secret-at-least-256-bits-long",
344+
algorithms=["HS256"],
345+
audience="https://auth.example.com/", # Should be issuer, not token endpoint
346+
verify=True,
347+
)
348+
assert claims["iss"] == "test-client"
349+
assert claims["sub"] == "test-client"
350+
351+
@pytest.mark.anyio
352+
async def test_exchange_token_authorization_code_with_private_key_jwt(
353+
self, rfc7523_oauth_provider: RFC7523OAuthClientProvider
354+
):
355+
"""Test authorization_code token exchange adds JWT when using private_key_jwt."""
356+
# Set up required context
357+
rfc7523_oauth_provider.context.client_info = OAuthClientInformationFull(
358+
client_id="test-client",
359+
grant_types=["authorization_code"],
360+
token_endpoint_auth_method="private_key_jwt",
361+
redirect_uris=[AnyUrl("http://localhost:3030/callback")],
362+
)
363+
rfc7523_oauth_provider.context.oauth_metadata = OAuthMetadata(
364+
issuer=AnyHttpUrl("https://auth.example.com"),
365+
authorization_endpoint=AnyHttpUrl("https://auth.example.com/authorize"),
366+
token_endpoint=AnyHttpUrl("https://auth.example.com/token"),
367+
)
368+
rfc7523_oauth_provider.context.client_metadata = rfc7523_oauth_provider.context.client_info
369+
rfc7523_oauth_provider.context.protocol_version = "2025-06-18"
370+
rfc7523_oauth_provider.jwt_parameters = JWTParameters(
371+
issuer="test-client",
372+
subject="test-client",
373+
jwt_signing_algorithm="HS256",
374+
jwt_signing_key="a-string-secret-at-least-256-bits-long",
375+
)
376+
377+
request = await rfc7523_oauth_provider._exchange_token_authorization_code(
378+
"test-auth-code", "test-verifier"
379+
)
380+
381+
# Check form data contains JWT assertion
382+
content = urllib.parse.unquote_plus(request.content.decode())
383+
assert "grant_type=authorization_code" in content
384+
assert "code=test-auth-code" in content
385+
assert "client_assertion=" in content
386+
assert "client_assertion_type=urn:ietf:params:oauth:client-assertion-type:jwt-bearer" in content
387+
388+
@pytest.mark.anyio
389+
async def test_exchange_token_authorization_code_without_private_key_jwt(
390+
self, rfc7523_oauth_provider: RFC7523OAuthClientProvider
391+
):
392+
"""Test authorization_code token exchange without private_key_jwt uses standard auth."""
393+
# Set up required context with client_secret_post (not private_key_jwt)
394+
rfc7523_oauth_provider.context.client_info = OAuthClientInformationFull(
395+
client_id="test-client",
396+
client_secret="test-secret",
397+
grant_types=["authorization_code"],
398+
token_endpoint_auth_method="client_secret_post",
399+
redirect_uris=[AnyUrl("http://localhost:3030/callback")],
400+
)
401+
rfc7523_oauth_provider.context.oauth_metadata = OAuthMetadata(
402+
issuer=AnyHttpUrl("https://auth.example.com"),
403+
authorization_endpoint=AnyHttpUrl("https://auth.example.com/authorize"),
404+
token_endpoint=AnyHttpUrl("https://auth.example.com/token"),
405+
)
406+
rfc7523_oauth_provider.context.client_metadata = rfc7523_oauth_provider.context.client_info
407+
rfc7523_oauth_provider.context.protocol_version = "2025-06-18"
408+
409+
request = await rfc7523_oauth_provider._exchange_token_authorization_code(
410+
"test-auth-code", "test-verifier"
411+
)
412+
413+
# Check form data does NOT contain JWT assertion
414+
content = urllib.parse.unquote_plus(request.content.decode())
415+
assert "grant_type=authorization_code" in content
416+
assert "code=test-auth-code" in content
417+
assert "client_assertion=" not in content
418+
# Should have client_secret in body for client_secret_post
419+
assert "client_secret=test-secret" in content
420+
421+
@pytest.mark.anyio
422+
async def test_perform_authorization_falls_back_to_parent(
423+
self, rfc7523_oauth_provider: RFC7523OAuthClientProvider
424+
):
425+
"""Test that _perform_authorization falls back to parent when not client_credentials or jwt-bearer."""
426+
# Set up required context with authorization_code grant (not client_credentials or jwt-bearer)
427+
rfc7523_oauth_provider.context.client_info = OAuthClientInformationFull(
428+
client_id="test-client",
429+
client_secret="test-secret",
430+
grant_types=["authorization_code", "refresh_token"],
431+
token_endpoint_auth_method="client_secret_post",
432+
redirect_uris=[AnyUrl("http://localhost:3030/callback")],
433+
)
434+
rfc7523_oauth_provider.context.oauth_metadata = OAuthMetadata(
435+
issuer=AnyHttpUrl("https://auth.example.com"),
436+
authorization_endpoint=AnyHttpUrl("https://auth.example.com/authorize"),
437+
token_endpoint=AnyHttpUrl("https://auth.example.com/token"),
438+
)
439+
rfc7523_oauth_provider.context.client_metadata = rfc7523_oauth_provider.context.client_info
440+
rfc7523_oauth_provider.context.protocol_version = "2025-06-18"
441+
442+
# Mock the parent class's _perform_authorization since it would try to do real OAuth
443+
from unittest.mock import AsyncMock, patch
444+
445+
mock_request = AsyncMock()
446+
with patch.object(
447+
rfc7523_oauth_provider.__class__.__bases__[0],
448+
"_perform_authorization",
449+
new=AsyncMock(return_value=mock_request),
450+
) as mock_parent:
451+
result = await rfc7523_oauth_provider._perform_authorization()
452+
mock_parent.assert_called_once()
453+
assert result == mock_request

0 commit comments

Comments
 (0)