diff --git a/redis/client.py b/redis/client.py index d3ab3cfcfe..9d46b89dfe 100755 --- a/redis/client.py +++ b/redis/client.py @@ -125,6 +125,17 @@ class Redis(RedisModuleCommands, CoreCommands, SentinelCommands): Connection object to talk to redis. It is not safe to pass PubSub or Pipeline objects between threads. + + :param float idle_connection_timeout: + If set, connections that have been idle for longer than this timeout + (in seconds) will be automatically closed. If unset, idle connections + are never closed. This parameter is passed through to the connection pool + constructor, so it's only used when a connection_pool instance is not provided. + :param float idle_check_interval: + Minimum time between idle connection cleanup runs. Defaults to 60 seconds. + Only used when idle_connection_timeout is set. As with idle_connection_timeout, + this parameter is passed through to the connection pool constructor, + so it's only used when a connection_pool instance is not provided. """ @classmethod @@ -250,6 +261,8 @@ def __init__( cache_config: Optional[CacheConfig] = None, event_dispatcher: Optional[EventDispatcher] = None, maint_notifications_config: Optional[MaintNotificationsConfig] = None, + idle_connection_timeout: Optional[float] = None, + idle_check_interval: float = 60.0, ) -> None: """ Initialize a new Redis client. @@ -314,6 +327,8 @@ def __init__( "redis_connect_func": redis_connect_func, "credential_provider": credential_provider, "protocol": protocol, + "idle_connection_timeout": idle_connection_timeout, + "idle_check_interval": idle_check_interval, } # based on input, setup appropriate connection args if unix_socket_path is not None: diff --git a/redis/cluster.py b/redis/cluster.py index 8f42c1a235..98a1304ea5 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -196,6 +196,8 @@ def parse_cluster_myshardid(resp, **options): "username", "cache", "cache_config", + "idle_connection_timeout", + "idle_check_interval", ) KWARGS_DISABLED_KEYS = ("host", "port", "retry") diff --git a/redis/connection.py b/redis/connection.py index 0a87777ac3..c5f712b878 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -1,4 +1,7 @@ import copy +import datetime +import heapq +import logging import os import socket import sys @@ -6,8 +9,10 @@ import time import weakref from abc import ABC, abstractmethod +from contextlib import contextmanager +from dataclasses import dataclass, field from itertools import chain -from queue import Empty, Full, LifoQueue +from queue import Empty, Full, LifoQueue, Queue from typing import ( Any, Callable, @@ -16,6 +21,7 @@ List, Literal, Optional, + Tuple, Type, TypeVar, Union, @@ -76,6 +82,8 @@ if HIREDIS_AVAILABLE: import hiredis +logger = logging.getLogger(__name__) + SYM_STAR = b"*" SYM_DOLLAR = b"$" SYM_CRLF = b"\r\n" @@ -162,6 +170,10 @@ def pack(self, *args): class ConnectionInterface: + pid: int + retry: Retry + maintenance_notification_hash: int | None + @abstractmethod def repr_pieces(self): pass @@ -2361,6 +2373,207 @@ def disconnect_free_connections( conn.disconnect() +@dataclass(order=True) +class _PoolMetadata: + """Metadata for a registered pool, used both for tracking and scheduling.""" + + next_check_time: ( + float # timestamp when this pool should be checked next (for heapq ordering) + ) + pool_id: int = field(compare=False) # id(pool) for identification + # we use a weakref to the connection pool itself; this is because clients rely + # on garbage collection to close a pool after it's no longer needed. + # however, if we didn't use a weakref, the IdleConnectionCleanupManager + # would always have a reference to the pool which is never dropped, and + # the GC + disconnect would never happen. + pool_ref: weakref.ref = field(compare=False) + idle_timeout: float = field(compare=False) + minimum_check_interval: float = field(compare=False) # minimum time between checks + + +class IdleConnectionCleanupManager: + """Global singleton manager for idle connection cleanup across all pools. + + This manager maintains a single worker thread that handles cleanup for all + connection pools, using a priority queue to efficiently schedule checks. + """ + + _instance: Optional["IdleConnectionCleanupManager"] = None + _instance_lock = threading.Lock() + + def __init__(self): + # callers should use get_instance() instead of directly calling the constructor + self._schedule_lock = threading.RLock() + self._condition = threading.Condition(self._schedule_lock) + self._schedule: List[_PoolMetadata] = [] # heapq-based priority queue + self._registered_pool_ids: set[int] = set() # set of pool ids in the _schedule + self._worker_thread: Optional[threading.Thread] = None + self._shutdown_event = threading.Event() + # WAL-style tracking: if we pop a pool from schedule but crash before re-adding, + # we can recover it from here + self._pool_being_cleaned: Optional[_PoolMetadata] = None + + @classmethod + def get_instance(cls) -> "IdleConnectionCleanupManager": + # get or create the singleton instance + if cls._instance is None: + with cls._instance_lock: + if cls._instance is None: + cls._instance = IdleConnectionCleanupManager() + cls._instance._start_worker() + return cls._instance + + def _start_worker(self) -> None: + """Start the background worker thread.""" + self._shutdown_event.clear() + self._worker_thread = threading.Thread( + target=self._worker_loop, daemon=True, name="IdleConnectionCleanupManager" + ) + self._worker_thread.start() + + def register_pool(self, pool: "ConnectionPool", next_check_time: float) -> None: + # Register a pool for idle connection cleanup. + # Called when a connection is released. + + if pool.idle_connection_timeout is None: + # no need to register, because this pool doesn't close idle connections + return + + pool_id = id(pool) + + with self._condition: + if pool_id in self._registered_pool_ids: + # no-op if already registered + return + + metadata = _PoolMetadata( + next_check_time=next_check_time, + pool_id=pool_id, + pool_ref=weakref.ref(pool), + idle_timeout=pool.idle_connection_timeout, + minimum_check_interval=pool.idle_check_interval, + ) + + self._registered_pool_ids.add(pool_id) + heapq.heappush(self._schedule, metadata) + + # wake up worker to potentially adjust sleep time + self._condition.notify() + + def unregister_pool(self, pool: "ConnectionPool") -> None: + # Unregister a pool from cleanup + pool_id = id(pool) + with self._condition: + self._registered_pool_ids.discard(pool_id) + # Note: We don't remove from schedule immediately, because the heapq + # doesn't have a fast way to do this. The worker will skip it when it + # processes the entry. + + def _worker_loop(self) -> None: + # processes pools in schedule order + while not self._shutdown_event.is_set(): + try: + with self._condition: + # first, check if we have a pool from a previous failed iteration + if self._pool_being_cleaned is not None: + # re-add it to schedule before processing anything else + heapq.heappush(self._schedule, self._pool_being_cleaned) + self._pool_being_cleaned = None + + # get the next pool to be processed + next_pair = self._get_next_pool() + if next_pair is None: + continue + metadata, pool = next_pair + + # use a WAL pattern to be defensive against bugs resulting + # in us dequeueing a pool, and never re-enqueueing it. + self._pool_being_cleaned = metadata + + # release lock while doing cleanup, since this is relatively slow + try: + oldest_conn_time = pool._cleanup_idle_connections() + except Exception as e: + logger.warning( + "Error during idle connection cleanup for pool %s: %s", + id(pool), + e, + exc_info=True, + ) + oldest_conn_time = None + finally: + # make sure to drop the pool reference - we never want the idle connection + # thread to be the only thing holding a reference to a pool, because this can + # keep the pool from being GC'd, and closing all of its connections. + del pool + + with self._condition: + self._reschedule_pool(metadata, oldest_conn_time) + # after the pool is rescheduled, we can clean up the WAL + self._pool_being_cleaned = None + except Exception as e: + logger.error( + "Unexpected error in idle connection cleanup worker: %s", + e, + exc_info=True, + ) + + def _get_next_pool(self) -> "Tuple[_PoolMetadata, ConnectionPool] | None": + if not self._schedule: + # No pools to manage, wait for registration or shutdown + self._condition.wait() + return None + + # Peek at next pool to check + metadata = self._schedule[0] + + wait_time = metadata.next_check_time - time.time() + if wait_time > 0: + # Sleep until next scheduled check (or until notified) + self._condition.wait(timeout=wait_time) + return None + + heapq.heappop(self._schedule) + + if metadata.pool_id not in self._registered_pool_ids: + # pool was unregistered + return None + + pool = metadata.pool_ref() + if pool is None: + # pool was GC'd + self._registered_pool_ids.discard(metadata.pool_id) + return None + + return metadata, pool + + def _reschedule_pool(self, metadata: _PoolMetadata, oldest_conn_time: float | None): + if metadata.pool_id not in self._registered_pool_ids: + # pool was unregistered while we were cleaning it + return + + # reschedule this pool, or remove if empty + if oldest_conn_time: + next_check = max( + # check when the oldest connection will become idle + oldest_conn_time + metadata.idle_timeout, + # ...but don't check more frequently than check_interval + time.time() + metadata.minimum_check_interval, + ) + # Pool has connections, reschedule it + metadata.next_check_time = next_check + heapq.heappush(self._schedule, metadata) + else: + # Pool is empty, remove from tracking entirely + self._registered_pool_ids.discard(metadata.pool_id) + + +class PooledConnection: + def __init__(self, connection: ConnectionInterface): + self.connection = connection + self.last_used: datetime.datetime = datetime.datetime.now() + + class ConnectionPool(MaintNotificationsAbstractConnectionPool, ConnectionPoolInterface): """ Create a connection pool. ``If max_connections`` is set, then this @@ -2378,6 +2591,14 @@ class ConnectionPool(MaintNotificationsAbstractConnectionPool, ConnectionPoolInt If the ``maint_notifications_config`` is not provided but the ``protocol`` is 3, the maintenance notifications will be enabled by default. + The pool can automatically close idle connections to free up resources. + Set ``idle_connection_timeout`` (in seconds) to enable this feature. + Connections that remain idle (not checked out from the pool) longer than + this timeout will be automatically closed and removed from the pool. + The ``idle_check_interval`` parameter controls the minimum time between + cleanup checks (default: 60 seconds). All pools in the process share a + single background thread for cleanup operations. + Any additional keyword arguments are passed to the constructor of ``connection_class``. """ @@ -2437,17 +2658,31 @@ def __init__( max_connections: Optional[int] = None, cache_factory: Optional[CacheFactoryInterface] = None, maint_notifications_config: Optional[MaintNotificationsConfig] = None, + idle_connection_timeout: Optional[float] = None, + idle_check_interval: float = 60.0, **connection_kwargs, ): max_connections = max_connections or 2**31 if not isinstance(max_connections, int) or max_connections < 0: raise ValueError('"max_connections" must be a positive integer') + if idle_connection_timeout is not None and idle_connection_timeout <= 0: + raise ValueError( + '"idle_connection_timeout" must be a positive number or None' + ) + + if idle_check_interval <= 0: + raise ValueError('"idle_check_interval" must be a positive number') + self.connection_class = connection_class self._connection_kwargs = connection_kwargs self.max_connections = max_connections self.cache = None self._cache_factory = cache_factory + self._available_connections: list[PooledConnection] = [] + self._in_use_connections: set[ConnectionInterface] = set() + self.idle_connection_timeout = idle_connection_timeout + self.idle_check_interval = idle_check_interval if connection_kwargs.get("cache_config") or connection_kwargs.get("cache"): if self._connection_kwargs.get("protocol") not in [3, "3"]: @@ -2535,6 +2770,50 @@ def reset(self) -> None: # reset() and they will immediately release _fork_lock and continue on. self.pid = os.getpid() + def _cleanup_idle_connections(self) -> Optional[float]: + """Remove connections that have been idle for longer than the timeout. + + Returns: + Timestamp of the oldest remaining connection, or None if pool is empty. + """ + if self.idle_connection_timeout is None: + return None + + now = datetime.datetime.now() + connections_to_disconnect = [] + oldest_connection_time = None + + with self._lock: + connections_to_keep = [] + for pooled_conn in self._available_connections: + idle_time = (now - pooled_conn.last_used).total_seconds() + if idle_time < self.idle_connection_timeout: + connections_to_keep.append(pooled_conn) + # Track the oldest connection we're keeping + conn_timestamp = pooled_conn.last_used.timestamp() + if ( + oldest_connection_time is None + or conn_timestamp < oldest_connection_time + ): + oldest_connection_time = conn_timestamp + else: + # Mark for disconnection + connections_to_disconnect.append(pooled_conn) + self._created_connections -= 1 + + self._available_connections = connections_to_keep + + # Disconnect outside the lock to avoid blocking pool operations + for pooled_conn in connections_to_disconnect: + try: + pooled_conn.connection.disconnect() + except Exception as e: + logger.warning( + "Error disconnecting idle connection: %s", e, exc_info=True + ) + + return oldest_connection_time + def _checkpid(self) -> None: # _checkpid() attempts to keep ConnectionPool fork-safe on modern # systems. this is called by all ConnectionPool methods that @@ -2587,13 +2866,15 @@ def _checkpid(self) -> None: reason="Use get_connection() without args instead", version="5.3.0", ) - def get_connection(self, command_name=None, *keys, **options) -> "Connection": + def get_connection( + self, command_name=None, *keys, **options + ) -> "ConnectionInterface": "Get a connection from the pool" self._checkpid() with self._lock: try: - connection = self._available_connections.pop() + connection = self._available_connections.pop().connection except IndexError: connection = self.make_connection() self._in_use_connections.add(connection) @@ -2647,9 +2928,10 @@ def make_connection(self) -> "ConnectionInterface": ) return self.connection_class(**kwargs) - def release(self, connection: "Connection") -> None: + def release(self, connection: "ConnectionInterface") -> None: "Releases the connection back to the pool" self._checkpid() + release_time = time.time() with self._lock: try: self._in_use_connections.remove(connection) @@ -2661,7 +2943,7 @@ def release(self, connection: "Connection") -> None: if self.owns_connection(connection): if connection.should_reconnect(): connection.disconnect() - self._available_connections.append(connection) + self._available_connections.append(PooledConnection(connection)) self._event_dispatcher.dispatch( AfterConnectionReleasedEvent(connection) ) @@ -2673,7 +2955,12 @@ def release(self, connection: "Connection") -> None: connection.disconnect() return - def owns_connection(self, connection: "Connection") -> int: + # Register with manager if pool was empty (will be a no-op if already registered) + if self.idle_connection_timeout is not None: + next_check = release_time + self.idle_connection_timeout + IdleConnectionCleanupManager.get_instance().register_pool(self, next_check) + + def owns_connection(self, connection: "ConnectionInterface") -> int: return connection.pid == self.pid def disconnect(self, inuse_connections: bool = True) -> None: @@ -2686,30 +2973,29 @@ def disconnect(self, inuse_connections: bool = True) -> None: """ self._checkpid() with self._lock: + connections = (p.connection for p in self._available_connections) if inuse_connections: - connections = chain( - self._available_connections, self._in_use_connections - ) - else: - connections = self._available_connections + connections = chain(connections, self._in_use_connections) for connection in connections: connection.disconnect() def close(self) -> None: """Close the pool, disconnecting all connections""" + if self.idle_connection_timeout is not None: + IdleConnectionCleanupManager.get_instance().unregister_pool(self) self.disconnect() def set_retry(self, retry: Retry) -> None: self.connection_kwargs.update({"retry": retry}) - for conn in self._available_connections: + for conn in self._get_free_connections(): conn.retry = retry for conn in self._in_use_connections: conn.retry = retry def re_auth_callback(self, token: TokenInterface): with self._lock: - for conn in self._available_connections: + for conn in self._get_free_connections(): conn.retry.call_with_retry( lambda: conn.send_command( "AUTH", token.try_get("oid"), token.get_value() @@ -2727,7 +3013,7 @@ def _get_pool_lock(self): def _get_free_connections(self): with self._lock: - return self._available_connections + return [p.connection for p in self._available_connections] def _get_in_use_connections(self): with self._lock: @@ -2774,6 +3060,9 @@ class BlockingConnectionPool(ConnectionPool): >>> # Raise a ``ConnectionError`` after five seconds if a connection is >>> # not available. >>> pool = BlockingConnectionPool(timeout=5) + + Like :py:class:`~redis.ConnectionPool`, this pool also supports automatic + idle connection cleanup via the ``idle_connection_timeout`` parameter. """ def __init__( @@ -2782,24 +3071,25 @@ def __init__( timeout=20, connection_class=Connection, queue_class=LifoQueue, + idle_connection_timeout: Optional[float] = None, + idle_check_interval: float = 60.0, **connection_kwargs, ): self.queue_class = queue_class self.timeout = timeout self._in_maintenance = False - self._locked = False + self.pool: Queue[PooledConnection | None] super().__init__( connection_class=connection_class, max_connections=max_connections, + idle_connection_timeout=idle_connection_timeout, + idle_check_interval=idle_check_interval, **connection_kwargs, ) def reset(self): # Create and fill up a thread safe queue with ``None`` values. - try: - if self._in_maintenance: - self._lock.acquire() - self._locked = True + with self._maintenance_lock(): self.pool = self.queue_class(self.max_connections) while True: try: @@ -2810,13 +3100,6 @@ def reset(self): # Keep a list of actual connection instances so that we can # disconnect them later. self._connections = [] - finally: - if self._locked: - try: - self._lock.release() - except Exception: - pass - self._locked = False # this must be the last operation in this method. while reset() is # called when holding _fork_lock, other threads in this process @@ -2831,11 +3114,7 @@ def reset(self): def make_connection(self): "Make a fresh connection." - try: - if self._in_maintenance: - self._lock.acquire() - self._locked = True - + with self._maintenance_lock(): if self.cache is not None: connection = CacheProxyConnection( self.connection_class(**self.connection_kwargs), @@ -2846,13 +3125,6 @@ def make_connection(self): connection = self.connection_class(**self.connection_kwargs) self._connections.append(connection) return connection - finally: - if self._locked: - try: - self._lock.release() - except Exception: - pass - self._locked = False @deprecated_args( args_to_warn=["*"], @@ -2877,28 +3149,21 @@ def get_connection(self, command_name=None, *keys, **options): # Try and get a connection from the pool. If one isn't available within # self.timeout then raise a ``ConnectionError``. connection = None - try: - if self._in_maintenance: - self._lock.acquire() - self._locked = True + with self._maintenance_lock(): try: - connection = self.pool.get(block=True, timeout=self.timeout) + pooled_connection = self.pool.get(block=True, timeout=self.timeout) except Empty: # Note that this is not caught by the redis client and will be # raised unless handled by application code. If you want never to raise ConnectionError("No connection available.") - # If the ``connection`` is actually ``None`` then that's a cue to make + # If the ``pooled_connection`` is actually ``None`` then that's a cue to make # a new connection to add to the pool. - if connection is None: + if pooled_connection: + connection = pooled_connection.connection + else: connection = self.make_connection() - finally: - if self._locked: - try: - self._lock.release() - except Exception: - pass - self._locked = False + self._in_use_connections.add(connection) try: # ensure this connection is connected to Redis @@ -2927,10 +3192,15 @@ def release(self, connection): # Make sure we haven't changed process. self._checkpid() - try: - if self._in_maintenance: - self._lock.acquire() - self._locked = True + release_time = time.time() + with self._maintenance_lock(): + try: + self._in_use_connections.remove(connection) + except KeyError: + # Gracefully fail when a connection is returned to this pool + # that the pool doesn't actually own + return + if not self.owns_connection(connection): # pool doesn't own this connection. do not add it back # to the pool. instead add a None value which is a placeholder @@ -2943,53 +3213,99 @@ def release(self, connection): connection.disconnect() # Put the connection back into the pool. try: - self.pool.put_nowait(connection) + self.pool.put_nowait(PooledConnection(connection)) except Full: # perhaps the pool has been reset() after a fork? regardless, # we don't want this connection pass - finally: - if self._locked: - try: - self._lock.release() - except Exception: - pass - self._locked = False + + # Register with manager if pool was empty (will be a no-op if already registered) + if self.idle_connection_timeout is not None: + next_check = release_time + self.idle_connection_timeout + IdleConnectionCleanupManager.get_instance().register_pool(self, next_check) def disconnect(self, inuse_connections: bool = True): "Disconnects either all connections in the pool or just the free connections." self._checkpid() - try: - if self._in_maintenance: - self._lock.acquire() - self._locked = True + with self._maintenance_lock(): if inuse_connections: connections = self._connections else: connections = self._get_free_connections() for connection in connections: connection.disconnect() - finally: - if self._locked: - try: - self._lock.release() - except Exception: - pass - self._locked = False + + def _cleanup_idle_connections(self) -> Optional[float]: + """Remove connections that have been idle for longer than the timeout. + + Override for BlockingConnectionPool to work with Queue structure. + + Returns: + Timestamp of the oldest remaining connection, or None if pool is empty. + """ + + if self.idle_connection_timeout is None: + return None + + now = datetime.datetime.now() + connections_to_disconnect = [] + oldest_connection_time = None + + with self._maintenance_lock(): + # Access the internal deque directly while holding the queue's mutex + # Note: it's safe to manipulate pool.queue while holding the lock, + # but ONLY because we're not adding / removing elements. If we were, + # we'd need to update pool.not_empty, pool.not_full, etc. as well, + # to keep all the state in sync. + with self.pool.mutex: + # Iterate through the internal deque in-place + for i, item in enumerate(self.pool.queue): + # Check if this is an idle connection that should be cleaned up + if item is None: + continue + idle_time = (now - item.last_used).total_seconds() + if idle_time >= self.idle_connection_timeout: + # Mark for disconnection and replace with None placeholder + connections_to_disconnect.append(item) + self.pool.queue[i] = None + # Remove from _connections tracking list + try: + self._connections.remove(item.connection) + except ValueError as e: + logger.debug( + "Connection not found in _connections list during cleanup: %s", + e, + ) + else: + # Track the oldest connection we're keeping + conn_timestamp = item.last_used.timestamp() + if ( + oldest_connection_time is None + or conn_timestamp < oldest_connection_time + ): + oldest_connection_time = conn_timestamp + + # Disconnect outside all locks to avoid blocking pool operations + for pooled_conn in connections_to_disconnect: + try: + pooled_conn.connection.disconnect() + except Exception as e: + logger.warning( + "Error disconnecting idle connection in BlockingConnectionPool: %s", + e, + exc_info=True, + ) + return oldest_connection_time def _get_free_connections(self): with self._lock: - return {conn for conn in self.pool.queue if conn} + return [ + pooled.connection for pooled in self.pool.queue if pooled is not None + ] def _get_in_use_connections(self): with self._lock: - # free connections - connections_in_queue = {conn for conn in self.pool.queue if conn} - # in self._connections we keep all created connections - # so the ones that are not in the queue are the in use ones - return { - conn for conn in self._connections if conn not in connections_in_queue - } + return self._in_use_connections def set_in_maintenance(self, in_maintenance: bool): """ @@ -2999,3 +3315,15 @@ def set_in_maintenance(self, in_maintenance: bool): The pool will be in maintenance mode only when we are processing a MOVING notification. """ self._in_maintenance = in_maintenance + + @contextmanager + def _maintenance_lock(self): + locked = False + try: + if self._in_maintenance: + self._lock.acquire() + locked = True + yield + finally: + if locked: + self._lock.release() diff --git a/tests/test_connection_pool.py b/tests/test_connection_pool.py index 7365c6ff13..0bc263455d 100644 --- a/tests/test_connection_pool.py +++ b/tests/test_connection_pool.py @@ -1,6 +1,9 @@ +import datetime +import gc import os import re import time +import weakref from contextlib import closing from threading import Thread from unittest import mock @@ -30,12 +33,11 @@ def __init__(self, **kwargs): self.kwargs = kwargs self.pid = os.getpid() self._sock = None + self._disconnected = False def connect(self): self._sock = mock.Mock() - - def disconnect(self): - self._sock = None + self._disconnected = False def can_read(self): return False @@ -43,6 +45,13 @@ def can_read(self): def should_reconnect(self): return False + def disconnect(self): + self._sock = None + self._disconnected = True + + def re_auth(self): + pass + class TestConnectionPool: def get_pool( @@ -643,7 +652,7 @@ def test_on_connect_error(self): bad_connection.info() pool = bad_connection.connection_pool assert len(pool._available_connections) == 1 - assert not pool._available_connections[0]._sock + assert not pool._available_connections[0].connection._sock @pytest.mark.onlynoncluster @skip_if_server_version_lt("2.8.8") @@ -688,7 +697,7 @@ def test_busy_loading_from_pipeline(self, r): pool = r.connection_pool assert not pipe.connection assert len(pool._available_connections) == 1 - assert not pool._available_connections[0]._sock + assert not pool._available_connections[0].connection._sock @skip_if_server_version_lt("2.8.8") @skip_if_redis_enterprise() @@ -955,3 +964,608 @@ def test_health_check_in_pubsub_poll(self, r): assert wait_for_message(p) is None m.assert_called_with("PING", p.HEALTH_CHECK_MESSAGE, check_health=False) self.assert_interval_advanced(p.connection) + + +class MockDateTime: + """Context manager for mocking datetime.datetime.now() and time.time().""" + + def __init__(self, start_time=None): + if start_time is None: + start_time = datetime.datetime(2024, 1, 1, 12, 0, 0) + self.current_time = start_time + self.start_time = start_time + + def advance(self, seconds): + """Advance the mocked time by the given number of seconds.""" + self.current_time = self.current_time + datetime.timedelta(seconds=seconds) + + def __enter__(self): + self._datetime_patcher = mock.patch("redis.connection.datetime") + self._time_patcher = mock.patch("redis.connection.time") + + mock_datetime = self._datetime_patcher.__enter__() + mock_time = self._time_patcher.__enter__() + + mock_datetime.datetime.now = lambda: self.current_time + mock_datetime.datetime.side_effect = datetime.datetime + mock_time.time = lambda: self.current_time.timestamp() + + return self + + def __exit__(self, *args): + self._time_patcher.__exit__(*args) + return self._datetime_patcher.__exit__(*args) + + +class TestIdleConnectionTimeout: + """Tests for idle connection timeout functionality.""" + + def test_idle_timeout_parameters_validation(self): + """Test that idle timeout parameters are validated properly.""" + # Valid parameters should work + pool = redis.ConnectionPool( + connection_class=DummyConnection, + idle_connection_timeout=10.0, + idle_check_interval=5.0, + ) + assert pool.idle_connection_timeout == 10.0 + assert pool.idle_check_interval == 5.0 + pool.close() + + # None for idle_connection_timeout should work (disables feature) + pool = redis.ConnectionPool( + connection_class=DummyConnection, + idle_connection_timeout=None, + ) + assert pool.idle_connection_timeout is None + pool.close() + + # Invalid idle_connection_timeout should raise ValueError + with pytest.raises(ValueError, match="idle_connection_timeout"): + redis.ConnectionPool( + connection_class=DummyConnection, + idle_connection_timeout=-1.0, + ) + + with pytest.raises(ValueError, match="idle_connection_timeout"): + redis.ConnectionPool( + connection_class=DummyConnection, + idle_connection_timeout=0, + ) + + # Invalid idle_check_interval should raise ValueError + with pytest.raises(ValueError, match="idle_check_interval"): + redis.ConnectionPool( + connection_class=DummyConnection, + idle_check_interval=-1.0, + ) + + with pytest.raises(ValueError, match="idle_check_interval"): + redis.ConnectionPool( + connection_class=DummyConnection, + idle_check_interval=0, + ) + + def test_pool_not_registered_without_timeout(self): + """Test that pool is not registered when idle_connection_timeout is None.""" + pool = redis.ConnectionPool( + connection_class=DummyConnection, + idle_connection_timeout=None, + ) + manager = redis.connection.IdleConnectionCleanupManager.get_instance() + assert id(pool) not in manager._registered_pool_ids + pool.close() + + # Get and release a connection, which would trigger registration if we had set the idle timeout + conn = pool.get_connection() + pool.release(conn) + + assert id(pool) not in manager._registered_pool_ids + pool.close() + + def test_pool_registered_with_timeout(self): + """Test that pool is registered with manager when idle_connection_timeout is set.""" + pool = redis.ConnectionPool( + connection_class=DummyConnection, + idle_connection_timeout=10.0, + idle_check_interval=1.0, + ) + manager = redis.connection.IdleConnectionCleanupManager.get_instance() + + # Pool is not registered until a connection is released + assert id(pool) not in manager._registered_pool_ids + + # Get and release a connection to trigger registration + conn = pool.get_connection() + pool.release(conn) + + assert id(pool) in manager._registered_pool_ids + assert manager._worker_thread is not None + assert manager._worker_thread.is_alive() + pool.close() + + # After close, pool should be unregistered + assert id(pool) not in manager._registered_pool_ids + + def test_idle_connections_cleaned_up(self): + """Test that idle connections are actually cleaned up.""" + with MockDateTime() as mock_time: + pool = redis.ConnectionPool( + connection_class=DummyConnection, + idle_connection_timeout=1.0, # 1 second timeout + idle_check_interval=0.5, # Check every 0.5 seconds + ) + + # Get and release a connection + conn1 = pool.get_connection() + pool.release(conn1) + + # Should have 1 available connection + assert len(pool._available_connections) == 1 + assert pool._created_connections == 1 + + # Advance time past the idle timeout + mock_time.advance(1.5) + + # Manually trigger cleanup + pool._cleanup_idle_connections() + + # The idle connection should have been cleaned up + assert len(pool._available_connections) == 0 + assert pool._created_connections == 0 + + # Pool should still work after cleanup + conn2 = pool.get_connection() + assert conn2 is not None + pool.release(conn2) + + pool.close() + + def test_fresh_connections_not_cleaned_up(self): + """Test that recently used connections are not cleaned up.""" + with MockDateTime() as mock_time: + pool = redis.ConnectionPool( + connection_class=DummyConnection, + idle_connection_timeout=2.0, + idle_check_interval=0.5, + ) + + # Get and release a connection + conn1 = pool.get_connection() + pool.release(conn1) + + # Advance time less than the timeout + mock_time.advance(0.8) + + # Manually trigger cleanup + pool._cleanup_idle_connections() + + # Connection should still be available + assert len(pool._available_connections) == 1 + + pool.close() + + def test_blocking_pool_idle_timeout(self): + """Test idle timeout with BlockingConnectionPool.""" + with MockDateTime() as mock_time: + pool = redis.BlockingConnectionPool( + connection_class=DummyConnection, + max_connections=5, + timeout=1, + idle_connection_timeout=1.0, + idle_check_interval=0.5, + ) + + # Get and release some connections + conn1 = pool.get_connection() + conn2 = pool.get_connection() + pool.release(conn1) + pool.release(conn2) + + # Should have 2 connections + assert len(pool._connections) == 2 + + # Advance time past the idle timeout + mock_time.advance(1.5) + + # Manually trigger cleanup + pool._cleanup_idle_connections() + + # Connections should be cleaned up + assert len(pool._connections) == 0 + + # Pool should still work + conn3 = pool.get_connection() + assert conn3 is not None + pool.release(conn3) + + pool.close() + + def test_blocking_pool_parameters(self): + """Test that BlockingConnectionPool accepts idle timeout parameters.""" + pool = redis.BlockingConnectionPool( + connection_class=DummyConnection, + max_connections=5, + timeout=1, + idle_connection_timeout=10.0, + idle_check_interval=5.0, + ) + assert pool.idle_connection_timeout == 10.0 + assert pool.idle_check_interval == 5.0 + pool.close() + + def test_multiple_pools_independent_cleanup(self): + """Test that multiple pools clean up independently.""" + with MockDateTime() as mock_time: + pool1 = redis.ConnectionPool( + connection_class=DummyConnection, + idle_connection_timeout=1.0, + idle_check_interval=0.5, + ) + pool2 = redis.ConnectionPool( + connection_class=DummyConnection, + idle_connection_timeout=2.0, + idle_check_interval=0.5, + ) + + # Create connections in both pools + conn1 = pool1.get_connection() + conn2 = pool2.get_connection() + pool1.release(conn1) + pool2.release(conn2) + + # Advance time past pool1's timeout but not pool2's + mock_time.advance(1.5) + + # Trigger cleanup for both pools + pool1._cleanup_idle_connections() + pool2._cleanup_idle_connections() + + # Pool1 should be cleaned up, pool2 should not + assert len(pool1._available_connections) == 0 + assert len(pool2._available_connections) == 1 + + pool1.close() + pool2.close() + + def test_pool_garbage_collection(self): + """Test that pool can be garbage collected when no longer referenced.""" + manager = redis.connection.IdleConnectionCleanupManager.get_instance() + + pool = redis.ConnectionPool( + connection_class=DummyConnection, + idle_connection_timeout=10.0, + idle_check_interval=0.5, + ) + + # Get and release a connection to trigger registration + conn = pool.get_connection() + pool.release(conn) + + assert id(pool) in manager._registered_pool_ids + + pool_weak_ref = weakref.ref(pool) + del pool + gc.collect() + assert pool_weak_ref() is None + + def test_manager_singleton(self): + """Test that IdleConnectionCleanupManager is a singleton.""" + manager1 = redis.connection.IdleConnectionCleanupManager.get_instance() + manager2 = redis.connection.IdleConnectionCleanupManager.get_instance() + assert manager1 is manager2 + + def test_manager_shared_across_pools(self): + """Test that multiple pools share the same cleanup manager.""" + pool1 = redis.ConnectionPool( + connection_class=DummyConnection, + idle_connection_timeout=10.0, + ) + conn = pool1.get_connection() + pool1.release(conn) + + pool2 = redis.ConnectionPool( + connection_class=DummyConnection, + idle_connection_timeout=5.0, + ) + conn = pool2.get_connection() + pool2.release(conn) + + manager = redis.connection.IdleConnectionCleanupManager.get_instance() + + # Both pools should be registered with the same manager + assert id(pool1) in manager._registered_pool_ids + assert id(pool2) in manager._registered_pool_ids + + pool1.close() + pool2.close() + + def test_manager_connection_release_notification(self): + """Test that manager is notified when connections are released.""" + with MockDateTime(): + pool = redis.ConnectionPool( + connection_class=DummyConnection, + idle_connection_timeout=10.0, + idle_check_interval=5.0, + ) + + manager = redis.connection.IdleConnectionCleanupManager.get_instance() + pool_id = id(pool) + + # Get and release a connection + conn = pool.get_connection() + pool.release(conn) + + # Manager should have metadata for this pool + assert pool_id in manager._registered_pool_ids + metadata = manager._schedule[0] + + # Check that idle_timeout and check_interval are stored correctly + assert metadata.pool_id == pool_id + assert metadata.idle_timeout == 10.0 + assert metadata.minimum_check_interval == 5.0 + + pool.close() + + def test_manager_schedules_multiple_pools(self): + """Test that manager correctly schedules cleanup for multiple pools.""" + with MockDateTime(): + # Create pools with different timeouts + pool1 = redis.ConnectionPool( + connection_class=DummyConnection, + idle_connection_timeout=5.0, + idle_check_interval=1.0, + ) + pool1.release(pool1.get_connection()) + pool2 = redis.ConnectionPool( + connection_class=DummyConnection, + idle_connection_timeout=10.0, + idle_check_interval=2.0, + ) + pool2.release(pool2.get_connection()) + + manager = redis.connection.IdleConnectionCleanupManager.get_instance() + + # Both pools should be in the schedule + pool_ids_in_schedule = {entry.pool_id for entry in manager._schedule} + assert id(pool1) in pool_ids_in_schedule + assert id(pool2) in pool_ids_in_schedule + + pool1.close() + pool2.close() + + def test_manager_schedules_empty_pool_on_release(self): + """Test that manager re-registers an empty pool when a connection is released.""" + pool = redis.ConnectionPool( + connection_class=DummyConnection, + idle_connection_timeout=10.0, + idle_check_interval=5.0, + ) + + manager = redis.connection.IdleConnectionCleanupManager.get_instance() + pool_id = id(pool) + + # Now pool should not be tracked + assert pool_id not in manager._registered_pool_ids + # Filter for this pool_id AND check that weakref is not dead (in case memory address was reused) + schedule = [ + entry + for entry in manager._schedule + if entry.pool_id == pool_id and entry.pool_ref() is not None + ] + assert len(schedule) == 0 + + # Release a connection + conn = pool.get_connection() + pool.release(conn) + + # Pool should now be re-registered and scheduled + assert pool_id in manager._registered_pool_ids + # Filter for this pool_id AND check that weakref is not dead (in case memory address was reused) + schedule = [ + entry + for entry in manager._schedule + if entry.pool_id == pool_id and entry.pool_ref() is not None + ] + assert len(schedule) == 1 + + pool.close() + + def test_manager_automatically_cleans_idle_connections(self): + """Integration test: Manager automatically cleans up idle connections without manual trigger.""" + import time + + with MockDateTime() as mock_time: + pool = redis.ConnectionPool( + connection_class=DummyConnection, + idle_connection_timeout=1.0, # 1 second timeout + idle_check_interval=0.5, # Check every 0.5 seconds + ) + + try: + # Get and release a connection + conn1 = pool.get_connection() + pool.release(conn1) + + # Should have 1 available connection + assert len(pool._available_connections) == 1 + assert pool._created_connections == 1 + + # Advance time past the idle timeout + mock_time.advance(1.5) + + # Notify the manager to wake up and check (simulates time passing) + manager = redis.connection.IdleConnectionCleanupManager.get_instance() + with manager._condition: + manager._condition.notify() + + # Poll until the worker thread processes (with timeout) + deadline = time.time() + 1.0 # 1 second timeout + while time.time() < deadline: + if len(pool._available_connections) == 0: + break + time.sleep(0.01) + + # The manager should have cleaned it up automatically + assert len(pool._available_connections) == 0 + assert pool._created_connections == 0 + finally: + pool.close() + + def test_manager_reschedules_pools_after_cleanup(self): + """Integration test: Manager reschedules pools that still have connections after cleanup.""" + import time + + with MockDateTime() as mock_time: + pool = redis.ConnectionPool( + connection_class=DummyConnection, + idle_connection_timeout=1.5, # 1.5 seconds timeout + idle_check_interval=0.5, # Check every 0.5 seconds + ) + + try: + manager = redis.connection.IdleConnectionCleanupManager.get_instance() + + # Get and release two connections + conn1 = pool.get_connection() + conn2 = pool.get_connection() + pool.release(conn1) + + # Advance time, then release conn2 + mock_time.advance(1.0) + pool.release(conn2) + + # Now we have: + # - conn1: idle for 1.0s + # - conn2: idle for 0s + + # Advance time for first cleanup cycle + mock_time.advance(0.6) # Total: conn1 at 1.6s, conn2 at 0.6s + + # Wake up manager + with manager._condition: + manager._condition.notify() + + # Poll until first cleanup happens + deadline = time.time() + 1.0 + while time.time() < deadline: + if len(pool._available_connections) == 1: + break + time.sleep(0.01) + + # conn1 should be cleaned (>1.5s), conn2 should remain (<1.5s) + assert len(pool._available_connections) == 1 + assert pool._created_connections == 1 + + # Verify pool was rescheduled by advancing time again + mock_time.advance(1.0) # Total: conn2 at 1.6s + + # Wake up manager + with manager._condition: + manager._condition.notify() + + # Poll until second cleanup happens + deadline = time.time() + 1.0 + while time.time() < deadline: + if len(pool._available_connections) == 0: + break + time.sleep(0.01) + + # Now conn2 should also be cleaned + assert len(pool._available_connections) == 0 + assert pool._created_connections == 0 + finally: + pool.close() + + def test_manager_removes_empty_pools_from_tracking(self): + """Integration test: Manager removes empty pools from its internal tracking.""" + import time + + with MockDateTime() as mock_time: + pool = redis.ConnectionPool( + connection_class=DummyConnection, + idle_connection_timeout=1.0, # 1 second timeout + idle_check_interval=0.5, # Check every 0.5 seconds + ) + + try: + # Get and release a connection + conn = pool.get_connection() + pool.release(conn) + + # Pool should be registered + manager = redis.connection.IdleConnectionCleanupManager.get_instance() + pool_id = id(pool) + assert pool_id in manager._registered_pool_ids + + # Advance time past timeout + mock_time.advance(1.5) + + # Wake up manager + with manager._condition: + manager._condition.notify() + + # Poll until cleanup happens + deadline = time.time() + 1.0 + while time.time() < deadline: + if pool_id not in manager._registered_pool_ids: + break + time.sleep(0.01) + + # Pool should be empty + assert len(pool._available_connections) == 0 + + # Pool should be removed from manager's tracking + assert pool_id not in manager._registered_pool_ids + finally: + pool.close() + + def test_manager_schedules_at_correct_time(self): + """Integration test: Manager schedules cleanups at the correct time based on idle_timeout.""" + import time + + with MockDateTime() as mock_time: + pool = redis.ConnectionPool( + connection_class=DummyConnection, + idle_connection_timeout=2.0, # 2 seconds timeout + idle_check_interval=0.5, # Check every 0.5 seconds + ) + + try: + manager = redis.connection.IdleConnectionCleanupManager.get_instance() + + # Get and release a connection + conn = pool.get_connection() + pool.release(conn) + + # Connection should NOT be cleaned up before timeout + mock_time.advance(1.0) # 1 second - less than 2 second timeout + + # Wake up manager + with manager._condition: + manager._condition.notify() + + # Give worker thread time to process, but it shouldn't clean anything + time.sleep(0.05) + + assert len(pool._available_connections) == 1 + assert pool._created_connections == 1 + + # Connection SHOULD be cleaned up after timeout + mock_time.advance(1.5) # Total 2.5 seconds - more than 2 second timeout + + # Wake up manager + with manager._condition: + manager._condition.notify() + + # Poll until cleanup happens + deadline = time.time() + 1.0 + while time.time() < deadline: + if len(pool._available_connections) == 0: + break + time.sleep(0.01) + + assert len(pool._available_connections) == 0 + assert pool._created_connections == 0 + finally: + pool.close() diff --git a/tests/test_maint_notifications_handling.py b/tests/test_maint_notifications_handling.py index 556b63d7e1..74541066e0 100644 --- a/tests/test_maint_notifications_handling.py +++ b/tests/test_maint_notifications_handling.py @@ -104,12 +104,7 @@ def validate_free_connections_state( ): """Helper method to validate state of free/available connections.""" - if isinstance(pool, BlockingConnectionPool): - free_connections = [conn for conn in pool.pool.queue if conn is not None] - elif isinstance(pool, ConnectionPool): - free_connections = pool._available_connections - else: - raise ValueError(f"Unsupported pool type: {type(pool)}") + free_connections = pool._get_free_connections() connected_count = 0 for connection in free_connections: @@ -2076,10 +2071,7 @@ def test_migrating_after_moving_multiple_proxies(self, pool_class): ) # validate free connections for ip1 changed_free_connections = 0 - if isinstance(pool, BlockingConnectionPool): - free_connections = [conn for conn in pool.pool.queue if conn is not None] - elif isinstance(pool, ConnectionPool): - free_connections = pool._available_connections + free_connections = pool._get_free_connections() for conn in free_connections: if conn.host == new_ip: changed_free_connections += 1 @@ -2126,10 +2118,7 @@ def test_migrating_after_moving_multiple_proxies(self, pool_class): ) # validate free connections for ip2 changed_free_connections = 0 - if isinstance(pool, BlockingConnectionPool): - free_connections = [conn for conn in pool.pool.queue if conn is not None] - elif isinstance(pool, ConnectionPool): - free_connections = pool._available_connections + free_connections = pool._get_free_connections() for conn in free_connections: if conn.host == new_ip_2: changed_free_connections += 1 diff --git a/tests/test_multiprocessing.py b/tests/test_multiprocessing.py index 6a90d55cc6..b8e4e60866 100644 --- a/tests/test_multiprocessing.py +++ b/tests/test_multiprocessing.py @@ -121,11 +121,13 @@ def target(pool, parent_conn): assert child_conn.pid != parent_conn.pid pool.release(child_conn) assert pool._created_connections == 1 - assert child_conn in pool._available_connections + assert child_conn in [p.connection for p in pool._available_connections] pool.release(parent_conn) assert pool._created_connections == 1 - assert child_conn in pool._available_connections - assert parent_conn not in pool._available_connections + assert child_conn in [p.connection for p in pool._available_connections] + assert parent_conn not in [ + p.connection for p in pool._available_connections + ] proc = self._mp_context.Process(target=target, args=(pool, parent_conn)) proc.start() diff --git a/tests/test_scenario/test_maint_notifications.py b/tests/test_scenario/test_maint_notifications.py index 7d99bfe8ae..398527bfc2 100644 --- a/tests/test_scenario/test_maint_notifications.py +++ b/tests/test_scenario/test_maint_notifications.py @@ -184,8 +184,8 @@ def _execute_migrate_bind_flow( def _get_all_connections_in_pool(self, client: Redis) -> List[ConnectionInterface]: connections = [] if hasattr(client.connection_pool, "_available_connections"): - for conn in client.connection_pool._available_connections: - connections.append(conn) + for pooled_conn in client.connection_pool._available_connections: + connections.append(pooled_conn.connection) if hasattr(client.connection_pool, "_in_use_connections"): for conn in client.connection_pool._in_use_connections: connections.append(conn)