Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 63 additions & 0 deletions cassandra/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,8 @@

_NOT_SET = object()

_NOT_SET_TLS_CACHE = object()


class NoHostAvailable(Exception):
"""
Expand Down Expand Up @@ -875,6 +877,50 @@
.. versionadded:: 3.17.0
"""

tls_session_cache = _NOT_SET_TLS_CACHE
"""
TLS session cache configuration for faster reconnections.
When SSL/TLS is enabled, TLS sessions are cached and reused for subsequent
connections to the same endpoint, reducing handshake latency.

Can be set to:

- ``_NOT_SET_TLS_CACHE`` (default): A :class:`~cassandra.tls.DefaultTLSSessionCache` is
automatically created when SSL/TLS is enabled.
- ``None``: Disable TLS session caching entirely.
- An instance of :class:`~cassandra.tls.TLSSessionCacheOptions` for
fine-grained control over session caching behavior (e.g., cache_by_host_only option).
- An instance of :class:`~cassandra.tls.TLSSessionCache` (or a custom subclass)
for complete control over session caching implementation.

Example disabling caching::

cluster = Cluster(ssl_context=ssl_context, tls_session_cache=None)

Example with options::

from cassandra.tls import TLSSessionCacheOptions

options = TLSSessionCacheOptions(
max_size=200,
ttl=7200,
cache_by_host_only=True
)
cluster = Cluster(ssl_context=ssl_context, tls_session_cache=options)

Example with custom cache::

from cassandra.tls import TLSSessionCache

class MyCustomCache(TLSSessionCache):
# Custom implementation
pass

cluster = Cluster(ssl_context=ssl_context, tls_session_cache=MyCustomCache())

.. versionadded:: 3.30.0
"""

sockopts = None
"""
An optional list of tuples which will be used as arguments to
Expand Down Expand Up @@ -1204,6 +1250,7 @@
idle_heartbeat_timeout=30,
no_compact=False,
ssl_context=None,
tls_session_cache=_NOT_SET_TLS_CACHE,
endpoint_factory=None,
application_name=None,
application_version=None,
Expand Down Expand Up @@ -1420,6 +1467,21 @@

self.ssl_options = ssl_options
self.ssl_context = ssl_context
self.tls_session_cache = tls_session_cache

# Initialize TLS session cache if SSL is enabled and caching is not disabled
self._tls_session_cache = None
if (ssl_context or ssl_options) and tls_session_cache is not None:
from cassandra.tls import TLSSessionCache, TLSSessionCacheOptions, DefaultTLSSessionCache

if isinstance(tls_session_cache, TLSSessionCache):
self._tls_session_cache = tls_session_cache
elif isinstance(tls_session_cache, TLSSessionCacheOptions):
self._tls_session_cache = tls_session_cache.create_cache()
else:
# Default: create cache with default parameters
self._tls_session_cache = DefaultTLSSessionCache()

self.sockopts = sockopts
self.cql_version = cql_version
self.max_schema_agreement_wait = max_schema_agreement_wait
Expand Down Expand Up @@ -1661,6 +1723,7 @@
kwargs_dict.setdefault('sockopts', self.sockopts)
kwargs_dict.setdefault('ssl_options', self.ssl_options)
kwargs_dict.setdefault('ssl_context', self.ssl_context)
kwargs_dict.setdefault('tls_session_cache', self._tls_session_cache)
kwargs_dict.setdefault('cql_version', self.cql_version)
kwargs_dict.setdefault('protocol_version', self.protocol_version)
kwargs_dict.setdefault('user_type_map', self._user_types)
Expand Down Expand Up @@ -4246,7 +4309,7 @@
self._scheduled_tasks.discard(task)
fn, args, kwargs = task
kwargs = dict(kwargs)
future = self._executor.submit(fn, *args, **kwargs)

Check failure on line 4312 in cassandra/cluster.py

View workflow job for this annotation

GitHub Actions / test asyncio (3.11)

cannot schedule new futures after shutdown
future.add_done_callback(self._log_if_failed)
else:
self._queue.put_nowait((run_at, i, task))
Expand Down
54 changes: 52 additions & 2 deletions cassandra/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,15 @@ def socket_family(self):
"""
return socket.AF_UNSPEC

@property
def tls_session_cache_key(self):
"""
Returns the cache key components for TLS session caching.
This is a tuple that uniquely identifies this endpoint for TLS session purposes.
Subclasses may override this to include additional components (e.g., SNI server name).
"""
return (self.address, self.port)

