Skip to content

Commit d3d9940

Browse files
committed
gh-142352: simplify start_tls buffered-data test
Replace the proxy-based combined/separate cases with a deterministic buffered-ClientHello check and a ping/pong verification after start_tls.
1 parent a3905bd commit d3d9940

File tree

1 file changed

+21
-63
lines changed

1 file changed

+21
-63
lines changed

Lib/test/test_asyncio/test_streams.py

Lines changed: 21 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
import asyncio
1616
from test.test_asyncio import utils as test_utils
17-
from test.support import socket_helper
17+
from test.support import socket_helper, LOOPBACK_TIMEOUT
1818

1919

2020
def tearDownModule():
@@ -819,60 +819,34 @@ async def client(addr):
819819
self.assertEqual(msg1, b"hello world 1!\n")
820820
self.assertEqual(msg2, b"hello world 2!\n")
821821

822-
def _test_start_tls_buffered_data(self, send_combined):
822+
@unittest.skipIf(ssl is None, 'No ssl module')
823+
def test_start_tls_buffered_data(self):
823824
# gh-142352: test start_tls() with buffered data
824825

825-
PROXY_LINE = b"PROXY TCP4 127.0.0.1 127.0.0.1 54321 443\r\n"
826-
TEST_MESSAGE = b"hello world\n"
827-
828-
async def pipe(src, dst):
829-
try:
830-
while data := await src.read(4096):
831-
dst.write(data)
832-
await dst.drain()
833-
finally:
834-
dst.close()
835-
await dst.wait_closed()
836-
837-
async def proxy_handler(client_reader, client_writer, backend_addr):
838-
backend_reader, backend_writer = await asyncio.open_connection(
839-
*backend_addr)
840-
try:
841-
tls_data = await client_reader.read(4096)
842-
if send_combined:
843-
backend_writer.write(PROXY_LINE + tls_data)
844-
else:
845-
backend_writer.write(PROXY_LINE)
846-
await backend_writer.drain()
847-
await asyncio.sleep(0.01)
848-
backend_writer.write(tls_data)
849-
await backend_writer.drain()
850-
851-
await asyncio.gather(
852-
pipe(client_reader, backend_writer),
853-
pipe(backend_reader, client_writer),
854-
)
855-
finally:
856-
client_writer.close()
857-
backend_writer.close()
858-
await asyncio.gather(
859-
client_writer.wait_closed(),
860-
backend_writer.wait_closed(),
861-
return_exceptions=True
862-
)
863-
864826
async def server_handler(client_reader, client_writer):
865-
self.assertEqual(await client_reader.readline(), PROXY_LINE)
827+
# Wait for TLS ClientHello to be buffered before start_tls().
828+
await asyncio.wait_for(
829+
client_reader._wait_for_data('test_start_tls_buffered_data'),
830+
LOOPBACK_TIMEOUT,
831+
)
832+
self.assertTrue(client_reader._buffer)
866833
await client_writer.start_tls(test_utils.simple_server_sslcontext())
867-
self.assertEqual(await client_reader.readline(), TEST_MESSAGE)
834+
835+
line = await asyncio.wait_for(client_reader.readline(), LOOPBACK_TIMEOUT)
836+
self.assertEqual(line, b"ping\n")
837+
client_writer.write(b"pong\n")
838+
await client_writer.drain()
868839
client_writer.close()
869840
await client_writer.wait_closed()
870841

871842
async def client(addr):
872-
_, writer = await asyncio.open_connection(*addr)
843+
reader, writer = await asyncio.open_connection(*addr)
873844
await writer.start_tls(test_utils.simple_client_sslcontext())
874-
writer.write(TEST_MESSAGE)
845+
846+
writer.write(b"ping\n")
875847
await writer.drain()
848+
line = await asyncio.wait_for(reader.readline(), LOOPBACK_TIMEOUT)
849+
self.assertEqual(line, b"pong\n")
876850
writer.close()
877851
await writer.wait_closed()
878852

@@ -881,31 +855,15 @@ async def run_test():
881855
server_handler, socket_helper.HOSTv4, 0)
882856
server_addr = server.sockets[0].getsockname()
883857

884-
proxy = await asyncio.start_server(
885-
lambda r, w: proxy_handler(r, w, server_addr),
886-
socket_helper.HOSTv4, 0)
887-
proxy_addr = proxy.sockets[0].getsockname()
888-
889-
await asyncio.wait_for(client(proxy_addr), timeout=5.0)
890-
proxy.close()
858+
await asyncio.wait_for(client(server_addr), timeout=5.0)
891859
server.close()
892-
await asyncio.gather(proxy.wait_closed(), server.wait_closed())
860+
await server.wait_closed()
893861

894862
messages = []
895863
self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
896864
self.loop.run_until_complete(run_test())
897865
self.assertEqual(messages, [])
898866

899-
@unittest.skipIf(ssl is None, 'No ssl module')
900-
def test_start_tls_buffered_data_combined(self):
901-
# gh-142352: Test TLS data buffered before start_tls
902-
self._test_start_tls_buffered_data(send_combined=True)
903-
904-
@unittest.skipIf(ssl is None, 'No ssl module')
905-
def test_start_tls_buffered_data_separate(self):
906-
# gh-142352: Test TLS data sent separately
907-
self._test_start_tls_buffered_data(send_combined=False)
908-
909867
def test_streamreader_constructor_without_loop(self):
910868
with self.assertRaisesRegex(RuntimeError, 'no current event loop'):
911869
asyncio.StreamReader()

0 commit comments

Comments
 (0)