Skip to content

Commit 00dd022

Browse files
committed
feat: support string literal inputs for AI functions
1 parent cdf2dd5 commit 00dd022

File tree

2 files changed

+31
-2
lines changed

2 files changed

+31
-2
lines changed

bigframes/bigquery/_operations/ai.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,11 @@
2424
import pandas as pd
2525

2626
from bigframes import clients, dtypes, series, session
27-
from bigframes.core import convert, log_adapter
27+
from bigframes.core import convert, global_session, log_adapter
2828
from bigframes.operations import ai_ops, output_schemas
2929

3030
PROMPT_TYPE = Union[
31+
str,
3132
series.Series,
3233
pd.Series,
3334
List[Union[str, series.Series, pd.Series]],
@@ -514,9 +515,14 @@ def _separate_context_and_series(
514515
Input: ("str1", series1, "str2", "str3", series2)
515516
Output: ["str1", None, "str2", "str3", None], [series1, series2]
516517
"""
517-
if not isinstance(prompt, (list, tuple, series.Series)):
518+
if not isinstance(prompt, (str, list, tuple, series.Series)):
518519
raise ValueError(f"Unsupported prompt type: {type(prompt)}")
519520

521+
if isinstance(prompt, str):
522+
return [None], [
523+
series.Series([prompt], session=global_session.get_global_session())
524+
]
525+
520526
if isinstance(prompt, series.Series):
521527
if prompt.dtype == dtypes.OBJ_REF_DTYPE:
522528
# Multi-model support

tests/system/small/bigquery/test_ai.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from unittest import mock
16+
1517
from packaging import version
1618
import pandas as pd
1719
import pyarrow as pa
@@ -42,6 +44,27 @@ def test_ai_function_pandas_input(session):
4244
)
4345

4446

47+
def test_ai_function_string_input(session):
48+
with mock.patch(
49+
"bigframes.core.global_session.get_global_session"
50+
) as mock_get_session:
51+
mock_get_session.return_value = session
52+
prompt = "Is apple a fruit?"
53+
54+
result = bbq.ai.generate_bool(prompt, endpoint="gemini-2.5-flash")
55+
56+
assert _contains_no_nulls(result)
57+
assert result.dtype == pd.ArrowDtype(
58+
pa.struct(
59+
(
60+
pa.field("result", pa.bool_()),
61+
pa.field("full_response", dtypes.JSON_ARROW_TYPE),
62+
pa.field("status", pa.string()),
63+
)
64+
)
65+
)
66+
67+
4568
def test_ai_function_compile_model_params(session):
4669
if version.Version(sqlglot.__version__) < version.Version("25.18.0"):
4770
pytest.skip(

0 commit comments

Comments
 (0)