Skip to content

Commit 8fc6473

Browse files
committed
prevent sse_client from cancelling external task groups
1 parent e1745f8 commit 8fc6473

File tree

2 files changed

+84
-73
lines changed

2 files changed

+84
-73
lines changed

src/mcp/client/sse.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
import logging
2-
from contextlib import asynccontextmanager
2+
from contextlib import asynccontextmanager, AsyncExitStack
33
from typing import Any
44
from urllib.parse import urljoin, urlparse
55

66
import anyio
77
import httpx
8-
from anyio.abc import TaskStatus
8+
from anyio.abc import TaskStatus, TaskGroup
99
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
1010
from httpx_sse import aconnect_sse
1111
from httpx_sse._exceptions import SSEError
@@ -29,6 +29,7 @@ async def sse_client(
2929
sse_read_timeout: float = 60 * 5,
3030
httpx_client_factory: McpHttpClientFactory = create_mcp_http_client,
3131
auth: httpx.Auth | None = None,
32+
maybe_task_group: TaskGroup | None = None,
3233
):
3334
"""
3435
Client transport for SSE.
@@ -52,7 +53,15 @@ async def sse_client(
5253
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
5354
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
5455

55-
async with anyio.create_task_group() as tg:
56+
async with AsyncExitStack() as stack:
57+
# Only create a task group if one wasn't provided
58+
if maybe_task_group is None:
59+
tg = await stack.enter_async_context(anyio.create_task_group())
60+
else:
61+
tg = maybe_task_group
62+
63+
owns_task_group = maybe_task_group is None
64+
5665
try:
5766
logger.debug(f"Connecting to SSE endpoint: {remove_request_params(url)}")
5867
async with httpx_client_factory(
@@ -142,7 +151,8 @@ async def post_writer(endpoint_url: str):
142151
try:
143152
yield read_stream, write_stream
144153
finally:
145-
tg.cancel_scope.cancel()
154+
if owns_task_group:
155+
tg.cancel_scope.cancel()
146156
finally:
147157
await read_stream_writer.aclose()
148158
await write_stream.aclose()

tests/shared/test_sse.py

Lines changed: 70 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -116,24 +116,28 @@ def server_app() -> Starlette:
116116

117117

118118
@pytest.fixture()
119-
async def http_client(server_app: Starlette) -> AsyncGenerator[httpx.AsyncClient, None]:
120-
"""Create test client using StreamingASGITransport"""
119+
async def tg() -> AsyncGenerator[TaskGroup, None]:
121120
async with anyio.create_task_group() as tg:
122-
transport = StreamingASGITransport(app=server_app, task_group=tg)
123-
async with httpx.AsyncClient(transport=transport, base_url=TEST_SERVER_BASE_URL) as client:
124-
yield client
121+
yield tg
125122

126123

127124
@pytest.fixture()
128-
async def sse_client_session(server_app: Starlette) -> AsyncGenerator[ClientSession, None]:
129-
async with anyio.create_task_group() as tg:
130-
asgi_client_factory = create_asgi_client_factory(server_app, tg)
125+
async def http_client(tg: TaskGroup, server_app: Starlette) -> AsyncGenerator[httpx.AsyncClient, None]:
126+
"""Create test client using StreamingASGITransport"""
127+
transport = StreamingASGITransport(app=server_app, task_group=tg)
128+
async with httpx.AsyncClient(transport=transport, base_url=TEST_SERVER_BASE_URL) as client:
129+
yield client
131130

132-
async with sse_client(
133-
f"{TEST_SERVER_BASE_URL}/sse", sse_read_timeout=0.5, httpx_client_factory=asgi_client_factory
134-
) as streams:
135-
async with ClientSession(*streams) as session:
136-
yield session
131+
132+
@pytest.fixture()
133+
async def sse_client_session(tg: TaskGroup, server_app: Starlette) -> AsyncGenerator[ClientSession, None]:
134+
asgi_client_factory = create_asgi_client_factory(server_app, tg)
135+
136+
async with sse_client(
137+
f"{TEST_SERVER_BASE_URL}/sse", sse_read_timeout=0.5, httpx_client_factory=asgi_client_factory,
138+
) as streams:
139+
async with ClientSession(*streams) as session:
140+
yield session
137141

138142

139143
# Tests
@@ -228,15 +232,16 @@ async def mounted_server_app(server_app: Starlette) -> Starlette:
228232

229233

230234
@pytest.fixture()
231-
async def sse_client_mounted_server_app_session(mounted_server_app: Starlette) -> AsyncGenerator[ClientSession, None]:
232-
async with anyio.create_task_group() as tg:
233-
asgi_client_factory = create_asgi_client_factory(mounted_server_app, tg)
235+
async def sse_client_mounted_server_app_session(
236+
tg: TaskGroup, mounted_server_app: Starlette
237+
) -> AsyncGenerator[ClientSession, None]:
238+
asgi_client_factory = create_asgi_client_factory(mounted_server_app, tg)
234239

235-
async with sse_client(
236-
f"{TEST_SERVER_BASE_URL}/mounted_app/sse", sse_read_timeout=0.5, httpx_client_factory=asgi_client_factory
237-
) as streams:
238-
async with ClientSession(*streams) as session:
239-
yield session
240+
async with sse_client(
241+
f"{TEST_SERVER_BASE_URL}/mounted_app/sse", sse_read_timeout=0.5, httpx_client_factory=asgi_client_factory,
242+
) as streams:
243+
async with ClientSession(*streams) as session:
244+
yield session
240245

241246

242247
@pytest.mark.anyio
@@ -303,7 +308,7 @@ async def context_server_app() -> Starlette:
303308

304309

305310
@pytest.mark.anyio
306-
async def test_request_context_propagation(context_server_app: Starlette) -> None:
311+
async def test_request_context_propagation(tg: TaskGroup, context_server_app: Starlette) -> None:
307312
"""Test that request context is properly propagated through SSE transport."""
308313
# Test with custom headers
309314
custom_headers = {
@@ -312,63 +317,59 @@ async def test_request_context_propagation(context_server_app: Starlette) -> Non
312317
"X-Trace-Id": "trace-123",
313318
}
314319

315-
async with anyio.create_task_group() as tg:
316-
asgi_client_factory = create_asgi_client_factory(context_server_app, tg)
320+
asgi_client_factory = create_asgi_client_factory(context_server_app, tg)
317321

318-
async with sse_client(
319-
f"{TEST_SERVER_BASE_URL}/sse",
320-
headers=custom_headers,
321-
httpx_client_factory=asgi_client_factory,
322-
sse_read_timeout=0.5,
323-
) as streams:
324-
async with ClientSession(*streams) as session:
325-
# Initialize the session
326-
result = await session.initialize()
327-
assert isinstance(result, InitializeResult)
322+
async with sse_client(
323+
f"{TEST_SERVER_BASE_URL}/sse",
324+
headers=custom_headers,
325+
httpx_client_factory=asgi_client_factory,
326+
sse_read_timeout=0.5,
327+
328+
) as streams:
329+
async with ClientSession(*streams) as session:
330+
# Initialize the session
331+
result = await session.initialize()
332+
assert isinstance(result, InitializeResult)
328333

329-
# Call the tool that echoes headers back
330-
tool_result = await session.call_tool("echo_headers", {})
334+
# Call the tool that echoes headers back
335+
tool_result = await session.call_tool("echo_headers", {})
331336

332-
# Parse the JSON response
333-
assert len(tool_result.content) == 1
334-
content_item = tool_result.content[0]
335-
headers_data = json.loads(content_item.text if content_item.type == "text" else "{}")
337+
# Parse the JSON response
338+
assert len(tool_result.content) == 1
339+
content_item = tool_result.content[0]
340+
headers_data = json.loads(content_item.text if content_item.type == "text" else "{}")
336341

337-
# Verify headers were propagated
338-
assert headers_data.get("authorization") == "Bearer test-token"
339-
assert headers_data.get("x-custom-header") == "test-value"
340-
assert headers_data.get("x-trace-id") == "trace-123"
342+
# Verify headers were propagated
343+
assert headers_data.get("authorization") == "Bearer test-token"
344+
assert headers_data.get("x-custom-header") == "test-value"
345+
assert headers_data.get("x-trace-id") == "trace-123"
341346

342347

343348
@pytest.mark.anyio
344-
async def test_request_context_isolation(context_server_app: Starlette) -> None:
349+
async def test_request_context_isolation(tg: TaskGroup, context_server_app: Starlette) -> None:
345350
"""Test that request contexts are isolated between different SSE clients."""
346351
contexts: list[dict[str, Any]] = []
347352

348-
async with anyio.create_task_group() as tg:
349-
asgi_client_factory = create_asgi_client_factory(context_server_app, tg)
350-
351-
# Create multiple clients with different headers
352-
for i in range(3):
353-
headers = {"X-Request-Id": f"request-{i}", "X-Custom-Value": f"value-{i}"}
354-
355-
async with sse_client(
356-
f"{TEST_SERVER_BASE_URL}/sse", headers=headers, httpx_client_factory=asgi_client_factory
357-
) as (
358-
read_stream,
359-
write_stream,
360-
):
361-
async with ClientSession(read_stream, write_stream) as session:
362-
await session.initialize()
363-
364-
# Call the tool that echoes context
365-
tool_result = await session.call_tool("echo_context", {"request_id": f"request-{i}"})
366-
367-
assert len(tool_result.content) == 1
368-
context_data = json.loads(
369-
tool_result.content[0].text if tool_result.content[0].type == "text" else "{}"
370-
)
371-
contexts.append(context_data)
353+
asgi_client_factory = create_asgi_client_factory(context_server_app, tg)
354+
355+
# Create multiple clients with different headers
356+
for i in range(3):
357+
headers = {"X-Request-Id": f"request-{i}", "X-Custom-Value": f"value-{i}"}
358+
359+
async with sse_client(
360+
f"{TEST_SERVER_BASE_URL}/sse", headers=headers, httpx_client_factory=asgi_client_factory,
361+
) as streams:
362+
async with ClientSession(*streams) as session:
363+
await session.initialize()
364+
365+
# Call the tool that echoes context
366+
tool_result = await session.call_tool("echo_context", {"request_id": f"request-{i}"})
367+
368+
assert len(tool_result.content) == 1
369+
context_data = json.loads(
370+
tool_result.content[0].text if tool_result.content[0].type == "text" else "{}"
371+
)
372+
contexts.append(context_data)
372373

373374
# Verify each request had its own context
374375
assert len(contexts) == 3

0 commit comments

Comments
 (0)