Skip to content

Commit 6592391

Browse files
authored
Merge pull request #34 from sacha-development-stuff/codex/fix-type-annotations-in-tests
Fix pyright typing issues in tests
2 parents b2a3b27 + 20215fb commit 6592391

File tree

4 files changed

+83
-45
lines changed

4 files changed

+83
-45
lines changed

tests/unit/client/test_oauth2_providers.py

Lines changed: 44 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
import base64
22
import time
3-
from types import SimpleNamespace
3+
from types import SimpleNamespace, TracebackType
4+
from typing import Iterator, cast
45

56
import httpx
67
import pytest
8+
from pydantic import AnyUrl
79

810
from mcp.client.auth.oauth2 import (
911
ClientCredentialsProvider,
@@ -49,7 +51,12 @@ def __init__(
4951
async def __aenter__(self) -> "DummyAsyncClient":
5052
return self
5153

52-
async def __aexit__(self, exc_type, exc, tb) -> None:
54+
async def __aexit__(
55+
self,
56+
exc_type: type[BaseException] | None,
57+
exc: BaseException | None,
58+
tb: TracebackType | None,
59+
) -> bool | None:
5360
return None
5461

5562
async def send(self, request: httpx.Request) -> httpx.Response:
@@ -63,12 +70,16 @@ async def post(self, url: str, *, data: dict[str, str], headers: dict[str, str])
6370

6471
class AsyncClientFactory:
6572
def __init__(self, clients: list[DummyAsyncClient]) -> None:
66-
self._clients = iter(clients)
73+
self._clients: Iterator[DummyAsyncClient] = iter(clients)
6774

68-
def __call__(self, *args, **kwargs) -> DummyAsyncClient:
75+
def __call__(self, *args: object, **kwargs: object) -> DummyAsyncClient:
6976
return next(self._clients)
7077

7178

79+
def _redirect_uris() -> list[AnyUrl]:
80+
return cast(list[AnyUrl], ["https://client.example.com/callback"])
81+
82+
7283
def _metadata_json() -> dict[str, object]:
7384
return {
7485
"issuer": "https://auth.example.com",
@@ -107,7 +118,7 @@ def _make_response(status: int, *, json_data: dict[str, object] | None = None) -
107118
@pytest.mark.anyio
108119
async def test_handle_oauth_metadata_response_sets_scope() -> None:
109120
storage = InMemoryStorage()
110-
metadata = OAuthClientMetadata(redirect_uris=["https://client.example.com/callback"])
121+
metadata = OAuthClientMetadata(redirect_uris=_redirect_uris())
111122
provider = ClientCredentialsProvider(
112123
"https://api.example.com/service",
113124
metadata,
@@ -130,7 +141,7 @@ async def test_client_credentials_initialize_loads_cached_values() -> None:
130141
storage.tokens = stored_token
131142
storage.client_info = stored_client
132143

133-
metadata = OAuthClientMetadata(redirect_uris=["https://client.example.com/callback"])
144+
metadata = OAuthClientMetadata(redirect_uris=_redirect_uris())
134145
provider = ClientCredentialsProvider("https://api.example.com/service", metadata, storage)
135146

136147
await provider.initialize()
@@ -141,7 +152,7 @@ async def test_client_credentials_initialize_loads_cached_values() -> None:
141152

142153
def test_create_registration_request_uses_cached_client_info() -> None:
143154
storage = InMemoryStorage()
144-
metadata = OAuthClientMetadata(redirect_uris=["https://client.example.com/callback"])
155+
metadata = OAuthClientMetadata(redirect_uris=_redirect_uris())
145156
provider = ClientCredentialsProvider(
146157
"https://api.example.com/service",
147158
metadata,
@@ -155,7 +166,7 @@ def test_create_registration_request_uses_cached_client_info() -> None:
155166

156167
def test_create_registration_request_uses_context() -> None:
157168
storage = InMemoryStorage()
158-
metadata = OAuthClientMetadata(redirect_uris=["https://client.example.com/callback"])
169+
metadata = OAuthClientMetadata(redirect_uris=_redirect_uris())
159170
provider = ClientCredentialsProvider(
160171
"https://api.example.com/service",
161172
metadata,
@@ -172,7 +183,7 @@ def test_create_registration_request_uses_context() -> None:
172183

173184
def test_create_registration_request_builds_url_from_metadata() -> None:
174185
storage = InMemoryStorage()
175-
metadata = OAuthClientMetadata(redirect_uris=["https://client.example.com/callback"])
186+
metadata = OAuthClientMetadata(redirect_uris=_redirect_uris())
176187
provider = ClientCredentialsProvider(
177188
"https://api.example.com/service",
178189
metadata,
@@ -187,7 +198,7 @@ def test_create_registration_request_builds_url_from_metadata() -> None:
187198

188199
def test_create_registration_request_builds_url_from_server() -> None:
189200
storage = InMemoryStorage()
190-
metadata = OAuthClientMetadata(redirect_uris=["https://client.example.com/callback"])
201+
metadata = OAuthClientMetadata(redirect_uris=_redirect_uris())
191202
provider = ClientCredentialsProvider(
192203
"https://api.example.com/service/path",
193204
metadata,
@@ -201,7 +212,7 @@ def test_create_registration_request_builds_url_from_server() -> None:
201212

202213
def test_apply_client_auth_requires_client_id() -> None:
203214
storage = InMemoryStorage()
204-
metadata = OAuthClientMetadata(redirect_uris=["https://client.example.com/callback"])
215+
metadata = OAuthClientMetadata(redirect_uris=_redirect_uris())
205216
provider = ClientCredentialsProvider("https://api.example.com/service", metadata, storage)
206217

207218
with pytest.raises(OAuthFlowError):
@@ -210,7 +221,7 @@ def test_apply_client_auth_requires_client_id() -> None:
210221

211222
def test_apply_client_auth_basic() -> None:
212223
storage = InMemoryStorage()
213-
metadata = OAuthClientMetadata(redirect_uris=["https://client.example.com/callback"])
224+
metadata = OAuthClientMetadata(redirect_uris=_redirect_uris())
214225
provider = ClientCredentialsProvider("https://api.example.com/service", metadata, storage)
215226
provider._metadata = OAuthMetadata.model_validate(
216227
{**_metadata_json(), "token_endpoint_auth_methods_supported": ["client_secret_basic"]}
@@ -229,7 +240,7 @@ def test_apply_client_auth_basic() -> None:
229240

230241
def test_apply_client_auth_basic_requires_secret() -> None:
231242
storage = InMemoryStorage()
232-
metadata = OAuthClientMetadata(redirect_uris=["https://client.example.com/callback"])
243+
metadata = OAuthClientMetadata(redirect_uris=_redirect_uris())
233244
provider = ClientCredentialsProvider("https://api.example.com/service", metadata, storage)
234245
provider._metadata = OAuthMetadata.model_validate(
235246
{**_metadata_json(), "token_endpoint_auth_methods_supported": ["client_secret_basic"]}
@@ -241,7 +252,7 @@ def test_apply_client_auth_basic_requires_secret() -> None:
241252

242253
def test_apply_client_auth_post_method() -> None:
243254
storage = InMemoryStorage()
244-
metadata = OAuthClientMetadata(redirect_uris=["https://client.example.com/callback"])
255+
metadata = OAuthClientMetadata(redirect_uris=_redirect_uris())
245256
provider = ClientCredentialsProvider("https://api.example.com/service", metadata, storage)
246257
provider._metadata = OAuthMetadata.model_validate(
247258
{**_metadata_json(), "token_endpoint_auth_methods_supported": ["client_secret_post"]}
@@ -259,9 +270,9 @@ def test_apply_client_auth_post_method() -> None:
259270

260271

261272
@pytest.mark.anyio
262-
async def test_client_credentials_request_token_with_metadata(monkeypatch) -> None:
273+
async def test_client_credentials_request_token_with_metadata(monkeypatch: pytest.MonkeyPatch) -> None:
263274
storage = InMemoryStorage()
264-
client_metadata = OAuthClientMetadata(redirect_uris=["https://client.example.com/callback"])
275+
client_metadata = OAuthClientMetadata(redirect_uris=_redirect_uris())
265276
provider = ClientCredentialsProvider("https://api.example.com/service", client_metadata, storage)
266277

267278
metadata_response = _make_response(200, json_data=_metadata_json())
@@ -286,9 +297,9 @@ async def test_client_credentials_request_token_with_metadata(monkeypatch) -> No
286297

287298

288299
@pytest.mark.anyio
289-
async def test_client_credentials_request_token_without_metadata(monkeypatch) -> None:
300+
async def test_client_credentials_request_token_without_metadata(monkeypatch: pytest.MonkeyPatch) -> None:
290301
storage = InMemoryStorage()
291-
client_metadata = OAuthClientMetadata(redirect_uris=["https://client.example.com/callback"], scope="alpha")
302+
client_metadata = OAuthClientMetadata(redirect_uris=_redirect_uris(), scope="alpha")
292303
provider = ClientCredentialsProvider("https://api.example.com/service", client_metadata, storage)
293304

294305
metadata_responses = [_make_response(404) for _ in range(4)]
@@ -312,7 +323,7 @@ async def test_client_credentials_request_token_without_metadata(monkeypatch) ->
312323
@pytest.mark.anyio
313324
async def test_client_credentials_ensure_token_returns_when_valid() -> None:
314325
storage = InMemoryStorage()
315-
client_metadata = OAuthClientMetadata(redirect_uris=["https://client.example.com/callback"])
326+
client_metadata = OAuthClientMetadata(redirect_uris=_redirect_uris())
316327
provider = ClientCredentialsProvider("https://api.example.com/service", client_metadata, storage)
317328
provider._current_tokens = OAuthToken(access_token="token")
318329
provider._token_expiry_time = time.time() + 60
@@ -334,7 +345,7 @@ async def fake_request_token() -> None:
334345
@pytest.mark.anyio
335346
async def test_client_credentials_validate_token_scopes_rejects_extra() -> None:
336347
storage = InMemoryStorage()
337-
client_metadata = OAuthClientMetadata(redirect_uris=["https://client.example.com/callback"], scope="alpha")
348+
client_metadata = OAuthClientMetadata(redirect_uris=_redirect_uris(), scope="alpha")
338349
provider = ClientCredentialsProvider("https://api.example.com/service", client_metadata, storage)
339350

340351
token = OAuthToken(access_token="token", scope="alpha beta")
@@ -346,7 +357,7 @@ async def test_client_credentials_validate_token_scopes_rejects_extra() -> None:
346357
@pytest.mark.anyio
347358
async def test_client_credentials_validate_token_scopes_accepts_server_defined() -> None:
348359
storage = InMemoryStorage()
349-
client_metadata = OAuthClientMetadata(redirect_uris=["https://client.example.com/callback"], scope=None)
360+
client_metadata = OAuthClientMetadata(redirect_uris=_redirect_uris(), scope=None)
350361
provider = ClientCredentialsProvider("https://api.example.com/service", client_metadata, storage)
351362

352363
token = OAuthToken(access_token="token", scope="delta")
@@ -355,9 +366,9 @@ async def test_client_credentials_validate_token_scopes_accepts_server_defined()
355366

356367

357368
@pytest.mark.anyio
358-
async def test_client_credentials_async_auth_flow_handles_401(monkeypatch) -> None:
369+
async def test_client_credentials_async_auth_flow_handles_401(monkeypatch: pytest.MonkeyPatch) -> None:
359370
storage = InMemoryStorage()
360-
client_metadata = OAuthClientMetadata(redirect_uris=["https://client.example.com/callback"])
371+
client_metadata = OAuthClientMetadata(redirect_uris=_redirect_uris())
361372
provider = ClientCredentialsProvider("https://api.example.com/service", client_metadata, storage)
362373

363374
async def fake_initialize() -> None:
@@ -372,7 +383,7 @@ async def fake_ensure_token() -> None:
372383
request = httpx.Request("GET", "https://api.example.com/resource")
373384
flow = provider.async_auth_flow(request)
374385

375-
prepared_request = await flow.asend(None)
386+
prepared_request = await anext(flow)
376387
assert prepared_request.headers["Authorization"] == "Bearer flow-token"
377388

378389
response = httpx.Response(401, request=prepared_request)
@@ -383,9 +394,9 @@ async def fake_ensure_token() -> None:
383394

384395

385396
@pytest.mark.anyio
386-
async def test_token_exchange_request_token(monkeypatch) -> None:
397+
async def test_token_exchange_request_token(monkeypatch: pytest.MonkeyPatch) -> None:
387398
storage = InMemoryStorage()
388-
client_metadata = OAuthClientMetadata(redirect_uris=["https://client.example.com/callback"], scope="alpha")
399+
client_metadata = OAuthClientMetadata(redirect_uris=_redirect_uris(), scope="alpha")
389400

390401
async def provide_subject() -> str:
391402
return "subject-token"
@@ -432,7 +443,7 @@ async def test_token_exchange_initialize_loads_cached_values() -> None:
432443
storage.tokens = stored_token
433444
storage.client_info = stored_client
434445

435-
client_metadata = OAuthClientMetadata(redirect_uris=["https://client.example.com/callback"])
446+
client_metadata = OAuthClientMetadata(redirect_uris=_redirect_uris())
436447

437448
async def provide_subject() -> str:
438449
return "subject-token"
@@ -453,7 +464,7 @@ async def provide_subject() -> str:
453464
@pytest.mark.anyio
454465
async def test_token_exchange_validate_token_scopes_rejects_extra() -> None:
455466
storage = InMemoryStorage()
456-
client_metadata = OAuthClientMetadata(redirect_uris=["https://client.example.com/callback"], scope="alpha")
467+
client_metadata = OAuthClientMetadata(redirect_uris=_redirect_uris(), scope="alpha")
457468

458469
async def provide_subject() -> str:
459470
return "subject-token"
@@ -474,7 +485,7 @@ async def provide_subject() -> str:
474485
@pytest.mark.anyio
475486
async def test_token_exchange_validate_token_scopes_accepts_server_defined() -> None:
476487
storage = InMemoryStorage()
477-
client_metadata = OAuthClientMetadata(redirect_uris=["https://client.example.com/callback"], scope=None)
488+
client_metadata = OAuthClientMetadata(redirect_uris=_redirect_uris(), scope=None)
478489

479490
async def provide_subject() -> str:
480491
return "subject-token"
@@ -492,9 +503,9 @@ async def provide_subject() -> str:
492503

493504

494505
@pytest.mark.anyio
495-
async def test_token_exchange_async_auth_flow_handles_401(monkeypatch) -> None:
506+
async def test_token_exchange_async_auth_flow_handles_401(monkeypatch: pytest.MonkeyPatch) -> None:
496507
storage = InMemoryStorage()
497-
client_metadata = OAuthClientMetadata(redirect_uris=["https://client.example.com/callback"])
508+
client_metadata = OAuthClientMetadata(redirect_uris=_redirect_uris())
498509

499510
async def provide_subject() -> str:
500511
return "subject-token"
@@ -518,7 +529,7 @@ async def fake_ensure_token() -> None:
518529
request = httpx.Request("GET", "https://api.example.com/resource")
519530
flow = provider.async_auth_flow(request)
520531

521-
prepared_request = await flow.asend(None)
532+
prepared_request = await anext(flow)
522533
assert prepared_request.headers["Authorization"] == "Bearer flow-token"
523534

524535
response = httpx.Response(401, request=prepared_request)
@@ -531,7 +542,7 @@ async def fake_ensure_token() -> None:
531542
@pytest.mark.anyio
532543
async def test_token_exchange_ensure_token_returns_when_valid() -> None:
533544
storage = InMemoryStorage()
534-
client_metadata = OAuthClientMetadata(redirect_uris=["https://client.example.com/callback"])
545+
client_metadata = OAuthClientMetadata(redirect_uris=_redirect_uris())
535546

536547
async def provide_subject() -> str:
537548
return "subject-token"

tests/unit/client/test_stdio_client.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
import anyio
44
import pytest
5+
from types import TracebackType
6+
from typing import Any
57

68
from mcp.client import stdio as stdio_module
79
from mcp.client.stdio import StdioServerParameters, stdio_client
@@ -23,15 +25,20 @@ def __init__(self) -> None:
2325
async def __aenter__(self) -> "DummyProcess":
2426
return self
2527

26-
async def __aexit__(self, exc_type, exc, tb) -> None:
28+
async def __aexit__(
29+
self,
30+
exc_type: type[BaseException] | None,
31+
exc: BaseException | None,
32+
tb: TracebackType | None,
33+
) -> bool | None:
2734
return None
2835

2936
async def wait(self) -> None:
3037
return None
3138

3239

3340
class BrokenPipeStream:
34-
def __init__(self, *args, **kwargs) -> None:
41+
def __init__(self, *args: Any, **kwargs: Any) -> None:
3542
pass
3643

3744
def __aiter__(self) -> "BrokenPipeStream":
@@ -42,14 +49,14 @@ async def __anext__(self) -> str:
4249

4350

4451
@pytest.mark.anyio
45-
async def test_stdio_client_handles_broken_pipe(monkeypatch) -> None:
52+
async def test_stdio_client_handles_broken_pipe(monkeypatch: pytest.MonkeyPatch) -> None:
4653
server = StdioServerParameters(command="dummy")
4754

4855
async def fake_checkpoint() -> None:
4956
nonlocal checkpoint_calls
5057
checkpoint_calls += 1
5158

52-
async def fake_create_process(*args, **kwargs) -> DummyProcess:
59+
async def fake_create_process(*args: object, **kwargs: object) -> DummyProcess:
5360
return DummyProcess()
5461

5562
checkpoint_calls = 0

0 commit comments

Comments
 (0)