Skip to content

Commit 52f4c4f

Browse files
authored
Merge branch 'main' into sycai_ai_generate_output_schema
2 parents 99c7bfc + 56e5033 commit 52f4c4f

File tree

10 files changed

+189
-25
lines changed

10 files changed

+189
-25
lines changed

bigframes/bigquery/_operations/ai.py

Lines changed: 74 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -370,20 +370,20 @@ def if_(
370370
provides optimization such that not all rows are evaluated with the LLM.
371371
372372
**Examples:**
373-
>>> import bigframes.pandas as bpd
374-
>>> import bigframes.bigquery as bbq
375-
>>> bpd.options.display.progress_bar = None
376-
>>> us_state = bpd.Series(["Massachusetts", "Illinois", "Hawaii"])
377-
>>> bbq.ai.if_((us_state, " has a city called Springfield"))
378-
0 True
379-
1 True
380-
2 False
381-
dtype: boolean
382-
383-
>>> us_state[bbq.ai.if_((us_state, " has a city called Springfield"))]
384-
0 Massachusetts
385-
1 Illinois
386-
dtype: string
373+
>>> import bigframes.pandas as bpd
374+
>>> import bigframes.bigquery as bbq
375+
>>> bpd.options.display.progress_bar = None
376+
>>> us_state = bpd.Series(["Massachusetts", "Illinois", "Hawaii"])
377+
>>> bbq.ai.if_((us_state, " has a city called Springfield"))
378+
0 True
379+
1 True
380+
2 False
381+
dtype: boolean
382+
383+
>>> us_state[bbq.ai.if_((us_state, " has a city called Springfield"))]
384+
0 Massachusetts
385+
1 Illinois
386+
dtype: string
387387
388388
Args:
389389
prompt (Series | List[str|Series] | Tuple[str|Series, ...]):
@@ -408,6 +408,56 @@ def if_(
408408
return series_list[0]._apply_nary_op(operator, series_list[1:])
409409

410410

411+
@log_adapter.method_logger(custom_base_name="bigquery_ai")
412+
def classify(
413+
input: PROMPT_TYPE,
414+
categories: tuple[str, ...] | list[str],
415+
*,
416+
connection_id: str | None = None,
417+
) -> series.Series:
418+
"""
419+
Classifies a given input into one of the specified categories. It will always return one of the provided categories best fit the prompt input.
420+
421+
**Examples:**
422+
423+
>>> import bigframes.pandas as bpd
424+
>>> import bigframes.bigquery as bbq
425+
>>> bpd.options.display.progress_bar = None
426+
>>> df = bpd.DataFrame({'creature': ['Cat', 'Salmon']})
427+
>>> df['type'] = bbq.ai.classify(df['creature'], ['Mammal', 'Fish'])
428+
>>> df
429+
creature type
430+
0 Cat Mammal
431+
1 Salmon Fish
432+
<BLANKLINE>
433+
[2 rows x 2 columns]
434+
435+
Args:
436+
input (Series | List[str|Series] | Tuple[str|Series, ...]):
437+
A mixture of Series and string literals that specifies the input to send to the model. The Series can be BigFrames Series
438+
or pandas Series.
439+
categories (tuple[str, ...] | list[str]):
440+
Categories to classify the input into.
441+
connection_id (str, optional):
442+
Specifies the connection to use to communicate with the model. For example, `myproject.us.myconnection`.
443+
If not provided, the connection from the current session will be used.
444+
445+
Returns:
446+
bigframes.series.Series: A new series of strings.
447+
"""
448+
449+
prompt_context, series_list = _separate_context_and_series(input)
450+
assert len(series_list) > 0
451+
452+
operator = ai_ops.AIClassify(
453+
prompt_context=tuple(prompt_context),
454+
categories=tuple(categories),
455+
connection_id=_resolve_connection_id(series_list[0], connection_id),
456+
)
457+
458+
return series_list[0]._apply_nary_op(operator, series_list[1:])
459+
460+
411461
@log_adapter.method_logger(custom_base_name="bigquery_ai")
412462
def score(
413463
prompt: PROMPT_TYPE,
@@ -420,15 +470,16 @@ def score(
420470
rubric with examples in the prompt.
421471
422472
**Examples:**
423-
>>> import bigframes.pandas as bpd
424-
>>> import bigframes.bigquery as bbq
425-
>>> bpd.options.display.progress_bar = None
426-
>>> animal = bpd.Series(["Tiger", "Rabbit", "Blue Whale"])
427-
>>> bbq.ai.score(("Rank the relative weights of ", animal, " on the scale from 1 to 3")) # doctest: +SKIP
428-
0 2.0
429-
1 1.0
430-
2 3.0
431-
dtype: Float64
473+
474+
>>> import bigframes.pandas as bpd
475+
>>> import bigframes.bigquery as bbq
476+
>>> bpd.options.display.progress_bar = None
477+
>>> animal = bpd.Series(["Tiger", "Rabbit", "Blue Whale"])
478+
>>> bbq.ai.score(("Rank the relative weights of ", animal, " on the scale from 1 to 3")) # doctest: +SKIP
479+
0 2.0
480+
1 1.0
481+
2 3.0
482+
dtype: Float64
432483
433484
Args:
434485
prompt (Series | List[str|Series] | Tuple[str|Series, ...]):

bigframes/core/compile/ibis_compiler/scalar_op_registry.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2040,6 +2040,18 @@ def ai_if(*values: ibis_types.Value, op: ops.AIIf) -> ibis_types.StructValue:
20402040
).to_expr()
20412041

20422042

2043+
@scalar_op_compiler.register_nary_op(ops.AIClassify, pass_op=True)
2044+
def ai_classify(
2045+
*values: ibis_types.Value, op: ops.AIClassify
2046+
) -> ibis_types.StructValue:
2047+
2048+
return ai_ops.AIClassify(
2049+
_construct_prompt(values, op.prompt_context), # type: ignore
2050+
op.categories, # type: ignore
2051+
op.connection_id, # type: ignore
2052+
).to_expr()
2053+
2054+
20432055
@scalar_op_compiler.register_nary_op(ops.AIScore, pass_op=True)
20442056
def ai_score(*values: ibis_types.Value, op: ops.AIScore) -> ibis_types.StructValue:
20452057

bigframes/core/compile/sqlglot/expressions/ai_ops.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,21 @@ def _(*exprs: TypedExpr, op: ops.AIIf) -> sge.Expression:
6060
return sge.func("AI.IF", *args)
6161

6262

63+
@register_nary_op(ops.AIClassify, pass_op=True)
64+
def _(*exprs: TypedExpr, op: ops.AIClassify) -> sge.Expression:
65+
category_literals = [sge.Literal.string(cat) for cat in op.categories]
66+
categories_arg = sge.Kwarg(
67+
this="categories", expression=sge.array(*category_literals)
68+
)
69+
70+
args = [
71+
_construct_prompt(exprs, op.prompt_context, param_name="input"),
72+
categories_arg,
73+
] + _construct_named_args(op)
74+
75+
return sge.func("AI.CLASSIFY", *args)
76+
77+
6378
@register_nary_op(ops.AIScore, pass_op=True)
6479
def _(*exprs: TypedExpr, op: ops.AIScore) -> sge.Expression:
6580
args = [_construct_prompt(exprs, op.prompt_context)] + _construct_named_args(op)
@@ -68,7 +83,9 @@ def _(*exprs: TypedExpr, op: ops.AIScore) -> sge.Expression:
6883

6984

7085
def _construct_prompt(
71-
exprs: tuple[TypedExpr, ...], prompt_context: tuple[str | None, ...]
86+
exprs: tuple[TypedExpr, ...],
87+
prompt_context: tuple[str | None, ...],
88+
param_name: str = "prompt",
7289
) -> sge.Kwarg:
7390
prompt: list[str | sge.Expression] = []
7491
column_ref_idx = 0
@@ -79,7 +96,7 @@ def _construct_prompt(
7996
else:
8097
prompt.append(sge.Literal.string(elem))
8198

82-
return sge.Kwarg(this="prompt", expression=sge.Tuple(expressions=prompt))
99+
return sge.Kwarg(this=param_name, expression=sge.Tuple(expressions=prompt))
83100

84101

85102
def _construct_named_args(op: ops.NaryOp) -> list[sge.Kwarg]:

bigframes/operations/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from __future__ import annotations
1616

1717
from bigframes.operations.ai_ops import (
18+
AIClassify,
1819
AIGenerate,
1920
AIGenerateBool,
2021
AIGenerateDouble,
@@ -419,6 +420,7 @@
419420
"geo_y_op",
420421
"GeoStDistanceOp",
421422
# AI ops
423+
"AIClassify",
422424
"AIGenerate",
423425
"AIGenerateBool",
424426
"AIGenerateDouble",

bigframes/operations/ai_ops.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,18 @@ def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionT
129129
return dtypes.BOOL_DTYPE
130130

131131

132+
@dataclasses.dataclass(frozen=True)
133+
class AIClassify(base_ops.NaryOp):
134+
name: ClassVar[str] = "ai_classify"
135+
136+
prompt_context: Tuple[str | None, ...]
137+
categories: tuple[str, ...]
138+
connection_id: str
139+
140+
def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionType:
141+
return dtypes.STRING_DTYPE
142+
143+
132144
@dataclasses.dataclass(frozen=True)
133145
class AIScore(base_ops.NaryOp):
134146
name: ClassVar[str] = "ai_score"

tests/system/small/bigquery/test_ai.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,27 @@ def test_ai_if_multi_model(session):
260260
assert result.dtype == dtypes.BOOL_DTYPE
261261

262262

263+
def test_ai_classify(session):
264+
s = bpd.Series(["cat", "orchid"], session=session)
265+
bpd.options.display.repr_mode = "deferred"
266+
267+
result = bbq.ai.classify(s, ["animal", "plant"])
268+
269+
assert _contains_no_nulls(result)
270+
assert result.dtype == dtypes.STRING_DTYPE
271+
272+
273+
def test_ai_classify_multi_model(session):
274+
df = session.from_glob_path(
275+
"gs://bigframes-dev-testing/a_multimodel/images/*", name="image"
276+
)
277+
278+
result = bbq.ai.classify(df["image"], ["photo", "cartoon"])
279+
280+
assert _contains_no_nulls(result)
281+
assert result.dtype == dtypes.STRING_DTYPE
282+
283+
263284
def test_ai_score(session):
264285
s = bpd.Series(["Tiger", "Rabbit"], session=session)
265286
prompt = ("Rank the relative weights of ", s, " on the scale from 1 to 3")
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`string_col` AS `bfcol_0`
4+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
5+
), `bfcte_1` AS (
6+
SELECT
7+
*,
8+
AI.CLASSIFY(
9+
input => (`bfcol_0`),
10+
categories => ['greeting', 'rejection'],
11+
connection_id => 'bigframes-dev.us.bigframes-default-connection'
12+
) AS `bfcol_1`
13+
FROM `bfcte_0`
14+
)
15+
SELECT
16+
`bfcol_1` AS `result`
17+
FROM `bfcte_1`

tests/unit/core/compile/sqlglot/expressions/test_ai_ops.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,20 @@ def test_ai_if(scalar_types_df: dataframe.DataFrame, snapshot):
237237
snapshot.assert_match(sql, "out.sql")
238238

239239

240+
def test_ai_classify(scalar_types_df: dataframe.DataFrame, snapshot):
241+
col_name = "string_col"
242+
243+
op = ops.AIClassify(
244+
prompt_context=(None,),
245+
categories=("greeting", "rejection"),
246+
connection_id=CONNECTION_ID,
247+
)
248+
249+
sql = utils._apply_unary_ops(scalar_types_df, [op.as_expr(col_name)], ["result"])
250+
251+
snapshot.assert_match(sql, "out.sql")
252+
253+
240254
def test_ai_score(scalar_types_df: dataframe.DataFrame, snapshot):
241255
col_name = "string_col"
242256

third_party/bigframes_vendored/ibis/backends/sql/compilers/bigquery/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1119,6 +1119,9 @@ def visit_AIGenerateDouble(self, op, **kwargs):
11191119
def visit_AIIf(self, op, **kwargs):
11201120
return sge.func("AI.IF", *self._compile_ai_args(**kwargs))
11211121

1122+
def visit_AIClassify(self, op, **kwargs):
1123+
return sge.func("AI.CLASSIFY", *self._compile_ai_args(**kwargs))
1124+
11221125
def visit_AIScore(self, op, **kwargs):
11231126
return sge.func("AI.SCORE", *self._compile_ai_args(**kwargs))
11241127

third_party/bigframes_vendored/ibis/expr/operations/ai_ops.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,21 @@ def dtype(self) -> dt.Struct:
122122
return dt.bool
123123

124124

125+
@public
126+
class AIClassify(Value):
127+
"""Generate True/False based on the prompt"""
128+
129+
input: Value
130+
categories: Value[dt.Array[dt.String]]
131+
connection_id: Value[dt.String]
132+
133+
shape = rlz.shape_like("input")
134+
135+
@attribute
136+
def dtype(self) -> dt.Struct:
137+
return dt.string
138+
139+
125140
@public
126141
class AIScore(Value):
127142
"""Generate doubles based on the prompt"""

0 commit comments

Comments
 (0)