From 6ce9c4009aaea0e52b42534ad0e1ea2a2d1c432a Mon Sep 17 00:00:00 2001 From: jackwotherspoon Date: Sat, 7 Dec 2024 00:57:06 +0000 Subject: [PATCH 1/2] chore: pass ConnectionName over individul args --- google/cloud/sql/connector/client.py | 22 +++++++++---------- google/cloud/sql/connector/connection_info.py | 2 ++ google/cloud/sql/connector/connector.py | 15 ++++--------- google/cloud/sql/connector/instance.py | 13 +++-------- google/cloud/sql/connector/lazy.py | 10 +-------- google/cloud/sql/connector/resolver.py | 5 ++++- tests/unit/test_connection_name.py | 8 +++---- 7 files changed, 27 insertions(+), 48 deletions(-) diff --git a/google/cloud/sql/connector/client.py b/google/cloud/sql/connector/client.py index 61b77d56..8a31eb9a 100644 --- a/google/cloud/sql/connector/client.py +++ b/google/cloud/sql/connector/client.py @@ -26,6 +26,7 @@ from google.auth.transport import requests from google.cloud.sql.connector.connection_info import ConnectionInfo +from google.cloud.sql.connector.connection_name import ConnectionName from google.cloud.sql.connector.exceptions import AutoIAMAuthNotSupported from google.cloud.sql.connector.refresh_utils import _downscope_credentials from google.cloud.sql.connector.refresh_utils import retry_50x @@ -245,9 +246,7 @@ async def _get_ephemeral( async def get_connection_info( self, - project: str, - region: str, - instance: str, + conn_name: ConnectionName, keys: asyncio.Future, enable_iam_auth: bool, ) -> ConnectionInfo: @@ -255,10 +254,8 @@ async def get_connection_info( Admin API. Args: - project (str): The name of the project the Cloud SQL instance is - located in. - region (str): The region the Cloud SQL instance is located in. - instance (str): Name of the Cloud SQL instance. + conn_name (ConnectionName): The Cloud SQL instance's + connection name. keys (asyncio.Future): A future to the client's public-private key pair. enable_iam_auth (bool): Whether an automatic IAM database @@ -278,16 +275,16 @@ async def get_connection_info( metadata_task = asyncio.create_task( self._get_metadata( - project, - region, - instance, + conn_name.project, + conn_name.region, + conn_name.instance_name, ) ) ephemeral_task = asyncio.create_task( self._get_ephemeral( - project, - instance, + conn_name.project, + conn_name.instance_name, pub_key, enable_iam_auth, ) @@ -311,6 +308,7 @@ async def get_connection_info( ephemeral_cert, expiration = await ephemeral_task return ConnectionInfo( + conn_name, ephemeral_cert, metadata["server_ca_cert"], priv_key, diff --git a/google/cloud/sql/connector/connection_info.py b/google/cloud/sql/connector/connection_info.py index b738063c..82e3a901 100644 --- a/google/cloud/sql/connector/connection_info.py +++ b/google/cloud/sql/connector/connection_info.py @@ -21,6 +21,7 @@ from aiofiles.tempfile import TemporaryDirectory +from google.cloud.sql.connector.connection_name import ConnectionName from google.cloud.sql.connector.exceptions import CloudSQLIPTypeError from google.cloud.sql.connector.exceptions import TLSVersionError from google.cloud.sql.connector.utils import write_to_file @@ -38,6 +39,7 @@ class ConnectionInfo: """Contains all necessary information to connect securely to the server-side Proxy running on a Cloud SQL instance.""" + conn_name: ConnectionName client_cert: str server_ca_cert: str private_key: bytes diff --git a/google/cloud/sql/connector/connector.py b/google/cloud/sql/connector/connector.py index 1e67373e..7c5fb20d 100755 --- a/google/cloud/sql/connector/connector.py +++ b/google/cloud/sql/connector/connector.py @@ -113,7 +113,6 @@ def __init__( name. To resolve a DNS record to an instance connection name, use DnsResolver. Default: DefaultResolver - """ # if refresh_strategy is str, convert to RefreshStrategy enum if isinstance(refresh_strategy, str): @@ -283,8 +282,7 @@ async def connect_async( conn_name = await self._resolver.resolve(instance_connection_string) if self._refresh_strategy == RefreshStrategy.LAZY: logger.debug( - f"['{instance_connection_string}']: Refresh strategy is set" - " to lazy refresh" + f"['{conn_name}']: Refresh strategy is set to lazy refresh" ) cache = LazyRefreshCache( conn_name, @@ -294,8 +292,7 @@ async def connect_async( ) else: logger.debug( - f"['{instance_connection_string}']: Refresh strategy is set" - " to backgound refresh" + f"['{conn_name}']: Refresh strategy is set to backgound refresh" ) cache = RefreshAheadCache( conn_name, @@ -303,9 +300,7 @@ async def connect_async( self._keys, enable_iam_auth, ) - logger.debug( - f"['{instance_connection_string}']: Connection info added to cache" - ) + logger.debug(f"['{conn_name}']: Connection info added to cache") self._cache[(instance_connection_string, enable_iam_auth)] = cache connect_func = { @@ -344,9 +339,7 @@ async def connect_async( # the cache and re-raise the error await self._remove_cached(instance_connection_string, enable_iam_auth) raise - logger.debug( - f"['{instance_connection_string}']: Connecting to {ip_address}:3307" - ) + logger.debug(f"['{conn_info.conn_name}']: Connecting to {ip_address}:3307") # format `user` param for automatic IAM database authn if enable_iam_auth: formatted_user = format_database_user( diff --git a/google/cloud/sql/connector/instance.py b/google/cloud/sql/connector/instance.py index 3b0b9263..5df272fe 100644 --- a/google/cloud/sql/connector/instance.py +++ b/google/cloud/sql/connector/instance.py @@ -62,11 +62,6 @@ def __init__( (Postgres and MySQL) as the default authentication method for all connections. """ - self._project, self._region, self._instance = ( - conn_name.project, - conn_name.region, - conn_name.instance_name, - ) self._conn_name = conn_name self._enable_iam_auth = enable_iam_auth @@ -104,20 +99,18 @@ async def _perform_refresh(self) -> ConnectionInfo: """ self._refresh_in_progress.set() logger.debug( - f"['{self._conn_name}']: Connection info refresh " "operation started" + f"['{self._conn_name}']: Connection info refresh operation started" ) try: await self._refresh_rate_limiter.acquire() connection_info = await self._client.get_connection_info( - self._project, - self._region, - self._instance, + self._conn_name, self._keys, self._enable_iam_auth, ) logger.debug( - f"['{self._conn_name}']: Connection info " "refresh operation complete" + f"['{self._conn_name}']: Connection info refresh operation complete" ) logger.debug( f"['{self._conn_name}']: Current certificate " diff --git a/google/cloud/sql/connector/lazy.py b/google/cloud/sql/connector/lazy.py index ab73785d..1bc4f90f 100644 --- a/google/cloud/sql/connector/lazy.py +++ b/google/cloud/sql/connector/lazy.py @@ -55,13 +55,7 @@ def __init__( (Postgres and MySQL) as the default authentication method for all connections. """ - self._project, self._region, self._instance = ( - conn_name.project, - conn_name.region, - conn_name.instance_name, - ) self._conn_name = conn_name - self._enable_iam_auth = enable_iam_auth self._keys = keys self._client = client @@ -101,9 +95,7 @@ async def connect_info(self) -> ConnectionInfo: ) try: conn_info = await self._client.get_connection_info( - self._project, - self._region, - self._instance, + self._conn_name, self._keys, self._enable_iam_auth, ) diff --git a/google/cloud/sql/connector/resolver.py b/google/cloud/sql/connector/resolver.py index 2cdcddbe..2af6ff2f 100644 --- a/google/cloud/sql/connector/resolver.py +++ b/google/cloud/sql/connector/resolver.py @@ -15,6 +15,9 @@ import dns.asyncresolver from google.cloud.sql.connector.connection_name import _parse_connection_name +from google.cloud.sql.connector.connection_name import ( + _parse_connection_name_with_domain_name, +) from google.cloud.sql.connector.connection_name import ConnectionName from google.cloud.sql.connector.exceptions import DnsResolutionError @@ -52,7 +55,7 @@ async def query_dns(self, dns: str) -> ConnectionName: # Attempt to parse records, returning the first valid record. for record in rdata: try: - conn_name = _parse_connection_name(record) + conn_name = _parse_connection_name_with_domain_name(record, dns) return conn_name except Exception: continue diff --git a/tests/unit/test_connection_name.py b/tests/unit/test_connection_name.py index a62f88d5..ca5bdf45 100644 --- a/tests/unit/test_connection_name.py +++ b/tests/unit/test_connection_name.py @@ -14,14 +14,12 @@ import pytest # noqa F401 Needed to run the tests -# fmt: off from google.cloud.sql.connector.connection_name import _parse_connection_name -from google.cloud.sql.connector.connection_name import \ - _parse_connection_name_with_domain_name +from google.cloud.sql.connector.connection_name import ( + _parse_connection_name_with_domain_name, +) from google.cloud.sql.connector.connection_name import ConnectionName -# fmt: on - def test_ConnectionName() -> None: conn_name = ConnectionName("project", "region", "instance") From 70aa081dab00c7c207d5ebaa3ef435863e988512 Mon Sep 17 00:00:00 2001 From: jackwotherspoon Date: Sat, 7 Dec 2024 02:23:42 +0000 Subject: [PATCH 2/2] chore: fix tests --- tests/unit/test_instance.py | 8 ++++---- tests/unit/test_resolver.py | 5 ++++- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/tests/unit/test_instance.py b/tests/unit/test_instance.py index f80bb149..aeedf339 100644 --- a/tests/unit/test_instance.py +++ b/tests/unit/test_instance.py @@ -47,9 +47,9 @@ async def test_Instance_init( can tell if the connection string that's passed in is formatted correctly. """ assert ( - cache._project == "test-project" - and cache._region == "test-region" - and cache._instance == "test-instance" + cache._conn_name.project == "test-project" + and cache._conn_name.region == "test-region" + and cache._conn_name.instance_name == "test-instance" ) assert cache._enable_iam_auth is False @@ -283,7 +283,7 @@ async def test_AutoIAMAuthNotSupportedError(fake_client: CloudSQLClient) -> None async def test_ConnectionInfo_caches_sslcontext() -> None: info = ConnectionInfo( - "cert", "cert", "key".encode(), {}, "POSTGRES", datetime.datetime.now() + "", "cert", "cert", "key".encode(), {}, "POSTGRES", datetime.datetime.now() ) # context should default to None assert info.context is None diff --git a/tests/unit/test_resolver.py b/tests/unit/test_resolver.py index d7404890..a9a7f263 100644 --- a/tests/unit/test_resolver.py +++ b/tests/unit/test_resolver.py @@ -26,6 +26,9 @@ conn_str = "my-project:my-region:my-instance" conn_name = ConnectionName("my-project", "my-region", "my-instance") +conn_name_with_domain = ConnectionName( + "my-project", "my-region", "my-instance", "db.example.com" +) async def test_DefaultResolver() -> None: @@ -74,7 +77,7 @@ async def test_DnsResolver_with_dns_name() -> None: resolver.port = 5053 # Resolution should return first value sorted alphabetically result = await resolver.resolve("db.example.com") - assert result == conn_name + assert result == conn_name_with_domain query_text_malformed = """id 1234