1818
1919import asyncpg # type: ignore
2020import conftest as conftest # python-docs-samples/alloydb/conftest.py
21+ from google .cloud .alloydb .connector import AsyncConnector , IPTypes
2122import pytest
2223import sqlalchemy
23- from google .cloud .alloydb .connector import AsyncConnector , IPTypes
2424from sqlalchemy .ext .asyncio import AsyncEngine , create_async_engine
2525
2626
@@ -43,7 +43,11 @@ async def _init_connection_pool(
4343 region : str ,
4444 password : str ,
4545) -> AsyncEngine :
46- connection_string = f"projects/{ project_id } /locations/{ region } /clusters/{ cluster_name } /instances/{ instance_name } "
46+ connection_string = (
47+ f"projects/{ project_id } /locations/"
48+ f"{ region } /clusters/{ cluster_name } /"
49+ f"instances/{ instance_name } "
50+ )
4751
4852 async def getconn () -> asyncpg .Connection :
4953 conn : asyncpg .Connection = await connector .connect (
@@ -71,27 +75,34 @@ async def test_embeddings_batch_processing(
7175 instance_name : str ,
7276 region : str ,
7377 database_name : str ,
74- username : str ,
7578 password : str ,
7679 table_name : str ,
7780) -> None :
81+ # TODO: Create new table
7882 # Populate the table with embeddings by running the notebook
7983 conftest .run_notebook (
8084 "embeddings_batch_processing.ipynb" ,
8185 variables = {
8286 "project_id" : project_id ,
8387 "cluster_name" : cluster_name ,
8488 "database_name" : database_name ,
85- "username" : username ,
8689 "region" : region ,
8790 "instance_name" : instance_name ,
8891 "table_name" : table_name ,
8992 },
9093 preprocess = preprocess ,
9194 skip_shell_commands = True ,
9295 replace = {
93- "password = input(\" Please provide a password to be used for 'postgres' database user: \" )" : f"password = '{ password } '" ,
94- "await create_db(database_name=database_name, connector=connector)" : "" ,
96+ (
97+ "password = input(\" Please provide "
98+ "a password to be used for 'postgres' "
99+ "database user: \" )"
100+ ): f"password = '{ password } '" ,
101+ (
102+ "await create_db("
103+ "database_name=database_name, "
104+ "connector=connector)"
105+ ): "" ,
95106 },
96107 until_end = True ,
97108 )
@@ -111,25 +122,35 @@ async def test_embeddings_batch_processing(
111122 # Validate that embeddings are non-empty for all rows
112123 result = await conn .execute (
113124 sqlalchemy .text (
114- f"SELECT COUNT(*) FROM { table_name } WHERE analysis_embedding IS NULL"
125+ f"SELECT COUNT(*) FROM "
126+ f"{ table_name } WHERE "
127+ f"analysis_embedding IS NULL"
115128 )
116129 )
117130 row = result .fetchone ()
118131 assert row [0 ] == 0
119132 result = await conn .execute (
120133 sqlalchemy .text (
121- f"SELECT COUNT(*) FROM { table_name } WHERE overview_embedding IS NULL"
134+ f"SELECT COUNT(*) FROM "
135+ f"{ table_name } WHERE "
136+ f"overview_embedding IS NULL"
122137 )
123138 )
124139 row = result .fetchone ()
125140 assert row [0 ] == 0
126141
127142 # Get the table back to the original state
128143 await conn .execute (
129- sqlalchemy .text (f"UPDATE { table_name } set analysis_embedding = NULL" )
144+ sqlalchemy .text (
145+ f"UPDATE { table_name } set "
146+ f"analysis_embedding = NULL"
147+ )
130148 )
131149 await conn .execute (
132- sqlalchemy .text (f"UPDATE { table_name } set overview_embedding = NULL" )
150+ sqlalchemy .text (
151+ f"UPDATE { table_name } set "
152+ f"overview_embedding = NULL"
153+ )
133154 )
134155 await conn .commit ()
135156 await pool .dispose ()
0 commit comments