From 3b01b0ddd5b3ebc8d14030b6fbc89134e6222b3a Mon Sep 17 00:00:00 2001 From: jackwotherspoon Date: Sun, 2 Feb 2025 21:28:33 +0000 Subject: [PATCH 01/16] chore: add failover_period to Connector --- google/cloud/sql/connector/connector.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/google/cloud/sql/connector/connector.py b/google/cloud/sql/connector/connector.py index 5160513f..ba419dde 100755 --- a/google/cloud/sql/connector/connector.py +++ b/google/cloud/sql/connector/connector.py @@ -67,6 +67,7 @@ def __init__( universe_domain: Optional[str] = None, refresh_strategy: str | RefreshStrategy = RefreshStrategy.BACKGROUND, resolver: type[DefaultResolver] | type[DnsResolver] = DefaultResolver, + failover_period: int = 30, ) -> None: """Initializes a Connector instance. @@ -114,6 +115,11 @@ def __init__( name. To resolve a DNS record to an instance connection name, use DnsResolver. Default: DefaultResolver + + failover_period (int): The time interval in seconds between each + attempt to check if a failover has occured for a given instance. + Must be used with `resolver=DnsResolver` to have any effect. + Default: 30 """ # if refresh_strategy is str, convert to RefreshStrategy enum if isinstance(refresh_strategy, str): @@ -168,6 +174,7 @@ def __init__( self._quota_project = quota_project self._user_agent = user_agent self._resolver = resolver() + self._failover_period = failover_period # if ip_type is str, convert to IPTypes enum if isinstance(ip_type, str): ip_type = IPTypes._from_str(ip_type) From d783d636369e426f8b1bf823a084baf20e83699a Mon Sep 17 00:00:00 2001 From: jackwotherspoon Date: Mon, 3 Mar 2025 18:10:17 +0000 Subject: [PATCH 02/16] feat: automatically reset connection on failover --- google/cloud/sql/connector/connection_info.py | 22 ++++ google/cloud/sql/connector/connection_name.py | 4 + google/cloud/sql/connector/connector.py | 19 +-- google/cloud/sql/connector/instance.py | 13 ++- google/cloud/sql/connector/lazy.py | 15 ++- google/cloud/sql/connector/monitored_cache.py | 109 ++++++++++++++++++ tests/unit/test_connection_name.py | 4 + tests/unit/test_lazy.py | 21 ++++ 8 files changed, 197 insertions(+), 10 deletions(-) create mode 100644 google/cloud/sql/connector/monitored_cache.py diff --git a/google/cloud/sql/connector/connection_info.py b/google/cloud/sql/connector/connection_info.py index 82e3a901..c9e48935 100644 --- a/google/cloud/sql/connector/connection_info.py +++ b/google/cloud/sql/connector/connection_info.py @@ -14,6 +14,7 @@ from __future__ import annotations +import abc from dataclasses import dataclass import logging import ssl @@ -34,6 +35,27 @@ logger = logging.getLogger(name=__name__) +class ConnectionInfoCache(abc.ABC): + """Abstract class for Connector connection info caches.""" + + @abc.abstractmethod + async def connect_info(self) -> ConnectionInfo: + pass + + @abc.abstractmethod + async def force_refresh(self) -> None: + pass + + @abc.abstractmethod + async def close(self) -> None: + pass + + @property + @abc.abstractmethod + def closed(self) -> bool: + pass + + @dataclass class ConnectionInfo: """Contains all necessary information to connect securely to the diff --git a/google/cloud/sql/connector/connection_name.py b/google/cloud/sql/connector/connection_name.py index 1bf711ab..4f4b9979 100644 --- a/google/cloud/sql/connector/connection_name.py +++ b/google/cloud/sql/connector/connection_name.py @@ -38,6 +38,10 @@ def __str__(self) -> str: return f"{self.domain_name} -> {self.project}:{self.region}:{self.instance_name}" return f"{self.project}:{self.region}:{self.instance_name}" + def get_connection_string(self) -> str: + """Get the instance connection string for the Cloud SQL instance.""" + return f"{self.project}:{self.region}:{self.instance_name}" + def _parse_connection_name(connection_name: str) -> ConnectionName: return _parse_connection_name_with_domain_name(connection_name, "") diff --git a/google/cloud/sql/connector/connector.py b/google/cloud/sql/connector/connector.py index ba419dde..c993f56c 100755 --- a/google/cloud/sql/connector/connector.py +++ b/google/cloud/sql/connector/connector.py @@ -35,6 +35,7 @@ from google.cloud.sql.connector.enums import RefreshStrategy from google.cloud.sql.connector.instance import RefreshAheadCache from google.cloud.sql.connector.lazy import LazyRefreshCache +from google.cloud.sql.connector.monitored_cache import MonitoredCache import google.cloud.sql.connector.pg8000 as pg8000 import google.cloud.sql.connector.pymysql as pymysql import google.cloud.sql.connector.pytds as pytds @@ -149,9 +150,7 @@ def __init__( ) # initialize dict to store caches, key is a tuple consisting of instance # connection name string and enable_iam_auth boolean flag - self._cache: dict[ - tuple[str, bool], Union[RefreshAheadCache, LazyRefreshCache] - ] = {} + self._cache: dict[tuple[str, bool], MonitoredCache] = {} self._client: Optional[CloudSQLClient] = None # initialize credentials @@ -289,14 +288,14 @@ async def connect_async( ) enable_iam_auth = kwargs.pop("enable_iam_auth", self._enable_iam_auth) if (instance_connection_string, enable_iam_auth) in self._cache: - cache = self._cache[(instance_connection_string, enable_iam_auth)] + monitored_cache = self._cache[(instance_connection_string, enable_iam_auth)] else: conn_name = await self._resolver.resolve(instance_connection_string) if self._refresh_strategy == RefreshStrategy.LAZY: logger.debug( f"['{conn_name}']: Refresh strategy is set to lazy refresh" ) - cache = LazyRefreshCache( + cache: Union[LazyRefreshCache, RefreshAheadCache] = LazyRefreshCache( conn_name, self._client, self._keys, @@ -312,8 +311,14 @@ async def connect_async( self._keys, enable_iam_auth, ) + # wrap cache as a MonitoredCache + monitored_cache = MonitoredCache( + cache, + self._failover_period, + self._resolver, + ) logger.debug(f"['{conn_name}']: Connection info added to cache") - self._cache[(instance_connection_string, enable_iam_auth)] = cache + self._cache[(instance_connection_string, enable_iam_auth)] = monitored_cache connect_func = { "pymysql": pymysql.connect, @@ -342,7 +347,7 @@ async def connect_async( # attempt to get connection info for Cloud SQL instance try: - conn_info = await cache.connect_info() + conn_info = await monitored_cache.connect_info() # validate driver matches intended database engine DriverMapping.validate_engine(driver, conn_info.database_version) ip_address = conn_info.get_preferred_ip(ip_type) diff --git a/google/cloud/sql/connector/instance.py b/google/cloud/sql/connector/instance.py index 5df272fe..fb871130 100644 --- a/google/cloud/sql/connector/instance.py +++ b/google/cloud/sql/connector/instance.py @@ -24,6 +24,7 @@ from google.cloud.sql.connector.client import CloudSQLClient from google.cloud.sql.connector.connection_info import ConnectionInfo +from google.cloud.sql.connector.connection_info import ConnectionInfoCache from google.cloud.sql.connector.connection_name import ConnectionName from google.cloud.sql.connector.exceptions import RefreshNotValidError from google.cloud.sql.connector.rate_limiter import AsyncRateLimiter @@ -35,7 +36,7 @@ APPLICATION_NAME = "cloud-sql-python-connector" -class RefreshAheadCache: +class RefreshAheadCache(ConnectionInfoCache): """Cache that refreshes connection info in the background prior to expiration. Background tasks are used to schedule refresh attempts to get a new @@ -74,6 +75,15 @@ def __init__( self._refresh_in_progress = asyncio.locks.Event() self._current: asyncio.Task = self._schedule_refresh(0) self._next: asyncio.Task = self._current + self._closed = False + + @property + def conn_name(self) -> ConnectionName: + return self._conn_name + + @property + def closed(self) -> bool: + return self._closed async def force_refresh(self) -> None: """ @@ -212,3 +222,4 @@ async def close(self) -> None: # gracefully wait for tasks to cancel tasks = asyncio.gather(self._current, self._next, return_exceptions=True) await asyncio.wait_for(tasks, timeout=2.0) + self._closed = True diff --git a/google/cloud/sql/connector/lazy.py b/google/cloud/sql/connector/lazy.py index 1bc4f90f..c75d07e5 100644 --- a/google/cloud/sql/connector/lazy.py +++ b/google/cloud/sql/connector/lazy.py @@ -21,13 +21,14 @@ from google.cloud.sql.connector.client import CloudSQLClient from google.cloud.sql.connector.connection_info import ConnectionInfo +from google.cloud.sql.connector.connection_info import ConnectionInfoCache from google.cloud.sql.connector.connection_name import ConnectionName from google.cloud.sql.connector.refresh_utils import _refresh_buffer logger = logging.getLogger(name=__name__) -class LazyRefreshCache: +class LazyRefreshCache(ConnectionInfoCache): """Cache that refreshes connection info when a caller requests a connection. Only refreshes the cache when a new connection is requested and the current @@ -62,6 +63,15 @@ def __init__( self._lock = asyncio.Lock() self._cached: Optional[ConnectionInfo] = None self._needs_refresh = False + self._closed = False + + @property + def conn_name(self) -> ConnectionName: + return self._conn_name + + @property + def closed(self) -> bool: + return self._closed async def force_refresh(self) -> None: """ @@ -121,4 +131,5 @@ async def close(self) -> None: """Close is a no-op and provided purely for a consistent interface with other cache types. """ - pass + self._closed = True + return diff --git a/google/cloud/sql/connector/monitored_cache.py b/google/cloud/sql/connector/monitored_cache.py new file mode 100644 index 00000000..0a4cb6f1 --- /dev/null +++ b/google/cloud/sql/connector/monitored_cache.py @@ -0,0 +1,109 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import logging +from typing import Any, Callable, Optional, Union + +from google.cloud.sql.connector.connection_info import ConnectionInfo +from google.cloud.sql.connector.connection_info import ConnectionInfoCache +from google.cloud.sql.connector.instance import RefreshAheadCache +from google.cloud.sql.connector.lazy import LazyRefreshCache +from google.cloud.sql.connector.resolver import DefaultResolver +from google.cloud.sql.connector.resolver import DnsResolver + +logger = logging.getLogger(name=__name__) + + +class MonitoredCache(ConnectionInfoCache): + def __init__( + self, + cache: Union[RefreshAheadCache, LazyRefreshCache], + failover_period: int, + resolver: Union[DefaultResolver, DnsResolver], + ) -> None: + self.resolver = resolver + self.cache = cache + self.domain_name_ticker: Optional[asyncio.Task] = None + self.open_conns_count: int = 0 + + if self.cache.conn_name.domain_name: + self.domain_name_ticker = asyncio.create_task( + ticker(failover_period, self._check_domain_name) + ) + logger.debug( + f"['{self.cache.conn_name}']: Configured polling of domain " + f"name with failover period of {failover_period} seconds." + ) + + @property + def closed(self) -> bool: + return self.cache.closed + + async def _check_domain_name(self) -> None: + try: + # Resolve domain name and see if Cloud SQL instance connection name + # has changed. If it has, close all connections. + new_conn_name = await self.resolver.resolve( + self.cache.conn_name.domain_name + ) + if new_conn_name != self.cache.conn_name: + logger.debug( + f"['{self.cache.conn_name}']: Cloud SQL instance changed " + f"from {self.cache.conn_name.get_connection_string()} to " + f"{new_conn_name.get_connection_string()}, closing all " + "connections!" + ) + await self.close() + + except Exception as e: + # Domain name checks should not be fatal, log error and continue. + logger.debug( + f"['{self.cache.conn_name}']: Unable to check domain name, " + f"domain name {self.cache.conn_name.domain_name} did not " + f"resolve: {e}" + ) + + async def connect_info(self) -> ConnectionInfo: + return await self.cache.connect_info() + + async def force_refresh(self) -> None: + return await self.cache.force_refresh() + + async def close(self) -> None: + # Cancel domain name ticker task. + if self.domain_name_ticker: + self.domain_name_ticker.cancel() + try: + await self.domain_name_ticker + except asyncio.CancelledError: + logger.debug( + f"['{self.cache.conn_name}']: Cancelled domain name polling task." + ) + + # If cache is already closed, no further work. + if self.cache.closed: + return + await self.cache.close() + + +async def ticker(interval: int, function: Callable, *args: Any, **kwargs: Any) -> None: + """ + Ticker function to sleep for specified interval and then schedule call + to given function. + """ + while True: + # Sleep for interval and then schedule task + await asyncio.sleep(interval) + asyncio.create_task(function(*args, **kwargs)) diff --git a/tests/unit/test_connection_name.py b/tests/unit/test_connection_name.py index 783e14fe..9089618d 100644 --- a/tests/unit/test_connection_name.py +++ b/tests/unit/test_connection_name.py @@ -30,6 +30,8 @@ def test_ConnectionName() -> None: assert conn_name.domain_name == "" # test ConnectionName str() method prints instance connection name assert str(conn_name) == "project:region:instance" + # test ConnectionName.get_connection_string + assert conn_name.get_connection_string() == "project:region:instance" def test_ConnectionName_with_domain_name() -> None: @@ -41,6 +43,8 @@ def test_ConnectionName_with_domain_name() -> None: assert conn_name.domain_name == "db.example.com" # test ConnectionName str() method prints with domain name assert str(conn_name) == "db.example.com -> project:region:instance" + # test ConnectionName.get_connection_string + assert conn_name.get_connection_string() == "project:region:instance" @pytest.mark.parametrize( diff --git a/tests/unit/test_lazy.py b/tests/unit/test_lazy.py index 344b073e..c6eef750 100644 --- a/tests/unit/test_lazy.py +++ b/tests/unit/test_lazy.py @@ -21,6 +21,27 @@ from google.cloud.sql.connector.utils import generate_keys +async def test_LazyRefreshCache_properties(fake_client: CloudSQLClient) -> None: + """ + Test that LazyRefreshCache properties work as expected. + """ + keys = asyncio.create_task(generate_keys()) + conn_name = ConnectionName("test-project", "test-region", "test-instance") + cache = LazyRefreshCache( + conn_name, + client=fake_client, + keys=keys, + enable_iam_auth=False, + ) + # test conn_name property + assert cache.conn_name == conn_name + # test closed property + assert cache.closed is False + # close cache and make sure property is updated + await cache.close() + assert cache.closed is True + + async def test_LazyRefreshCache_connect_info(fake_client: CloudSQLClient) -> None: """ Test that LazyRefreshCache.connect_info works as expected. From 4f2fc4cca04eec1dbc202a4d34f30c0f943e89fb Mon Sep 17 00:00:00 2001 From: jackwotherspoon Date: Mon, 3 Mar 2025 18:34:40 +0000 Subject: [PATCH 03/16] chore: add integration test with domain name --- .github/workflows/tests.yml | 2 ++ tests/system/test_pg8000_connection.py | 28 +++++++++++++++++++++++++- 2 files changed, 29 insertions(+), 1 deletion(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index e710138f..e4d4675b 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -81,6 +81,7 @@ jobs: POSTGRES_CAS_PASS:${{ vars.GOOGLE_CLOUD_PROJECT }}/POSTGRES_CAS_PASS POSTGRES_CUSTOMER_CAS_CONNECTION_NAME:${{ vars.GOOGLE_CLOUD_PROJECT }}/POSTGRES_CUSTOMER_CAS_CONNECTION_NAME POSTGRES_CUSTOMER_CAS_PASS:${{ vars.GOOGLE_CLOUD_PROJECT }}/POSTGRES_CUSTOMER_CAS_PASS + POSTGRES_CUSTOMER_CAS_DOMAIN_NAME:${{ vars.GOOGLE_CLOUD_PROJECT }}/POSTGRES_CUSTOMER_CAS_DOMAIN_NAME SQLSERVER_CONNECTION_NAME:${{ vars.GOOGLE_CLOUD_PROJECT }}/SQLSERVER_CONNECTION_NAME SQLSERVER_USER:${{ vars.GOOGLE_CLOUD_PROJECT }}/SQLSERVER_USER SQLSERVER_PASS:${{ vars.GOOGLE_CLOUD_PROJECT }}/SQLSERVER_PASS @@ -102,6 +103,7 @@ jobs: POSTGRES_CAS_PASS: "${{ steps.secrets.outputs.POSTGRES_CAS_PASS }}" POSTGRES_CUSTOMER_CAS_CONNECTION_NAME: "${{ steps.secrets.outputs.POSTGRES_CUSTOMER_CAS_CONNECTION_NAME }}" POSTGRES_CUSTOMER_CAS_PASS: "${{ steps.secrets.outputs.POSTGRES_CUSTOMER_CAS_PASS }}" + POSTGRES_CUSTOMER_CAS_DOMAIN_NAME: "${{ steps.secrets.outputs.POSTGRES_CUSTOMER_CAS_DOMAIN_NAME }}" SQLSERVER_CONNECTION_NAME: "${{ steps.secrets.outputs.SQLSERVER_CONNECTION_NAME }}" SQLSERVER_USER: "${{ steps.secrets.outputs.SQLSERVER_USER }}" SQLSERVER_PASS: "${{ steps.secrets.outputs.SQLSERVER_PASS }}" diff --git a/tests/system/test_pg8000_connection.py b/tests/system/test_pg8000_connection.py index b56a8e82..b80e7835 100644 --- a/tests/system/test_pg8000_connection.py +++ b/tests/system/test_pg8000_connection.py @@ -22,6 +22,8 @@ import sqlalchemy from google.cloud.sql.connector import Connector +from google.cloud.sql.connector import DefaultResolver +from google.cloud.sql.connector import DnsResolver def create_sqlalchemy_engine( @@ -30,6 +32,7 @@ def create_sqlalchemy_engine( password: str, db: str, refresh_strategy: str = "background", + resolver: DefaultResolver | DnsResolver = DefaultResolver, ) -> tuple[sqlalchemy.engine.Engine, Connector]: """Creates a connection pool for a Cloud SQL instance and returns the pool and the connector. Callers are responsible for closing the pool and the @@ -64,8 +67,13 @@ def create_sqlalchemy_engine( Refresh strategy for the Cloud SQL Connector. Can be one of "lazy" or "background". For serverless environments use "lazy" to avoid errors resulting from CPU being throttled. + resolver (Optional[google.cloud.sql.connector.DefaultResolver | google.cloud.sql.connector.DnsResolver]) + Resolver class for the Cloud SQL Connector. Can be one of + DefaultResolver (default) or DnsResolver. The resolver tells the + connector whether to resolve the 'instance_connection_name' as a + Cloud SQL instance connection name or as a domain name. """ - connector = Connector(refresh_strategy=refresh_strategy) + connector = Connector(refresh_strategy=refresh_strategy, resolver=resolver) def getconn() -> pg8000.dbapi.Connection: conn: pg8000.dbapi.Connection = connector.connect( @@ -153,3 +161,21 @@ def test_customer_managed_CAS_pg8000_connection() -> None: curr_time = time[0] assert type(curr_time) is datetime connector.close() + + +def test_domain_name_pg8000_connection() -> None: + """Basic test to get time from database using domain name to connect.""" + domain_name = os.environ["POSTGRES_CUSTOMER_CAS_DOMAIN_NAME"] + user = os.environ["POSTGRES_USER"] + password = os.environ["POSTGRES_CUSTOMER_CAS_PASS"] + db = os.environ["POSTGRES_DB"] + + engine, connector = create_sqlalchemy_engine( + domain_name, user, password, db, "lazy", DnsResolver + ) + with engine.connect() as conn: + time = conn.execute(sqlalchemy.text("SELECT NOW()")).fetchone() + conn.commit() + curr_time = time[0] + assert type(curr_time) is datetime + connector.close() From f92fd8837a06aa533d977006d27e771c0e51ad87 Mon Sep 17 00:00:00 2001 From: jackwotherspoon Date: Mon, 3 Mar 2025 18:41:22 +0000 Subject: [PATCH 04/16] chore: update type hint --- tests/system/test_pg8000_connection.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/system/test_pg8000_connection.py b/tests/system/test_pg8000_connection.py index b80e7835..d56cc502 100644 --- a/tests/system/test_pg8000_connection.py +++ b/tests/system/test_pg8000_connection.py @@ -18,6 +18,8 @@ import os # [START cloud_sql_connector_postgres_pg8000] +from typing import Union + import pg8000 import sqlalchemy @@ -32,7 +34,7 @@ def create_sqlalchemy_engine( password: str, db: str, refresh_strategy: str = "background", - resolver: DefaultResolver | DnsResolver = DefaultResolver, + resolver: Union[DefaultResolver, DnsResolver] = DefaultResolver, ) -> tuple[sqlalchemy.engine.Engine, Connector]: """Creates a connection pool for a Cloud SQL instance and returns the pool and the connector. Callers are responsible for closing the pool and the From 7a1812a1d1bf808a92038664777b0b5b14159331 Mon Sep 17 00:00:00 2001 From: jackwotherspoon Date: Tue, 18 Mar 2025 13:53:43 +0000 Subject: [PATCH 05/16] chore: attempt moving socket into ConnectionInfo --- google/cloud/sql/connector/__init__.py | 2 +- google/cloud/sql/connector/connection_info.py | 24 ++++++++++++++++--- google/cloud/sql/connector/connector.py | 2 +- google/cloud/sql/connector/monitored_cache.py | 8 ++++++- google/cloud/sql/connector/pymysql.py | 12 +++------- google/cloud/sql/connector/pytds.py | 12 +++------- tests/conftest.py | 2 +- tests/system/test_connector_object.py | 2 +- tests/system/test_ip_types.py | 2 +- tests/system/test_pymysql_connection.py | 2 +- tests/system/test_pytds_connection.py | 2 +- tests/unit/test_connector.py | 2 +- tests/unit/test_instance.py | 2 +- tests/unit/test_rate_limiter.py | 2 +- tests/unit/test_utils.py | 2 +- 15 files changed, 45 insertions(+), 33 deletions(-) diff --git a/google/cloud/sql/connector/__init__.py b/google/cloud/sql/connector/__init__.py index 99a5097a..6913337d 100644 --- a/google/cloud/sql/connector/__init__.py +++ b/google/cloud/sql/connector/__init__.py @@ -1,4 +1,4 @@ -"""" +""" Copyright 2019 Google LLC Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/google/cloud/sql/connector/connection_info.py b/google/cloud/sql/connector/connection_info.py index c9e48935..0e681604 100644 --- a/google/cloud/sql/connector/connection_info.py +++ b/google/cloud/sql/connector/connection_info.py @@ -17,12 +17,14 @@ import abc from dataclasses import dataclass import logging +import socket import ssl -from typing import Any, Optional, TYPE_CHECKING +from typing import Any, Optional, TYPE_CHECKING, Union from aiofiles.tempfile import TemporaryDirectory from google.cloud.sql.connector.connection_name import ConnectionName +from google.cloud.sql.connector.enums import IPTypes from google.cloud.sql.connector.exceptions import CloudSQLIPTypeError from google.cloud.sql.connector.exceptions import TLSVersionError from google.cloud.sql.connector.utils import write_to_file @@ -30,7 +32,6 @@ if TYPE_CHECKING: import datetime - from google.cloud.sql.connector.enums import IPTypes logger = logging.getLogger(name=__name__) @@ -69,13 +70,21 @@ class ConnectionInfo: database_version: str expiration: datetime.datetime context: Optional[ssl.SSLContext] = None + sock: Optional[ssl.SSLSocket] = None - async def create_ssl_context(self, enable_iam_auth: bool = False) -> ssl.SSLContext: + async def create_ssl_context( + self, enable_iam_auth: bool = False, return_socket: bool = False + ) -> Union[ssl.SSLContext, ssl.SSLSocket]: """Constructs a SSL/TLS context for the given connection info. Cache the SSL context to ensure we don't read from disk repeatedly when configuring a secure connection. """ + # Return socket if socket is cached and return_socket is set to True + if self.sock is not None and return_socket: + logger.debug("Socket in cache, returning it!") + return self.sock + # if SSL context is cached, use it if self.context is not None: return self.context @@ -116,6 +125,15 @@ async def create_ssl_context(self, enable_iam_auth: bool = False) -> ssl.SSLCont context.load_verify_locations(cafile=ca_filename) # set class attribute to cache context for subsequent calls self.context = context + # If return_socket is True, cache socket and return it + if return_socket: + logger.debug("Returning socket instead of context!") + sock = self.context.wrap_socket( + socket.create_connection((self.get_preferred_ip(IPTypes.PUBLIC), 3307)), + server_hostname="blah", + ) + self.sock = sock + return sock return context def get_preferred_ip(self, ip_type: IPTypes) -> str: diff --git a/google/cloud/sql/connector/connector.py b/google/cloud/sql/connector/connector.py index 7c25f306..356cad21 100755 --- a/google/cloud/sql/connector/connector.py +++ b/google/cloud/sql/connector/connector.py @@ -390,7 +390,7 @@ async def connect_async( except Exception: # with any exception, we attempt a force refresh, then throw the error - await cache.force_refresh() + await monitored_cache.force_refresh() raise async def _remove_cached( diff --git a/google/cloud/sql/connector/monitored_cache.py b/google/cloud/sql/connector/monitored_cache.py index 0a4cb6f1..d807f257 100644 --- a/google/cloud/sql/connector/monitored_cache.py +++ b/google/cloud/sql/connector/monitored_cache.py @@ -36,7 +36,7 @@ def __init__( self.resolver = resolver self.cache = cache self.domain_name_ticker: Optional[asyncio.Task] = None - self.open_conns_count: int = 0 + self.open_conns: int = 0 if self.cache.conn_name.domain_name: self.domain_name_ticker = asyncio.create_task( @@ -66,6 +66,12 @@ async def _check_domain_name(self) -> None: "connections!" ) await self.close() + conn_info = await self.connect_info() + if conn_info.sock: + logger.debug(f"Socket type: {type(conn_info.sock)}") + conn_info.sock.close() + else: + logger.debug("Domain name mapping has not changed!") except Exception as e: # Domain name checks should not be fatal, log error and continue. diff --git a/google/cloud/sql/connector/pymysql.py b/google/cloud/sql/connector/pymysql.py index a1658436..580008de 100644 --- a/google/cloud/sql/connector/pymysql.py +++ b/google/cloud/sql/connector/pymysql.py @@ -14,7 +14,6 @@ limitations under the License. """ -import socket import ssl from typing import Any, TYPE_CHECKING @@ -25,7 +24,7 @@ def connect( - ip_address: str, ctx: ssl.SSLContext, **kwargs: Any + ip_address: str, sock: ssl.SSLSocket, **kwargs: Any ) -> "pymysql.connections.Connection": """Helper function to create a pymysql DB-API connection object. @@ -33,8 +32,8 @@ def connect( :param ip_address: A string containing an IP address for the Cloud SQL instance. - :type ctx: ssl.SSLContext - :param ctx: An SSLContext object created from the Cloud SQL server CA + :type sock: ssl.SSLSocket + :param sock: An SSLSocket object created from the Cloud SQL server CA cert and ephemeral cert. :rtype: pymysql.Connection @@ -50,11 +49,6 @@ def connect( # allow automatic IAM database authentication to not require password kwargs["password"] = kwargs["password"] if "password" in kwargs else None - # Create socket and wrap with context. - sock = ctx.wrap_socket( - socket.create_connection((ip_address, SERVER_PROXY_PORT)), - server_hostname=ip_address, - ) # pop timeout as timeout arg is called 'connect_timeout' for pymysql timeout = kwargs.pop("timeout") kwargs["connect_timeout"] = kwargs.get("connect_timeout", timeout) diff --git a/google/cloud/sql/connector/pytds.py b/google/cloud/sql/connector/pytds.py index 243d90fd..1d51122f 100644 --- a/google/cloud/sql/connector/pytds.py +++ b/google/cloud/sql/connector/pytds.py @@ -15,7 +15,6 @@ """ import platform -import socket import ssl from typing import Any, TYPE_CHECKING @@ -27,15 +26,15 @@ import pytds -def connect(ip_address: str, ctx: ssl.SSLContext, **kwargs: Any) -> "pytds.Connection": +def connect(ip_address: str, sock: ssl.SSLSocket, **kwargs: Any) -> "pytds.Connection": """Helper function to create a pytds DB-API connection object. :type ip_address: str :param ip_address: A string containing an IP address for the Cloud SQL instance. - :type ctx: ssl.SSLContext - :param ctx: An SSLContext object created from the Cloud SQL server CA + :type sock: ssl.SSLSocket + :param sock: An SSLSocket object created from the Cloud SQL server CA cert and ephemeral cert. @@ -51,11 +50,6 @@ def connect(ip_address: str, ctx: ssl.SSLContext, **kwargs: Any) -> "pytds.Conne db = kwargs.pop("db", None) - # Create socket and wrap with context. - sock = ctx.wrap_socket( - socket.create_connection((ip_address, SERVER_PROXY_PORT)), - server_hostname=ip_address, - ) if kwargs.pop("active_directory_auth", False): if platform.system() == "Windows": # Ignore username and password if using active directory auth diff --git a/tests/conftest.py b/tests/conftest.py index 3a1a38a2..c75de48c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,4 +1,4 @@ -"""" +""" Copyright 2021 Google LLC Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/tests/system/test_connector_object.py b/tests/system/test_connector_object.py index c2b5cf12..258b80aa 100644 --- a/tests/system/test_connector_object.py +++ b/tests/system/test_connector_object.py @@ -1,4 +1,4 @@ -"""" +""" Copyright 2021 Google LLC Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/tests/system/test_ip_types.py b/tests/system/test_ip_types.py index 2df3b1df..3af49c54 100644 --- a/tests/system/test_ip_types.py +++ b/tests/system/test_ip_types.py @@ -1,4 +1,4 @@ -"""" +""" Copyright 2021 Google LLC Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/tests/system/test_pymysql_connection.py b/tests/system/test_pymysql_connection.py index 490b1fab..1e7e2683 100644 --- a/tests/system/test_pymysql_connection.py +++ b/tests/system/test_pymysql_connection.py @@ -1,4 +1,4 @@ -"""" +""" Copyright 2021 Google LLC Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/tests/system/test_pytds_connection.py b/tests/system/test_pytds_connection.py index d848abc1..fd88d230 100644 --- a/tests/system/test_pytds_connection.py +++ b/tests/system/test_pytds_connection.py @@ -1,4 +1,4 @@ -"""" +""" Copyright 2021 Google LLC Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/tests/unit/test_connector.py b/tests/unit/test_connector.py index e25c9a38..498c947c 100644 --- a/tests/unit/test_connector.py +++ b/tests/unit/test_connector.py @@ -1,4 +1,4 @@ -"""" +""" Copyright 2021 Google LLC Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/tests/unit/test_instance.py b/tests/unit/test_instance.py index aeedf339..1a3d6091 100644 --- a/tests/unit/test_instance.py +++ b/tests/unit/test_instance.py @@ -1,4 +1,4 @@ -"""" +""" Copyright 2019 Google LLC Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/tests/unit/test_rate_limiter.py b/tests/unit/test_rate_limiter.py index 5e187b81..8ef586b5 100644 --- a/tests/unit/test_rate_limiter.py +++ b/tests/unit/test_rate_limiter.py @@ -1,4 +1,4 @@ -"""" +""" Copyright 2021 Google LLC Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py index 6545bc7a..fe4e9095 100644 --- a/tests/unit/test_utils.py +++ b/tests/unit/test_utils.py @@ -1,4 +1,4 @@ -"""" +""" Copyright 2019 Google LLC Licensed under the Apache License, Version 2.0 (the "License"); From 6f6d5e49f6d71bbfe1b6060006f7e7469f1db914 Mon Sep 17 00:00:00 2001 From: jackwotherspoon Date: Wed, 19 Mar 2025 22:52:09 +0000 Subject: [PATCH 06/16] chore: revert connection_info.py --- google/cloud/sql/connector/connection_info.py | 46 ++----------------- 1 file changed, 3 insertions(+), 43 deletions(-) diff --git a/google/cloud/sql/connector/connection_info.py b/google/cloud/sql/connector/connection_info.py index 0e681604..82e3a901 100644 --- a/google/cloud/sql/connector/connection_info.py +++ b/google/cloud/sql/connector/connection_info.py @@ -14,17 +14,14 @@ from __future__ import annotations -import abc from dataclasses import dataclass import logging -import socket import ssl -from typing import Any, Optional, TYPE_CHECKING, Union +from typing import Any, Optional, TYPE_CHECKING from aiofiles.tempfile import TemporaryDirectory from google.cloud.sql.connector.connection_name import ConnectionName -from google.cloud.sql.connector.enums import IPTypes from google.cloud.sql.connector.exceptions import CloudSQLIPTypeError from google.cloud.sql.connector.exceptions import TLSVersionError from google.cloud.sql.connector.utils import write_to_file @@ -32,31 +29,11 @@ if TYPE_CHECKING: import datetime + from google.cloud.sql.connector.enums import IPTypes logger = logging.getLogger(name=__name__) -class ConnectionInfoCache(abc.ABC): - """Abstract class for Connector connection info caches.""" - - @abc.abstractmethod - async def connect_info(self) -> ConnectionInfo: - pass - - @abc.abstractmethod - async def force_refresh(self) -> None: - pass - - @abc.abstractmethod - async def close(self) -> None: - pass - - @property - @abc.abstractmethod - def closed(self) -> bool: - pass - - @dataclass class ConnectionInfo: """Contains all necessary information to connect securely to the @@ -70,21 +47,13 @@ class ConnectionInfo: database_version: str expiration: datetime.datetime context: Optional[ssl.SSLContext] = None - sock: Optional[ssl.SSLSocket] = None - async def create_ssl_context( - self, enable_iam_auth: bool = False, return_socket: bool = False - ) -> Union[ssl.SSLContext, ssl.SSLSocket]: + async def create_ssl_context(self, enable_iam_auth: bool = False) -> ssl.SSLContext: """Constructs a SSL/TLS context for the given connection info. Cache the SSL context to ensure we don't read from disk repeatedly when configuring a secure connection. """ - # Return socket if socket is cached and return_socket is set to True - if self.sock is not None and return_socket: - logger.debug("Socket in cache, returning it!") - return self.sock - # if SSL context is cached, use it if self.context is not None: return self.context @@ -125,15 +94,6 @@ async def create_ssl_context( context.load_verify_locations(cafile=ca_filename) # set class attribute to cache context for subsequent calls self.context = context - # If return_socket is True, cache socket and return it - if return_socket: - logger.debug("Returning socket instead of context!") - sock = self.context.wrap_socket( - socket.create_connection((self.get_preferred_ip(IPTypes.PUBLIC), 3307)), - server_hostname="blah", - ) - self.sock = sock - return sock return context def get_preferred_ip(self, ip_type: IPTypes) -> str: From e8702a214fdcf248220df404b3cb81a7d29560d6 Mon Sep 17 00:00:00 2001 From: jackwotherspoon Date: Thu, 20 Mar 2025 14:14:33 +0000 Subject: [PATCH 07/16] chore: move socket initialization to Connector level --- google/cloud/sql/connector/connection_info.py | 22 ++++++++++++ google/cloud/sql/connector/connector.py | 34 +++++++++++++++---- google/cloud/sql/connector/monitored_cache.py | 32 ++++++++++++----- google/cloud/sql/connector/pg8000.py | 13 ++----- 4 files changed, 75 insertions(+), 26 deletions(-) diff --git a/google/cloud/sql/connector/connection_info.py b/google/cloud/sql/connector/connection_info.py index 82e3a901..c9e48935 100644 --- a/google/cloud/sql/connector/connection_info.py +++ b/google/cloud/sql/connector/connection_info.py @@ -14,6 +14,7 @@ from __future__ import annotations +import abc from dataclasses import dataclass import logging import ssl @@ -34,6 +35,27 @@ logger = logging.getLogger(name=__name__) +class ConnectionInfoCache(abc.ABC): + """Abstract class for Connector connection info caches.""" + + @abc.abstractmethod + async def connect_info(self) -> ConnectionInfo: + pass + + @abc.abstractmethod + async def force_refresh(self) -> None: + pass + + @abc.abstractmethod + async def close(self) -> None: + pass + + @property + @abc.abstractmethod + def closed(self) -> bool: + pass + + @dataclass class ConnectionInfo: """Contains all necessary information to connect securely to the diff --git a/google/cloud/sql/connector/connector.py b/google/cloud/sql/connector/connector.py index 356cad21..3b55b310 100755 --- a/google/cloud/sql/connector/connector.py +++ b/google/cloud/sql/connector/connector.py @@ -20,6 +20,7 @@ from functools import partial import logging import os +import socket from threading import Thread from types import TracebackType from typing import Any, Optional, Union @@ -47,6 +48,7 @@ logger = logging.getLogger(name=__name__) ASYNC_DRIVERS = ["asyncpg"] +SERVER_PROXY_PORT = 3307 _DEFAULT_SCHEME = "https://" _DEFAULT_UNIVERSE_DOMAIN = "googleapis.com" _SQLADMIN_HOST_TEMPLATE = "sqladmin.{universe_domain}" @@ -291,10 +293,11 @@ async def connect_async( driver=driver, ) enable_iam_auth = kwargs.pop("enable_iam_auth", self._enable_iam_auth) - if (instance_connection_string, enable_iam_auth) in self._cache: - monitored_cache = self._cache[(instance_connection_string, enable_iam_auth)] + + conn_name = await self._resolver.resolve(instance_connection_string) + if (str(conn_name), enable_iam_auth) in self._cache: + monitored_cache = self._cache[(str(conn_name), enable_iam_auth)] else: - conn_name = await self._resolver.resolve(instance_connection_string) if self._refresh_strategy == RefreshStrategy.LAZY: logger.debug( f"['{conn_name}']: Refresh strategy is set to lazy refresh" @@ -322,7 +325,7 @@ async def connect_async( self._resolver, ) logger.debug(f"['{conn_name}']: Connection info added to cache") - self._cache[(instance_connection_string, enable_iam_auth)] = monitored_cache + self._cache[(str(conn_name), enable_iam_auth)] = monitored_cache connect_func = { "pymysql": pymysql.connect, @@ -358,7 +361,7 @@ async def connect_async( except Exception: # with an error from Cloud SQL Admin API call or IP type, invalidate # the cache and re-raise the error - await self._remove_cached(instance_connection_string, enable_iam_auth) + await self._remove_cached(str(conn_name), enable_iam_auth) raise logger.debug(f"['{conn_info.conn_name}']: Connecting to {ip_address}:3307") # format `user` param for automatic IAM database authn @@ -379,11 +382,21 @@ async def connect_async( await conn_info.create_ssl_context(enable_iam_auth), **kwargs, ) - # synchronous drivers are blocking and run using executor + # Create socket with SSLContext for sync drivers + ctx = await conn_info.create_ssl_context(enable_iam_auth) + sock = ctx.wrap_socket( + socket.create_connection((ip_address, SERVER_PROXY_PORT)), + server_hostname=ip_address, + ) + # If this connection was opened using a domain name, then store it + # for later in case we need to forcibly close it on failover. + if conn_info.conn_name.domain_name: + monitored_cache.sockets.append(sock) + # Synchronous drivers are blocking and run using executor connect_partial = partial( connector, ip_address, - await conn_info.create_ssl_context(enable_iam_auth), + sock, **kwargs, ) return await self._loop.run_in_executor(None, connect_partial) @@ -468,6 +481,7 @@ async def create_async_connector( universe_domain: Optional[str] = None, refresh_strategy: str | RefreshStrategy = RefreshStrategy.BACKGROUND, resolver: type[DefaultResolver] | type[DnsResolver] = DefaultResolver, + failover_period: int = 30, ) -> Connector: """Helper function to create Connector object for asyncio connections. @@ -519,6 +533,11 @@ async def create_async_connector( DnsResolver. Default: DefaultResolver + failover_period (int): The time interval in seconds between each + attempt to check if a failover has occured for a given instance. + Must be used with `resolver=DnsResolver` to have any effect. + Default: 30 + Returns: A Connector instance configured with running event loop. """ @@ -537,4 +556,5 @@ async def create_async_connector( universe_domain=universe_domain, refresh_strategy=refresh_strategy, resolver=resolver, + failover_period=failover_period, ) diff --git a/google/cloud/sql/connector/monitored_cache.py b/google/cloud/sql/connector/monitored_cache.py index d807f257..1c9bb798 100644 --- a/google/cloud/sql/connector/monitored_cache.py +++ b/google/cloud/sql/connector/monitored_cache.py @@ -14,6 +14,7 @@ import asyncio import logging +import ssl from typing import Any, Callable, Optional, Union from google.cloud.sql.connector.connection_info import ConnectionInfo @@ -36,7 +37,7 @@ def __init__( self.resolver = resolver self.cache = cache self.domain_name_ticker: Optional[asyncio.Task] = None - self.open_conns: int = 0 + self.sockets: list[ssl.SSLSocket] = [] if self.cache.conn_name.domain_name: self.domain_name_ticker = asyncio.create_task( @@ -51,6 +52,15 @@ def __init__( def closed(self) -> bool: return self.cache.closed + async def _purge_closed_sockets(self) -> None: + open_sockets = [] + for socket in self.sockets: + # Check fileno as method to check if socket is closed. Will return + # -1 on failure, which will be used to signal socket closed. + if socket.fileno() != -1: + open_sockets.append(socket) + self.sockets = open_sockets + async def _check_domain_name(self) -> None: try: # Resolve domain name and see if Cloud SQL instance connection name @@ -66,12 +76,6 @@ async def _check_domain_name(self) -> None: "connections!" ) await self.close() - conn_info = await self.connect_info() - if conn_info.sock: - logger.debug(f"Socket type: {type(conn_info.sock)}") - conn_info.sock.close() - else: - logger.debug("Domain name mapping has not changed!") except Exception as e: # Domain name checks should not be fatal, log error and continue. @@ -97,10 +101,20 @@ async def close(self) -> None: logger.debug( f"['{self.cache.conn_name}']: Cancelled domain name polling task." ) - + finally: + self.domain_name_ticker = None # If cache is already closed, no further work. - if self.cache.closed: + if self.closed: return + + # Close any still open sockets + for socket in self.sockets: + # Check fileno as method to check if socket is closed. Will return + # -1 on failure, which will be used to signal socket closed. + if socket.fileno() != -1: + socket.close() + + # Close underyling ConnectionInfoCache await self.cache.close() diff --git a/google/cloud/sql/connector/pg8000.py b/google/cloud/sql/connector/pg8000.py index 1f66dde2..001a6acc 100644 --- a/google/cloud/sql/connector/pg8000.py +++ b/google/cloud/sql/connector/pg8000.py @@ -14,7 +14,6 @@ limitations under the License. """ -import socket import ssl from typing import Any, TYPE_CHECKING @@ -25,7 +24,7 @@ def connect( - ip_address: str, ctx: ssl.SSLContext, **kwargs: Any + ip_address: str, sock: ssl.SSLSocket, **kwargs: Any ) -> "pg8000.dbapi.Connection": """Helper function to create a pg8000 DB-API connection object. @@ -33,8 +32,8 @@ def connect( :param ip_address: A string containing an IP address for the Cloud SQL instance. - :type ctx: ssl.SSLContext - :param ctx: An SSLContext object created from the Cloud SQL server CA + :type sock: ssl.SSLSocket + :param sock: An SSLSocket object created from the Cloud SQL server CA cert and ephemeral cert. @@ -48,12 +47,6 @@ def connect( 'Unable to import module "pg8000." Please install and try again.' ) - # Create socket and wrap with context. - sock = ctx.wrap_socket( - socket.create_connection((ip_address, SERVER_PROXY_PORT)), - server_hostname=ip_address, - ) - user = kwargs.pop("user") db = kwargs.pop("db") passwd = kwargs.pop("password", None) From ff9d6c91aece82d2175454c08bc7b3f180f258e7 Mon Sep 17 00:00:00 2001 From: jackwotherspoon Date: Thu, 20 Mar 2025 14:28:52 +0000 Subject: [PATCH 08/16] chore: change secret back --- .github/workflows/tests.yml | 4 ++-- tests/system/test_pg8000_connection.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index e4d4675b..b8e6eb58 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -81,7 +81,7 @@ jobs: POSTGRES_CAS_PASS:${{ vars.GOOGLE_CLOUD_PROJECT }}/POSTGRES_CAS_PASS POSTGRES_CUSTOMER_CAS_CONNECTION_NAME:${{ vars.GOOGLE_CLOUD_PROJECT }}/POSTGRES_CUSTOMER_CAS_CONNECTION_NAME POSTGRES_CUSTOMER_CAS_PASS:${{ vars.GOOGLE_CLOUD_PROJECT }}/POSTGRES_CUSTOMER_CAS_PASS - POSTGRES_CUSTOMER_CAS_DOMAIN_NAME:${{ vars.GOOGLE_CLOUD_PROJECT }}/POSTGRES_CUSTOMER_CAS_DOMAIN_NAME + POSTGRES_CUSTOMER_CAS_PASS_VALID_DOMAIN_NAME:${{ vars.GOOGLE_CLOUD_PROJECT }}/POSTGRES_CUSTOMER_CAS_PASS_VALID_DOMAIN_NAME SQLSERVER_CONNECTION_NAME:${{ vars.GOOGLE_CLOUD_PROJECT }}/SQLSERVER_CONNECTION_NAME SQLSERVER_USER:${{ vars.GOOGLE_CLOUD_PROJECT }}/SQLSERVER_USER SQLSERVER_PASS:${{ vars.GOOGLE_CLOUD_PROJECT }}/SQLSERVER_PASS @@ -103,7 +103,7 @@ jobs: POSTGRES_CAS_PASS: "${{ steps.secrets.outputs.POSTGRES_CAS_PASS }}" POSTGRES_CUSTOMER_CAS_CONNECTION_NAME: "${{ steps.secrets.outputs.POSTGRES_CUSTOMER_CAS_CONNECTION_NAME }}" POSTGRES_CUSTOMER_CAS_PASS: "${{ steps.secrets.outputs.POSTGRES_CUSTOMER_CAS_PASS }}" - POSTGRES_CUSTOMER_CAS_DOMAIN_NAME: "${{ steps.secrets.outputs.POSTGRES_CUSTOMER_CAS_DOMAIN_NAME }}" + POSTGRES_CUSTOMER_CAS_PASS_VALID_DOMAIN_NAME: "${{ steps.secrets.outputs.POSTGRES_CUSTOMER_CAS_PASS_VALID_DOMAIN_NAME }}" SQLSERVER_CONNECTION_NAME: "${{ steps.secrets.outputs.SQLSERVER_CONNECTION_NAME }}" SQLSERVER_USER: "${{ steps.secrets.outputs.SQLSERVER_USER }}" SQLSERVER_PASS: "${{ steps.secrets.outputs.SQLSERVER_PASS }}" diff --git a/tests/system/test_pg8000_connection.py b/tests/system/test_pg8000_connection.py index b92ffc23..c47b860c 100644 --- a/tests/system/test_pg8000_connection.py +++ b/tests/system/test_pg8000_connection.py @@ -167,7 +167,7 @@ def test_customer_managed_CAS_pg8000_connection() -> None: def test_custom_SAN_with_dns_pg8000_connection() -> None: """Basic test to get time from database.""" - inst_conn_name = os.environ["POSTGRES_CUSTOMER_CAS_DOMAIN_NAME"] + inst_conn_name = os.environ["POSTGRES_CUSTOMER_CAS_PASS_VALID_DOMAIN_NAME"] user = os.environ["POSTGRES_USER"] password = os.environ["POSTGRES_CUSTOMER_CAS_PASS"] db = os.environ["POSTGRES_DB"] From 24a6230075459838a580a9522ea743965f1daa9c Mon Sep 17 00:00:00 2001 From: jackwotherspoon Date: Thu, 20 Mar 2025 14:44:14 +0000 Subject: [PATCH 09/16] chore: lint --- google/cloud/sql/connector/connector.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/google/cloud/sql/connector/connector.py b/google/cloud/sql/connector/connector.py index 3b55b310..2a7c6c96 100755 --- a/google/cloud/sql/connector/connector.py +++ b/google/cloud/sql/connector/connector.py @@ -23,7 +23,7 @@ import socket from threading import Thread from types import TracebackType -from typing import Any, Optional, Union +from typing import Any, Callable, Optional, Union import google.auth from google.auth.credentials import Credentials @@ -336,7 +336,7 @@ async def connect_async( # only accept supported database drivers try: - connector = connect_func[driver] + connector: Callable = connect_func[driver] # type: ignore except KeyError: raise KeyError(f"Driver '{driver}' is not supported.") From ac5fca037e1c55470d3d9a518ffaf5d29a5cf57b Mon Sep 17 00:00:00 2001 From: jackwotherspoon Date: Thu, 20 Mar 2025 15:09:37 +0000 Subject: [PATCH 10/16] chore: update unit tests --- google/cloud/sql/connector/pg8000.py | 2 -- google/cloud/sql/connector/pymysql.py | 2 -- google/cloud/sql/connector/pytds.py | 2 -- tests/unit/test_pg8000.py | 13 +++++----- tests/unit/test_pymysql.py | 13 +++++----- tests/unit/test_pytds.py | 35 ++++++++++++--------------- 6 files changed, 28 insertions(+), 39 deletions(-) diff --git a/google/cloud/sql/connector/pg8000.py b/google/cloud/sql/connector/pg8000.py index 001a6acc..baaee661 100644 --- a/google/cloud/sql/connector/pg8000.py +++ b/google/cloud/sql/connector/pg8000.py @@ -17,8 +17,6 @@ import ssl from typing import Any, TYPE_CHECKING -SERVER_PROXY_PORT = 3307 - if TYPE_CHECKING: import pg8000 diff --git a/google/cloud/sql/connector/pymysql.py b/google/cloud/sql/connector/pymysql.py index 580008de..f83f7076 100644 --- a/google/cloud/sql/connector/pymysql.py +++ b/google/cloud/sql/connector/pymysql.py @@ -17,8 +17,6 @@ import ssl from typing import Any, TYPE_CHECKING -SERVER_PROXY_PORT = 3307 - if TYPE_CHECKING: import pymysql diff --git a/google/cloud/sql/connector/pytds.py b/google/cloud/sql/connector/pytds.py index 1d51122f..3128fdb6 100644 --- a/google/cloud/sql/connector/pytds.py +++ b/google/cloud/sql/connector/pytds.py @@ -20,8 +20,6 @@ from google.cloud.sql.connector.exceptions import PlatformNotSupportedError -SERVER_PROXY_PORT = 3307 - if TYPE_CHECKING: import pytds diff --git a/tests/unit/test_pg8000.py b/tests/unit/test_pg8000.py index 1b2adbb6..e01a5344 100644 --- a/tests/unit/test_pg8000.py +++ b/tests/unit/test_pg8000.py @@ -14,7 +14,7 @@ limitations under the License. """ -from functools import partial +import socket from typing import Any from mock import patch @@ -31,15 +31,14 @@ async def test_pg8000(kwargs: Any) -> None: ip_addr = "127.0.0.1" # build ssl.SSLContext context = await create_ssl_context() - # force all wrap_socket calls to have do_handshake_on_connect=False - setattr( - context, - "wrap_socket", - partial(context.wrap_socket, do_handshake_on_connect=False), + sock = context.wrap_socket( + socket.create_connection((ip_addr, 3307)), + server_hostname=ip_addr, + do_handshake_on_connect=False, ) with patch("pg8000.dbapi.connect") as mock_connect: mock_connect.return_value = True - connection = connect(ip_addr, context, **kwargs) + connection = connect(ip_addr, sock, **kwargs) assert connection is True # verify that driver connection call would be made assert mock_connect.assert_called_once diff --git a/tests/unit/test_pymysql.py b/tests/unit/test_pymysql.py index 69d2aba8..66b1f22a 100644 --- a/tests/unit/test_pymysql.py +++ b/tests/unit/test_pymysql.py @@ -14,7 +14,7 @@ limitations under the License. """ -from functools import partial +import socket import ssl from typing import Any @@ -40,15 +40,14 @@ async def test_pymysql(kwargs: Any) -> None: ip_addr = "127.0.0.1" # build ssl.SSLContext context = await create_ssl_context() - # force all wrap_socket calls to have do_handshake_on_connect=False - setattr( - context, - "wrap_socket", - partial(context.wrap_socket, do_handshake_on_connect=False), + sock = context.wrap_socket( + socket.create_connection((ip_addr, 3307)), + server_hostname=ip_addr, + do_handshake_on_connect=False, ) kwargs["timeout"] = 30 with patch("pymysql.Connection") as mock_connect: mock_connect.return_value = MockConnection - pymysql_connect(ip_addr, context, **kwargs) + pymysql_connect(ip_addr, sock, **kwargs) # verify that driver connection call would be made assert mock_connect.assert_called_once diff --git a/tests/unit/test_pytds.py b/tests/unit/test_pytds.py index 633aab74..9efe00ee 100644 --- a/tests/unit/test_pytds.py +++ b/tests/unit/test_pytds.py @@ -14,8 +14,8 @@ limitations under the License. """ -from functools import partial import platform +import socket from typing import Any from mock import patch @@ -43,16 +43,15 @@ async def test_pytds(kwargs: Any) -> None: ip_addr = "127.0.0.1" # build ssl.SSLContext context = await create_ssl_context() - # force all wrap_socket calls to have do_handshake_on_connect=False - setattr( - context, - "wrap_socket", - partial(context.wrap_socket, do_handshake_on_connect=False), + sock = context.wrap_socket( + socket.create_connection((ip_addr, 3307)), + server_hostname=ip_addr, + do_handshake_on_connect=False, ) with patch("pytds.connect") as mock_connect: mock_connect.return_value = True - connection = connect(ip_addr, context, **kwargs) + connection = connect(ip_addr, sock, **kwargs) # verify that driver connection call would be made assert connection is True assert mock_connect.assert_called_once @@ -68,17 +67,16 @@ async def test_pytds_platform_error(kwargs: Any) -> None: assert platform.system() == "Linux" # build ssl.SSLContext context = await create_ssl_context() - # force all wrap_socket calls to have do_handshake_on_connect=False - setattr( - context, - "wrap_socket", - partial(context.wrap_socket, do_handshake_on_connect=False), + sock = context.wrap_socket( + socket.create_connection((ip_addr, 3307)), + server_hostname=ip_addr, + do_handshake_on_connect=False, ) # add active_directory_auth to kwargs kwargs["active_directory_auth"] = True # verify that error is thrown with Linux and active_directory_auth with pytest.raises(PlatformNotSupportedError): - connect(ip_addr, context, **kwargs) + connect(ip_addr, sock, **kwargs) @pytest.mark.usefixtures("server") @@ -94,11 +92,10 @@ async def test_pytds_windows_active_directory_auth(kwargs: Any) -> None: assert platform.system() == "Windows" # build ssl.SSLContext context = await create_ssl_context() - # force all wrap_socket calls to have do_handshake_on_connect=False - setattr( - context, - "wrap_socket", - partial(context.wrap_socket, do_handshake_on_connect=False), + sock = context.wrap_socket( + socket.create_connection((ip_addr, 3307)), + server_hostname=ip_addr, + do_handshake_on_connect=False, ) # add active_directory_auth and server_name to kwargs kwargs["active_directory_auth"] = True @@ -107,7 +104,7 @@ async def test_pytds_windows_active_directory_auth(kwargs: Any) -> None: mock_connect.return_value = True with patch("pytds.login.SspiAuth") as mock_login: mock_login.return_value = True - connection = connect(ip_addr, context, **kwargs) + connection = connect(ip_addr, sock, **kwargs) # verify that driver connection call would be made assert mock_login.assert_called_once assert connection is True From a10100364d6965d4cfc6e9cbcc37e81aa0ee518a Mon Sep 17 00:00:00 2001 From: jackwotherspoon Date: Thu, 20 Mar 2025 18:20:03 +0000 Subject: [PATCH 11/16] chore: add additional tests --- README.md | 38 +++++++++++++++++++ google/cloud/sql/connector/monitored_cache.py | 15 ++------ 2 files changed, 42 insertions(+), 11 deletions(-) diff --git a/README.md b/README.md index 1f0e633b..d79e706d 100644 --- a/README.md +++ b/README.md @@ -428,6 +428,44 @@ with Connector(resolver=DnsResolver) as connector: # ... use SQLAlchemy engine normally ``` +### Automatic failover using DNS domain names + +> [!NOTE] +> +> Usage of the `asyncpg` driver does not currently support automatic failover. + +When the connector is configured using a domain name, the connector will +periodically check if the DNS record for an instance changes. When the connector +detects that the domain name refers to a different instance, the connector will +close all open connections to the old instance. Subsequent connection attempts +will be directed to the new instance. + +For example: suppose application is configured to connect using the +domain name `prod-db.mycompany.example.com`. Initially the private DNS +zone has a TXT record with the value `my-project:region:my-instance`. The +application establishes connections to the `my-project:region:my-instance` +Cloud SQL instance. + +Then, to reconfigure the application to use a different database +instance, change the value of the `prod-db.mycompany.example.com` DNS record +from `my-project:region:my-instance` to `my-project:other-region:my-instance-2` + +The connector inside the application detects the change to this +DNS record. Now, when the application connects to its database using the +domain name `prod-db.mycompany.example.com`, it will connect to the +`my-project:other-region:my-instance-2` Cloud SQL instance. + +The connector will automatically close all existing connections to +`my-project:region:my-instance`. This will force the connection pools to +establish new connections. Also, it may cause database queries in progress +to fail. + +The connector will poll for changes to the DNS name every 30 seconds by default. +You may configure the frequency of the connections using the Connector's +`failover_period` argument (i.e. `Connector(failover_period=60`). When this is +set to 0, the connector will disable polling and only check if the DNS record +changed when it is creating a new connection. + ### Using the Python Connector with Python Web Frameworks The Python Connector can be used alongside popular Python web frameworks such diff --git a/google/cloud/sql/connector/monitored_cache.py b/google/cloud/sql/connector/monitored_cache.py index 1c9bb798..96c976da 100644 --- a/google/cloud/sql/connector/monitored_cache.py +++ b/google/cloud/sql/connector/monitored_cache.py @@ -39,7 +39,9 @@ def __init__( self.domain_name_ticker: Optional[asyncio.Task] = None self.sockets: list[ssl.SSLSocket] = [] - if self.cache.conn_name.domain_name: + # If domain name is configured for instance and failover period is set, + # poll for DNS record changes. + if self.cache.conn_name.domain_name and failover_period > 0: self.domain_name_ticker = asyncio.create_task( ticker(failover_period, self._check_domain_name) ) @@ -52,15 +54,6 @@ def __init__( def closed(self) -> bool: return self.cache.closed - async def _purge_closed_sockets(self) -> None: - open_sockets = [] - for socket in self.sockets: - # Check fileno as method to check if socket is closed. Will return - # -1 on failure, which will be used to signal socket closed. - if socket.fileno() != -1: - open_sockets.append(socket) - self.sockets = open_sockets - async def _check_domain_name(self) -> None: try: # Resolve domain name and see if Cloud SQL instance connection name @@ -109,7 +102,7 @@ async def close(self) -> None: # Close any still open sockets for socket in self.sockets: - # Check fileno as method to check if socket is closed. Will return + # Check fileno for if socket is closed. Will return # -1 on failure, which will be used to signal socket closed. if socket.fileno() != -1: socket.close() From d260934022a451fa4f97e6d8cd0a8b56c25da77f Mon Sep 17 00:00:00 2001 From: jackwotherspoon Date: Thu, 20 Mar 2025 19:00:30 +0000 Subject: [PATCH 12/16] chore: improve tests --- google/cloud/sql/connector/monitored_cache.py | 16 ++ tests/unit/test_monitored_cache.py | 218 ++++++++++++++++++ 2 files changed, 234 insertions(+) create mode 100644 tests/unit/test_monitored_cache.py diff --git a/google/cloud/sql/connector/monitored_cache.py b/google/cloud/sql/connector/monitored_cache.py index 96c976da..4c119cd3 100644 --- a/google/cloud/sql/connector/monitored_cache.py +++ b/google/cloud/sql/connector/monitored_cache.py @@ -54,7 +54,23 @@ def __init__( def closed(self) -> bool: return self.cache.closed + def _purge_closed_sockets(self) -> None: + """Remove closed sockets from monitored cache. + + If a socket is closed by the database driver we should remove it from + list of sockets. + """ + open_sockets = [] + for socket in self.sockets: + # Check fileno for if socket is closed. Will return + # -1 on failure, which will be used to signal socket closed. + if socket.fileno() != -1: + open_sockets.append(socket) + self.sockets = open_sockets + async def _check_domain_name(self) -> None: + # remove any closed connections from cache + self._purge_closed_sockets() try: # Resolve domain name and see if Cloud SQL instance connection name # has changed. If it has, close all connections. diff --git a/tests/unit/test_monitored_cache.py b/tests/unit/test_monitored_cache.py new file mode 100644 index 00000000..976b4e0c --- /dev/null +++ b/tests/unit/test_monitored_cache.py @@ -0,0 +1,218 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import socket + +import dns.message +import dns.rdataclass +import dns.rdatatype +import dns.resolver +from mock import patch +from mocks import create_ssl_context +import pytest + +from google.cloud.sql.connector.client import CloudSQLClient +from google.cloud.sql.connector.connection_name import ConnectionName +from google.cloud.sql.connector.lazy import LazyRefreshCache +from google.cloud.sql.connector.monitored_cache import MonitoredCache +from google.cloud.sql.connector.resolver import DefaultResolver +from google.cloud.sql.connector.resolver import DnsResolver +from google.cloud.sql.connector.utils import generate_keys + +query_text = """id 1234 +opcode QUERY +rcode NOERROR +flags QR AA RD RA +;QUESTION +db.example.com. IN TXT +;ANSWER +db.example.com. 0 IN TXT "test-project:test-region:test-instance" +;AUTHORITY +;ADDITIONAL +""" + + +async def test_MonitoredCache_properties(fake_client: CloudSQLClient) -> None: + """ + Test that MonitoredCache properties work as expected. + """ + conn_name = ConnectionName("test-project", "test-region", "test-instance") + cache = LazyRefreshCache( + conn_name, + client=fake_client, + keys=asyncio.create_task(generate_keys()), + enable_iam_auth=False, + ) + monitored_cache = MonitoredCache(cache, 30, DefaultResolver()) + # test that ticker is not set for instance not using domain name + assert monitored_cache.domain_name_ticker is None + # test closed property + assert monitored_cache.closed is False + # close cache and make sure property is updated + await monitored_cache.close() + assert monitored_cache.closed is True + + +async def test_MonitoredCache_with_DnsResolver(fake_client: CloudSQLClient) -> None: + """ + Test that MonitoredCache with DnsResolver work as expected. + """ + conn_name = ConnectionName( + "test-project", "test-region", "test-instance", "db.example.com" + ) + cache = LazyRefreshCache( + conn_name, + client=fake_client, + keys=asyncio.create_task(generate_keys()), + enable_iam_auth=False, + ) + # Patch DNS resolution with valid TXT records + with patch("dns.asyncresolver.Resolver.resolve") as mock_connect: + answer = dns.resolver.Answer( + "db.example.com", + dns.rdatatype.TXT, + dns.rdataclass.IN, + dns.message.from_text(query_text), + ) + mock_connect.return_value = answer + resolver = DnsResolver() + resolver.port = 5053 + monitored_cache = MonitoredCache(cache, 30, resolver) + # test that ticker is set for instance using domain name + assert type(monitored_cache.domain_name_ticker) is asyncio.Task + # test closed property + assert monitored_cache.closed is False + # close cache and make sure property is updated + await monitored_cache.close() + assert monitored_cache.closed is True + # domain name ticker should be set back to None + assert monitored_cache.domain_name_ticker is None + + +async def test_MonitoredCache_with_disabled_failover( + fake_client: CloudSQLClient, +) -> None: + """ + Test that MonitoredCache disables DNS polling with failover_period=0 + """ + conn_name = ConnectionName( + "test-project", "test-region", "test-instance", "db.example.com" + ) + cache = LazyRefreshCache( + conn_name, + client=fake_client, + keys=asyncio.create_task(generate_keys()), + enable_iam_auth=False, + ) + monitored_cache = MonitoredCache(cache, 0, DnsResolver()) + # test that ticker is not set when failover is disabled + assert monitored_cache.domain_name_ticker is None + # test closed property + assert monitored_cache.closed is False + # close cache and make sure property is updated + await monitored_cache.close() + assert monitored_cache.closed is True + + +@pytest.mark.usefixtures("server") +async def test_MonitoredCache_check_domain_name(fake_client: CloudSQLClient) -> None: + """ + Test that MonitoredCache is closed when _check_domain_name has domain change. + """ + conn_name = ConnectionName( + "my-project", "my-region", "my-instance", "db.example.com" + ) + cache = LazyRefreshCache( + conn_name, + client=fake_client, + keys=asyncio.create_task(generate_keys()), + enable_iam_auth=False, + ) + # Patch DNS resolution with valid TXT records + with patch("dns.asyncresolver.Resolver.resolve") as mock_connect: + answer = dns.resolver.Answer( + "db.example.com", + dns.rdatatype.TXT, + dns.rdataclass.IN, + dns.message.from_text(query_text), + ) + mock_connect.return_value = answer + resolver = DnsResolver() + resolver.port = 5053 + + # configure a local socket + ip_addr = "127.0.0.1" + context = await create_ssl_context() + sock = context.wrap_socket( + socket.create_connection((ip_addr, 3307)), + server_hostname=ip_addr, + do_handshake_on_connect=False, + ) + # verify socket is open + assert sock.fileno() != -1 + # set failover to 0 to disable polling + monitored_cache = MonitoredCache(cache, 0, resolver) + # add socket to cache + monitored_cache.sockets = [sock] + # check cache is not closed + assert monitored_cache.closed is False + # call _check_domain_name and verify cache is closed + await monitored_cache._check_domain_name() + assert monitored_cache.closed is True + # verify socket was closed + assert sock.fileno() == -1 + + +@pytest.mark.usefixtures("server") +async def test_MonitoredCache_purge_closed_sockets(fake_client: CloudSQLClient) -> None: + """ + Test that MonitoredCache._purge_closed_sockets removes closed sockets from + cache. + """ + conn_name = ConnectionName( + "my-project", "my-region", "my-instance", "db.example.com" + ) + cache = LazyRefreshCache( + conn_name, + client=fake_client, + keys=asyncio.create_task(generate_keys()), + enable_iam_auth=False, + ) + # configure a local socket + ip_addr = "127.0.0.1" + context = await create_ssl_context() + sock = context.wrap_socket( + socket.create_connection((ip_addr, 3307)), + server_hostname=ip_addr, + do_handshake_on_connect=False, + ) + + # set failover to 0 to disable polling + monitored_cache = MonitoredCache(cache, 0, DnsResolver()) + # verify socket is open + assert sock.fileno() != -1 + # add socket to cache + monitored_cache.sockets = [sock] + # call _purge_closed_sockets and verify socket remains + monitored_cache._purge_closed_sockets() + # verify socket is still open + assert sock.fileno() != -1 + assert len(monitored_cache.sockets) == 1 + # close socket + sock.close() + # call _purge_closed_sockets and verify socket is clsoed + monitored_cache._purge_closed_sockets() + assert len(monitored_cache.sockets) == 0 + assert sock.fileno() == -1 From c6b74e83bacdf5812c621a79b8cc0ea3c047352c Mon Sep 17 00:00:00 2001 From: jackwotherspoon Date: Thu, 20 Mar 2025 19:01:59 +0000 Subject: [PATCH 13/16] chore: update header --- tests/unit/test_monitored_cache.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/test_monitored_cache.py b/tests/unit/test_monitored_cache.py index 976b4e0c..dc72a417 100644 --- a/tests/unit/test_monitored_cache.py +++ b/tests/unit/test_monitored_cache.py @@ -1,4 +1,4 @@ -# Copyright 2024 Google LLC +# Copyright 2025 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. From 9c4d4d1044f40e3ff6eeabf7a6050d60ff9bddcd Mon Sep 17 00:00:00 2001 From: jackwotherspoon Date: Thu, 20 Mar 2025 19:04:07 +0000 Subject: [PATCH 14/16] chore: update typo --- tests/unit/test_monitored_cache.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/unit/test_monitored_cache.py b/tests/unit/test_monitored_cache.py index dc72a417..1d648f1d 100644 --- a/tests/unit/test_monitored_cache.py +++ b/tests/unit/test_monitored_cache.py @@ -212,7 +212,6 @@ async def test_MonitoredCache_purge_closed_sockets(fake_client: CloudSQLClient) assert len(monitored_cache.sockets) == 1 # close socket sock.close() - # call _purge_closed_sockets and verify socket is clsoed + # call _purge_closed_sockets and verify socket is removed monitored_cache._purge_closed_sockets() assert len(monitored_cache.sockets) == 0 - assert sock.fileno() == -1 From ce2c30aa4f141f57374fae2be81c7ad89b5fb5d0 Mon Sep 17 00:00:00 2001 From: jackwotherspoon Date: Thu, 20 Mar 2025 20:00:25 +0000 Subject: [PATCH 15/16] chore: review comments --- google/cloud/sql/connector/connector.py | 5 ++++- google/cloud/sql/connector/monitored_cache.py | 6 +++--- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/google/cloud/sql/connector/connector.py b/google/cloud/sql/connector/connector.py index 2a7c6c96..c76092a4 100755 --- a/google/cloud/sql/connector/connector.py +++ b/google/cloud/sql/connector/connector.py @@ -295,7 +295,10 @@ async def connect_async( enable_iam_auth = kwargs.pop("enable_iam_auth", self._enable_iam_auth) conn_name = await self._resolver.resolve(instance_connection_string) - if (str(conn_name), enable_iam_auth) in self._cache: + # Cache entry must exist and not be closed + if (str(conn_name), enable_iam_auth) in self._cache and not self._cache[ + (str(conn_name), enable_iam_auth) + ].closed: monitored_cache = self._cache[(str(conn_name), enable_iam_auth)] else: if self._refresh_strategy == RefreshStrategy.LAZY: diff --git a/google/cloud/sql/connector/monitored_cache.py b/google/cloud/sql/connector/monitored_cache.py index 4c119cd3..1db3cc70 100644 --- a/google/cloud/sql/connector/monitored_cache.py +++ b/google/cloud/sql/connector/monitored_cache.py @@ -116,6 +116,9 @@ async def close(self) -> None: if self.closed: return + # Close underyling ConnectionInfoCache + await self.cache.close() + # Close any still open sockets for socket in self.sockets: # Check fileno for if socket is closed. Will return @@ -123,9 +126,6 @@ async def close(self) -> None: if socket.fileno() != -1: socket.close() - # Close underyling ConnectionInfoCache - await self.cache.close() - async def ticker(interval: int, function: Callable, *args: Any, **kwargs: Any) -> None: """ From 0998f8dbdfdea43aa4d23f501cd504818bc7feeb Mon Sep 17 00:00:00 2001 From: jackwotherspoon Date: Fri, 21 Mar 2025 16:33:08 +0000 Subject: [PATCH 16/16] chore: update based on feedback --- google/cloud/sql/connector/exceptions.py | 7 ++++++ google/cloud/sql/connector/monitored_cache.py | 8 +++++++ tests/unit/test_monitored_cache.py | 23 +++++++++++++++++++ 3 files changed, 38 insertions(+) diff --git a/google/cloud/sql/connector/exceptions.py b/google/cloud/sql/connector/exceptions.py index 92e3e566..da39ea25 100644 --- a/google/cloud/sql/connector/exceptions.py +++ b/google/cloud/sql/connector/exceptions.py @@ -77,3 +77,10 @@ class DnsResolutionError(Exception): Exception to be raised when an instance connection name can not be resolved from a DNS record. """ + + +class CacheClosedError(Exception): + """ + Exception to be raised when a ConnectionInfoCache can not be accessed after + it is closed. + """ diff --git a/google/cloud/sql/connector/monitored_cache.py b/google/cloud/sql/connector/monitored_cache.py index 1db3cc70..0c3fc4d0 100644 --- a/google/cloud/sql/connector/monitored_cache.py +++ b/google/cloud/sql/connector/monitored_cache.py @@ -19,6 +19,7 @@ from google.cloud.sql.connector.connection_info import ConnectionInfo from google.cloud.sql.connector.connection_info import ConnectionInfoCache +from google.cloud.sql.connector.exceptions import CacheClosedError from google.cloud.sql.connector.instance import RefreshAheadCache from google.cloud.sql.connector.lazy import LazyRefreshCache from google.cloud.sql.connector.resolver import DefaultResolver @@ -95,9 +96,16 @@ async def _check_domain_name(self) -> None: ) async def connect_info(self) -> ConnectionInfo: + if self.closed: + raise CacheClosedError( + "Can not get connection info, cache has already been closed." + ) return await self.cache.connect_info() async def force_refresh(self) -> None: + # if cache is closed do not refresh + if self.closed: + return return await self.cache.force_refresh() async def close(self) -> None: diff --git a/tests/unit/test_monitored_cache.py b/tests/unit/test_monitored_cache.py index 1d648f1d..1eea4eb4 100644 --- a/tests/unit/test_monitored_cache.py +++ b/tests/unit/test_monitored_cache.py @@ -25,6 +25,7 @@ from google.cloud.sql.connector.client import CloudSQLClient from google.cloud.sql.connector.connection_name import ConnectionName +from google.cloud.sql.connector.exceptions import CacheClosedError from google.cloud.sql.connector.lazy import LazyRefreshCache from google.cloud.sql.connector.monitored_cache import MonitoredCache from google.cloud.sql.connector.resolver import DefaultResolver @@ -65,6 +66,28 @@ async def test_MonitoredCache_properties(fake_client: CloudSQLClient) -> None: assert monitored_cache.closed is True +async def test_MonitoredCache_CacheClosedError(fake_client: CloudSQLClient) -> None: + """ + Test that MonitoredCache.connect_info errors when cache is closed. + """ + conn_name = ConnectionName("test-project", "test-region", "test-instance") + cache = LazyRefreshCache( + conn_name, + client=fake_client, + keys=asyncio.create_task(generate_keys()), + enable_iam_auth=False, + ) + monitored_cache = MonitoredCache(cache, 30, DefaultResolver()) + # test closed property + assert monitored_cache.closed is False + # close cache and make sure property is updated + await monitored_cache.close() + assert monitored_cache.closed is True + # attempt to get connect info from closed cache + with pytest.raises(CacheClosedError): + await monitored_cache.connect_info() + + async def test_MonitoredCache_with_DnsResolver(fake_client: CloudSQLClient) -> None: """ Test that MonitoredCache with DnsResolver work as expected.