Skip to content

Commit aa23fca

Browse files
authored
Merge branch 'main' into tswast-doctest-boilerplate
2 parents bed4069 + 8514200 commit aa23fca

File tree

28 files changed

+604
-146
lines changed

28 files changed

+604
-146
lines changed

CHANGELOG.md

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,29 @@
44

55
[1]: https://pypi.org/project/bigframes/#history
66

7+
## [2.24.0](https://github.com/googleapis/python-bigquery-dataframes/compare/v2.23.0...v2.24.0) (2025-10-07)
8+
9+
10+
### Features
11+
12+
* Add ai.classify() to bigframes.bigquery package ([#2137](https://github.com/googleapis/python-bigquery-dataframes/issues/2137)) ([56e5033](https://github.com/googleapis/python-bigquery-dataframes/commit/56e50331d198b7f517f85695c208f893ab9389d2))
13+
* Add ai.generate() to bigframes.bigquery module ([#2128](https://github.com/googleapis/python-bigquery-dataframes/issues/2128)) ([3810452](https://github.com/googleapis/python-bigquery-dataframes/commit/3810452f16d8d6c9d3eb9075f1537177d98b4725))
14+
* Add ai.if_() and ai.score() to bigframes.bigquery package ([#2132](https://github.com/googleapis/python-bigquery-dataframes/issues/2132)) ([32502f4](https://github.com/googleapis/python-bigquery-dataframes/commit/32502f4195306d262788f39d1ab4206fc84ae50e))
15+
16+
17+
### Bug Fixes
18+
19+
* Fix internal type errors with temporal accessors ([#2125](https://github.com/googleapis/python-bigquery-dataframes/issues/2125)) ([c390da1](https://github.com/googleapis/python-bigquery-dataframes/commit/c390da11b7c2aa710bc2fbc692efb9f06059e4c4))
20+
* Fix row count local execution bug ([#2133](https://github.com/googleapis/python-bigquery-dataframes/issues/2133)) ([ece0762](https://github.com/googleapis/python-bigquery-dataframes/commit/ece07623e354a1dde2bd37020349e13f682e863f))
21+
* Join on, how args are now positional ([#2140](https://github.com/googleapis/python-bigquery-dataframes/issues/2140)) ([b711815](https://github.com/googleapis/python-bigquery-dataframes/commit/b7118152bfecc6ecf67aa4df23ec3f0a2b08aa30))
22+
* Only show JSON dtype warning when accessing dtypes directly ([#2136](https://github.com/googleapis/python-bigquery-dataframes/issues/2136)) ([eca22ee](https://github.com/googleapis/python-bigquery-dataframes/commit/eca22ee3104104cea96189391e527cad09bd7509))
23+
* Remove noisy AmbiguousWindowWarning from partial ordering mode ([#2129](https://github.com/googleapis/python-bigquery-dataframes/issues/2129)) ([4607f86](https://github.com/googleapis/python-bigquery-dataframes/commit/4607f86ebd77b916aafc37f69725b676e203b332))
24+
25+
26+
### Performance Improvements
27+
28+
* Scale read stream workers to cpu count ([#2135](https://github.com/googleapis/python-bigquery-dataframes/issues/2135)) ([67e46cd](https://github.com/googleapis/python-bigquery-dataframes/commit/67e46cd47933b84b55808003ed344b559e47c498))
29+
730
## [2.23.0](https://github.com/googleapis/python-bigquery-dataframes/compare/v2.22.0...v2.23.0) (2025-09-29)
831

932

bigframes/bigquery/_operations/ai.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525

2626
from bigframes import clients, dtypes, series, session
2727
from bigframes.core import convert, log_adapter
28-
from bigframes.operations import ai_ops
28+
from bigframes.operations import ai_ops, output_schemas
2929

3030
PROMPT_TYPE = Union[
3131
series.Series,
@@ -43,7 +43,7 @@ def generate(
4343
endpoint: str | None = None,
4444
request_type: Literal["dedicated", "shared", "unspecified"] = "unspecified",
4545
model_params: Mapping[Any, Any] | None = None,
46-
# TODO(b/446974666) Add output_schema parameter
46+
output_schema: Mapping[str, str] | None = None,
4747
) -> series.Series:
4848
"""
4949
Returns the AI analysis based on the prompt, which can be any combination of text and unstructured data.
@@ -63,6 +63,14 @@ def generate(
6363
1 Ottawa\\n
6464
Name: result, dtype: string
6565
66+
You get structured output when the `output_schema` parameter is set:
67+
68+
>>> animals = bpd.Series(["Rabbit", "Spider"])
69+
>>> bbq.ai.generate(animals, output_schema={"number_of_legs": "INT64", "is_herbivore": "BOOL"})
70+
0 {'is_herbivore': True, 'number_of_legs': 4, 'f...
71+
1 {'is_herbivore': False, 'number_of_legs': 8, '...
72+
dtype: struct<is_herbivore: bool, number_of_legs: int64, full_response: extension<dbjson<JSONArrowType>>, status: string>[pyarrow]
73+
6674
Args:
6775
prompt (Series | List[str|Series] | Tuple[str|Series, ...]):
6876
A mixture of Series and string literals that specifies the prompt to send to the model. The Series can be BigFrames Series
@@ -85,10 +93,14 @@ def generate(
8593
If requests exceed the Provisioned Throughput quota, the overflow traffic uses DSQ quota.
8694
model_params (Mapping[Any, Any]):
8795
Provides additional parameters to the model. The MODEL_PARAMS value must conform to the generateContent request body format.
96+
output_schema (Mapping[str, str]):
97+
A mapping value that specifies the schema of the output, in the form {field_name: data_type}. Supported data types include
98+
`STRING`, `INT64`, `FLOAT64`, `BOOL`, `ARRAY`, and `STRUCT`.
8899
89100
Returns:
90101
bigframes.series.Series: A new struct Series with the result data. The struct contains these fields:
91102
* "result": a STRING value containing the model's response to the prompt. The result is None if the request fails or is filtered by responsible AI.
103+
If you specify an output schema then result is replaced by your custom schema.
92104
* "full_response": a JSON value containing the response from the projects.locations.endpoints.generateContent call to the model.
93105
The generated text is in the text element.
94106
* "status": a STRING value that contains the API response status for the corresponding row. This value is empty if the operation was successful.
@@ -97,12 +109,22 @@ def generate(
97109
prompt_context, series_list = _separate_context_and_series(prompt)
98110
assert len(series_list) > 0
99111

112+
if output_schema is None:
113+
output_schema_str = None
114+
else:
115+
output_schema_str = ", ".join(
116+
[f"{name} {sql_type}" for name, sql_type in output_schema.items()]
117+
)
118+
# Validate user input
119+
output_schemas.parse_sql_fields(output_schema_str)
120+
100121
operator = ai_ops.AIGenerate(
101122
prompt_context=tuple(prompt_context),
102123
connection_id=_resolve_connection_id(series_list[0], connection_id),
103124
endpoint=endpoint,
104125
request_type=request_type,
105126
model_params=json.dumps(model_params) if model_params else None,
127+
output_schema=output_schema_str,
106128
)
107129

108130
return series_list[0]._apply_nary_op(operator, series_list[1:])

bigframes/core/compile/ibis_compiler/scalar_op_registry.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1985,6 +1985,7 @@ def ai_generate(
19851985
op.endpoint, # type: ignore
19861986
op.request_type.upper(), # type: ignore
19871987
op.model_params, # type: ignore
1988+
op.output_schema, # type: ignore
19881989
).to_expr()
19891990

19901991

bigframes/core/compile/sqlglot/expressions/ai_ops.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
from __future__ import annotations
1616

1717
from dataclasses import asdict
18-
import typing
1918

2019
import sqlglot.expressions as sge
2120

@@ -105,24 +104,24 @@ def _construct_named_args(op: ops.NaryOp) -> list[sge.Kwarg]:
105104

106105
op_args = asdict(op)
107106

108-
connection_id = typing.cast(str, op_args["connection_id"])
107+
connection_id = op_args["connection_id"]
109108
args.append(
110109
sge.Kwarg(this="connection_id", expression=sge.Literal.string(connection_id))
111110
)
112111

113-
endpoit = typing.cast(str, op_args.get("endpoint", None))
112+
endpoit = op_args.get("endpoint", None)
114113
if endpoit is not None:
115114
args.append(sge.Kwarg(this="endpoint", expression=sge.Literal.string(endpoit)))
116115

117-
request_type = typing.cast(str, op_args.get("request_type", None))
116+
request_type = op_args.get("request_type", None)
118117
if request_type is not None:
119118
args.append(
120119
sge.Kwarg(
121120
this="request_type", expression=sge.Literal.string(request_type.upper())
122121
)
123122
)
124123

125-
model_params = typing.cast(str, op_args.get("model_params", None))
124+
model_params = op_args.get("model_params", None)
126125
if model_params is not None:
127126
args.append(
128127
sge.Kwarg(
@@ -133,4 +132,13 @@ def _construct_named_args(op: ops.NaryOp) -> list[sge.Kwarg]:
133132
)
134133
)
135134

135+
output_schema = op_args.get("output_schema", None)
136+
if output_schema is not None:
137+
args.append(
138+
sge.Kwarg(
139+
this="output_schema",
140+
expression=sge.Literal.string(output_schema),
141+
)
142+
)
143+
136144
return args

bigframes/core/compile/sqlglot/expressions/generic_ops.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,9 @@
1818

1919
from bigframes import dtypes
2020
from bigframes import operations as ops
21+
from bigframes.core.compile.sqlglot import sqlglot_types
2122
from bigframes.core.compile.sqlglot.expressions.typed_expr import TypedExpr
2223
import bigframes.core.compile.sqlglot.scalar_compiler as scalar_compiler
23-
from bigframes.core.compile.sqlglot.sqlglot_types import SQLGlotType
2424

2525
register_unary_op = scalar_compiler.scalar_op_compiler.register_unary_op
2626

@@ -29,7 +29,7 @@
2929
def _(expr: TypedExpr, op: ops.AsTypeOp) -> sge.Expression:
3030
from_type = expr.dtype
3131
to_type = op.to_type
32-
sg_to_type = SQLGlotType.from_bigframes_dtype(to_type)
32+
sg_to_type = sqlglot_types.from_bigframes_dtype(to_type)
3333
sg_expr = expr.expr
3434

3535
if to_type == dtypes.JSON_DTYPE:

bigframes/core/compile/sqlglot/sqlglot_ir.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def from_pyarrow(
7979
expressions=[
8080
sge.ColumnDef(
8181
this=sge.to_identifier(field.column, quoted=True),
82-
kind=sgt.SQLGlotType.from_bigframes_dtype(field.dtype),
82+
kind=sgt.from_bigframes_dtype(field.dtype),
8383
)
8484
for field in schema.items
8585
],
@@ -620,7 +620,7 @@ def _select_to_cte(expr: sge.Select, cte_name: sge.Identifier) -> sge.Select:
620620

621621

622622
def _literal(value: typing.Any, dtype: dtypes.Dtype) -> sge.Expression:
623-
sqlglot_type = sgt.SQLGlotType.from_bigframes_dtype(dtype)
623+
sqlglot_type = sgt.from_bigframes_dtype(dtype)
624624
if value is None:
625625
return _cast(sge.Null(), sqlglot_type)
626626
elif dtype == dtypes.BYTES_DTYPE:

bigframes/core/compile/sqlglot/sqlglot_types.py

Lines changed: 52 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -25,62 +25,57 @@
2525
import bigframes.dtypes
2626

2727

28-
class SQLGlotType:
29-
@classmethod
30-
def from_bigframes_dtype(
31-
cls,
32-
bigframes_dtype: typing.Union[
33-
bigframes.dtypes.DtypeString, bigframes.dtypes.Dtype, np.dtype[typing.Any]
34-
],
35-
) -> str:
36-
if bigframes_dtype == bigframes.dtypes.INT_DTYPE:
37-
return "INT64"
38-
elif bigframes_dtype == bigframes.dtypes.FLOAT_DTYPE:
39-
return "FLOAT64"
40-
elif bigframes_dtype == bigframes.dtypes.STRING_DTYPE:
41-
return "STRING"
42-
elif bigframes_dtype == bigframes.dtypes.BOOL_DTYPE:
43-
return "BOOLEAN"
44-
elif bigframes_dtype == bigframes.dtypes.DATE_DTYPE:
45-
return "DATE"
46-
elif bigframes_dtype == bigframes.dtypes.TIME_DTYPE:
47-
return "TIME"
48-
elif bigframes_dtype == bigframes.dtypes.DATETIME_DTYPE:
49-
return "DATETIME"
50-
elif bigframes_dtype == bigframes.dtypes.TIMESTAMP_DTYPE:
51-
return "TIMESTAMP"
52-
elif bigframes_dtype == bigframes.dtypes.BYTES_DTYPE:
53-
return "BYTES"
54-
elif bigframes_dtype == bigframes.dtypes.NUMERIC_DTYPE:
55-
return "NUMERIC"
56-
elif bigframes_dtype == bigframes.dtypes.BIGNUMERIC_DTYPE:
57-
return "BIGNUMERIC"
58-
elif bigframes_dtype == bigframes.dtypes.JSON_DTYPE:
59-
return "JSON"
60-
elif bigframes_dtype == bigframes.dtypes.GEO_DTYPE:
61-
return "GEOGRAPHY"
62-
elif bigframes_dtype == bigframes.dtypes.TIMEDELTA_DTYPE:
63-
return "INT64"
64-
elif isinstance(bigframes_dtype, pd.ArrowDtype):
65-
if pa.types.is_list(bigframes_dtype.pyarrow_dtype):
66-
inner_bigframes_dtype = bigframes.dtypes.arrow_dtype_to_bigframes_dtype(
67-
bigframes_dtype.pyarrow_dtype.value_type
28+
def from_bigframes_dtype(
29+
bigframes_dtype: typing.Union[
30+
bigframes.dtypes.DtypeString, bigframes.dtypes.Dtype, np.dtype[typing.Any]
31+
],
32+
) -> str:
33+
if bigframes_dtype == bigframes.dtypes.INT_DTYPE:
34+
return "INT64"
35+
elif bigframes_dtype == bigframes.dtypes.FLOAT_DTYPE:
36+
return "FLOAT64"
37+
elif bigframes_dtype == bigframes.dtypes.STRING_DTYPE:
38+
return "STRING"
39+
elif bigframes_dtype == bigframes.dtypes.BOOL_DTYPE:
40+
return "BOOLEAN"
41+
elif bigframes_dtype == bigframes.dtypes.DATE_DTYPE:
42+
return "DATE"
43+
elif bigframes_dtype == bigframes.dtypes.TIME_DTYPE:
44+
return "TIME"
45+
elif bigframes_dtype == bigframes.dtypes.DATETIME_DTYPE:
46+
return "DATETIME"
47+
elif bigframes_dtype == bigframes.dtypes.TIMESTAMP_DTYPE:
48+
return "TIMESTAMP"
49+
elif bigframes_dtype == bigframes.dtypes.BYTES_DTYPE:
50+
return "BYTES"
51+
elif bigframes_dtype == bigframes.dtypes.NUMERIC_DTYPE:
52+
return "NUMERIC"
53+
elif bigframes_dtype == bigframes.dtypes.BIGNUMERIC_DTYPE:
54+
return "BIGNUMERIC"
55+
elif bigframes_dtype == bigframes.dtypes.JSON_DTYPE:
56+
return "JSON"
57+
elif bigframes_dtype == bigframes.dtypes.GEO_DTYPE:
58+
return "GEOGRAPHY"
59+
elif bigframes_dtype == bigframes.dtypes.TIMEDELTA_DTYPE:
60+
return "INT64"
61+
elif isinstance(bigframes_dtype, pd.ArrowDtype):
62+
if pa.types.is_list(bigframes_dtype.pyarrow_dtype):
63+
inner_bigframes_dtype = bigframes.dtypes.arrow_dtype_to_bigframes_dtype(
64+
bigframes_dtype.pyarrow_dtype.value_type
65+
)
66+
return f"ARRAY<{from_bigframes_dtype(inner_bigframes_dtype)}>"
67+
elif pa.types.is_struct(bigframes_dtype.pyarrow_dtype):
68+
struct_type = typing.cast(pa.StructType, bigframes_dtype.pyarrow_dtype)
69+
inner_fields: list[str] = []
70+
for i in range(struct_type.num_fields):
71+
field = struct_type.field(i)
72+
key = sg.to_identifier(field.name).sql("bigquery")
73+
dtype = from_bigframes_dtype(
74+
bigframes.dtypes.arrow_dtype_to_bigframes_dtype(field.type)
6875
)
69-
return (
70-
f"ARRAY<{SQLGlotType.from_bigframes_dtype(inner_bigframes_dtype)}>"
71-
)
72-
elif pa.types.is_struct(bigframes_dtype.pyarrow_dtype):
73-
struct_type = typing.cast(pa.StructType, bigframes_dtype.pyarrow_dtype)
74-
inner_fields: list[str] = []
75-
for i in range(struct_type.num_fields):
76-
field = struct_type.field(i)
77-
key = sg.to_identifier(field.name).sql("bigquery")
78-
dtype = SQLGlotType.from_bigframes_dtype(
79-
bigframes.dtypes.arrow_dtype_to_bigframes_dtype(field.type)
80-
)
81-
inner_fields.append(f"{key} {dtype}")
82-
return "STRUCT<{}>".format(", ".join(inner_fields))
76+
inner_fields.append(f"{key} {dtype}")
77+
return "STRUCT<{}>".format(", ".join(inner_fields))
8378

84-
raise ValueError(
85-
f"Unsupported type for {bigframes_dtype}. {constants.FEEDBACK_LINK}"
86-
)
79+
raise ValueError(
80+
f"Unsupported type for {bigframes_dtype}. {constants.FEEDBACK_LINK}"
81+
)

bigframes/core/indexes/base.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -761,6 +761,58 @@ def item(self):
761761
# Docstring is in third_party/bigframes_vendored/pandas/core/indexes/base.py
762762
return self.to_series().peek(2).item()
763763

764+
def __eq__(self, other) -> Index: # type: ignore
765+
return self._apply_binop(other, ops.eq_op)
766+
767+
def _apply_binop(self, other, op: ops.BinaryOp) -> Index:
768+
# TODO: Handle local objects, or objects not implicitly alignable? Gets ambiguous with partial ordering though
769+
if isinstance(other, (bigframes.series.Series, Index)):
770+
other = Index(other)
771+
if other.nlevels != self.nlevels:
772+
raise ValueError("Dimensions do not match")
773+
774+
lexpr = self._block.expr
775+
rexpr = other._block.expr
776+
join_result = lexpr.try_row_join(rexpr)
777+
if join_result is None:
778+
raise ValueError("Cannot align objects")
779+
780+
expr, (lmap, rmap) = join_result
781+
782+
expr, res_ids = expr.compute_values(
783+
[
784+
op.as_expr(lmap[lid], rmap[rid])
785+
for lid, rid in zip(lexpr.column_ids, rexpr.column_ids)
786+
]
787+
)
788+
return Index(
789+
blocks.Block(
790+
expr.select_columns(res_ids),
791+
index_columns=res_ids,
792+
column_labels=[],
793+
index_labels=[None] * len(res_ids),
794+
)
795+
)
796+
elif (
797+
isinstance(other, bigframes.dtypes.LOCAL_SCALAR_TYPES) and self.nlevels == 1
798+
):
799+
block, id = self._block.project_expr(
800+
op.as_expr(self._block.index_columns[0], ex.const(other))
801+
)
802+
return Index(block.select_column(id))
803+
elif isinstance(other, tuple) and len(other) == self.nlevels:
804+
block = self._block.project_exprs(
805+
[
806+
op.as_expr(self._block.index_columns[i], ex.const(other[i]))
807+
for i in range(self.nlevels)
808+
],
809+
labels=[None] * self.nlevels,
810+
drop=True,
811+
)
812+
return Index(block.set_index(block.value_columns))
813+
else:
814+
return NotImplemented
815+
764816

765817
def _should_create_datetime_index(block: blocks.Block) -> bool:
766818
if len(block.index.dtypes) != 1:

0 commit comments

Comments
 (0)