Skip to content

Commit 6d0fe48

Browse files
committed
fix vector search tests again
1 parent d611337 commit 6d0fe48

File tree

1 file changed

+42
-20
lines changed

1 file changed

+42
-20
lines changed

tests/system/small/bigquery/test_vector_search.py

Lines changed: 42 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -123,12 +123,17 @@ def test_vector_search_basic_params_with_df():
123123
"embedding": [[1.0, 2.0], [3.0, 5.2]],
124124
}
125125
)
126-
vector_search_result = bbq.vector_search(
127-
base_table="bigframes-dev.bigframes_tests_sys.base_table",
128-
column_to_search="my_embedding",
129-
query=search_query,
130-
top_k=2,
131-
).to_pandas() # type:ignore
126+
vector_search_result = (
127+
bbq.vector_search(
128+
base_table="bigframes-dev.bigframes_tests_sys.base_table",
129+
column_to_search="my_embedding",
130+
query=search_query,
131+
top_k=2,
132+
)
133+
.sort_values("distance")
134+
.sort_index()
135+
.to_pandas()
136+
) # type:ignore
132137
expected = pd.DataFrame(
133138
{
134139
"query_id": ["cat", "dog", "dog", "cat"],
@@ -173,22 +178,39 @@ def test_vector_search_different_params_with_query(session):
173178
base_table = base_df.to_gbq()
174179
try:
175180
search_query = bpd.Series([[0.75, 0.25], [-0.25, -0.75]], session=session)
176-
vector_search_result = bbq.vector_search(
177-
base_table=base_table,
178-
column_to_search="my_embedding",
179-
query=search_query,
180-
distance_type="cosine",
181-
top_k=2,
182-
).to_pandas() # type:ignore
181+
vector_search_result = (
182+
bbq.vector_search(
183+
base_table=base_table,
184+
column_to_search="my_embedding",
185+
query=search_query,
186+
distance_type="cosine",
187+
top_k=2,
188+
)
189+
.sort_values("distance")
190+
.sort_index()
191+
.to_pandas()
192+
) # type:ignore
183193
expected = pd.DataFrame(
184194
{
185-
"0": {np.int64(0): [0.75, 0.25], np.int64(1): [-0.25, -0.75]},
186-
"id": {np.int64(0): 1, np.int64(1): 4},
187-
"my_embedding": {np.int64(0): [0.0, 1.0], np.int64(1): [-1.0, 0.0]},
188-
"distance": {
189-
np.int64(0): 0.683772233983162,
190-
np.int64(1): 0.683772233983162,
191-
},
195+
"0": [
196+
[0.75, 0.25],
197+
[0.75, 0.25],
198+
[-0.25, -0.75],
199+
[-0.25, -0.75],
200+
],
201+
"id": [2, 1, 3, 4],
202+
"my_embedding": [
203+
[1.0, 0.0],
204+
[0.0, 1.0],
205+
[0.0, -1.0],
206+
[-1.0, 0.0],
207+
],
208+
"distance": [
209+
0.051317,
210+
0.683772,
211+
0.051317,
212+
0.683772,
213+
],
192214
},
193215
index=pd.Index([0, 0, 1, 1], dtype="Int64"),
194216
)

0 commit comments

Comments
 (0)