diff --git a/bigframes/core/array_value.py b/bigframes/core/array_value.py index 5af6fbd56e..2cc8fdf3f0 100644 --- a/bigframes/core/array_value.py +++ b/bigframes/core/array_value.py @@ -268,8 +268,7 @@ def compute_values(self, assignments: Sequence[ex.Expression]): def compute_general_expression(self, assignments: Sequence[ex.Expression]): named_exprs = [ - expression_factoring.NamedExpression(expr, ids.ColumnId.unique()) - for expr in assignments + nodes.ColumnDef(expr, ids.ColumnId.unique()) for expr in assignments ] # TODO: Push this to rewrite later to go from block expression to planning form # TODO: Jointly fragmentize expressions to more efficiently reuse common sub-expressions @@ -279,7 +278,7 @@ def compute_general_expression(self, assignments: Sequence[ex.Expression]): for expr in named_exprs ) ) - target_ids = tuple(named_expr.name for named_expr in named_exprs) + target_ids = tuple(named_expr.id for named_expr in named_exprs) new_root = expression_factoring.push_into_tree(self.node, fragments, target_ids) return (ArrayValue(new_root), target_ids) @@ -403,22 +402,24 @@ def aggregate( def project_window_expr( self, - expression: agg_expressions.Aggregation, + expressions: Sequence[agg_expressions.Aggregation], window: WindowSpec, - skip_reproject_unsafe: bool = False, ): - output_name = self._gen_namespaced_uid() + id_strings = [self._gen_namespaced_uid() for _ in expressions] + agg_exprs = tuple( + nodes.ColumnDef(expression, ids.ColumnId(id_str)) + for expression, id_str in zip(expressions, id_strings) + ) + return ( ArrayValue( nodes.WindowOpNode( child=self.node, - expression=expression, + agg_exprs=agg_exprs, window_spec=window, - output_name=ids.ColumnId(output_name), - skip_reproject_unsafe=skip_reproject_unsafe, ) ), - output_name, + id_strings, ) def isin( diff --git a/bigframes/core/block_transforms.py b/bigframes/core/block_transforms.py index 4e7abb1104..773a615fd9 100644 --- a/bigframes/core/block_transforms.py +++ b/bigframes/core/block_transforms.py @@ -232,13 +232,11 @@ def _interpolate_column( masked_offsets, agg_ops.LastNonNullOp(), backwards_window, - skip_reproject_unsafe=True, ) block, next_value_offset = block.apply_window_op( masked_offsets, agg_ops.FirstNonNullOp(), forwards_window, - skip_reproject_unsafe=True, ) if interpolate_method == "linear": diff --git a/bigframes/core/blocks.py b/bigframes/core/blocks.py index 466dbfce72..e45d945e23 100644 --- a/bigframes/core/blocks.py +++ b/bigframes/core/blocks.py @@ -1091,20 +1091,14 @@ def multi_apply_window_op( *, skip_null_groups: bool = False, ) -> typing.Tuple[Block, typing.Sequence[str]]: - block = self - result_ids = [] - for i, col_id in enumerate(columns): - label = self.col_id_to_label[col_id] - block, result_id = block.apply_window_op( - col_id, - op, - window_spec=window_spec, - skip_reproject_unsafe=(i + 1) < len(columns), - result_label=label, - skip_null_groups=skip_null_groups, - ) - result_ids.append(result_id) - return block, result_ids + return self.apply_analytic( + agg_exprs=( + agg_expressions.UnaryAggregation(op, ex.deref(col)) for col in columns + ), + window=window_spec, + result_labels=self._get_labels_for_columns(columns), + skip_null_groups=skip_null_groups, + ) def multi_apply_unary_op( self, @@ -1181,44 +1175,39 @@ def apply_window_op( *, result_label: Label = None, skip_null_groups: bool = False, - skip_reproject_unsafe: bool = False, ) -> typing.Tuple[Block, str]: agg_expr = agg_expressions.UnaryAggregation(op, ex.deref(column)) - return self.apply_analytic( - agg_expr, + block, ids = self.apply_analytic( + [agg_expr], window_spec, - result_label, - skip_reproject_unsafe=skip_reproject_unsafe, + [result_label], skip_null_groups=skip_null_groups, ) + return block, ids[0] def apply_analytic( self, - agg_expr: agg_expressions.Aggregation, + agg_exprs: Iterable[agg_expressions.Aggregation], window: windows.WindowSpec, - result_label: Label, + result_labels: Iterable[Label], *, - skip_reproject_unsafe: bool = False, skip_null_groups: bool = False, - ) -> typing.Tuple[Block, str]: + ) -> typing.Tuple[Block, Sequence[str]]: block = self if skip_null_groups: for key in window.grouping_keys: block = block.filter(ops.notnull_op.as_expr(key)) - expr, result_id = block._expr.project_window_expr( - agg_expr, + expr, result_ids = block._expr.project_window_expr( + tuple(agg_exprs), window, - skip_reproject_unsafe=skip_reproject_unsafe, ) block = Block( expr, index_columns=self.index_columns, - column_labels=self.column_labels.insert( - len(self.column_labels), result_label - ), + column_labels=self.column_labels.append(pd.Index(result_labels)), index_labels=self._index_labels, ) - return (block, result_id) + return (block, result_ids) def copy_values(self, source_column_id: str, destination_column_id: str) -> Block: expr = self.expr.assign(source_column_id, destination_column_id) diff --git a/bigframes/core/compile/ibis_compiler/ibis_compiler.py b/bigframes/core/compile/ibis_compiler/ibis_compiler.py index b46c66f879..31cd9a0456 100644 --- a/bigframes/core/compile/ibis_compiler/ibis_compiler.py +++ b/bigframes/core/compile/ibis_compiler/ibis_compiler.py @@ -265,11 +265,13 @@ def compile_aggregate(node: nodes.AggregateNode, child: compiled.UnorderedIR): @_compile_node.register def compile_window(node: nodes.WindowOpNode, child: compiled.UnorderedIR): - result = child.project_window_op( - node.expression, - node.window_spec, - node.output_name.sql, - ) + result = child + for cdef in node.agg_exprs: + result = result.project_window_op( + cdef.expression, # type: ignore + node.window_spec, + cdef.id.sql, + ) return result diff --git a/bigframes/core/compile/polars/compiler.py b/bigframes/core/compile/polars/compiler.py index 1c9b0d802d..5988ecaa90 100644 --- a/bigframes/core/compile/polars/compiler.py +++ b/bigframes/core/compile/polars/compiler.py @@ -853,20 +853,23 @@ def compile_window(self, node: nodes.WindowOpNode): "min_period not yet supported for polars engine" ) - if (window.bounds is None) or (window.is_unbounded): - # polars will automatically broadcast the aggregate to the matching input rows - agg_pl = self.agg_compiler.compile_agg_expr(node.expression) - if window.grouping_keys: - agg_pl = agg_pl.over( - self.expr_compiler.compile_expression(key) - for key in window.grouping_keys + result = df + for cdef in node.agg_exprs: + assert isinstance(cdef.expression, agg_expressions.Aggregation) + if (window.bounds is None) or (window.is_unbounded): + # polars will automatically broadcast the aggregate to the matching input rows + agg_pl = self.agg_compiler.compile_agg_expr(cdef.expression) + if window.grouping_keys: + agg_pl = agg_pl.over( + self.expr_compiler.compile_expression(key) + for key in window.grouping_keys + ) + result = result.with_columns(agg_pl.alias(cdef.id.sql)) + else: # row-bounded window + window_result = self._calc_row_analytic_func( + result, cdef.expression, node.window_spec, cdef.id.sql ) - result = df.with_columns(agg_pl.alias(node.output_name.sql)) - else: # row-bounded window - window_result = self._calc_row_analytic_func( - df, node.expression, node.window_spec, node.output_name.sql - ) - result = pl.concat([df, window_result], how="horizontal") + result = pl.concat([result, window_result], how="horizontal") return result def _calc_row_analytic_func( diff --git a/bigframes/core/compile/sqlglot/compiler.py b/bigframes/core/compile/sqlglot/compiler.py index 276751d6e3..7ecc15f6a2 100644 --- a/bigframes/core/compile/sqlglot/compiler.py +++ b/bigframes/core/compile/sqlglot/compiler.py @@ -19,7 +19,15 @@ import sqlglot.expressions as sge -from bigframes.core import expression, guid, identifiers, nodes, pyarrow_utils, rewrite +from bigframes.core import ( + agg_expressions, + expression, + guid, + identifiers, + nodes, + pyarrow_utils, + rewrite, +) from bigframes.core.compile import configs import bigframes.core.compile.sqlglot.aggregate_compiler as aggregate_compiler from bigframes.core.compile.sqlglot.aggregations import windows @@ -310,67 +318,71 @@ def compile_aggregate(node: nodes.AggregateNode, child: ir.SQLGlotIR) -> ir.SQLG @_compile_node.register def compile_window(node: nodes.WindowOpNode, child: ir.SQLGlotIR) -> ir.SQLGlotIR: window_spec = node.window_spec - if node.expression.op.order_independent and window_spec.is_unbounded: - # notably percentile_cont does not support ordering clause - window_spec = window_spec.without_order() + result = child + for cdef in node.agg_exprs: + assert isinstance(cdef.expression, agg_expressions.Aggregation) + if cdef.expression.op.order_independent and window_spec.is_unbounded: + # notably percentile_cont does not support ordering clause + window_spec = window_spec.without_order() - window_op = aggregate_compiler.compile_analytic(node.expression, window_spec) + window_op = aggregate_compiler.compile_analytic(cdef.expression, window_spec) - inputs: tuple[sge.Expression, ...] = tuple( - scalar_compiler.scalar_op_compiler.compile_expression( - expression.DerefOp(column) + inputs: tuple[sge.Expression, ...] = tuple( + scalar_compiler.scalar_op_compiler.compile_expression( + expression.DerefOp(column) + ) + for column in cdef.expression.column_references ) - for column in node.expression.column_references - ) - clauses: list[tuple[sge.Expression, sge.Expression]] = [] - if window_spec.min_periods and len(inputs) > 0: - if not node.expression.op.nulls_count_for_min_values: - # Most operations do not count NULL values towards min_periods - not_null_columns = [ - sge.Not(this=sge.Is(this=column, expression=sge.Null())) - for column in inputs - ] - # All inputs must be non-null for observation to count - if not not_null_columns: - is_observation_expr: sge.Expression = sge.convert(True) + clauses: list[tuple[sge.Expression, sge.Expression]] = [] + if window_spec.min_periods and len(inputs) > 0: + if not cdef.expression.op.nulls_count_for_min_values: + # Most operations do not count NULL values towards min_periods + not_null_columns = [ + sge.Not(this=sge.Is(this=column, expression=sge.Null())) + for column in inputs + ] + # All inputs must be non-null for observation to count + if not not_null_columns: + is_observation_expr: sge.Expression = sge.convert(True) + else: + is_observation_expr = not_null_columns[0] + for expr in not_null_columns[1:]: + is_observation_expr = sge.And( + this=is_observation_expr, expression=expr + ) + is_observation = ir._cast(is_observation_expr, "INT64") + observation_count = windows.apply_window_if_present( + sge.func("SUM", is_observation), window_spec + ) else: - is_observation_expr = not_null_columns[0] - for expr in not_null_columns[1:]: - is_observation_expr = sge.And( - this=is_observation_expr, expression=expr - ) - is_observation = ir._cast(is_observation_expr, "INT64") - observation_count = windows.apply_window_if_present( - sge.func("SUM", is_observation), window_spec - ) - else: - # Operations like count treat even NULLs as valid observations - # for the sake of min_periods notnull is just used to convert - # null values to non-null (FALSE) values to be counted. - is_observation = ir._cast( - sge.Not(this=sge.Is(this=inputs[0], expression=sge.Null())), - "INT64", - ) - observation_count = windows.apply_window_if_present( - sge.func("COUNT", is_observation), window_spec - ) - - clauses.append( - ( - observation_count < sge.convert(window_spec.min_periods), - sge.Null(), + # Operations like count treat even NULLs as valid observations + # for the sake of min_periods notnull is just used to convert + # null values to non-null (FALSE) values to be counted. + is_observation = ir._cast( + sge.Not(this=sge.Is(this=inputs[0], expression=sge.Null())), + "INT64", + ) + observation_count = windows.apply_window_if_present( + sge.func("COUNT", is_observation), window_spec + ) + + clauses.append( + ( + observation_count < sge.convert(window_spec.min_periods), + sge.Null(), + ) ) + if clauses: + when_expressions = [sge.When(this=cond, true=res) for cond, res in clauses] + window_op = sge.Case(ifs=when_expressions, default=window_op) + + # TODO: check if we can directly window the expression. + result = child.window( + window_op=window_op, + output_column_id=cdef.id.sql, ) - if clauses: - when_expressions = [sge.When(this=cond, true=res) for cond, res in clauses] - window_op = sge.Case(ifs=when_expressions, default=window_op) - - # TODO: check if we can directly window the expression. - return child.window( - window_op=window_op, - output_column_id=node.output_name.sql, - ) + return result def _replace_unsupported_ops(node: nodes.BigFrameNode): diff --git a/bigframes/core/expression_factoring.py b/bigframes/core/expression_factoring.py index 07d5591bc5..d7ac49b585 100644 --- a/bigframes/core/expression_factoring.py +++ b/bigframes/core/expression_factoring.py @@ -1,33 +1,26 @@ import collections import dataclasses import functools -import itertools -from typing import Generic, Hashable, Iterable, Optional, Sequence, Tuple, TypeVar +from typing import cast, Generic, Hashable, Iterable, Optional, Sequence, Tuple, TypeVar -from bigframes.core import agg_expressions, expression, identifiers, nodes +from bigframes.core import agg_expressions, expression, identifiers, nodes, window_spec _MAX_INLINE_COMPLEXITY = 10 -@dataclasses.dataclass(frozen=True, eq=False) -class NamedExpression: - expr: expression.Expression - name: identifiers.ColumnId - - @dataclasses.dataclass(frozen=True, eq=False) class FactoredExpression: root_expr: expression.Expression - sub_exprs: Tuple[NamedExpression, ...] + sub_exprs: Tuple[nodes.ColumnDef, ...] -def fragmentize_expression(root: NamedExpression) -> Sequence[NamedExpression]: +def fragmentize_expression(root: nodes.ColumnDef) -> Sequence[nodes.ColumnDef]: """ The goal of this functions is to factor out an expression into multiple sub-expressions. """ - factored_expr = root.expr.reduce_up(gather_fragments) - root_expr = NamedExpression(factored_expr.root_expr, root.name) + factored_expr = root.expression.reduce_up(gather_fragments) + root_expr = nodes.ColumnDef(factored_expr.root_expr, root.id) return (root_expr, *factored_expr.sub_exprs) @@ -48,7 +41,7 @@ def gather_fragments( if not do_inline: id = identifiers.ColumnId.unique() replacements.append(expression.DerefOp(id)) - named_exprs.append(NamedExpression(child_result.root_expr, id)) + named_exprs.append(nodes.ColumnDef(child_result.root_expr, id)) named_exprs.extend(child_result.sub_exprs) else: replacements.append(child_result.root_expr) @@ -116,24 +109,34 @@ def remove_node(self, node: T) -> None: def push_into_tree( root: nodes.BigFrameNode, - exprs: Sequence[NamedExpression], + exprs: Sequence[nodes.ColumnDef], target_ids: Sequence[identifiers.ColumnId], ) -> nodes.BigFrameNode: curr_root = root - by_id = {expr.name: expr for expr in exprs} + by_id = {expr.id: expr for expr in exprs} # id -> id graph = DiGraph( - (expr.name, child_id) + (expr.id, child_id) for expr in exprs - for child_id in expr.expr.column_references + for child_id in expr.expression.column_references if child_id in by_id.keys() ) # TODO: Also prevent inlining expensive or non-deterministic # We avoid inlining multi-parent ids, as they would be inlined multiple places, potentially increasing work and/or compiled text size multi_parent_ids = set(id for id in graph.nodes if len(graph.parents(id)) > 2) - scalar_ids = set(expr.name for expr in exprs if expr.expr.is_scalar_expr) + scalar_ids = set(expr.id for expr in exprs if expr.expression.is_scalar_expr) - def graph_extract_scalar_exprs() -> Sequence[NamedExpression]: + analytic_defs = filter( + lambda x: isinstance(x.expression, agg_expressions.WindowExpression), exprs + ) + analytic_by_window = grouped( + map( + lambda x: (cast(agg_expressions.WindowExpression, x.expression).window, x), + analytic_defs, + ) + ) + + def graph_extract_scalar_exprs() -> Sequence[nodes.ColumnDef]: results: dict[identifiers.ColumnId, expression.Expression] = dict() while ( True @@ -156,38 +159,55 @@ def graph_extract_scalar_exprs() -> Sequence[NamedExpression]: for id in candidate_ids: graph.remove_node(id) new_exprs = { - id: by_id[id].expr.bind_refs(results, allow_partial_bindings=True) + id: by_id[id].expression.bind_refs( + results, allow_partial_bindings=True + ) } results.update(new_exprs) # TODO: We can prune expressions that won't be reused here, - return tuple(NamedExpression(expr, id) for id, expr in results.items()) + return tuple(nodes.ColumnDef(expr, id) for id, expr in results.items()) def graph_extract_window_expr() -> Optional[ - Tuple[identifiers.ColumnId, agg_expressions.WindowExpression] + Tuple[Sequence[nodes.ColumnDef], window_spec.WindowSpec] ]: - candidate = list( - itertools.islice((id for id in graph.sinks if id not in scalar_ids), 1) - ) - if not candidate: - return None - else: - id = next(iter(candidate)) - graph.remove_node(id) - result_expr = by_id[id].expr - assert isinstance(result_expr, agg_expressions.WindowExpression) - return (id, result_expr) + for id in graph.sinks: + next_def = by_id[id] + if isinstance(next_def.expression, agg_expressions.WindowExpression): + window = next_def.expression.window + window_exprs = [ + cdef + for cdef in analytic_by_window[window] + if cdef.id in graph.sinks + ] + agg_exprs = tuple( + nodes.ColumnDef( + cast( + agg_expressions.WindowExpression, cdef.expression + ).analytic_expr, + cdef.id, + ) + for cdef in window_exprs + ) + for cdef in window_exprs: + graph.remove_node(cdef.id) + return (agg_exprs, window) + + return None while not graph.empty: pre_size = len(graph.nodes) scalar_exprs = graph_extract_scalar_exprs() if scalar_exprs: curr_root = nodes.ProjectionNode( - curr_root, tuple((x.expr, x.name) for x in scalar_exprs) + curr_root, tuple((x.expression, x.id) for x in scalar_exprs) ) while result := graph_extract_window_expr(): - id, window_expr = result + defs, window = result + assert len(defs) > 0 curr_root = nodes.WindowOpNode( - curr_root, window_expr.analytic_expr, window_expr.window, output_name=id + curr_root, + tuple(defs), + window, ) if len(graph.nodes) >= pre_size: raise ValueError("graph didn't shrink") @@ -208,3 +228,14 @@ def is_simple(expr: expression.Expression) -> bool: if count > _MAX_INLINE_COMPLEXITY: return False return True + + +K = TypeVar("K", bound=Hashable) +V = TypeVar("V") + + +def grouped(values: Iterable[tuple[K, V]]) -> dict[K, list[V]]: + result = collections.defaultdict(list) + for k, v in values: + result[k].append(v) + return result diff --git a/bigframes/core/groupby/dataframe_group_by.py b/bigframes/core/groupby/dataframe_group_by.py index 21f0d7f426..2ec3ce2c96 100644 --- a/bigframes/core/groupby/dataframe_group_by.py +++ b/bigframes/core/groupby/dataframe_group_by.py @@ -434,12 +434,12 @@ def cumcount(self, ascending: bool = True) -> series.Series: grouping_keys=tuple(self._by_col_ids) ) ) - block, result_id = self._block.apply_analytic( - agg_expressions.NullaryAggregation(agg_ops.size_op), + block, result_ids = self._block.apply_analytic( + [agg_expressions.NullaryAggregation(agg_ops.size_op)], window=window_spec, - result_label=None, + result_labels=[None], ) - result = series.Series(block.select_column(result_id)) - 1 + result = series.Series(block.select_columns(result_ids)) - 1 if self._dropna and (len(self._by_col_ids) == 1): result = result.mask( series.Series(block.select_column(self._by_col_ids[0])).isna() diff --git a/bigframes/core/nodes.py b/bigframes/core/nodes.py index e1631c435d..ddccb39ef9 100644 --- a/bigframes/core/nodes.py +++ b/bigframes/core/nodes.py @@ -48,6 +48,12 @@ OVERHEAD_VARIABLES = 5 +@dataclasses.dataclass(frozen=True, eq=True) +class ColumnDef: + expression: ex.Expression + id: identifiers.ColumnId + + class AdditiveNode: """Definition of additive - if you drop added_fields, you end up with the descendent. @@ -1391,21 +1397,23 @@ def remap_refs( @dataclasses.dataclass(frozen=True, eq=False) class WindowOpNode(UnaryNode, AdditiveNode): - expression: agg_expressions.Aggregation + agg_exprs: tuple[ColumnDef, ...] # must be analytic/aggregation op window_spec: window.WindowSpec - output_name: identifiers.ColumnId - skip_reproject_unsafe: bool = False def _validate(self): """Validate the local data in the node.""" # Since inner order and row bounds are coupled, rank ops can't be row bounded - assert ( - not self.window_spec.is_row_bounded - ) or self.expression.op.implicitly_inherits_order - assert all(ref in self.child.ids for ref in self.expression.column_references) - assert self.added_field.dtype is not None - for agg_child in self.expression.children: - assert agg_child.is_scalar_expr + for cdef in self.agg_exprs: + assert isinstance(cdef.expression, agg_expressions.Aggregation) + if self.window_spec.is_row_bounded: + assert cdef.expression.op.implicitly_inherits_order + for agg_child in cdef.expression.children: + assert agg_child.is_scalar_expr + for ref in cdef.expression.column_references: + assert ref in self.child.ids + + assert not any(field.dtype is None for field in self.added_fields) + for window_expr in self.window_spec.expressions: assert window_expr.is_scalar_expr @@ -1415,7 +1423,7 @@ def non_local(self) -> bool: @property def fields(self) -> Sequence[Field]: - return sequences.ChainedSequence(self.child.fields, (self.added_field,)) + return sequences.ChainedSequence(self.child.fields, self.added_fields) @property def variables_introduced(self) -> int: @@ -1423,49 +1431,54 @@ def variables_introduced(self) -> int: @property def added_fields(self) -> Tuple[Field, ...]: - return (self.added_field,) + return tuple( + Field( + cdef.id, + ex.bind_schema_fields( + cdef.expression, self.child.field_by_id + ).output_type, + ) + for cdef in self.agg_exprs + ) @property def relation_ops_created(self) -> int: - # Assume that if not reprojecting, that there is a sequence of window operations sharing the same window - return 0 if self.skip_reproject_unsafe else 4 + return 2 @property def row_count(self) -> Optional[int]: return self.child.row_count - @functools.cached_property - def added_field(self) -> Field: - # TODO: Determine if output could be non-null - return Field( - self.output_name, - ex.bind_schema_fields(self.expression, self.child.field_by_id).output_type, - ) - @property def node_defined_ids(self) -> Tuple[identifiers.ColumnId, ...]: - return (self.output_name,) + return tuple(field.id for field in self.added_fields) @property def consumed_ids(self) -> COLUMN_SET: - return frozenset( - set(self.ids).difference([self.output_name]).union(self.referenced_ids) - ) + return frozenset(self.ids) @property def referenced_ids(self) -> COLUMN_SET: + ids_for_aggs = itertools.chain.from_iterable( + cdef.expression.column_references for cdef in self.agg_exprs + ) return ( frozenset() - .union(self.expression.column_references) + .union(ids_for_aggs) .union(self.window_spec.all_referenced_columns) ) @property def inherits_order(self) -> bool: # does the op both use ordering at all? and if so, can it inherit order? - op_inherits_order = ( - not self.expression.op.order_independent - ) and self.expression.op.implicitly_inherits_order + aggs = ( + typing.cast(agg_expressions.Aggregation, cdef.expression) + for cdef in self.agg_exprs + ) + op_inherits_order = any( + not agg.op.order_independent and agg.op.implicitly_inherits_order + for agg in aggs + ) # range-bounded windows do not inherit orders because their ordering are # already defined before rewrite time. return op_inherits_order or self.window_spec.is_row_bounded @@ -1476,7 +1489,10 @@ def additive_base(self) -> BigFrameNode: @property def _node_expressions(self): - return (self.expression, *self.window_spec.expressions) + return ( + *(cdef.expression for cdef in self.agg_exprs), + *self.window_spec.expressions, + ) def replace_additive_base(self, node: BigFrameNode) -> WindowOpNode: return dataclasses.replace(self, child=node) @@ -1485,7 +1501,11 @@ def remap_vars( self, mappings: Mapping[identifiers.ColumnId, identifiers.ColumnId] ) -> WindowOpNode: return dataclasses.replace( - self, output_name=mappings.get(self.output_name, self.output_name) + self, + agg_exprs=tuple( + ColumnDef(cdef.expression, mappings.get(cdef.id, cdef.id)) + for cdef in self.agg_exprs + ), ) def remap_refs( @@ -1493,8 +1513,14 @@ def remap_refs( ) -> WindowOpNode: return dataclasses.replace( self, - expression=self.expression.remap_column_refs( - mappings, allow_partial_bindings=True + agg_exprs=tuple( + ColumnDef( + cdef.expression.remap_column_refs( + mappings, allow_partial_bindings=True + ), + cdef.id, + ) + for cdef in self.agg_exprs ), window_spec=self.window_spec.remap_column_refs( mappings, allow_partial_bindings=True diff --git a/bigframes/core/rewrite/order.py b/bigframes/core/rewrite/order.py index 881badd603..6741dfddad 100644 --- a/bigframes/core/rewrite/order.py +++ b/bigframes/core/rewrite/order.py @@ -168,11 +168,12 @@ def pull_up_order_inner( else: # Otherwise we need to generate offsets agg = agg_expressions.NullaryAggregation(agg_ops.RowNumberOp()) + col_def = bigframes.core.nodes.ColumnDef(agg, node.col_id) window_spec = bigframes.core.window_spec.unbound( ordering=tuple(child_order.all_ordering_columns) ) new_offsets_node = bigframes.core.nodes.WindowOpNode( - child_result, agg, window_spec, node.col_id + child_result, (col_def,), window_spec ) return ( new_offsets_node, @@ -289,8 +290,9 @@ def pull_order_concat( window_spec = bigframes.core.window_spec.unbound( ordering=tuple(order.all_ordering_columns) ) + col_def = bigframes.core.nodes.ColumnDef(agg, offsets_id) new_source = bigframes.core.nodes.WindowOpNode( - new_source, agg, window_spec, offsets_id + new_source, (col_def,), window_spec ) new_source = bigframes.core.nodes.ProjectionNode( new_source, ((bigframes.core.expression.const(i), table_id),) @@ -421,7 +423,9 @@ def rewrite_promote_offsets( ) -> bigframes.core.nodes.WindowOpNode: agg = agg_expressions.NullaryAggregation(agg_ops.RowNumberOp()) window_spec = bigframes.core.window_spec.unbound() - return bigframes.core.nodes.WindowOpNode(node.child, agg, window_spec, node.col_id) + return bigframes.core.nodes.WindowOpNode( + node.child, (bigframes.core.nodes.ColumnDef(agg, node.col_id),), window_spec + ) def rename_cols( diff --git a/bigframes/core/rewrite/schema_binding.py b/bigframes/core/rewrite/schema_binding.py index fe9143baf2..d874c7c598 100644 --- a/bigframes/core/rewrite/schema_binding.py +++ b/bigframes/core/rewrite/schema_binding.py @@ -107,7 +107,13 @@ def bind_schema_to_node( ) return dataclasses.replace( node, - expression=_bind_schema_to_aggregation_expr(node.expression, node.child), + agg_exprs=tuple( + nodes.ColumnDef( + _bind_schema_to_aggregation_expr(cdef.expression, node.child), # type: ignore + cdef.id, + ) + for cdef in node.agg_exprs + ), window_spec=window_spec, ) diff --git a/bigframes/core/rewrite/timedeltas.py b/bigframes/core/rewrite/timedeltas.py index 5c7a85ee1b..7190810f71 100644 --- a/bigframes/core/rewrite/timedeltas.py +++ b/bigframes/core/rewrite/timedeltas.py @@ -64,10 +64,13 @@ def rewrite_timedelta_expressions(root: nodes.BigFrameNode) -> nodes.BigFrameNod if isinstance(root, nodes.WindowOpNode): return nodes.WindowOpNode( root.child, - _rewrite_aggregation(root.expression, root.schema), + tuple( + nodes.ColumnDef( + _rewrite_aggregation(cdef.expression, root.schema), cdef.id + ) + for cdef in root.agg_exprs + ), root.window_spec, - root.output_name, - root.skip_reproject_unsafe, ) if isinstance(root, nodes.AggregateNode): diff --git a/tests/system/small/engines/test_windowing.py b/tests/system/small/engines/test_windowing.py index 510a2de3ba..5e4a94d900 100644 --- a/tests/system/small/engines/test_windowing.py +++ b/tests/system/small/engines/test_windowing.py @@ -54,12 +54,13 @@ def test_engines_with_rows_window( ) window_node = nodes.WindowOpNode( child=scalars_array_value.node, - expression=agg_expressions.UnaryAggregation( - agg_op, expression.deref("int64_too") + agg_exprs=( + nodes.ColumnDef( + agg_expressions.UnaryAggregation(agg_op, expression.deref("int64_too")), + identifiers.ColumnId("agg_int64"), + ), ), window_spec=window, - output_name=identifiers.ColumnId("agg_int64"), - skip_reproject_unsafe=False, ) publisher = events.Publisher() diff --git a/tests/unit/core/compile/sqlglot/aggregations/test_nullary_compiler.py b/tests/unit/core/compile/sqlglot/aggregations/test_nullary_compiler.py index 2348b95496..f9ddf3e0c0 100644 --- a/tests/unit/core/compile/sqlglot/aggregations/test_nullary_compiler.py +++ b/tests/unit/core/compile/sqlglot/aggregations/test_nullary_compiler.py @@ -46,9 +46,8 @@ def _apply_nullary_window_op( ) -> str: win_node = nodes.WindowOpNode( obj._block.expr.node, - expression=op, + agg_exprs=(nodes.ColumnDef(op, identifiers.ColumnId(new_name)),), window_spec=window_spec, - output_name=identifiers.ColumnId(new_name), ) result = array_value.ArrayValue(win_node).select_columns([new_name]) diff --git a/tests/unit/core/compile/sqlglot/aggregations/test_unary_compiler.py b/tests/unit/core/compile/sqlglot/aggregations/test_unary_compiler.py index a21c753896..755fc6eb73 100644 --- a/tests/unit/core/compile/sqlglot/aggregations/test_unary_compiler.py +++ b/tests/unit/core/compile/sqlglot/aggregations/test_unary_compiler.py @@ -54,9 +54,8 @@ def _apply_unary_window_op( ) -> str: win_node = nodes.WindowOpNode( obj._block.expr.node, - expression=op, + agg_exprs=(nodes.ColumnDef(op, identifiers.ColumnId(new_name)),), window_spec=window_spec, - output_name=identifiers.ColumnId(new_name), ) result = array_value.ArrayValue(win_node).select_columns([new_name]) diff --git a/tests/unit/core/compile/sqlglot/snapshots/test_compile_window/test_compile_window_w_groupby_rolling/out.sql b/tests/unit/core/compile/sqlglot/snapshots/test_compile_window/test_compile_window_w_groupby_rolling/out.sql index 11e3f4773e..b1d498bc76 100644 --- a/tests/unit/core/compile/sqlglot/snapshots/test_compile_window/test_compile_window_w_groupby_rolling/out.sql +++ b/tests/unit/core/compile/sqlglot/snapshots/test_compile_window/test_compile_window_w_groupby_rolling/out.sql @@ -12,39 +12,13 @@ WITH `bfcte_0` AS ( `int64_col` AS `bfcol_8`, `bool_col` AS `bfcol_9` FROM `bfcte_0` -), `bfcte_2` AS ( +), `bfcte_3` AS ( SELECT * FROM `bfcte_1` WHERE NOT `bfcol_9` IS NULL -), `bfcte_3` AS ( - SELECT - *, - CASE - WHEN SUM(CAST(NOT `bfcol_7` IS NULL AS INT64)) OVER ( - PARTITION BY `bfcol_9` - ORDER BY `bfcol_9` ASC NULLS LAST, `rowindex` ASC NULLS LAST - ROWS BETWEEN 3 PRECEDING AND CURRENT ROW - ) < 3 - THEN NULL - ELSE COALESCE( - SUM(CAST(`bfcol_7` AS INT64)) OVER ( - PARTITION BY `bfcol_9` - ORDER BY `bfcol_9` ASC NULLS LAST, `rowindex` ASC NULLS LAST - ROWS BETWEEN 3 PRECEDING AND CURRENT ROW - ), - 0 - ) - END AS `bfcol_15` - FROM `bfcte_2` ), `bfcte_4` AS ( - SELECT - * - FROM `bfcte_3` - WHERE - NOT `bfcol_9` IS NULL -), `bfcte_5` AS ( SELECT *, CASE @@ -62,15 +36,15 @@ WITH `bfcte_0` AS ( ), 0 ) - END AS `bfcol_21` - FROM `bfcte_4` + END AS `bfcol_16` + FROM `bfcte_3` ) SELECT `bfcol_9` AS `bool_col`, `bfcol_6` AS `rowindex`, `bfcol_15` AS `bool_col_1`, - `bfcol_21` AS `int64_col` -FROM `bfcte_5` + `bfcol_16` AS `int64_col` +FROM `bfcte_4` ORDER BY `bfcol_9` ASC NULLS LAST, `rowindex` ASC NULLS LAST \ No newline at end of file