From 75bb9338754c778ad7dc1091ba7c6350f772b991 Mon Sep 17 00:00:00 2001 From: Trevor Bergeron Date: Wed, 3 Sep 2025 23:50:39 +0000 Subject: [PATCH 1/3] refactor: Aggregation is now an expression subclass --- bigframes/core/array_value.py | 9 +- bigframes/core/bigframe_node.py | 9 +- bigframes/core/block_transforms.py | 22 +-- bigframes/core/blocks.py | 36 +++-- bigframes/core/compile/compiled.py | 11 +- .../ibis_compiler/aggregate_compiler.py | 18 +-- bigframes/core/compile/polars/compiler.py | 24 +-- .../compile/sqlglot/aggregate_compiler.py | 18 +-- bigframes/core/expression.py | 113 ------------- bigframes/core/expression_types.py | 151 ++++++++++++++++++ bigframes/core/groupby/aggs.py | 8 +- bigframes/core/groupby/dataframe_group_by.py | 6 +- bigframes/core/indexes/base.py | 9 +- bigframes/core/nodes.py | 15 +- bigframes/core/rewrite/order.py | 12 +- bigframes/core/rewrite/schema_binding.py | 16 +- bigframes/core/rewrite/timedeltas.py | 19 +-- bigframes/dataframe.py | 12 +- bigframes/operations/aggregations.py | 45 +++++- bigframes/series.py | 12 +- bigframes/session/polars_executor.py | 11 +- .../system/small/engines/test_aggregation.py | 12 +- .../aggregations/test_unary_compiler.py | 4 +- 23 files changed, 348 insertions(+), 244 deletions(-) create mode 100644 bigframes/core/expression_types.py diff --git a/bigframes/core/array_value.py b/bigframes/core/array_value.py index b47637cb59..f0cf64fe45 100644 --- a/bigframes/core/array_value.py +++ b/bigframes/core/array_value.py @@ -24,6 +24,7 @@ import pandas import pyarrow as pa +from bigframes.core import expression_types import bigframes.core.expression as ex import bigframes.core.guid import bigframes.core.identifiers as ids @@ -190,7 +191,7 @@ def row_count(self) -> ArrayValue: child=self.node, aggregations=( ( - ex.NullaryAggregation(agg_ops.size_op), + expression_types.NullaryAggregation(agg_ops.size_op), ids.ColumnId(bigframes.core.guid.generate_guid()), ), ), @@ -379,7 +380,7 @@ def drop_columns(self, columns: Iterable[str]) -> ArrayValue: def aggregate( self, - aggregations: typing.Sequence[typing.Tuple[ex.Aggregation, str]], + aggregations: typing.Sequence[typing.Tuple[expression_types.Aggregation, str]], by_column_ids: typing.Sequence[str] = (), dropna: bool = True, ) -> ArrayValue: @@ -420,7 +421,7 @@ def project_window_op( """ return self.project_window_expr( - ex.UnaryAggregation(op, ex.deref(column_name)), + expression_types.UnaryAggregation(op, ex.deref(column_name)), window_spec, never_skip_nulls, skip_reproject_unsafe, @@ -428,7 +429,7 @@ def project_window_op( def project_window_expr( self, - expression: ex.Aggregation, + expression: expression_types.Aggregation, window: WindowSpec, never_skip_nulls=False, skip_reproject_unsafe: bool = False, diff --git a/bigframes/core/bigframe_node.py b/bigframes/core/bigframe_node.py index 0c6f56f35a..7e40248a00 100644 --- a/bigframes/core/bigframe_node.py +++ b/bigframes/core/bigframe_node.py @@ -20,15 +20,12 @@ import functools import itertools import typing -from typing import Callable, Dict, Generator, Iterable, Mapping, Sequence, Tuple, Union +from typing import Callable, Dict, Generator, Iterable, Mapping, Sequence, Tuple from bigframes.core import expression, field, identifiers import bigframes.core.schema as schemata import bigframes.dtypes -if typing.TYPE_CHECKING: - import bigframes.session - COLUMN_SET = frozenset[identifiers.ColumnId] T = typing.TypeVar("T") @@ -281,8 +278,8 @@ def field_by_id(self) -> Mapping[identifiers.ColumnId, field.Field]: @property def _node_expressions( self, - ) -> Sequence[Union[expression.Expression, expression.Aggregation]]: - """List of scalar expressions. Intended for checking engine compatibility with used ops.""" + ) -> Sequence[expression.Expression]: + """List of expressions. Intended for checking engine compatibility with used ops.""" return () # Plan algorithms diff --git a/bigframes/core/block_transforms.py b/bigframes/core/block_transforms.py index 465728b0ef..9802884e89 100644 --- a/bigframes/core/block_transforms.py +++ b/bigframes/core/block_transforms.py @@ -21,12 +21,12 @@ import pandas as pd import bigframes.constants +from bigframes.core import expression_types import bigframes.core as core import bigframes.core.blocks as blocks import bigframes.core.expression as ex import bigframes.core.ordering as ordering import bigframes.core.window_spec as windows -import bigframes.dtypes import bigframes.dtypes as dtypes import bigframes.operations as ops import bigframes.operations.aggregations as agg_ops @@ -133,7 +133,7 @@ def quantile( block, _ = block.aggregate( grouping_column_ids, tuple( - ex.UnaryAggregation(agg_ops.AnyValueOp(), ex.deref(col)) + expression_types.UnaryAggregation(agg_ops.AnyValueOp(), ex.deref(col)) for col in quantile_cols ), column_labels=pd.Index(labels), @@ -363,7 +363,7 @@ def value_counts( block = dropna(block, columns, how="any") block, agg_ids = block.aggregate( by_column_ids=(*grouping_keys, *columns), - aggregations=[ex.NullaryAggregation(agg_ops.size_op)], + aggregations=[expression_types.NullaryAggregation(agg_ops.size_op)], dropna=drop_na and not grouping_keys, ) count_id = agg_ids[0] @@ -647,15 +647,15 @@ def skew( # counts, moment3 for each column aggregations = [] for i, col in enumerate(original_columns): - count_agg = ex.UnaryAggregation( + count_agg = expression_types.UnaryAggregation( agg_ops.count_op, ex.deref(col), ) - moment3_agg = ex.UnaryAggregation( + moment3_agg = expression_types.UnaryAggregation( agg_ops.mean_op, ex.deref(delta3_ids[i]), ) - variance_agg = ex.UnaryAggregation( + variance_agg = expression_types.UnaryAggregation( agg_ops.PopVarOp(), ex.deref(col), ) @@ -698,9 +698,13 @@ def kurt( # counts, moment4 for each column aggregations = [] for i, col in enumerate(original_columns): - count_agg = ex.UnaryAggregation(agg_ops.count_op, ex.deref(col)) - moment4_agg = ex.UnaryAggregation(agg_ops.mean_op, ex.deref(delta4_ids[i])) - variance_agg = ex.UnaryAggregation(agg_ops.PopVarOp(), ex.deref(col)) + count_agg = expression_types.UnaryAggregation(agg_ops.count_op, ex.deref(col)) + moment4_agg = expression_types.UnaryAggregation( + agg_ops.mean_op, ex.deref(delta4_ids[i]) + ) + variance_agg = expression_types.UnaryAggregation( + agg_ops.PopVarOp(), ex.deref(col) + ) aggregations.extend([count_agg, moment4_agg, variance_agg]) block, agg_ids = block.aggregate( diff --git a/bigframes/core/blocks.py b/bigframes/core/blocks.py index 07d7e4c45b..bc28bd6283 100644 --- a/bigframes/core/blocks.py +++ b/bigframes/core/blocks.py @@ -51,11 +51,12 @@ from bigframes import session from bigframes._config import sampling_options import bigframes.constants -from bigframes.core import local_data +from bigframes.core import expression_types, local_data import bigframes.core as core import bigframes.core.compile.googlesql as googlesql import bigframes.core.expression as ex import bigframes.core.expression as scalars +import bigframes.core.expression_types as ex_types import bigframes.core.guid as guid import bigframes.core.identifiers import bigframes.core.join_def as join_defs @@ -1143,7 +1144,7 @@ def apply_window_op( skip_reproject_unsafe: bool = False, never_skip_nulls: bool = False, ) -> typing.Tuple[Block, str]: - agg_expr = ex.UnaryAggregation(op, ex.deref(column)) + agg_expr = expression_types.UnaryAggregation(op, ex.deref(column)) return self.apply_analytic( agg_expr, window_spec, @@ -1155,7 +1156,7 @@ def apply_window_op( def apply_analytic( self, - agg_expr: ex.Aggregation, + agg_expr: expression_types.Aggregation, window: windows.WindowSpec, result_label: Label, *, @@ -1248,9 +1249,9 @@ def aggregate_all_and_stack( if axis_n == 0: aggregations = [ ( - ex.UnaryAggregation(operation, ex.deref(col_id)) + expression_types.UnaryAggregation(operation, ex.deref(col_id)) if isinstance(operation, agg_ops.UnaryAggregateOp) - else ex.NullaryAggregation(operation), + else expression_types.NullaryAggregation(operation), col_id, ) for col_id in self.value_columns @@ -1279,7 +1280,10 @@ def aggregate_size( ): """Returns a block object to compute the size(s) of groups.""" agg_specs = [ - (ex.NullaryAggregation(agg_ops.SizeOp()), guid.generate_guid()), + ( + expression_types.NullaryAggregation(agg_ops.SizeOp()), + guid.generate_guid(), + ), ] output_col_ids = [agg_spec[1] for agg_spec in agg_specs] result_expr = self.expr.aggregate(agg_specs, by_column_ids, dropna=dropna) @@ -1350,7 +1354,7 @@ def remap_f(x): def aggregate( self, by_column_ids: typing.Sequence[str] = (), - aggregations: typing.Sequence[ex.Aggregation] = (), + aggregations: typing.Sequence[expression_types.Aggregation] = (), column_labels: Optional[pd.Index] = None, *, dropna: bool = True, @@ -1419,9 +1423,9 @@ def get_stat( aggregations = [ ( - ex.UnaryAggregation(stat, ex.deref(column_id)) + expression_types.UnaryAggregation(stat, ex.deref(column_id)) if isinstance(stat, agg_ops.UnaryAggregateOp) - else ex.NullaryAggregation(stat), + else expression_types.NullaryAggregation(stat), stat.name, ) for stat in stats_to_fetch @@ -1447,7 +1451,7 @@ def get_binary_stat( # TODO(kemppeterson): Add a cache here. aggregations = [ ( - ex.BinaryAggregation( + expression_types.BinaryAggregation( stat, ex.deref(column_id_left), ex.deref(column_id_right) ), f"{stat.name}_{column_id_left}{column_id_right}", @@ -1474,9 +1478,9 @@ def summarize( labels = pd.Index([stat.name for stat in stats]) aggregations = [ ( - ex.UnaryAggregation(stat, ex.deref(col_id)) + expression_types.UnaryAggregation(stat, ex.deref(col_id)) if isinstance(stat, agg_ops.UnaryAggregateOp) - else ex.NullaryAggregation(stat), + else expression_types.NullaryAggregation(stat), f"{col_id}-{stat.name}", ) for stat in stats @@ -1750,7 +1754,7 @@ def pivot( block = block.select_columns(column_ids) aggregations = [ - ex.UnaryAggregation(agg_ops.AnyValueOp(), ex.deref(col_id)) + expression_types.UnaryAggregation(agg_ops.AnyValueOp(), ex.deref(col_id)) for col_id in column_ids ] result_block, _ = block.aggregate( @@ -2018,7 +2022,7 @@ def _generate_resample_label( agg_specs = [ ( - ex.UnaryAggregation(agg_ops.min_op, ex.deref(col_id)), + expression_types.UnaryAggregation(agg_ops.min_op, ex.deref(col_id)), guid.generate_guid(), ), ] @@ -2047,13 +2051,13 @@ def _generate_resample_label( # Generate integer label sequence. min_agg_specs = [ ( - ex.UnaryAggregation(agg_ops.min_op, ex.deref(label_col_id)), + ex_types.UnaryAggregation(agg_ops.min_op, ex.deref(label_col_id)), guid.generate_guid(), ), ] max_agg_specs = [ ( - ex.UnaryAggregation(agg_ops.max_op, ex.deref(label_col_id)), + ex_types.UnaryAggregation(agg_ops.max_op, ex.deref(label_col_id)), guid.generate_guid(), ), ] diff --git a/bigframes/core/compile/compiled.py b/bigframes/core/compile/compiled.py index f7de5c051a..a5309446cd 100644 --- a/bigframes/core/compile/compiled.py +++ b/bigframes/core/compile/compiled.py @@ -35,6 +35,7 @@ import bigframes.core.compile.ibis_compiler.scalar_op_compiler as op_compilers import bigframes.core.compile.ibis_types import bigframes.core.expression as ex +import bigframes.core.expression_types as ex_types from bigframes.core.ordering import OrderingExpression import bigframes.core.sql from bigframes.core.window_spec import RangeWindowBounds, RowsWindowBounds, WindowSpec @@ -215,7 +216,7 @@ def filter(self, predicate: ex.Expression) -> UnorderedIR: def aggregate( self, - aggregations: typing.Sequence[tuple[ex.Aggregation, str]], + aggregations: typing.Sequence[tuple[ex_types.Aggregation, str]], by_column_ids: typing.Sequence[ex.DerefOp] = (), order_by: typing.Sequence[OrderingExpression] = (), ) -> UnorderedIR: @@ -401,7 +402,7 @@ def isin_join( def project_window_op( self, - expression: ex.Aggregation, + expression: ex_types.Aggregation, window_spec: WindowSpec, output_name: str, *, @@ -467,7 +468,9 @@ def project_window_op( lambda x, y: x & y, per_col_does_count ).cast(int) observation_count = agg_compiler.compile_analytic( - ex.UnaryAggregation(agg_ops.sum_op, ex.deref("_observation_count")), + ex_types.UnaryAggregation( + agg_ops.sum_op, ex.deref("_observation_count") + ), window, bindings={"_observation_count": is_observation}, ) @@ -476,7 +479,7 @@ def project_window_op( # notnull is just used to convert null values to non-null (FALSE) values to be counted is_observation = inputs[0].notnull() observation_count = agg_compiler.compile_analytic( - ex.UnaryAggregation( + ex_types.UnaryAggregation( agg_ops.count_op, ex.deref("_observation_count") ), window, diff --git a/bigframes/core/compile/ibis_compiler/aggregate_compiler.py b/bigframes/core/compile/ibis_compiler/aggregate_compiler.py index 291db44524..3aebb10dcd 100644 --- a/bigframes/core/compile/ibis_compiler/aggregate_compiler.py +++ b/bigframes/core/compile/ibis_compiler/aggregate_compiler.py @@ -26,10 +26,10 @@ import bigframes_vendored.ibis.expr.types as ibis_types import pandas as pd +from bigframes.core import expression_types from bigframes.core.compile import constants as compiler_constants import bigframes.core.compile.ibis_compiler.scalar_op_compiler as scalar_compilers import bigframes.core.compile.ibis_types as compile_ibis_types -import bigframes.core.expression as ex import bigframes.core.window_spec as window_spec import bigframes.operations.aggregations as agg_ops @@ -48,19 +48,19 @@ def approx_quantiles(expression: float, number) -> List[float]: def compile_aggregate( - aggregate: ex.Aggregation, + aggregate: expression_types.Aggregation, bindings: typing.Dict[str, ibis_types.Value], order_by: typing.Sequence[ibis_types.Value] = [], ) -> ibis_types.Value: - if isinstance(aggregate, ex.NullaryAggregation): + if isinstance(aggregate, expression_types.NullaryAggregation): return compile_nullary_agg(aggregate.op) - if isinstance(aggregate, ex.UnaryAggregation): + if isinstance(aggregate, expression_types.UnaryAggregation): input = scalar_compiler.compile_expression(aggregate.arg, bindings=bindings) if not aggregate.op.order_independent: return compile_ordered_unary_agg(aggregate.op, input, order_by=order_by) # type: ignore else: return compile_unary_agg(aggregate.op, input) # type: ignore - elif isinstance(aggregate, ex.BinaryAggregation): + elif isinstance(aggregate, expression_types.BinaryAggregation): left = scalar_compiler.compile_expression(aggregate.left, bindings=bindings) right = scalar_compiler.compile_expression(aggregate.right, bindings=bindings) return compile_binary_agg(aggregate.op, left, right) # type: ignore @@ -69,16 +69,16 @@ def compile_aggregate( def compile_analytic( - aggregate: ex.Aggregation, + aggregate: expression_types.Aggregation, window: window_spec.WindowSpec, bindings: typing.Dict[str, ibis_types.Value], ) -> ibis_types.Value: - if isinstance(aggregate, ex.NullaryAggregation): + if isinstance(aggregate, expression_types.NullaryAggregation): return compile_nullary_agg(aggregate.op, window) - elif isinstance(aggregate, ex.UnaryAggregation): + elif isinstance(aggregate, expression_types.UnaryAggregation): input = scalar_compiler.compile_expression(aggregate.arg, bindings=bindings) return compile_unary_agg(aggregate.op, input, window) # type: ignore - elif isinstance(aggregate, ex.BinaryAggregation): + elif isinstance(aggregate, expression_types.BinaryAggregation): raise NotImplementedError("binary analytic operations not yet supported") else: raise ValueError(f"Unexpected analytic operation: {aggregate}") diff --git a/bigframes/core/compile/polars/compiler.py b/bigframes/core/compile/polars/compiler.py index 70fa516e51..c1ff0bcf18 100644 --- a/bigframes/core/compile/polars/compiler.py +++ b/bigframes/core/compile/polars/compiler.py @@ -22,7 +22,7 @@ import pandas as pd import bigframes.core -from bigframes.core import identifiers, nodes, ordering, window_spec +from bigframes.core import expression_types, identifiers, nodes, ordering, window_spec from bigframes.core.compile.polars import lowering import bigframes.core.expression as ex import bigframes.core.guid as guid @@ -443,15 +443,15 @@ class PolarsAggregateCompiler: def get_args( self, - agg: ex.Aggregation, + agg: expression_types.Aggregation, ) -> Sequence[pl.Expr]: """Prepares arguments for aggregation by compiling them.""" - if isinstance(agg, ex.NullaryAggregation): + if isinstance(agg, expression_types.NullaryAggregation): return [] - elif isinstance(agg, ex.UnaryAggregation): + elif isinstance(agg, expression_types.UnaryAggregation): arg = self.scalar_compiler.compile_expression(agg.arg) return [arg] - elif isinstance(agg, ex.BinaryAggregation): + elif isinstance(agg, expression_types.BinaryAggregation): larg = self.scalar_compiler.compile_expression(agg.left) rarg = self.scalar_compiler.compile_expression(agg.right) return [larg, rarg] @@ -460,13 +460,13 @@ def get_args( f"Aggregation {agg} not yet supported in polars engine." ) - def compile_agg_expr(self, expr: ex.Aggregation): - if isinstance(expr, ex.NullaryAggregation): + def compile_agg_expr(self, expr: expression_types.Aggregation): + if isinstance(expr, expression_types.NullaryAggregation): inputs: Tuple = () - elif isinstance(expr, ex.UnaryAggregation): + elif isinstance(expr, expression_types.UnaryAggregation): assert isinstance(expr.arg, ex.DerefOp) inputs = (expr.arg.id.sql,) - elif isinstance(expr, ex.BinaryAggregation): + elif isinstance(expr, expression_types.BinaryAggregation): assert isinstance(expr.left, ex.DerefOp) assert isinstance(expr.right, ex.DerefOp) inputs = ( @@ -769,7 +769,9 @@ def compile_agg(self, node: nodes.AggregateNode): def _aggregate( self, df: pl.LazyFrame, - aggregations: Sequence[Tuple[ex.Aggregation, identifiers.ColumnId]], + aggregations: Sequence[ + Tuple[expression_types.Aggregation, identifiers.ColumnId] + ], grouping_keys: Tuple[ex.DerefOp, ...], ) -> pl.LazyFrame: # Need to materialize columns to broadcast constants @@ -858,7 +860,7 @@ def compile_window(self, node: nodes.WindowOpNode): def _calc_row_analytic_func( self, frame: pl.LazyFrame, - agg_expr: ex.Aggregation, + agg_expr: expression_types.Aggregation, window: window_spec.WindowSpec, name: str, ) -> pl.LazyFrame: diff --git a/bigframes/core/compile/sqlglot/aggregate_compiler.py b/bigframes/core/compile/sqlglot/aggregate_compiler.py index 52ef4cc26c..85c3a2f75a 100644 --- a/bigframes/core/compile/sqlglot/aggregate_compiler.py +++ b/bigframes/core/compile/sqlglot/aggregate_compiler.py @@ -15,7 +15,7 @@ import sqlglot.expressions as sge -from bigframes.core import expression, window_spec +from bigframes.core import expression_types, window_spec from bigframes.core.compile.sqlglot.aggregations import ( binary_compiler, nullary_compiler, @@ -27,13 +27,13 @@ def compile_aggregate( - aggregate: expression.Aggregation, + aggregate: expression_types.Aggregation, order_by: tuple[sge.Expression, ...], ) -> sge.Expression: """Compiles BigFrames aggregation expression into SQLGlot expression.""" - if isinstance(aggregate, expression.NullaryAggregation): + if isinstance(aggregate, expression_types.NullaryAggregation): return nullary_compiler.compile(aggregate.op) - if isinstance(aggregate, expression.UnaryAggregation): + if isinstance(aggregate, expression_types.UnaryAggregation): column = typed_expr.TypedExpr( scalar_compiler.compile_scalar_expression(aggregate.arg), aggregate.arg.output_type, @@ -44,7 +44,7 @@ def compile_aggregate( ) else: return unary_compiler.compile(aggregate.op, column) - elif isinstance(aggregate, expression.BinaryAggregation): + elif isinstance(aggregate, expression_types.BinaryAggregation): left = typed_expr.TypedExpr( scalar_compiler.compile_scalar_expression(aggregate.left), aggregate.left.output_type, @@ -59,18 +59,18 @@ def compile_aggregate( def compile_analytic( - aggregate: expression.Aggregation, + aggregate: expression_types.Aggregation, window: window_spec.WindowSpec, ) -> sge.Expression: - if isinstance(aggregate, expression.NullaryAggregation): + if isinstance(aggregate, expression_types.NullaryAggregation): return nullary_compiler.compile(aggregate.op) - if isinstance(aggregate, expression.UnaryAggregation): + if isinstance(aggregate, expression_types.UnaryAggregation): column = typed_expr.TypedExpr( scalar_compiler.compile_scalar_expression(aggregate.arg), aggregate.arg.output_type, ) return unary_compiler.compile(aggregate.op, column, window) - elif isinstance(aggregate, expression.BinaryAggregation): + elif isinstance(aggregate, expression_types.BinaryAggregation): raise NotImplementedError("binary analytic operations not yet supported") else: raise ValueError(f"Unexpected analytic operation: {aggregate}") diff --git a/bigframes/core/expression.py b/bigframes/core/expression.py index 0e94193bd3..59679f1bc4 100644 --- a/bigframes/core/expression.py +++ b/bigframes/core/expression.py @@ -27,7 +27,6 @@ from bigframes.core import field import bigframes.core.identifiers as ids import bigframes.operations -import bigframes.operations.aggregations as agg_ops def const( @@ -44,118 +43,6 @@ def free_var(id: str) -> UnboundVariableExpression: return UnboundVariableExpression(id) -@dataclasses.dataclass(frozen=True) -class Aggregation(abc.ABC): - """Represents windowing or aggregation over a column.""" - - op: agg_ops.WindowOp = dataclasses.field() - - @abc.abstractmethod - def output_type( - self, input_fields: Mapping[ids.ColumnId, field.Field] - ) -> dtypes.ExpressionType: - ... - - @property - def column_references(self) -> typing.Tuple[ids.ColumnId, ...]: - return () - - @abc.abstractmethod - def remap_column_refs( - self, - name_mapping: Mapping[ids.ColumnId, ids.ColumnId], - allow_partial_bindings: bool = False, - ) -> Aggregation: - ... - - -@dataclasses.dataclass(frozen=True) -class NullaryAggregation(Aggregation): - op: agg_ops.NullaryWindowOp = dataclasses.field() - - def output_type( - self, input_fields: Mapping[ids.ColumnId, field.Field] - ) -> dtypes.ExpressionType: - return self.op.output_type() - - def remap_column_refs( - self, - name_mapping: Mapping[ids.ColumnId, ids.ColumnId], - allow_partial_bindings: bool = False, - ) -> NullaryAggregation: - return self - - -@dataclasses.dataclass(frozen=True) -class UnaryAggregation(Aggregation): - op: agg_ops.UnaryWindowOp - arg: Union[DerefOp, ScalarConstantExpression] - - def output_type( - self, input_fields: Mapping[ids.ColumnId, field.Field] - ) -> dtypes.ExpressionType: - # TODO(b/419300717) Remove resolutions once defers are cleaned up. - resolved_expr = bind_schema_fields(self.arg, input_fields) - assert resolved_expr.is_resolved - - return self.op.output_type(resolved_expr.output_type) - - @property - def column_references(self) -> typing.Tuple[ids.ColumnId, ...]: - return self.arg.column_references - - def remap_column_refs( - self, - name_mapping: Mapping[ids.ColumnId, ids.ColumnId], - allow_partial_bindings: bool = False, - ) -> UnaryAggregation: - return UnaryAggregation( - self.op, - self.arg.remap_column_refs( - name_mapping, allow_partial_bindings=allow_partial_bindings - ), - ) - - -@dataclasses.dataclass(frozen=True) -class BinaryAggregation(Aggregation): - op: agg_ops.BinaryAggregateOp = dataclasses.field() - left: Union[DerefOp, ScalarConstantExpression] = dataclasses.field() - right: Union[DerefOp, ScalarConstantExpression] = dataclasses.field() - - def output_type( - self, input_fields: Mapping[ids.ColumnId, field.Field] - ) -> dtypes.ExpressionType: - # TODO(b/419300717) Remove resolutions once defers are cleaned up. - left_resolved_expr = bind_schema_fields(self.left, input_fields) - assert left_resolved_expr.is_resolved - right_resolved_expr = bind_schema_fields(self.right, input_fields) - assert right_resolved_expr.is_resolved - - return self.op.output_type( - left_resolved_expr.output_type, left_resolved_expr.output_type - ) - - @property - def column_references(self) -> typing.Tuple[ids.ColumnId, ...]: - return (*self.left.column_references, *self.right.column_references) - - def remap_column_refs( - self, - name_mapping: Mapping[ids.ColumnId, ids.ColumnId], - allow_partial_bindings: bool = False, - ) -> BinaryAggregation: - return BinaryAggregation( - self.op, - self.left.remap_column_refs( - name_mapping, allow_partial_bindings=allow_partial_bindings - ), - self.right.remap_column_refs( - name_mapping, allow_partial_bindings=allow_partial_bindings - ), - ) - - TExpression = TypeVar("TExpression", bound="Expression") diff --git a/bigframes/core/expression_types.py b/bigframes/core/expression_types.py new file mode 100644 index 0000000000..f77525706b --- /dev/null +++ b/bigframes/core/expression_types.py @@ -0,0 +1,151 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import abc +import dataclasses +import functools +import itertools +import typing +from typing import Callable, Mapping, TypeVar + +from bigframes import dtypes +from bigframes.core import expression +import bigframes.core.identifiers as ids +import bigframes.operations.aggregations as agg_ops + +TExpression = TypeVar("TExpression", bound="Aggregation") + + +@dataclasses.dataclass(frozen=True) +class Aggregation(expression.Expression): + """Represents windowing or aggregation over a column.""" + + op: agg_ops.WindowOp = dataclasses.field() + + @property + def column_references(self) -> typing.Tuple[ids.ColumnId, ...]: + return tuple( + itertools.chain.from_iterable( + map(lambda x: x.column_references, self.inputs) + ) + ) + + @functools.cached_property + def is_resolved(self) -> bool: + return all(input.is_resolved for input in self.inputs) + + @functools.cached_property + def output_type(self) -> dtypes.ExpressionType: + if not self.is_resolved: + raise ValueError(f"Type of expression {self.op} has not been fixed.") + + input_types = [input.output_type for input in self.inputs] + + return self.op.output_type(*input_types) + + @property + @abc.abstractmethod + def inputs( + self, + ) -> typing.Tuple[expression.Expression, ...]: + ... + + @property + def free_variables(self) -> typing.Tuple[str, ...]: + return tuple( + itertools.chain.from_iterable(map(lambda x: x.free_variables, self.inputs)) + ) + + @property + def is_const(self) -> bool: + return all(child.is_const for child in self.inputs) + + @abc.abstractmethod + def replace_args(self: TExpression, *arg) -> TExpression: + ... + + def transform_children( + self: TExpression, t: Callable[[expression.Expression], expression.Expression] + ) -> TExpression: + return self.replace_args(*(t(arg) for arg in self.inputs)) + + def bind_variables( + self: TExpression, + bindings: Mapping[str, expression.Expression], + allow_partial_bindings: bool = False, + ) -> TExpression: + return self.transform_children( + lambda x: x.bind_variables(bindings, allow_partial_bindings) + ) + + def bind_refs( + self: TExpression, + bindings: Mapping[ids.ColumnId, expression.Expression], + allow_partial_bindings: bool = False, + ) -> TExpression: + return self.transform_children( + lambda x: x.bind_refs(bindings, allow_partial_bindings) + ) + + +@dataclasses.dataclass(frozen=True) +class NullaryAggregation(Aggregation): + op: agg_ops.NullaryWindowOp = dataclasses.field() + + @property + def inputs( + self, + ) -> typing.Tuple[expression.Expression, ...]: + return () + + def replace_args(self, *arg) -> NullaryAggregation: + return self + + +@dataclasses.dataclass(frozen=True) +class UnaryAggregation(Aggregation): + op: agg_ops.UnaryWindowOp + arg: expression.Expression + + @property + def inputs( + self, + ) -> typing.Tuple[expression.Expression, ...]: + return (self.arg,) + + def replace_args(self, arg: expression.Expression) -> UnaryAggregation: + return UnaryAggregation( + self.op, + arg, + ) + + +@dataclasses.dataclass(frozen=True) +class BinaryAggregation(Aggregation): + op: agg_ops.BinaryAggregateOp = dataclasses.field() + left: expression.Expression = dataclasses.field() + right: expression.Expression = dataclasses.field() + + @property + def inputs( + self, + ) -> typing.Tuple[expression.Expression, ...]: + return (self.left, self.right) + + def replace_args( + self, larg: expression.Expression, rarg: expression.Expression + ) -> BinaryAggregation: + return BinaryAggregation(self.op, larg, rarg) diff --git a/bigframes/core/groupby/aggs.py b/bigframes/core/groupby/aggs.py index 26257cc9b6..424294f3dd 100644 --- a/bigframes/core/groupby/aggs.py +++ b/bigframes/core/groupby/aggs.py @@ -14,13 +14,13 @@ from __future__ import annotations -from bigframes.core import expression +from bigframes.core import expression, expression_types from bigframes.operations import aggregations as agg_ops -def agg(input: str, op: agg_ops.AggregateOp) -> expression.Aggregation: +def agg(input: str, op: agg_ops.AggregateOp) -> expression_types.Aggregation: if isinstance(op, agg_ops.UnaryAggregateOp): - return expression.UnaryAggregation(op, expression.deref(input)) + return expression_types.UnaryAggregation(op, expression.deref(input)) else: assert isinstance(op, agg_ops.NullaryAggregateOp) - return expression.NullaryAggregation(op) + return expression_types.NullaryAggregation(op) diff --git a/bigframes/core/groupby/dataframe_group_by.py b/bigframes/core/groupby/dataframe_group_by.py index e4e4b313f9..68dcaca365 100644 --- a/bigframes/core/groupby/dataframe_group_by.py +++ b/bigframes/core/groupby/dataframe_group_by.py @@ -25,7 +25,7 @@ from bigframes import session from bigframes.core import expression as ex -from bigframes.core import log_adapter +from bigframes.core import expression_types, log_adapter import bigframes.core.block_transforms as block_ops import bigframes.core.blocks as blocks from bigframes.core.groupby import aggs, series_group_by @@ -327,7 +327,7 @@ def cumcount(self, ascending: bool = True) -> series.Series: ) ) block, result_id = self._block.apply_analytic( - ex.NullaryAggregation(agg_ops.size_op), + expression_types.NullaryAggregation(agg_ops.size_op), window=window_spec, result_label=None, ) @@ -488,7 +488,7 @@ def _agg_string(self, func: str) -> df.DataFrame: return dataframe if self._as_index else self._convert_index(dataframe) def _agg_dict(self, func: typing.Mapping) -> df.DataFrame: - aggregations: typing.List[ex.Aggregation] = [] + aggregations: typing.List[expression_types.Aggregation] = [] column_labels = [] want_aggfunc_level = any(utils.is_list_like(aggs) for aggs in func.values()) diff --git a/bigframes/core/indexes/base.py b/bigframes/core/indexes/base.py index f8ec38621d..b1c51d11ac 100644 --- a/bigframes/core/indexes/base.py +++ b/bigframes/core/indexes/base.py @@ -30,6 +30,7 @@ import bigframes.core.block_transforms as block_ops import bigframes.core.blocks as blocks import bigframes.core.expression as ex +import bigframes.core.expression_types as ex_types import bigframes.core.ordering as order import bigframes.core.utils as utils import bigframes.core.validations as validations @@ -282,7 +283,7 @@ def get_loc(self, key) -> typing.Union[int, slice, "bigframes.series.Series"]: filtered_block = block_with_offsets.filter_by_id(match_col_id) # Check if key exists at all by counting - count_agg = ex.UnaryAggregation(agg_ops.count_op, ex.deref(offsets_id)) + count_agg = ex_types.UnaryAggregation(agg_ops.count_op, ex.deref(offsets_id)) count_result = filtered_block._expr.aggregate([(count_agg, "count")]) count_scalar = self._block.session._executor.execute( @@ -294,7 +295,7 @@ def get_loc(self, key) -> typing.Union[int, slice, "bigframes.series.Series"]: # If only one match, return integer position if count_scalar == 1: - min_agg = ex.UnaryAggregation(agg_ops.min_op, ex.deref(offsets_id)) + min_agg = ex_types.UnaryAggregation(agg_ops.min_op, ex.deref(offsets_id)) position_result = filtered_block._expr.aggregate([(min_agg, "position")]) position_scalar = self._block.session._executor.execute( position_result, ex_spec.ExecutionSpec(promise_under_10gb=True) @@ -317,11 +318,11 @@ def _get_monotonic_slice(self, filtered_block, offsets_id: str) -> slice: # Combine min and max aggregations into a single query for efficiency min_max_aggs = [ ( - ex.UnaryAggregation(agg_ops.min_op, ex.deref(offsets_id)), + ex_types.UnaryAggregation(agg_ops.min_op, ex.deref(offsets_id)), "min_pos", ), ( - ex.UnaryAggregation(agg_ops.max_op, ex.deref(offsets_id)), + ex_types.UnaryAggregation(agg_ops.max_op, ex.deref(offsets_id)), "max_pos", ), ] diff --git a/bigframes/core/nodes.py b/bigframes/core/nodes.py index cf6e8a7e5c..3e160de49e 100644 --- a/bigframes/core/nodes.py +++ b/bigframes/core/nodes.py @@ -33,7 +33,7 @@ import google.cloud.bigquery as bq -from bigframes.core import identifiers, local_data, sequences +from bigframes.core import expression_types, identifiers, local_data, sequences from bigframes.core.bigframe_node import BigFrameNode, COLUMN_SET import bigframes.core.expression as ex from bigframes.core.field import Field @@ -1337,7 +1337,9 @@ def remap_refs( @dataclasses.dataclass(frozen=True, eq=False) class AggregateNode(UnaryNode): - aggregations: typing.Tuple[typing.Tuple[ex.Aggregation, identifiers.ColumnId], ...] + aggregations: typing.Tuple[ + typing.Tuple[expression_types.Aggregation, identifiers.ColumnId], ... + ] by_column_ids: typing.Tuple[ex.DerefOp, ...] = tuple([]) order_by: Tuple[OrderingExpression, ...] = () dropna: bool = True @@ -1360,9 +1362,7 @@ def fields(self) -> Sequence[Field]: agg_items = ( Field( id, - bigframes.dtypes.dtype_for_etype( - agg.output_type(self.child.field_by_id) - ), + ex.bind_schema_fields(agg, self.child.field_by_id).output_type, nullable=True, ) for agg, id in self.aggregations @@ -1437,7 +1437,7 @@ def remap_refs( @dataclasses.dataclass(frozen=True, eq=False) class WindowOpNode(UnaryNode, AdditiveNode): - expression: ex.Aggregation + expression: expression_types.Aggregation window_spec: window.WindowSpec output_name: identifiers.ColumnId never_skip_nulls: bool = False @@ -1478,11 +1478,10 @@ def row_count(self) -> Optional[int]: @functools.cached_property def added_field(self) -> Field: - input_fields = self.child.field_by_id # TODO: Determine if output could be non-null return Field( self.output_name, - bigframes.dtypes.dtype_for_etype(self.expression.output_type(input_fields)), + ex.bind_schema_fields(self.expression, self.child.field_by_id).output_type, ) @property diff --git a/bigframes/core/rewrite/order.py b/bigframes/core/rewrite/order.py index 5b5fb10753..da4d9dfa3b 100644 --- a/bigframes/core/rewrite/order.py +++ b/bigframes/core/rewrite/order.py @@ -15,7 +15,7 @@ import functools from typing import Mapping, Tuple -from bigframes.core import expression, identifiers +from bigframes.core import expression, expression_types, identifiers import bigframes.core.nodes import bigframes.core.ordering import bigframes.core.window_spec @@ -167,9 +167,7 @@ def pull_up_order_inner( ) else: # Otherwise we need to generate offsets - agg = bigframes.core.expression.NullaryAggregation( - agg_ops.RowNumberOp() - ) + agg = expression_types.NullaryAggregation(agg_ops.RowNumberOp()) window_spec = bigframes.core.window_spec.unbound( ordering=tuple(child_order.all_ordering_columns) ) @@ -287,9 +285,7 @@ def pull_order_concat( new_source, ((order_expression.scalar_expression, offsets_id),) ) else: - agg = bigframes.core.expression.NullaryAggregation( - agg_ops.RowNumberOp() - ) + agg = expression_types.NullaryAggregation(agg_ops.RowNumberOp()) window_spec = bigframes.core.window_spec.unbound( ordering=tuple(order.all_ordering_columns) ) @@ -423,7 +419,7 @@ def remove_order_strict( def rewrite_promote_offsets( node: bigframes.core.nodes.PromoteOffsetsNode, ) -> bigframes.core.nodes.WindowOpNode: - agg = bigframes.core.expression.NullaryAggregation(agg_ops.RowNumberOp()) + agg = expression_types.NullaryAggregation(agg_ops.RowNumberOp()) window_spec = bigframes.core.window_spec.unbound() return bigframes.core.nodes.WindowOpNode(node.child, agg, window_spec, node.col_id) diff --git a/bigframes/core/rewrite/schema_binding.py b/bigframes/core/rewrite/schema_binding.py index cbecf83035..b497cd5700 100644 --- a/bigframes/core/rewrite/schema_binding.py +++ b/bigframes/core/rewrite/schema_binding.py @@ -17,7 +17,7 @@ from bigframes.core import bigframe_node from bigframes.core import expression as ex -from bigframes.core import nodes, ordering +from bigframes.core import expression_types, nodes, ordering def bind_schema_to_tree( @@ -118,16 +118,16 @@ def bind_schema_to_node( def _bind_schema_to_aggregation_expr( - aggregation: ex.Aggregation, + aggregation: expression_types.Aggregation, child: bigframe_node.BigFrameNode, -) -> ex.Aggregation: +) -> expression_types.Aggregation: assert isinstance( - aggregation, ex.Aggregation + aggregation, expression_types.Aggregation ), f"Expected Aggregation, got {type(aggregation)}" - if isinstance(aggregation, ex.UnaryAggregation): + if isinstance(aggregation, expression_types.UnaryAggregation): return typing.cast( - ex.Aggregation, + expression_types.Aggregation, dataclasses.replace( aggregation, arg=typing.cast( @@ -136,9 +136,9 @@ def _bind_schema_to_aggregation_expr( ), ), ) - elif isinstance(aggregation, ex.BinaryAggregation): + elif isinstance(aggregation, expression_types.BinaryAggregation): return typing.cast( - ex.Aggregation, + expression_types.Aggregation, dataclasses.replace( aggregation, left=typing.cast( diff --git a/bigframes/core/rewrite/timedeltas.py b/bigframes/core/rewrite/timedeltas.py index ea8e608a84..d023f038f1 100644 --- a/bigframes/core/rewrite/timedeltas.py +++ b/bigframes/core/rewrite/timedeltas.py @@ -21,6 +21,7 @@ from bigframes import dtypes from bigframes import operations as ops from bigframes.core import expression as ex +from bigframes.core import expression_types as ex_types from bigframes.core import nodes, schema, utils from bigframes.operations import aggregations as aggs @@ -219,33 +220,33 @@ def _rewrite_to_timedelta_op(op: ops.ToTimedeltaOp, arg: _TypedExpr): @functools.cache def _rewrite_aggregation( - aggregation: ex.Aggregation, schema: schema.ArraySchema -) -> ex.Aggregation: - if not isinstance(aggregation, ex.UnaryAggregation): + aggregation: ex_types.Aggregation, schema: schema.ArraySchema +) -> ex_types.Aggregation: + if not isinstance(aggregation, ex_types.UnaryAggregation): return aggregation if isinstance(aggregation.arg, ex.DerefOp): input_type = schema.get_type(aggregation.arg.id.sql) else: - input_type = aggregation.arg.dtype + input_type = aggregation.arg.output_type if isinstance(aggregation.op, aggs.DiffOp): if dtypes.is_datetime_like(input_type): - return ex.UnaryAggregation( + return ex_types.UnaryAggregation( aggs.TimeSeriesDiffOp(aggregation.op.periods), aggregation.arg ) elif input_type == dtypes.DATE_DTYPE: - return ex.UnaryAggregation( + return ex_types.UnaryAggregation( aggs.DateSeriesDiffOp(aggregation.op.periods), aggregation.arg ) if isinstance(aggregation.op, aggs.StdOp) and input_type == dtypes.TIMEDELTA_DTYPE: - return ex.UnaryAggregation( + return ex_types.UnaryAggregation( aggs.StdOp(should_floor_result=True), aggregation.arg ) if isinstance(aggregation.op, aggs.MeanOp) and input_type == dtypes.TIMEDELTA_DTYPE: - return ex.UnaryAggregation( + return ex_types.UnaryAggregation( aggs.MeanOp(should_floor_result=True), aggregation.arg ) @@ -253,7 +254,7 @@ def _rewrite_aggregation( isinstance(aggregation.op, aggs.QuantileOp) and input_type == dtypes.TIMEDELTA_DTYPE ): - return ex.UnaryAggregation( + return ex_types.UnaryAggregation( aggs.QuantileOp(q=aggregation.op.q, should_floor_result=True), aggregation.arg, ) diff --git a/bigframes/dataframe.py b/bigframes/dataframe.py index f9de117b29..bfe5b5a96b 100644 --- a/bigframes/dataframe.py +++ b/bigframes/dataframe.py @@ -57,7 +57,7 @@ import bigframes._config.display_options as display_options import bigframes.constants import bigframes.core -from bigframes.core import log_adapter +from bigframes.core import expression_types, log_adapter import bigframes.core.block_transforms as block_ops import bigframes.core.blocks as blocks import bigframes.core.convert @@ -1363,7 +1363,9 @@ def _fast_stat_matrix(self, op: agg_ops.BinaryAggregateOp) -> DataFrame: block = frame._block aggregations = [ - ex.BinaryAggregation(op, ex.deref(left_col), ex.deref(right_col)) + expression_types.BinaryAggregation( + op, ex.deref(left_col), ex.deref(right_col) + ) for left_col in block.value_columns for right_col in block.value_columns ] @@ -1630,7 +1632,7 @@ def corrwith( block, _ = block.aggregate( aggregations=tuple( - ex.BinaryAggregation(agg_ops.CorrOp(), left_ex, right_ex) + expression_types.BinaryAggregation(agg_ops.CorrOp(), left_ex, right_ex) for left_ex, right_ex in expr_pairs ), column_labels=labels, @@ -3189,9 +3191,9 @@ def agg( for agg_func in agg_func_list: agg_op = agg_ops.lookup_agg_func(typing.cast(str, agg_func)) agg_expr = ( - ex.UnaryAggregation(agg_op, ex.deref(col_id)) + expression_types.UnaryAggregation(agg_op, ex.deref(col_id)) if isinstance(agg_op, agg_ops.UnaryAggregateOp) - else ex.NullaryAggregation(agg_op) + else expression_types.NullaryAggregation(agg_op) ) aggs.append(agg_expr) labels.append(col_label) diff --git a/bigframes/operations/aggregations.py b/bigframes/operations/aggregations.py index 6889997a10..38cf90d4bb 100644 --- a/bigframes/operations/aggregations.py +++ b/bigframes/operations/aggregations.py @@ -17,7 +17,7 @@ import abc import dataclasses import typing -from typing import ClassVar, Iterable, Optional +from typing import ClassVar, Iterable, Optional, TYPE_CHECKING import pandas as pd import pyarrow as pa @@ -25,6 +25,9 @@ import bigframes.dtypes as dtypes import bigframes.operations.type as signatures +if TYPE_CHECKING: + from bigframes.core import expression, expression_types + @dataclasses.dataclass(frozen=True) class WindowOp: @@ -110,6 +113,14 @@ class NullaryAggregateOp(AggregateOp, NullaryWindowOp): def arguments(self) -> int: return 0 + def as_expr( + self, + *exprs: typing.Union[str, expression.Expression], + ) -> expression_types.NullaryAggregation: + from bigframes.core import expression_types + + return expression_types.NullaryAggregation(self) + @dataclasses.dataclass(frozen=True) class UnaryAggregateOp(AggregateOp, UnaryWindowOp): @@ -117,6 +128,23 @@ class UnaryAggregateOp(AggregateOp, UnaryWindowOp): def arguments(self) -> int: return 1 + def as_expr( + self, + *exprs: typing.Union[str, expression.Expression], + ) -> expression_types.UnaryAggregation: + from bigframes.core import expression_types + from bigframes.operations.base_ops import _convert_expr_input + + # Keep this in sync with output_type and compilers + inputs: list[expression.Expression] = [] + + for expr in exprs: + inputs.append(_convert_expr_input(expr)) + return expression_types.UnaryAggregation( + self, + inputs[0], + ) + @dataclasses.dataclass(frozen=True) class BinaryAggregateOp(AggregateOp): @@ -124,6 +152,21 @@ class BinaryAggregateOp(AggregateOp): def arguments(self) -> int: return 2 + def as_expr( + self, + *exprs: typing.Union[str, expression.Expression], + ) -> expression_types.BinaryAggregation: + from bigframes.core import expression_types + from bigframes.operations.base_ops import _convert_expr_input + + # Keep this in sync with output_type and compilers + inputs: list[expression.Expression] = [] + + for expr in exprs: + inputs.append(_convert_expr_input(expr)) + + return expression_types.BinaryAggregation(self, inputs[0], inputs[1]) + @dataclasses.dataclass(frozen=True) class SizeOp(NullaryAggregateOp): diff --git a/bigframes/series.py b/bigframes/series.py index c95b2ca37f..7381d386b3 100644 --- a/bigframes/series.py +++ b/bigframes/series.py @@ -49,7 +49,7 @@ import typing_extensions import bigframes.core -from bigframes.core import groupby, log_adapter +from bigframes.core import expression_types, groupby, log_adapter import bigframes.core.block_transforms as block_ops import bigframes.core.blocks as blocks import bigframes.core.expression as ex @@ -1391,7 +1391,9 @@ def mode(self) -> Series: block, agg_ids = block.aggregate( by_column_ids=[self._value_column], aggregations=( - ex.UnaryAggregation(agg_ops.count_op, ex.deref(self._value_column)), + expression_types.UnaryAggregation( + agg_ops.count_op, ex.deref(self._value_column) + ), ), ) value_count_col_id = agg_ids[0] @@ -2116,7 +2118,11 @@ def unique(self, keep_order=True) -> Series: return self.drop_duplicates() block, result = self._block.aggregate( [self._value_column], - [ex.UnaryAggregation(agg_ops.AnyValueOp(), ex.deref(self._value_column))], + [ + expression_types.UnaryAggregation( + agg_ops.AnyValueOp(), ex.deref(self._value_column) + ) + ], column_labels=self._block.column_labels, dropna=False, ) diff --git a/bigframes/session/polars_executor.py b/bigframes/session/polars_executor.py index d8df558fe4..7f1de99f49 100644 --- a/bigframes/session/polars_executor.py +++ b/bigframes/session/polars_executor.py @@ -18,7 +18,14 @@ import pyarrow as pa -from bigframes.core import array_value, bigframe_node, expression, local_data, nodes +from bigframes.core import ( + array_value, + bigframe_node, + expression, + expression_types, + local_data, + nodes, +) import bigframes.operations from bigframes.operations import aggregations as agg_ops from bigframes.operations import ( @@ -112,7 +119,7 @@ def _is_node_polars_executable(node: nodes.BigFrameNode): if not isinstance(node, _COMPATIBLE_NODES): return False for expr in node._node_expressions: - if isinstance(expr, expression.Aggregation): + if isinstance(expr, expression_types.Aggregation): if not type(expr.op) in _COMPATIBLE_AGG_OPS: return False if isinstance(expr, expression.Expression): diff --git a/tests/system/small/engines/test_aggregation.py b/tests/system/small/engines/test_aggregation.py index c2fc9ad706..466e408114 100644 --- a/tests/system/small/engines/test_aggregation.py +++ b/tests/system/small/engines/test_aggregation.py @@ -14,7 +14,7 @@ import pytest -from bigframes.core import array_value, expression, identifiers, nodes +from bigframes.core import array_value, expression, expression_types, identifiers, nodes import bigframes.operations.aggregations as agg_ops from bigframes.session import polars_executor from bigframes.testing.engine_utils import assert_equivalence_execution @@ -37,7 +37,7 @@ def apply_agg_to_all_valid( continue try: _ = op.output_type(array.get_column_type(arg)) - expr = expression.UnaryAggregation(op, expression.deref(arg)) + expr = expression_types.UnaryAggregation(op, expression.deref(arg)) name = f"{arg}-{op.name}" exprs_by_name.append((expr, name)) except TypeError: @@ -56,11 +56,11 @@ def test_engines_aggregate_size( scalars_array_value.node, aggregations=( ( - expression.NullaryAggregation(agg_ops.SizeOp()), + expression_types.NullaryAggregation(agg_ops.SizeOp()), identifiers.ColumnId("size_op"), ), ( - expression.UnaryAggregation( + expression_types.UnaryAggregation( agg_ops.SizeUnaryOp(), expression.deref("string_col") ), identifiers.ColumnId("unary_size_op"), @@ -103,11 +103,11 @@ def test_engines_grouped_aggregate( scalars_array_value.node, aggregations=( ( - expression.NullaryAggregation(agg_ops.SizeOp()), + expression_types.NullaryAggregation(agg_ops.SizeOp()), identifiers.ColumnId("size_op"), ), ( - expression.UnaryAggregation( + expression_types.UnaryAggregation( agg_ops.SizeUnaryOp(), expression.deref("string_col") ), identifiers.ColumnId("unary_size_op"), diff --git a/tests/unit/core/compile/sqlglot/aggregations/test_unary_compiler.py b/tests/unit/core/compile/sqlglot/aggregations/test_unary_compiler.py index 96cdceb3c6..f185634741 100644 --- a/tests/unit/core/compile/sqlglot/aggregations/test_unary_compiler.py +++ b/tests/unit/core/compile/sqlglot/aggregations/test_unary_compiler.py @@ -14,7 +14,7 @@ import pytest -from bigframes.core import array_value, expression, identifiers, nodes +from bigframes.core import array_value, expression, expression_types, identifiers, nodes from bigframes.operations import aggregations as agg_ops import bigframes.pandas as bpd @@ -26,7 +26,7 @@ def _apply_unary_op(obj: bpd.DataFrame, op: agg_ops.UnaryWindowOp, arg: str) -> obj._block.expr.node, aggregations=( ( - expression.UnaryAggregation(op, expression.deref(arg)), + expression_types.UnaryAggregation(op, expression.deref(arg)), identifiers.ColumnId(arg + "_agg"), ), ), From 6d02dbc81b8cabaa66ca1b5fcd91cf344d5b83f4 Mon Sep 17 00:00:00 2001 From: Trevor Bergeron Date: Thu, 4 Sep 2025 00:16:28 +0000 Subject: [PATCH 2/3] fix test_windowing --- tests/system/small/engines/test_windowing.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/tests/system/small/engines/test_windowing.py b/tests/system/small/engines/test_windowing.py index a5f20a47cd..66a4cdb6a2 100644 --- a/tests/system/small/engines/test_windowing.py +++ b/tests/system/small/engines/test_windowing.py @@ -15,7 +15,14 @@ from google.cloud import bigquery import pytest -from bigframes.core import array_value, expression, identifiers, nodes, window_spec +from bigframes.core import ( + array_value, + expression, + expression_types, + identifiers, + nodes, + window_spec, +) import bigframes.operations.aggregations as agg_ops from bigframes.session import direct_gbq_execution, polars_executor from bigframes.testing.engine_utils import assert_equivalence_execution @@ -48,7 +55,9 @@ def test_engines_with_rows_window( ) window_node = nodes.WindowOpNode( child=scalars_array_value.node, - expression=expression.UnaryAggregation(agg_op, expression.deref("int64_too")), + expression=expression_types.UnaryAggregation( + agg_op, expression.deref("int64_too") + ), window_spec=window, output_name=identifiers.ColumnId("agg_int64"), never_skip_nulls=never_skip_nulls, From 9ec116a16620583b9cf72096fa24828cc75dd213 Mon Sep 17 00:00:00 2001 From: Trevor Bergeron Date: Thu, 4 Sep 2025 17:53:50 +0000 Subject: [PATCH 3/3] rename expression_types.py to agg_expressions.py --- ...expression_types.py => agg_expressions.py} | 0 bigframes/core/array_value.py | 10 +++---- bigframes/core/block_transforms.py | 18 +++++------ bigframes/core/blocks.py | 30 +++++++++---------- bigframes/core/compile/compiled.py | 2 +- .../ibis_compiler/aggregate_compiler.py | 18 +++++------ bigframes/core/compile/polars/compiler.py | 22 +++++++------- .../compile/sqlglot/aggregate_compiler.py | 18 +++++------ bigframes/core/groupby/aggs.py | 8 ++--- bigframes/core/groupby/dataframe_group_by.py | 7 +++-- bigframes/core/indexes/base.py | 2 +- bigframes/core/nodes.py | 6 ++-- bigframes/core/rewrite/order.py | 8 ++--- bigframes/core/rewrite/schema_binding.py | 18 +++++------ bigframes/core/rewrite/timedeltas.py | 2 +- bigframes/dataframe.py | 10 +++---- bigframes/operations/aggregations.py | 21 ++++++------- bigframes/series.py | 6 ++-- bigframes/session/polars_executor.py | 4 +-- .../system/small/engines/test_aggregation.py | 12 ++++---- tests/system/small/engines/test_windowing.py | 4 +-- .../aggregations/test_unary_compiler.py | 4 +-- 22 files changed, 116 insertions(+), 114 deletions(-) rename bigframes/core/{expression_types.py => agg_expressions.py} (100%) diff --git a/bigframes/core/expression_types.py b/bigframes/core/agg_expressions.py similarity index 100% rename from bigframes/core/expression_types.py rename to bigframes/core/agg_expressions.py diff --git a/bigframes/core/array_value.py b/bigframes/core/array_value.py index f0cf64fe45..b37c581a4a 100644 --- a/bigframes/core/array_value.py +++ b/bigframes/core/array_value.py @@ -24,7 +24,7 @@ import pandas import pyarrow as pa -from bigframes.core import expression_types +from bigframes.core import agg_expressions import bigframes.core.expression as ex import bigframes.core.guid import bigframes.core.identifiers as ids @@ -191,7 +191,7 @@ def row_count(self) -> ArrayValue: child=self.node, aggregations=( ( - expression_types.NullaryAggregation(agg_ops.size_op), + agg_expressions.NullaryAggregation(agg_ops.size_op), ids.ColumnId(bigframes.core.guid.generate_guid()), ), ), @@ -380,7 +380,7 @@ def drop_columns(self, columns: Iterable[str]) -> ArrayValue: def aggregate( self, - aggregations: typing.Sequence[typing.Tuple[expression_types.Aggregation, str]], + aggregations: typing.Sequence[typing.Tuple[agg_expressions.Aggregation, str]], by_column_ids: typing.Sequence[str] = (), dropna: bool = True, ) -> ArrayValue: @@ -421,7 +421,7 @@ def project_window_op( """ return self.project_window_expr( - expression_types.UnaryAggregation(op, ex.deref(column_name)), + agg_expressions.UnaryAggregation(op, ex.deref(column_name)), window_spec, never_skip_nulls, skip_reproject_unsafe, @@ -429,7 +429,7 @@ def project_window_op( def project_window_expr( self, - expression: expression_types.Aggregation, + expression: agg_expressions.Aggregation, window: WindowSpec, never_skip_nulls=False, skip_reproject_unsafe: bool = False, diff --git a/bigframes/core/block_transforms.py b/bigframes/core/block_transforms.py index 9802884e89..279643b91d 100644 --- a/bigframes/core/block_transforms.py +++ b/bigframes/core/block_transforms.py @@ -21,7 +21,7 @@ import pandas as pd import bigframes.constants -from bigframes.core import expression_types +from bigframes.core import agg_expressions import bigframes.core as core import bigframes.core.blocks as blocks import bigframes.core.expression as ex @@ -133,7 +133,7 @@ def quantile( block, _ = block.aggregate( grouping_column_ids, tuple( - expression_types.UnaryAggregation(agg_ops.AnyValueOp(), ex.deref(col)) + agg_expressions.UnaryAggregation(agg_ops.AnyValueOp(), ex.deref(col)) for col in quantile_cols ), column_labels=pd.Index(labels), @@ -363,7 +363,7 @@ def value_counts( block = dropna(block, columns, how="any") block, agg_ids = block.aggregate( by_column_ids=(*grouping_keys, *columns), - aggregations=[expression_types.NullaryAggregation(agg_ops.size_op)], + aggregations=[agg_expressions.NullaryAggregation(agg_ops.size_op)], dropna=drop_na and not grouping_keys, ) count_id = agg_ids[0] @@ -647,15 +647,15 @@ def skew( # counts, moment3 for each column aggregations = [] for i, col in enumerate(original_columns): - count_agg = expression_types.UnaryAggregation( + count_agg = agg_expressions.UnaryAggregation( agg_ops.count_op, ex.deref(col), ) - moment3_agg = expression_types.UnaryAggregation( + moment3_agg = agg_expressions.UnaryAggregation( agg_ops.mean_op, ex.deref(delta3_ids[i]), ) - variance_agg = expression_types.UnaryAggregation( + variance_agg = agg_expressions.UnaryAggregation( agg_ops.PopVarOp(), ex.deref(col), ) @@ -698,11 +698,11 @@ def kurt( # counts, moment4 for each column aggregations = [] for i, col in enumerate(original_columns): - count_agg = expression_types.UnaryAggregation(agg_ops.count_op, ex.deref(col)) - moment4_agg = expression_types.UnaryAggregation( + count_agg = agg_expressions.UnaryAggregation(agg_ops.count_op, ex.deref(col)) + moment4_agg = agg_expressions.UnaryAggregation( agg_ops.mean_op, ex.deref(delta4_ids[i]) ) - variance_agg = expression_types.UnaryAggregation( + variance_agg = agg_expressions.UnaryAggregation( agg_ops.PopVarOp(), ex.deref(col) ) aggregations.extend([count_agg, moment4_agg, variance_agg]) diff --git a/bigframes/core/blocks.py b/bigframes/core/blocks.py index bc28bd6283..444a5eddd7 100644 --- a/bigframes/core/blocks.py +++ b/bigframes/core/blocks.py @@ -51,12 +51,12 @@ from bigframes import session from bigframes._config import sampling_options import bigframes.constants -from bigframes.core import expression_types, local_data +from bigframes.core import agg_expressions, local_data import bigframes.core as core +import bigframes.core.agg_expressions as ex_types import bigframes.core.compile.googlesql as googlesql import bigframes.core.expression as ex import bigframes.core.expression as scalars -import bigframes.core.expression_types as ex_types import bigframes.core.guid as guid import bigframes.core.identifiers import bigframes.core.join_def as join_defs @@ -1144,7 +1144,7 @@ def apply_window_op( skip_reproject_unsafe: bool = False, never_skip_nulls: bool = False, ) -> typing.Tuple[Block, str]: - agg_expr = expression_types.UnaryAggregation(op, ex.deref(column)) + agg_expr = agg_expressions.UnaryAggregation(op, ex.deref(column)) return self.apply_analytic( agg_expr, window_spec, @@ -1156,7 +1156,7 @@ def apply_window_op( def apply_analytic( self, - agg_expr: expression_types.Aggregation, + agg_expr: agg_expressions.Aggregation, window: windows.WindowSpec, result_label: Label, *, @@ -1249,9 +1249,9 @@ def aggregate_all_and_stack( if axis_n == 0: aggregations = [ ( - expression_types.UnaryAggregation(operation, ex.deref(col_id)) + agg_expressions.UnaryAggregation(operation, ex.deref(col_id)) if isinstance(operation, agg_ops.UnaryAggregateOp) - else expression_types.NullaryAggregation(operation), + else agg_expressions.NullaryAggregation(operation), col_id, ) for col_id in self.value_columns @@ -1281,7 +1281,7 @@ def aggregate_size( """Returns a block object to compute the size(s) of groups.""" agg_specs = [ ( - expression_types.NullaryAggregation(agg_ops.SizeOp()), + agg_expressions.NullaryAggregation(agg_ops.SizeOp()), guid.generate_guid(), ), ] @@ -1354,7 +1354,7 @@ def remap_f(x): def aggregate( self, by_column_ids: typing.Sequence[str] = (), - aggregations: typing.Sequence[expression_types.Aggregation] = (), + aggregations: typing.Sequence[agg_expressions.Aggregation] = (), column_labels: Optional[pd.Index] = None, *, dropna: bool = True, @@ -1423,9 +1423,9 @@ def get_stat( aggregations = [ ( - expression_types.UnaryAggregation(stat, ex.deref(column_id)) + agg_expressions.UnaryAggregation(stat, ex.deref(column_id)) if isinstance(stat, agg_ops.UnaryAggregateOp) - else expression_types.NullaryAggregation(stat), + else agg_expressions.NullaryAggregation(stat), stat.name, ) for stat in stats_to_fetch @@ -1451,7 +1451,7 @@ def get_binary_stat( # TODO(kemppeterson): Add a cache here. aggregations = [ ( - expression_types.BinaryAggregation( + agg_expressions.BinaryAggregation( stat, ex.deref(column_id_left), ex.deref(column_id_right) ), f"{stat.name}_{column_id_left}{column_id_right}", @@ -1478,9 +1478,9 @@ def summarize( labels = pd.Index([stat.name for stat in stats]) aggregations = [ ( - expression_types.UnaryAggregation(stat, ex.deref(col_id)) + agg_expressions.UnaryAggregation(stat, ex.deref(col_id)) if isinstance(stat, agg_ops.UnaryAggregateOp) - else expression_types.NullaryAggregation(stat), + else agg_expressions.NullaryAggregation(stat), f"{col_id}-{stat.name}", ) for stat in stats @@ -1754,7 +1754,7 @@ def pivot( block = block.select_columns(column_ids) aggregations = [ - expression_types.UnaryAggregation(agg_ops.AnyValueOp(), ex.deref(col_id)) + agg_expressions.UnaryAggregation(agg_ops.AnyValueOp(), ex.deref(col_id)) for col_id in column_ids ] result_block, _ = block.aggregate( @@ -2022,7 +2022,7 @@ def _generate_resample_label( agg_specs = [ ( - expression_types.UnaryAggregation(agg_ops.min_op, ex.deref(col_id)), + agg_expressions.UnaryAggregation(agg_ops.min_op, ex.deref(col_id)), guid.generate_guid(), ), ] diff --git a/bigframes/core/compile/compiled.py b/bigframes/core/compile/compiled.py index a5309446cd..b28880d498 100644 --- a/bigframes/core/compile/compiled.py +++ b/bigframes/core/compile/compiled.py @@ -30,12 +30,12 @@ import pyarrow as pa from bigframes.core import utils +import bigframes.core.agg_expressions as ex_types import bigframes.core.compile.googlesql import bigframes.core.compile.ibis_compiler.aggregate_compiler as agg_compiler import bigframes.core.compile.ibis_compiler.scalar_op_compiler as op_compilers import bigframes.core.compile.ibis_types import bigframes.core.expression as ex -import bigframes.core.expression_types as ex_types from bigframes.core.ordering import OrderingExpression import bigframes.core.sql from bigframes.core.window_spec import RangeWindowBounds, RowsWindowBounds, WindowSpec diff --git a/bigframes/core/compile/ibis_compiler/aggregate_compiler.py b/bigframes/core/compile/ibis_compiler/aggregate_compiler.py index 3aebb10dcd..5e9cba7f8c 100644 --- a/bigframes/core/compile/ibis_compiler/aggregate_compiler.py +++ b/bigframes/core/compile/ibis_compiler/aggregate_compiler.py @@ -26,7 +26,7 @@ import bigframes_vendored.ibis.expr.types as ibis_types import pandas as pd -from bigframes.core import expression_types +from bigframes.core import agg_expressions from bigframes.core.compile import constants as compiler_constants import bigframes.core.compile.ibis_compiler.scalar_op_compiler as scalar_compilers import bigframes.core.compile.ibis_types as compile_ibis_types @@ -48,19 +48,19 @@ def approx_quantiles(expression: float, number) -> List[float]: def compile_aggregate( - aggregate: expression_types.Aggregation, + aggregate: agg_expressions.Aggregation, bindings: typing.Dict[str, ibis_types.Value], order_by: typing.Sequence[ibis_types.Value] = [], ) -> ibis_types.Value: - if isinstance(aggregate, expression_types.NullaryAggregation): + if isinstance(aggregate, agg_expressions.NullaryAggregation): return compile_nullary_agg(aggregate.op) - if isinstance(aggregate, expression_types.UnaryAggregation): + if isinstance(aggregate, agg_expressions.UnaryAggregation): input = scalar_compiler.compile_expression(aggregate.arg, bindings=bindings) if not aggregate.op.order_independent: return compile_ordered_unary_agg(aggregate.op, input, order_by=order_by) # type: ignore else: return compile_unary_agg(aggregate.op, input) # type: ignore - elif isinstance(aggregate, expression_types.BinaryAggregation): + elif isinstance(aggregate, agg_expressions.BinaryAggregation): left = scalar_compiler.compile_expression(aggregate.left, bindings=bindings) right = scalar_compiler.compile_expression(aggregate.right, bindings=bindings) return compile_binary_agg(aggregate.op, left, right) # type: ignore @@ -69,16 +69,16 @@ def compile_aggregate( def compile_analytic( - aggregate: expression_types.Aggregation, + aggregate: agg_expressions.Aggregation, window: window_spec.WindowSpec, bindings: typing.Dict[str, ibis_types.Value], ) -> ibis_types.Value: - if isinstance(aggregate, expression_types.NullaryAggregation): + if isinstance(aggregate, agg_expressions.NullaryAggregation): return compile_nullary_agg(aggregate.op, window) - elif isinstance(aggregate, expression_types.UnaryAggregation): + elif isinstance(aggregate, agg_expressions.UnaryAggregation): input = scalar_compiler.compile_expression(aggregate.arg, bindings=bindings) return compile_unary_agg(aggregate.op, input, window) # type: ignore - elif isinstance(aggregate, expression_types.BinaryAggregation): + elif isinstance(aggregate, agg_expressions.BinaryAggregation): raise NotImplementedError("binary analytic operations not yet supported") else: raise ValueError(f"Unexpected analytic operation: {aggregate}") diff --git a/bigframes/core/compile/polars/compiler.py b/bigframes/core/compile/polars/compiler.py index c1ff0bcf18..df84f08852 100644 --- a/bigframes/core/compile/polars/compiler.py +++ b/bigframes/core/compile/polars/compiler.py @@ -22,7 +22,7 @@ import pandas as pd import bigframes.core -from bigframes.core import expression_types, identifiers, nodes, ordering, window_spec +from bigframes.core import agg_expressions, identifiers, nodes, ordering, window_spec from bigframes.core.compile.polars import lowering import bigframes.core.expression as ex import bigframes.core.guid as guid @@ -443,15 +443,15 @@ class PolarsAggregateCompiler: def get_args( self, - agg: expression_types.Aggregation, + agg: agg_expressions.Aggregation, ) -> Sequence[pl.Expr]: """Prepares arguments for aggregation by compiling them.""" - if isinstance(agg, expression_types.NullaryAggregation): + if isinstance(agg, agg_expressions.NullaryAggregation): return [] - elif isinstance(agg, expression_types.UnaryAggregation): + elif isinstance(agg, agg_expressions.UnaryAggregation): arg = self.scalar_compiler.compile_expression(agg.arg) return [arg] - elif isinstance(agg, expression_types.BinaryAggregation): + elif isinstance(agg, agg_expressions.BinaryAggregation): larg = self.scalar_compiler.compile_expression(agg.left) rarg = self.scalar_compiler.compile_expression(agg.right) return [larg, rarg] @@ -460,13 +460,13 @@ def get_args( f"Aggregation {agg} not yet supported in polars engine." ) - def compile_agg_expr(self, expr: expression_types.Aggregation): - if isinstance(expr, expression_types.NullaryAggregation): + def compile_agg_expr(self, expr: agg_expressions.Aggregation): + if isinstance(expr, agg_expressions.NullaryAggregation): inputs: Tuple = () - elif isinstance(expr, expression_types.UnaryAggregation): + elif isinstance(expr, agg_expressions.UnaryAggregation): assert isinstance(expr.arg, ex.DerefOp) inputs = (expr.arg.id.sql,) - elif isinstance(expr, expression_types.BinaryAggregation): + elif isinstance(expr, agg_expressions.BinaryAggregation): assert isinstance(expr.left, ex.DerefOp) assert isinstance(expr.right, ex.DerefOp) inputs = ( @@ -770,7 +770,7 @@ def _aggregate( self, df: pl.LazyFrame, aggregations: Sequence[ - Tuple[expression_types.Aggregation, identifiers.ColumnId] + Tuple[agg_expressions.Aggregation, identifiers.ColumnId] ], grouping_keys: Tuple[ex.DerefOp, ...], ) -> pl.LazyFrame: @@ -860,7 +860,7 @@ def compile_window(self, node: nodes.WindowOpNode): def _calc_row_analytic_func( self, frame: pl.LazyFrame, - agg_expr: expression_types.Aggregation, + agg_expr: agg_expressions.Aggregation, window: window_spec.WindowSpec, name: str, ) -> pl.LazyFrame: diff --git a/bigframes/core/compile/sqlglot/aggregate_compiler.py b/bigframes/core/compile/sqlglot/aggregate_compiler.py index 85c3a2f75a..ccfba1ce0f 100644 --- a/bigframes/core/compile/sqlglot/aggregate_compiler.py +++ b/bigframes/core/compile/sqlglot/aggregate_compiler.py @@ -15,7 +15,7 @@ import sqlglot.expressions as sge -from bigframes.core import expression_types, window_spec +from bigframes.core import agg_expressions, window_spec from bigframes.core.compile.sqlglot.aggregations import ( binary_compiler, nullary_compiler, @@ -27,13 +27,13 @@ def compile_aggregate( - aggregate: expression_types.Aggregation, + aggregate: agg_expressions.Aggregation, order_by: tuple[sge.Expression, ...], ) -> sge.Expression: """Compiles BigFrames aggregation expression into SQLGlot expression.""" - if isinstance(aggregate, expression_types.NullaryAggregation): + if isinstance(aggregate, agg_expressions.NullaryAggregation): return nullary_compiler.compile(aggregate.op) - if isinstance(aggregate, expression_types.UnaryAggregation): + if isinstance(aggregate, agg_expressions.UnaryAggregation): column = typed_expr.TypedExpr( scalar_compiler.compile_scalar_expression(aggregate.arg), aggregate.arg.output_type, @@ -44,7 +44,7 @@ def compile_aggregate( ) else: return unary_compiler.compile(aggregate.op, column) - elif isinstance(aggregate, expression_types.BinaryAggregation): + elif isinstance(aggregate, agg_expressions.BinaryAggregation): left = typed_expr.TypedExpr( scalar_compiler.compile_scalar_expression(aggregate.left), aggregate.left.output_type, @@ -59,18 +59,18 @@ def compile_aggregate( def compile_analytic( - aggregate: expression_types.Aggregation, + aggregate: agg_expressions.Aggregation, window: window_spec.WindowSpec, ) -> sge.Expression: - if isinstance(aggregate, expression_types.NullaryAggregation): + if isinstance(aggregate, agg_expressions.NullaryAggregation): return nullary_compiler.compile(aggregate.op) - if isinstance(aggregate, expression_types.UnaryAggregation): + if isinstance(aggregate, agg_expressions.UnaryAggregation): column = typed_expr.TypedExpr( scalar_compiler.compile_scalar_expression(aggregate.arg), aggregate.arg.output_type, ) return unary_compiler.compile(aggregate.op, column, window) - elif isinstance(aggregate, expression_types.BinaryAggregation): + elif isinstance(aggregate, agg_expressions.BinaryAggregation): raise NotImplementedError("binary analytic operations not yet supported") else: raise ValueError(f"Unexpected analytic operation: {aggregate}") diff --git a/bigframes/core/groupby/aggs.py b/bigframes/core/groupby/aggs.py index 424294f3dd..9d8b957d54 100644 --- a/bigframes/core/groupby/aggs.py +++ b/bigframes/core/groupby/aggs.py @@ -14,13 +14,13 @@ from __future__ import annotations -from bigframes.core import expression, expression_types +from bigframes.core import agg_expressions, expression from bigframes.operations import aggregations as agg_ops -def agg(input: str, op: agg_ops.AggregateOp) -> expression_types.Aggregation: +def agg(input: str, op: agg_ops.AggregateOp) -> agg_expressions.Aggregation: if isinstance(op, agg_ops.UnaryAggregateOp): - return expression_types.UnaryAggregation(op, expression.deref(input)) + return agg_expressions.UnaryAggregation(op, expression.deref(input)) else: assert isinstance(op, agg_ops.NullaryAggregateOp) - return expression_types.NullaryAggregation(op) + return agg_expressions.NullaryAggregation(op) diff --git a/bigframes/core/groupby/dataframe_group_by.py b/bigframes/core/groupby/dataframe_group_by.py index 68dcaca365..3f5480436a 100644 --- a/bigframes/core/groupby/dataframe_group_by.py +++ b/bigframes/core/groupby/dataframe_group_by.py @@ -24,8 +24,9 @@ import pandas as pd from bigframes import session +from bigframes.core import agg_expressions from bigframes.core import expression as ex -from bigframes.core import expression_types, log_adapter +from bigframes.core import log_adapter import bigframes.core.block_transforms as block_ops import bigframes.core.blocks as blocks from bigframes.core.groupby import aggs, series_group_by @@ -327,7 +328,7 @@ def cumcount(self, ascending: bool = True) -> series.Series: ) ) block, result_id = self._block.apply_analytic( - expression_types.NullaryAggregation(agg_ops.size_op), + agg_expressions.NullaryAggregation(agg_ops.size_op), window=window_spec, result_label=None, ) @@ -488,7 +489,7 @@ def _agg_string(self, func: str) -> df.DataFrame: return dataframe if self._as_index else self._convert_index(dataframe) def _agg_dict(self, func: typing.Mapping) -> df.DataFrame: - aggregations: typing.List[expression_types.Aggregation] = [] + aggregations: typing.List[agg_expressions.Aggregation] = [] column_labels = [] want_aggfunc_level = any(utils.is_list_like(aggs) for aggs in func.values()) diff --git a/bigframes/core/indexes/base.py b/bigframes/core/indexes/base.py index b1c51d11ac..2a35ab6546 100644 --- a/bigframes/core/indexes/base.py +++ b/bigframes/core/indexes/base.py @@ -27,10 +27,10 @@ import pandas from bigframes import dtypes +import bigframes.core.agg_expressions as ex_types import bigframes.core.block_transforms as block_ops import bigframes.core.blocks as blocks import bigframes.core.expression as ex -import bigframes.core.expression_types as ex_types import bigframes.core.ordering as order import bigframes.core.utils as utils import bigframes.core.validations as validations diff --git a/bigframes/core/nodes.py b/bigframes/core/nodes.py index 3e160de49e..b6483689dc 100644 --- a/bigframes/core/nodes.py +++ b/bigframes/core/nodes.py @@ -33,7 +33,7 @@ import google.cloud.bigquery as bq -from bigframes.core import expression_types, identifiers, local_data, sequences +from bigframes.core import agg_expressions, identifiers, local_data, sequences from bigframes.core.bigframe_node import BigFrameNode, COLUMN_SET import bigframes.core.expression as ex from bigframes.core.field import Field @@ -1338,7 +1338,7 @@ def remap_refs( @dataclasses.dataclass(frozen=True, eq=False) class AggregateNode(UnaryNode): aggregations: typing.Tuple[ - typing.Tuple[expression_types.Aggregation, identifiers.ColumnId], ... + typing.Tuple[agg_expressions.Aggregation, identifiers.ColumnId], ... ] by_column_ids: typing.Tuple[ex.DerefOp, ...] = tuple([]) order_by: Tuple[OrderingExpression, ...] = () @@ -1437,7 +1437,7 @@ def remap_refs( @dataclasses.dataclass(frozen=True, eq=False) class WindowOpNode(UnaryNode, AdditiveNode): - expression: expression_types.Aggregation + expression: agg_expressions.Aggregation window_spec: window.WindowSpec output_name: identifiers.ColumnId never_skip_nulls: bool = False diff --git a/bigframes/core/rewrite/order.py b/bigframes/core/rewrite/order.py index da4d9dfa3b..881badd603 100644 --- a/bigframes/core/rewrite/order.py +++ b/bigframes/core/rewrite/order.py @@ -15,7 +15,7 @@ import functools from typing import Mapping, Tuple -from bigframes.core import expression, expression_types, identifiers +from bigframes.core import agg_expressions, expression, identifiers import bigframes.core.nodes import bigframes.core.ordering import bigframes.core.window_spec @@ -167,7 +167,7 @@ def pull_up_order_inner( ) else: # Otherwise we need to generate offsets - agg = expression_types.NullaryAggregation(agg_ops.RowNumberOp()) + agg = agg_expressions.NullaryAggregation(agg_ops.RowNumberOp()) window_spec = bigframes.core.window_spec.unbound( ordering=tuple(child_order.all_ordering_columns) ) @@ -285,7 +285,7 @@ def pull_order_concat( new_source, ((order_expression.scalar_expression, offsets_id),) ) else: - agg = expression_types.NullaryAggregation(agg_ops.RowNumberOp()) + agg = agg_expressions.NullaryAggregation(agg_ops.RowNumberOp()) window_spec = bigframes.core.window_spec.unbound( ordering=tuple(order.all_ordering_columns) ) @@ -419,7 +419,7 @@ def remove_order_strict( def rewrite_promote_offsets( node: bigframes.core.nodes.PromoteOffsetsNode, ) -> bigframes.core.nodes.WindowOpNode: - agg = expression_types.NullaryAggregation(agg_ops.RowNumberOp()) + agg = agg_expressions.NullaryAggregation(agg_ops.RowNumberOp()) window_spec = bigframes.core.window_spec.unbound() return bigframes.core.nodes.WindowOpNode(node.child, agg, window_spec, node.col_id) diff --git a/bigframes/core/rewrite/schema_binding.py b/bigframes/core/rewrite/schema_binding.py index b497cd5700..8a0bcc4921 100644 --- a/bigframes/core/rewrite/schema_binding.py +++ b/bigframes/core/rewrite/schema_binding.py @@ -15,9 +15,9 @@ import dataclasses import typing -from bigframes.core import bigframe_node +from bigframes.core import agg_expressions, bigframe_node from bigframes.core import expression as ex -from bigframes.core import expression_types, nodes, ordering +from bigframes.core import nodes, ordering def bind_schema_to_tree( @@ -118,16 +118,16 @@ def bind_schema_to_node( def _bind_schema_to_aggregation_expr( - aggregation: expression_types.Aggregation, + aggregation: agg_expressions.Aggregation, child: bigframe_node.BigFrameNode, -) -> expression_types.Aggregation: +) -> agg_expressions.Aggregation: assert isinstance( - aggregation, expression_types.Aggregation + aggregation, agg_expressions.Aggregation ), f"Expected Aggregation, got {type(aggregation)}" - if isinstance(aggregation, expression_types.UnaryAggregation): + if isinstance(aggregation, agg_expressions.UnaryAggregation): return typing.cast( - expression_types.Aggregation, + agg_expressions.Aggregation, dataclasses.replace( aggregation, arg=typing.cast( @@ -136,9 +136,9 @@ def _bind_schema_to_aggregation_expr( ), ), ) - elif isinstance(aggregation, expression_types.BinaryAggregation): + elif isinstance(aggregation, agg_expressions.BinaryAggregation): return typing.cast( - expression_types.Aggregation, + agg_expressions.Aggregation, dataclasses.replace( aggregation, left=typing.cast( diff --git a/bigframes/core/rewrite/timedeltas.py b/bigframes/core/rewrite/timedeltas.py index d023f038f1..91c6ab83c6 100644 --- a/bigframes/core/rewrite/timedeltas.py +++ b/bigframes/core/rewrite/timedeltas.py @@ -20,8 +20,8 @@ from bigframes import dtypes from bigframes import operations as ops +from bigframes.core import agg_expressions as ex_types from bigframes.core import expression as ex -from bigframes.core import expression_types as ex_types from bigframes.core import nodes, schema, utils from bigframes.operations import aggregations as aggs diff --git a/bigframes/dataframe.py b/bigframes/dataframe.py index bfe5b5a96b..c65bbdd2c8 100644 --- a/bigframes/dataframe.py +++ b/bigframes/dataframe.py @@ -57,7 +57,7 @@ import bigframes._config.display_options as display_options import bigframes.constants import bigframes.core -from bigframes.core import expression_types, log_adapter +from bigframes.core import agg_expressions, log_adapter import bigframes.core.block_transforms as block_ops import bigframes.core.blocks as blocks import bigframes.core.convert @@ -1363,7 +1363,7 @@ def _fast_stat_matrix(self, op: agg_ops.BinaryAggregateOp) -> DataFrame: block = frame._block aggregations = [ - expression_types.BinaryAggregation( + agg_expressions.BinaryAggregation( op, ex.deref(left_col), ex.deref(right_col) ) for left_col in block.value_columns @@ -1632,7 +1632,7 @@ def corrwith( block, _ = block.aggregate( aggregations=tuple( - expression_types.BinaryAggregation(agg_ops.CorrOp(), left_ex, right_ex) + agg_expressions.BinaryAggregation(agg_ops.CorrOp(), left_ex, right_ex) for left_ex, right_ex in expr_pairs ), column_labels=labels, @@ -3191,9 +3191,9 @@ def agg( for agg_func in agg_func_list: agg_op = agg_ops.lookup_agg_func(typing.cast(str, agg_func)) agg_expr = ( - expression_types.UnaryAggregation(agg_op, ex.deref(col_id)) + agg_expressions.UnaryAggregation(agg_op, ex.deref(col_id)) if isinstance(agg_op, agg_ops.UnaryAggregateOp) - else expression_types.NullaryAggregation(agg_op) + else agg_expressions.NullaryAggregation(agg_op) ) aggs.append(agg_expr) labels.append(col_label) diff --git a/bigframes/operations/aggregations.py b/bigframes/operations/aggregations.py index 38cf90d4bb..81ab18272c 100644 --- a/bigframes/operations/aggregations.py +++ b/bigframes/operations/aggregations.py @@ -22,11 +22,12 @@ import pandas as pd import pyarrow as pa +from bigframes.core import agg_expressions import bigframes.dtypes as dtypes import bigframes.operations.type as signatures if TYPE_CHECKING: - from bigframes.core import expression, expression_types + from bigframes.core import expression @dataclasses.dataclass(frozen=True) @@ -116,10 +117,10 @@ def arguments(self) -> int: def as_expr( self, *exprs: typing.Union[str, expression.Expression], - ) -> expression_types.NullaryAggregation: - from bigframes.core import expression_types + ) -> agg_expressions.NullaryAggregation: + from bigframes.core import agg_expressions - return expression_types.NullaryAggregation(self) + return agg_expressions.NullaryAggregation(self) @dataclasses.dataclass(frozen=True) @@ -131,8 +132,8 @@ def arguments(self) -> int: def as_expr( self, *exprs: typing.Union[str, expression.Expression], - ) -> expression_types.UnaryAggregation: - from bigframes.core import expression_types + ) -> agg_expressions.UnaryAggregation: + from bigframes.core import agg_expressions from bigframes.operations.base_ops import _convert_expr_input # Keep this in sync with output_type and compilers @@ -140,7 +141,7 @@ def as_expr( for expr in exprs: inputs.append(_convert_expr_input(expr)) - return expression_types.UnaryAggregation( + return agg_expressions.UnaryAggregation( self, inputs[0], ) @@ -155,8 +156,8 @@ def arguments(self) -> int: def as_expr( self, *exprs: typing.Union[str, expression.Expression], - ) -> expression_types.BinaryAggregation: - from bigframes.core import expression_types + ) -> agg_expressions.BinaryAggregation: + from bigframes.core import agg_expressions from bigframes.operations.base_ops import _convert_expr_input # Keep this in sync with output_type and compilers @@ -165,7 +166,7 @@ def as_expr( for expr in exprs: inputs.append(_convert_expr_input(expr)) - return expression_types.BinaryAggregation(self, inputs[0], inputs[1]) + return agg_expressions.BinaryAggregation(self, inputs[0], inputs[1]) @dataclasses.dataclass(frozen=True) diff --git a/bigframes/series.py b/bigframes/series.py index 7381d386b3..3e24a75d9b 100644 --- a/bigframes/series.py +++ b/bigframes/series.py @@ -49,7 +49,7 @@ import typing_extensions import bigframes.core -from bigframes.core import expression_types, groupby, log_adapter +from bigframes.core import agg_expressions, groupby, log_adapter import bigframes.core.block_transforms as block_ops import bigframes.core.blocks as blocks import bigframes.core.expression as ex @@ -1391,7 +1391,7 @@ def mode(self) -> Series: block, agg_ids = block.aggregate( by_column_ids=[self._value_column], aggregations=( - expression_types.UnaryAggregation( + agg_expressions.UnaryAggregation( agg_ops.count_op, ex.deref(self._value_column) ), ), @@ -2119,7 +2119,7 @@ def unique(self, keep_order=True) -> Series: block, result = self._block.aggregate( [self._value_column], [ - expression_types.UnaryAggregation( + agg_expressions.UnaryAggregation( agg_ops.AnyValueOp(), ex.deref(self._value_column) ) ], diff --git a/bigframes/session/polars_executor.py b/bigframes/session/polars_executor.py index 7f1de99f49..a1e1d436e1 100644 --- a/bigframes/session/polars_executor.py +++ b/bigframes/session/polars_executor.py @@ -19,10 +19,10 @@ import pyarrow as pa from bigframes.core import ( + agg_expressions, array_value, bigframe_node, expression, - expression_types, local_data, nodes, ) @@ -119,7 +119,7 @@ def _is_node_polars_executable(node: nodes.BigFrameNode): if not isinstance(node, _COMPATIBLE_NODES): return False for expr in node._node_expressions: - if isinstance(expr, expression_types.Aggregation): + if isinstance(expr, agg_expressions.Aggregation): if not type(expr.op) in _COMPATIBLE_AGG_OPS: return False if isinstance(expr, expression.Expression): diff --git a/tests/system/small/engines/test_aggregation.py b/tests/system/small/engines/test_aggregation.py index 466e408114..a4a49c622a 100644 --- a/tests/system/small/engines/test_aggregation.py +++ b/tests/system/small/engines/test_aggregation.py @@ -14,7 +14,7 @@ import pytest -from bigframes.core import array_value, expression, expression_types, identifiers, nodes +from bigframes.core import agg_expressions, array_value, expression, identifiers, nodes import bigframes.operations.aggregations as agg_ops from bigframes.session import polars_executor from bigframes.testing.engine_utils import assert_equivalence_execution @@ -37,7 +37,7 @@ def apply_agg_to_all_valid( continue try: _ = op.output_type(array.get_column_type(arg)) - expr = expression_types.UnaryAggregation(op, expression.deref(arg)) + expr = agg_expressions.UnaryAggregation(op, expression.deref(arg)) name = f"{arg}-{op.name}" exprs_by_name.append((expr, name)) except TypeError: @@ -56,11 +56,11 @@ def test_engines_aggregate_size( scalars_array_value.node, aggregations=( ( - expression_types.NullaryAggregation(agg_ops.SizeOp()), + agg_expressions.NullaryAggregation(agg_ops.SizeOp()), identifiers.ColumnId("size_op"), ), ( - expression_types.UnaryAggregation( + agg_expressions.UnaryAggregation( agg_ops.SizeUnaryOp(), expression.deref("string_col") ), identifiers.ColumnId("unary_size_op"), @@ -103,11 +103,11 @@ def test_engines_grouped_aggregate( scalars_array_value.node, aggregations=( ( - expression_types.NullaryAggregation(agg_ops.SizeOp()), + agg_expressions.NullaryAggregation(agg_ops.SizeOp()), identifiers.ColumnId("size_op"), ), ( - expression_types.UnaryAggregation( + agg_expressions.UnaryAggregation( agg_ops.SizeUnaryOp(), expression.deref("string_col") ), identifiers.ColumnId("unary_size_op"), diff --git a/tests/system/small/engines/test_windowing.py b/tests/system/small/engines/test_windowing.py index 66a4cdb6a2..f344a3b60a 100644 --- a/tests/system/small/engines/test_windowing.py +++ b/tests/system/small/engines/test_windowing.py @@ -16,9 +16,9 @@ import pytest from bigframes.core import ( + agg_expressions, array_value, expression, - expression_types, identifiers, nodes, window_spec, @@ -55,7 +55,7 @@ def test_engines_with_rows_window( ) window_node = nodes.WindowOpNode( child=scalars_array_value.node, - expression=expression_types.UnaryAggregation( + expression=agg_expressions.UnaryAggregation( agg_op, expression.deref("int64_too") ), window_spec=window, diff --git a/tests/unit/core/compile/sqlglot/aggregations/test_unary_compiler.py b/tests/unit/core/compile/sqlglot/aggregations/test_unary_compiler.py index f185634741..d12b4dda17 100644 --- a/tests/unit/core/compile/sqlglot/aggregations/test_unary_compiler.py +++ b/tests/unit/core/compile/sqlglot/aggregations/test_unary_compiler.py @@ -14,7 +14,7 @@ import pytest -from bigframes.core import array_value, expression, expression_types, identifiers, nodes +from bigframes.core import agg_expressions, array_value, expression, identifiers, nodes from bigframes.operations import aggregations as agg_ops import bigframes.pandas as bpd @@ -26,7 +26,7 @@ def _apply_unary_op(obj: bpd.DataFrame, op: agg_ops.UnaryWindowOp, arg: str) -> obj._block.expr.node, aggregations=( ( - expression_types.UnaryAggregation(op, expression.deref(arg)), + agg_expressions.UnaryAggregation(op, expression.deref(arg)), identifiers.ColumnId(arg + "_agg"), ), ),