Skip to content

Commit 0e40d69

Browse files
committed
merge main
2 parents 8fac06c + bc33c98 commit 0e40d69

File tree

34 files changed

+620
-87
lines changed

34 files changed

+620
-87
lines changed

bigframes/bigquery/_operations/ai.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def generate(
8888
or pandas Series.
8989
connection_id (str, optional):
9090
Specifies the connection to use to communicate with the model. For example, `myproject.us.myconnection`.
91-
If not provided, the connection from the current session will be used.
91+
If not provided, the query uses your end-user credential.
9292
endpoint (str, optional):
9393
Specifies the Vertex AI endpoint to use for the model. For example `"gemini-2.5-flash"`. You can specify any
9494
generally available or preview Gemini model. If you specify the model name, BigQuery ML automatically identifies and
@@ -131,7 +131,7 @@ def generate(
131131

132132
operator = ai_ops.AIGenerate(
133133
prompt_context=tuple(prompt_context),
134-
connection_id=_resolve_connection_id(series_list[0], connection_id),
134+
connection_id=connection_id,
135135
endpoint=endpoint,
136136
request_type=request_type,
137137
model_params=json.dumps(model_params) if model_params else None,
@@ -186,7 +186,7 @@ def generate_bool(
186186
or pandas Series.
187187
connection_id (str, optional):
188188
Specifies the connection to use to communicate with the model. For example, `myproject.us.myconnection`.
189-
If not provided, the connection from the current session will be used.
189+
If not provided, the query uses your end-user credential.
190190
endpoint (str, optional):
191191
Specifies the Vertex AI endpoint to use for the model. For example `"gemini-2.5-flash"`. You can specify any
192192
generally available or preview Gemini model. If you specify the model name, BigQuery ML automatically identifies and
@@ -216,7 +216,7 @@ def generate_bool(
216216

217217
operator = ai_ops.AIGenerateBool(
218218
prompt_context=tuple(prompt_context),
219-
connection_id=_resolve_connection_id(series_list[0], connection_id),
219+
connection_id=connection_id,
220220
endpoint=endpoint,
221221
request_type=request_type,
222222
model_params=json.dumps(model_params) if model_params else None,
@@ -267,7 +267,7 @@ def generate_int(
267267
or pandas Series.
268268
connection_id (str, optional):
269269
Specifies the connection to use to communicate with the model. For example, `myproject.us.myconnection`.
270-
If not provided, the connection from the current session will be used.
270+
If not provided, the query uses your end-user credential.
271271
endpoint (str, optional):
272272
Specifies the Vertex AI endpoint to use for the model. For example `"gemini-2.5-flash"`. You can specify any
273273
generally available or preview Gemini model. If you specify the model name, BigQuery ML automatically identifies and
@@ -297,7 +297,7 @@ def generate_int(
297297

298298
operator = ai_ops.AIGenerateInt(
299299
prompt_context=tuple(prompt_context),
300-
connection_id=_resolve_connection_id(series_list[0], connection_id),
300+
connection_id=connection_id,
301301
endpoint=endpoint,
302302
request_type=request_type,
303303
model_params=json.dumps(model_params) if model_params else None,
@@ -348,7 +348,7 @@ def generate_double(
348348
or pandas Series.
349349
connection_id (str, optional):
350350
Specifies the connection to use to communicate with the model. For example, `myproject.us.myconnection`.
351-
If not provided, the connection from the current session will be used.
351+
If not provided, the query uses your end-user credential.
352352
endpoint (str, optional):
353353
Specifies the Vertex AI endpoint to use for the model. For example `"gemini-2.5-flash"`. You can specify any
354354
generally available or preview Gemini model. If you specify the model name, BigQuery ML automatically identifies and
@@ -378,7 +378,7 @@ def generate_double(
378378

379379
operator = ai_ops.AIGenerateDouble(
380380
prompt_context=tuple(prompt_context),
381-
connection_id=_resolve_connection_id(series_list[0], connection_id),
381+
connection_id=connection_id,
382382
endpoint=endpoint,
383383
request_type=request_type,
384384
model_params=json.dumps(model_params) if model_params else None,

bigframes/core/compile/sqlglot/aggregations/ordered_unary_compiler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def compile(
2727
op: agg_ops.WindowOp,
2828
column: typed_expr.TypedExpr,
2929
*,
30-
order_by: tuple[sge.Expression, ...],
30+
order_by: tuple[sge.Expression, ...] = (),
3131
) -> sge.Expression:
3232
return ORDERED_UNARY_OP_REGISTRATION[op](op, column, order_by=order_by)
3333

bigframes/core/compile/sqlglot/aggregations/unary_compiler.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,19 @@ def _(
4949
return sge.func("IFNULL", result, sge.true())
5050

5151

52+
@UNARY_OP_REGISTRATION.register(agg_ops.AnyOp)
53+
def _(
54+
op: agg_ops.AnyOp,
55+
column: typed_expr.TypedExpr,
56+
window: typing.Optional[window_spec.WindowSpec] = None,
57+
) -> sge.Expression:
58+
expr = column.expr
59+
expr = apply_window_if_present(sge.func("LOGICAL_OR", expr), window)
60+
61+
# BQ will return null for empty column, result would be false in pandas.
62+
return sge.func("COALESCE", expr, sge.convert(False))
63+
64+
5265
@UNARY_OP_REGISTRATION.register(agg_ops.ApproxQuartilesOp)
5366
def _(
5467
op: agg_ops.ApproxQuartilesOp,

bigframes/core/compile/sqlglot/expressions/ai_ops.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -104,10 +104,13 @@ def _construct_named_args(op: ops.NaryOp) -> list[sge.Kwarg]:
104104

105105
op_args = asdict(op)
106106

107-
connection_id = op_args["connection_id"]
108-
args.append(
109-
sge.Kwarg(this="connection_id", expression=sge.Literal.string(connection_id))
110-
)
107+
connection_id = op_args.get("connection_id", None)
108+
if connection_id is not None:
109+
args.append(
110+
sge.Kwarg(
111+
this="connection_id", expression=sge.Literal.string(connection_id)
112+
)
113+
)
111114

112115
endpoit = op_args.get("endpoint", None)
113116
if endpoit is not None:

bigframes/core/compile/sqlglot/expressions/array_ops.py

Lines changed: 57 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,19 +16,16 @@
1616

1717
import typing
1818

19-
import sqlglot
19+
import sqlglot as sg
2020
import sqlglot.expressions as sge
2121

2222
from bigframes import operations as ops
2323
from bigframes.core.compile.sqlglot.expressions.typed_expr import TypedExpr
2424
import bigframes.core.compile.sqlglot.scalar_compiler as scalar_compiler
25+
import bigframes.dtypes as dtypes
2526

2627
register_unary_op = scalar_compiler.scalar_op_compiler.register_unary_op
27-
28-
29-
@register_unary_op(ops.ArrayToStringOp, pass_op=True)
30-
def _(expr: TypedExpr, op: ops.ArrayToStringOp) -> sge.Expression:
31-
return sge.ArrayToString(this=expr.expr, expression=f"'{op.delimiter}'")
28+
register_nary_op = scalar_compiler.scalar_op_compiler.register_nary_op
3229

3330

3431
@register_unary_op(ops.ArrayIndexOp, pass_op=True)
@@ -41,17 +38,45 @@ def _(expr: TypedExpr, op: ops.ArrayIndexOp) -> sge.Expression:
4138
)
4239

4340

41+
@register_unary_op(ops.ArrayReduceOp, pass_op=True)
42+
def _(expr: TypedExpr, op: ops.ArrayReduceOp) -> sge.Expression:
43+
sub_expr = sg.to_identifier("bf_arr_reduce_uid")
44+
sub_type = dtypes.get_array_inner_type(expr.dtype)
45+
46+
if op.aggregation.order_independent:
47+
from bigframes.core.compile.sqlglot.aggregations import unary_compiler
48+
49+
agg_expr = unary_compiler.compile(op.aggregation, TypedExpr(sub_expr, sub_type))
50+
else:
51+
from bigframes.core.compile.sqlglot.aggregations import ordered_unary_compiler
52+
53+
agg_expr = ordered_unary_compiler.compile(
54+
op.aggregation, TypedExpr(sub_expr, sub_type)
55+
)
56+
57+
return (
58+
sge.select(agg_expr)
59+
.from_(
60+
sge.Unnest(
61+
expressions=[expr.expr],
62+
alias=sge.TableAlias(columns=[sub_expr]),
63+
)
64+
)
65+
.subquery()
66+
)
67+
68+
4469
@register_unary_op(ops.ArraySliceOp, pass_op=True)
4570
def _(expr: TypedExpr, op: ops.ArraySliceOp) -> sge.Expression:
46-
slice_idx = sqlglot.to_identifier("slice_idx")
71+
slice_idx = sg.to_identifier("slice_idx")
4772

4873
conditions: typing.List[sge.Predicate] = [slice_idx >= op.start]
4974

5075
if op.stop is not None:
5176
conditions.append(slice_idx < op.stop)
5277

5378
# local name for each element in the array
54-
el = sqlglot.to_identifier("el")
79+
el = sg.to_identifier("el")
5580

5681
selected_elements = (
5782
sge.select(el)
@@ -66,3 +91,27 @@ def _(expr: TypedExpr, op: ops.ArraySliceOp) -> sge.Expression:
6691
)
6792

6893
return sge.array(selected_elements)
94+
95+
96+
@register_unary_op(ops.ArrayToStringOp, pass_op=True)
97+
def _(expr: TypedExpr, op: ops.ArrayToStringOp) -> sge.Expression:
98+
return sge.ArrayToString(this=expr.expr, expression=f"'{op.delimiter}'")
99+
100+
101+
@register_nary_op(ops.ToArrayOp)
102+
def _(*exprs: TypedExpr) -> sge.Expression:
103+
do_upcast_bool = any(
104+
dtypes.is_numeric(expr.dtype, include_bool=False) for expr in exprs
105+
)
106+
if do_upcast_bool:
107+
sg_exprs = [_coerce_bool_to_int(expr) for expr in exprs]
108+
else:
109+
sg_exprs = [expr.expr for expr in exprs]
110+
return sge.Array(expressions=sg_exprs)
111+
112+
113+
def _coerce_bool_to_int(typed_expr: TypedExpr) -> sge.Expression:
114+
"""Coerce boolean expression to integer."""
115+
if typed_expr.dtype == dtypes.BOOL_DTYPE:
116+
return sge.Cast(this=typed_expr.expr, to="INT64")
117+
return typed_expr.expr

bigframes/display/anywidget.py

Lines changed: 44 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,11 @@ class TableWidget(WIDGET_BASE):
5656

5757
page = traitlets.Int(0).tag(sync=True)
5858
page_size = traitlets.Int(0).tag(sync=True)
59-
row_count = traitlets.Int(0).tag(sync=True)
59+
row_count = traitlets.Union(
60+
[traitlets.Int(), traitlets.Instance(type(None))],
61+
default_value=None,
62+
allow_none=True,
63+
).tag(sync=True)
6064
table_html = traitlets.Unicode().tag(sync=True)
6165
sort_column = traitlets.Unicode("").tag(sync=True)
6266
sort_ascending = traitlets.Bool(True).tag(sync=True)
@@ -103,12 +107,17 @@ def __init__(self, dataframe: bigframes.dataframe.DataFrame):
103107
# obtain the row counts
104108
# TODO(b/428238610): Start iterating over the result of `to_pandas_batches()`
105109
# before we get here so that the count might already be cached.
106-
# TODO(b/452747934): Allow row_count to be None and check to see if
107-
# there are multiple pages and show "page 1 of many" in this case
108110
self._reset_batches_for_new_page_size()
109-
if self._batches is None or self._batches.total_rows is None:
110-
self._error_message = "Could not determine total row count. Data might be unavailable or an error occurred."
111-
self.row_count = 0
111+
112+
if self._batches is None:
113+
self._error_message = "Could not retrieve data batches. Data might be unavailable or an error occurred."
114+
self.row_count = None
115+
elif self._batches.total_rows is None:
116+
# Total rows is unknown, this is an expected state.
117+
# TODO(b/461536343): Cheaply discover if we have exactly 1 page.
118+
# There are cases where total rows is not set, but there are no additional
119+
# pages. We could disable the "next" button in these cases.
120+
self.row_count = None
112121
else:
113122
self.row_count = self._batches.total_rows
114123

@@ -141,11 +150,22 @@ def _validate_page(self, proposal: Dict[str, Any]) -> int:
141150
Returns:
142151
The validated and clamped page number as an integer.
143152
"""
144-
145153
value = proposal["value"]
154+
155+
if value < 0:
156+
raise ValueError("Page number cannot be negative.")
157+
158+
# If truly empty or invalid page size, stay on page 0.
159+
# This handles cases where row_count is 0 or page_size is 0, preventing
160+
# division by zero or nonsensical pagination, regardless of row_count being None.
146161
if self.row_count == 0 or self.page_size == 0:
147162
return 0
148163

164+
# If row count is unknown, allow any non-negative page. The previous check
165+
# ensures that invalid page_size (0) is already handled.
166+
if self.row_count is None:
167+
return value
168+
149169
# Calculate the zero-indexed maximum page number.
150170
max_page = max(0, math.ceil(self.row_count / self.page_size) - 1)
151171

@@ -260,6 +280,23 @@ def _set_table_html(self) -> None:
260280
# Get the data for the current page
261281
page_data = cached_data.iloc[start:end]
262282

283+
# Handle case where user navigated beyond available data with unknown row count
284+
is_unknown_count = self.row_count is None
285+
is_beyond_data = self._all_data_loaded and len(page_data) == 0 and self.page > 0
286+
if is_unknown_count and is_beyond_data:
287+
# Calculate the last valid page (zero-indexed)
288+
total_rows = len(cached_data)
289+
if total_rows > 0:
290+
last_valid_page = max(0, math.ceil(total_rows / self.page_size) - 1)
291+
# Navigate back to the last valid page
292+
self.page = last_valid_page
293+
# Recursively call to display the correct page
294+
return self._set_table_html()
295+
else:
296+
# If no data at all, stay on page 0 with empty display
297+
self.page = 0
298+
return self._set_table_html()
299+
263300
# Generate HTML table
264301
self.table_html = bigframes.display.html.render_html(
265302
dataframe=page_data,

bigframes/display/table_widget.js

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -99,14 +99,21 @@ function render({ model, el }) {
9999
const rowCount = model.get(ModelProperty.ROW_COUNT);
100100
const pageSize = model.get(ModelProperty.PAGE_SIZE);
101101
const currentPage = model.get(ModelProperty.PAGE);
102-
const totalPages = Math.ceil(rowCount / pageSize);
103-
104-
rowCountLabel.textContent = `${rowCount.toLocaleString()} total rows`;
105-
paginationLabel.textContent = `Page ${(
106-
currentPage + 1
107-
).toLocaleString()} of ${(totalPages || 1).toLocaleString()}`;
108-
prevPage.disabled = currentPage === 0;
109-
nextPage.disabled = currentPage >= totalPages - 1;
102+
103+
if (rowCount === null) {
104+
// Unknown total rows
105+
rowCountLabel.textContent = "Total rows unknown";
106+
paginationLabel.textContent = `Page ${(currentPage + 1).toLocaleString()} of many`;
107+
prevPage.disabled = currentPage === 0;
108+
nextPage.disabled = false; // Allow navigation until we hit the end
109+
} else {
110+
// Known total rows
111+
const totalPages = Math.ceil(rowCount / pageSize);
112+
rowCountLabel.textContent = `${rowCount.toLocaleString()} total rows`;
113+
paginationLabel.textContent = `Page ${(currentPage + 1).toLocaleString()} of ${rowCount.toLocaleString()}`;
114+
prevPage.disabled = currentPage === 0;
115+
nextPage.disabled = currentPage >= totalPages - 1;
116+
}
110117
pageSizeSelect.value = pageSize;
111118
}
112119

bigframes/operations/ai_ops.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ class AIGenerate(base_ops.NaryOp):
2929
name: ClassVar[str] = "ai_generate"
3030

3131
prompt_context: Tuple[str | None, ...]
32-
connection_id: str
32+
connection_id: str | None
3333
endpoint: str | None
3434
request_type: Literal["dedicated", "shared", "unspecified"]
3535
model_params: str | None
@@ -57,7 +57,7 @@ class AIGenerateBool(base_ops.NaryOp):
5757
name: ClassVar[str] = "ai_generate_bool"
5858

5959
prompt_context: Tuple[str | None, ...]
60-
connection_id: str
60+
connection_id: str | None
6161
endpoint: str | None
6262
request_type: Literal["dedicated", "shared", "unspecified"]
6363
model_params: str | None
@@ -79,7 +79,7 @@ class AIGenerateInt(base_ops.NaryOp):
7979
name: ClassVar[str] = "ai_generate_int"
8080

8181
prompt_context: Tuple[str | None, ...]
82-
connection_id: str
82+
connection_id: str | None
8383
endpoint: str | None
8484
request_type: Literal["dedicated", "shared", "unspecified"]
8585
model_params: str | None
@@ -101,7 +101,7 @@ class AIGenerateDouble(base_ops.NaryOp):
101101
name: ClassVar[str] = "ai_generate_double"
102102

103103
prompt_context: Tuple[str | None, ...]
104-
connection_id: str
104+
connection_id: str | None
105105
endpoint: str | None
106106
request_type: Literal["dedicated", "shared", "unspecified"]
107107
model_params: str | None

0 commit comments

Comments
 (0)