Skip to content

Commit 5b9bf8c

Browse files
Merge remote-tracking branch 'github/main' into plot_kinds
2 parents 1526678 + 8514200 commit 5b9bf8c

File tree

18 files changed

+505
-90
lines changed

18 files changed

+505
-90
lines changed

bigframes/bigquery/_operations/ai.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525

2626
from bigframes import clients, dtypes, series, session
2727
from bigframes.core import convert, log_adapter
28-
from bigframes.operations import ai_ops
28+
from bigframes.operations import ai_ops, output_schemas
2929

3030
PROMPT_TYPE = Union[
3131
series.Series,
@@ -43,7 +43,7 @@ def generate(
4343
endpoint: str | None = None,
4444
request_type: Literal["dedicated", "shared", "unspecified"] = "unspecified",
4545
model_params: Mapping[Any, Any] | None = None,
46-
# TODO(b/446974666) Add output_schema parameter
46+
output_schema: Mapping[str, str] | None = None,
4747
) -> series.Series:
4848
"""
4949
Returns the AI analysis based on the prompt, which can be any combination of text and unstructured data.
@@ -64,6 +64,14 @@ def generate(
6464
1 Ottawa\\n
6565
Name: result, dtype: string
6666
67+
You get structured output when the `output_schema` parameter is set:
68+
69+
>>> animals = bpd.Series(["Rabbit", "Spider"])
70+
>>> bbq.ai.generate(animals, output_schema={"number_of_legs": "INT64", "is_herbivore": "BOOL"})
71+
0 {'is_herbivore': True, 'number_of_legs': 4, 'f...
72+
1 {'is_herbivore': False, 'number_of_legs': 8, '...
73+
dtype: struct<is_herbivore: bool, number_of_legs: int64, full_response: extension<dbjson<JSONArrowType>>, status: string>[pyarrow]
74+
6775
Args:
6876
prompt (Series | List[str|Series] | Tuple[str|Series, ...]):
6977
A mixture of Series and string literals that specifies the prompt to send to the model. The Series can be BigFrames Series
@@ -86,10 +94,14 @@ def generate(
8694
If requests exceed the Provisioned Throughput quota, the overflow traffic uses DSQ quota.
8795
model_params (Mapping[Any, Any]):
8896
Provides additional parameters to the model. The MODEL_PARAMS value must conform to the generateContent request body format.
97+
output_schema (Mapping[str, str]):
98+
A mapping value that specifies the schema of the output, in the form {field_name: data_type}. Supported data types include
99+
`STRING`, `INT64`, `FLOAT64`, `BOOL`, `ARRAY`, and `STRUCT`.
89100
90101
Returns:
91102
bigframes.series.Series: A new struct Series with the result data. The struct contains these fields:
92103
* "result": a STRING value containing the model's response to the prompt. The result is None if the request fails or is filtered by responsible AI.
104+
If you specify an output schema then result is replaced by your custom schema.
93105
* "full_response": a JSON value containing the response from the projects.locations.endpoints.generateContent call to the model.
94106
The generated text is in the text element.
95107
* "status": a STRING value that contains the API response status for the corresponding row. This value is empty if the operation was successful.
@@ -98,12 +110,22 @@ def generate(
98110
prompt_context, series_list = _separate_context_and_series(prompt)
99111
assert len(series_list) > 0
100112

113+
if output_schema is None:
114+
output_schema_str = None
115+
else:
116+
output_schema_str = ", ".join(
117+
[f"{name} {sql_type}" for name, sql_type in output_schema.items()]
118+
)
119+
# Validate user input
120+
output_schemas.parse_sql_fields(output_schema_str)
121+
101122
operator = ai_ops.AIGenerate(
102123
prompt_context=tuple(prompt_context),
103124
connection_id=_resolve_connection_id(series_list[0], connection_id),
104125
endpoint=endpoint,
105126
request_type=request_type,
106127
model_params=json.dumps(model_params) if model_params else None,
128+
output_schema=output_schema_str,
107129
)
108130

109131
return series_list[0]._apply_nary_op(operator, series_list[1:])

bigframes/core/compile/ibis_compiler/scalar_op_registry.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1985,6 +1985,7 @@ def ai_generate(
19851985
op.endpoint, # type: ignore
19861986
op.request_type.upper(), # type: ignore
19871987
op.model_params, # type: ignore
1988+
op.output_schema, # type: ignore
19881989
).to_expr()
19891990

19901991

bigframes/core/compile/sqlglot/expressions/ai_ops.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
from __future__ import annotations
1616

1717
from dataclasses import asdict
18-
import typing
1918

2019
import sqlglot.expressions as sge
2120

@@ -105,24 +104,24 @@ def _construct_named_args(op: ops.NaryOp) -> list[sge.Kwarg]:
105104

106105
op_args = asdict(op)
107106

108-
connection_id = typing.cast(str, op_args["connection_id"])
107+
connection_id = op_args["connection_id"]
109108
args.append(
110109
sge.Kwarg(this="connection_id", expression=sge.Literal.string(connection_id))
111110
)
112111

113-
endpoit = typing.cast(str, op_args.get("endpoint", None))
112+
endpoit = op_args.get("endpoint", None)
114113
if endpoit is not None:
115114
args.append(sge.Kwarg(this="endpoint", expression=sge.Literal.string(endpoit)))
116115

117-
request_type = typing.cast(str, op_args.get("request_type", None))
116+
request_type = op_args.get("request_type", None)
118117
if request_type is not None:
119118
args.append(
120119
sge.Kwarg(
121120
this="request_type", expression=sge.Literal.string(request_type.upper())
122121
)
123122
)
124123

125-
model_params = typing.cast(str, op_args.get("model_params", None))
124+
model_params = op_args.get("model_params", None)
126125
if model_params is not None:
127126
args.append(
128127
sge.Kwarg(
@@ -133,4 +132,13 @@ def _construct_named_args(op: ops.NaryOp) -> list[sge.Kwarg]:
133132
)
134133
)
135134

135+
output_schema = op_args.get("output_schema", None)
136+
if output_schema is not None:
137+
args.append(
138+
sge.Kwarg(
139+
this="output_schema",
140+
expression=sge.Literal.string(output_schema),
141+
)
142+
)
143+
136144
return args

bigframes/core/compile/sqlglot/expressions/generic_ops.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,9 @@
1818

1919
from bigframes import dtypes
2020
from bigframes import operations as ops
21+
from bigframes.core.compile.sqlglot import sqlglot_types
2122
from bigframes.core.compile.sqlglot.expressions.typed_expr import TypedExpr
2223
import bigframes.core.compile.sqlglot.scalar_compiler as scalar_compiler
23-
from bigframes.core.compile.sqlglot.sqlglot_types import SQLGlotType
2424

2525
register_unary_op = scalar_compiler.scalar_op_compiler.register_unary_op
2626

@@ -29,7 +29,7 @@
2929
def _(expr: TypedExpr, op: ops.AsTypeOp) -> sge.Expression:
3030
from_type = expr.dtype
3131
to_type = op.to_type
32-
sg_to_type = SQLGlotType.from_bigframes_dtype(to_type)
32+
sg_to_type = sqlglot_types.from_bigframes_dtype(to_type)
3333
sg_expr = expr.expr
3434

3535
if to_type == dtypes.JSON_DTYPE:

bigframes/core/compile/sqlglot/sqlglot_ir.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def from_pyarrow(
7979
expressions=[
8080
sge.ColumnDef(
8181
this=sge.to_identifier(field.column, quoted=True),
82-
kind=sgt.SQLGlotType.from_bigframes_dtype(field.dtype),
82+
kind=sgt.from_bigframes_dtype(field.dtype),
8383
)
8484
for field in schema.items
8585
],
@@ -620,7 +620,7 @@ def _select_to_cte(expr: sge.Select, cte_name: sge.Identifier) -> sge.Select:
620620

621621

622622
def _literal(value: typing.Any, dtype: dtypes.Dtype) -> sge.Expression:
623-
sqlglot_type = sgt.SQLGlotType.from_bigframes_dtype(dtype)
623+
sqlglot_type = sgt.from_bigframes_dtype(dtype)
624624
if value is None:
625625
return _cast(sge.Null(), sqlglot_type)
626626
elif dtype == dtypes.BYTES_DTYPE:

bigframes/core/compile/sqlglot/sqlglot_types.py

Lines changed: 52 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -25,62 +25,57 @@
2525
import bigframes.dtypes
2626

2727

28-
class SQLGlotType:
29-
@classmethod
30-
def from_bigframes_dtype(
31-
cls,
32-
bigframes_dtype: typing.Union[
33-
bigframes.dtypes.DtypeString, bigframes.dtypes.Dtype, np.dtype[typing.Any]
34-
],
35-
) -> str:
36-
if bigframes_dtype == bigframes.dtypes.INT_DTYPE:
37-
return "INT64"
38-
elif bigframes_dtype == bigframes.dtypes.FLOAT_DTYPE:
39-
return "FLOAT64"
40-
elif bigframes_dtype == bigframes.dtypes.STRING_DTYPE:
41-
return "STRING"
42-
elif bigframes_dtype == bigframes.dtypes.BOOL_DTYPE:
43-
return "BOOLEAN"
44-
elif bigframes_dtype == bigframes.dtypes.DATE_DTYPE:
45-
return "DATE"
46-
elif bigframes_dtype == bigframes.dtypes.TIME_DTYPE:
47-
return "TIME"
48-
elif bigframes_dtype == bigframes.dtypes.DATETIME_DTYPE:
49-
return "DATETIME"
50-
elif bigframes_dtype == bigframes.dtypes.TIMESTAMP_DTYPE:
51-
return "TIMESTAMP"
52-
elif bigframes_dtype == bigframes.dtypes.BYTES_DTYPE:
53-
return "BYTES"
54-
elif bigframes_dtype == bigframes.dtypes.NUMERIC_DTYPE:
55-
return "NUMERIC"
56-
elif bigframes_dtype == bigframes.dtypes.BIGNUMERIC_DTYPE:
57-
return "BIGNUMERIC"
58-
elif bigframes_dtype == bigframes.dtypes.JSON_DTYPE:
59-
return "JSON"
60-
elif bigframes_dtype == bigframes.dtypes.GEO_DTYPE:
61-
return "GEOGRAPHY"
62-
elif bigframes_dtype == bigframes.dtypes.TIMEDELTA_DTYPE:
63-
return "INT64"
64-
elif isinstance(bigframes_dtype, pd.ArrowDtype):
65-
if pa.types.is_list(bigframes_dtype.pyarrow_dtype):
66-
inner_bigframes_dtype = bigframes.dtypes.arrow_dtype_to_bigframes_dtype(
67-
bigframes_dtype.pyarrow_dtype.value_type
28+
def from_bigframes_dtype(
29+
bigframes_dtype: typing.Union[
30+
bigframes.dtypes.DtypeString, bigframes.dtypes.Dtype, np.dtype[typing.Any]
31+
],
32+
) -> str:
33+
if bigframes_dtype == bigframes.dtypes.INT_DTYPE:
34+
return "INT64"
35+
elif bigframes_dtype == bigframes.dtypes.FLOAT_DTYPE:
36+
return "FLOAT64"
37+
elif bigframes_dtype == bigframes.dtypes.STRING_DTYPE:
38+
return "STRING"
39+
elif bigframes_dtype == bigframes.dtypes.BOOL_DTYPE:
40+
return "BOOLEAN"
41+
elif bigframes_dtype == bigframes.dtypes.DATE_DTYPE:
42+
return "DATE"
43+
elif bigframes_dtype == bigframes.dtypes.TIME_DTYPE:
44+
return "TIME"
45+
elif bigframes_dtype == bigframes.dtypes.DATETIME_DTYPE:
46+
return "DATETIME"
47+
elif bigframes_dtype == bigframes.dtypes.TIMESTAMP_DTYPE:
48+
return "TIMESTAMP"
49+
elif bigframes_dtype == bigframes.dtypes.BYTES_DTYPE:
50+
return "BYTES"
51+
elif bigframes_dtype == bigframes.dtypes.NUMERIC_DTYPE:
52+
return "NUMERIC"
53+
elif bigframes_dtype == bigframes.dtypes.BIGNUMERIC_DTYPE:
54+
return "BIGNUMERIC"
55+
elif bigframes_dtype == bigframes.dtypes.JSON_DTYPE:
56+
return "JSON"
57+
elif bigframes_dtype == bigframes.dtypes.GEO_DTYPE:
58+
return "GEOGRAPHY"
59+
elif bigframes_dtype == bigframes.dtypes.TIMEDELTA_DTYPE:
60+
return "INT64"
61+
elif isinstance(bigframes_dtype, pd.ArrowDtype):
62+
if pa.types.is_list(bigframes_dtype.pyarrow_dtype):
63+
inner_bigframes_dtype = bigframes.dtypes.arrow_dtype_to_bigframes_dtype(
64+
bigframes_dtype.pyarrow_dtype.value_type
65+
)
66+
return f"ARRAY<{from_bigframes_dtype(inner_bigframes_dtype)}>"
67+
elif pa.types.is_struct(bigframes_dtype.pyarrow_dtype):
68+
struct_type = typing.cast(pa.StructType, bigframes_dtype.pyarrow_dtype)
69+
inner_fields: list[str] = []
70+
for i in range(struct_type.num_fields):
71+
field = struct_type.field(i)
72+
key = sg.to_identifier(field.name).sql("bigquery")
73+
dtype = from_bigframes_dtype(
74+
bigframes.dtypes.arrow_dtype_to_bigframes_dtype(field.type)
6875
)
69-
return (
70-
f"ARRAY<{SQLGlotType.from_bigframes_dtype(inner_bigframes_dtype)}>"
71-
)
72-
elif pa.types.is_struct(bigframes_dtype.pyarrow_dtype):
73-
struct_type = typing.cast(pa.StructType, bigframes_dtype.pyarrow_dtype)
74-
inner_fields: list[str] = []
75-
for i in range(struct_type.num_fields):
76-
field = struct_type.field(i)
77-
key = sg.to_identifier(field.name).sql("bigquery")
78-
dtype = SQLGlotType.from_bigframes_dtype(
79-
bigframes.dtypes.arrow_dtype_to_bigframes_dtype(field.type)
80-
)
81-
inner_fields.append(f"{key} {dtype}")
82-
return "STRUCT<{}>".format(", ".join(inner_fields))
76+
inner_fields.append(f"{key} {dtype}")
77+
return "STRUCT<{}>".format(", ".join(inner_fields))
8378

84-
raise ValueError(
85-
f"Unsupported type for {bigframes_dtype}. {constants.FEEDBACK_LINK}"
86-
)
79+
raise ValueError(
80+
f"Unsupported type for {bigframes_dtype}. {constants.FEEDBACK_LINK}"
81+
)

bigframes/core/indexes/base.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -754,6 +754,58 @@ def item(self):
754754
# Docstring is in third_party/bigframes_vendored/pandas/core/indexes/base.py
755755
return self.to_series().peek(2).item()
756756

757+
def __eq__(self, other) -> Index: # type: ignore
758+
return self._apply_binop(other, ops.eq_op)
759+
760+
def _apply_binop(self, other, op: ops.BinaryOp) -> Index:
761+
# TODO: Handle local objects, or objects not implicitly alignable? Gets ambiguous with partial ordering though
762+
if isinstance(other, (bigframes.series.Series, Index)):
763+
other = Index(other)
764+
if other.nlevels != self.nlevels:
765+
raise ValueError("Dimensions do not match")
766+
767+
lexpr = self._block.expr
768+
rexpr = other._block.expr
769+
join_result = lexpr.try_row_join(rexpr)
770+
if join_result is None:
771+
raise ValueError("Cannot align objects")
772+
773+
expr, (lmap, rmap) = join_result
774+
775+
expr, res_ids = expr.compute_values(
776+
[
777+
op.as_expr(lmap[lid], rmap[rid])
778+
for lid, rid in zip(lexpr.column_ids, rexpr.column_ids)
779+
]
780+
)
781+
return Index(
782+
blocks.Block(
783+
expr.select_columns(res_ids),
784+
index_columns=res_ids,
785+
column_labels=[],
786+
index_labels=[None] * len(res_ids),
787+
)
788+
)
789+
elif (
790+
isinstance(other, bigframes.dtypes.LOCAL_SCALAR_TYPES) and self.nlevels == 1
791+
):
792+
block, id = self._block.project_expr(
793+
op.as_expr(self._block.index_columns[0], ex.const(other))
794+
)
795+
return Index(block.select_column(id))
796+
elif isinstance(other, tuple) and len(other) == self.nlevels:
797+
block = self._block.project_exprs(
798+
[
799+
op.as_expr(self._block.index_columns[i], ex.const(other[i]))
800+
for i in range(self.nlevels)
801+
],
802+
labels=[None] * self.nlevels,
803+
drop=True,
804+
)
805+
return Index(block.set_index(block.value_columns))
806+
else:
807+
return NotImplemented
808+
757809

758810
def _should_create_datetime_index(block: blocks.Block) -> bool:
759811
if len(block.index.dtypes) != 1:

bigframes/core/indexes/multi.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
import bigframes_vendored.pandas.core.indexes.multi as vendored_pandas_multindex
2020
import pandas
2121

22+
from bigframes.core import blocks
23+
from bigframes.core import expression as ex
2224
from bigframes.core.indexes.base import Index
2325

2426

@@ -46,3 +48,26 @@ def from_arrays(
4648
pd_index = pandas.MultiIndex.from_arrays(arrays, sortorder, names)
4749
# Index.__new__ should detect multiple levels and properly create a multiindex
4850
return cast(MultiIndex, Index(pd_index))
51+
52+
def __eq__(self, other) -> Index: # type: ignore
53+
import bigframes.operations as ops
54+
import bigframes.operations.aggregations as agg_ops
55+
56+
eq_result = self._apply_binop(other, ops.eq_op)._block.expr
57+
58+
as_array = ops.ToArrayOp().as_expr(
59+
*(
60+
ops.fillna_op.as_expr(col, ex.const(False))
61+
for col in eq_result.column_ids
62+
)
63+
)
64+
reduced = ops.ArrayReduceOp(agg_ops.all_op).as_expr(as_array)
65+
result_expr, result_ids = eq_result.compute_values([reduced])
66+
return Index(
67+
blocks.Block(
68+
result_expr.select_columns(result_ids),
69+
index_columns=result_ids,
70+
column_labels=(),
71+
index_labels=[None],
72+
)
73+
)

0 commit comments

Comments
 (0)