@@ -129,6 +129,8 @@ class StreamableHTTPServerTransport:
129129 _read_stream_writer : MemoryObjectSendStream [SessionMessage | Exception ] | None = (
130130 None
131131 )
132+ _read_stream : MemoryObjectReceiveStream [SessionMessage | Exception ] | None = None
133+ _write_stream : MemoryObjectSendStream [SessionMessage ] | None = None
132134 _write_stream_reader : MemoryObjectReceiveStream [SessionMessage ] | None = None
133135
134136 def __init__ (
@@ -163,7 +165,11 @@ def __init__(
163165 self .is_json_response_enabled = is_json_response_enabled
164166 self ._event_store = event_store
165167 self ._request_streams : dict [
166- RequestId , MemoryObjectSendStream [EventMessage ]
168+ RequestId ,
169+ tuple [
170+ MemoryObjectSendStream [EventMessage ],
171+ MemoryObjectReceiveStream [EventMessage ],
172+ ],
167173 ] = {}
168174 self ._terminated = False
169175
@@ -239,6 +245,19 @@ def _create_event_data(self, event_message: EventMessage) -> dict[str, str]:
239245
240246 return event_data
241247
248+ async def _clean_up_memory_streams (self , request_id : RequestId ) -> None :
249+ """Clean up memory streams for a given request ID."""
250+ if request_id in self ._request_streams :
251+ try :
252+ # Close the request stream
253+ await self ._request_streams [request_id ][0 ].aclose ()
254+ await self ._request_streams [request_id ][1 ].aclose ()
255+ except Exception as e :
256+ logger .debug (f"Error closing memory streams: { e } " )
257+ finally :
258+ # Remove the request stream from the mapping
259+ self ._request_streams .pop (request_id , None )
260+
242261 async def handle_request (self , scope : Scope , receive : Receive , send : Send ) -> None :
243262 """Application entry point that handles all HTTP requests"""
244263 request = Request (scope , receive )
@@ -386,13 +405,11 @@ async def _handle_post_request(
386405
387406 # Extract the request ID outside the try block for proper scope
388407 request_id = str (message .root .id )
389- # Create promise stream for getting response
390- request_stream_writer , request_stream_reader = (
391- anyio .create_memory_object_stream [EventMessage ](0 )
392- )
393-
394408 # Register this stream for the request ID
395- self ._request_streams [request_id ] = request_stream_writer
409+ self ._request_streams [request_id ] = anyio .create_memory_object_stream [
410+ EventMessage
411+ ](0 )
412+ request_stream_reader = self ._request_streams [request_id ][1 ]
396413
397414 if self .is_json_response_enabled :
398415 # Process the message
@@ -441,11 +458,7 @@ async def _handle_post_request(
441458 )
442459 await response (scope , receive , send )
443460 finally :
444- # Clean up the request stream
445- if request_id in self ._request_streams :
446- self ._request_streams .pop (request_id , None )
447- await request_stream_reader .aclose ()
448- await request_stream_writer .aclose ()
461+ await self ._clean_up_memory_streams (request_id )
449462 else :
450463 # Create SSE stream
451464 sse_stream_writer , sse_stream_reader = (
@@ -467,16 +480,12 @@ async def sse_writer():
467480 event_message .message .root ,
468481 JSONRPCResponse | JSONRPCError ,
469482 ):
470- if request_id :
471- self ._request_streams .pop (request_id , None )
472483 break
473484 except Exception as e :
474485 logger .exception (f"Error in SSE writer: { e } " )
475486 finally :
476487 logger .debug ("Closing SSE writer" )
477- # Clean up the request-specific streams
478- if request_id and request_id in self ._request_streams :
479- self ._request_streams .pop (request_id , None )
488+ await self ._clean_up_memory_streams (request_id )
480489
481490 # Create and start EventSourceResponse
482491 # SSE stream mode (original behavior)
@@ -507,9 +516,9 @@ async def sse_writer():
507516 await writer .send (session_message )
508517 except Exception :
509518 logger .exception ("SSE response error" )
510- # Clean up the request stream if something goes wrong
511- if request_id and request_id in self . _request_streams :
512- self ._request_streams . pop (request_id , None )
519+ await sse_stream_writer . aclose ()
520+ await sse_stream_reader . aclose ()
521+ await self ._clean_up_memory_streams (request_id )
513522
514523 except Exception as err :
515524 logger .exception ("Error handling POST request" )
@@ -583,12 +592,11 @@ async def _handle_get_request(self, request: Request, send: Send) -> None:
583592 async def standalone_sse_writer ():
584593 try :
585594 # Create a standalone message stream for server-initiated messages
586- standalone_stream_writer , standalone_stream_reader = (
595+
596+ self ._request_streams [GET_STREAM_KEY ] = (
587597 anyio .create_memory_object_stream [EventMessage ](0 )
588598 )
589-
590- # Register this stream using the special key
591- self ._request_streams [GET_STREAM_KEY ] = standalone_stream_writer
599+ standalone_stream_reader = self ._request_streams [GET_STREAM_KEY ][1 ]
592600
593601 async with sse_stream_writer , standalone_stream_reader :
594602 # Process messages from the standalone stream
@@ -605,8 +613,7 @@ async def standalone_sse_writer():
605613 logger .exception (f"Error in standalone SSE writer: { e } " )
606614 finally :
607615 logger .debug ("Closing standalone SSE writer" )
608- # Remove the stream from request_streams
609- self ._request_streams .pop (GET_STREAM_KEY , None )
616+ await self ._clean_up_memory_streams (GET_STREAM_KEY )
610617
611618 # Create and start EventSourceResponse
612619 response = EventSourceResponse (
@@ -620,8 +627,9 @@ async def standalone_sse_writer():
620627 await response (request .scope , request .receive , send )
621628 except Exception as e :
622629 logger .exception (f"Error in standalone SSE response: { e } " )
623- # Clean up the request stream
624- self ._request_streams .pop (GET_STREAM_KEY , None )
630+ # await sse_stream_writer.aclose()
631+ # await sse_stream_reader.aclose()
632+ await self ._clean_up_memory_streams (GET_STREAM_KEY )
625633
626634 async def _handle_delete_request (self , request : Request , send : Send ) -> None :
627635 """Handle DELETE requests for explicit session termination."""
@@ -638,15 +646,15 @@ async def _handle_delete_request(self, request: Request, send: Send) -> None:
638646 if not await self ._validate_session (request , send ):
639647 return
640648
641- self ._terminate_session ()
649+ await self ._terminate_session ()
642650
643651 response = self ._create_json_response (
644652 None ,
645653 HTTPStatus .OK ,
646654 )
647655 await response (request .scope , request .receive , send )
648656
649- def _terminate_session (self ) -> None :
657+ async def _terminate_session (self ) -> None :
650658 """Terminate the current session, closing all streams.
651659
652660 Once terminated, all requests with this session ID will receive 404 Not Found.
@@ -658,19 +666,27 @@ def _terminate_session(self) -> None:
658666 # We need a copy of the keys to avoid modification during iteration
659667 request_stream_keys = list (self ._request_streams .keys ())
660668
661- # Close all request streams (synchronously)
669+ # Close all request streams asynchronously
662670 for key in request_stream_keys :
663671 try :
664- # Get the stream
665- stream = self ._request_streams .get (key )
666- if stream :
667- # We must use close() here, not aclose() since this is a sync method
668- stream .close ()
672+ await self ._clean_up_memory_streams (key )
669673 except Exception as e :
670674 logger .debug (f"Error closing stream { key } during termination: { e } " )
671675
672676 # Clear the request streams dictionary immediately
673677 self ._request_streams .clear ()
678+ try :
679+ if self ._read_stream_writer is not None :
680+ await self ._read_stream_writer .aclose ()
681+ if self ._read_stream is not None :
682+ await self ._read_stream .aclose ()
683+ if self ._write_stream_reader is not None :
684+ await self ._write_stream_reader .aclose ()
685+ if self ._write_stream is not None :
686+ await self ._write_stream .aclose ()
687+ except Exception as e :
688+ logger .debug (f"Error closing streams: { e } " )
689+ pass
674690
675691 async def _handle_unsupported_request (self , request : Request , send : Send ) -> None :
676692 """Handle unsupported HTTP methods."""
@@ -758,10 +774,10 @@ async def send_event(event_message: EventMessage) -> None:
758774
759775 # If stream ID not in mapping, create it
760776 if stream_id and stream_id not in self ._request_streams :
761- msg_writer , msg_reader = anyio . create_memory_object_stream [
762- EventMessage
763- ]( 0 )
764- self ._request_streams [stream_id ] = msg_writer
777+ self . _request_streams [ stream_id ] = (
778+ anyio . create_memory_object_stream [ EventMessage ]( 0 )
779+ )
780+ msg_reader = self ._request_streams [stream_id ][ 1 ]
765781
766782 # Forward messages to SSE
767783 async with msg_reader :
@@ -783,6 +799,9 @@ async def send_event(event_message: EventMessage) -> None:
783799 await response (request .scope , request .receive , send )
784800 except Exception as e :
785801 logger .exception (f"Error in replay response: { e } " )
802+ finally :
803+ await sse_stream_writer .aclose ()
804+ await sse_stream_reader .aclose ()
786805
787806 except Exception as e :
788807 logger .exception (f"Error replaying events: { e } " )
@@ -820,7 +839,9 @@ async def connect(
820839
821840 # Store the streams
822841 self ._read_stream_writer = read_stream_writer
842+ self ._read_stream = read_stream
823843 self ._write_stream_reader = write_stream_reader
844+ self ._write_stream = write_stream
824845
825846 # Start a task group for message routing
826847 async with anyio .create_task_group () as tg :
@@ -865,7 +886,7 @@ async def message_router():
865886 if request_stream_id in self ._request_streams :
866887 try :
867888 # Send both the message and the event ID
868- await self ._request_streams [request_stream_id ].send (
889+ await self ._request_streams [request_stream_id ][ 0 ] .send (
869890 EventMessage (message , event_id )
870891 )
871892 except (
@@ -874,6 +895,12 @@ async def message_router():
874895 ):
875896 # Stream might be closed, remove from registry
876897 self ._request_streams .pop (request_stream_id , None )
898+ else :
899+ logging .debug (
900+ f"""Request stream { request_stream_id } not found
901+ for message. Still processing message as the client
902+ might reconnect and replay."""
903+ )
877904 except Exception as e :
878905 logger .exception (f"Error in message router: { e } " )
879906
@@ -884,9 +911,20 @@ async def message_router():
884911 # Yield the streams for the caller to use
885912 yield read_stream , write_stream
886913 finally :
887- for stream in list (self ._request_streams .values ()):
914+ for stream_id in list (self ._request_streams .keys ()):
888915 try :
889- await stream .aclose ()
890- except Exception :
916+ await self ._clean_up_memory_streams (stream_id )
917+ except Exception as e :
918+ logger .debug (f"Error closing request stream: { e } " )
891919 pass
892920 self ._request_streams .clear ()
921+
922+ # Clean up the read and write streams
923+ try :
924+ await read_stream_writer .aclose ()
925+ await read_stream .aclose ()
926+ await write_stream_reader .aclose ()
927+ await write_stream .aclose ()
928+ except Exception as e :
929+ logger .debug (f"Error closing streams: { e } " )
930+ pass
0 commit comments