From 68c4a6ece35eead49a8ca75537c2301ed25cc411 Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Mon, 3 Feb 2025 11:13:09 -0500 Subject: [PATCH 1/7] PYTHON-5053 - AsyncMongoClient.close() should await all background tasks --- pymongo/asynchronous/mongo_client.py | 2 ++ pymongo/asynchronous/monitor.py | 6 ++++++ pymongo/asynchronous/server.py | 2 ++ pymongo/asynchronous/topology.py | 2 ++ pymongo/periodic_executor.py | 2 ++ pymongo/synchronous/mongo_client.py | 2 ++ pymongo/synchronous/monitor.py | 6 ++++++ pymongo/synchronous/server.py | 2 ++ pymongo/synchronous/topology.py | 2 ++ 9 files changed, 26 insertions(+) diff --git a/pymongo/asynchronous/mongo_client.py b/pymongo/asynchronous/mongo_client.py index cf7de19c2f..962da3def3 100644 --- a/pymongo/asynchronous/mongo_client.py +++ b/pymongo/asynchronous/mongo_client.py @@ -1564,6 +1564,8 @@ async def close(self) -> None: if self._encrypter: # TODO: PYTHON-1921 Encrypted MongoClients cannot be re-opened. await self._encrypter.close() + if not _IS_SYNC: + await self._kill_cursors_executor.join() self._closed = True if not _IS_SYNC: diff --git a/pymongo/asynchronous/monitor.py b/pymongo/asynchronous/monitor.py index ad1bc70aba..1d3d685a10 100644 --- a/pymongo/asynchronous/monitor.py +++ b/pymongo/asynchronous/monitor.py @@ -111,6 +111,8 @@ async def close(self) -> None: open() restarts the monitor after closing. """ self.gc_safe_close() + if not _IS_SYNC: + await self._executor.join() async def join(self, timeout: Optional[int] = None) -> None: """Wait for the monitor to stop.""" @@ -191,6 +193,8 @@ def gc_safe_close(self) -> None: async def close(self) -> None: self.gc_safe_close() + if not _IS_SYNC: + await self._executor.join() await self._rtt_monitor.close() # Increment the generation and maybe close the socket. If the executor # thread has the socket checked out, it will be closed when checked in. @@ -460,6 +464,8 @@ async def close(self) -> None: self.gc_safe_close() # Increment the generation and maybe close the socket. If the executor # thread has the socket checked out, it will be closed when checked in. + if not _IS_SYNC: + await self._executor.join() await self._pool.reset() async def add_sample(self, sample: float) -> None: diff --git a/pymongo/asynchronous/server.py b/pymongo/asynchronous/server.py index 72f22584e2..6cf827faab 100644 --- a/pymongo/asynchronous/server.py +++ b/pymongo/asynchronous/server.py @@ -115,6 +115,8 @@ async def close(self) -> None: ) await self._monitor.close() + if not _IS_SYNC: + await self._monitor.join() await self._pool.close() def request_check(self) -> None: diff --git a/pymongo/asynchronous/topology.py b/pymongo/asynchronous/topology.py index 6d67710a7e..0b98b59366 100644 --- a/pymongo/asynchronous/topology.py +++ b/pymongo/asynchronous/topology.py @@ -705,6 +705,8 @@ async def close(self) -> None: # Stop SRV polling thread. if self._srv_monitor: await self._srv_monitor.close() + if not _IS_SYNC: + await self._srv_monitor.join() self._opened = False self._closed = True diff --git a/pymongo/periodic_executor.py b/pymongo/periodic_executor.py index 9b10f6e7e3..f51a988728 100644 --- a/pymongo/periodic_executor.py +++ b/pymongo/periodic_executor.py @@ -75,6 +75,8 @@ def close(self, dummy: Any = None) -> None: callback; see monitor.py. """ self._stopped = True + if self._task is not None: + self._task.cancel() async def join(self, timeout: Optional[int] = None) -> None: if self._task is not None: diff --git a/pymongo/synchronous/mongo_client.py b/pymongo/synchronous/mongo_client.py index 706623c214..d764871f32 100644 --- a/pymongo/synchronous/mongo_client.py +++ b/pymongo/synchronous/mongo_client.py @@ -1558,6 +1558,8 @@ def close(self) -> None: if self._encrypter: # TODO: PYTHON-1921 Encrypted MongoClients cannot be re-opened. self._encrypter.close() + if not _IS_SYNC: + self._kill_cursors_executor.join() self._closed = True if not _IS_SYNC: diff --git a/pymongo/synchronous/monitor.py b/pymongo/synchronous/monitor.py index df4130d4ab..81a33d4b3e 100644 --- a/pymongo/synchronous/monitor.py +++ b/pymongo/synchronous/monitor.py @@ -111,6 +111,8 @@ def close(self) -> None: open() restarts the monitor after closing. """ self.gc_safe_close() + if not _IS_SYNC: + self._executor.join() def join(self, timeout: Optional[int] = None) -> None: """Wait for the monitor to stop.""" @@ -191,6 +193,8 @@ def gc_safe_close(self) -> None: def close(self) -> None: self.gc_safe_close() + if not _IS_SYNC: + self._executor.join() self._rtt_monitor.close() # Increment the generation and maybe close the socket. If the executor # thread has the socket checked out, it will be closed when checked in. @@ -460,6 +464,8 @@ def close(self) -> None: self.gc_safe_close() # Increment the generation and maybe close the socket. If the executor # thread has the socket checked out, it will be closed when checked in. + if not _IS_SYNC: + self._executor.join() self._pool.reset() def add_sample(self, sample: float) -> None: diff --git a/pymongo/synchronous/server.py b/pymongo/synchronous/server.py index ed48cc6cc8..4dc2b7a0a0 100644 --- a/pymongo/synchronous/server.py +++ b/pymongo/synchronous/server.py @@ -115,6 +115,8 @@ def close(self) -> None: ) self._monitor.close() + if not _IS_SYNC: + self._monitor.join() self._pool.close() def request_check(self) -> None: diff --git a/pymongo/synchronous/topology.py b/pymongo/synchronous/topology.py index b03269ae43..867e84c466 100644 --- a/pymongo/synchronous/topology.py +++ b/pymongo/synchronous/topology.py @@ -703,6 +703,8 @@ def close(self) -> None: # Stop SRV polling thread. if self._srv_monitor: self._srv_monitor.close() + if not _IS_SYNC: + self._srv_monitor.join() self._opened = False self._closed = True From d14d8e807fd2959550700e450a84f291a1ccd25d Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Mon, 3 Feb 2025 16:15:45 -0500 Subject: [PATCH 2/7] Don't call join() inside close() --- pymongo/asynchronous/mongo_client.py | 4 ++-- pymongo/asynchronous/monitor.py | 10 ++++------ pymongo/asynchronous/topology.py | 2 ++ pymongo/synchronous/mongo_client.py | 4 ++-- pymongo/synchronous/monitor.py | 10 ++++------ pymongo/synchronous/topology.py | 2 ++ 6 files changed, 16 insertions(+), 16 deletions(-) diff --git a/pymongo/asynchronous/mongo_client.py b/pymongo/asynchronous/mongo_client.py index 962da3def3..fbe026b973 100644 --- a/pymongo/asynchronous/mongo_client.py +++ b/pymongo/asynchronous/mongo_client.py @@ -1559,13 +1559,13 @@ async def close(self) -> None: # Stop the periodic task thread and then send pending killCursor # requests before closing the topology. self._kill_cursors_executor.close() + if not _IS_SYNC: + await self._kill_cursors_executor.join() await self._process_kill_cursors() await self._topology.close() if self._encrypter: # TODO: PYTHON-1921 Encrypted MongoClients cannot be re-opened. await self._encrypter.close() - if not _IS_SYNC: - await self._kill_cursors_executor.join() self._closed = True if not _IS_SYNC: diff --git a/pymongo/asynchronous/monitor.py b/pymongo/asynchronous/monitor.py index 1d3d685a10..be7fd07f1c 100644 --- a/pymongo/asynchronous/monitor.py +++ b/pymongo/asynchronous/monitor.py @@ -111,8 +111,6 @@ async def close(self) -> None: open() restarts the monitor after closing. """ self.gc_safe_close() - if not _IS_SYNC: - await self._executor.join() async def join(self, timeout: Optional[int] = None) -> None: """Wait for the monitor to stop.""" @@ -191,10 +189,12 @@ def gc_safe_close(self) -> None: self._rtt_monitor.gc_safe_close() self.cancel_check() + async def join(self, timeout: Optional[int] = None) -> None: + await self._executor.join(timeout) + await self._rtt_monitor.join() + async def close(self) -> None: self.gc_safe_close() - if not _IS_SYNC: - await self._executor.join() await self._rtt_monitor.close() # Increment the generation and maybe close the socket. If the executor # thread has the socket checked out, it will be closed when checked in. @@ -464,8 +464,6 @@ async def close(self) -> None: self.gc_safe_close() # Increment the generation and maybe close the socket. If the executor # thread has the socket checked out, it will be closed when checked in. - if not _IS_SYNC: - await self._executor.join() await self._pool.reset() async def add_sample(self, sample: float) -> None: diff --git a/pymongo/asynchronous/topology.py b/pymongo/asynchronous/topology.py index 0b98b59366..d295ad8c86 100644 --- a/pymongo/asynchronous/topology.py +++ b/pymongo/asynchronous/topology.py @@ -520,6 +520,8 @@ async def _process_change( and self._description.topology_type not in SRV_POLLING_TOPOLOGIES ): await self._srv_monitor.close() + if not _IS_SYNC: + await self._srv_monitor.join() # Clear the pool from a failed heartbeat. if reset_pool: diff --git a/pymongo/synchronous/mongo_client.py b/pymongo/synchronous/mongo_client.py index d764871f32..be363299c3 100644 --- a/pymongo/synchronous/mongo_client.py +++ b/pymongo/synchronous/mongo_client.py @@ -1553,13 +1553,13 @@ def close(self) -> None: # Stop the periodic task thread and then send pending killCursor # requests before closing the topology. self._kill_cursors_executor.close() + if not _IS_SYNC: + self._kill_cursors_executor.join() self._process_kill_cursors() self._topology.close() if self._encrypter: # TODO: PYTHON-1921 Encrypted MongoClients cannot be re-opened. self._encrypter.close() - if not _IS_SYNC: - self._kill_cursors_executor.join() self._closed = True if not _IS_SYNC: diff --git a/pymongo/synchronous/monitor.py b/pymongo/synchronous/monitor.py index 81a33d4b3e..e0e0407623 100644 --- a/pymongo/synchronous/monitor.py +++ b/pymongo/synchronous/monitor.py @@ -111,8 +111,6 @@ def close(self) -> None: open() restarts the monitor after closing. """ self.gc_safe_close() - if not _IS_SYNC: - self._executor.join() def join(self, timeout: Optional[int] = None) -> None: """Wait for the monitor to stop.""" @@ -191,10 +189,12 @@ def gc_safe_close(self) -> None: self._rtt_monitor.gc_safe_close() self.cancel_check() + def join(self, timeout: Optional[int] = None) -> None: + self._executor.join(timeout) + self._rtt_monitor.join() + def close(self) -> None: self.gc_safe_close() - if not _IS_SYNC: - self._executor.join() self._rtt_monitor.close() # Increment the generation and maybe close the socket. If the executor # thread has the socket checked out, it will be closed when checked in. @@ -464,8 +464,6 @@ def close(self) -> None: self.gc_safe_close() # Increment the generation and maybe close the socket. If the executor # thread has the socket checked out, it will be closed when checked in. - if not _IS_SYNC: - self._executor.join() self._pool.reset() def add_sample(self, sample: float) -> None: diff --git a/pymongo/synchronous/topology.py b/pymongo/synchronous/topology.py index 867e84c466..2fcd0bfd51 100644 --- a/pymongo/synchronous/topology.py +++ b/pymongo/synchronous/topology.py @@ -520,6 +520,8 @@ def _process_change( and self._description.topology_type not in SRV_POLLING_TOPOLOGIES ): self._srv_monitor.close() + if not _IS_SYNC: + self._srv_monitor.join() # Clear the pool from a failed heartbeat. if reset_pool: From 6c6a32da16fa3650bc79d671ddd89055dab7cc2b Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Tue, 4 Feb 2025 11:51:27 -0500 Subject: [PATCH 3/7] Store tasks to be awaited inside Topology --- pymongo/asynchronous/mongo_client.py | 6 ++++-- pymongo/asynchronous/monitor.py | 8 ++++---- pymongo/asynchronous/server.py | 2 -- pymongo/asynchronous/topology.py | 24 ++++++++++++++++++------ pymongo/synchronous/mongo_client.py | 6 ++++-- pymongo/synchronous/monitor.py | 8 ++++---- pymongo/synchronous/server.py | 2 -- pymongo/synchronous/topology.py | 24 ++++++++++++++++++------ 8 files changed, 52 insertions(+), 28 deletions(-) diff --git a/pymongo/asynchronous/mongo_client.py b/pymongo/asynchronous/mongo_client.py index fbe026b973..9d6d6cd8db 100644 --- a/pymongo/asynchronous/mongo_client.py +++ b/pymongo/asynchronous/mongo_client.py @@ -1559,14 +1559,16 @@ async def close(self) -> None: # Stop the periodic task thread and then send pending killCursor # requests before closing the topology. self._kill_cursors_executor.close() - if not _IS_SYNC: - await self._kill_cursors_executor.join() await self._process_kill_cursors() await self._topology.close() if self._encrypter: # TODO: PYTHON-1921 Encrypted MongoClients cannot be re-opened. await self._encrypter.close() self._closed = True + if not _IS_SYNC: + self._topology._monitor_tasks.append(self._kill_cursors_executor) # type: ignore[arg-type] + join_tasks = [t.join() for t in self._topology._monitor_tasks] # type: ignore[func-returns-value] + await asyncio.gather(*join_tasks) if not _IS_SYNC: # Add support for contextlib.aclosing. diff --git a/pymongo/asynchronous/monitor.py b/pymongo/asynchronous/monitor.py index be7fd07f1c..ea0416e7f0 100644 --- a/pymongo/asynchronous/monitor.py +++ b/pymongo/asynchronous/monitor.py @@ -112,9 +112,9 @@ async def close(self) -> None: """ self.gc_safe_close() - async def join(self, timeout: Optional[int] = None) -> None: + async def join(self) -> None: """Wait for the monitor to stop.""" - await self._executor.join(timeout) + await self._executor.join() def request_check(self) -> None: """If the monitor is sleeping, wake it soon.""" @@ -189,8 +189,8 @@ def gc_safe_close(self) -> None: self._rtt_monitor.gc_safe_close() self.cancel_check() - async def join(self, timeout: Optional[int] = None) -> None: - await self._executor.join(timeout) + async def join(self) -> None: + await self._executor.join() await self._rtt_monitor.join() async def close(self) -> None: diff --git a/pymongo/asynchronous/server.py b/pymongo/asynchronous/server.py index 6cf827faab..72f22584e2 100644 --- a/pymongo/asynchronous/server.py +++ b/pymongo/asynchronous/server.py @@ -115,8 +115,6 @@ async def close(self) -> None: ) await self._monitor.close() - if not _IS_SYNC: - await self._monitor.join() await self._pool.close() def request_check(self) -> None: diff --git a/pymongo/asynchronous/topology.py b/pymongo/asynchronous/topology.py index d295ad8c86..f7de5bd926 100644 --- a/pymongo/asynchronous/topology.py +++ b/pymongo/asynchronous/topology.py @@ -16,6 +16,7 @@ from __future__ import annotations +import asyncio import logging import os import queue @@ -29,7 +30,7 @@ from pymongo import _csot, common, helpers_shared, periodic_executor from pymongo.asynchronous.client_session import _ServerSession, _ServerSessionPool -from pymongo.asynchronous.monitor import SrvMonitor +from pymongo.asynchronous.monitor import MonitorBase, SrvMonitor from pymongo.asynchronous.pool import Pool from pymongo.asynchronous.server import Server from pymongo.errors import ( @@ -207,6 +208,9 @@ async def target() -> bool: if self._settings.fqdn is not None and not self._settings.load_balanced: self._srv_monitor = SrvMonitor(self, self._settings) + # Stores all monitor tasks that need to be joined on close or server selection + self._monitor_tasks: list[MonitorBase] = [] + async def open(self) -> None: """Start monitoring, or restart after a fork. @@ -241,6 +245,7 @@ async def open(self) -> None: # Close servers and clear the pools. for server in self._servers.values(): await server.close() + self._monitor_tasks.append(server._monitor) # Reset the session pool to avoid duplicate sessions in # the child process. self._session_pool.reset() @@ -288,10 +293,17 @@ async def select_servers( selector, server_timeout, operation, operation_id, address ) - return [ + servers = [ cast(Server, self.get_server_by_address(sd.address)) for sd in server_descriptions ] + if not _IS_SYNC and self._monitor_tasks: + joins = [t.join() for t in self._monitor_tasks] # type: ignore[func-returns-value] + await asyncio.gather(*joins) + self._monitor_tasks = [] + + return servers + async def _select_servers_loop( self, selector: Callable[[Selection], Selection], @@ -520,8 +532,7 @@ async def _process_change( and self._description.topology_type not in SRV_POLLING_TOPOLOGIES ): await self._srv_monitor.close() - if not _IS_SYNC: - await self._srv_monitor.join() + self._monitor_tasks.append(self._srv_monitor) # Clear the pool from a failed heartbeat. if reset_pool: @@ -697,6 +708,7 @@ async def close(self) -> None: old_td = self._description for server in self._servers.values(): await server.close() + self._monitor_tasks.append(server._monitor) # Mark all servers Unknown. self._description = self._description.reset() @@ -707,8 +719,7 @@ async def close(self) -> None: # Stop SRV polling thread. if self._srv_monitor: await self._srv_monitor.close() - if not _IS_SYNC: - await self._srv_monitor.join() + self._monitor_tasks.append(self._srv_monitor) self._opened = False self._closed = True @@ -948,6 +959,7 @@ async def _update_servers(self) -> None: for address, server in list(self._servers.items()): if not self._description.has_server(address): await server.close() + self._monitor_tasks.append(server._monitor) self._servers.pop(address) def _create_pool_for_server(self, address: _Address) -> Pool: diff --git a/pymongo/synchronous/mongo_client.py b/pymongo/synchronous/mongo_client.py index be363299c3..b24e22f11c 100644 --- a/pymongo/synchronous/mongo_client.py +++ b/pymongo/synchronous/mongo_client.py @@ -1553,14 +1553,16 @@ def close(self) -> None: # Stop the periodic task thread and then send pending killCursor # requests before closing the topology. self._kill_cursors_executor.close() - if not _IS_SYNC: - self._kill_cursors_executor.join() self._process_kill_cursors() self._topology.close() if self._encrypter: # TODO: PYTHON-1921 Encrypted MongoClients cannot be re-opened. self._encrypter.close() self._closed = True + if not _IS_SYNC: + self._topology._monitor_tasks.append(self._kill_cursors_executor) # type: ignore[arg-type] + join_tasks = [t.join() for t in self._topology._monitor_tasks] # type: ignore[func-returns-value] + asyncio.gather(*join_tasks) if not _IS_SYNC: # Add support for contextlib.closing. diff --git a/pymongo/synchronous/monitor.py b/pymongo/synchronous/monitor.py index e0e0407623..0787a28a0a 100644 --- a/pymongo/synchronous/monitor.py +++ b/pymongo/synchronous/monitor.py @@ -112,9 +112,9 @@ def close(self) -> None: """ self.gc_safe_close() - def join(self, timeout: Optional[int] = None) -> None: + def join(self) -> None: """Wait for the monitor to stop.""" - self._executor.join(timeout) + self._executor.join() def request_check(self) -> None: """If the monitor is sleeping, wake it soon.""" @@ -189,8 +189,8 @@ def gc_safe_close(self) -> None: self._rtt_monitor.gc_safe_close() self.cancel_check() - def join(self, timeout: Optional[int] = None) -> None: - self._executor.join(timeout) + def join(self) -> None: + self._executor.join() self._rtt_monitor.join() def close(self) -> None: diff --git a/pymongo/synchronous/server.py b/pymongo/synchronous/server.py index 4dc2b7a0a0..ed48cc6cc8 100644 --- a/pymongo/synchronous/server.py +++ b/pymongo/synchronous/server.py @@ -115,8 +115,6 @@ def close(self) -> None: ) self._monitor.close() - if not _IS_SYNC: - self._monitor.join() self._pool.close() def request_check(self) -> None: diff --git a/pymongo/synchronous/topology.py b/pymongo/synchronous/topology.py index 2fcd0bfd51..d98928d0a2 100644 --- a/pymongo/synchronous/topology.py +++ b/pymongo/synchronous/topology.py @@ -16,6 +16,7 @@ from __future__ import annotations +import asyncio import logging import os import queue @@ -61,7 +62,7 @@ writable_server_selector, ) from pymongo.synchronous.client_session import _ServerSession, _ServerSessionPool -from pymongo.synchronous.monitor import SrvMonitor +from pymongo.synchronous.monitor import MonitorBase, SrvMonitor from pymongo.synchronous.pool import Pool from pymongo.synchronous.server import Server from pymongo.topology_description import ( @@ -207,6 +208,9 @@ def target() -> bool: if self._settings.fqdn is not None and not self._settings.load_balanced: self._srv_monitor = SrvMonitor(self, self._settings) + # Stores all monitor tasks that need to be joined on close or server selection + self._monitor_tasks: list[MonitorBase] = [] + def open(self) -> None: """Start monitoring, or restart after a fork. @@ -241,6 +245,7 @@ def open(self) -> None: # Close servers and clear the pools. for server in self._servers.values(): server.close() + self._monitor_tasks.append(server._monitor) # Reset the session pool to avoid duplicate sessions in # the child process. self._session_pool.reset() @@ -288,10 +293,17 @@ def select_servers( selector, server_timeout, operation, operation_id, address ) - return [ + servers = [ cast(Server, self.get_server_by_address(sd.address)) for sd in server_descriptions ] + if not _IS_SYNC and self._monitor_tasks: + joins = [t.join() for t in self._monitor_tasks] # type: ignore[func-returns-value] + asyncio.gather(*joins) + self._monitor_tasks = [] + + return servers + def _select_servers_loop( self, selector: Callable[[Selection], Selection], @@ -520,8 +532,7 @@ def _process_change( and self._description.topology_type not in SRV_POLLING_TOPOLOGIES ): self._srv_monitor.close() - if not _IS_SYNC: - self._srv_monitor.join() + self._monitor_tasks.append(self._srv_monitor) # Clear the pool from a failed heartbeat. if reset_pool: @@ -695,6 +706,7 @@ def close(self) -> None: old_td = self._description for server in self._servers.values(): server.close() + self._monitor_tasks.append(server._monitor) # Mark all servers Unknown. self._description = self._description.reset() @@ -705,8 +717,7 @@ def close(self) -> None: # Stop SRV polling thread. if self._srv_monitor: self._srv_monitor.close() - if not _IS_SYNC: - self._srv_monitor.join() + self._monitor_tasks.append(self._srv_monitor) self._opened = False self._closed = True @@ -946,6 +957,7 @@ def _update_servers(self) -> None: for address, server in list(self._servers.items()): if not self._description.has_server(address): server.close() + self._monitor_tasks.append(server._monitor) self._servers.pop(address) def _create_pool_for_server(self, address: _Address) -> Pool: From 861dbb5f653513d2bdb2f6b46edff7ae8ea972cf Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Tue, 4 Feb 2025 15:25:59 -0500 Subject: [PATCH 4/7] Address review --- pymongo/asynchronous/mongo_client.py | 9 +++++++-- pymongo/asynchronous/monitor.py | 3 +-- pymongo/asynchronous/topology.py | 26 ++++++++++++++++++-------- pymongo/synchronous/mongo_client.py | 9 +++++++-- pymongo/synchronous/monitor.py | 3 +-- pymongo/synchronous/topology.py | 26 ++++++++++++++++++-------- 6 files changed, 52 insertions(+), 24 deletions(-) diff --git a/pymongo/asynchronous/mongo_client.py b/pymongo/asynchronous/mongo_client.py index 9d6d6cd8db..0e13608548 100644 --- a/pymongo/asynchronous/mongo_client.py +++ b/pymongo/asynchronous/mongo_client.py @@ -1566,8 +1566,13 @@ async def close(self) -> None: await self._encrypter.close() self._closed = True if not _IS_SYNC: - self._topology._monitor_tasks.append(self._kill_cursors_executor) # type: ignore[arg-type] - join_tasks = [t.join() for t in self._topology._monitor_tasks] # type: ignore[func-returns-value] + join_tasks = [self._kill_cursors_executor] + try: + while self._topology._monitor_tasks: + join_tasks.append(self._topology._monitor_tasks.pop()) + except IndexError: + pass + join_tasks = [t.join() for t in join_tasks] # type: ignore[func-returns-value] await asyncio.gather(*join_tasks) if not _IS_SYNC: diff --git a/pymongo/asynchronous/monitor.py b/pymongo/asynchronous/monitor.py index ea0416e7f0..cb41ee26b3 100644 --- a/pymongo/asynchronous/monitor.py +++ b/pymongo/asynchronous/monitor.py @@ -190,8 +190,7 @@ def gc_safe_close(self) -> None: self.cancel_check() async def join(self) -> None: - await self._executor.join() - await self._rtt_monitor.join() + await asyncio.gather(*[self._executor.join(), self._rtt_monitor.join()]) async def close(self) -> None: self.gc_safe_close() diff --git a/pymongo/asynchronous/topology.py b/pymongo/asynchronous/topology.py index f7de5bd926..c64a449e78 100644 --- a/pymongo/asynchronous/topology.py +++ b/pymongo/asynchronous/topology.py @@ -245,7 +245,8 @@ async def open(self) -> None: # Close servers and clear the pools. for server in self._servers.values(): await server.close() - self._monitor_tasks.append(server._monitor) + if not _IS_SYNC: + self._monitor_tasks.append(server._monitor) # Reset the session pool to avoid duplicate sessions in # the child process. self._session_pool.reset() @@ -298,9 +299,14 @@ async def select_servers( ] if not _IS_SYNC and self._monitor_tasks: - joins = [t.join() for t in self._monitor_tasks] # type: ignore[func-returns-value] - await asyncio.gather(*joins) - self._monitor_tasks = [] + join_tasks = [] + try: + while self._monitor_tasks: + join_tasks.append(self._monitor_tasks.pop()) + except IndexError: + pass + join_tasks = [t.join() for t in join_tasks] # type: ignore[func-returns-value] + await asyncio.gather(*join_tasks) return servers @@ -532,7 +538,8 @@ async def _process_change( and self._description.topology_type not in SRV_POLLING_TOPOLOGIES ): await self._srv_monitor.close() - self._monitor_tasks.append(self._srv_monitor) + if not _IS_SYNC: + self._monitor_tasks.append(self._srv_monitor) # Clear the pool from a failed heartbeat. if reset_pool: @@ -708,7 +715,8 @@ async def close(self) -> None: old_td = self._description for server in self._servers.values(): await server.close() - self._monitor_tasks.append(server._monitor) + if not _IS_SYNC: + self._monitor_tasks.append(server._monitor) # Mark all servers Unknown. self._description = self._description.reset() @@ -719,7 +727,8 @@ async def close(self) -> None: # Stop SRV polling thread. if self._srv_monitor: await self._srv_monitor.close() - self._monitor_tasks.append(self._srv_monitor) + if not _IS_SYNC: + self._monitor_tasks.append(self._srv_monitor) self._opened = False self._closed = True @@ -959,7 +968,8 @@ async def _update_servers(self) -> None: for address, server in list(self._servers.items()): if not self._description.has_server(address): await server.close() - self._monitor_tasks.append(server._monitor) + if not _IS_SYNC: + self._monitor_tasks.append(server._monitor) self._servers.pop(address) def _create_pool_for_server(self, address: _Address) -> Pool: diff --git a/pymongo/synchronous/mongo_client.py b/pymongo/synchronous/mongo_client.py index b24e22f11c..5a755e0464 100644 --- a/pymongo/synchronous/mongo_client.py +++ b/pymongo/synchronous/mongo_client.py @@ -1560,8 +1560,13 @@ def close(self) -> None: self._encrypter.close() self._closed = True if not _IS_SYNC: - self._topology._monitor_tasks.append(self._kill_cursors_executor) # type: ignore[arg-type] - join_tasks = [t.join() for t in self._topology._monitor_tasks] # type: ignore[func-returns-value] + join_tasks = [self._kill_cursors_executor] + try: + while self._topology._monitor_tasks: + join_tasks.append(self._topology._monitor_tasks.pop()) + except IndexError: + pass + join_tasks = [t.join() for t in join_tasks] # type: ignore[func-returns-value] asyncio.gather(*join_tasks) if not _IS_SYNC: diff --git a/pymongo/synchronous/monitor.py b/pymongo/synchronous/monitor.py index 0787a28a0a..b3a3de7104 100644 --- a/pymongo/synchronous/monitor.py +++ b/pymongo/synchronous/monitor.py @@ -190,8 +190,7 @@ def gc_safe_close(self) -> None: self.cancel_check() def join(self) -> None: - self._executor.join() - self._rtt_monitor.join() + asyncio.gather(*[self._executor.join(), self._rtt_monitor.join()]) def close(self) -> None: self.gc_safe_close() diff --git a/pymongo/synchronous/topology.py b/pymongo/synchronous/topology.py index d98928d0a2..c83f163444 100644 --- a/pymongo/synchronous/topology.py +++ b/pymongo/synchronous/topology.py @@ -245,7 +245,8 @@ def open(self) -> None: # Close servers and clear the pools. for server in self._servers.values(): server.close() - self._monitor_tasks.append(server._monitor) + if not _IS_SYNC: + self._monitor_tasks.append(server._monitor) # Reset the session pool to avoid duplicate sessions in # the child process. self._session_pool.reset() @@ -298,9 +299,14 @@ def select_servers( ] if not _IS_SYNC and self._monitor_tasks: - joins = [t.join() for t in self._monitor_tasks] # type: ignore[func-returns-value] - asyncio.gather(*joins) - self._monitor_tasks = [] + join_tasks = [] + try: + while self._monitor_tasks: + join_tasks.append(self._monitor_tasks.pop()) + except IndexError: + pass + join_tasks = [t.join() for t in join_tasks] # type: ignore[func-returns-value] + asyncio.gather(*join_tasks) return servers @@ -532,7 +538,8 @@ def _process_change( and self._description.topology_type not in SRV_POLLING_TOPOLOGIES ): self._srv_monitor.close() - self._monitor_tasks.append(self._srv_monitor) + if not _IS_SYNC: + self._monitor_tasks.append(self._srv_monitor) # Clear the pool from a failed heartbeat. if reset_pool: @@ -706,7 +713,8 @@ def close(self) -> None: old_td = self._description for server in self._servers.values(): server.close() - self._monitor_tasks.append(server._monitor) + if not _IS_SYNC: + self._monitor_tasks.append(server._monitor) # Mark all servers Unknown. self._description = self._description.reset() @@ -717,7 +725,8 @@ def close(self) -> None: # Stop SRV polling thread. if self._srv_monitor: self._srv_monitor.close() - self._monitor_tasks.append(self._srv_monitor) + if not _IS_SYNC: + self._monitor_tasks.append(self._srv_monitor) self._opened = False self._closed = True @@ -957,7 +966,8 @@ def _update_servers(self) -> None: for address, server in list(self._servers.items()): if not self._description.has_server(address): server.close() - self._monitor_tasks.append(server._monitor) + if not _IS_SYNC: + self._monitor_tasks.append(server._monitor) self._servers.pop(address) def _create_pool_for_server(self, address: _Address) -> Pool: From 2ac4ea3f0ba5c121e5fe2cab789d3a29d6bac15d Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Tue, 4 Feb 2025 16:33:45 -0500 Subject: [PATCH 5/7] Address review --- pymongo/asynchronous/mongo_client.py | 11 +++-------- pymongo/asynchronous/monitor.py | 2 +- pymongo/asynchronous/topology.py | 27 ++++++++++++++------------- pymongo/synchronous/mongo_client.py | 11 +++-------- pymongo/synchronous/monitor.py | 2 +- pymongo/synchronous/topology.py | 27 ++++++++++++++------------- 6 files changed, 36 insertions(+), 44 deletions(-) diff --git a/pymongo/asynchronous/mongo_client.py b/pymongo/asynchronous/mongo_client.py index 0e13608548..ffef52cce6 100644 --- a/pymongo/asynchronous/mongo_client.py +++ b/pymongo/asynchronous/mongo_client.py @@ -1566,14 +1566,9 @@ async def close(self) -> None: await self._encrypter.close() self._closed = True if not _IS_SYNC: - join_tasks = [self._kill_cursors_executor] - try: - while self._topology._monitor_tasks: - join_tasks.append(self._topology._monitor_tasks.pop()) - except IndexError: - pass - join_tasks = [t.join() for t in join_tasks] # type: ignore[func-returns-value] - await asyncio.gather(*join_tasks) + await asyncio.gather( + *[self._topology.cleanup_monitors(), self._kill_cursors_executor.join()] # type: ignore[func-returns-value] + ) if not _IS_SYNC: # Add support for contextlib.aclosing. diff --git a/pymongo/asynchronous/monitor.py b/pymongo/asynchronous/monitor.py index cb41ee26b3..7148da57f0 100644 --- a/pymongo/asynchronous/monitor.py +++ b/pymongo/asynchronous/monitor.py @@ -190,7 +190,7 @@ def gc_safe_close(self) -> None: self.cancel_check() async def join(self) -> None: - await asyncio.gather(*[self._executor.join(), self._rtt_monitor.join()]) + await asyncio.gather(*[self._executor.join(), self._rtt_monitor.join()]) # type: ignore[func-returns-value] async def close(self) -> None: self.gc_safe_close() diff --git a/pymongo/asynchronous/topology.py b/pymongo/asynchronous/topology.py index c64a449e78..78e512f65b 100644 --- a/pymongo/asynchronous/topology.py +++ b/pymongo/asynchronous/topology.py @@ -289,27 +289,19 @@ async def select_servers( else: server_timeout = server_selection_timeout + # Cleanup any completed monitor tasks safely + if not _IS_SYNC and self._monitor_tasks: + await self.cleanup_monitors() + async with self._lock: server_descriptions = await self._select_servers_loop( selector, server_timeout, operation, operation_id, address ) - servers = [ + return [ cast(Server, self.get_server_by_address(sd.address)) for sd in server_descriptions ] - if not _IS_SYNC and self._monitor_tasks: - join_tasks = [] - try: - while self._monitor_tasks: - join_tasks.append(self._monitor_tasks.pop()) - except IndexError: - pass - join_tasks = [t.join() for t in join_tasks] # type: ignore[func-returns-value] - await asyncio.gather(*join_tasks) - - return servers - async def _select_servers_loop( self, selector: Callable[[Selection], Selection], @@ -1057,6 +1049,15 @@ def _error_message(self, selector: Callable[[Selection], Selection]) -> str: else: return ",".join(str(server.error) for server in servers if server.error) + async def cleanup_monitors(self) -> None: + tasks = [] + try: + while self._monitor_tasks: + tasks.append(self._monitor_tasks.pop()) + except IndexError: + pass + await asyncio.gather(*[t.join() for t in tasks]) # type: ignore[func-returns-value] + def __repr__(self) -> str: msg = "" if not self._opened: diff --git a/pymongo/synchronous/mongo_client.py b/pymongo/synchronous/mongo_client.py index 5a755e0464..54b8d4102c 100644 --- a/pymongo/synchronous/mongo_client.py +++ b/pymongo/synchronous/mongo_client.py @@ -1560,14 +1560,9 @@ def close(self) -> None: self._encrypter.close() self._closed = True if not _IS_SYNC: - join_tasks = [self._kill_cursors_executor] - try: - while self._topology._monitor_tasks: - join_tasks.append(self._topology._monitor_tasks.pop()) - except IndexError: - pass - join_tasks = [t.join() for t in join_tasks] # type: ignore[func-returns-value] - asyncio.gather(*join_tasks) + asyncio.gather( + *[self._topology.cleanup_monitors(), self._kill_cursors_executor.join()] # type: ignore[func-returns-value] + ) if not _IS_SYNC: # Add support for contextlib.closing. diff --git a/pymongo/synchronous/monitor.py b/pymongo/synchronous/monitor.py index b3a3de7104..0e848000c0 100644 --- a/pymongo/synchronous/monitor.py +++ b/pymongo/synchronous/monitor.py @@ -190,7 +190,7 @@ def gc_safe_close(self) -> None: self.cancel_check() def join(self) -> None: - asyncio.gather(*[self._executor.join(), self._rtt_monitor.join()]) + asyncio.gather(*[self._executor.join(), self._rtt_monitor.join()]) # type: ignore[func-returns-value] def close(self) -> None: self.gc_safe_close() diff --git a/pymongo/synchronous/topology.py b/pymongo/synchronous/topology.py index c83f163444..f9c2e6d669 100644 --- a/pymongo/synchronous/topology.py +++ b/pymongo/synchronous/topology.py @@ -289,27 +289,19 @@ def select_servers( else: server_timeout = server_selection_timeout + # Cleanup any completed monitor tasks safely + if not _IS_SYNC and self._monitor_tasks: + self.cleanup_monitors() + with self._lock: server_descriptions = self._select_servers_loop( selector, server_timeout, operation, operation_id, address ) - servers = [ + return [ cast(Server, self.get_server_by_address(sd.address)) for sd in server_descriptions ] - if not _IS_SYNC and self._monitor_tasks: - join_tasks = [] - try: - while self._monitor_tasks: - join_tasks.append(self._monitor_tasks.pop()) - except IndexError: - pass - join_tasks = [t.join() for t in join_tasks] # type: ignore[func-returns-value] - asyncio.gather(*join_tasks) - - return servers - def _select_servers_loop( self, selector: Callable[[Selection], Selection], @@ -1055,6 +1047,15 @@ def _error_message(self, selector: Callable[[Selection], Selection]) -> str: else: return ",".join(str(server.error) for server in servers if server.error) + def cleanup_monitors(self) -> None: + tasks = [] + try: + while self._monitor_tasks: + tasks.append(self._monitor_tasks.pop()) + except IndexError: + pass + asyncio.gather(*[t.join() for t in tasks]) # type: ignore[func-returns-value] + def __repr__(self) -> str: msg = "" if not self._opened: From 24e96f08b910b91e613385a3faa8872b0f893854 Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Wed, 5 Feb 2025 09:12:51 -0500 Subject: [PATCH 6/7] return_exceptions=True for gather calls --- pymongo/asynchronous/mongo_client.py | 3 ++- pymongo/asynchronous/monitor.py | 4 +++- pymongo/asynchronous/topology.py | 2 +- pymongo/synchronous/mongo_client.py | 3 ++- pymongo/synchronous/monitor.py | 2 +- pymongo/synchronous/topology.py | 2 +- 6 files changed, 10 insertions(+), 6 deletions(-) diff --git a/pymongo/asynchronous/mongo_client.py b/pymongo/asynchronous/mongo_client.py index ffef52cce6..34d4030dcd 100644 --- a/pymongo/asynchronous/mongo_client.py +++ b/pymongo/asynchronous/mongo_client.py @@ -1567,7 +1567,8 @@ async def close(self) -> None: self._closed = True if not _IS_SYNC: await asyncio.gather( - *[self._topology.cleanup_monitors(), self._kill_cursors_executor.join()] # type: ignore[func-returns-value] + *[self._topology.cleanup_monitors(), self._kill_cursors_executor.join()], # type: ignore[func-returns-value] + return_exceptions=True, ) if not _IS_SYNC: diff --git a/pymongo/asynchronous/monitor.py b/pymongo/asynchronous/monitor.py index 7148da57f0..01d28f82a1 100644 --- a/pymongo/asynchronous/monitor.py +++ b/pymongo/asynchronous/monitor.py @@ -190,7 +190,9 @@ def gc_safe_close(self) -> None: self.cancel_check() async def join(self) -> None: - await asyncio.gather(*[self._executor.join(), self._rtt_monitor.join()]) # type: ignore[func-returns-value] + await asyncio.gather( + *[self._executor.join(), self._rtt_monitor.join()], return_exceptions=True + ) # type: ignore[func-returns-value] async def close(self) -> None: self.gc_safe_close() diff --git a/pymongo/asynchronous/topology.py b/pymongo/asynchronous/topology.py index 78e512f65b..3033377de5 100644 --- a/pymongo/asynchronous/topology.py +++ b/pymongo/asynchronous/topology.py @@ -1056,7 +1056,7 @@ async def cleanup_monitors(self) -> None: tasks.append(self._monitor_tasks.pop()) except IndexError: pass - await asyncio.gather(*[t.join() for t in tasks]) # type: ignore[func-returns-value] + await asyncio.gather(*[t.join() for t in tasks], return_exceptions=True) # type: ignore[func-returns-value] def __repr__(self) -> str: msg = "" diff --git a/pymongo/synchronous/mongo_client.py b/pymongo/synchronous/mongo_client.py index 54b8d4102c..3453f8399e 100644 --- a/pymongo/synchronous/mongo_client.py +++ b/pymongo/synchronous/mongo_client.py @@ -1561,7 +1561,8 @@ def close(self) -> None: self._closed = True if not _IS_SYNC: asyncio.gather( - *[self._topology.cleanup_monitors(), self._kill_cursors_executor.join()] # type: ignore[func-returns-value] + *[self._topology.cleanup_monitors(), self._kill_cursors_executor.join()], # type: ignore[func-returns-value] + return_exceptions=True, ) if not _IS_SYNC: diff --git a/pymongo/synchronous/monitor.py b/pymongo/synchronous/monitor.py index 0e848000c0..e9402b40b2 100644 --- a/pymongo/synchronous/monitor.py +++ b/pymongo/synchronous/monitor.py @@ -190,7 +190,7 @@ def gc_safe_close(self) -> None: self.cancel_check() def join(self) -> None: - asyncio.gather(*[self._executor.join(), self._rtt_monitor.join()]) # type: ignore[func-returns-value] + asyncio.gather(*[self._executor.join(), self._rtt_monitor.join()], return_exceptions=True) # type: ignore[func-returns-value] def close(self) -> None: self.gc_safe_close() diff --git a/pymongo/synchronous/topology.py b/pymongo/synchronous/topology.py index f9c2e6d669..09b61f6d05 100644 --- a/pymongo/synchronous/topology.py +++ b/pymongo/synchronous/topology.py @@ -1054,7 +1054,7 @@ def cleanup_monitors(self) -> None: tasks.append(self._monitor_tasks.pop()) except IndexError: pass - asyncio.gather(*[t.join() for t in tasks]) # type: ignore[func-returns-value] + asyncio.gather(*[t.join() for t in tasks], return_exceptions=True) # type: ignore[func-returns-value] def __repr__(self) -> str: msg = "" From a2eb4bf3c426c89caa859ec32df27d8c27d4cbb7 Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Wed, 5 Feb 2025 14:45:25 -0500 Subject: [PATCH 7/7] Cleanup gathers --- pymongo/asynchronous/mongo_client.py | 3 ++- pymongo/asynchronous/monitor.py | 2 +- pymongo/synchronous/mongo_client.py | 3 ++- pymongo/synchronous/monitor.py | 2 +- 4 files changed, 6 insertions(+), 4 deletions(-) diff --git a/pymongo/asynchronous/mongo_client.py b/pymongo/asynchronous/mongo_client.py index 34d4030dcd..365fc62100 100644 --- a/pymongo/asynchronous/mongo_client.py +++ b/pymongo/asynchronous/mongo_client.py @@ -1567,7 +1567,8 @@ async def close(self) -> None: self._closed = True if not _IS_SYNC: await asyncio.gather( - *[self._topology.cleanup_monitors(), self._kill_cursors_executor.join()], # type: ignore[func-returns-value] + self._topology.cleanup_monitors(), # type: ignore[func-returns-value] + self._kill_cursors_executor.join(), # type: ignore[func-returns-value] return_exceptions=True, ) diff --git a/pymongo/asynchronous/monitor.py b/pymongo/asynchronous/monitor.py index 01d28f82a1..abde7a9055 100644 --- a/pymongo/asynchronous/monitor.py +++ b/pymongo/asynchronous/monitor.py @@ -191,7 +191,7 @@ def gc_safe_close(self) -> None: async def join(self) -> None: await asyncio.gather( - *[self._executor.join(), self._rtt_monitor.join()], return_exceptions=True + self._executor.join(), self._rtt_monitor.join(), return_exceptions=True ) # type: ignore[func-returns-value] async def close(self) -> None: diff --git a/pymongo/synchronous/mongo_client.py b/pymongo/synchronous/mongo_client.py index 3453f8399e..8cd08ab725 100644 --- a/pymongo/synchronous/mongo_client.py +++ b/pymongo/synchronous/mongo_client.py @@ -1561,7 +1561,8 @@ def close(self) -> None: self._closed = True if not _IS_SYNC: asyncio.gather( - *[self._topology.cleanup_monitors(), self._kill_cursors_executor.join()], # type: ignore[func-returns-value] + self._topology.cleanup_monitors(), # type: ignore[func-returns-value] + self._kill_cursors_executor.join(), # type: ignore[func-returns-value] return_exceptions=True, ) diff --git a/pymongo/synchronous/monitor.py b/pymongo/synchronous/monitor.py index e9402b40b2..211635d8b8 100644 --- a/pymongo/synchronous/monitor.py +++ b/pymongo/synchronous/monitor.py @@ -190,7 +190,7 @@ def gc_safe_close(self) -> None: self.cancel_check() def join(self) -> None: - asyncio.gather(*[self._executor.join(), self._rtt_monitor.join()], return_exceptions=True) # type: ignore[func-returns-value] + asyncio.gather(self._executor.join(), self._rtt_monitor.join(), return_exceptions=True) # type: ignore[func-returns-value] def close(self) -> None: self.gc_safe_close()