Skip to content

Commit 67206cb

Browse files
committed
feat: support pandas series in ai.generate_bool
1 parent 5ce5d63 commit 67206cb

File tree

2 files changed

+46
-9
lines changed

2 files changed

+46
-9
lines changed

bigframes/bigquery/_operations/ai.py

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,19 @@
2121
import json
2222
from typing import Any, List, Literal, Mapping, Tuple
2323

24-
from bigframes import clients, dtypes, series
25-
from bigframes.core import log_adapter
24+
import pandas as pd
25+
26+
from bigframes import clients, dtypes, series, session
27+
from bigframes.core import convert, log_adapter
2628
from bigframes.operations import ai_ops
2729

2830

2931
@log_adapter.method_logger(custom_base_name="bigquery_ai")
3032
def generate_bool(
31-
prompt: series.Series | List[str | series.Series] | Tuple[str | series.Series, ...],
33+
prompt: series.Series
34+
| pd.Series
35+
| List[str | series.Series | pd.Series]
36+
| Tuple[str | series.Series | pd.Series, ...],
3237
*,
3338
connection_id: str | None = None,
3439
endpoint: str | None = None,
@@ -77,8 +82,9 @@ def generate_bool(
7782
Name: result, dtype: boolean
7883
7984
Args:
80-
prompt (series.Series | List[str|series.Series] | Tuple[str|series.Series, ...]):
81-
A mixture of Series and string literals that specifies the prompt to send to the model.
85+
prompt (Series | List[str|Series] | Tuple[str|Series, ...]):
86+
A mixture of Series and string literals that specifies the prompt to send to the model. The Series can be BigFrames Series
87+
or pandas Series.
8288
connection_id (str, optional):
8389
Specifies the connection to use to communicate with the model. For example, `myproject.us.myconnection`.
8490
If not provided, the connection from the current session will be used.
@@ -142,16 +148,17 @@ def _separate_context_and_series(
142148
prompt_context: List[str | None] = []
143149
series_list: List[series.Series] = []
144150

151+
session = None
145152
for item in prompt:
146153
if isinstance(item, str):
147154
prompt_context.append(item)
148155

149-
elif isinstance(item, series.Series):
156+
elif isinstance(item, (series.Series, pd.Series)):
150157
prompt_context.append(None)
151158

152-
if item.dtype == dtypes.OBJ_REF_DTYPE:
153-
# Multi-model support
154-
item = item.blob.read_url()
159+
if isinstance(item, series.Series) and session is None:
160+
# use the session from the first BigFrames session if possible
161+
session = item._session
155162
series_list.append(item)
156163

157164
else:
@@ -160,9 +167,22 @@ def _separate_context_and_series(
160167
if not series_list:
161168
raise ValueError("Please provide at least one Series in the prompt")
162169

170+
series_list = [_convert_series(s, session) for s in series_list]
171+
163172
return prompt_context, series_list
164173

165174

175+
def _convert_series(
176+
s: series.Series | pd.Series, session: session.Session | None
177+
) -> series.Series:
178+
result = convert.to_bf_series(s, default_index=None, session=session)
179+
180+
if result.dtype == dtypes.OBJ_REF_DTYPE:
181+
# Support multimodel
182+
return result.blob.read_url()
183+
return result
184+
185+
166186
def _resolve_connection_id(series: series.Series, connection_id: str | None):
167187
return clients.get_canonical_bq_connection_id(
168188
connection_id or series._session._bq_connection,

tests/system/small/bigquery/test_ai.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,23 @@ def test_ai_generate_bool(session):
3939
)
4040

4141

42+
def test_ai_generate_bool_with_pandas(session):
43+
s1 = pd.Series(["apple", "bear"])
44+
s2 = bpd.Series(["fruit", "tree"], session=session)
45+
prompt = (s1, " is a ", s2)
46+
47+
result = bbq.ai.generate_bool(prompt, endpoint="gemini-2.5-flash").struct.field(
48+
"result"
49+
)
50+
51+
pandas.testing.assert_series_equal(
52+
result.to_pandas(),
53+
pd.Series([True, False], name="result"),
54+
check_dtype=False,
55+
check_index=False,
56+
)
57+
58+
4259
def test_ai_generate_bool_with_model_params(session):
4360
if sys.version_info < (3, 12):
4461
pytest.skip(

0 commit comments

Comments
 (0)