Skip to content

Commit 1b4dbcb

Browse files
committed
feat: add ai.generate() to bigframes.bigquery module
1 parent 10b2a38 commit 1b4dbcb

File tree

17 files changed

+261
-14
lines changed

17 files changed

+261
-14
lines changed

bigframes/bigquery/_operations/ai.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,85 @@
3535
]
3636

3737

38+
@log_adapter.method_logger(custom_base_name="bigquery_ai")
39+
def generate(
40+
prompt: PROMPT_TYPE,
41+
*,
42+
connection_id: str | None = None,
43+
endpoint: str | None = None,
44+
request_type: Literal["dedicated", "shared", "unspecified"] = "unspecified",
45+
model_params: Mapping[Any, Any] | None = None,
46+
# TODO(b/446974666) Add output_schema parameter
47+
) -> series.Series:
48+
"""
49+
Returns the AI analysis based on the prompt, which can be any combination of text and unstructured data.
50+
51+
**Examples:**
52+
53+
>>> import bigframes.pandas as bpd
54+
>>> import bigframes.bigquery as bbq
55+
>>> bpd.options.display.progress_bar = None
56+
>>> df = bpd.DataFrame({
57+
... "col_1": ["apple", "bear", "pear"],
58+
... "col_2": ["fruit", "animal", "animal"]
59+
... })
60+
>>> bbq.ai.generate_bool((df["col_1"], " is a ", df["col_2"]))
61+
0 {'result': True, 'full_response': '{"candidate...
62+
1 {'result': True, 'full_response': '{"candidate...
63+
2 {'result': False, 'full_response': '{"candidat...
64+
dtype: struct<result: bool, full_response: extension<dbjson<JSONArrowType>>, status: string>[pyarrow]
65+
66+
>>> bbq.ai.generate_bool((df["col_1"], " is a ", df["col_2"])).struct.field("result")
67+
0 True
68+
1 True
69+
2 False
70+
Name: result, dtype: boolean
71+
72+
Args:
73+
prompt (Series | List[str|Series] | Tuple[str|Series, ...]):
74+
A mixture of Series and string literals that specifies the prompt to send to the model. The Series can be BigFrames Series
75+
or pandas Series.
76+
connection_id (str, optional):
77+
Specifies the connection to use to communicate with the model. For example, `myproject.us.myconnection`.
78+
If not provided, the connection from the current session will be used.
79+
endpoint (str, optional):
80+
Specifies the Vertex AI endpoint to use for the model. For example `"gemini-2.5-flash"`. You can specify any
81+
generally available or preview Gemini model. If you specify the model name, BigQuery ML automatically identifies and
82+
uses the full endpoint of the model. If you don't specify an ENDPOINT value, BigQuery ML selects a recent stable
83+
version of Gemini to use.
84+
request_type (Literal["dedicated", "shared", "unspecified"]):
85+
Specifies the type of inference request to send to the Gemini model. The request type determines what quota the request uses.
86+
* "dedicated": function only uses Provisioned Throughput quota. The function returns the error Provisioned throughput is not
87+
purchased or is not active if Provisioned Throughput quota isn't available.
88+
* "shared": the function only uses dynamic shared quota (DSQ), even if you have purchased Provisioned Throughput quota.
89+
* "unspecified": If you haven't purchased Provisioned Throughput quota, the function uses DSQ quota.
90+
If you have purchased Provisioned Throughput quota, the function uses the Provisioned Throughput quota first.
91+
If requests exceed the Provisioned Throughput quota, the overflow traffic uses DSQ quota.
92+
model_params (Mapping[Any, Any]):
93+
Provides additional parameters to the model. The MODEL_PARAMS value must conform to the generateContent request body format.
94+
95+
Returns:
96+
bigframes.series.Series: A new struct Series with the result data. The struct contains these fields:
97+
* "result": a BOOL value containing the model's response to the prompt. The result is None if the request fails or is filtered by responsible AI.
98+
* "full_response": a JSON value containing the response from the projects.locations.endpoints.generateContent call to the model.
99+
The generated text is in the text element.
100+
* "status": a STRING value that contains the API response status for the corresponding row. This value is empty if the operation was successful.
101+
"""
102+
103+
prompt_context, series_list = _separate_context_and_series(prompt)
104+
assert len(series_list) > 0
105+
106+
operator = ai_ops.AIGenerate(
107+
prompt_context=tuple(prompt_context),
108+
connection_id=_resolve_connection_id(series_list[0], connection_id),
109+
endpoint=endpoint,
110+
request_type=request_type,
111+
model_params=json.dumps(model_params) if model_params else None,
112+
)
113+
114+
return series_list[0]._apply_nary_op(operator, series_list[1:])
115+
116+
38117
@log_adapter.method_logger(custom_base_name="bigquery_ai")
39118
def generate_bool(
40119
prompt: PROMPT_TYPE,

bigframes/core/compile/ibis_compiler/scalar_op_registry.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1970,6 +1970,20 @@ def struct_op_impl(
19701970
return ibis_types.struct(data)
19711971

19721972

1973+
@scalar_op_compiler.register_nary_op(ops.AIGenerate, pass_op=True)
1974+
def ai_generate(
1975+
*values: ibis_types.Value, op: ops.AIGenerate
1976+
) -> ibis_types.StructValue:
1977+
1978+
return ai_ops.AIGenerate(
1979+
_construct_prompt(values, op.prompt_context), # type: ignore
1980+
op.connection_id, # type: ignore
1981+
op.endpoint, # type: ignore
1982+
op.request_type.upper(), # type: ignore
1983+
op.model_params, # type: ignore
1984+
).to_expr()
1985+
1986+
19731987
@scalar_op_compiler.register_nary_op(ops.AIGenerateBool, pass_op=True)
19741988
def ai_generate_bool(
19751989
*values: ibis_types.Value, op: ops.AIGenerateBool

bigframes/core/compile/sqlglot/expressions/ai_ops.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,13 @@
2626
register_nary_op = scalar_compiler.scalar_op_compiler.register_nary_op
2727

2828

29+
@register_nary_op(ops.AIGenerate, pass_op=True)
30+
def _(*exprs: TypedExpr, op: ops.AIGenerate) -> sge.Expression:
31+
args = [_construct_prompt(exprs, op.prompt_context)] + _construct_named_args(op)
32+
33+
return sge.func("AI.GENERATE", *args)
34+
35+
2936
@register_nary_op(ops.AIGenerateBool, pass_op=True)
3037
def _(*exprs: TypedExpr, op: ops.AIGenerateBool) -> sge.Expression:
3138
args = [_construct_prompt(exprs, op.prompt_context)] + _construct_named_args(op)

bigframes/operations/__init__.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,12 @@
1414

1515
from __future__ import annotations
1616

17-
from bigframes.operations.ai_ops import AIGenerateBool, AIGenerateDouble, AIGenerateInt
17+
from bigframes.operations.ai_ops import (
18+
AIGenerate,
19+
AIGenerateBool,
20+
AIGenerateDouble,
21+
AIGenerateInt,
22+
)
1823
from bigframes.operations.array_ops import (
1924
ArrayIndexOp,
2025
ArrayReduceOp,
@@ -412,6 +417,7 @@
412417
"geo_y_op",
413418
"GeoStDistanceOp",
414419
# AI ops
420+
"AIGenerate",
415421
"AIGenerateBool",
416422
"AIGenerateDouble",
417423
"AIGenerateInt",

bigframes/operations/ai_ops.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,28 @@
2424
from bigframes.operations import base_ops
2525

2626

27+
@dataclasses.dataclass(frozen=True)
28+
class AIGenerate(base_ops.NaryOp):
29+
name: ClassVar[str] = "ai_generate"
30+
31+
prompt_context: Tuple[str | None, ...]
32+
connection_id: str
33+
endpoint: str | None
34+
request_type: Literal["dedicated", "shared", "unspecified"]
35+
model_params: str | None
36+
37+
def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionType:
38+
return pd.ArrowDtype(
39+
pa.struct(
40+
(
41+
pa.field("result", pa.string()),
42+
pa.field("full_response", dtypes.JSON_ARROW_TYPE),
43+
pa.field("status", pa.string()),
44+
)
45+
)
46+
)
47+
48+
2749
@dataclasses.dataclass(frozen=True)
2850
class AIGenerateBool(base_ops.NaryOp):
2951
name: ClassVar[str] = "ai_generate_bool"

tests/system/small/bigquery/test_ai.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,24 @@ def test_ai_function_compile_model_params(session):
6969
)
7070

7171

72+
def test_ai_generate(session):
73+
country = bpd.Series(["Japan", "Canada"], session=session)
74+
prompt = ("What's the capital city of ", country, "? one word only")
75+
76+
result = bbq.ai.generate(prompt, endpoint="gemini-2.5-flash")
77+
78+
assert _contains_no_nulls(result)
79+
assert result.dtype == pd.ArrowDtype(
80+
pa.struct(
81+
(
82+
pa.field("result", pa.string()),
83+
pa.field("full_response", dtypes.JSON_ARROW_TYPE),
84+
pa.field("status", pa.string()),
85+
)
86+
)
87+
)
88+
89+
7290
def test_ai_generate_bool(session):
7391
s1 = bpd.Series(["apple", "bear"], session=session)
7492
s2 = bpd.Series(["fruit", "tree"], session=session)
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`string_col` AS `bfcol_0`
4+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
5+
), `bfcte_1` AS (
6+
SELECT
7+
*,
8+
AI.GENERATE(
9+
prompt => (`bfcol_0`, ' is the same as ', `bfcol_0`),
10+
connection_id => 'bigframes-dev.us.bigframes-default-connection',
11+
endpoint => 'gemini-2.5-flash',
12+
request_type => 'SHARED'
13+
) AS `bfcol_1`
14+
FROM `bfcte_0`
15+
)
16+
SELECT
17+
`bfcol_1` AS `result`
18+
FROM `bfcte_1`

tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_generate_bool/out.sql

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ WITH `bfcte_0` AS (
77
*,
88
AI.GENERATE_BOOL(
99
prompt => (`bfcol_0`, ' is the same as ', `bfcol_0`),
10-
connection_id => 'test_connection_id',
10+
connection_id => 'bigframes-dev.us.bigframes-default-connection',
1111
endpoint => 'gemini-2.5-flash',
1212
request_type => 'SHARED'
1313
) AS `bfcol_1`

tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_generate_bool_with_model_param/out.sql

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ WITH `bfcte_0` AS (
77
*,
88
AI.GENERATE_BOOL(
99
prompt => (`bfcol_0`, ' is the same as ', `bfcol_0`),
10-
connection_id => 'test_connection_id',
10+
connection_id => 'bigframes-dev.us.bigframes-default-connection',
1111
request_type => 'SHARED',
1212
model_params => JSON '{}'
1313
) AS `bfcol_1`

tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_generate_double/out.sql

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ WITH `bfcte_0` AS (
77
*,
88
AI.GENERATE_DOUBLE(
99
prompt => (`bfcol_0`, ' is the same as ', `bfcol_0`),
10-
connection_id => 'test_connection_id',
10+
connection_id => 'bigframes-dev.us.bigframes-default-connection',
1111
endpoint => 'gemini-2.5-flash',
1212
request_type => 'SHARED'
1313
) AS `bfcol_1`

0 commit comments

Comments
 (0)