Skip to content

Commit 06efb39

Browse files
authored
Merge branch 'main' into series-input
2 parents 081de7a + 770918e commit 06efb39

File tree

32 files changed

+803
-87
lines changed

32 files changed

+803
-87
lines changed

bigframes/core/compile/sqlglot/aggregations/nullary_compiler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
from bigframes.core import window_spec
2222
import bigframes.core.compile.sqlglot.aggregations.op_registration as reg
23-
from bigframes.core.compile.sqlglot.aggregations.utils import apply_window_if_present
23+
from bigframes.core.compile.sqlglot.aggregations.windows import apply_window_if_present
2424
from bigframes.operations import aggregations as agg_ops
2525

2626
NULLARY_OP_REGISTRATION = reg.OpRegistration()

bigframes/core/compile/sqlglot/aggregations/unary_compiler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
from bigframes.core import window_spec
2222
import bigframes.core.compile.sqlglot.aggregations.op_registration as reg
23-
from bigframes.core.compile.sqlglot.aggregations.utils import apply_window_if_present
23+
from bigframes.core.compile.sqlglot.aggregations.windows import apply_window_if_present
2424
import bigframes.core.compile.sqlglot.expressions.typed_expr as typed_expr
2525
import bigframes.core.compile.sqlglot.sqlglot_ir as ir
2626
from bigframes.operations import aggregations as agg_ops

bigframes/core/compile/sqlglot/aggregations/utils.py

