Skip to content

Commit 74ef7fc

Browse files
refactor: WindowOpNode can create multiple cols
1 parent e39dfe2 commit 74ef7fc

File tree

12 files changed

+181
-127
lines changed

12 files changed

+181
-127
lines changed

bigframes/core/array_value.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -405,17 +405,14 @@ def project_window_expr(
405405
self,
406406
expression: agg_expressions.Aggregation,
407407
window: WindowSpec,
408-
skip_reproject_unsafe: bool = False,
409408
):
410409
output_name = self._gen_namespaced_uid()
411410
return (
412411
ArrayValue(
413412
nodes.WindowOpNode(
414413
child=self.node,
415-
expression=expression,
414+
agg_exprs=(nodes.ColumnDef(expression, ids.ColumnId(output_name)),),
416415
window_spec=window,
417-
output_name=ids.ColumnId(output_name),
418-
skip_reproject_unsafe=skip_reproject_unsafe,
419416
)
420417
),
421418
output_name,

bigframes/core/compile/ibis_compiler/ibis_compiler.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -265,11 +265,13 @@ def compile_aggregate(node: nodes.AggregateNode, child: compiled.UnorderedIR):
265265

266266
@_compile_node.register
267267
def compile_window(node: nodes.WindowOpNode, child: compiled.UnorderedIR):
268-
result = child.project_window_op(
269-
node.expression,
270-
node.window_spec,
271-
node.output_name.sql,
272-
)
268+
result = child
269+
for cdef in node.agg_exprs:
270+
result = child.project_window_op(
271+
cdef.expression, # type: ignore
272+
node.window_spec,
273+
cdef.id.sql,
274+
)
273275
return result
274276

275277

bigframes/core/compile/polars/compiler.py

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -853,20 +853,23 @@ def compile_window(self, node: nodes.WindowOpNode):
853853
"min_period not yet supported for polars engine"
854854
)
855855

