Skip to content

Commit 58f45f8

Browse files
committed
fix vector search tests
1 parent 05c145c commit 58f45f8

File tree

4 files changed

+45
-74
lines changed

4 files changed

+45
-74
lines changed

bigframes/session/_io/bigquery/read_gbq_query.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,8 +67,11 @@ def create_dataframe_from_row_iterator(
6767
"""
6868
pa_table = rows.to_arrow()
6969
bq_schema = list(rows.schema)
70+
is_default_index = not index_col or isinstance(
71+
index_col, bigframes.enums.DefaultIndexKind
72+
)
7073

71-
if not index_col or isinstance(index_col, bigframes.enums.DefaultIndexKind):
74+
if is_default_index:
7275
# We get a sequential index for free, so use that if no index is specified.
7376
# TODO(tswast): Use array_value.promote_offsets() instead once that node is
7477
# supported by the local engine.
@@ -81,6 +84,7 @@ def create_dataframe_from_row_iterator(
8184
index_columns = (index_col,)
8285
index_labels = (index_col,)
8386
else:
87+
index_col = cast(Iterable[str], index_col)
8488
index_columns = tuple(index_col)
8589
index_labels = cast(Tuple[Optional[str], ...], tuple(index_col))
8690

@@ -111,4 +115,7 @@ def create_dataframe_from_row_iterator(
111115
if columns:
112116
df = df[list(columns)]
113117

118+
if not is_default_index:
119+
df = df.sort_index()
120+
114121
return df

bigframes/session/loader.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -956,6 +956,7 @@ def read_gbq_query(
956956
True if use_cache is None else use_cache
957957
)
958958

959+
_check_duplicates("columns", columns)
959960
index_cols = _to_index_cols(index_col)
960961
_check_index_col_param(index_cols, columns)
961962

tests/system/small/bigquery/test_vector_search.py

Lines changed: 35 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -157,80 +157,43 @@ def test_vector_search_basic_params_with_df():
157157
)
158158

159159

160-
def test_vector_search_different_params_with_query():
161-
search_query = bpd.Series([[1.0, 2.0], [3.0, 5.2]])
162-
vector_search_result = bbq.vector_search(
163-
base_table="bigframes-dev.bigframes_tests_sys.base_table",
164-
column_to_search="my_embedding",
165-
query=search_query,
166-
distance_type="cosine",
167-
top_k=2,
168-
).to_pandas() # type:ignore
169-
expected = pd.DataFrame(
160+
def test_vector_search_different_params_with_query(session):
161+
base_df = bpd.DataFrame(
170162
{
171-
"0": [
172-
np.array([1.0, 2.0]),
173-
np.array([1.0, 2.0]),
174-
np.array([3.0, 5.2]),
175-
np.array([3.0, 5.2]),
176-
],
177-
"id": [2, 1, 1, 2],
163+
"id": [1, 2, 3, 4],
178164
"my_embedding": [
179-
np.array([2.0, 4.0]),
180-
np.array([1.0, 2.0]),
181-
np.array([1.0, 2.0]),
182-
np.array([2.0, 4.0]),
165+
np.array([0.0, 1.0]),
166+
np.array([1.0, 0.0]),
167+
np.array([0.0, -1.0]),
168+
np.array([-1.0, 0.0]),
183169
],
184-
"distance": [0.0, 0.0, 0.001777, 0.001777],
185170
},
186-
index=pd.Index([0, 0, 1, 1], dtype="Int64"),
187-
)
188-
pd.testing.assert_frame_equal(
189-
vector_search_result, expected, check_dtype=False, rtol=0.1
190-
)
191-
192-
193-
def test_vector_search_df_with_query_column_to_search():
194-
search_query = bpd.DataFrame(
195-
{
196-
"query_id": ["dog", "cat"],
197-
"embedding": [[1.0, 2.0], [3.0, 5.2]],
198-
"another_embedding": [[1.0, 2.5], [3.3, 5.2]],
199-
}
200-
)
201-
vector_search_result = bbq.vector_search(
202-
base_table="bigframes-dev.bigframes_tests_sys.base_table",
203-
column_to_search="my_embedding",
204-
query=search_query,
205-
query_column_to_search="another_embedding",
206-
top_k=2,
207-
).to_pandas() # type:ignore
208-
expected = pd.DataFrame(
209-
{
210-
"query_id": ["dog", "dog", "cat", "cat"],
211-
"embedding": [
212-
np.array([1.0, 2.0]),
213-
np.array([1.0, 2.0]),
214-
np.array([3.0, 5.2]),
215-
np.array([3.0, 5.2]),
216-
],
217-
"another_embedding": [
218-
np.array([1.0, 2.5]),
219-
np.array([1.0, 2.5]),
220-
np.array([3.3, 5.2]),
221-
np.array([3.3, 5.2]),
222-
],
223-
"id": [1, 4, 2, 5],
224-
"my_embedding": [
225-
np.array([1.0, 2.0]),
226-
np.array([1.0, 3.2]),
227-
np.array([2.0, 4.0]),
228-
np.array([5.0, 5.4]),
229-
],
230-
"distance": [0.5, 0.7, 1.769181, 1.711724],
231-
},
232-
index=pd.Index([0, 0, 1, 1], dtype="Int64"),
233-
)
234-
pd.testing.assert_frame_equal(
235-
vector_search_result, expected, check_dtype=False, rtol=0.1
171+
session=session,
236172
)
173+
base_table = base_df.to_gbq()
174+
try:
175+
search_query = bpd.Series([[0.75, 0.25], [-0.25, -0.75]], session=session)
176+
vector_search_result = bbq.vector_search(
177+
base_table=base_table,
178+
column_to_search="my_embedding",
179+
query=search_query,
180+
distance_type="cosine",
181+
top_k=2,
182+
).to_pandas() # type:ignore
183+
expected = pd.DataFrame(
184+
{
185+
"0": {np.int64(0): [0.75, 0.25], np.int64(1): [-0.25, -0.75]},
186+
"id": {np.int64(0): 1, np.int64(1): 4},
187+
"my_embedding": {np.int64(0): [0.0, 1.0], np.int64(1): [-1.0, 0.0]},
188+
"distance": {
189+
np.int64(0): 0.683772233983162,
190+
np.int64(1): 0.683772233983162,
191+
},
192+
},
193+
index=pd.Index([0, 0, 1, 1], dtype="Int64"),
194+
)
195+
pd.testing.assert_frame_equal(
196+
vector_search_result, expected, check_dtype=False, rtol=0.1
197+
)
198+
finally:
199+
session.bqclient.delete_table(base_table, not_found_ok=True)

tests/system/small/test_unordered.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ def test_unordered_mode_read_gbq(unordered_session):
103103
}
104104
)
105105
# Don't need ignore_order as there is only 1 row
106-
assert_pandas_df_equal(df.to_pandas(), expected)
106+
assert_pandas_df_equal(df.to_pandas(), expected, check_index_type=False)
107107

108108

109109
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)