Lines changed: 0 additions & 29 deletions
This file was deleted.
Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from __future__ import annotations
15+
16+
import typing
17+
18+
import sqlglot.expressions as sge
19+
20+
from bigframes.core import utils, window_spec
21+
import bigframes.core.compile.sqlglot.scalar_compiler as scalar_compiler
22+
import bigframes.core.ordering as ordering_spec
23+
24+
25+
def apply_window_if_present(
26+
value: sge.Expression,
27+
window: typing.Optional[window_spec.WindowSpec] = None,
28+
) -> sge.Expression:
29+
if window is None:
30+
return value
31+
32+
if window.is_row_bounded and not window.ordering:
33+
raise ValueError("No ordering provided for ordered analytic function")
34+
elif (
35+
not window.is_row_bounded
36+
and not window.is_range_bounded
37+
and not window.ordering
38+
):
39+
# Unbound grouping window.
40+
order_by = None
41+
elif window.is_range_bounded:
42+
# Note that, when the window is range-bounded, we only need one ordering key.
43+
# There are two reasons:
44+
# 1. Manipulating null positions requires more than one ordering key, which
45+
# is forbidden by SQL window syntax for range rolling.
46+
# 2. Pandas does not allow range rolling on timeseries with nulls.
47+
order_by = get_window_order_by((window.ordering[0],), override_null_order=False)
48+
else:
49+
order_by = get_window_order_by(window.ordering, override_null_order=True)
50+
51+
order = sge.Order(expressions=order_by) if order_by else None
52+
53+
group_by = (
54+
[scalar_compiler.compile_scalar_expression(key) for key in window.grouping_keys]
55+
if window.grouping_keys
56+
else None
57+
)
58+
59+
# This is the key change. Don't create a spec for the default window frame
60+
# if there's no ordering. This avoids generating an `ORDER BY NULL` clause.
61+
if not window.bounds and not order:
62+
return sge.Window(this=value, partition_by=group_by)
63+
64+
kind = (
65+
"ROWS" if isinstance(window.bounds, window_spec.RowsWindowBounds) else "RANGE"
66+
)
67+
68+
start: typing.Union[int, float, None] = None
69+
end: typing.Union[int, float, None] = None
70+
if isinstance(window.bounds, window_spec.RangeWindowBounds):
71+
if window.bounds.start is not None:
72+
start = utils.timedelta_to_micros(window.bounds.start)
73+
if window.bounds.end is not None:
74+
end = utils.timedelta_to_micros(window.bounds.end)
75+
elif window.bounds:
76+
start = window.bounds.start
77+
end = window.bounds.end
78+
79+
start_value, start_side = _get_window_bounds(start, is_preceding=True)
80+
end_value, end_side = _get_window_bounds(end, is_preceding=False)
81+
82+
spec = sge.WindowSpec(
83+
kind=kind,
84+
start=start_value,
85+
start_side=start_side,
86+
end=end_value,
87+
end_side=end_side,
88+
over="OVER",
89+
)
90+
91+
return sge.Window(this=value, partition_by=group_by, order=order, spec=spec)
92+
93+
94+
def get_window_order_by(
95+
ordering: typing.Tuple[ordering_spec.OrderingExpression, ...],
96+
override_null_order: bool = False,
97+
) -> typing.Optional[tuple[sge.Ordered, ...]]:
98+
"""Returns the SQL order by clause for a window specification."""
99+
if not ordering:
100+
return None
101+
102+
order_by = []
103+
for ordering_spec_item in ordering:
104+
expr = scalar_compiler.compile_scalar_expression(
105+
ordering_spec_item.scalar_expression
106+
)
107+
desc = not ordering_spec_item.direction.is_ascending
108+
nulls_first = not ordering_spec_item.na_last
109+
110+
if override_null_order:
111+
# Bigquery SQL considers NULLS to be "smallest" values, but we need
112+
# to override in these cases.
113+
is_null_expr = sge.Is(this=expr, expression=sge.Null())
114+
if nulls_first and desc:
115+
order_by.append(
116+
sge.Ordered(
117+
this=is_null_expr,
118+
desc=desc,
119+
nulls_first=nulls_first,
120+
)
121+
)
122+
elif not nulls_first and not desc:
123+
order_by.append(
124+
sge.Ordered(
125+
this=is_null_expr,
126+
desc=desc,
127+
nulls_first=nulls_first,
128+
)
129+
)
130+
131+
order_by.append(
132+
sge.Ordered(
133+
this=expr,
134+
desc=desc,
135+
nulls_first=nulls_first,
136+
)
137+
)
138+
return tuple(order_by)
139+
140+
141+
def _get_window_bounds(
142+
value, is_preceding: bool
143+
) -> tuple[typing.Union[str, sge.Expression], typing.Optional[str]]:
144+
"""Compiles a single boundary value into its SQL components."""
145+
if value is None:
146+
side = "PRECEDING" if is_preceding else "FOLLOWING"
147+
return "UNBOUNDED", side
148+
149+
if value == 0:
150+
return "CURRENT ROW", None
151+
152+
side = "PRECEDING" if value < 0 else "FOLLOWING"
153+
return sge.convert(abs(value)), side

