44Contains tests for both server and client sides of the StreamableHTTP transport.
55"""
66
7- import json
87import multiprocessing
98import socket
109import time
1918import uvicorn
2019from pydantic import AnyUrl
2120from starlette .applications import Starlette
22- from starlette .requests import Request
23- from starlette .responses import Response
2421from starlette .routing import Mount
2522
2623import mcp .types as types
2724from mcp .client .session import ClientSession
28- from mcp .client .streamable_http import HEADER_CAPTURE , streamablehttp_client
25+ from mcp .client .streamable_http import streamablehttp_client
2926from mcp .server import Server
3027from mcp .server .streamable_http import (
3128 MCP_SESSION_ID_HEADER ,
@@ -247,46 +244,8 @@ def create_app(
247244 return app
248245
249246
250- def create_header_capture_app () -> Starlette :
251- """Implement a minimal Starlette app that intercepts every request,
252- extracts its headers, and responds with status 418 (Test Status code),
253- embedding the captured headers as the JSON response body.
254- We use this server solely to verify that the MCP Server is forwarding
255- headers correctly."""
256-
257- # Create a wrapper that captures headers and returns them in error response
258- async def header_capture_wrapper (scope , receive , send ):
259- # Capture headers
260- request = Request (scope , receive = receive )
261- headers = dict (request .headers )
262-
263- # Return error response with headers in body
264- response = Response (
265- HEADER_CAPTURE + json .dumps ({"headers" : headers }),
266- status_code = 418 ,
267- )
268- await response (scope , receive , send )
269-
270- # Create an ASGI application that uses our wrapper
271- app = Starlette (
272- debug = True ,
273- routes = [
274- Mount ("/mcp" , app = header_capture_wrapper ),
275- ],
276- )
277-
278- return app
279-
280-
281- def _get_captured_headrs (str ) -> dict [str , str ]:
282- return json .loads (str .split (HEADER_CAPTURE )[1 ])["headers" ]
283-
284-
285247def run_server (
286- port : int ,
287- is_json_response_enabled = False ,
288- event_store : EventStore | None = None ,
289- testing_header_capture : bool = False ,
248+ port : int , is_json_response_enabled = False , event_store : EventStore | None = None
290249) -> None :
291250 """Run the test server.
292251
@@ -296,11 +255,7 @@ def run_server(
296255 event_store: Optional event store for testing resumability.
297256 """
298257
299- if testing_header_capture :
300- app = create_header_capture_app ()
301- else :
302- app = create_app (is_json_response_enabled , event_store )
303-
258+ app = create_app (is_json_response_enabled , event_store )
304259 # Configure server
305260 config = uvicorn .Config (
306261 app = app ,
@@ -341,16 +296,11 @@ def json_server_port() -> int:
341296 return s .getsockname ()[1 ]
342297
343298
344- def _start_basic_server (
345- basic_server_port : int , testing_header_capture : bool
346- ) -> Generator [ None , None , None ]:
299+ @ pytest . fixture
300+ def basic_server ( basic_server_port : int ) -> Generator [ None , None , None ]:
301+ """Start a basic server."""
347302 proc = multiprocessing .Process (
348- target = run_server ,
349- kwargs = {
350- "port" : basic_server_port ,
351- "testing_header_capture" : testing_header_capture ,
352- },
353- daemon = True ,
303+ target = run_server , kwargs = {"port" : basic_server_port }, daemon = True
354304 )
355305 proc .start ()
356306
@@ -375,18 +325,6 @@ def _start_basic_server(
375325 proc .join (timeout = 2 )
376326
377327
378- @pytest .fixture
379- def basic_server (basic_server_port : int ) -> Generator [None , None , None ]:
380- yield from _start_basic_server (basic_server_port , testing_header_capture = False )
381-
382-
383- @pytest .fixture
384- def basic_server_with_header_capture (
385- basic_server_port : int ,
386- ) -> Generator [None , None , None ]:
387- yield from _start_basic_server (basic_server_port , testing_header_capture = True )
388-
389-
390328@pytest .fixture
391329def event_store () -> SimpleEventStore :
392330 """Create a test event store."""
@@ -1295,83 +1233,55 @@ def __init__(self, token: str):
12951233 self .token = token
12961234
12971235 async def get_auth_headers (self ) -> dict [str , str ]:
1298- return {"Authorization" : f "Bearer { self .token } " }
1236+ return {"Authorization" : "Bearer " + self .token }
12991237
13001238
13011239@pytest .mark .anyio
1302- async def test_auth_client_provider_headers (
1303- basic_server_with_header_capture , basic_server_url
1304- ):
1240+ async def test_auth_client_provider_headers (basic_server , basic_server_url ):
13051241 """Test that auth token provider correctly sets Authorization header."""
13061242 # Create a mock token provider
1307- client_provider = MockAuthClientProvider ("short-lived-token-123" )
1243+ client_provider = MockAuthClientProvider ("test-token-123" )
1244+ client_provider .get_auth_headers = AsyncMock (
1245+ return_value = {"Authorization" : "Bearer test-token-123" }
1246+ )
13081247
13091248 # Create client with token provider
13101249 async with streamablehttp_client (
13111250 f"{ basic_server_url } /mcp" , auth_client_provider = client_provider
13121251 ) as (read_stream , write_stream , _ ):
13131252 async with ClientSession (read_stream , write_stream ) as session :
13141253 # Initialize the session
1315- with pytest .raises (McpError ) as mcpError :
1316- _ = await session .initialize ()
1317- assert (
1318- _get_captured_headrs (mcpError .value .error .message )["Authorization" ]
1319- == "Bearer short-lived-token-123"
1320- )
1254+ result = await session .initialize ()
1255+ assert isinstance (result , InitializeResult )
1256+
1257+ # Make a request to verify headers
1258+ tools = await session .list_tools ()
1259+ assert len (tools .tools ) == 4
1260+
1261+ client_provider .get_auth_headers .assert_called ()
13211262
13221263
13231264@pytest .mark .anyio
1324- async def test_auth_client_provider_token_called_on_every_request (
1325- basic_server_with_header_capture , basic_server_url
1326- ):
1265+ async def test_auth_client_provider_called_per_request (basic_server , basic_server_url ):
13271266 """Test that auth token provider can return different tokens."""
13281267 # Create a dynamic token provider
1329- client_provider = MockAuthClientProvider ("short-lived-token-123" )
1268+ client_provider = MockAuthClientProvider ("test-token-123" )
1269+ client_provider .get_auth_headers = AsyncMock (
1270+ return_value = {"Authorization" : "Bearer test-token-123" }
1271+ )
13301272
1273+ # Create client with dynamic token provider
13311274 async with streamablehttp_client (
13321275 f"{ basic_server_url } /mcp" , auth_client_provider = client_provider
13331276 ) as (read_stream , write_stream , _ ):
13341277 async with ClientSession (read_stream , write_stream ) as session :
13351278 # Initialize the session
1336- with pytest .raises (McpError ) as mcpError :
1337- _ = await session .initialize ()
1338- assert (
1339- _get_captured_headrs (mcpError .value .error .message )["Authorization" ]
1340- == "Bearer short-lived-token-123"
1341- )
1342-
1343- # Mock a new token and ensure the new token is returned
1344- client_provider .get_auth_headers = AsyncMock (
1345- return_value = {"Authorization" : "Bearer short-lived-token-456" }
1346- )
1347- with pytest .raises (McpError ) as mcpError :
1348- _ = await session .initialize ()
1349- assert (
1350- _get_captured_headrs (mcpError .value .error .message )["Authorization" ]
1351- == "Bearer short-lived-token-456"
1352- )
1279+ result = await session .initialize ()
1280+ assert isinstance (result , InitializeResult )
13531281
1282+ # Make multiple requests to verify token updates
1283+ for i in range (3 ):
1284+ tools = await session .list_tools ()
1285+ assert len (tools .tools ) == 4
13541286
1355- @pytest .mark .anyio
1356- async def test_auth_client_provider_headers_not_overridden (
1357- basic_server_with_header_capture , basic_server_url
1358- ):
1359- """Test that provided headers override auth client provider headers."""
1360- # Create a mock token provider
1361- client_provider = MockAuthClientProvider ("short-lived-token" )
1362-
1363- # Create client with token provider and custom headers
1364- custom_headers = {"Authorization" : "Bearer original-long-lived-token" }
1365- async with streamablehttp_client (
1366- f"{ basic_server_url } /mcp" ,
1367- auth_client_provider = client_provider ,
1368- headers = custom_headers ,
1369- ) as (read_stream , write_stream , _ ):
1370- async with ClientSession (read_stream , write_stream ) as session :
1371- # Original token is used and not short-lived-token from the provider
1372- with pytest .raises (McpError ) as mcpError :
1373- _ = await session .initialize ()
1374- assert (
1375- _get_captured_headrs (mcpError .value .error .message )["Authorization" ]
1376- == "Bearer original-long-lived-token"
1377- )
1287+ client_provider .get_auth_headers .call_count > 1
0 commit comments