Skip to content

Commit b2ed3c3

Browse files
committed
asyncio.streams: transfer buffered data to SSL layer in start_tls()
1 parent cd4d0ae commit b2ed3c3

File tree

3 files changed

+189
-0
lines changed

3 files changed

+189
-0
lines changed

Lib/asyncio/base_events.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1311,6 +1311,30 @@ async def _sendfile_fallback(self, transp, file, offset, count):
13111311
file.seek(offset + total_sent)
13121312
await proto.restore()
13131313

1314+
def _transfer_buffered_data_to_ssl(self, protocol, ssl_protocol):
1315+
"""Transfer buffered data from StreamReader to SSL incoming BIO.
1316+
1317+
When using start_tls() mid-connection (e.g., after reading a
1318+
PROXY protocol header), any data already buffered in the
1319+
StreamReader would be lost. This transfers that data to the
1320+
SSL layer so the handshake can proceed.
1321+
1322+
Note: This only works with StreamReaderProtocol (used by the
1323+
streams API). Custom Protocol implementations that buffer data
1324+
must handle this manually before calling start_tls().
1325+
"""
1326+
if not hasattr(protocol, '_stream_reader'):
1327+
return
1328+
1329+
stream_reader = protocol._stream_reader
1330+
if stream_reader is None:
1331+
return
1332+
1333+
buffer = stream_reader._buffer
1334+
if buffer:
1335+
ssl_protocol._incoming.write(buffer)
1336+
buffer.clear()
1337+
13141338
async def start_tls(self, transport, protocol, sslcontext, *,
13151339
server_side=False,
13161340
server_hostname=None,
@@ -1341,6 +1365,8 @@ async def start_tls(self, transport, protocol, sslcontext, *,
13411365
ssl_shutdown_timeout=ssl_shutdown_timeout,
13421366
call_connection_made=False)
13431367

1368+
self._transfer_buffered_data_to_ssl(protocol, ssl_protocol)
1369+
13441370
# Pause early so that "ssl_protocol.data_received()" doesn't
13451371
# have a chance to get called before "ssl_protocol.connection_made()".
13461372
transport.pause_reading()

Lib/test/test_asyncio/test_streams.py

Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -819,6 +819,165 @@ async def client(addr):
819819
self.assertEqual(msg1, b"hello world 1!\n")
820820
self.assertEqual(msg2, b"hello world 2!\n")
821821

822+
def _run_test_start_tls_behind_proxy(self, send_combined):
823+
"""Test start_tls() when TLS ClientHello arrives with PROXY header.
824+
825+
This simulates HAProxy with send-proxy, where the PROXY protocol
826+
header and TLS handshake data may arrive in the same TCP segment.
827+
Without the fix, buffered TLS data would be lost after start_tls().
828+
"""
829+
830+
def reverse_message(data):
831+
return data.strip()[::-1] + b'\n'
832+
833+
test_message = b"hello world\n"
834+
expected_response = reverse_message(test_message)
835+
836+
class TCPProxyServer:
837+
"""A simple TCP proxy server that adds a PROXY protocol header
838+
before forwarding data to the target server."""
839+
840+
PROXY_LINE = b"PROXY TCP4 127.0.0.1 127.0.0.1 54321 443\r\n"
841+
842+
def __init__(self, loop, target_host, target_port):
843+
self.loop = loop
844+
self.target_host = target_host
845+
self.target_port = target_port
846+
self.server = None
847+
848+
async def _pipe(self, reader, writer):
849+
try:
850+
while True:
851+
data = await reader.read(4096)
852+
if not data:
853+
break
854+
writer.write(data)
855+
await writer.drain()
856+
finally:
857+
writer.close()
858+
await writer.wait_closed()
859+
860+
async def handle_client(self, client_reader, client_writer):
861+
# Connecting to the target server
862+
remote_reader, remote_writer = await asyncio.open_connection(
863+
self.target_host, self.target_port)
864+
865+
# Reading data from the client (TLS ClientHello)
866+
tls_data = await client_reader.read(4096)
867+
868+
if send_combined:
869+
# send everything together: PROXY + TLS data
870+
remote_writer.write(self.PROXY_LINE + tls_data)
871+
await remote_writer.drain()
872+
else:
873+
# send TLS data after the PROXY line
874+
remote_writer.write(self.PROXY_LINE)
875+
await remote_writer.drain()
876+
await asyncio.sleep(0.01)
877+
remote_writer.write(tls_data)
878+
await remote_writer.drain()
879+
880+
await asyncio.gather(
881+
self._pipe(client_reader, remote_writer),
882+
self._pipe(remote_reader, client_writer),
883+
)
884+
885+
def start(self):
886+
sock = socket.create_server(('127.0.0.1', 0))
887+
self.server = self.loop.run_until_complete(
888+
asyncio.start_server(self.handle_client, sock=sock))
889+
return sock.getsockname()
890+
891+
def stop(self):
892+
if self.server:
893+
self.server.close()
894+
self.loop.run_until_complete(self.server.wait_closed())
895+
self.server = None
896+
897+
class ServerWithSendProxySupport:
898+
"""A server that supports the PROXY protocol and starts TLS
899+
after receiving the PROXY header."""
900+
901+
def __init__(self, test_case, loop):
902+
self.test = test_case
903+
self.server = None
904+
self.loop = loop
905+
906+
async def handle_client(self, client_reader, client_writer):
907+
proxy_line = await client_reader.readline()
908+
self.test.assertEqual(proxy_line, TCPProxyServer.PROXY_LINE)
909+
910+
# Now we can start TLS
911+
self.test.assertIsNone(
912+
client_writer.get_extra_info('sslcontext'))
913+
await client_writer.start_tls(
914+
test_utils.simple_server_sslcontext()
915+
)
916+
self.test.assertIsNotNone(
917+
client_writer.get_extra_info('sslcontext'))
918+
919+
data = await client_reader.readline()
920+
client_writer.write(reverse_message(data))
921+
await client_writer.drain()
922+
client_writer.close()
923+
await client_writer.wait_closed()
924+
925+
def start(self):
926+
sock = socket.create_server(('127.0.0.1', 0))
927+
self.server = self.loop.run_until_complete(
928+
asyncio.start_server(self.handle_client,
929+
sock=sock))
930+
return sock.getsockname()
931+
932+
def stop(self):
933+
if self.server is not None:
934+
self.server.close()
935+
self.loop.run_until_complete(self.server.wait_closed())
936+
self.server = None
937+
938+
async def client(addr, test_case):
939+
reader, writer = await asyncio.open_connection(*addr)
940+
941+
test_case.assertIsNone(writer.get_extra_info('sslcontext'))
942+
await writer.start_tls(test_utils.simple_client_sslcontext())
943+
test_case.assertIsNotNone(writer.get_extra_info('sslcontext'))
944+
945+
writer.write(test_message)
946+
await writer.drain()
947+
msgback = await reader.readline()
948+
writer.close()
949+
await writer.wait_closed()
950+
return msgback
951+
952+
messages = []
953+
self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
954+
955+
server = ServerWithSendProxySupport(self, self.loop)
956+
server_addr = server.start()
957+
958+
proxy = TCPProxyServer(self.loop, *server_addr)
959+
proxy_addr = proxy.start()
960+
961+
msg = self.loop.run_until_complete(
962+
asyncio.wait_for(client(proxy_addr, self), timeout=5.0)
963+
)
964+
965+
proxy.stop()
966+
server.stop()
967+
968+
self.assertEqual(messages, [])
969+
self.assertEqual(msg, expected_response)
970+
971+
@unittest.skipIf(ssl is None, 'No ssl module')
972+
def test_start_tls_behind_proxy_send_combined(self):
973+
# Test with sending PROXY header and TLS data in one packet
974+
self._run_test_start_tls_behind_proxy(send_combined=True)
975+
976+
@unittest.skipIf(ssl is None, 'No ssl module')
977+
def test_start_tls_behind_proxy_send_separate(self):
978+
# Test with sending PROXY header and TLS data in separate packets
979+
self._run_test_start_tls_behind_proxy(send_combined=False)
980+
822981
def test_streamreader_constructor_without_loop(self):
823982
with self.assertRaisesRegex(RuntimeError, 'no current event loop'):
824983
asyncio.StreamReader()
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
Fix :meth:`asyncio.StreamWriter.start_tls` to transfer buffered data from
2+
:class:`~asyncio.StreamReader` to the SSL layer, preventing data loss when
3+
upgrading a connection to TLS mid-stream (e.g., when implementing PROXY
4+
protocol support).

0 commit comments

Comments
 (0)