From 16fac638d5976f95754f99dae11023fe50306dd6 Mon Sep 17 00:00:00 2001 From: sylwiaszunejko Date: Fri, 13 Feb 2026 11:09:27 +0100 Subject: [PATCH 1/2] Add host-id to AddressTranslator --- cassandra/connection.py | 14 +++++++++++--- cassandra/policies.py | 21 +++++++++++++++++++++ tests/unit/test_policies.py | 2 +- 3 files changed, 33 insertions(+), 4 deletions(-) diff --git a/cassandra/connection.py b/cassandra/connection.py index 87f860f32b..c25372e39f 100644 --- a/cassandra/connection.py +++ b/cassandra/connection.py @@ -244,11 +244,19 @@ def create(self, row): if port is None: port = self.port if self.port else 9042 + # Extract host metadata for V2 address translation + host_id = row.get("host_id") + + # Use V2 API if available, fall back to V1 + translator = self.cluster.address_translator + if hasattr(translator, 'translate_with_host_id'): + translated_addr = translator.translate_with_host_id(addr, host_id) + else: + translated_addr = translator.translate(addr) + # create the endpoint with the translated address # TODO next major, create a TranslatedEndPoint type - return DefaultEndPoint( - self.cluster.address_translator.translate(addr), - port) + return DefaultEndPoint(translated_addr, port) @total_ordering diff --git a/cassandra/policies.py b/cassandra/policies.py index e742708019..1f8456481f 100644 --- a/cassandra/policies.py +++ b/cassandra/policies.py @@ -1210,9 +1210,27 @@ class AddressTranslator(object): def translate(self, addr): """ Accepts the node ip address, and returns a translated address to be used connecting to this node. + + Legacy V1 API for backward compatibility. New implementations should override translate_with_host. """ raise NotImplementedError() + def translate_with_host_id(self, addr, host_id): + """ + V2 API: Accepts the node ip address and optional Host ID, returns translated address. + + :param addr: The node IP address to translate + :param host_id: Host ID + :return: Translated address to be used for connecting + + This method provides access to Host metadata (especially host_id) which is required + for PrivateLink and similar scenarios where translation is keyed by Host ID rather + than IP address. + + Default implementation delegates to translate() for backward compatibility. + """ + return self.translate(addr) + class IdentityTranslator(AddressTranslator): """ @@ -1221,6 +1239,9 @@ class IdentityTranslator(AddressTranslator): def translate(self, addr): return addr + def translate_with_host_id(self, addr, host_id): + return addr + class EC2MultiRegionTranslator(AddressTranslator): """ diff --git a/tests/unit/test_policies.py b/tests/unit/test_policies.py index 6142af1aa1..e0754baec9 100644 --- a/tests/unit/test_policies.py +++ b/tests/unit/test_policies.py @@ -1442,7 +1442,7 @@ def test_identity_translator(self): def test_ec2_multi_region_translator(self, *_): ec2t = EC2MultiRegionTranslator() addr = '127.0.0.1' - translated = ec2t.translate(addr) + translated = ec2t.translate_with_host_id(addr, "") assert translated is not addr # verifies that the resolver path is followed assert translated == addr # and that it resolves to the same address From b2b13a2764a71472a20a5e63113b4c00a1e98ac3 Mon Sep 17 00:00:00 2001 From: sylwiaszunejko Date: Mon, 16 Feb 2026 10:02:47 +0100 Subject: [PATCH 2/2] Implement scylla-specific ClientRoutes feature This feature was implemented in scylladb/scylladb#27323 Idea is to enable clients to dynamically learn address translation information from the system.client_routes table. When this table is updated drivers get CLIENT_ROUTES_CHANGE event with scope of the change. This PR adds ability to configure driver to read this table and events and maintain address translation mapping updated. --- cassandra/client_routes.py | 716 ++++++++++++++++++ cassandra/cluster.py | 84 +- cassandra/policies.py | 1 - .../standard/test_client_routes.py | 187 +++++ tests/unit/test_client_routes.py | 600 +++++++++++++++ 5 files changed, 1583 insertions(+), 5 deletions(-) create mode 100644 cassandra/client_routes.py create mode 100644 tests/integration/standard/test_client_routes.py create mode 100644 tests/unit/test_client_routes.py diff --git a/cassandra/client_routes.py b/cassandra/client_routes.py new file mode 100644 index 0000000000..6dee4b490d --- /dev/null +++ b/cassandra/client_routes.py @@ -0,0 +1,716 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Client Routes support for Private Link and similar network configurations. + +This module implements support for dynamic address translation via the +system.client_routes table and CLIENT_ROUTES_CHANGE events. +""" + +from __future__ import absolute_import + +import logging +import socket +import threading +import time +import uuid +from collections import namedtuple + +from cassandra import DriverException +from cassandra.query import dict_factory + +log = logging.getLogger(__name__) + + +class ClientRoutesEndpoint: + + def __init__(self, connection_id, connection_addr=None): + """ + :param connection_id: UUID string or UUID object identifying the connection + :param connection_addr: Optional string address for initial connection + """ + if connection_id is None: + raise ValueError("connection_id is required") + + if isinstance(connection_id, str): + try: + self.connection_id = uuid.UUID(connection_id) + except ValueError: + raise ValueError(f"Invalid UUID format for connection_id: {connection_id}") + elif isinstance(connection_id, uuid.UUID): + self.connection_id = connection_id + else: + raise TypeError("connection_id must be a UUID or string") + + self.connection_addr = connection_addr + + def __repr__(self): + return f"ClientRoutesEndpoint(connection_id={self.connection_id}, connection_addr={self.connection_addr})" + + +class ClientRoutesConfig: + """ + Configuration for client routes (Private Link support). + + :param endpoints: List of ClientRoutesEndpoint objects (REQUIRED, at least one) + :param table_name: Name of the system table to query (default: "system.client_routes") + """ + + def __init__(self, endpoints, table_name="system.client_routes", + max_resolver_concurrency=1, resolve_healthy_endpoint_period_ms=500, + block_unknown_endpoints=False): + """ + :param endpoints: List of ClientRoutesEndpoint objects + :param table_name: System table name for route discovery + :param max_resolver_concurrency: Maximum concurrent DNS resolutions (must be > 0, default: 1) + :param resolve_healthy_endpoint_period_ms: How often to re-resolve healthy endpoints in ms (must be >= 0, default: 500) + :param block_unknown_endpoints: Only process events for configured connection IDs (default: False) + """ + if not endpoints: + raise ValueError("At least one endpoint must be specified") + + if not isinstance(endpoints, (list, tuple)): + raise TypeError("endpoints must be a list or tuple") + + for endpoint in endpoints: + if not isinstance(endpoint, ClientRoutesEndpoint): + raise TypeError("All endpoints must be ClientRoutesEndpoint instances") + + if max_resolver_concurrency <= 0: + raise ValueError("max_resolver_concurrency must be > 0") + + if resolve_healthy_endpoint_period_ms < 0: + raise ValueError("resolve_healthy_endpoint_period_ms must be >= 0") + + self.endpoints = list(endpoints) + self.table_name = table_name + self.max_resolver_concurrency = max_resolver_concurrency + self.resolve_healthy_endpoint_period_ms = resolve_healthy_endpoint_period_ms + self.block_unknown_endpoints = block_unknown_endpoints + + def __repr__(self): + return f"ClientRoutesConfig(endpoints={self.endpoints}, table_name={self.table_name})" + + +# Internal data structures + +ResolvedRoute = namedtuple('ResolvedRoute', [ + 'connection_id', + 'host_id', + 'address', # DNS hostname from system.client_routes + 'port', + 'tls_port', + 'datacenter', + 'rack', + 'all_known_ips', # List of all resolved IP addresses + 'current_ip', # Currently selected IP address + 'update_time', # Timestamp of last resolution + 'forced_resolve' # Flag to force resolution on next cycle +]) + + +class DNSResolver: + """ + DNS resolver with caching and concurrency limits. + + :param cache_duration_ms: How long to cache DNS resolutions (default: 500ms) + :param max_concurrent: Maximum concurrent DNS resolutions (default: 1) + """ + + def __init__(self, cache_duration_ms=500, max_concurrent=1): + self.cache_duration_ms = cache_duration_ms + self.max_concurrent = max_concurrent + self._lock = threading.Lock() + self._semaphore = threading.Semaphore(max_concurrent) + + def resolve(self, hostname, cached_ips=None, current_ip=None, cached_at=None): + """ + Resolve a hostname to all IP addresses with caching. + + :param hostname: DNS hostname to resolve + :param cached_ips: Previously resolved list of IPs (for cache check) + :param current_ip: Previously selected current IP + :param cached_at: Timestamp of cached resolution + :return: Tuple of (all_ips, current_ip, timestamp) or (cached_ips, current_ip, cached_at) on failure + """ + # Check if cached result is still valid + if cached_ips is not None and cached_at is not None: + age_ms = (time.time() - cached_at) * 1000 + if age_ms < self.cache_duration_ms: + return (cached_ips, current_ip, cached_at) + + # Acquire semaphore for DNS resolution + if not self._semaphore.acquire(blocking=True, timeout=5.0): + log.warning("DNS resolution timeout acquiring semaphore for %s", hostname) + # Return cached IPs as fallback if available + return (cached_ips, current_ip, cached_at) if cached_ips else (None, None, None) + + try: + # Perform DNS resolution + try: + # Get all IPv4 addresses + result = socket.getaddrinfo(hostname, None, socket.AF_INET, socket.SOCK_STREAM) + if result: + # Extract all unique IPs + all_ips = list(set([addr[4][0] for addr in result])) + timestamp = time.time() + + # Select current IP: prefer existing if still in list, otherwise use first + selected_ip = None + if current_ip and current_ip in all_ips: + selected_ip = current_ip + else: + selected_ip = all_ips[0] if all_ips else None + + log.debug("Resolved %s to %d IPs, selected: %s", hostname, len(all_ips), selected_ip) + return (all_ips, selected_ip, timestamp) + else: + log.warning("No DNS results for %s", hostname) + return (cached_ips, current_ip, cached_at) if cached_ips else (None, None, None) + except (socket.gaierror, Exception) as e: + log.warning("DNS resolution failed for %s: %s", hostname, e) + # Return cached IPs as fallback (best-effort continuity) + return (cached_ips, current_ip, cached_at) if cached_ips else (None, None, None) + finally: + self._semaphore.release() + + +class ResolvedRoutes: + """ + Thread-safe storage for resolved routes using lock-free reads. + + This uses atomic pointer swaps for updates, allowing lock-free reads + while serializing writes. + """ + + def __init__(self): + self._routes_by_host_id = {} # Dict[UUID, ResolvedRoute] + self._lock = threading.RLock() + + def get_by_host_id(self, host_id): + """ + Get route for a host ID (lock-free read). + + :param host_id: UUID of the host + :return: ResolvedRoute or None + """ + return self._routes_by_host_id.get(host_id) + + def get_all(self): + """ + Get all routes as a list (lock-free read). + + :return: List of ResolvedRoute + """ + return list(self._routes_by_host_id.values()) + + def update(self, routes): + """ + Replace all routes atomically. + + :param routes: List of ResolvedRoute objects + """ + with self._lock: + self._routes_by_host_id = {route.host_id: route for route in routes} + + def merge(self, new_routes): + """ + Merge new routes with existing ones atomically. + + :param new_routes: List of ResolvedRoute objects to merge + """ + with self._lock: + updated = dict(self._routes_by_host_id) + for route in new_routes: + updated[route.host_id] = route + self._routes_by_host_id = updated + + def merge_with_unresolved(self, new_routes): + """ + Merge unresolved routes, marking changed ones for forced resolution. + + :param new_routes: List of ResolvedRoute objects from system.client_routes + """ + with self._lock: + updated = dict(self._routes_by_host_id) + + for new_route in new_routes: + key = new_route.host_id + existing = updated.get(key) + + if existing is None: + # New route, add with forced_resolve=True + updated[key] = new_route._replace(forced_resolve=True) + else: + # Check if route details changed (address, port, tls_port) + if (existing.connection_id != new_route.connection_id or + existing.address != new_route.address or + existing.port != new_route.port or + existing.tls_port != new_route.tls_port): + # Route changed, mark for forced resolution + updated[key] = new_route._replace(forced_resolve=True) + # Otherwise keep existing route with its resolution state + + self._routes_by_host_id = updated + + def update_single(self, host_id, update_fn): + """ + Update a single route using CAS (compare-and-swap) pattern. + + :param host_id: UUID of the host to update + :param update_fn: Function that takes existing route and returns updated route + :return: Updated ResolvedRoute or None if not found + """ + with self._lock: + existing = self._routes_by_host_id.get(host_id) + if existing: + updated = update_fn(existing) + self._routes_by_host_id[host_id] = updated + return updated + return None + + +def _parse_route_row(row): + """ + Parse a row from system.client_routes into a ResolvedRoute. + + :param row: dict from system.client_routes query result + :return: ResolvedRoute (with all_known_ips and current_ip as None initially) + """ + return ResolvedRoute( + connection_id=row['connection_id'], + host_id=row['host_id'], + address=row['address'], + port=row['port'], + tls_port=row.get('tls_port'), + datacenter=row.get('datacenter'), + rack=row.get('rack'), + all_known_ips=None, + current_ip=None, + update_time=None, + forced_resolve=True # Force initial resolution + ) + + +class ClientRoutesHandler: + """ + Handles dynamic address translation for Private Link via system.client_routes. + + Lifecycle: + 1. Construction: Create with configuration + 2. Initialization: Read system.client_routes after control connection established + 3. Steady state: Listen for CLIENT_ROUTES_CHANGE events and update routes + 4. Translation: Translate addresses using Host ID lookup + 5. Shutdown: Clean up resources + """ + + def __init__(self, config, ssl_enabled=False): + """ + :param config: ClientRoutesConfig instance + :param ssl_enabled: Whether TLS is enabled (determines port selection) + """ + if not isinstance(config, ClientRoutesConfig): + raise TypeError("config must be a ClientRoutesConfig instance") + + self.config = config + self.ssl_enabled = ssl_enabled + self._resolver = DNSResolver() + self._routes = ResolvedRoutes() + self._initial_endpoints = {ep.connection_id for ep in config.endpoints} + self._is_shutdown = False + self._lock = threading.RLock() + + def initialize(self, control_connection): + """ + Initialize handler after control connection is established. + + Reads system.client_routes for all configured connection IDs and resolves DNS. + This is a synchronous operation that blocks until complete. + + :param control_connection: The ControlConnection instance + """ + if self._is_shutdown: + return + + log.info("[client routes] Initializing with %d endpoints", len(self.config.endpoints)) + + try: + # Query all connection IDs + connection_ids = [ep.connection_id for ep in self.config.endpoints] + routes = self._query_routes(control_connection, connection_ids=connection_ids) + + # Merge unresolved routes and resolve + self._routes.merge_with_unresolved(routes) + self._resolve_and_update_in_place() + + log.info("[client routes] Initialized with %d routes", len(self._routes.get_all())) + except Exception as e: + log.error("[client routes] Initialization failed: %s", e, exc_info=True) + raise + + def handle_client_routes_change(self, control_connection, change_type, connection_ids, host_ids): + """ + Handle CLIENT_ROUTES_CHANGE event. + + :param control_connection: The ControlConnection instance + :param change_type: Type of change (e.g., "UPDATED") + :param connection_ids: List of affected connection ID strings + :param host_ids: List of affected host ID strings + """ + if self._is_shutdown: + return + + log.debug("[client routes] Handling CLIENT_ROUTES_CHANGE: change_type=%s, " + "connection_ids=%s, host_ids=%s", + change_type, connection_ids, host_ids) + + try: + # Filter connection IDs if block_unknown_endpoints is enabled + filtered_conn_ids = None + if connection_ids: + if self.config.block_unknown_endpoints: + configured_ids = {str(ep.connection_id) for ep in self.config.endpoints} + filtered = [cid for cid in connection_ids if cid in configured_ids] + if not filtered: + log.debug("[client routes] All connection IDs filtered out, ignoring event") + return + filtered_conn_ids = [uuid.UUID(cid) for cid in filtered] + else: + filtered_conn_ids = [uuid.UUID(cid) for cid in connection_ids] + + host_uuids = [uuid.UUID(hid) for hid in host_ids] if host_ids else None + + # Query affected routes + routes = self._query_routes( + control_connection, + connection_ids=filtered_conn_ids, + host_ids=host_uuids + ) + + # Merge and resolve + self._routes.merge_with_unresolved(routes) + self._resolve_and_update_in_place() + + log.debug("[client routes] Updated routes after CLIENT_ROUTES_CHANGE") + except Exception as e: + log.warning("[client routes] Failed to handle CLIENT_ROUTES_CHANGE: %s", e, exc_info=True) + + def handle_control_connection_reconnect(self, control_connection): + """ + Handle control connection recreation - full re-read of all connection IDs. + + :param control_connection: The new ControlConnection instance + """ + if self._is_shutdown: + return + + log.info("[client routes] Control connection reconnected, re-reading all routes") + + try: + self.initialize(control_connection) + except Exception as e: + log.error("[client routes] Failed to re-initialize after reconnect: %s", e, exc_info=True) + + def _query_routes(self, control_connection, connection_ids=None, host_ids=None): + """ + Query system.client_routes table. + + :param control_connection: ControlConnection to execute query + :param connection_ids: Optional list of connection UUIDs to filter by + :param host_ids: Optional list of host UUIDs to filter by + :return: List of ResolvedRoute (with resolved_ip/resolved_at as None) + """ + query_parts = [f"SELECT * FROM {self.config.table_name}"] + where_clauses = [] + + if connection_ids: + conn_id_list = ', '.join(str(cid) for cid in connection_ids) + where_clauses.append(f"connection_id IN ({conn_id_list})") + + if host_ids: + host_id_list = ', '.join(str(hid) for hid in host_ids) + where_clauses.append(f"host_id IN ({host_id_list})") + + if where_clauses: + query_parts.append("WHERE " + " AND ".join(where_clauses)) + + if (not connection_ids or len(connection_ids) == 0) and (not host_ids or len(host_ids) == 0): + query_parts.append("ALLOW FILTERING") + query = " ".join(query_parts) + + log.debug("[client routes] Querying: %s", query) + + from cassandra.protocol import QueryMessage + from cassandra import ConsistencyLevel + + query_msg = QueryMessage(query=query, consistency_level=ConsistencyLevel.ONE) + result = control_connection._connection.wait_for_response( + query_msg, timeout=control_connection._timeout + ) + + routes = [] + if hasattr(result, 'parsed_rows') and result.parsed_rows: + rows = dict_factory( + result.column_names, + result.parsed_rows) + for row in rows: + try: + routes.append(_parse_route_row(row)) + except Exception as e: + log.warning("[client routes] Failed to parse route row: %s", e) + + return routes + + def _resolve_and_update_in_place(self): + """ + Resolve routes that need resolution. + + Routes are resolved if: + 1. Marked with forced_resolve=True + 2. Have no IP address information + 3. Were resolved more than resolve_healthy_endpoint_period_ms ago + """ + all_routes = self._routes.get_all() + if not all_routes: + return + + # Calculate cutoff time for healthy re-resolution + cutoff_time = None + if self.config.resolve_healthy_endpoint_period_ms == 0: + cutoff_time = None # Never re-resolve healthy endpoints + else: + cutoff_time = time.time() - (self.config.resolve_healthy_endpoint_period_ms / 1000.0) + + # Identify routes that need resolution + routes_to_resolve = [] + for route in all_routes: + needs_resolve = False + + if route.current_ip is None or route.all_known_ips is None or route.forced_resolve: + needs_resolve = True + elif cutoff_time is not None and route.update_time is not None: + if route.update_time < cutoff_time: + needs_resolve = True + + if needs_resolve: + routes_to_resolve.append(route) + + if not routes_to_resolve: + return + + # Resolve in parallel with concurrency limit + from concurrent.futures import ThreadPoolExecutor, as_completed + + with ThreadPoolExecutor(max_workers=self.config.max_resolver_concurrency) as executor: + futures = {} + for route in routes_to_resolve: + future = executor.submit(self._resolve_single_route, route) + futures[future] = route + + # Wait for all resolutions and update + for future in as_completed(futures): + try: + resolved_route = future.result() + if resolved_route: + # Update single route atomically + self._routes.update_single(resolved_route.host_id, lambda _: resolved_route) + except Exception as e: + route = futures[future] + log.warning("[client routes] Failed to resolve %s: %s", route.address, e) + + def _resolve_single_route(self, route): + """ + Resolve a single route and return updated ResolvedRoute. + """ + all_ips, current_ip, timestamp = self._resolver.resolve( + route.address, + route.all_known_ips, + route.current_ip, + route.update_time + ) + + if all_ips and current_ip: + return route._replace( + all_known_ips=all_ips, + current_ip=current_ip, + update_time=timestamp, + forced_resolve=False + ) + else: + # Resolution failed, keep old values but update time + return route._replace( + update_time=time.time() + ) + + def translate_address(self, addr, host_id): + """ + Translate an address using Host ID lookup. + + This is called per connection attempt. Key behaviors: + - Initial endpoints (contact points) are NOT translated + - Empty Host ID returns address unchanged (bootstrap placeholder) + - Host ID not found returns error + - Found route with resolved IP returns translated address + - Found route without resolved IP performs on-demand DNS resolution with CAS retry + + :param addr: Original IP address + :param host_id: Host ID + :return: Translated IP address + :raises: DriverException if translation fails + """ + if self._is_shutdown: + return addr + + if host_id is None: + return addr + + if not host_id: + return addr + + # Look up route by host ID + route = self._routes.get_by_host_id(host_id) + + if route is None: + raise DriverException( + f"[client routes] No route found for host_id={host_id}, " + f"addr={addr}. This may indicate configuration mismatch or " + f"that CLIENT_ROUTES_CHANGE events are not being received." + ) + + # If already resolved, return it + if route.current_ip: + return route.current_ip + + # On-demand DNS resolution with CAS retry loop (like Go) + log.debug("[client routes] On-demand DNS resolution for %s", route.address) + + # Try to resolve + all_ips, current_ip, timestamp = self._resolver.resolve( + route.address, + route.all_known_ips, + route.current_ip, + route.update_time + ) + + if not current_ip: + raise DriverException( + f"[client routes] DNS resolution failed for {route.address} " + f"(host_id={host_id})" + ) + + # Update with resolved data + updated_route = route._replace( + all_known_ips=all_ips, + current_ip=current_ip, + update_time=timestamp, + forced_resolve=False + ) + + # CAS retry loop (similar to Go's for loop) + max_retries = 10 + for attempt in range(max_retries): + with self._routes._lock: + # Re-check current state + current_route = self._routes.get_by_host_id(host_id) + + if current_route is None: + raise DriverException( + f"[client routes] Route for host_id={host_id} disappeared during resolution" + ) + + # If someone else resolved it, use their result + if current_route.current_ip and current_route.update_time and updated_route.update_time: + if current_route.update_time >= updated_route.update_time: + return current_route.current_ip + + # Try to update + self._routes._routes_by_host_id[host_id] = updated_route + return current_ip + + # If we exhausted retries, just return what we resolved + log.warning("[client routes] CAS retry limit reached for host_id=%s", host_id) + return current_ip + + def get_port(self, host): + """ + Get the appropriate port for a host based on SSL configuration. + + :param host: Host object with host_id + :return: Port number or None if not found + """ + route = self._routes.get_by_host_id(host.host_id) if host else None + if not route: + return None + + # Pick TLS port if SSL enabled and available, otherwise regular port + if self.ssl_enabled: + if route.tls_port and route.tls_port > 0: + return route.tls_port + elif route.port and route.port > 0: + return route.port + else: + if route.port and route.port > 0: + return route.port + elif route.tls_port and route.tls_port > 0: + return route.tls_port + + return None + + def shutdown(self): + """ + Shutdown the handler and release resources. + """ + with self._lock: + if self._is_shutdown: + return + + self._is_shutdown = True + log.info("[client routes] Handler shutdown") + + +class ClientRoutesAddressTranslator: + """ + AddressTranslator implementation that uses ClientRoutesHandler. + + This bridges the AddressTranslator interface with the ClientRoutesHandler. + """ + + def __init__(self, handler): + """ + :param handler: ClientRoutesHandler instance + """ + self.handler = handler + + def translate(self, addr): + """ + Legacy V1 API - not sufficient for client routes. + """ + # Can't properly translate without host_id, return unchanged + return addr + + def translate_with_host_id(self, addr, host_id=None): + """ + V2 API - translate using Host metadata. + """ + if not host_id: + return addr + + try: + return self.handler.translate_address(addr, host_id) + except Exception as e: + log.warning("[client routes] Translation failed for %s: %s", addr, e) + return addr + + diff --git a/cassandra/cluster.py b/cassandra/cluster.py index 622b706330..0f75fa01fd 100644 --- a/cassandra/cluster.py +++ b/cassandra/cluster.py @@ -1215,7 +1215,8 @@ def __init__(self, shard_aware_options=None, metadata_request_timeout: Optional[float] = None, column_encryption_policy=None, - application_info:Optional[ApplicationInfoBase]=None + application_info:Optional[ApplicationInfoBase]=None, + client_routes_config=None ): """ ``executor_threads`` defines the number of threads in a pool for handling asynchronous tasks such as @@ -1337,7 +1338,31 @@ def __init__(self, raise ValueError("conviction_policy_factory must be callable") self.conviction_policy_factory = conviction_policy_factory - if address_translator is not None: + # Validate mutual exclusivity of client_routes_config and address_translator + if client_routes_config is not None and address_translator is not None: + raise ValueError("client_routes_config and address_translator are mutually exclusive") + + # Handle client routes configuration + self._client_routes_handler = None + if client_routes_config is not None: + from cassandra.client_routes import ClientRoutesConfig, ClientRoutesHandler, ClientRoutesAddressTranslator + + if not isinstance(client_routes_config, ClientRoutesConfig): + raise TypeError("client_routes_config must be a ClientRoutesConfig instance") + + ssl_enabled = ssl_context is not None or ssl_options is not None + self._client_routes_handler = ClientRoutesHandler(client_routes_config, ssl_enabled=ssl_enabled) + self.address_translator = ClientRoutesAddressTranslator(self._client_routes_handler) + + if contact_points is _NOT_SET or not self._contact_points_explicit: + seed_addrs = [ep.connection_addr for ep in client_routes_config.endpoints + if ep.connection_addr] + if seed_addrs: + self.contact_points = seed_addrs + self._contact_points_explicit = True + log.info("[client routes] Using %d endpoint connection addresses as contact points", + len(seed_addrs)) + elif address_translator is not None: if isinstance(address_translator, type): raise TypeError("address_translator should not be a class, it should be an instance of that class") self.address_translator = address_translator @@ -1798,6 +1823,12 @@ def shutdown(self): if self.metrics_enabled and self.metrics: self.metrics.shutdown() + if self._client_routes_handler is not None: + try: + self._client_routes_handler.shutdown() + except Exception: + log.warning("Error shutting down client routes handler", exc_info=True) + _discard_cluster_shutdown(self) def __enter__(self): @@ -3612,11 +3643,24 @@ def _try_connect(self, endpoint): # this object (after a dereferencing a weakref) self_weakref = weakref.ref(self, partial(_clear_watcher, weakref.proxy(connection))) try: - connection.register_watchers({ + watchers = { "TOPOLOGY_CHANGE": partial(_watch_callback, self_weakref, '_handle_topology_change'), "STATUS_CHANGE": partial(_watch_callback, self_weakref, '_handle_status_change'), "SCHEMA_CHANGE": partial(_watch_callback, self_weakref, '_handle_schema_change') - }, register_timeout=self._timeout) + } + + if self._cluster._client_routes_handler is not None: + watchers["CLIENT_ROUTES_CHANGE"] = partial(_watch_callback, self_weakref, '_handle_client_routes_change') + + connection.register_watchers(watchers, register_timeout=self._timeout) + + if self._cluster._client_routes_handler is not None: + try: + self._cluster._client_routes_handler.initialize(self) + except Exception as e: + log.error("[control connection] Failed to initialize client routes handler: %s", e, exc_info=True) + connection.close() + raise sel_peers = self._get_peers_query(self.PeersQueryType.PEERS, connection) sel_local = self._SELECT_LOCAL if self._token_meta_enabled else self._SELECT_LOCAL_NO_TOKENS @@ -3658,6 +3702,13 @@ def _reconnect(self): log.debug("[control connection] Attempting to reconnect") try: self._set_new_connection(self._reconnect_internal()) + + # Notify client routes handler of reconnection (full re-read) + if self._cluster._client_routes_handler is not None: + try: + self._cluster._client_routes_handler.handle_control_connection_reconnect(self) + except Exception as e: + log.warning("[control connection] Failed to notify client routes handler of reconnection: %s", e) except NoHostAvailable: # make a retry schedule (which includes backoff) schedule = self._cluster.reconnection_policy.new_schedule() @@ -3979,6 +4030,31 @@ def _handle_status_change(self, event): # this will be run by the scheduler self._cluster.on_down(host, is_host_addition=False) + def _handle_client_routes_change(self, event): + """ + Handle CLIENT_ROUTES_CHANGE event from the server. + + This event indicates that the system.client_routes table has been updated + and we need to refresh our route mappings. + """ + if self._cluster._client_routes_handler is None: + log.warning("[control connection] Received CLIENT_ROUTES_CHANGE but no handler configured") + return + + change_type = event.get("change_type") + connection_ids = event.get("connection_ids", []) + host_ids = event.get("host_ids", []) + + log.debug("[control connection] Received CLIENT_ROUTES_CHANGE: change_type=%s, " + "connection_ids=%s, host_ids=%s", change_type, connection_ids, host_ids) + + # Handle the event asynchronously + self._cluster.scheduler.schedule_unique( + 0, + self._cluster._client_routes_handler.handle_client_routes_change, + self, change_type, connection_ids, host_ids + ) + def _handle_schema_change(self, event): if self._schema_event_refresh_window < 0: return diff --git a/cassandra/policies.py b/cassandra/policies.py index 1f8456481f..8e2be162be 100644 --- a/cassandra/policies.py +++ b/cassandra/policies.py @@ -1262,7 +1262,6 @@ def translate(self, addr): pass return addr - class SpeculativeExecutionPolicy(object): """ Interface for specifying speculative execution plans diff --git a/tests/integration/standard/test_client_routes.py b/tests/integration/standard/test_client_routes.py new file mode 100644 index 0000000000..ea78d0cd48 --- /dev/null +++ b/tests/integration/standard/test_client_routes.py @@ -0,0 +1,187 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import unittest +import uuid + +from cassandra.cluster import Cluster +from cassandra.client_routes import ClientRoutesConfig, ClientRoutesEndpoint, ClientRoutesHandler +from tests.integration import TestCluster, use_cluster + +def setup_module(): + os.environ['SCYLLA_EXT_OPTS'] = "--smp 2 --memory 2048M" + use_cluster('test_client_routes', [3], start=True) + +class TestGetHostPortMapping(unittest.TestCase): + """ + Test _query_routes method with different filtering scenarios. + + This test matches the Golang TestGetHostPortMapping implementation. + """ + + @classmethod + def setUpClass(cls): + """Create test keyspace and table, populate with test data.""" + cls.cluster = TestCluster() + cls.session = cls.cluster.connect() + + cls.session.execute(""" + CREATE KEYSPACE IF NOT EXISTS gocql_test + WITH replication = {'class': 'SimpleStrategy', 'replication_factor': 1} + """) + + cls.session.execute(""" + CREATE TABLE IF NOT EXISTS gocql_test.client_routes ( + connection_id uuid, + host_id uuid, + address text, + port int, + tls_port int, + alternator_port int, + alternator_https_port int, + datacenter text, + rack text, + PRIMARY KEY (connection_id, host_id) + ) + """) + + cls.session.execute("TRUNCATE gocql_test.client_routes") + + cls.host_ids = [uuid.uuid4() for _ in range(3)] + cls.connection_ids = [uuid.uuid4() for _ in range(3)] + cls.racks = ["rack1", "rack2", "rack3"] + cls.expected = [] + + for idx, host_id in enumerate(cls.host_ids): + rack = cls.racks[idx] + ip = f"127.0.0.{idx + 1}" + + for connection_id in cls.connection_ids: + cls.session.execute( + """ + INSERT INTO gocql_test.client_routes + (connection_id, host_id, address, port, tls_port, + alternator_port, alternator_https_port, datacenter, rack) + VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s) + """, + (connection_id, host_id, ip, 9042, 9142, 0, 0, 'dc1', rack) + ) + + cls.expected.append({ + 'connection_id': connection_id, + 'host_id': host_id, + 'address': ip, + 'port': 9042, + 'tls_port': 9142, + 'datacenter': 'dc1', + 'rack': rack + }) + + cls._sort_routes(cls.expected) + + @classmethod + def tearDownClass(cls): + """Clean up test keyspace.""" + try: + cls.session.execute("DROP KEYSPACE IF EXISTS gocql_test") + finally: + cls.cluster.shutdown() + + @staticmethod + def _sort_routes(routes): + """Sort routes by connection_id then host_id for deterministic comparison.""" + routes.sort(key=lambda r: (str(r['connection_id']), str(r['host_id']))) + + def _query_and_compare(self, connection_ids, host_ids, expected): + """ + Query routes using ClientRoutesHandler._query_routes and compare with expected. + + :param connection_ids: List of connection UUIDs or None + :param host_ids: List of host UUIDs or None + :param expected: Expected list of route dicts + """ + config = ClientRoutesConfig( + endpoints=[ClientRoutesEndpoint( + connection_id=self.connection_ids[0], + connection_addr="127.0.0.1" + )], + table_name="gocql_test.client_routes" + ) + handler = ClientRoutesHandler(config) + + routes = handler._query_routes( + self.cluster.control_connection, + connection_ids=connection_ids, + host_ids=host_ids + ) + + got = [] + for route in routes: + got.append({ + 'connection_id': route.connection_id, + 'host_id': route.host_id, + 'address': route.address, + 'port': route.port, + 'tls_port': route.tls_port, + 'datacenter': route.datacenter, + 'rack': route.rack + }) + + self._sort_routes(got) + + self.assertEqual(len(got), len(expected), + f"Expected {len(expected)} routes, got {len(got)}") + + for i, (got_route, expected_route) in enumerate(zip(got, expected)): + self.assertEqual(got_route['connection_id'], expected_route['connection_id'], + f"Route {i}: connection_id mismatch") + self.assertEqual(got_route['host_id'], expected_route['host_id'], + f"Route {i}: host_id mismatch") + self.assertEqual(got_route['address'], expected_route['address'], + f"Route {i}: address mismatch") + self.assertEqual(got_route['port'], expected_route['port'], + f"Route {i}: port mismatch") + self.assertEqual(got_route['tls_port'], expected_route['tls_port'], + f"Route {i}: tls_port mismatch") + + def test_get_all(self): + """Test querying all routes without filters.""" + self._query_and_compare(None, None, self.expected) + + def test_get_all_hosts(self): + """Test querying with connection_ids filter only.""" + self._query_and_compare(self.connection_ids, None, self.expected) + + def test_get_all_connections(self): + """Test querying with host_ids filter only.""" + self._query_and_compare(None, self.host_ids, self.expected) + + def test_get_concrete(self): + """Test querying with both connection_ids and host_ids filters.""" + self._query_and_compare(self.connection_ids, self.host_ids, self.expected) + + def test_get_concrete_host(self): + """Test querying specific connection and host combination.""" + filtered_expected = [ + r for r in self.expected + if r['connection_id'] == self.connection_ids[0] and + r['host_id'] == self.host_ids[0] + ] + + self._query_and_compare( + [self.connection_ids[0]], + [self.host_ids[0]], + filtered_expected + ) diff --git a/tests/unit/test_client_routes.py b/tests/unit/test_client_routes.py new file mode 100644 index 0000000000..e31e05e2fd --- /dev/null +++ b/tests/unit/test_client_routes.py @@ -0,0 +1,600 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import uuid +import time +from unittest.mock import Mock, patch +from collections import namedtuple + +from cassandra.client_routes import ( + ClientRoutesEndpoint, + ClientRoutesConfig, + DNSResolver, + ResolvedRoutes, + ResolvedRoute, + ClientRoutesHandler, + ClientRoutesAddressTranslator, + _parse_route_row +) +from cassandra import DriverException + + +class TestClientRoutesEndpoint(unittest.TestCase): + + def test_endpoint_with_uuid_string(self): + conn_id = str(uuid.uuid4()) + endpoint = ClientRoutesEndpoint(conn_id, "10.0.0.1") + self.assertIsInstance(endpoint.connection_id, uuid.UUID) + self.assertEqual(endpoint.connection_addr, "10.0.0.1") + + def test_endpoint_with_uuid_object(self): + conn_id = uuid.uuid4() + endpoint = ClientRoutesEndpoint(conn_id) + self.assertEqual(endpoint.connection_id, conn_id) + self.assertIsNone(endpoint.connection_addr) + + def test_endpoint_invalid_uuid(self): + with self.assertRaises(ValueError): + ClientRoutesEndpoint("not-a-uuid") + + def test_endpoint_none_connection_id(self): + with self.assertRaises(ValueError): + ClientRoutesEndpoint(None) + + +class TestClientRoutesConfig(unittest.TestCase): + + def test_config_with_endpoints(self): + ep1 = ClientRoutesEndpoint(uuid.uuid4(), "10.0.0.1") + ep2 = ClientRoutesEndpoint(uuid.uuid4(), "10.0.0.2") + config = ClientRoutesConfig([ep1, ep2]) + self.assertEqual(len(config.endpoints), 2) + self.assertEqual(config.table_name, "system.client_routes") + + def test_config_custom_table_name(self): + ep = ClientRoutesEndpoint(uuid.uuid4()) + config = ClientRoutesConfig([ep], table_name="custom.routes") + self.assertEqual(config.table_name, "custom.routes") + + def test_config_empty_endpoints(self): + with self.assertRaises(ValueError): + ClientRoutesConfig([]) + + def test_config_invalid_endpoint_type(self): + with self.assertRaises(TypeError): + ClientRoutesConfig(["not-an-endpoint"]) + + def test_config_max_resolver_concurrency_validation(self): + ep = ClientRoutesEndpoint(uuid.uuid4()) + # Must be > 0 + with self.assertRaises(ValueError): + ClientRoutesConfig([ep], max_resolver_concurrency=0) + with self.assertRaises(ValueError): + ClientRoutesConfig([ep], max_resolver_concurrency=-1) + + def test_config_resolve_period_validation(self): + ep = ClientRoutesEndpoint(uuid.uuid4()) + # Must be >= 0 + with self.assertRaises(ValueError): + ClientRoutesConfig([ep], resolve_healthy_endpoint_period_ms=-1) + # 0 is valid (never re-resolve healthy) + config = ClientRoutesConfig([ep], resolve_healthy_endpoint_period_ms=0) + self.assertEqual(config.resolve_healthy_endpoint_period_ms, 0) + + +class TestDNSResolver(unittest.TestCase): + + @patch('cassandra.client_routes.socket.getaddrinfo') + def test_resolve_success(self, mock_getaddrinfo): + mock_getaddrinfo.return_value = [ + (None, None, None, None, ('192.168.1.1', 9042)), + (None, None, None, None, ('192.168.1.2', 9042)) + ] + + resolver = DNSResolver() + all_ips, current_ip, timestamp = resolver.resolve("example.com") + + self.assertIsNotNone(all_ips) + self.assertEqual(len(all_ips), 2) + self.assertIn(current_ip, all_ips) + self.assertIsNotNone(timestamp) + + @patch('cassandra.client_routes.socket.getaddrinfo') + def test_resolve_uses_cache(self, mock_getaddrinfo): + mock_getaddrinfo.return_value = [ + (None, None, None, None, ('192.168.1.1', 9042)) + ] + + resolver = DNSResolver(cache_duration_ms=5000) + + # First resolution + all_ips1, current_ip1, ts1 = resolver.resolve("example.com") + + # Second resolution with cache (should return cached result) + all_ips2, current_ip2, ts2 = resolver.resolve("example.com", all_ips1, current_ip1, ts1) + + self.assertEqual(all_ips1, all_ips2) + self.assertEqual(current_ip1, current_ip2) + self.assertEqual(ts1, ts2) + # getaddrinfo should only be called once + self.assertEqual(mock_getaddrinfo.call_count, 1) + + @patch('cassandra.client_routes.socket.getaddrinfo') + def test_resolve_failure_returns_cached(self, mock_getaddrinfo): + mock_getaddrinfo.side_effect = Exception("DNS failure") + + resolver = DNSResolver() + cached_ips = ["192.168.1.100"] + cached_ip = "192.168.1.100" + cached_at = 12345.0 + + all_ips, current_ip, timestamp = resolver.resolve("example.com", cached_ips, cached_ip, cached_at) + + # Should return cached values on failure + self.assertEqual(all_ips, cached_ips) + self.assertEqual(current_ip, cached_ip) + self.assertEqual(timestamp, cached_at) + + +class TestResolvedRoutes(unittest.TestCase): + + def test_get_by_host_id(self): + routes = ResolvedRoutes() + host_id = uuid.uuid4() + route = ResolvedRoute( + connection_id=uuid.uuid4(), + host_id=host_id, + address="example.com", + port=9042, + tls_port=9142, + datacenter="dc1", + rack="rack1", + all_known_ips=["192.168.1.1", "192.168.1.2"], + current_ip="192.168.1.1", + update_time=12345.0, + forced_resolve=False + ) + + routes.update([route]) + + retrieved = routes.get_by_host_id(host_id) + self.assertEqual(retrieved.host_id, host_id) + self.assertEqual(retrieved.current_ip, "192.168.1.1") + + def test_merge_routes(self): + routes = ResolvedRoutes() + host_id1 = uuid.uuid4() + host_id2 = uuid.uuid4() + + route1 = ResolvedRoute( + connection_id=uuid.uuid4(), host_id=host_id1, + address="host1.com", port=9042, tls_port=None, + datacenter="dc1", rack="rack1", + all_known_ips=["192.168.1.1"], current_ip="192.168.1.1", + update_time=12345.0, forced_resolve=False + ) + + route2 = ResolvedRoute( + connection_id=uuid.uuid4(), host_id=host_id2, + address="host2.com", port=9042, tls_port=None, + datacenter="dc1", rack="rack1", + all_known_ips=["192.168.1.2"], current_ip="192.168.1.2", + update_time=12345.0, forced_resolve=False + ) + + routes.update([route1]) + routes.merge([route2]) + + self.assertIsNotNone(routes.get_by_host_id(host_id1)) + self.assertIsNotNone(routes.get_by_host_id(host_id2)) + + +class TestClientRoutesHandler(unittest.TestCase): + + def setUp(self): + self.conn_id = uuid.uuid4() + self.endpoint = ClientRoutesEndpoint(self.conn_id, "10.0.0.1") + self.config = ClientRoutesConfig([self.endpoint]) + + def test_handler_initialization(self): + handler = ClientRoutesHandler(self.config, ssl_enabled=False) + self.assertIsNotNone(handler) + self.assertEqual(handler.ssl_enabled, False) + + @patch.object(ClientRoutesHandler, '_query_routes') + @patch.object(ClientRoutesHandler, '_resolve_and_update_in_place') + def test_initialize(self, mock_resolve, mock_query): + host_id = uuid.uuid4() + mock_query.return_value = [ + ResolvedRoute( + connection_id=self.conn_id, + host_id=host_id, + address="node1.example.com", + port=9042, + tls_port=9142, + datacenter="dc1", + rack="rack1", + all_known_ips=None, + current_ip=None, + update_time=None, + forced_resolve=True + ) + ] + + handler = ClientRoutesHandler(self.config) + mock_control_conn = Mock() + + handler.initialize(mock_control_conn) + + # Verify route was stored + route = handler._routes.get_by_host_id(host_id) + self.assertIsNotNone(route) + mock_resolve.assert_called_once() + + def test_translate_address_no_host(self): + handler = ClientRoutesHandler(self.config) + addr = "192.168.1.1" + + # Should return unchanged when host is None + result = handler.translate_address(addr, None) + self.assertEqual(result, addr) + + def test_translate_address_not_found(self): + handler = ClientRoutesHandler(self.config) + + class MockHost: + def __init__(self): + self.host_id = uuid.uuid4() + + mock_host = MockHost() + + with self.assertRaises(DriverException) as cm: + handler.translate_address("192.168.1.1", mock_host) + + self.assertIn("No route found", str(cm.exception)) + + def test_get_port_ssl_enabled(self): + handler = ClientRoutesHandler(self.config, ssl_enabled=True) + host_id = uuid.uuid4() + + route = ResolvedRoute( + connection_id=self.conn_id, + host_id=host_id, + address="node1.example.com", + port=9042, + tls_port=9142, + datacenter="dc1", + rack="rack1", + all_known_ips=["192.168.1.1"], + current_ip="192.168.1.1", + update_time=12345.0, + forced_resolve=False + ) + handler._routes.update([route]) + + class MockHost: + def __init__(self, host_id): + self.host_id = host_id + + mock_host = MockHost(host_id) + port = handler.get_port(mock_host) + + # Should return TLS port when SSL is enabled + self.assertEqual(port, 9142) + + +class TestClientRoutesAddressTranslator(unittest.TestCase): + + def test_translate_v1_returns_unchanged(self): + mock_handler = Mock() + translator = ClientRoutesAddressTranslator(mock_handler) + + addr = "192.168.1.1" + result = translator.translate(addr) + + # V1 API should return unchanged (can't translate without host_id) + self.assertEqual(result, addr) + + def test_translate_with_host_success(self): + mock_handler = Mock() + mock_handler.translate_address.return_value = "10.0.0.1" + + translator = ClientRoutesAddressTranslator(mock_handler) + + result = translator.translate_with_host_id("192.168.1.1", uuid.uuid4()) + + self.assertEqual(result, "10.0.0.1") + mock_handler.translate_address.assert_called_once() + + +class TestParseRouteRow(unittest.TestCase): + + def test_parse_complete_row(self): + RowType = namedtuple('RowType', [ + 'connection_id', 'host_id', 'address', 'port', + 'tls_port', 'datacenter', 'rack' + ]) + + conn_id = uuid.uuid4() + host_id = uuid.uuid4() + + row = { + 'connection_id': conn_id, + 'host_id': host_id, + 'address': "node1.example.com", + 'port': 9042, + 'tls_port': 9142, + 'datacenter': "dc1", + 'rack': "rack1" + } + + route = _parse_route_row(row) + + self.assertEqual(route.connection_id, conn_id) + self.assertEqual(route.host_id, host_id) + self.assertEqual(route.address, "node1.example.com") + self.assertEqual(route.port, 9042) + self.assertEqual(route.tls_port, 9142) + self.assertEqual(route.datacenter, "dc1") + self.assertEqual(route.rack, "rack1") + self.assertIsNone(route.all_known_ips) + self.assertIsNone(route.current_ip) + self.assertIsNone(route.update_time) + self.assertTrue(route.forced_resolve) # Should be True initially + + +class TestResolvedRouteMergeLogic(unittest.TestCase): + """Test ResolvedRoutes merge operations matching Golang behavior.""" + + def test_merge_with_unresolved_unchanged_record(self): + """Test that unchanged records don't get forcedResolve flag set.""" + routes = ResolvedRoutes() + conn_id = uuid.uuid4() + host_id = uuid.uuid4() + + # Initial resolved route + route = ResolvedRoute( + connection_id=conn_id, + host_id=host_id, + address="a1", + port=9042, + tls_port=None, + datacenter=None, + rack=None, + all_known_ips=["10.0.0.1"], + current_ip="10.0.0.1", + update_time=time.time(), + forced_resolve=False + ) + routes.update([route]) + + # Merge with identical unresolved route + unresolved = [ResolvedRoute( + connection_id=conn_id, + host_id=host_id, + address="a1", + port=9042, + tls_port=None, + datacenter=None, + rack=None, + all_known_ips=None, + current_ip=None, + update_time=None, + forced_resolve=True + )] + routes.merge_with_unresolved(unresolved) + + # Should remain unchanged (forced_resolve stays False) + result = routes.get_by_host_id(host_id) + self.assertFalse(result.forced_resolve) + + def test_merge_with_unresolved_changed_record(self): + """Test that changed records get forcedResolve flag set.""" + routes = ResolvedRoutes() + conn_id = uuid.uuid4() + host_id = uuid.uuid4() + + # Initial resolved route + route = ResolvedRoute( + connection_id=conn_id, + host_id=host_id, + address="a1", + port=9042, + tls_port=None, + datacenter=None, + rack=None, + all_known_ips=["10.0.0.1"], + current_ip="10.0.0.1", + update_time=time.time(), + forced_resolve=False + ) + routes.update([route]) + + # Merge with changed unresolved route + unresolved = [ResolvedRoute( + connection_id=conn_id, + host_id=host_id, + address="a2", # Changed + port=9043, # Changed + tls_port=None, + datacenter=None, + rack=None, + all_known_ips=None, + current_ip=None, + update_time=None, + forced_resolve=True + )] + routes.merge_with_unresolved(unresolved) + + # Should be updated with forcedResolve=True + result = routes.get_by_host_id(host_id) + self.assertEqual(result.address, "a2") + self.assertEqual(result.port, 9043) + self.assertTrue(result.forced_resolve) + + def test_merge_with_unresolved_new_record(self): + """Test that new records are added with forcedResolve=True.""" + routes = ResolvedRoutes() + conn_id = uuid.uuid4() + host_id = uuid.uuid4() + + # Merge with new unresolved route + unresolved = [ResolvedRoute( + connection_id=conn_id, + host_id=host_id, + address="a3", + port=9044, + tls_port=None, + datacenter=None, + rack=None, + all_known_ips=None, + current_ip=None, + update_time=None, + forced_resolve=True + )] + routes.merge_with_unresolved(unresolved) + + # Should be added with forcedResolve=True + result = routes.get_by_host_id(host_id) + self.assertIsNotNone(result) + self.assertTrue(result.forced_resolve) + + +class TestQueryBuilding(unittest.TestCase): + """Test query building with different filter combinations matching Golang behavior.""" + + @patch.object(ClientRoutesHandler, '_resolve_and_update_in_place') + def test_query_all_routes(self, mock_resolve): + """Test querying all routes without filters.""" + conn_id = uuid.uuid4() + config = ClientRoutesConfig([ClientRoutesEndpoint(conn_id, "10.0.0.1")]) + handler = ClientRoutesHandler(config) + + mock_control_conn = Mock() + mock_connection = Mock() + mock_control_conn._connection = mock_connection + mock_control_conn._timeout = 10 + + # Mock response + mock_result = Mock() + mock_result.parsed_rows = [] + mock_result.column_names = [] + mock_connection.wait_for_response.return_value = mock_result + + # Query all routes + handler._query_routes(mock_control_conn, connection_ids=None, host_ids=None) + + # Verify query structure + call_args = mock_connection.wait_for_response.call_args + query_msg = call_args[0][0] + query_str = query_msg.query.lower() + + self.assertIn("select * from", query_str) + self.assertIn("allow filtering", query_str) + self.assertNotIn("where", query_str) + + @patch.object(ClientRoutesHandler, '_resolve_and_update_in_place') + def test_query_with_connection_ids_only(self, mock_resolve): + """Test querying with connection_ids filter only.""" + conn_id1 = uuid.uuid4() + conn_id2 = uuid.uuid4() + config = ClientRoutesConfig([ClientRoutesEndpoint(conn_id1, "10.0.0.1")]) + handler = ClientRoutesHandler(config) + + mock_control_conn = Mock() + mock_connection = Mock() + mock_control_conn._connection = mock_connection + mock_control_conn._timeout = 10 + + mock_result = Mock() + mock_result.parsed_rows = [] + mock_result.column_names = [] + mock_connection.wait_for_response.return_value = mock_result + + # Query with connection_ids filter + handler._query_routes(mock_control_conn, connection_ids=[conn_id1, conn_id2], host_ids=None) + + call_args = mock_connection.wait_for_response.call_args + query_msg = call_args[0][0] + query_str = query_msg.query.lower() + + self.assertIn("where", query_str) + self.assertIn("connection_id in", query_str) + self.assertNotIn("host_id in", query_str) + + @patch.object(ClientRoutesHandler, '_resolve_and_update_in_place') + def test_query_with_host_ids_only(self, mock_resolve): + """Test querying with host_ids filter only.""" + conn_id = uuid.uuid4() + host_id = uuid.uuid4() + config = ClientRoutesConfig([ClientRoutesEndpoint(conn_id, "10.0.0.1")]) + handler = ClientRoutesHandler(config) + + mock_control_conn = Mock() + mock_connection = Mock() + mock_control_conn._connection = mock_connection + mock_control_conn._timeout = 10 + + mock_result = Mock() + mock_result.parsed_rows = [] + mock_result.column_names = [] + mock_connection.wait_for_response.return_value = mock_result + + # Query with host_ids filter + handler._query_routes(mock_control_conn, connection_ids=None, host_ids=[host_id]) + + call_args = mock_connection.wait_for_response.call_args + query_msg = call_args[0][0] + query_str = query_msg.query.lower() + + self.assertIn("where", query_str) + self.assertIn("host_id in", query_str) + self.assertNotIn("connection_id in", query_str) + + @patch.object(ClientRoutesHandler, '_resolve_and_update_in_place') + def test_query_with_both_filters(self, mock_resolve): + """Test querying with both connection_ids and host_ids filters.""" + conn_id = uuid.uuid4() + host_id1 = uuid.uuid4() + host_id2 = uuid.uuid4() + config = ClientRoutesConfig([ClientRoutesEndpoint(conn_id, "10.0.0.1")]) + handler = ClientRoutesHandler(config) + + mock_control_conn = Mock() + mock_connection = Mock() + mock_control_conn._connection = mock_connection + mock_control_conn._timeout = 10 + + mock_result = Mock() + mock_result.parsed_rows = [] + mock_result.column_names = [] + mock_connection.wait_for_response.return_value = mock_result + + # Query with both filters + handler._query_routes(mock_control_conn, connection_ids=[conn_id], host_ids=[host_id1, host_id2]) + + call_args = mock_connection.wait_for_response.call_args + query_msg = call_args[0][0] + query_str = query_msg.query.lower() + + self.assertIn("where", query_str) + self.assertIn("connection_id in", query_str) + self.assertIn("host_id in", query_str) + # When both filters present, should not use ALLOW FILTERING + self.assertNotIn("allow filtering", query_str) + + +if __name__ == '__main__': + unittest.main()