Skip to content

Commit 502bfaf

Browse files
committed
use Handles to ensure selector thread don't raise and halt handling
1 parent c12d733 commit 502bfaf

File tree

2 files changed

+124
-6
lines changed

2 files changed

+124
-6
lines changed

Lib/asyncio/_selector_thread.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@
3030
Protocol,
3131
)
3232

33+
from . import events
34+
3335

3436
class _HasFileno(Protocol):
3537
def fileno(self) -> int:
@@ -259,10 +261,11 @@ def _handle_event(
259261
so exception handler wrappers, etc. are applied.
260262
"""
261263
try:
262-
fileobj, callback = cb_map[fd]
264+
fileobj, handle = cb_map[fd]
263265
except KeyError:
264266
return
265-
callback()
267+
if not handle.cancelled():
268+
handle._run()
266269

267270
def _split_fd(self, fd: _FileDescriptorLike) -> tuple[int, _FileDescriptorLike]:
268271
"""Return fd, file object
@@ -283,30 +286,38 @@ def add_reader(
283286
self, fd: _FileDescriptorLike, callback: Callable[..., None], *args: Any
284287
) -> None:
285288
fd, fileobj = self._split_fd(fd)
286-
self._readers[fd] = (fileobj, functools.partial(callback, *args))
289+
if fd in self._readers:
290+
_, handle = self._readers[fd]
291+
handle.cancel()
292+
self._readers[fd] = (fileobj, events.Handle(callback, args, self._real_loop))
287293
self._wake_selector()
288294

289295
def add_writer(
290296
self, fd: _FileDescriptorLike, callback: Callable[..., None], *args: Any
291297
) -> None:
292298
fd, fileobj = self._split_fd(fd)
293-
self._writers[fd] = (fileobj, functools.partial(callback, *args))
299+
if fd in self._writers:
300+
_, handle = self._writers[fd]
301+
handle.cancel()
302+
self._writers[fd] = (fileobj, events.Handle(callback, args, self._real_loop))
294303
self._wake_selector()
295304

296305
def remove_reader(self, fd: _FileDescriptorLike) -> bool:
297306
fd, _ = self._split_fd(fd)
298307
try:
299-
del self._readers[fd]
308+
_, handle = self._readers.pop(fd)
300309
except KeyError:
301310
return False
311+
handle.cancel()
302312
self._wake_selector()
303313
return True
304314

305315
def remove_writer(self, fd: _FileDescriptorLike) -> bool:
306316
fd, _ = self._split_fd(fd)
307317
try:
308-
del self._writers[fd]
318+
_, handle = self._writers.pop(fd)
309319
except KeyError:
310320
return False
321+
handle.cancel()
311322
self._wake_selector()
312323
return True
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
import asyncio
2+
import select
3+
import socket
4+
import time
5+
import unittest
6+
from asyncio._selector_thread import SelectorThread
7+
from unittest import mock
8+
9+
10+
class SelectorThreadTest((unittest.IsolatedAsyncioTestCase)):
11+
async def asyncSetUp(self):
12+
self._sockets = []
13+
self.selector_thread = SelectorThread(asyncio.get_running_loop())
14+
15+
def socketpair(self):
16+
pair = socket.socketpair()
17+
self._sockets.extend(pair)
18+
return pair
19+
20+
async def asyncTearDown(self):
21+
self.selector_thread.close()
22+
for s in self._sockets:
23+
s.close()
24+
25+
async def test_slow_reader(self):
26+
a, b = self.socketpair()
27+
def recv():
28+
b.recv(100)
29+
mock_recv = mock.MagicMock(wraps=recv)
30+
# make sure select is only called once when
31+
# event loop thread is slow to consume events
32+
a.sendall(b"msg")
33+
with mock.patch("select.select", wraps=select.select) as mock_select:
34+
self.selector_thread.add_reader(b, mock_recv)
35+
time.sleep(0.1)
36+
self.assertEqual(mock_select.call_count, 1)
37+
await asyncio.sleep(0.1)
38+
self.assertEqual(mock_recv.call_count, 1)
39+
40+
async def test_reader_error(self):
41+
# test error handling in callbacks doesn't break handling
42+
a, b = self.socketpair()
43+
a.sendall(b"to_b")
44+
45+
selector_thread = self.selector_thread
46+
47+
# make sure it's called a few
48+
n_failures = 5
49+
counter = 0
50+
bad_recv_done = asyncio.Future()
51+
def bad_recv(sock):
52+
# fail the first n_failures calls, then succeed
53+
nonlocal counter
54+
counter += 1
55+
if counter > n_failures:
56+
bad_recv_done.set_result(None)
57+
sock.recv(10)
58+
return
59+
raise Exception("Testing reader error")
60+
61+
recv_callback = mock.MagicMock(wraps=bad_recv)
62+
63+
exception_handler = mock.MagicMock()
64+
asyncio.get_running_loop().set_exception_handler(exception_handler)
65+
66+
selector_thread.add_reader(b, recv_callback, b)
67+
68+
# make sure start_select is called
69+
# even when recv callback errors,
70+
with mock.patch.object(selector_thread, "_start_select", wraps=selector_thread._start_select) as start_select:
71+
await asyncio.wait_for(bad_recv_done, timeout=10)
72+
73+
# make sure recv is called N + 1 times,
74+
# exception N times,
75+
# start_select at least that many
76+
self.assertEqual(recv_callback.call_count, n_failures + 1)
77+
self.assertEqual(exception_handler.call_count, n_failures)
78+
self.assertGreaterEqual(start_select.call_count, n_failures)
79+
80+
async def test_read_write(self):
81+
a, b = self.socketpair()
82+
read_future = asyncio.Future()
83+
sent = b"asdf"
84+
loop = asyncio.get_running_loop()
85+
selector_thread = self.selector_thread
86+
87+
def write():
88+
a.sendall(sent)
89+
loop.remove_writer(a)
90+
91+
def read():
92+
msg = b.recv(100)
93+
read_future.set_result(msg)
94+
95+
selector_thread.add_reader(b, read)
96+
assert b.fileno() in selector_thread._readers
97+
selector_thread.add_writer(a, write)
98+
assert a.fileno() in selector_thread._writers
99+
msg = await asyncio.wait_for(read_future, timeout=10)
100+
101+
selector_thread.remove_writer(a)
102+
assert a.fileno() not in selector_thread._writers
103+
selector_thread.remove_reader(b)
104+
assert b.fileno() not in selector_thread._readers
105+
a.close()
106+
b.close()
107+
assert msg == sent

0 commit comments

Comments
 (0)