bigframes/core/compile/sqlglot/compiler.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from bigframes.core import expression, guid, identifiers, nodes, pyarrow_utils, rewrite
2424
from bigframes.core.compile import configs
2525
import bigframes.core.compile.sqlglot.aggregate_compiler as aggregate_compiler
26+
from bigframes.core.compile.sqlglot.aggregations import windows
2627
from bigframes.core.compile.sqlglot.expressions import typed_expr
2728
import bigframes.core.compile.sqlglot.scalar_compiler as scalar_compiler
2829
import bigframes.core.compile.sqlglot.sqlglot_ir as ir
@@ -272,18 +273,16 @@ def compile_random_sample(
272273
def compile_aggregate(
273274
self, node: nodes.AggregateNode, child: ir.SQLGlotIR
274275
) -> ir.SQLGlotIR:
275-
ordering_cols = tuple(
276-
sge.Ordered(
277-
this=scalar_compiler.compile_scalar_expression(
278-
ordering.scalar_expression
279-
),
280-
desc=ordering.direction.is_ascending is False,
281-
nulls_first=ordering.na_last is False,
282-
)
283-
for ordering in node.order_by
276+
ordering_cols = windows.get_window_order_by(
277+
node.order_by, override_null_order=True
284278
)
285279
aggregations: tuple[tuple[str, sge.Expression], ...] = tuple(
286-
(id.sql, aggregate_compiler.compile_aggregate(agg, order_by=ordering_cols))
280+
(
281+
id.sql,
282+
aggregate_compiler.compile_aggregate(
283+
agg, order_by=ordering_cols if ordering_cols else ()
284+
),
285+
)
287286
for agg, id in node.aggregations
288287
)
289288
by_cols: tuple[sge.Expression, ...] = tuple(

bigframes/core/compile/sqlglot/expressions/unary_compiler.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616

1717
import typing
1818

19+
import pandas as pd
20+
import pyarrow as pa
1921
import sqlglot
2022
import sqlglot.expressions as sge
2123

@@ -105,6 +107,12 @@ def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression:
105107
)
106108

107109

110+
@UNARY_OP_REGISTRATION.register(ops.AsTypeOp)
111+
def _(op: ops.AsTypeOp, expr: TypedExpr) -> sge.Expression:
112+
# TODO: Support more types for casting, such as JSON, etc.
113+
return sge.Cast(this=expr.expr, to=op.to_type)
114+
115+
108116
@UNARY_OP_REGISTRATION.register(ops.ArrayToStringOp)
109117
def _(op: ops.ArrayToStringOp, expr: TypedExpr) -> sge.Expression:
110118
return sge.ArrayToString(this=expr.expr, expression=f"'{op.delimiter}'")
@@ -234,6 +242,12 @@ def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression:
234242
) - sge.convert(1)
235243

236244

245+
@UNARY_OP_REGISTRATION.register(ops.FloorDtOp)
246+
def _(op: ops.FloorDtOp, expr: TypedExpr) -> sge.Expression:
247+
# TODO: Remove this method when it is covered by ops.FloorOp
248+
return sge.TimestampTrunc(this=expr.expr, unit=sge.Identifier(this=op.freq))
249+
250+
237251
@UNARY_OP_REGISTRATION.register(ops.floor_op)
238252
def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression:
239253
return sge.Floor(this=expr.expr)
@@ -249,6 +263,26 @@ def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression:
249263
return sge.func("ST_ASTEXT", expr.expr)
250264

251265

266+
@UNARY_OP_REGISTRATION.register(ops.geo_st_boundary_op)
267+
def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression:
268+
return sge.func("ST_BOUNDARY", expr.expr)
269+
270+
271+
@UNARY_OP_REGISTRATION.register(ops.geo_st_geogfromtext_op)
272+
def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression:
273+
return sge.func("SAFE.ST_GEOGFROMTEXT", expr.expr)
274+
275+
276+
@UNARY_OP_REGISTRATION.register(ops.geo_st_isclosed_op)
277+
def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression:
278+
return sge.func("ST_ISCLOSED", expr.expr)
279+
280+
281+
@UNARY_OP_REGISTRATION.register(ops.GeoStLengthOp)
282+
def _(op: ops.GeoStLengthOp, expr: TypedExpr) -> sge.Expression:
283+
return sge.func("ST_LENGTH", expr.expr)
284+
285+
252286
@UNARY_OP_REGISTRATION.register(ops.geo_x_op)
253287
def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression:
254288
return sge.func("SAFE.ST_X", expr.expr)
@@ -274,6 +308,11 @@ def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression:
274308
return sge.BitwiseNot(this=expr.expr)
275309

276310

311+
@UNARY_OP_REGISTRATION.register(ops.IsInOp)
312+
def _(op: ops.IsInOp, expr: TypedExpr) -> sge.Expression:
313+
return sge.In(this=expr.expr, expressions=[sge.convert(v) for v in op.values])
314+
315+
277316
@UNARY_OP_REGISTRATION.register(ops.isalnum_op)
278317
def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression:
279318
return sge.RegexpLike(this=expr.expr, expression=sge.convert(r"^(\p{N}|\p{L})+$"))
@@ -517,6 +556,26 @@ def _(op: ops.StrSliceOp, expr: TypedExpr) -> sge.Expression:
517556
)
518557

