diff --git a/src/mss/linux/xshmgetimage.py b/src/mss/linux/xshmgetimage.py index 9858552..5628910 100644 --- a/src/mss/linux/xshmgetimage.py +++ b/src/mss/linux/xshmgetimage.py @@ -190,8 +190,19 @@ def _grab_impl_xshmgetimage(self, monitor: Monitor) -> ScreenShot: new_size = monitor["width"] * monitor["height"] * 4 # Slicing the memoryview creates a new memoryview that points to the relevant subregion. Making this and # then copying it into a fresh bytearray is much faster than slicing the mmap object. - img_mv = memoryview(self._buf)[:new_size] - img_data = bytearray(img_mv) + try: + img_mv: memoryview | None = memoryview(self._buf)[:new_size] + assert img_mv is not None # noqa: S101 + img_data = bytearray(img_mv) + finally: + # Imagine that an exception happened in the above code, such as an asynchronous KeyboardInterrupt. Let's + # imagine it happened after we created img_mv, but while we were populating img_data, a process that can + # take a few milliseconds. That exception includes this stack frame, and hence holds a reference to + # img_mv. If the exception unwinds to the enclosing `with mss() as sct:` block, then self._buf.close() + # would be executed, to close the mmapped region. But img_mv still exists, and self._buf.close() would + # throw an exception, because it can't close the region while references exist. To prevent that, remove + # the reference to img_mv during the stack unwind here. + img_mv = None return self.cls_image(img_data, monitor) diff --git a/src/tests/test_gnu_linux.py b/src/tests/test_gnu_linux.py index dcb1d99..048d6f3 100644 --- a/src/tests/test_gnu_linux.py +++ b/src/tests/test_gnu_linux.py @@ -4,6 +4,7 @@ from __future__ import annotations +import builtins import ctypes.util import platform from ctypes import CFUNCTYPE, POINTER, _Pointer, c_int @@ -315,3 +316,32 @@ def test_shm_fallback() -> None: sct.grab(sct.monitors[0]) # Ensure that it really did have to fall back; otherwise, we'd need to change how we test this case. assert sct.shm_status == mss.linux.xshmgetimage.ShmStatus.UNAVAILABLE + + +def test_exception_while_holding_memoryview(monkeypatch: pytest.MonkeyPatch) -> None: + """Verify that an exception at a particular point doesn't prevent cleanup. + + The particular point is the window when the XShmGetImage's mmapped + buffer has a memoryview still outstanding, and the pixel data is + being copied into a bytearray. This can take a few milliseconds. + """ + # Force an exception during bytearray(img_mv) + real_bytearray = builtins.bytearray + + def boom(*args: list, **kwargs: dict[str, Any]) -> bytearray: + # Only explode when called with the memoryview (the code path we care about). + if len(args) > 0 and isinstance(args[0], memoryview): + # We still need to eliminate args from the stack frame, just like the fix. + del args, kwargs + msg = "Boom!" + raise RuntimeError(msg) + return real_bytearray(*args, **kwargs) + + # We have to be careful about the order in which we catch things. If we were to catch and discard the exception + # before the MSS object closes, it won't trigger the bug. That's why we have the pytest.raises outside the + # mss.mss block. In addition, we do as much as we can before patching bytearray, to limit its scope. + with pytest.raises(RuntimeError, match="Boom!"), mss.mss(backend="xshmgetimage") as sct: # noqa: PT012 + monitor = sct.monitors[0] + with monkeypatch.context() as m: + m.setattr(builtins, "bytearray", boom) + sct.grab(monitor)