Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 11 additions & 10 deletions bigframes/core/array_value.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,8 +268,7 @@ def compute_values(self, assignments: Sequence[ex.Expression]):

def compute_general_expression(self, assignments: Sequence[ex.Expression]):
named_exprs = [
expression_factoring.NamedExpression(expr, ids.ColumnId.unique())
for expr in assignments
nodes.ColumnDef(expr, ids.ColumnId.unique()) for expr in assignments
]
# TODO: Push this to rewrite later to go from block expression to planning form
# TODO: Jointly fragmentize expressions to more efficiently reuse common sub-expressions
Expand All @@ -279,7 +278,7 @@ def compute_general_expression(self, assignments: Sequence[ex.Expression]):
for expr in named_exprs
)
)
target_ids = tuple(named_expr.name for named_expr in named_exprs)
target_ids = tuple(named_expr.id for named_expr in named_exprs)
new_root = expression_factoring.push_into_tree(self.node, fragments, target_ids)
return (ArrayValue(new_root), target_ids)

Expand Down Expand Up @@ -403,22 +402,24 @@ def aggregate(

def project_window_expr(
self,
expression: agg_expressions.Aggregation,
expressions: Sequence[agg_expressions.Aggregation],
window: WindowSpec,
skip_reproject_unsafe: bool = False,
):
output_name = self._gen_namespaced_uid()
id_strings = [self._gen_namespaced_uid() for _ in expressions]
agg_exprs = tuple(
nodes.ColumnDef(expression, ids.ColumnId(id_str))
for expression, id_str in zip(expressions, id_strings)
)

return (
ArrayValue(
nodes.WindowOpNode(
child=self.node,
expression=expression,
agg_exprs=agg_exprs,
window_spec=window,
output_name=ids.ColumnId(output_name),
skip_reproject_unsafe=skip_reproject_unsafe,
)
),
output_name,
id_strings,
)

def isin(
Expand Down
2 changes: 0 additions & 2 deletions bigframes/core/block_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,13 +232,11 @@ def _interpolate_column(
masked_offsets,
agg_ops.LastNonNullOp(),
backwards_window,
skip_reproject_unsafe=True,
)
block, next_value_offset = block.apply_window_op(
masked_offsets,
agg_ops.FirstNonNullOp(),
forwards_window,
skip_reproject_unsafe=True,
)

if interpolate_method == "linear":
Expand Down
49 changes: 19 additions & 30 deletions bigframes/core/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -1091,20 +1091,14 @@ def multi_apply_window_op(
*,
skip_null_groups: bool = False,
) -> typing.Tuple[Block, typing.Sequence[str]]:
block = self
result_ids = []
for i, col_id in enumerate(columns):
label = self.col_id_to_label[col_id]
block, result_id = block.apply_window_op(
col_id,
op,
window_spec=window_spec,
skip_reproject_unsafe=(i + 1) < len(columns),
result_label=label,
skip_null_groups=skip_null_groups,
)
result_ids.append(result_id)
return block, result_ids
return self.apply_analytic(
agg_exprs=(
agg_expressions.UnaryAggregation(op, ex.deref(col)) for col in columns
),
window=window_spec,
result_labels=self._get_labels_for_columns(columns),
skip_null_groups=skip_null_groups,
)

def multi_apply_unary_op(
self,
Expand Down Expand Up @@ -1181,44 +1175,39 @@ def apply_window_op(
*,
result_label: Label = None,
skip_null_groups: bool = False,
skip_reproject_unsafe: bool = False,
) -> typing.Tuple[Block, str]:
agg_expr = agg_expressions.UnaryAggregation(op, ex.deref(column))
return self.apply_analytic(
agg_expr,
block, ids = self.apply_analytic(
[agg_expr],
window_spec,
result_label,
skip_reproject_unsafe=skip_reproject_unsafe,
[result_label],
skip_null_groups=skip_null_groups,
)
return block, ids[0]

def apply_analytic(
self,
agg_expr: agg_expressions.Aggregation,
agg_exprs: Iterable[agg_expressions.Aggregation],
window: windows.WindowSpec,
result_label: Label,
result_labels: Iterable[Label],
*,
skip_reproject_unsafe: bool = False,
skip_null_groups: bool = False,
) -> typing.Tuple[Block, str]:
) -> typing.Tuple[Block, Sequence[str]]:
block = self
if skip_null_groups:
for key in window.grouping_keys:
block = block.filter(ops.notnull_op.as_expr(key))
expr, result_id = block._expr.project_window_expr(
agg_expr,
expr, result_ids = block._expr.project_window_expr(
tuple(agg_exprs),
window,
skip_reproject_unsafe=skip_reproject_unsafe,
)
block = Block(
expr,
index_columns=self.index_columns,
column_labels=self.column_labels.insert(
len(self.column_labels), result_label
),
column_labels=self.column_labels.append(pd.Index(result_labels)),
index_labels=self._index_labels,
)
return (block, result_id)
return (block, result_ids)

def copy_values(self, source_column_id: str, destination_column_id: str) -> Block:
expr = self.expr.assign(source_column_id, destination_column_id)
Expand Down
12 changes: 7 additions & 5 deletions bigframes/core/compile/ibis_compiler/ibis_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,11 +265,13 @@ def compile_aggregate(node: nodes.AggregateNode, child: compiled.UnorderedIR):

@_compile_node.register
def compile_window(node: nodes.WindowOpNode, child: compiled.UnorderedIR):
result = child.project_window_op(
node.expression,
node.window_spec,
node.output_name.sql,
)
result = child
for cdef in node.agg_exprs:
result = result.project_window_op(
cdef.expression, # type: ignore
node.window_spec,
cdef.id.sql,
)
return result


Expand Down
29 changes: 16 additions & 13 deletions bigframes/core/compile/polars/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -853,20 +853,23 @@ def compile_window(self, node: nodes.WindowOpNode):
"min_period not yet supported for polars engine"
)

if (window.bounds is None) or (window.is_unbounded):
# polars will automatically broadcast the aggregate to the matching input rows
agg_pl = self.agg_compiler.compile_agg_expr(node.expression)
if window.grouping_keys:
agg_pl = agg_pl.over(
self.expr_compiler.compile_expression(key)
for key in window.grouping_keys
result = df
for cdef in node.agg_exprs:
assert isinstance(cdef.expression, agg_expressions.Aggregation)
if (window.bounds is None) or (window.is_unbounded):
# polars will automatically broadcast the aggregate to the matching input rows
agg_pl = self.agg_compiler.compile_agg_expr(cdef.expression)
if window.grouping_keys:
agg_pl = agg_pl.over(
self.expr_compiler.compile_expression(key)
for key in window.grouping_keys
)
result = result.with_columns(agg_pl.alias(cdef.id.sql))
else: # row-bounded window
window_result = self._calc_row_analytic_func(
result, cdef.expression, node.window_spec, cdef.id.sql
)
result = df.with_columns(agg_pl.alias(node.output_name.sql))
else: # row-bounded window
window_result = self._calc_row_analytic_func(
df, node.expression, node.window_spec, node.output_name.sql
)
result = pl.concat([df, window_result], how="horizontal")
result = pl.concat([result, window_result], how="horizontal")
return result

def _calc_row_analytic_func(
Expand Down
122 changes: 67 additions & 55 deletions bigframes/core/compile/sqlglot/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,15 @@

import sqlglot.expressions as sge

from bigframes.core import expression, guid, identifiers, nodes, pyarrow_utils, rewrite
from bigframes.core import (
agg_expressions,
expression,
guid,
identifiers,
nodes,
pyarrow_utils,
rewrite,
)
from bigframes.core.compile import configs
import bigframes.core.compile.sqlglot.aggregate_compiler as aggregate_compiler
from bigframes.core.compile.sqlglot.aggregations import windows
Expand Down Expand Up @@ -310,67 +318,71 @@ def compile_aggregate(node: nodes.AggregateNode, child: ir.SQLGlotIR) -> ir.SQLG
@_compile_node.register
def compile_window(node: nodes.WindowOpNode, child: ir.SQLGlotIR) -> ir.SQLGlotIR:
window_spec = node.window_spec
if node.expression.op.order_independent and window_spec.is_unbounded:
# notably percentile_cont does not support ordering clause
window_spec = window_spec.without_order()
result = child
for cdef in node.agg_exprs:
assert isinstance(cdef.expression, agg_expressions.Aggregation)
if cdef.expression.op.order_independent and window_spec.is_unbounded:
# notably percentile_cont does not support ordering clause
window_spec = window_spec.without_order()

window_op = aggregate_compiler.compile_analytic(node.expression, window_spec)
window_op = aggregate_compiler.compile_analytic(cdef.expression, window_spec)

inputs: tuple[sge.Expression, ...] = tuple(
scalar_compiler.scalar_op_compiler.compile_expression(
expression.DerefOp(column)
inputs: tuple[sge.Expression, ...] = tuple(
scalar_compiler.scalar_op_compiler.compile_expression(
expression.DerefOp(column)
)
for column in cdef.expression.column_references
)
for column in node.expression.column_references
)

clauses: list[tuple[sge.Expression, sge.Expression]] = []
if window_spec.min_periods and len(inputs) > 0:
if not node.expression.op.nulls_count_for_min_values:
# Most operations do not count NULL values towards min_periods
not_null_columns = [
sge.Not(this=sge.Is(this=column, expression=sge.Null()))
for column in inputs
]
# All inputs must be non-null for observation to count
if not not_null_columns:
is_observation_expr: sge.Expression = sge.convert(True)
clauses: list[tuple[sge.Expression, sge.Expression]] = []
if window_spec.min_periods and len(inputs) > 0:
if not cdef.expression.op.nulls_count_for_min_values:
# Most operations do not count NULL values towards min_periods
not_null_columns = [
sge.Not(this=sge.Is(this=column, expression=sge.Null()))
for column in inputs
]
# All inputs must be non-null for observation to count
if not not_null_columns:
is_observation_expr: sge.Expression = sge.convert(True)
else:
is_observation_expr = not_null_columns[0]
for expr in not_null_columns[1:]:
is_observation_expr = sge.And(
this=is_observation_expr, expression=expr
)
is_observation = ir._cast(is_observation_expr, "INT64")
observation_count = windows.apply_window_if_present(
sge.func("SUM", is_observation), window_spec
)
else:
is_observation_expr = not_null_columns[0]
for expr in not_null_columns[1:]:
is_observation_expr = sge.And(
this=is_observation_expr, expression=expr
)
is_observation = ir._cast(is_observation_expr, "INT64")
observation_count = windows.apply_window_if_present(
sge.func("SUM", is_observation), window_spec
)
else:
# Operations like count treat even NULLs as valid observations
# for the sake of min_periods notnull is just used to convert
# null values to non-null (FALSE) values to be counted.
is_observation = ir._cast(
sge.Not(this=sge.Is(this=inputs[0], expression=sge.Null())),
"INT64",
)
observation_count = windows.apply_window_if_present(
sge.func("COUNT", is_observation), window_spec
)

clauses.append(
(
observation_count < sge.convert(window_spec.min_periods),
sge.Null(),
# Operations like count treat even NULLs as valid observations
# for the sake of min_periods notnull is just used to convert
# null values to non-null (FALSE) values to be counted.
is_observation = ir._cast(
sge.Not(this=sge.Is(this=inputs[0], expression=sge.Null())),
"INT64",
)
observation_count = windows.apply_window_if_present(
sge.func("COUNT", is_observation), window_spec
)

clauses.append(
(
observation_count < sge.convert(window_spec.min_periods),
sge.Null(),
)
)
if clauses:
when_expressions = [sge.When(this=cond, true=res) for cond, res in clauses]
window_op = sge.Case(ifs=when_expressions, default=window_op)

# TODO: check if we can directly window the expression.
result = child.window(
window_op=window_op,
output_column_id=cdef.id.sql,
)
if clauses:
when_expressions = [sge.When(this=cond, true=res) for cond, res in clauses]
window_op = sge.Case(ifs=when_expressions, default=window_op)

# TODO: check if we can directly window the expression.
return child.window(
window_op=window_op,
output_column_id=node.output_name.sql,
)
return result


def _replace_unsupported_ops(node: nodes.BigFrameNode):
Expand Down
Loading