Skip to content

Commit fedadb3

Browse files
committed
Add tests covering OAuth scope and discovery branches
1 parent 308bc63 commit fedadb3

File tree

1 file changed

+101
-0
lines changed

1 file changed

+101
-0
lines changed

tests/unit/client/test_oauth2_providers.py

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -464,6 +464,34 @@ async def test_client_credentials_request_token_without_metadata(monkeypatch: py
464464
assert provider._metadata is None
465465

466466

467+
@pytest.mark.anyio
468+
async def test_client_credentials_request_token_omits_scope_when_unset(monkeypatch: pytest.MonkeyPatch) -> None:
469+
storage = InMemoryStorage()
470+
client_metadata = OAuthClientMetadata(redirect_uris=_redirect_uris(), scope=None)
471+
provider = ClientCredentialsProvider("https://api.example.com/service", client_metadata, storage)
472+
473+
provider._metadata = OAuthMetadata.model_validate(_metadata_json())
474+
provider._client_info = OAuthClientInformationFull(client_id="client", client_secret="secret")
475+
476+
class RecordingAsyncClient(DummyAsyncClient):
477+
def __init__(self) -> None:
478+
super().__init__(post_responses=[_make_response(200, json_data=_token_json())])
479+
self.last_data: dict[str, str] | None = None
480+
481+
async def post(self, url: str, *, data: dict[str, str], headers: dict[str, str]) -> httpx.Response:
482+
self.last_data = data
483+
return await super().post(url, data=data, headers=headers)
484+
485+
clients: list[DummyAsyncClient] = [RecordingAsyncClient()]
486+
monkeypatch.setattr("mcp.client.auth.oauth2.httpx.AsyncClient", AsyncClientFactory(clients))
487+
488+
await provider._request_token()
489+
490+
recorded_client = cast(RecordingAsyncClient, clients[0])
491+
assert recorded_client.last_data is not None
492+
assert "scope" not in recorded_client.last_data
493+
494+
467495
@pytest.mark.anyio
468496
async def test_client_credentials_ensure_token_returns_when_valid() -> None:
469497
storage = InMemoryStorage()
@@ -668,6 +696,43 @@ async def test_token_exchange_request_token_handles_invalid_metadata(monkeypatch
668696
actor_supplier.assert_awaited_once()
669697

670698

699+
@pytest.mark.anyio
700+
async def test_token_exchange_request_token_skips_discovery_when_no_urls(monkeypatch: pytest.MonkeyPatch) -> None:
701+
storage = InMemoryStorage()
702+
client_metadata = OAuthClientMetadata(redirect_uris=_redirect_uris())
703+
704+
subject_supplier = AsyncMock(return_value="subject-token")
705+
706+
provider = TokenExchangeProvider(
707+
"https://api.example.com/service",
708+
client_metadata,
709+
storage,
710+
subject_token_supplier=subject_supplier,
711+
)
712+
713+
provider._client_info = OAuthClientInformationFull(client_id="client", client_secret="secret")
714+
provider._get_discovery_urls = MethodType(lambda self, server_url=None: [], provider)
715+
716+
class RecordingAsyncClient(DummyAsyncClient):
717+
def __init__(self) -> None:
718+
super().__init__(post_responses=[_make_response(200, json_data=_token_json())])
719+
self.last_data: dict[str, str] | None = None
720+
721+
async def post(self, url: str, *, data: dict[str, str], headers: dict[str, str]) -> httpx.Response:
722+
self.last_data = data
723+
return await super().post(url, data=data, headers=headers)
724+
725+
clients: list[DummyAsyncClient] = [RecordingAsyncClient()]
726+
monkeypatch.setattr("mcp.client.auth.oauth2.httpx.AsyncClient", AsyncClientFactory(clients))
727+
728+
await provider._request_token()
729+
730+
recorded_client = cast(RecordingAsyncClient, clients[0])
731+
assert recorded_client.last_data is not None
732+
assert subject_supplier.await_count == 1
733+
assert provider._metadata is None
734+
735+
671736
@pytest.mark.anyio
672737
async def test_token_exchange_request_token_excludes_resource_when_unset(monkeypatch: pytest.MonkeyPatch) -> None:
673738
storage = InMemoryStorage()
@@ -705,6 +770,42 @@ async def post(self, url: str, *, data: dict[str, str], headers: dict[str, str])
705770
assert "resource" not in recorded_client.last_data
706771

707772

773+
@pytest.mark.anyio
774+
async def test_token_exchange_request_token_omits_scope_when_unset(monkeypatch: pytest.MonkeyPatch) -> None:
775+
storage = InMemoryStorage()
776+
client_metadata = OAuthClientMetadata(redirect_uris=_redirect_uris(), scope=None)
777+
778+
subject_supplier = AsyncMock(return_value="subject-token")
779+
780+
provider = TokenExchangeProvider(
781+
"https://api.example.com/service",
782+
client_metadata,
783+
storage,
784+
subject_token_supplier=subject_supplier,
785+
)
786+
787+
provider._metadata = OAuthMetadata.model_validate(_metadata_json())
788+
provider._client_info = OAuthClientInformationFull(client_id="client", client_secret="secret")
789+
790+
class RecordingAsyncClient(DummyAsyncClient):
791+
def __init__(self) -> None:
792+
super().__init__(post_responses=[_make_response(200, json_data=_token_json())])
793+
self.last_data: dict[str, str] | None = None
794+
795+
async def post(self, url: str, *, data: dict[str, str], headers: dict[str, str]) -> httpx.Response:
796+
self.last_data = data
797+
return await super().post(url, data=data, headers=headers)
798+
799+
clients: list[DummyAsyncClient] = [RecordingAsyncClient()]
800+
monkeypatch.setattr("mcp.client.auth.oauth2.httpx.AsyncClient", AsyncClientFactory(clients))
801+
802+
await provider._request_token()
803+
804+
recorded_client = cast(RecordingAsyncClient, clients[0])
805+
assert recorded_client.last_data is not None
806+
assert "scope" not in recorded_client.last_data
807+
808+
708809
@pytest.mark.anyio
709810
async def test_token_exchange_request_token_raises_on_failure(monkeypatch: pytest.MonkeyPatch) -> None:
710811
storage = InMemoryStorage()

0 commit comments

Comments
 (0)