Skip to content

Commit 28ae4f7

Browse files
committed
Making unit tests simpler
1 parent c4fb621 commit 28ae4f7

File tree

2 files changed

+34
-152
lines changed

2 files changed

+34
-152
lines changed

src/mcp/client/streamable_http.py

Lines changed: 0 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,6 @@
4242
MCP_SESSION_ID = "mcp-session-id"
4343
LAST_EVENT_ID = "last-event-id"
4444
CONTENT_TYPE = "content-type"
45-
HEADER_CAPTURE = "[TESTING_HEADER_CAPTURE]"
4645
ACCEPT = "Accept"
4746

4847

@@ -273,16 +272,6 @@ async def _handle_resumption_request(self, ctx: RequestContext) -> None:
273272
if is_complete:
274273
break
275274

276-
async def _is_testing_header_capture(self, response: httpx.Response) -> str | None:
277-
try:
278-
content = await response.aread()
279-
if content.decode().startswith(HEADER_CAPTURE):
280-
return content.decode()
281-
except Exception as _:
282-
return None
283-
284-
return None
285-
286275
async def _handle_post_request(self, ctx: RequestContext) -> None:
287276
"""Handle a POST request with response processing."""
288277
headers = await self._update_headers(ctx.headers)
@@ -307,23 +296,6 @@ async def _handle_post_request(self, ctx: RequestContext) -> None:
307296
)
308297
return
309298

310-
# To test if headers are being forwarded correctly, in unit tests
311-
# we have a mock server that returns a 418 status code with the
312-
# HEADER_CAPTURE prefix. If the response has this status code
313-
# with the prefix, return the response content as part of the error message.
314-
if response.status_code == 418:
315-
test_error_message = await self._is_testing_header_capture(response)
316-
# If this is coming from the test case return the response content
317-
if test_error_message and isinstance(message.root, JSONRPCRequest):
318-
jsonrpc_error = JSONRPCError(
319-
jsonrpc="2.0",
320-
id=message.root.id,
321-
error=ErrorData(code=32600, message=test_error_message),
322-
)
323-
session_message = SessionMessage(JSONRPCMessage(jsonrpc_error))
324-
await ctx.read_stream_writer.send(session_message)
325-
return
326-
327299
response.raise_for_status()
328300
if is_initialization:
329301
self._maybe_extract_session_id_from_response(response)

tests/shared/test_streamable_http.py

