diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index e710138f..b8e6eb58 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -81,6 +81,7 @@ jobs: POSTGRES_CAS_PASS:${{ vars.GOOGLE_CLOUD_PROJECT }}/POSTGRES_CAS_PASS POSTGRES_CUSTOMER_CAS_CONNECTION_NAME:${{ vars.GOOGLE_CLOUD_PROJECT }}/POSTGRES_CUSTOMER_CAS_CONNECTION_NAME POSTGRES_CUSTOMER_CAS_PASS:${{ vars.GOOGLE_CLOUD_PROJECT }}/POSTGRES_CUSTOMER_CAS_PASS + POSTGRES_CUSTOMER_CAS_PASS_VALID_DOMAIN_NAME:${{ vars.GOOGLE_CLOUD_PROJECT }}/POSTGRES_CUSTOMER_CAS_PASS_VALID_DOMAIN_NAME SQLSERVER_CONNECTION_NAME:${{ vars.GOOGLE_CLOUD_PROJECT }}/SQLSERVER_CONNECTION_NAME SQLSERVER_USER:${{ vars.GOOGLE_CLOUD_PROJECT }}/SQLSERVER_USER SQLSERVER_PASS:${{ vars.GOOGLE_CLOUD_PROJECT }}/SQLSERVER_PASS @@ -102,6 +103,7 @@ jobs: POSTGRES_CAS_PASS: "${{ steps.secrets.outputs.POSTGRES_CAS_PASS }}" POSTGRES_CUSTOMER_CAS_CONNECTION_NAME: "${{ steps.secrets.outputs.POSTGRES_CUSTOMER_CAS_CONNECTION_NAME }}" POSTGRES_CUSTOMER_CAS_PASS: "${{ steps.secrets.outputs.POSTGRES_CUSTOMER_CAS_PASS }}" + POSTGRES_CUSTOMER_CAS_PASS_VALID_DOMAIN_NAME: "${{ steps.secrets.outputs.POSTGRES_CUSTOMER_CAS_PASS_VALID_DOMAIN_NAME }}" SQLSERVER_CONNECTION_NAME: "${{ steps.secrets.outputs.SQLSERVER_CONNECTION_NAME }}" SQLSERVER_USER: "${{ steps.secrets.outputs.SQLSERVER_USER }}" SQLSERVER_PASS: "${{ steps.secrets.outputs.SQLSERVER_PASS }}" diff --git a/tests/system/test_asyncpg_connection.py b/tests/system/test_asyncpg_connection.py index 8de14d57..dfcc3941 100644 --- a/tests/system/test_asyncpg_connection.py +++ b/tests/system/test_asyncpg_connection.py @@ -16,13 +16,15 @@ import asyncio import os -from typing import Any +from typing import Any, Union import asyncpg import sqlalchemy import sqlalchemy.ext.asyncio from google.cloud.sql.connector import Connector +from google.cloud.sql.connector import DefaultResolver +from google.cloud.sql.connector import DnsResolver async def create_sqlalchemy_engine( @@ -31,6 +33,7 @@ async def create_sqlalchemy_engine( password: str, db: str, refresh_strategy: str = "background", + resolver: Union[type[DefaultResolver], type[DnsResolver]] = DefaultResolver, ) -> tuple[sqlalchemy.ext.asyncio.engine.AsyncEngine, Connector]: """Creates a connection pool for a Cloud SQL instance and returns the pool and the connector. Callers are responsible for closing the pool and the @@ -64,9 +67,16 @@ async def create_sqlalchemy_engine( Refresh strategy for the Cloud SQL Connector. Can be one of "lazy" or "background". For serverless environments use "lazy" to avoid errors resulting from CPU being throttled. + resolver (Optional[google.cloud.sql.connector.DefaultResolver]): + Resolver class for resolving instance connection name. Use + google.cloud.sql.connector.DnsResolver when resolving DNS domain + names or google.cloud.sql.connector.DefaultResolver for regular + instance connection names ("my-project:my-region:my-instance"). """ loop = asyncio.get_running_loop() - connector = Connector(loop=loop, refresh_strategy=refresh_strategy) + connector = Connector( + loop=loop, refresh_strategy=refresh_strategy, resolver=resolver + ) async def getconn() -> asyncpg.Connection: conn: asyncpg.Connection = await connector.connect_async( @@ -183,6 +193,24 @@ async def test_lazy_sqlalchemy_connection_with_asyncpg() -> None: await connector.close_async() +async def test_custom_SAN_with_dns_sqlalchemy_connection_with_asyncpg() -> None: + """Basic test to get time from database.""" + inst_conn_name = os.environ["POSTGRES_CUSTOMER_CAS_PASS_VALID_DOMAIN_NAME"] + user = os.environ["POSTGRES_USER"] + password = os.environ["POSTGRES_CUSTOMER_CAS_PASS"] + db = os.environ["POSTGRES_DB"] + + pool, connector = await create_sqlalchemy_engine( + inst_conn_name, user, password, db, resolver=DnsResolver + ) + + async with pool.connect() as conn: + res = (await conn.execute(sqlalchemy.text("SELECT 1"))).fetchone() + assert res[0] == 1 + + await connector.close_async() + + async def test_connection_with_asyncpg() -> None: """Basic test to get time from database.""" inst_conn_name = os.environ["POSTGRES_CONNECTION_NAME"] diff --git a/tests/system/test_pg8000_connection.py b/tests/system/test_pg8000_connection.py index b56a8e82..c47b860c 100644 --- a/tests/system/test_pg8000_connection.py +++ b/tests/system/test_pg8000_connection.py @@ -18,10 +18,14 @@ import os # [START cloud_sql_connector_postgres_pg8000] +from typing import Union + import pg8000 import sqlalchemy from google.cloud.sql.connector import Connector +from google.cloud.sql.connector import DefaultResolver +from google.cloud.sql.connector import DnsResolver def create_sqlalchemy_engine( @@ -30,6 +34,7 @@ def create_sqlalchemy_engine( password: str, db: str, refresh_strategy: str = "background", + resolver: Union[type[DefaultResolver], type[DnsResolver]] = DefaultResolver, ) -> tuple[sqlalchemy.engine.Engine, Connector]: """Creates a connection pool for a Cloud SQL instance and returns the pool and the connector. Callers are responsible for closing the pool and the @@ -64,8 +69,13 @@ def create_sqlalchemy_engine( Refresh strategy for the Cloud SQL Connector. Can be one of "lazy" or "background". For serverless environments use "lazy" to avoid errors resulting from CPU being throttled. + resolver (Optional[google.cloud.sql.connector.DefaultResolver]): + Resolver class for resolving instance connection name. Use + google.cloud.sql.connector.DnsResolver when resolving DNS domain + names or google.cloud.sql.connector.DefaultResolver for regular + instance connection names ("my-project:my-region:my-instance"). """ - connector = Connector(refresh_strategy=refresh_strategy) + connector = Connector(refresh_strategy=refresh_strategy, resolver=resolver) def getconn() -> pg8000.dbapi.Connection: conn: pg8000.dbapi.Connection = connector.connect( @@ -153,3 +163,21 @@ def test_customer_managed_CAS_pg8000_connection() -> None: curr_time = time[0] assert type(curr_time) is datetime connector.close() + + +def test_custom_SAN_with_dns_pg8000_connection() -> None: + """Basic test to get time from database.""" + inst_conn_name = os.environ["POSTGRES_CUSTOMER_CAS_PASS_VALID_DOMAIN_NAME"] + user = os.environ["POSTGRES_USER"] + password = os.environ["POSTGRES_CUSTOMER_CAS_PASS"] + db = os.environ["POSTGRES_DB"] + + engine, connector = create_sqlalchemy_engine( + inst_conn_name, user, password, db, resolver=DnsResolver + ) + with engine.connect() as conn: + time = conn.execute(sqlalchemy.text("SELECT NOW()")).fetchone() + conn.commit() + curr_time = time[0] + assert type(curr_time) is datetime + connector.close()