def resolve(self):
"""
Resolve the endpoint to an address/port. This is called
Expand Down Expand Up @@ -275,6 +284,14 @@ def port(self):
def ssl_options(self):
return self._ssl_options

@property
def tls_session_cache_key(self):
"""
Returns the cache key including server_name for SNI endpoints.
This prevents cache collisions when multiple SNI endpoints use the same proxy.
"""
return (self.address, self.port, self._server_name)

def resolve(self):
try:
resolved_addresses = socket.getaddrinfo(self._proxy_address, self._port,
Expand Down Expand Up @@ -349,6 +366,14 @@ def port(self):
def socket_family(self):
return socket.AF_UNIX

@property
def tls_session_cache_key(self):
"""
Returns the cache key for Unix socket endpoints.
Since Unix sockets don't have a port, only the path is used.
"""
return (self._unix_socket_path,)

def resolve(self):
return self.address, None

Expand Down Expand Up @@ -687,6 +712,7 @@ class Connection(object):
endpoint = None
ssl_options = None
ssl_context = None
tls_session_cache = None
last_error = None

# The current number of operations that are in flight. More precisely,
Expand Down Expand Up @@ -763,14 +789,15 @@ def __init__(self, host='127.0.0.1', port=9042, authenticator=None,
ssl_options=None, sockopts=None, compression: Union[bool, str] = True,
cql_version=None, protocol_version=ProtocolVersion.MAX_SUPPORTED, is_control_connection=False,
user_type_map=None, connect_timeout=None, allow_beta_protocol_version=False, no_compact=False,
ssl_context=None, owning_pool=None, shard_id=None, total_shards=None,
ssl_context=None, tls_session_cache=None, owning_pool=None, shard_id=None, total_shards=None,
on_orphaned_stream_released=None, application_info: Optional[ApplicationInfoBase] = None):
# TODO next major rename host to endpoint and remove port kwarg.
self.endpoint = host if isinstance(host, EndPoint) else DefaultEndPoint(host, port)

self.authenticator = authenticator
self.ssl_options = ssl_options.copy() if ssl_options else {}
self.ssl_context = ssl_context
self.tls_session_cache = tls_session_cache
self.sockopts = sockopts
self.compression = compression
self.cql_version = cql_version
Expand Down Expand Up @@ -913,7 +940,21 @@ def _wrap_socket_from_context(self):
server_hostname = self.endpoint.address
opts['server_hostname'] = server_hostname

return self.ssl_context.wrap_socket(self._socket, **opts)
# Try to get a cached TLS session for resumption
# Note: Session resumption works with both TLS 1.2 and TLS 1.3
# Python's ssl module handles both transparently via SSLSession objects
if self.tls_session_cache:
cached_session = self.tls_session_cache.get_session(self.endpoint)
if cached_session:
opts['session'] = cached_session
log.debug("Using cached TLS session for %s", self.endpoint)

ssl_socket = self.ssl_context.wrap_socket(self._socket, **opts)

# Note: Session is NOT stored here - it will be stored after successful connection
# in _connect_socket() to ensure we only cache sessions for successful connections

return ssl_socket

def _initiate_connection(self, sockaddr):
if self.features.shard_id is not None:
Expand Down Expand Up @@ -968,6 +1009,15 @@ def _connect_socket(self):
# run that here.
if self._check_hostname:
self._validate_hostname()

# Store the TLS session after successful connection
# This ensures we only cache sessions for connections that actually succeeded
if self.tls_session_cache and self.ssl_context and hasattr(self._socket, 'session'):
if self._socket.session:
self.tls_session_cache.set_session(self.endpoint, self._socket.session)
if hasattr(self._socket, 'session_reused') and self._socket.session_reused:
log.debug("TLS session was reused for %s", self.endpoint)

sockerr = None
break
except socket.error as err:
Expand Down
14 changes: 14 additions & 0 deletions cassandra/io/eventletreactor.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,13 +109,27 @@ def _wrap_socket_from_context(self):
# This is necessary for SNI
self._socket.set_tlsext_host_name(self.ssl_options['server_hostname'].encode('ascii'))

# Apply cached TLS session for resumption (PyOpenSSL)
if self.tls_session_cache:
cached_session = self.tls_session_cache.get_session(self.endpoint)
if cached_session:
self._socket.set_session(cached_session)
log.debug("Using cached TLS session for %s", self.endpoint)

def _initiate_connection(self, sockaddr):
if self.uses_legacy_ssl_options:
super(EventletConnection, self)._initiate_connection(sockaddr)
else:
self._socket.connect(sockaddr)
if self.ssl_context or self.ssl_options:
self._socket.do_handshake()
# Store TLS session after successful handshake (PyOpenSSL)
if self.tls_session_cache:
session = self._socket.get_session()
if session:
self.tls_session_cache.set_session(self.endpoint, session)
if self._socket.session_reused():
log.debug("TLS session was reused for %s", self.endpoint)

def _match_hostname(self):
if self.uses_legacy_ssl_options:
Expand Down
20 changes: 19 additions & 1 deletion cassandra/io/twistedreactor.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,11 +139,12 @@ def _on_loop_timer(self):

@implementer(IOpenSSLClientConnectionCreator)
class _SSLCreator(object):
def __init__(self, endpoint, ssl_context, ssl_options, check_hostname, timeout):
def __init__(self, endpoint, ssl_context, ssl_options, check_hostname, timeout, tls_session_cache=None):
self.endpoint = endpoint
self.ssl_options = ssl_options
self.check_hostname = check_hostname
self.timeout = timeout
self.tls_session_cache = tls_session_cache

if ssl_context:
self.context = ssl_context
Expand Down Expand Up @@ -171,11 +172,27 @@ def info_callback(self, connection, where, ret):
transport = connection.get_app_data()
transport.failVerification(Failure(ConnectionException("Hostname verification failed", self.endpoint)))

# Store TLS session after successful handshake (PyOpenSSL)
if self.tls_session_cache:
session = connection.get_session()
if session:
self.tls_session_cache.set_session(self.endpoint, session)
if connection.session_reused():
log.debug("TLS session was reused for %s", self.endpoint)

def clientConnectionForTLS(self, tlsProtocol):
connection = SSL.Connection(self.context, None)
connection.set_app_data(tlsProtocol)
if self.ssl_options and "server_hostname" in self.ssl_options:
connection.set_tlsext_host_name(self.ssl_options['server_hostname'].encode('ascii'))

# Apply cached TLS session for resumption (PyOpenSSL)
if self.tls_session_cache:
cached_session = self.tls_session_cache.get_session(self.endpoint)
if cached_session:
connection.set_session(cached_session)
log.debug("Using cached TLS session for %s", self.endpoint)

return connection


Expand Down Expand Up @@ -241,6 +258,7 @@ def add_connection(self):
self.ssl_options,
self._check_hostname,
self.connect_timeout,
tls_session_cache=self.tls_session_cache,
)

endpoint = SSL4ClientEndpoint(
Expand Down
Loading
Loading