Skip to content

Commit 6263b53

Browse files
authored
Merge pull request #42 from sacha-development-stuff/codex/fix-coverage-failure-in-tests-pz6y8h
Add tests to cover OAuth token flows
2 parents ea8b2ef + d8bef42 commit 6263b53

File tree

2 files changed

+121
-0
lines changed

2 files changed

+121
-0
lines changed

tests/unit/client/test_oauth2_providers.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -705,6 +705,57 @@ async def post(self, url: str, *, data: dict[str, str], headers: dict[str, str])
705705
assert "resource" not in recorded_client.last_data
706706

707707

708+
@pytest.mark.anyio
709+
async def test_token_exchange_request_token_skips_client_error_and_omits_scope(
710+
monkeypatch: pytest.MonkeyPatch,
711+
) -> None:
712+
storage = InMemoryStorage()
713+
client_metadata = OAuthClientMetadata(redirect_uris=_redirect_uris())
714+
715+
subject_supplier = AsyncMock(return_value="subject-token")
716+
717+
provider = TokenExchangeProvider(
718+
"https://api.example.com/service",
719+
client_metadata,
720+
storage,
721+
subject_token_supplier=subject_supplier,
722+
)
723+
724+
metadata_without_scopes = _metadata_json()
725+
metadata_without_scopes.pop("scopes_supported", None)
726+
727+
metadata_responses = [
728+
_make_response(404),
729+
_make_response(200, json_data=metadata_without_scopes),
730+
]
731+
registration_response = _make_response(200, json_data=_registration_json())
732+
733+
class RecordingAsyncClient(DummyAsyncClient):
734+
def __init__(self) -> None:
735+
super().__init__(post_responses=[_make_response(200, json_data=_token_json())])
736+
self.last_data: dict[str, str] | None = None
737+
738+
async def post(
739+
self, url: str, *, data: dict[str, str], headers: dict[str, str]
740+
) -> httpx.Response:
741+
self.last_data = data
742+
return await super().post(url, data=data, headers=headers)
743+
744+
clients: list[DummyAsyncClient] = [
745+
DummyAsyncClient(send_responses=metadata_responses),
746+
DummyAsyncClient(send_responses=[registration_response]),
747+
RecordingAsyncClient(),
748+
]
749+
monkeypatch.setattr("mcp.client.auth.oauth2.httpx.AsyncClient", AsyncClientFactory(clients))
750+
751+
await provider._request_token()
752+
753+
recorded_client = cast(RecordingAsyncClient, clients[-1])
754+
assert recorded_client.last_data is not None
755+
assert "scope" not in recorded_client.last_data
756+
assert provider.client_metadata.scope is None
757+
758+
708759
@pytest.mark.anyio
709760
async def test_token_exchange_request_token_raises_on_failure(monkeypatch: pytest.MonkeyPatch) -> None:
710761
storage = InMemoryStorage()

tests/unit/server/auth/test_token_handler.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,15 @@ async def exchange_client_credentials(self, client_info: object, scopes: list[st
5252
raise TokenError(error="invalid_client", error_description="bad credentials")
5353

5454

55+
class ClientCredentialsProviderSuccess:
56+
def __init__(self) -> None:
57+
self.last_scopes: list[str] | None = None
58+
59+
async def exchange_client_credentials(self, client_info: object, scopes: list[str]) -> OAuthToken:
60+
self.last_scopes = scopes
61+
return OAuthToken(access_token="client-token")
62+
63+
5564
class TokenExchangeProviderStub:
5665
def __init__(self) -> None:
5766
self.last_call: dict[str, Any] | None = None
@@ -155,6 +164,67 @@ async def test_handle_client_credentials_returns_token_error() -> None:
155164
assert result.error_description == "bad credentials"
156165

157166

167+
@pytest.mark.anyio
168+
async def test_handle_route_authorization_code_branch() -> None:
169+
code_verifier = "a" * 64
170+
digest = hashlib.sha256(code_verifier.encode()).digest()
171+
code_challenge = base64.urlsafe_b64encode(digest).decode().rstrip("=")
172+
173+
provider = AuthorizationCodeProvider(expected_code="auth-code", code_challenge=code_challenge)
174+
client_info = OAuthClientInformationFull(
175+
client_id="client",
176+
grant_types=["authorization_code"],
177+
scope="alpha",
178+
)
179+
handler = TokenHandler(
180+
provider=cast(OAuthAuthorizationServerProvider[Any, Any, Any], provider),
181+
client_authenticator=cast(ClientAuthenticator, DummyAuthenticator(client_info)),
182+
)
183+
184+
request_data = {
185+
"grant_type": "authorization_code",
186+
"code": "auth-code",
187+
"redirect_uri": None,
188+
"client_id": "client",
189+
"client_secret": "secret",
190+
"code_verifier": code_verifier,
191+
}
192+
193+
response = await handler.handle(cast(Request, DummyRequest(request_data)))
194+
195+
assert response.status_code == 200
196+
payload = json.loads(bytes(response.body).decode())
197+
assert payload["access_token"] == "auth-token"
198+
199+
200+
@pytest.mark.anyio
201+
async def test_handle_route_client_credentials_branch() -> None:
202+
provider = ClientCredentialsProviderSuccess()
203+
client_info = OAuthClientInformationFull(
204+
client_id="client",
205+
grant_types=["client_credentials"],
206+
scope="alpha beta",
207+
)
208+
handler = TokenHandler(
209+
provider=cast(OAuthAuthorizationServerProvider[Any, Any, Any], provider),
210+
client_authenticator=cast(ClientAuthenticator, DummyAuthenticator(client_info)),
211+
)
212+
213+
request_data = {
214+
"grant_type": "client_credentials",
215+
"scope": "beta",
216+
"client_id": "client",
217+
"client_secret": "secret",
218+
}
219+
220+
response = await handler.handle(cast(Request, DummyRequest(request_data)))
221+
222+
assert response.status_code == 200
223+
payload = json.loads(bytes(response.body).decode())
224+
assert payload["access_token"] == "client-token"
225+
assert provider.last_scopes == ["beta"]
226+
227+
158228
@pytest.mark.anyio
159229
async def test_handle_route_refresh_token_branch() -> None:
160230
provider = RefreshTokenProvider()

0 commit comments

Comments
 (0)