Skip to content

Commit a8067e1

Browse files
committed
Add tests for client_credentials flow
1 parent ed2a486 commit a8067e1

File tree

1 file changed

+63
-1
lines changed

1 file changed

+63
-1
lines changed

tests/client/test_auth.py

Lines changed: 63 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,10 @@
22
Tests for refactored OAuth client authentication implementation.
33
"""
44

5+
import base64
56
import time
7+
import urllib
8+
import urllib.parse
69

710
import httpx
811
import pytest
@@ -386,7 +389,7 @@ async def test_register_client_skip_if_registered(self, oauth_provider, mock_sto
386389
assert request is None
387390

388391
@pytest.mark.anyio
389-
async def test_token_exchange_request(self, oauth_provider):
392+
async def test_token_exchange_request_authorization_code(self, oauth_provider):
390393
"""Test token exchange request building."""
391394
# Set up required context
392395
oauth_provider.context.client_info = OAuthClientInformationFull(
@@ -409,6 +412,65 @@ async def test_token_exchange_request(self, oauth_provider):
409412
assert "client_id=test_client" in content
410413
assert "client_secret=test_secret" in content
411414

415+
@pytest.mark.anyio
416+
async def test_token_exchange_request_client_credentials_basic(self, oauth_provider):
417+
"""Test token exchange request building."""
418+
# Set up required context
419+
oauth_provider.context.client_info = oauth_provider.context.client_metadata = OAuthClientInformationFull(
420+
grant_types=["client_credentials"],
421+
token_endpoint_auth_method="client_secret_basic",
422+
client_id="test_client",
423+
client_secret="test_secret",
424+
redirect_uris=None,
425+
scope="read write",
426+
)
427+
428+
request = await oauth_provider._exchange_token_client_credentials()
429+
430+
assert request.method == "POST"
431+
assert str(request.url) == "https://api.example.com/token"
432+
assert request.headers["Content-Type"] == "application/x-www-form-urlencoded"
433+
434+
# Check form data
435+
content = urllib.parse.unquote_plus(request.content.decode())
436+
assert "grant_type=client_credentials" in content
437+
assert "scope=read write" in content
438+
assert "resource=https://api.example.com/v1/mcp" in content
439+
assert "client_id=test_client" not in content
440+
assert "client_secret=test_secret" not in content
441+
442+
# Check auth header
443+
assert "Authorization" in request.headers
444+
assert request.headers["Authorization"].startswith("Basic ")
445+
assert base64.b64decode(request.headers["Authorization"].split(" ")[1]).decode() == "test_client:test_secret"
446+
447+
@pytest.mark.anyio
448+
async def test_token_exchange_request_client_credentials_post(self, oauth_provider):
449+
"""Test token exchange request building."""
450+
# Set up required context
451+
oauth_provider.context.client_info = oauth_provider.context.client_metadata = OAuthClientInformationFull(
452+
grant_types=["client_credentials"],
453+
token_endpoint_auth_method="client_secret_post",
454+
client_id="test_client",
455+
client_secret="test_secret",
456+
redirect_uris=None,
457+
scope="read write",
458+
)
459+
460+
request = await oauth_provider._exchange_token_client_credentials()
461+
462+
assert request.method == "POST"
463+
assert str(request.url) == "https://api.example.com/token"
464+
assert request.headers["Content-Type"] == "application/x-www-form-urlencoded"
465+
466+
# Check form data
467+
content = urllib.parse.unquote_plus(request.content.decode())
468+
assert "grant_type=client_credentials" in content
469+
assert "scope=read write" in content
470+
assert "resource=https://api.example.com/v1/mcp" in content
471+
assert "client_id=test_client" in content
472+
assert "client_secret=test_secret" in content
473+
412474
@pytest.mark.anyio
413475
async def test_refresh_token_request(self, oauth_provider, valid_tokens):
414476
"""Test refresh token request building."""

0 commit comments

Comments
 (0)