519558

559+
@UNARY_OP_REGISTRATION.register(ops.StrftimeOp)
560+
def _(op: ops.StrftimeOp, expr: TypedExpr) -> sge.Expression:
561+
return sge.func("FORMAT_TIMESTAMP", sge.convert(op.date_format), expr.expr)
562+
563+
564+
@UNARY_OP_REGISTRATION.register(ops.StructFieldOp)
565+
def _(op: ops.StructFieldOp, expr: TypedExpr) -> sge.Expression:
566+
if isinstance(op.name_or_index, str):
567+
name = op.name_or_index
568+
else:
569+
pa_type = typing.cast(pd.ArrowDtype, expr.dtype)
570+
pa_struct_type = typing.cast(pa.StructType, pa_type.pyarrow_dtype)
571+
name = pa_struct_type.field(op.name_or_index).name
572+
573+
return sge.Column(
574+
this=sge.to_identifier(name, quoted=True),
575+
catalog=expr.expr,
576+
)
577+
578+
520579
@UNARY_OP_REGISTRATION.register(ops.tan_op)
521580
def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression:
522581
return sge.func("TAN", expr.expr)
@@ -537,6 +596,36 @@ def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression:
537596
return sge.Floor(this=expr.expr)
538597

539598

599+
@UNARY_OP_REGISTRATION.register(ops.ToDatetimeOp)
600+
def _(op: ops.ToDatetimeOp, expr: TypedExpr) -> sge.Expression:
601+
return sge.Cast(this=sge.func("TIMESTAMP_SECONDS", expr.expr), to="DATETIME")
602+
603+
604+
@UNARY_OP_REGISTRATION.register(ops.ToTimestampOp)
605+
def _(op: ops.ToTimestampOp, expr: TypedExpr) -> sge.Expression:
606+
return sge.func("TIMESTAMP_SECONDS", expr.expr)
607+
608+
609+
@UNARY_OP_REGISTRATION.register(ops.ToTimedeltaOp)
610+
def _(op: ops.ToTimedeltaOp, expr: TypedExpr) -> sge.Expression:
611+
return sge.Interval(this=expr.expr, unit=sge.Identifier(this="SECOND"))
612+
613+
614+
@UNARY_OP_REGISTRATION.register(ops.UnixMicros)
615+
def _(op: ops.UnixMicros, expr: TypedExpr) -> sge.Expression:
616+
return sge.func("UNIX_MICROS", expr.expr)
617+
618+
619+
@UNARY_OP_REGISTRATION.register(ops.UnixMillis)
620+
def _(op: ops.UnixMillis, expr: TypedExpr) -> sge.Expression:
621+
return sge.func("UNIX_MILLIS", expr.expr)
622+
623+
624+
@UNARY_OP_REGISTRATION.register(ops.UnixSeconds)
625+
def _(op: ops.UnixSeconds, expr: TypedExpr) -> sge.Expression:
626+
return sge.func("UNIX_SECONDS", expr.expr)
627+
628+
540629
# JSON Ops
541630
@UNARY_OP_REGISTRATION.register(ops.JSONExtract)
542631
def _(op: ops.JSONExtract, expr: TypedExpr) -> sge.Expression:

bigframes/functions/function_typing.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,8 @@ def __init__(self, type_, supported_types):
6161
self.type = type_
6262
self.supported_types = supported_types
6363
super().__init__(
64-
f"'{type_}' is not one of the supported types {supported_types}"
64+
f"'{type_}' must be one of the supported types ({supported_types}) "
65+
"or a list of one of those types."
6566
)
6667

6768

0 commit comments

Comments
 (0)