@@ -819,6 +819,80 @@ async def client(addr):
819819 self .assertEqual (msg1 , b"hello world 1!\n " )
820820 self .assertEqual (msg2 , b"hello world 2!\n " )
821821
822+ def _test_start_tls_buffered_data (self , send_combined ):
823+ # gh-142352: test start_tls() with buffered data
824+
825+ PROXY_LINE = b"PROXY TCP4 127.0.0.1 127.0.0.1 54321 443\r \n "
826+ TEST_MESSAGE = b"hello world\n "
827+
828+ async def pipe (src , dst ):
829+ while data := await src .read (4096 ):
830+ dst .write (data )
831+ await dst .drain ()
832+
833+ async def proxy_handler (client_reader , client_writer , backend_addr ):
834+ backend_reader , backend_writer = await asyncio .open_connection (
835+ * backend_addr )
836+ tls_data = await client_reader .read (4096 )
837+ if send_combined :
838+ backend_writer .write (PROXY_LINE + tls_data )
839+ else :
840+ backend_writer .write (PROXY_LINE )
841+ await backend_writer .drain ()
842+ await asyncio .sleep (0.01 )
843+ backend_writer .write (tls_data )
844+ await backend_writer .drain ()
845+
846+ await asyncio .gather (
847+ pipe (client_reader , backend_writer ),
848+ pipe (backend_reader , client_writer ),
849+ )
850+
851+ async def server_handler (client_reader , client_writer ):
852+ self .assertEqual (await client_reader .readline (), PROXY_LINE )
853+ await client_writer .start_tls (test_utils .simple_server_sslcontext ())
854+ self .assertEqual (await client_reader .readline (), TEST_MESSAGE )
855+ client_writer .close ()
856+ await client_writer .wait_closed ()
857+
858+ async def client (addr ):
859+ _ , writer = await asyncio .open_connection (* addr )
860+ await writer .start_tls (test_utils .simple_client_sslcontext ())
861+ writer .write (TEST_MESSAGE )
862+ await writer .drain ()
863+ writer .close ()
864+ await writer .wait_closed ()
865+
866+ async def run_test ():
867+ server = await asyncio .start_server (
868+ server_handler , socket_helper .HOSTv4 , 0 )
869+ server_addr = server .sockets [0 ].getsockname ()
870+
871+ proxy = await asyncio .start_server (
872+ lambda r , w : proxy_handler (r , w , server_addr ),
873+ socket_helper .HOSTv4 , 0 )
874+ proxy_addr = proxy .sockets [0 ].getsockname ()
875+
876+ await asyncio .wait_for (client (proxy_addr ), timeout = 5.0 )
877+ proxy .close ()
878+ server .close ()
879+ await asyncio .gather (proxy .wait_closed (), server .wait_closed ())
880+
881+ messages = []
882+ self .loop .set_exception_handler (lambda loop , ctx : messages .append (ctx ))
883+ self .loop .run_until_complete (run_test ())
884+ self .assertEqual (messages , [])
885+
886+ @unittest .skipIf (ssl is None , 'No ssl module' )
887+ def test_start_tls_buffered_data_combined (self ):
888+ # gh-142352: Test TLS data buffered before start_tls
889+ self ._test_start_tls_buffered_data (send_combined = True )
890+
891+ @unittest .skipIf (ssl is None , 'No ssl module' )
892+ def test_start_tls_buffered_data_separate (self ):
893+ # gh-142352: Test TLS data sent separately
894+ self ._test_start_tls_buffered_data (send_combined = False )
895+
822896 def test_streamreader_constructor_without_loop (self ):
823897 with self .assertRaisesRegex (RuntimeError , 'no current event loop' ):
824898 asyncio .StreamReader ()
0 commit comments