@@ -1226,24 +1226,14 @@ async def sampling_callback(
12261226 )
12271227
12281228
1229- class MockAuthClientProvider :
1230- """Mock implementation of AuthClientProvider for testing."""
1231-
1232- def __init__ (self , token : str ):
1233- self .token = token
1234-
1235- async def get_auth_headers (self ) -> dict [str , str ]:
1236- return {"Authorization" : "Bearer " + self .token }
1237-
1238-
12391229@pytest .mark .anyio
12401230async def test_auth_client_provider_headers (basic_server , basic_server_url ):
12411231 """Test that auth token provider correctly sets Authorization header."""
12421232 # Create a mock token provider
1243- client_provider = MockAuthClientProvider ( "test-token-123" )
1244- client_provider .get_auth_headers = AsyncMock (
1245- return_value = { "Authorization" : "Bearer test-token-123" }
1246- )
1233+ client_provider = AsyncMock ( )
1234+ client_provider .get_headers . return_value = {
1235+ "Authorization" : "Bearer test-token-123"
1236+ }
12471237
12481238 # Create client with token provider
12491239 async with streamablehttp_client (
@@ -1258,17 +1248,17 @@ async def test_auth_client_provider_headers(basic_server, basic_server_url):
12581248 tools = await session .list_tools ()
12591249 assert len (tools .tools ) == 4
12601250
1261- client_provider .get_auth_headers .assert_called ()
1251+ client_provider .get_headers .assert_called ()
12621252
12631253
12641254@pytest .mark .anyio
12651255async def test_auth_client_provider_called_per_request (basic_server , basic_server_url ):
12661256 """Test that auth token provider can return different tokens."""
12671257 # Create a dynamic token provider
1268- client_provider = MockAuthClientProvider ( "test-token-123" )
1269- client_provider .get_auth_headers = AsyncMock (
1270- return_value = { "Authorization" : "Bearer test-token-123" }
1271- )
1258+ client_provider = AsyncMock ( )
1259+ client_provider .get_headers . return_value = {
1260+ "Authorization" : "Bearer test-token-123"
1261+ }
12721262
12731263 # Create client with dynamic token provider
12741264 async with streamablehttp_client (
@@ -1284,4 +1274,6 @@ async def test_auth_client_provider_called_per_request(basic_server, basic_serve
12841274 tools = await session .list_tools ()
12851275 assert len (tools .tools ) == 4
12861276
1287- client_provider .get_auth_headers .call_count > 1
1277+ # list_tools is called 3 times, but get_auth_headers is also used during
1278+ # session initialization and setup. Verify it's called at least 3 times.
1279+ assert client_provider .get_headers .call_count > 3
0 commit comments