@@ -308,6 +308,93 @@ def _check_content_type(self, request: Request) -> bool:
308308
309309 return any (part == CONTENT_TYPE_JSON for part in content_type_parts )
310310
311+ def _is_async_operation_response (self , response_message : JSONRPCMessage ) -> bool :
312+ """Check if response is for an async operation that should keep stream open."""
313+ try :
314+ if not isinstance (response_message .root , JSONRPCResponse ):
315+ return False
316+
317+ result = response_message .root .result
318+ if not result :
319+ return False
320+
321+ # Check if result has _operation with token
322+ if hasattr (result , "__getitem__" ) and "_operation" in result :
323+ operation = result ["_operation" ] # type: ignore
324+ if hasattr (operation , "__getitem__" ) and "token" in operation :
325+ return bool (operation ["token" ]) # type: ignore
326+
327+ return False
328+ except (TypeError , KeyError , AttributeError ):
329+ return False
330+
331+ async def _handle_sse_mode (
332+ self ,
333+ message : JSONRPCMessage ,
334+ request : Request ,
335+ writer : MemoryObjectSendStream [SessionMessage | Exception ],
336+ request_id : str ,
337+ request_stream_reader : MemoryObjectReceiveStream [EventMessage ],
338+ scope : Scope ,
339+ receive : Receive ,
340+ send : Send ,
341+ ) -> None :
342+ """Handle SSE response mode."""
343+ # Create SSE stream
344+ sse_stream_writer , sse_stream_reader = anyio .create_memory_object_stream [dict [str , str ]](0 )
345+
346+ async def sse_writer ():
347+ # Get the request ID from the incoming request message
348+ try :
349+ async with sse_stream_writer , request_stream_reader :
350+ # Process messages from the request-specific stream
351+ async for event_message in request_stream_reader :
352+ # Build the event data
353+ event_data = self ._create_event_data (event_message )
354+ await sse_stream_writer .send (event_data )
355+
356+ # If response, remove from pending streams and close
357+ if isinstance (
358+ event_message .message .root ,
359+ JSONRPCResponse | JSONRPCError ,
360+ ):
361+ break
362+ except Exception :
363+ logger .exception ("Error in SSE writer" )
364+ finally :
365+ logger .debug ("Closing SSE writer" )
366+ await self ._clean_up_memory_streams (request_id )
367+
368+ # Create and start EventSourceResponse
369+ # SSE stream mode (original behavior)
370+ # Set up headers
371+ headers = {
372+ "Cache-Control" : "no-cache, no-transform" ,
373+ "Connection" : "keep-alive" ,
374+ "Content-Type" : CONTENT_TYPE_SSE ,
375+ ** ({MCP_SESSION_ID_HEADER : self .mcp_session_id } if self .mcp_session_id else {}),
376+ }
377+ response = EventSourceResponse (
378+ content = sse_stream_reader ,
379+ data_sender_callable = sse_writer ,
380+ headers = headers ,
381+ )
382+
383+ # Start the SSE response (this will send headers immediately)
384+ try :
385+ # First send the response to establish the SSE connection
386+ async with anyio .create_task_group () as tg :
387+ tg .start_soon (response , scope , receive , send )
388+ # Then send the message to be processed by the server
389+ metadata = ServerMessageMetadata (request_context = request )
390+ session_message = SessionMessage (message , metadata = metadata )
391+ await writer .send (session_message )
392+ except Exception :
393+ logger .exception ("SSE response error" )
394+ await sse_stream_writer .aclose ()
395+ await sse_stream_reader .aclose ()
396+ await self ._clean_up_memory_streams (request_id )
397+
311398 async def _handle_post_request (self , scope : Scope , request : Request , receive : Receive , send : Send ) -> None :
312399 """Handle POST requests containing JSON-RPC messages."""
313400 writer = self ._read_stream_writer
@@ -420,15 +507,7 @@ async def _handle_post_request(self, scope: Scope, request: Request, receive: Re
420507 # At this point we should have a response
421508 if response_message :
422509 # Check if this is an async operation response - keep stream open
423- if (
424- isinstance (response_message .root , JSONRPCResponse )
425- and response_message .root .result
426- and "_operation" in response_message .root .result
427- and (
428- ("token" in response_message .root .result ["_operation" ])
429- and response_message .root .result ["_operation" ]["token" ]
430- )
431- ):
510+ if self ._is_async_operation_response (response_message ):
432511 # This is an async operation - keep the stream open for elicitation/sampling
433512 should_pop_stream = False
434513
@@ -455,61 +534,10 @@ async def _handle_post_request(self, scope: Scope, request: Request, receive: Re
455534 if should_pop_stream :
456535 await self ._clean_up_memory_streams (request_id )
457536 else :
458- # Create SSE stream
459- sse_stream_writer , sse_stream_reader = anyio .create_memory_object_stream [dict [str , str ]](0 )
460-
461- async def sse_writer ():
462- # Get the request ID from the incoming request message
463- try :
464- async with sse_stream_writer , request_stream_reader :
465- # Process messages from the request-specific stream
466- async for event_message in request_stream_reader :
467- # Build the event data
468- event_data = self ._create_event_data (event_message )
469- await sse_stream_writer .send (event_data )
470-
471- # If response, remove from pending streams and close
472- if isinstance (
473- event_message .message .root ,
474- JSONRPCResponse | JSONRPCError ,
475- ):
476- break
477- except Exception :
478- logger .exception ("Error in SSE writer" )
479- finally :
480- logger .debug ("Closing SSE writer" )
481- await self ._clean_up_memory_streams (request_id )
482-
483- # Create and start EventSourceResponse
484- # SSE stream mode (original behavior)
485- # Set up headers
486- headers = {
487- "Cache-Control" : "no-cache, no-transform" ,
488- "Connection" : "keep-alive" ,
489- "Content-Type" : CONTENT_TYPE_SSE ,
490- ** ({MCP_SESSION_ID_HEADER : self .mcp_session_id } if self .mcp_session_id else {}),
491- }
492- response = EventSourceResponse (
493- content = sse_stream_reader ,
494- data_sender_callable = sse_writer ,
495- headers = headers ,
537+ await self ._handle_sse_mode (
538+ message , request , writer , request_id , request_stream_reader , scope , receive , send
496539 )
497540
498- # Start the SSE response (this will send headers immediately)
499- try :
500- # First send the response to establish the SSE connection
501- async with anyio .create_task_group () as tg :
502- tg .start_soon (response , scope , receive , send )
503- # Then send the message to be processed by the server
504- metadata = ServerMessageMetadata (request_context = request )
505- session_message = SessionMessage (message , metadata = metadata )
506- await writer .send (session_message )
507- except Exception :
508- logger .exception ("SSE response error" )
509- await sse_stream_writer .aclose ()
510- await sse_stream_reader .aclose ()
511- await self ._clean_up_memory_streams (request_id )
512-
513541 except Exception as err :
514542 logger .exception ("Error handling POST request" )
515543 response = self ._create_error_response (
0 commit comments