@@ -157,80 +157,43 @@ def test_vector_search_basic_params_with_df():
157157 )
158158
159159
160- def test_vector_search_different_params_with_query ():
161- search_query = bpd .Series ([[1.0 , 2.0 ], [3.0 , 5.2 ]])
162- vector_search_result = bbq .vector_search (
163- base_table = "bigframes-dev.bigframes_tests_sys.base_table" ,
164- column_to_search = "my_embedding" ,
165- query = search_query ,
166- distance_type = "cosine" ,
167- top_k = 2 ,
168- ).to_pandas () # type:ignore
169- expected = pd .DataFrame (
160+ def test_vector_search_different_params_with_query (session ):
161+ base_df = bpd .DataFrame (
170162 {
171- "0" : [
172- np .array ([1.0 , 2.0 ]),
173- np .array ([1.0 , 2.0 ]),
174- np .array ([3.0 , 5.2 ]),
175- np .array ([3.0 , 5.2 ]),
176- ],
177- "id" : [2 , 1 , 1 , 2 ],
163+ "id" : [1 , 2 , 3 , 4 ],
178164 "my_embedding" : [
179- np .array ([2 .0 , 4 .0 ]),
180- np .array ([1.0 , 2 .0 ]),
181- np .array ([1 .0 , 2 .0 ]),
182- np .array ([2 .0 , 4 .0 ]),
165+ np .array ([0 .0 , 1 .0 ]),
166+ np .array ([1.0 , 0 .0 ]),
167+ np .array ([0 .0 , - 1 .0 ]),
168+ np .array ([- 1 .0 , 0 .0 ]),
183169 ],
184- "distance" : [0.0 , 0.0 , 0.001777 , 0.001777 ],
185170 },
186- index = pd .Index ([0 , 0 , 1 , 1 ], dtype = "Int64" ),
187- )
188- pd .testing .assert_frame_equal (
189- vector_search_result , expected , check_dtype = False , rtol = 0.1
190- )
191-
192-
193- def test_vector_search_df_with_query_column_to_search ():
194- search_query = bpd .DataFrame (
195- {
196- "query_id" : ["dog" , "cat" ],
197- "embedding" : [[1.0 , 2.0 ], [3.0 , 5.2 ]],
198- "another_embedding" : [[1.0 , 2.5 ], [3.3 , 5.2 ]],
199- }
200- )
201- vector_search_result = bbq .vector_search (
202- base_table = "bigframes-dev.bigframes_tests_sys.base_table" ,
203- column_to_search = "my_embedding" ,
204- query = search_query ,
205- query_column_to_search = "another_embedding" ,
206- top_k = 2 ,
207- ).to_pandas () # type:ignore
208- expected = pd .DataFrame (
209- {
210- "query_id" : ["dog" , "dog" , "cat" , "cat" ],
211- "embedding" : [
212- np .array ([1.0 , 2.0 ]),
213- np .array ([1.0 , 2.0 ]),
214- np .array ([3.0 , 5.2 ]),
215- np .array ([3.0 , 5.2 ]),
216- ],
217- "another_embedding" : [
218- np .array ([1.0 , 2.5 ]),
219- np .array ([1.0 , 2.5 ]),
220- np .array ([3.3 , 5.2 ]),
221- np .array ([3.3 , 5.2 ]),
222- ],
223- "id" : [1 , 4 , 2 , 5 ],
224- "my_embedding" : [
225- np .array ([1.0 , 2.0 ]),
226- np .array ([1.0 , 3.2 ]),
227- np .array ([2.0 , 4.0 ]),
228- np .array ([5.0 , 5.4 ]),
229- ],
230- "distance" : [0.5 , 0.7 , 1.769181 , 1.711724 ],
231- },
232- index = pd .Index ([0 , 0 , 1 , 1 ], dtype = "Int64" ),
233- )
234- pd .testing .assert_frame_equal (
235- vector_search_result , expected , check_dtype = False , rtol = 0.1
171+ session = session ,
236172 )
173+ base_table = base_df .to_gbq ()
174+ try :
175+ search_query = bpd .Series ([[0.75 , 0.25 ], [- 0.25 , - 0.75 ]], session = session )
176+ vector_search_result = bbq .vector_search (
177+ base_table = base_table ,
178+ column_to_search = "my_embedding" ,
179+ query = search_query ,
180+ distance_type = "cosine" ,
181+ top_k = 2 ,
182+ ).to_pandas () # type:ignore
183+ expected = pd .DataFrame (
184+ {
185+ "0" : {np .int64 (0 ): [0.75 , 0.25 ], np .int64 (1 ): [- 0.25 , - 0.75 ]},
186+ "id" : {np .int64 (0 ): 1 , np .int64 (1 ): 4 },
187+ "my_embedding" : {np .int64 (0 ): [0.0 , 1.0 ], np .int64 (1 ): [- 1.0 , 0.0 ]},
188+ "distance" : {
189+ np .int64 (0 ): 0.683772233983162 ,
190+ np .int64 (1 ): 0.683772233983162 ,
191+ },
192+ },
193+ index = pd .Index ([0 , 0 , 1 , 1 ], dtype = "Int64" ),
194+ )
195+ pd .testing .assert_frame_equal (
196+ vector_search_result , expected , check_dtype = False , rtol = 0.1
197+ )
198+ finally :
199+ session .bqclient .delete_table (base_table , not_found_ok = True )
0 commit comments