856-
if (window.bounds is None) or (window.is_unbounded):
857-
# polars will automatically broadcast the aggregate to the matching input rows
858-
agg_pl = self.agg_compiler.compile_agg_expr(node.expression)
859-
if window.grouping_keys:
860-
agg_pl = agg_pl.over(
861-
self.expr_compiler.compile_expression(key)
862-
for key in window.grouping_keys
856+
result = df
857+
for cdef in node.agg_exprs:
858+
assert isinstance(cdef.expression, agg_expressions.Aggregation)
859+
if (window.bounds is None) or (window.is_unbounded):
860+
# polars will automatically broadcast the aggregate to the matching input rows
861+
agg_pl = self.agg_compiler.compile_agg_expr(cdef.expression)
862+
if window.grouping_keys:
863+
agg_pl = agg_pl.over(
864+
self.expr_compiler.compile_expression(key)
865+
for key in window.grouping_keys
866+
)
867+
result = df.with_columns(agg_pl.alias(cdef.id.sql))
868+
else: # row-bounded window
869+
window_result = self._calc_row_analytic_func(
870+
df, cdef.expression, node.window_spec, cdef.id.sql
863871
)
864-
result = df.with_columns(agg_pl.alias(node.output_name.sql))
865-
else: # row-bounded window
866-
window_result = self._calc_row_analytic_func(
867-
df, node.expression, node.window_spec, node.output_name.sql
868-
)
869-
result = pl.concat([df, window_result], how="horizontal")
872+
result = pl.concat([result, window_result], how="horizontal")
870873
return result
871874

872875
def _calc_row_analytic_func(

bigframes/core/compile/sqlglot/compiler.py

Lines changed: 67 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,15 @@
1919

2020
import sqlglot.expressions as sge
2121

22-
from bigframes.core import expression, guid, identifiers, nodes, pyarrow_utils, rewrite
22+
from bigframes.core import (
23+
agg_expressions,
24+
expression,
25+
guid,
26+
identifiers,
27+
nodes,
28+
pyarrow_utils,
29+
rewrite,
30+
)
2331
from bigframes.core.compile import configs
2432
import bigframes.core.compile.sqlglot.aggregate_compiler as aggregate_compiler
2533
from bigframes.core.compile.sqlglot.aggregations import windows
@@ -310,67 +318,71 @@ def compile_aggregate(node: nodes.AggregateNode, child: ir.SQLGlotIR) -> ir.SQLG
310318
@_compile_node.register
311319
def compile_window(node: nodes.WindowOpNode, child: ir.SQLGlotIR) -> ir.SQLGlotIR:
312320
window_spec = node.window_spec
313-
if node.expression.op.order_independent and window_spec.is_unbounded:
314-
# notably percentile_cont does not support ordering clause
315-
window_spec = window_spec.without_order()
321+
result = child
322+
for cdef in node.agg_exprs:
323+
assert isinstance(cdef.expression, agg_expressions.Aggregation)
324+
if cdef.expression.op.order_independent and window_spec.is_unbounded:
325+
# notably percentile_cont does not support ordering clause
326+
window_spec = window_spec.without_order()
316327

317-
window_op = aggregate_compiler.compile_analytic(node.expression, window_spec)
328+
window_op = aggregate_compiler.compile_analytic(cdef.expression, window_spec)
318329

319-
inputs: tuple[sge.Expression, ...] = tuple(
320-
scalar_compiler.scalar_op_compiler.compile_expression(
321-
expression.DerefOp(column)
330+
inputs: tuple[sge.Expression, ...] = tuple(
331+
scalar_compiler.scalar_op_compiler.compile_expression(
332+
expression.DerefOp(column)
333+
)
334+
for column in cdef.expression.column_references
322335
)
323-
for column in node.expression.column_references
324-
)
325336

326-
clauses: list[tuple[sge.Expression, sge.Expression]] = []
327-
if window_spec.min_periods and len(inputs) > 0:
328-
if not node.expression.op.nulls_count_for_min_values:
329-
# Most operations do not count NULL values towards min_periods
330-
not_null_columns = [
331-
sge.Not(this=sge.Is(this=column, expression=sge.Null()))
332-
for column in inputs
333-
]
334-
# All inputs must be non-null for observation to count
335-
if not not_null_columns:
336-
is_observation_expr: sge.Expression = sge.convert(True)
337+
clauses: list[tuple[sge.Expression, sge.Expression]] = []
338+
if window_spec.min_periods and len(inputs) > 0:
339+
if not cdef.expression.op.nulls_count_for_min_values:
340+
# Most operations do not count NULL values towards min_periods
341+
not_null_columns = [
342+
sge.Not(this=sge.Is(this=column, expression=sge.Null()))
343+
for column in inputs
344+
]
345+
# All inputs must be non-null for observation to count
346+
if not not_null_columns:
347+
is_observation_expr: sge.Expression = sge.convert(True)
348+
else:
349+
is_observation_expr = not_null_columns[0]
350+
for expr in not_null_columns[1:]:
351+
is_observation_expr = sge.And(
352+
this=is_observation_expr, expression=expr
353+
)
354+
is_observation = ir._cast(is_observation_expr, "INT64")
355+
observation_count = windows.apply_window_if_present(
356+
sge.func("SUM", is_observation), window_spec
357+
)
337358
else:
338-
is_observation_expr = not_null_columns[0]
339-
for expr in not_null_columns[1:]:
340-
is_observation_expr = sge.And(
341-
this=is_observation_expr, expression=expr
342-
)
343-
is_observation = ir._cast(is_observation_expr, "INT64")
344-
observation_count = windows.apply_window_if_present(
345-
sge.func("SUM", is_observation), window_spec
346-
)
347-
else:
348-
# Operations like count treat even NULLs as valid observations
349-
# for the sake of min_periods notnull is just used to convert
350-
# null values to non-null (FALSE) values to be counted.
351-
is_observation = ir._cast(
352-
sge.Not(this=sge.Is(this=inputs[0], expression=sge.Null())),
353-
"INT64",
354-
)
355-
observation_count = windows.apply_window_if_present(
356-
sge.func("COUNT", is_observation), window_spec
357-
)
358-
359-
clauses.append(
360-
(
361-
observation_count < sge.convert(window_spec.min_periods),
362-
sge.Null(),
359+
# Operations like count treat even NULLs as valid observations
360+
# for the sake of min_periods notnull is just used to convert
361+
# null values to non-null (FALSE) values to be counted.
362+
is_observation = ir._cast(
363+
sge.Not(this=sge.Is(this=inputs[0], expression=sge.Null())),
364+
"INT64",
365+
)
366+
observation_count = windows.apply_window_if_present(
367+
sge.func("COUNT", is_observation), window_spec
368+
)
369+
370+
clauses.append(
371+
(
372+
observation_count < sge.convert(window_spec.min_periods),
373+
sge.Null(),
374+
)
363375
)
376+
if clauses:
377+
when_expressions = [sge.When(this=cond, true=res) for cond, res in clauses]
378+
window_op = sge.Case(ifs=when_expressions, default=window_op)
379+
380+
# TODO: check if we can directly window the expression.
381+
result = child.window(
382+
window_op=window_op,
383+
output_column_id=cdef.id.sql,
364384
)
365-
if clauses:
366-
when_expressions = [sge.When(this=cond, true=res) for cond, res in clauses]
367-
window_op = sge.Case(ifs=when_expressions, default=window_op)
368-
369-
# TODO: check if we can directly window the expression.
370-
return child.window(
371-
window_op=window_op,
372-
output_column_id=node.output_name.sql,
373-
)
385+
return result
374386

375387

376388
def _replace_unsupported_ops(node: nodes.BigFrameNode):

bigframes/core/expression_factoring.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,9 @@ def graph_extract_window_expr() -> Optional[
187187
while result := graph_extract_window_expr():
188188
id, window_expr = result
189189
curr_root = nodes.WindowOpNode(
190-
curr_root, window_expr.analytic_expr, window_expr.window, output_name=id
190+
curr_root,
191+
(nodes.ColumnDef(window_expr.analytic_expr, id),),
192+
window_expr.window,
191193
)
192194
if len(graph.nodes) >= pre_size:
193195
raise ValueError("graph didn't shrink")

0 commit comments

Comments
 (0)