Skip to content

Commit 10f6754

Browse files
committed
## Fix race condition in Streamable HTTP transport
1 parent 5983a65 commit 10f6754

File tree

2 files changed

+73
-2
lines changed

2 files changed

+73
-2
lines changed

src/mcp/client/streamable_http.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
import anyio
1616
import httpx
17-
from anyio.abc import TaskGroup
17+
from anyio.abc import TaskGroup, TaskStatus
1818
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
1919
from httpx_sse import EventSource, ServerSentEvent, aconnect_sse
2020

@@ -376,10 +376,14 @@ async def post_writer(
376376
write_stream: MemoryObjectSendStream[SessionMessage],
377377
start_get_stream: Callable[[], None],
378378
tg: TaskGroup,
379+
*,
380+
task_status: TaskStatus[None] = anyio.TASK_STATUS_IGNORED,
379381
) -> None:
380382
"""Handle writing requests to the server."""
381383
try:
382384
async with write_stream_reader:
385+
# Signal that we're ready to receive messages
386+
task_status.started(None)
383387
async for session_message in write_stream_reader:
384388
message = session_message.message
385389
metadata = (
@@ -493,7 +497,10 @@ async def streamablehttp_client(
493497
def start_get_stream() -> None:
494498
tg.start_soon(transport.handle_get_stream, client, read_stream_writer)
495499

496-
tg.start_soon(
500+
# Use tg.start() to ensure post_writer is ready before yielding.
501+
# This prevents a race condition where the client might try to send
502+
# a message before the writer task is ready to receive it.
503+
await tg.start(
497504
transport.post_writer,
498505
client,
499506
write_stream_reader,

tests/shared/test_streamable_http.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1633,3 +1633,67 @@ async def test_handle_sse_event_skips_empty_data():
16331633
finally:
16341634
await write_stream.aclose()
16351635
await read_stream.aclose()
1636+
1637+
1638+
@pytest.mark.anyio
1639+
async def test_streamablehttp_no_race_condition_on_consecutive_requests(basic_server: None, basic_server_url: str):
1640+
"""Test that consecutive requests after initialize() work reliably.
1641+
1642+
This test verifies the fix for the race condition where list_tools()
1643+
could intermittently return empty results immediately after initialize().
1644+
The fix ensures post_writer is fully ready before yielding from the
1645+
context manager by using tg.start() instead of tg.start_soon().
1646+
1647+
We run multiple iterations to catch any intermittent issues.
1648+
"""
1649+
for iteration in range(10):
1650+
async with streamablehttp_client(f"{basic_server_url}/mcp") as (
1651+
read_stream,
1652+
write_stream,
1653+
_,
1654+
):
1655+
async with ClientSession(read_stream, write_stream) as session:
1656+
# Initialize the session
1657+
result = await session.initialize()
1658+
assert isinstance(result, InitializeResult)
1659+
assert result.serverInfo.name == SERVER_NAME
1660+
1661+
# Immediately call list_tools() - this should never fail or return empty
1662+
tools = await session.list_tools()
1663+
assert len(tools.tools) > 0, f"Iteration {iteration}: list_tools() returned empty"
1664+
assert tools.tools[0].name == "test_tool"
1665+
1666+
# Make several more consecutive requests to ensure stability
1667+
tools2 = await session.list_tools()
1668+
assert len(tools2.tools) == len(tools.tools)
1669+
1670+
# Read a resource
1671+
resource = await session.read_resource(uri=AnyUrl("foobar://test-iteration"))
1672+
assert len(resource.contents) == 1
1673+
1674+
1675+
@pytest.mark.anyio
1676+
async def test_streamablehttp_rapid_request_sequence(basic_server: None, basic_server_url: str):
1677+
"""Test that rapid sequences of requests work correctly.
1678+
1679+
This stress test verifies that the transport handles rapid request sequences
1680+
without race conditions or message loss.
1681+
"""
1682+
async with streamablehttp_client(f"{basic_server_url}/mcp") as (
1683+
read_stream,
1684+
write_stream,
1685+
_,
1686+
):
1687+
async with ClientSession(read_stream, write_stream) as session:
1688+
# Initialize
1689+
result = await session.initialize()
1690+
assert isinstance(result, InitializeResult)
1691+
1692+
# Rapid sequence of requests
1693+
for i in range(20):
1694+
tools = await session.list_tools()
1695+
assert len(tools.tools) == 6, f"Request {i}: Expected 6 tools, got {len(tools.tools)}"
1696+
1697+
# Verify we can still make other types of requests
1698+
resource = await session.read_resource(uri=AnyUrl("foobar://final-test"))
1699+
assert len(resource.contents) == 1

0 commit comments

Comments
 (0)