|
3 | 3 | import json |
4 | 4 | import time |
5 | 5 | from collections.abc import Mapping |
6 | | -from types import SimpleNamespace |
| 6 | +from types import MethodType, SimpleNamespace |
7 | 7 | from typing import Any, cast |
8 | 8 |
|
9 | 9 | import pytest |
|
12 | 12 | from mcp.server.auth.handlers.token import ( |
13 | 13 | AuthorizationCodeRequest, |
14 | 14 | ClientCredentialsRequest, |
| 15 | + RefreshTokenRequest, |
15 | 16 | TokenErrorResponse, |
16 | 17 | TokenHandler, |
17 | 18 | TokenSuccessResponse, |
@@ -287,6 +288,51 @@ async def test_handle_route_refresh_token_invalid_scope() -> None: |
287 | 288 | } |
288 | 289 |
|
289 | 290 |
|
| 291 | +@pytest.mark.anyio |
| 292 | +async def test_handle_route_refresh_token_dispatches_to_handler( |
| 293 | + monkeypatch: pytest.MonkeyPatch, |
| 294 | +) -> None: |
| 295 | + provider = RefreshTokenProvider() |
| 296 | + client_info = OAuthClientInformationFull( |
| 297 | + client_id="client", |
| 298 | + grant_types=["refresh_token"], |
| 299 | + scope="alpha", |
| 300 | + ) |
| 301 | + handler = TokenHandler( |
| 302 | + provider=cast(OAuthAuthorizationServerProvider[Any, Any, Any], provider), |
| 303 | + client_authenticator=cast(ClientAuthenticator, DummyAuthenticator(client_info)), |
| 304 | + ) |
| 305 | + |
| 306 | + captured_requests: list[RefreshTokenRequest] = [] |
| 307 | + |
| 308 | + async def fake_handle_refresh_token( |
| 309 | + self: TokenHandler, |
| 310 | + client: OAuthClientInformationFull, |
| 311 | + token_request: RefreshTokenRequest, |
| 312 | + ) -> TokenSuccessResponse: |
| 313 | + captured_requests.append(token_request) |
| 314 | + return TokenSuccessResponse(root=OAuthToken(access_token="dispatched-token")) |
| 315 | + |
| 316 | + monkeypatch.setattr( |
| 317 | + handler, |
| 318 | + "_handle_refresh_token", |
| 319 | + MethodType(fake_handle_refresh_token, handler), |
| 320 | + ) |
| 321 | + |
| 322 | + request_data = { |
| 323 | + "grant_type": "refresh_token", |
| 324 | + "refresh_token": "refresh-token", |
| 325 | + "client_id": "client", |
| 326 | + "client_secret": "secret", |
| 327 | + } |
| 328 | + |
| 329 | + response = await handler.handle(cast(Request, DummyRequest(request_data))) |
| 330 | + |
| 331 | + assert response.status_code == 200 |
| 332 | + assert captured_requests |
| 333 | + assert isinstance(captured_requests[0], RefreshTokenRequest) |
| 334 | + |
| 335 | + |
290 | 336 | @pytest.mark.anyio |
291 | 337 | async def test_handle_route_token_exchange_branch() -> None: |
292 | 338 | provider = TokenExchangeProviderStub() |
|
0 commit comments