From bbd29ad5a19b6024a4d0516af6454e7ddb78073b Mon Sep 17 00:00:00 2001 From: Jonathan Hess Date: Tue, 11 Mar 2025 21:43:16 -0600 Subject: [PATCH] refactor: Use new ConnectSettings.DnsNames field to validate the server TLS certificate. --- .gitignore | 3 +++ google/cloud/sql/connector/client.py | 21 +++++++++++++++++---- tests/unit/mocks.py | 14 +++++++++++++- tests/unit/test_client.py | 22 ++++++++++++++++++++++ 4 files changed, 55 insertions(+), 5 deletions(-) diff --git a/.gitignore b/.gitignore index 9ef6a906..9f449ce4 100644 --- a/.gitignore +++ b/.gitignore @@ -5,3 +5,6 @@ venv .python-version cloud_sql_python_connector.egg-info/ dist/ +.idea +.coverage +sponge_log.xml diff --git a/google/cloud/sql/connector/client.py b/google/cloud/sql/connector/client.py index 8a31eb9a..556a01bd 100644 --- a/google/cloud/sql/connector/client.py +++ b/google/cloud/sql/connector/client.py @@ -156,10 +156,23 @@ async def _get_metadata( # resolve dnsName into IP address for PSC # Note that we have to check for PSC enablement also because CAS # instances also set the dnsName field. - # Remove trailing period from DNS name. Required for SSL in Python - dns_name = ret_dict.get("dnsName", "").rstrip(".") - if dns_name and ret_dict.get("pscEnabled"): - ip_addresses["PSC"] = dns_name + if ret_dict.get("pscEnabled"): + # Find PSC instance DNS name in the dns_names field + psc_dns_names = [ + d["name"] + for d in ret_dict.get("dnsNames", []) + if d["connectionType"] == "PRIVATE_SERVICE_CONNECT" + and d["dnsScope"] == "INSTANCE" + ] + dns_name = psc_dns_names[0] if psc_dns_names else None + + # Fall back do dns_name field if dns_names is not set + if dns_name is None: + dns_name = ret_dict.get("dnsName", None) + + # Remove trailing period from DNS name. Required for SSL in Python + if dns_name: + ip_addresses["PSC"] = dns_name.rstrip(".") return { "ip_addresses": ip_addresses, diff --git a/tests/unit/mocks.py b/tests/unit/mocks.py index 5d863677..cd3299b7 100644 --- a/tests/unit/mocks.py +++ b/tests/unit/mocks.py @@ -225,6 +225,7 @@ def __init__( "PRIMARY": "127.0.0.1", "PRIVATE": "10.0.0.1", }, + legacy_dns_name: bool = False, cert_before: datetime = datetime.datetime.now(datetime.timezone.utc), cert_expiration: datetime = datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(hours=1), @@ -237,6 +238,7 @@ def __init__( self.psc_enabled = False self.cert_before = cert_before self.cert_expiration = cert_expiration + self.legacy_dns_name = legacy_dns_name # create self signed CA cert self.server_ca, self.server_key = generate_cert( self.project, self.name, cert_before, cert_expiration @@ -255,12 +257,22 @@ async def connect_settings(self, request: Any) -> web.Response: "instance": self.name, "expirationTime": str(self.cert_expiration), }, - "dnsName": "abcde.12345.us-central1.sql.goog", "pscEnabled": self.psc_enabled, "ipAddresses": ip_addrs, "region": self.region, "databaseVersion": self.db_version, } + if self.legacy_dns_name: + response["dnsName"] = "abcde.12345.us-central1.sql.goog" + else: + response["dnsNames"] = [ + { + "name": "abcde.12345.us-central1.sql.goog", + "connectionType": "PRIVATE_SERVICE_CONNECT", + "dnsScope": "INSTANCE", + } + ] + return web.Response(content_type="application/json", body=json.dumps(response)) async def generate_ephemeral(self, request: Any) -> web.Response: diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index af42af0a..cfe50947 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -65,6 +65,28 @@ async def test_get_metadata_with_psc(fake_client: CloudSQLClient) -> None: assert isinstance(resp["server_ca_cert"], str) +@pytest.mark.asyncio +async def test_get_metadata_legacy_dns_with_psc(fake_client: CloudSQLClient) -> None: + """ + Test _get_metadata returns successfully with PSC IP type. + """ + # set PSC to enabled on test instance + fake_client.instance.psc_enabled = True + fake_client.instance.legacy_dns_name = True + resp = await fake_client._get_metadata( + "test-project", + "test-region", + "test-instance", + ) + assert resp["database_version"] == "POSTGRES_15" + assert resp["ip_addresses"] == { + "PRIMARY": "127.0.0.1", + "PRIVATE": "10.0.0.1", + "PSC": "abcde.12345.us-central1.sql.goog", + } + assert isinstance(resp["server_ca_cert"], str) + + @pytest.mark.asyncio async def test_get_ephemeral(fake_client: CloudSQLClient) -> None: """