@@ -864,3 +864,129 @@ def route_error(self, request_id: str | int, error: ErrorData) -> bool:
864864 await server_to_client_receive .aclose ()
865865 await client_to_server_send .aclose ()
866866 await client_to_server_receive .aclose ()
867+
868+
869+ @pytest .mark .anyio
870+ async def test_response_routing_skips_non_matching_routers () -> None :
871+ """Test that routing continues to next router when first doesn't match."""
872+ from mcp .shared .session import ResponseRouter
873+
874+ server_to_client_send , server_to_client_receive = anyio .create_memory_object_stream [SessionMessage ](10 )
875+ client_to_server_send , client_to_server_receive = anyio .create_memory_object_stream [SessionMessage ](10 )
876+
877+ # Track which routers were called
878+ router_calls : list [str ] = []
879+ response_received = anyio .Event ()
880+
881+ class NonMatchingRouter (ResponseRouter ):
882+ def route_response (self , request_id : str | int , response : dict [str , Any ]) -> bool :
883+ router_calls .append ("non_matching_response" )
884+ return False # Doesn't handle it
885+
886+ def route_error (self , request_id : str | int , error : ErrorData ) -> bool :
887+ router_calls .append ("non_matching_error" )
888+ return False # Doesn't handle it
889+
890+ class MatchingRouter (ResponseRouter ):
891+ def route_response (self , request_id : str | int , response : dict [str , Any ]) -> bool :
892+ router_calls .append ("matching_response" )
893+ response_received .set ()
894+ return True # Handles it
895+
896+ def route_error (self , request_id : str | int , error : ErrorData ) -> bool :
897+ router_calls .append ("matching_error" )
898+ response_received .set ()
899+ return True # Handles it
900+
901+ try :
902+ async with ServerSession (
903+ client_to_server_receive ,
904+ server_to_client_send ,
905+ InitializationOptions (
906+ server_name = "test-server" ,
907+ server_version = "1.0.0" ,
908+ capabilities = ServerCapabilities (),
909+ ),
910+ ) as server_session :
911+ # Add non-matching router first, then matching router
912+ server_session .add_response_router (NonMatchingRouter ())
913+ server_session .add_response_router (MatchingRouter ())
914+
915+ # Send a response - should skip first router and be handled by second
916+ response = JSONRPCResponse (jsonrpc = "2.0" , id = "test-req-1" , result = {"status" : "ok" })
917+ message = SessionMessage (message = JSONRPCMessage (response ))
918+ await client_to_server_send .send (message )
919+
920+ with anyio .fail_after (5 ):
921+ await response_received .wait ()
922+
923+ # Verify both routers were called (first returned False, second returned True)
924+ assert router_calls == ["non_matching_response" , "matching_response" ]
925+ finally :
926+ await server_to_client_send .aclose ()
927+ await server_to_client_receive .aclose ()
928+ await client_to_server_send .aclose ()
929+ await client_to_server_receive .aclose ()
930+
931+
932+ @pytest .mark .anyio
933+ async def test_error_routing_skips_non_matching_routers () -> None :
934+ """Test that error routing continues to next router when first doesn't match."""
935+ from mcp .shared .session import ResponseRouter
936+
937+ server_to_client_send , server_to_client_receive = anyio .create_memory_object_stream [SessionMessage ](10 )
938+ client_to_server_send , client_to_server_receive = anyio .create_memory_object_stream [SessionMessage ](10 )
939+
940+ # Track which routers were called
941+ router_calls : list [str ] = []
942+ error_received = anyio .Event ()
943+
944+ class NonMatchingRouter (ResponseRouter ):
945+ def route_response (self , request_id : str | int , response : dict [str , Any ]) -> bool :
946+ router_calls .append ("non_matching_response" )
947+ return False
948+
949+ def route_error (self , request_id : str | int , error : ErrorData ) -> bool :
950+ router_calls .append ("non_matching_error" )
951+ return False # Doesn't handle it
952+
953+ class MatchingRouter (ResponseRouter ):
954+ def route_response (self , request_id : str | int , response : dict [str , Any ]) -> bool :
955+ router_calls .append ("matching_response" )
956+ return True
957+
958+ def route_error (self , request_id : str | int , error : ErrorData ) -> bool :
959+ router_calls .append ("matching_error" )
960+ error_received .set ()
961+ return True # Handles it
962+
963+ try :
964+ async with ServerSession (
965+ client_to_server_receive ,
966+ server_to_client_send ,
967+ InitializationOptions (
968+ server_name = "test-server" ,
969+ server_version = "1.0.0" ,
970+ capabilities = ServerCapabilities (),
971+ ),
972+ ) as server_session :
973+ # Add non-matching router first, then matching router
974+ server_session .add_response_router (NonMatchingRouter ())
975+ server_session .add_response_router (MatchingRouter ())
976+
977+ # Send an error - should skip first router and be handled by second
978+ error_data = ErrorData (code = - 32600 , message = "Test error" )
979+ error_response = JSONRPCError (jsonrpc = "2.0" , id = "test-req-2" , error = error_data )
980+ message = SessionMessage (message = JSONRPCMessage (error_response ))
981+ await client_to_server_send .send (message )
982+
983+ with anyio .fail_after (5 ):
984+ await error_received .wait ()
985+
986+ # Verify both routers were called (first returned False, second returned True)
987+ assert router_calls == ["non_matching_error" , "matching_error" ]
988+ finally :
989+ await server_to_client_send .aclose ()
990+ await server_to_client_receive .aclose ()
991+ await client_to_server_send .aclose ()
992+ await client_to_server_receive .aclose ()
0 commit comments