Skip to content

Commit a62198a

Browse files
committed
refactor: add compile_window to the sqlglot compiler
1 parent 92a2377 commit a62198a

File tree

17 files changed

+591
-82
lines changed

17 files changed

+591
-82
lines changed

bigframes/core/compile/sqlglot/aggregate_compiler.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
import sqlglot.expressions as sge
1717

18-
from bigframes.core import expression
18+
from bigframes.core import expression, window_spec
1919
from bigframes.core.compile.sqlglot.aggregations import (
2020
binary_compiler,
2121
nullary_compiler,
@@ -56,3 +56,21 @@ def compile_aggregate(
5656
return binary_compiler.compile(aggregate.op, left, right)
5757
else:
5858
raise ValueError(f"Unexpected aggregation: {aggregate}")
59+
60+
61+
def compile_analytic(
62+
aggregate: expression.Aggregation,
63+
window: window_spec.WindowSpec,
64+
) -> sge.Expression:
65+
if isinstance(aggregate, expression.NullaryAggregation):
66+
return nullary_compiler.compile(aggregate.op)
67+
if isinstance(aggregate, expression.UnaryAggregation):
68+
column = typed_expr.TypedExpr(
69+
scalar_compiler.compile_scalar_expression(aggregate.arg),
70+
aggregate.arg.output_type,
71+
)
72+
return unary_compiler.compile(aggregate.op, column, window)
73+
elif isinstance(aggregate, expression.BinaryAggregation):
74+
raise NotImplementedError("binary analytic operations not yet supported")
75+
else:
76+
raise ValueError(f"Unexpected analytic operation: {aggregate}")

bigframes/core/compile/sqlglot/aggregations/nullary_compiler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
from bigframes.core import window_spec
2222
import bigframes.core.compile.sqlglot.aggregations.op_registration as reg
23-
from bigframes.core.compile.sqlglot.aggregations.utils import apply_window_if_present
23+
from bigframes.core.compile.sqlglot.aggregations.window import apply_window_if_present
2424
from bigframes.operations import aggregations as agg_ops
2525

2626
NULLARY_OP_REGISTRATION = reg.OpRegistration()

bigframes/core/compile/sqlglot/aggregations/unary_compiler.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,10 @@
1818

1919
import sqlglot.expressions as sge
2020

21+
from bigframes import dtypes
2122
from bigframes.core import window_spec
2223
import bigframes.core.compile.sqlglot.aggregations.op_registration as reg
23-
from bigframes.core.compile.sqlglot.aggregations.utils import apply_window_if_present
24+
from bigframes.core.compile.sqlglot.aggregations.window import apply_window_if_present
2425
import bigframes.core.compile.sqlglot.expressions.typed_expr as typed_expr
2526
import bigframes.core.compile.sqlglot.sqlglot_ir as ir
2627
from bigframes.operations import aggregations as agg_ops
@@ -42,8 +43,11 @@ def _(
4243
column: typed_expr.TypedExpr,
4344
window: typing.Optional[window_spec.WindowSpec] = None,
4445
) -> sge.Expression:
46+
expr = column.expr
47+
if column.dtype == dtypes.BOOL_DTYPE:
48+
expr = sge.Cast(this=column.expr, to="INT64")
4549
# Will be null if all inputs are null. Pandas defaults to zero sum though.
46-
expr = apply_window_if_present(sge.func("SUM", column.expr), window)
50+
expr = apply_window_if_present(sge.func("SUM", expr), window)
4751
return sge.func("IFNULL", expr, ir._literal(0, column.dtype))
4852

4953

bigframes/core/compile/sqlglot/aggregations/utils.py

Lines changed: 0 additions & 29 deletions
This file was deleted.
Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from __future__ import annotations
15+
16+
import typing
17+
18+
import sqlglot.expressions as sge
19+
20+
from bigframes.core import utils, window_spec
21+
import bigframes.core.compile.sqlglot.scalar_compiler as scalar_compiler
22+
import bigframes.core.ordering as ordering_spec
23+
24+
25+
def apply_window_if_present(
26+
value: sge.Expression,
27+
window: typing.Optional[window_spec.WindowSpec] = None,
28+
) -> sge.Expression:
29+
if window is None:
30+
return value
31+
32+
kind = (
33+
"ROWS" if isinstance(window.bounds, window_spec.RowsWindowBounds) else "RANGE"
34+
)
35+
36+
start: typing.Union[int, float, None] = None
37+
end: typing.Union[int, float, None] = None
38+
if isinstance(window.bounds, window_spec.RangeWindowBounds):
39+
if window.bounds.start is not None:
40+
start = utils.timedelta_to_micros(window.bounds.start)
41+
if window.bounds.end is not None:
42+
end = utils.timedelta_to_micros(window.bounds.end)
43+
elif window.bounds:
44+
start = window.bounds.start
45+
end = window.bounds.end
46+
47+
start_value, start_side = _get_window_bounds(start, is_start_bounds=True)
48+
end_value, end_side = _get_window_bounds(end, is_start_bounds=False)
49+
50+
spec = sge.WindowSpec(
51+
kind=kind,
52+
start=start_value,
53+
start_side=start_side,
54+
end=end_value,
55+
end_side=end_side,
56+
over="OVER",
57+
)
58+
59+
if window.is_row_bounded and not window.ordering:
60+
raise ValueError("No ordering provided for ordered analytic function")
61+
elif (
62+
not window.is_row_bounded
63+
and not window.is_range_bounded
64+
and not window.ordering
65+
):
66+
# Unbound grouping window.
67+
order_by = None
68+
elif window.is_range_bounded:
69+
# Note that, when the window is range-bounded, we only need one ordering key.
70+
# There are two reasons:
71+
# 1. Manipulating null positions requires more than one ordering key, which
72+
# is forbidden by SQL window syntax for range rolling.
73+
# 2. Pandas does not allow range rolling on timeseries with nulls.
74+
order_by = get_window_order_by((window.ordering[0],), override_null_order=False)
75+
else:
76+
order_by = get_window_order_by(window.ordering, override_null_order=True)
77+
78+
order = sge.Order(expressions=order_by) if order_by else None
79+
80+
group_by = (
81+
[scalar_compiler.compile_scalar_expression(key) for key in window.grouping_keys]
82+
if window.grouping_keys
83+
else None
84+
)
85+
86+
return sge.Window(this=value, partition_by=group_by, order=order, spec=spec)
87+
88+
89+
def get_window_order_by(
90+
ordering: typing.Tuple[ordering_spec.OrderingExpression, ...],
91+
override_null_order: bool = False,
92+
) -> typing.Optional[typing.List[sge.Ordered]]:
93+
"""Returns the SQL order by clause for a window specification."""
94+
if not ordering:
95+
return None
96+
97+
order_by = []
98+
for ordering_spec_item in ordering:
99+
expr = scalar_compiler.compile_scalar_expression(
100+
ordering_spec_item.scalar_expression
101+
)
102+
desc = not ordering_spec_item.direction.is_ascending
103+
nulls_first = not ordering_spec_item.na_last
104+
105+
if override_null_order:
106+
# Bigquery SQL considers NULLS to be "smallest" values, but we need
107+
# to override in these cases.
108+
is_null_expr = sge.Is(this=expr, expression=sge.Null())
109+
if nulls_first and desc:
110+
order_by.append(
111+
sge.Ordered(
112+
this=is_null_expr,
113+
desc=desc,
114+
nulls_first=nulls_first,
115+
)
116+
)
117+
elif not nulls_first and not desc:
118+
order_by.append(
119+
sge.Ordered(
120+
this=is_null_expr,
121+
desc=desc,
122+
nulls_first=nulls_first,
123+
)
124+
)
125+
126+
order_by.append(
127+
sge.Ordered(
128+
this=expr,
129+
desc=desc,
130+
nulls_first=nulls_first,
131+
)
132+
)
133+
return order_by
134+
135+
136+
def _get_window_bounds(
137+
value, is_start_bounds: bool
138+
) -> tuple[typing.Union[str, sge.Expression], typing.Optional[str]]:
139+
"""Compiles a single boundary value into its SQL components."""
140+
if value is None:
141+
side = "PRECEDING" if is_start_bounds else "FOLLOWING"
142+
return "UNBOUNDED", side
143+
144+
if value == 0:
145+
return "CURRENT ROW", None
146+
147+
side = "PRECEDING" if value < 0 else "FOLLOWING"
148+
return sge.convert(abs(value)), side

bigframes/core/compile/sqlglot/compiler.py

Lines changed: 76 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
from bigframes.core import expression, guid, identifiers, nodes, pyarrow_utils, rewrite
2424
from bigframes.core.compile import configs
2525
import bigframes.core.compile.sqlglot.aggregate_compiler as aggregate_compiler
26+
from bigframes.core.compile.sqlglot.aggregations import window
27+
from bigframes.core.compile.sqlglot.aggregations.window import apply_window_if_present
2628
from bigframes.core.compile.sqlglot.expressions import typed_expr
2729
import bigframes.core.compile.sqlglot.scalar_compiler as scalar_compiler
2830
import bigframes.core.compile.sqlglot.sqlglot_ir as ir
@@ -272,18 +274,16 @@ def compile_random_sample(
272274
def compile_aggregate(
273275
self, node: nodes.AggregateNode, child: ir.SQLGlotIR
274276
) -> ir.SQLGlotIR:
275-
ordering_cols = tuple(
276-
sge.Ordered(
277-
this=scalar_compiler.compile_scalar_expression(
278-
ordering.scalar_expression
279-
),
280-
desc=ordering.direction.is_ascending is False,
281-
nulls_first=ordering.na_last is False,
282-
)
283-
for ordering in node.order_by
277+
ordering_cols = window.get_window_order_by(
278+
node.order_by, override_null_order=True
284279
)
285280
aggregations: tuple[tuple[str, sge.Expression], ...] = tuple(
286-
(id.sql, aggregate_compiler.compile_aggregate(agg, order_by=ordering_cols))
281+
(
282+
id.sql,
283+
aggregate_compiler.compile_aggregate(
284+
agg, order_by=tuple(ordering_cols) if ordering_cols else ()
285+
),
286+
)
287287
for agg, id in node.aggregations
288288
)
289289
by_cols: tuple[sge.Expression, ...] = tuple(
@@ -299,6 +299,72 @@ def compile_aggregate(
299299

300300
return child.aggregate(aggregations, by_cols, tuple(dropna_cols))
301301

302+
@_compile_node.register
303+
def compile_window(
304+
self, node: nodes.WindowOpNode, child: ir.SQLGlotIR
305+
) -> ir.SQLGlotIR:
306+
window_spec = node.window_spec
307+
if node.expression.op.order_independent and window_spec.is_unbounded:
308+
# notably percentile_cont does not support ordering clause
309+
window_spec = window_spec.without_order()
310+
311+
window_op = aggregate_compiler.compile_analytic(node.expression, window_spec)
312+
313+
inputs: tuple[sge.Expression, ...] = tuple(
314+
scalar_compiler.compile_scalar_expression(expression.DerefOp(column))
315+
for column in node.expression.column_references
316+
)
317+
318+
clauses: list[tuple[sge.Expression, sge.Expression]] = []
319+
if node.expression.op.skips_nulls and not node.never_skip_nulls:
320+
for column in inputs:
321+
clauses.append((sge.Is(this=column, expression=sge.Null()), sge.Null()))
322+
323+
if window_spec.min_periods and len(inputs) > 0:
324+
if node.expression.op.skips_nulls:
325+
# Most operations do not count NULL values towards min_periods
326+
not_null_columns = [
327+
sge.Not(this=sge.Is(this=column, expression=sge.Null()))
328+
for column in inputs
329+
]
330+
# All inputs must be non-null for observation to count
331+
if not not_null_columns:
332+
is_observation_expr: sge.Expression = sge.convert(True)
333+
else:
334+
is_observation_expr = not_null_columns[0]
335+
for expr in not_null_columns[1:]:
336+
is_observation_expr = sge.And(
337+
this=is_observation_expr, expression=expr
338+
)
339+
is_observation = ir._cast(is_observation_expr, "INT64")
340+
else:
341+
# Operations like count treat even NULLs as valid observations
342+
# for the sake of min_periods notnull is just used to convert
343+
# null values to non-null (FALSE) values to be counted.
344+
is_observation = ir._cast(
345+
sge.Not(this=sge.Is(this=inputs[0], expression=sge.Null())),
346+
"INT64",
347+
)
348+
349+
observation_count = apply_window_if_present(
350+
sge.func("SUM", is_observation), window_spec
351+
)
352+
clauses.append(
353+
(
354+
observation_count < sge.convert(window_spec.min_periods),
355+
sge.Null(),
356+
)
357+
)
358+
if clauses:
359+
when_expressions = [sge.When(this=cond, true=res) for cond, res in clauses]
360+
window_op = sge.Case(ifs=when_expressions, default=window_op)
361+
362+
# TODO: check if we can directly window the expression.
363+
return child.window(
364+
window_op=window_op,
365+
output_column_id=node.output_name.sql,
366+
)
367+
302368

303369
def _replace_unsupported_ops(node: nodes.BigFrameNode):
304370
node = nodes.bottom_up(node, rewrite.rewrite_slice)

bigframes/core/compile/sqlglot/expressions/unary_compiler.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -394,3 +394,8 @@ def _(op: ops.ParseJSON, expr: TypedExpr) -> sge.Expression:
394394
@UNARY_OP_REGISTRATION.register(ops.ToJSONString)
395395
def _(op: ops.ToJSONString, expr: TypedExpr) -> sge.Expression:
396396
return sge.func("TO_JSON_STRING", expr.expr)
397+
398+
399+
@UNARY_OP_REGISTRATION.register(ops.UnixMicros)
400+
def _(op: ops.UnixMicros, expr: TypedExpr) -> sge.Expression:
401+
return sge.func("UNIX_MICROS", expr.expr)

bigframes/core/compile/sqlglot/scalar_compiler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131

3232
@functools.singledispatch
3333
def compile_scalar_expression(
34-
expression: expression.Expression,
34+
expr: expression.Expression,
3535
) -> sge.Expression:
3636
"""Compiles BigFrames scalar expression into SQLGlot expression."""
3737
raise ValueError(f"Can't compile unrecognized node: {expression}")

bigframes/core/compile/sqlglot/sqlglot_ir.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -409,6 +409,13 @@ def aggregate(
409409
new_expr = new_expr.where(condition, append=False)
410410
return SQLGlotIR(expr=new_expr, uid_gen=self.uid_gen)
411411

412+
def window(
413+
self,
414+
window_op: sge.Expression,
415+
output_column_id: str,
416+
) -> SQLGlotIR:
417+
return self.project(((output_column_id, window_op),))
418+
412419
def insert(
413420
self,
414421
destination: bigquery.TableReference,

0 commit comments

Comments
 (0)