Skip to content

Commit 8ab1d66

Browse files
committed
Address comments
1 parent 28ae4f7 commit 8ab1d66

File tree

2 files changed

+19
-28
lines changed

2 files changed

+19
-28
lines changed

src/mcp/client/streamable_http.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ class AuthClientProvider(Protocol):
7878
"""Base class that can be extended to implement custom client-to-server
7979
authentication"""
8080

81-
async def get_auth_headers(self) -> dict[str, str]:
81+
async def get_headers(self) -> dict[str, str]:
8282
"""Gets auth headers for authenticating to an MCP server.
8383
Clients may call this API multiple times per request to an MCP server.
8484
@@ -132,12 +132,12 @@ async def _update_headers_with_auth_headers(
132132
self, base_headers: dict[str, str]
133133
) -> dict[str, str]:
134134
"""Update headers with auth_headers if auth client provider is specified.
135-
The headers are merged giving precedence to the base_headers to
136-
avoid overwriting existing Authorization headers"""
135+
The headers are merged, giving precedence to any headers already
136+
specified in base_headers"""
137137
if self.auth_client_provider is None:
138138
return base_headers
139139

140-
auth_headers = await self.auth_client_provider.get_auth_headers()
140+
auth_headers = await self.auth_client_provider.get_headers()
141141
return {**auth_headers, **base_headers}
142142

143143
async def _update_headers(self, base_headers: dict[str, str]) -> dict[str, str]:
@@ -476,10 +476,9 @@ async def streamablehttp_client(
476476
477477
`auth_client_provider` instance of `AuthClientProvider` that can be passed to
478478
support client-to-server authentication. Before each request to the MCP Server,
479-
the auth_client_provider.get_token method is invoked to retrieve a fresh
480-
authentication token and update the request headers. Note that if the passed in
481-
`headers` already contain an Authorization header, that header will take precedence
482-
over any tokens generated by this provider.
479+
the auth_client_provider.get_headers() method is invoked to retrieve headers
480+
for authentication. Note that any headers already specified in `headers`
481+
will take precedence over headers returned by auth_client_provider.get_headers()
483482
484483
Yields:
485484
Tuple containing:

tests/shared/test_streamable_http.py

Lines changed: 12 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1226,24 +1226,14 @@ async def sampling_callback(
12261226
)
12271227

12281228

1229-
class MockAuthClientProvider:
1230-
"""Mock implementation of AuthClientProvider for testing."""
1231-
1232-
def __init__(self, token: str):
1233-
self.token = token
1234-
1235-
async def get_auth_headers(self) -> dict[str, str]:
1236-
return {"Authorization": "Bearer " + self.token}
1237-
1238-
12391229
@pytest.mark.anyio
12401230
async def test_auth_client_provider_headers(basic_server, basic_server_url):
12411231
"""Test that auth token provider correctly sets Authorization header."""
12421232
# Create a mock token provider
1243-
client_provider = MockAuthClientProvider("test-token-123")
1244-
client_provider.get_auth_headers = AsyncMock(
1245-
return_value={"Authorization": "Bearer test-token-123"}
1246-
)
1233+
client_provider = AsyncMock()
1234+
client_provider.get_headers.return_value = {
1235+
"Authorization": "Bearer test-token-123"
1236+
}
12471237

12481238
# Create client with token provider
12491239
async with streamablehttp_client(
@@ -1258,17 +1248,17 @@ async def test_auth_client_provider_headers(basic_server, basic_server_url):
12581248
tools = await session.list_tools()
12591249
assert len(tools.tools) == 4
12601250

1261-
client_provider.get_auth_headers.assert_called()
1251+
client_provider.get_headers.assert_called()
12621252

12631253

12641254
@pytest.mark.anyio
12651255
async def test_auth_client_provider_called_per_request(basic_server, basic_server_url):
12661256
"""Test that auth token provider can return different tokens."""
12671257
# Create a dynamic token provider
1268-
client_provider = MockAuthClientProvider("test-token-123")
1269-
client_provider.get_auth_headers = AsyncMock(
1270-
return_value={"Authorization": "Bearer test-token-123"}
1271-
)
1258+
client_provider = AsyncMock()
1259+
client_provider.get_headers.return_value = {
1260+
"Authorization": "Bearer test-token-123"
1261+
}
12721262

12731263
# Create client with dynamic token provider
12741264
async with streamablehttp_client(
@@ -1284,4 +1274,6 @@ async def test_auth_client_provider_called_per_request(basic_server, basic_serve
12841274
tools = await session.list_tools()
12851275
assert len(tools.tools) == 4
12861276

1287-
client_provider.get_auth_headers.call_count > 1
1277+
# list_tools is called 3 times, but get_auth_headers is also used during
1278+
# session initialization and setup. Verify it's called at least 3 times.
1279+
assert client_provider.get_headers.call_count > 3

0 commit comments

Comments
 (0)