Lines changed: 34 additions & 124 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
Contains tests for both server and client sides of the StreamableHTTP transport.
55
"""
66

7-
import json
87
import multiprocessing
98
import socket
109
import time
@@ -19,13 +18,11 @@
1918
import uvicorn
2019
from pydantic import AnyUrl
2120
from starlette.applications import Starlette
22-
from starlette.requests import Request
23-
from starlette.responses import Response
2421
from starlette.routing import Mount
2522

2623
import mcp.types as types
2724
from mcp.client.session import ClientSession
28-
from mcp.client.streamable_http import HEADER_CAPTURE, streamablehttp_client
25+
from mcp.client.streamable_http import streamablehttp_client
2926
from mcp.server import Server
3027
from mcp.server.streamable_http import (
3128
MCP_SESSION_ID_HEADER,
@@ -247,46 +244,8 @@ def create_app(
247244
return app
248245

249246

250-
def create_header_capture_app() -> Starlette:
251-
"""Implement a minimal Starlette app that intercepts every request,
252-
extracts its headers, and responds with status 418 (Test Status code),
253-
embedding the captured headers as the JSON response body.
254-
We use this server solely to verify that the MCP Server is forwarding
255-
headers correctly."""
256-
257-
# Create a wrapper that captures headers and returns them in error response
258-
async def header_capture_wrapper(scope, receive, send):
259-
# Capture headers
260-
request = Request(scope, receive=receive)
261-
headers = dict(request.headers)
262-
263-
# Return error response with headers in body
264-
response = Response(
265-
HEADER_CAPTURE + json.dumps({"headers": headers}),
266-
status_code=418,
267-
)
268-
await response(scope, receive, send)
269-
270-
# Create an ASGI application that uses our wrapper
271-
app = Starlette(
272-
debug=True,
273-
routes=[
274-
Mount("/mcp", app=header_capture_wrapper),
275-
],
276-
)
277-
278-
return app
279-
280-
281-
def _get_captured_headrs(str) -> dict[str, str]:
282-
return json.loads(str.split(HEADER_CAPTURE)[1])["headers"]
283-
284-
285247
def run_server(
286-
port: int,
287-
is_json_response_enabled=False,
288-
event_store: EventStore | None = None,
289-
testing_header_capture: bool = False,
248+
port: int, is_json_response_enabled=False, event_store: EventStore | None = None
290249
) -> None:
291250
"""Run the test server.
292251
@@ -296,11 +255,7 @@ def run_server(
296255
event_store: Optional event store for testing resumability.
297256
"""
298257

299-
if testing_header_capture:
300-
app = create_header_capture_app()
301-
else:
302-
app = create_app(is_json_response_enabled, event_store)
303-
258+
app = create_app(is_json_response_enabled, event_store)
304259
# Configure server
305260
config = uvicorn.Config(
306261
app=app,
@@ -341,16 +296,11 @@ def json_server_port() -> int:
341296
return s.getsockname()[1]
342297

343298

344-
def _start_basic_server(
345-
basic_server_port: int, testing_header_capture: bool
346-
) -> Generator[None, None, None]:
299+
@pytest.fixture
300+
def basic_server(basic_server_port: int) -> Generator[None, None, None]:
301+
"""Start a basic server."""
347302
proc = multiprocessing.Process(
348-
target=run_server,
349-
kwargs={
350-
"port": basic_server_port,
351-
"testing_header_capture": testing_header_capture,
352-
},
353-
daemon=True,
303+
target=run_server, kwargs={"port": basic_server_port}, daemon=True
354304
)
355305
proc.start()
356306

@@ -375,18 +325,6 @@ def _start_basic_server(
375325
proc.join(timeout=2)
376326

377327

378-
@pytest.fixture
379-
def basic_server(basic_server_port: int) -> Generator[None, None, None]:
380-
yield from _start_basic_server(basic_server_port, testing_header_capture=False)
381-
382-
383-
@pytest.fixture
384-
def basic_server_with_header_capture(
385-
basic_server_port: int,
386-
) -> Generator[None, None, None]:
387-
yield from _start_basic_server(basic_server_port, testing_header_capture=True)
388-
389-
390328
@pytest.fixture
391329
def event_store() -> SimpleEventStore:
392330
"""Create a test event store."""
@@ -1295,83 +1233,55 @@ def __init__(self, token: str):
12951233
self.token = token
12961234

12971235
async def get_auth_headers(self) -> dict[str, str]:
1298-
return {"Authorization": f"Bearer {self.token}"}
1236+
return {"Authorization": "Bearer " + self.token}
12991237

13001238

13011239
@pytest.mark.anyio
1302-
async def test_auth_client_provider_headers(
1303-
basic_server_with_header_capture, basic_server_url
1304-
):
1240+
async def test_auth_client_provider_headers(basic_server, basic_server_url):
13051241
"""Test that auth token provider correctly sets Authorization header."""
13061242
# Create a mock token provider
1307-
client_provider = MockAuthClientProvider("short-lived-token-123")
1243+
client_provider = MockAuthClientProvider("test-token-123")
1244+
client_provider.get_auth_headers = AsyncMock(
1245+
return_value={"Authorization": "Bearer test-token-123"}
1246+
)
13081247

13091248
# Create client with token provider
13101249
async with streamablehttp_client(
13111250
f"{basic_server_url}/mcp", auth_client_provider=client_provider
13121251
) as (read_stream, write_stream, _):
13131252
async with ClientSession(read_stream, write_stream) as session:
13141253
# Initialize the session
1315-
with pytest.raises(McpError) as mcpError:
1316-
_ = await session.initialize()
1317-
assert (
1318-
_get_captured_headrs(mcpError.value.error.message)["Authorization"]
1319-
== "Bearer short-lived-token-123"
1320-
)
1254+
result = await session.initialize()
1255+
assert isinstance(result, InitializeResult)
1256+
1257+
# Make a request to verify headers
1258+
tools = await session.list_tools()
1259+
assert len(tools.tools) == 4
1260+
1261+
client_provider.get_auth_headers.assert_called()
13211262

13221263

13231264
@pytest.mark.anyio
1324-
async def test_auth_client_provider_token_called_on_every_request(
1325-
basic_server_with_header_capture, basic_server_url
1326-
):
1265+
async def test_auth_client_provider_called_per_request(basic_server, basic_server_url):
13271266
"""Test that auth token provider can return different tokens."""
13281267
# Create a dynamic token provider
1329-
client_provider = MockAuthClientProvider("short-lived-token-123")
1268+
client_provider = MockAuthClientProvider("test-token-123")
1269+
client_provider.get_auth_headers = AsyncMock(
1270+
return_value={"Authorization": "Bearer test-token-123"}
1271+
)
13301272

1273+
# Create client with dynamic token provider
13311274
async with streamablehttp_client(
13321275
f"{basic_server_url}/mcp", auth_client_provider=client_provider
13331276
) as (read_stream, write_stream, _):
13341277
async with ClientSession(read_stream, write_stream) as session:
13351278
# Initialize the session
1336-
with pytest.raises(McpError) as mcpError:
1337-
_ = await session.initialize()
1338-
assert (
1339-
_get_captured_headrs(mcpError.value.error.message)["Authorization"]
1340-
== "Bearer short-lived-token-123"
1341-
)
1342-
1343-
# Mock a new token and ensure the new token is returned
1344-
client_provider.get_auth_headers = AsyncMock(
1345-
return_value={"Authorization": "Bearer short-lived-token-456"}
1346-
)
1347-
with pytest.raises(McpError) as mcpError:
1348-
_ = await session.initialize()
1349-
assert (
1350-
_get_captured_headrs(mcpError.value.error.message)["Authorization"]
1351-
== "Bearer short-lived-token-456"
1352-
)
1279+
result = await session.initialize()
1280+
assert isinstance(result, InitializeResult)
13531281

1282+
# Make multiple requests to verify token updates
1283+
for i in range(3):
1284+
tools = await session.list_tools()
1285+
assert len(tools.tools) == 4
13541286

1355-
@pytest.mark.anyio
1356-
async def test_auth_client_provider_headers_not_overridden(
1357-
basic_server_with_header_capture, basic_server_url
1358-
):
1359-
"""Test that provided headers override auth client provider headers."""
1360-
# Create a mock token provider
1361-
client_provider = MockAuthClientProvider("short-lived-token")
1362-
1363-
# Create client with token provider and custom headers
1364-
custom_headers = {"Authorization": "Bearer original-long-lived-token"}
1365-
async with streamablehttp_client(
1366-
f"{basic_server_url}/mcp",
1367-
auth_client_provider=client_provider,
1368-
headers=custom_headers,
1369-
) as (read_stream, write_stream, _):
1370-
async with ClientSession(read_stream, write_stream) as session:
1371-
# Original token is used and not short-lived-token from the provider
1372-
with pytest.raises(McpError) as mcpError:
1373-
_ = await session.initialize()
1374-
assert (
1375-
_get_captured_headrs(mcpError.value.error.message)["Authorization"]
1376-
== "Bearer original-long-lived-token"
1377-
)
1287+
client_provider.get_auth_headers.call_count > 1

0 commit comments

Comments
 (0)