Skip to content

Commit 7f241e7

Browse files
committed
add disconnect_event to handle closure from client/task_group
1 parent 38f307c commit 7f241e7

File tree

1 file changed

+22
-5
lines changed

1 file changed

+22
-5
lines changed

src/mcp/server/streaming_asgi_transport.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
import typing
1212
from typing import Any, cast
13+
from typing import Callable, Awaitable
1314

1415
import anyio
1516
import anyio.abc
@@ -65,6 +66,8 @@ async def handle_async_request(
6566
) -> Response:
6667
assert isinstance(request.stream, AsyncByteStream)
6768

69+
disconnect_event = anyio.Event()
70+
6871
# ASGI scope.
6972
scope = {
7073
"type": "http",
@@ -97,11 +100,17 @@ async def handle_async_request(
97100
content_send_channel, content_receive_channel = anyio.create_memory_object_stream[bytes](100)
98101

99102
# ASGI callables.
103+
async def send_disconnect() -> None:
104+
disconnect_event.set()
105+
100106
async def receive() -> dict[str, Any]:
101107
nonlocal request_complete
102108

109+
if disconnect_event.is_set():
110+
return {"type": "http.disconnect"}
111+
103112
if request_complete:
104-
await response_complete.wait()
113+
await disconnect_event.wait()
105114
return {"type": "http.disconnect"}
106115

107116
try:
@@ -140,7 +149,9 @@ async def process_messages() -> None:
140149
async with asgi_receive_channel:
141150
async for message in asgi_receive_channel:
142151
if message["type"] == "http.response.start":
143-
assert not response_started
152+
if response_started:
153+
# Ignore duplicate response.start from ASGI app during SSE disconnect
154+
continue
144155
status_code = message["status"]
145156
response_headers = message.get("headers", [])
146157
response_started = True
@@ -163,7 +174,7 @@ async def process_messages() -> None:
163174
# Ensure events are set even if there's an error
164175
initial_response_ready.set()
165176
response_complete.set()
166-
await content_send_channel.aclose()
177+
167178

168179
# Create tasks for running the app and processing messages
169180
self.task_group.start_soon(run_app)
@@ -176,7 +187,7 @@ async def process_messages() -> None:
176187
return Response(
177188
status_code,
178189
headers=response_headers,
179-
stream=StreamingASGIResponseStream(content_receive_channel),
190+
stream = StreamingASGIResponseStream(content_receive_channel, send_disconnect),
180191
)
181192

182193

@@ -192,12 +203,18 @@ class StreamingASGIResponseStream(AsyncByteStream):
192203
def __init__(
193204
self,
194205
receive_channel: anyio.streams.memory.MemoryObjectReceiveStream[bytes],
206+
send_disconnect: Callable[[], Awaitable[None]],
195207
) -> None:
196208
self.receive_channel = receive_channel
209+
self.send_disconnect = send_disconnect
197210

198211
async def __aiter__(self) -> typing.AsyncIterator[bytes]:
199212
try:
200213
async for chunk in self.receive_channel:
201214
yield chunk
202215
finally:
203-
await self.receive_channel.aclose()
216+
await self.aclose()
217+
218+
async def aclose(self) -> None:
219+
await self.receive_channel.aclose()
220+
await self.send_disconnect()

0 commit comments

Comments
 (0)