@@ -819,6 +819,165 @@ async def client(addr):
819819 self .assertEqual (msg1 , b"hello world 1!\n " )
820820 self .assertEqual (msg2 , b"hello world 2!\n " )
821821
822+ def _run_test_start_tls_behind_proxy (self , send_combined ):
823+ """Test start_tls() when TLS ClientHello arrives with PROXY header.
824+
825+ This simulates HAProxy with send-proxy, where the PROXY protocol
826+ header and TLS handshake data may arrive in the same TCP segment.
827+ Without the fix, buffered TLS data would be lost after start_tls().
828+ """
829+
830+ def reverse_message (data ):
831+ return data .strip ()[::- 1 ] + b'\n '
832+
833+ test_message = b"hello world\n "
834+ expected_response = reverse_message (test_message )
835+
836+ class TCPProxyServer :
837+ """A simple TCP proxy server that adds a PROXY protocol header
838+ before forwarding data to the target server."""
839+
840+ PROXY_LINE = b"PROXY TCP4 127.0.0.1 127.0.0.1 54321 443\r \n "
841+
842+ def __init__ (self , loop , target_host , target_port ):
843+ self .loop = loop
844+ self .target_host = target_host
845+ self .target_port = target_port
846+ self .server = None
847+
848+ async def _pipe (self , reader , writer ):
849+ try :
850+ while True :
851+ data = await reader .read (4096 )
852+ if not data :
853+ break
854+ writer .write (data )
855+ await writer .drain ()
856+ finally :
857+ writer .close ()
858+ await writer .wait_closed ()
859+
860+ async def handle_client (self , client_reader , client_writer ):
861+ # Connecting to the target server
862+ remote_reader , remote_writer = await asyncio .open_connection (
863+ self .target_host , self .target_port )
864+
865+ # Reading data from the client (TLS ClientHello)
866+ tls_data = await client_reader .read (4096 )
867+
868+ if send_combined :
869+ # send everything together: PROXY + TLS data
870+ remote_writer .write (self .PROXY_LINE + tls_data )
871+ await remote_writer .drain ()
872+ else :
873+ # send TLS data after the PROXY line
874+ remote_writer .write (self .PROXY_LINE )
875+ await remote_writer .drain ()
876+ await asyncio .sleep (0.01 )
877+ remote_writer .write (tls_data )
878+ await remote_writer .drain ()
879+
880+ await asyncio .gather (
881+ self ._pipe (client_reader , remote_writer ),
882+ self ._pipe (remote_reader , client_writer ),
883+ )
884+
885+ def start (self ):
886+ sock = socket .create_server (('127.0.0.1' , 0 ))
887+ self .server = self .loop .run_until_complete (
888+ asyncio .start_server (self .handle_client , sock = sock ))
889+ return sock .getsockname ()
890+
891+ def stop (self ):
892+ if self .server :
893+ self .server .close ()
894+ self .loop .run_until_complete (self .server .wait_closed ())
895+ self .server = None
896+
897+ class ServerWithSendProxySupport :
898+ """A server that supports the PROXY protocol and starts TLS
899+ after receiving the PROXY header."""
900+
901+ def __init__ (self , test_case , loop ):
902+ self .test = test_case
903+ self .server = None
904+ self .loop = loop
905+
906+ async def handle_client (self , client_reader , client_writer ):
907+ proxy_line = await client_reader .readline ()
908+ self .test .assertEqual (proxy_line , TCPProxyServer .PROXY_LINE )
909+
910+ # Now we can start TLS
911+ self .test .assertIsNone (
912+ client_writer .get_extra_info ('sslcontext' ))
913+ await client_writer .start_tls (
914+ test_utils .simple_server_sslcontext ()
915+ )
916+ self .test .assertIsNotNone (
917+ client_writer .get_extra_info ('sslcontext' ))
918+
919+ data = await client_reader .readline ()
920+ client_writer .write (reverse_message (data ))
921+ await client_writer .drain ()
922+ client_writer .close ()
923+ await client_writer .wait_closed ()
924+
925+ def start (self ):
926+ sock = socket .create_server (('127.0.0.1' , 0 ))
927+ self .server = self .loop .run_until_complete (
928+ asyncio .start_server (self .handle_client ,
929+ sock = sock ))
930+ return sock .getsockname ()
931+
932+ def stop (self ):
933+ if self .server is not None :
934+ self .server .close ()
935+ self .loop .run_until_complete (self .server .wait_closed ())
936+ self .server = None
937+
938+ async def client (addr , test_case ):
939+ reader , writer = await asyncio .open_connection (* addr )
940+
941+ test_case .assertIsNone (writer .get_extra_info ('sslcontext' ))
942+ await writer .start_tls (test_utils .simple_client_sslcontext ())
943+ test_case .assertIsNotNone (writer .get_extra_info ('sslcontext' ))
944+
945+ writer .write (test_message )
946+ await writer .drain ()
947+ msgback = await reader .readline ()
948+ writer .close ()
949+ await writer .wait_closed ()
950+ return msgback
951+
952+ messages = []
953+ self .loop .set_exception_handler (lambda loop , ctx : messages .append (ctx ))
954+
955+ server = ServerWithSendProxySupport (self , self .loop )
956+ server_addr = server .start ()
957+
958+ proxy = TCPProxyServer (self .loop , * server_addr )
959+ proxy_addr = proxy .start ()
960+
961+ msg = self .loop .run_until_complete (
962+ asyncio .wait_for (client (proxy_addr , self ), timeout = 5.0 )
963+ )
964+
965+ proxy .stop ()
966+ server .stop ()
967+
968+ self .assertEqual (messages , [])
969+ self .assertEqual (msg , expected_response )
970+
971+ @unittest .skipIf (ssl is None , 'No ssl module' )
972+ def test_start_tls_behind_proxy_send_combined (self ):
973+ # Test with sending PROXY header and TLS data in one packet
974+ self ._run_test_start_tls_behind_proxy (send_combined = True )
975+
976+ @unittest .skipIf (ssl is None , 'No ssl module' )
977+ def test_start_tls_behind_proxy_send_separate (self ):
978+ # Test with sending PROXY header and TLS data in separate packets
979+ self ._run_test_start_tls_behind_proxy (send_combined = False )
980+
822981 def test_streamreader_constructor_without_loop (self ):
823982 with self .assertRaisesRegex (RuntimeError , 'no current event loop' ):
824983 asyncio .StreamReader ()
0 commit comments