Skip to content

Commit 86d80bf

Browse files
committed
gh-142352: asyncio.streams: transfer buffered data to SSL layer in start_tls()
1 parent 0e0d51c commit 86d80bf

File tree

3 files changed

+86
-0
lines changed

3 files changed

+86
-0
lines changed

Lib/asyncio/base_events.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1341,6 +1341,14 @@ async def start_tls(self, transport, protocol, sslcontext, *,
13411341
ssl_shutdown_timeout=ssl_shutdown_timeout,
13421342
call_connection_made=False)
13431343

1344+
# gh-142352: move buffered StreamReader data to SSLProtocol
1345+
stream_reader = getattr(protocol, '_stream_reader', None)
1346+
if stream_reader is not None:
1347+
buffer = stream_reader._buffer
1348+
if buffer:
1349+
ssl_protocol._incoming.write(buffer)
1350+
buffer.clear()
1351+
13441352
# Pause early so that "ssl_protocol.data_received()" doesn't
13451353
# have a chance to get called before "ssl_protocol.connection_made()".
13461354
transport.pause_reading()

Lib/test/test_asyncio/test_streams.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -819,6 +819,80 @@ async def client(addr):
819819
self.assertEqual(msg1, b"hello world 1!\n")
820820
self.assertEqual(msg2, b"hello world 2!\n")
821821

822+
def _test_start_tls_buffered_data(self, send_combined):
823+
# gh-142352: test start_tls() with buffered data
824+
825+
PROXY_LINE = b"PROXY TCP4 127.0.0.1 127.0.0.1 54321 443\r\n"
826+
TEST_MESSAGE = b"hello world\n"
827+
828+
async def pipe(src, dst):
829+
while data := await src.read(4096):
830+
dst.write(data)
831+
await dst.drain()
832+
833+
async def proxy_handler(client_reader, client_writer, backend_addr):
834+
backend_reader, backend_writer = await asyncio.open_connection(
835+
*backend_addr)
836+
tls_data = await client_reader.read(4096)
837+
if send_combined:
838+
backend_writer.write(PROXY_LINE + tls_data)
839+
else:
840+
backend_writer.write(PROXY_LINE)
841+
await backend_writer.drain()
842+
await asyncio.sleep(0.01)
843+
backend_writer.write(tls_data)
844+
await backend_writer.drain()
845+
846+
await asyncio.gather(
847+
pipe(client_reader, backend_writer),
848+
pipe(backend_reader, client_writer),
849+
)
850+
851+
async def server_handler(client_reader, client_writer):
852+
self.assertEqual(await client_reader.readline(), PROXY_LINE)
853+
await client_writer.start_tls(test_utils.simple_server_sslcontext())
854+
self.assertEqual(await client_reader.readline(), TEST_MESSAGE)
855+
client_writer.close()
856+
await client_writer.wait_closed()
857+
858+
async def client(addr):
859+
_, writer = await asyncio.open_connection(*addr)
860+
await writer.start_tls(test_utils.simple_client_sslcontext())
861+
writer.write(TEST_MESSAGE)
862+
await writer.drain()
863+
writer.close()
864+
await writer.wait_closed()
865+
866+
async def run_test():
867+
server = await asyncio.start_server(
868+
server_handler, socket_helper.HOSTv4, 0)
869+
server_addr = server.sockets[0].getsockname()
870+
871+
proxy = await asyncio.start_server(
872+
lambda r, w: proxy_handler(r, w, server_addr),
873+
socket_helper.HOSTv4, 0)
874+
proxy_addr = proxy.sockets[0].getsockname()
875+
876+
await asyncio.wait_for(client(proxy_addr), timeout=5.0)
877+
proxy.close()
878+
server.close()
879+
await asyncio.gather(proxy.wait_closed(), server.wait_closed())
880+
881+
messages = []
882+
self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
883+
self.loop.run_until_complete(run_test())
884+
self.assertEqual(messages, [])
885+
886+
@unittest.skipIf(ssl is None, 'No ssl module')
887+
def test_start_tls_buffered_data_combined(self):
888+
# gh-142352: Test TLS data buffered before start_tls
889+
self._test_start_tls_buffered_data(send_combined=True)
890+
891+
@unittest.skipIf(ssl is None, 'No ssl module')
892+
def test_start_tls_buffered_data_separate(self):
893+
# gh-142352: Test TLS data sent separately
894+
self._test_start_tls_buffered_data(send_combined=False)
895+
822896
def test_streamreader_constructor_without_loop(self):
823897
with self.assertRaisesRegex(RuntimeError, 'no current event loop'):
824898
asyncio.StreamReader()
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
Fix :meth:`asyncio.StreamWriter.start_tls` to transfer buffered data from
2+
:class:`~asyncio.StreamReader` to the SSL layer, preventing data loss when
3+
upgrading a connection to TLS mid-stream (e.g., when implementing PROXY
4+
protocol support).

0 commit comments

Comments
 (0)