diff --git a/bigframes/bigquery/_operations/ai.py b/bigframes/bigquery/_operations/ai.py index d82023e4b5..3bafce6166 100644 --- a/bigframes/bigquery/_operations/ai.py +++ b/bigframes/bigquery/_operations/ai.py @@ -19,16 +19,25 @@ from __future__ import annotations import json -from typing import Any, List, Literal, Mapping, Tuple +from typing import Any, List, Literal, Mapping, Tuple, Union -from bigframes import clients, dtypes, series -from bigframes.core import log_adapter +import pandas as pd + +from bigframes import clients, dtypes, series, session +from bigframes.core import convert, log_adapter from bigframes.operations import ai_ops +PROMPT_TYPE = Union[ + series.Series, + pd.Series, + List[Union[str, series.Series, pd.Series]], + Tuple[Union[str, series.Series, pd.Series], ...], +] + @log_adapter.method_logger(custom_base_name="bigquery_ai") def generate_bool( - prompt: series.Series | List[str | series.Series] | Tuple[str | series.Series, ...], + prompt: PROMPT_TYPE, *, connection_id: str | None = None, endpoint: str | None = None, @@ -51,7 +60,7 @@ def generate_bool( 0 {'result': True, 'full_response': '{"candidate... 1 {'result': True, 'full_response': '{"candidate... 2 {'result': False, 'full_response': '{"candidat... - dtype: struct[pyarrow] + dtype: struct>, status: string>[pyarrow] >>> bbq.ai.generate_bool((df["col_1"], " is a ", df["col_2"])).struct.field("result") 0 True @@ -60,8 +69,9 @@ def generate_bool( Name: result, dtype: boolean Args: - prompt (series.Series | List[str|series.Series] | Tuple[str|series.Series, ...]): - A mixture of Series and string literals that specifies the prompt to send to the model. + prompt (Series | List[str|Series] | Tuple[str|Series, ...]): + A mixture of Series and string literals that specifies the prompt to send to the model. The Series can be BigFrames Series + or pandas Series. connection_id (str, optional): Specifies the connection to use to communicate with the model. For example, `myproject.us.myconnection`. If not provided, the connection from the current session will be used. @@ -84,7 +94,7 @@ def generate_bool( Returns: bigframes.series.Series: A new struct Series with the result data. The struct contains these fields: * "result": a BOOL value containing the model's response to the prompt. The result is None if the request fails or is filtered by responsible AI. - * "full_response": a STRING value containing the JSON response from the projects.locations.endpoints.generateContent call to the model. + * "full_response": a JSON value containing the response from the projects.locations.endpoints.generateContent call to the model. The generated text is in the text element. * "status": a STRING value that contains the API response status for the corresponding row. This value is empty if the operation was successful. """ @@ -104,7 +114,7 @@ def generate_bool( def _separate_context_and_series( - prompt: series.Series | List[str | series.Series] | Tuple[str | series.Series, ...], + prompt: PROMPT_TYPE, ) -> Tuple[List[str | None], List[series.Series]]: """ Returns the two values. The first value is the prompt with all series replaced by None. The second value is all the series @@ -123,18 +133,19 @@ def _separate_context_and_series( return [None], [prompt] prompt_context: List[str | None] = [] - series_list: List[series.Series] = [] + series_list: List[series.Series | pd.Series] = [] + session = None for item in prompt: if isinstance(item, str): prompt_context.append(item) - elif isinstance(item, series.Series): + elif isinstance(item, (series.Series, pd.Series)): prompt_context.append(None) - if item.dtype == dtypes.OBJ_REF_DTYPE: - # Multi-model support - item = item.blob.read_url() + if isinstance(item, series.Series) and session is None: + # Use the first available BF session if there's any. + session = item._session series_list.append(item) else: @@ -143,7 +154,20 @@ def _separate_context_and_series( if not series_list: raise ValueError("Please provide at least one Series in the prompt") - return prompt_context, series_list + converted_list = [_convert_series(s, session) for s in series_list] + + return prompt_context, converted_list + + +def _convert_series( + s: series.Series | pd.Series, session: session.Session | None +) -> series.Series: + result = convert.to_bf_series(s, default_index=None, session=session) + + if result.dtype == dtypes.OBJ_REF_DTYPE: + # Support multimodel + return result.blob.read_url() + return result def _resolve_connection_id(series: series.Series, connection_id: str | None): diff --git a/bigframes/operations/ai_ops.py b/bigframes/operations/ai_ops.py index fe5eb1406f..680c1585fb 100644 --- a/bigframes/operations/ai_ops.py +++ b/bigframes/operations/ai_ops.py @@ -40,7 +40,7 @@ def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionT pa.struct( ( pa.field("result", pa.bool_()), - pa.field("full_response", pa.string()), + pa.field("full_response", dtypes.JSON_ARROW_TYPE), pa.field("status", pa.string()), ) ) diff --git a/tests/system/small/bigquery/test_ai.py b/tests/system/small/bigquery/test_ai.py index 443d4c54a3..be67a0d580 100644 --- a/tests/system/small/bigquery/test_ai.py +++ b/tests/system/small/bigquery/test_ai.py @@ -18,7 +18,7 @@ import pyarrow as pa import pytest -from bigframes import series +from bigframes import dtypes, series import bigframes.bigquery as bbq import bigframes.pandas as bpd @@ -35,7 +35,26 @@ def test_ai_generate_bool(session): pa.struct( ( pa.field("result", pa.bool_()), - pa.field("full_response", pa.string()), + pa.field("full_response", dtypes.JSON_ARROW_TYPE), + pa.field("status", pa.string()), + ) + ) + ) + + +def test_ai_generate_bool_with_pandas(session): + s1 = pd.Series(["apple", "bear"]) + s2 = bpd.Series(["fruit", "tree"], session=session) + prompt = (s1, " is a ", s2) + + result = bbq.ai.generate_bool(prompt, endpoint="gemini-2.5-flash") + + assert _contains_no_nulls(result) + assert result.dtype == pd.ArrowDtype( + pa.struct( + ( + pa.field("result", pa.bool_()), + pa.field("full_response", dtypes.JSON_ARROW_TYPE), pa.field("status", pa.string()), ) ) @@ -62,7 +81,7 @@ def test_ai_generate_bool_with_model_params(session): pa.struct( ( pa.field("result", pa.bool_()), - pa.field("full_response", pa.string()), + pa.field("full_response", dtypes.JSON_ARROW_TYPE), pa.field("status", pa.string()), ) ) @@ -81,7 +100,7 @@ def test_ai_generate_bool_multi_model(session): pa.struct( ( pa.field("result", pa.bool_()), - pa.field("full_response", pa.string()), + pa.field("full_response", dtypes.JSON_ARROW_TYPE), pa.field("status", pa.string()), ) )