1010
1111import typing
1212from typing import Any , cast
13+ from typing import Callable , Awaitable
1314
1415import anyio
1516import anyio .abc
@@ -65,6 +66,8 @@ async def handle_async_request(
6566 ) -> Response :
6667 assert isinstance (request .stream , AsyncByteStream )
6768
69+ disconnect_event = anyio .Event ()
70+
6871 # ASGI scope.
6972 scope = {
7073 "type" : "http" ,
@@ -97,11 +100,17 @@ async def handle_async_request(
97100 content_send_channel , content_receive_channel = anyio .create_memory_object_stream [bytes ](100 )
98101
99102 # ASGI callables.
103+ async def send_disconnect () -> None :
104+ disconnect_event .set ()
105+
100106 async def receive () -> dict [str , Any ]:
101107 nonlocal request_complete
102108
109+ if disconnect_event .is_set ():
110+ return {"type" : "http.disconnect" }
111+
103112 if request_complete :
104- await response_complete .wait ()
113+ await disconnect_event .wait ()
105114 return {"type" : "http.disconnect" }
106115
107116 try :
@@ -140,7 +149,9 @@ async def process_messages() -> None:
140149 async with asgi_receive_channel :
141150 async for message in asgi_receive_channel :
142151 if message ["type" ] == "http.response.start" :
143- assert not response_started
152+ if response_started :
153+ # Ignore duplicate response.start from ASGI app during SSE disconnect
154+ continue
144155 status_code = message ["status" ]
145156 response_headers = message .get ("headers" , [])
146157 response_started = True
@@ -163,7 +174,7 @@ async def process_messages() -> None:
163174 # Ensure events are set even if there's an error
164175 initial_response_ready .set ()
165176 response_complete .set ()
166- await content_send_channel . aclose ()
177+
167178
168179 # Create tasks for running the app and processing messages
169180 self .task_group .start_soon (run_app )
@@ -176,7 +187,7 @@ async def process_messages() -> None:
176187 return Response (
177188 status_code ,
178189 headers = response_headers ,
179- stream = StreamingASGIResponseStream (content_receive_channel ),
190+ stream = StreamingASGIResponseStream (content_receive_channel , send_disconnect ),
180191 )
181192
182193
@@ -192,12 +203,18 @@ class StreamingASGIResponseStream(AsyncByteStream):
192203 def __init__ (
193204 self ,
194205 receive_channel : anyio .streams .memory .MemoryObjectReceiveStream [bytes ],
206+ send_disconnect : Callable [[], Awaitable [None ]],
195207 ) -> None :
196208 self .receive_channel = receive_channel
209+ self .send_disconnect = send_disconnect
197210
198211 async def __aiter__ (self ) -> typing .AsyncIterator [bytes ]:
199212 try :
200213 async for chunk in self .receive_channel :
201214 yield chunk
202215 finally :
203- await self .receive_channel .aclose ()
216+ await self .aclose ()
217+
218+ async def aclose (self ) -> None :
219+ await self .receive_channel .aclose ()
220+ await self .send_disconnect ()
0 commit comments