Skip to content

Commit 75bb933

Browse files
refactor: Aggregation is now an expression subclass
1 parent 88115fa commit 75bb933

File tree

23 files changed

+348
-244
lines changed

23 files changed

+348
-244
lines changed

bigframes/core/array_value.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import pandas
2525
import pyarrow as pa
2626

27+
from bigframes.core import expression_types
2728
import bigframes.core.expression as ex
2829
import bigframes.core.guid
2930
import bigframes.core.identifiers as ids
@@ -190,7 +191,7 @@ def row_count(self) -> ArrayValue:
190191
child=self.node,
191192
aggregations=(
192193
(
193-
ex.NullaryAggregation(agg_ops.size_op),
194+
expression_types.NullaryAggregation(agg_ops.size_op),
194195
ids.ColumnId(bigframes.core.guid.generate_guid()),
195196
),
196197
),
@@ -379,7 +380,7 @@ def drop_columns(self, columns: Iterable[str]) -> ArrayValue:
379380

380381
def aggregate(
381382
self,
382-
aggregations: typing.Sequence[typing.Tuple[ex.Aggregation, str]],
383+
aggregations: typing.Sequence[typing.Tuple[expression_types.Aggregation, str]],
383384
by_column_ids: typing.Sequence[str] = (),
384385
dropna: bool = True,
385386
) -> ArrayValue:
@@ -420,15 +421,15 @@ def project_window_op(
420421
"""
421422

422423
return self.project_window_expr(
423-
ex.UnaryAggregation(op, ex.deref(column_name)),
424+
expression_types.UnaryAggregation(op, ex.deref(column_name)),
424425
window_spec,
425426
never_skip_nulls,
426427
skip_reproject_unsafe,
427428
)
428429

429430
def project_window_expr(
430431
self,
431-
expression: ex.Aggregation,
432+
expression: expression_types.Aggregation,
432433
window: WindowSpec,
433434
never_skip_nulls=False,
434435
skip_reproject_unsafe: bool = False,

bigframes/core/bigframe_node.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,12 @@
2020
import functools
2121
import itertools
2222
import typing
23-
from typing import Callable, Dict, Generator, Iterable, Mapping, Sequence, Tuple, Union
23+
from typing import Callable, Dict, Generator, Iterable, Mapping, Sequence, Tuple
2424

2525
from bigframes.core import expression, field, identifiers
2626
import bigframes.core.schema as schemata
2727
import bigframes.dtypes
2828

29-
if typing.TYPE_CHECKING:
30-
import bigframes.session
31-
3229
COLUMN_SET = frozenset[identifiers.ColumnId]
3330

3431
T = typing.TypeVar("T")
@@ -281,8 +278,8 @@ def field_by_id(self) -> Mapping[identifiers.ColumnId, field.Field]:
281278
@property
282279
def _node_expressions(
283280
self,
284-
) -> Sequence[Union[expression.Expression, expression.Aggregation]]:
285-
"""List of scalar expressions. Intended for checking engine compatibility with used ops."""
281+
) -> Sequence[expression.Expression]:
282+
"""List of expressions. Intended for checking engine compatibility with used ops."""
286283
return ()
287284

288285
# Plan algorithms

bigframes/core/block_transforms.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,12 @@
2121
import pandas as pd
2222

2323
import bigframes.constants
24+
from bigframes.core import expression_types
2425
import bigframes.core as core
2526
import bigframes.core.blocks as blocks
2627
import bigframes.core.expression as ex
2728
import bigframes.core.ordering as ordering
2829
import bigframes.core.window_spec as windows
29-
import bigframes.dtypes
3030
import bigframes.dtypes as dtypes
3131
import bigframes.operations as ops
3232
import bigframes.operations.aggregations as agg_ops
@@ -133,7 +133,7 @@ def quantile(
133133
block, _ = block.aggregate(
134134
grouping_column_ids,
135135
tuple(
136-
ex.UnaryAggregation(agg_ops.AnyValueOp(), ex.deref(col))
136+
expression_types.UnaryAggregation(agg_ops.AnyValueOp(), ex.deref(col))
137137
for col in quantile_cols
138138
),
139139
column_labels=pd.Index(labels),
@@ -363,7 +363,7 @@ def value_counts(
363363
block = dropna(block, columns, how="any")
364364
block, agg_ids = block.aggregate(
365365
by_column_ids=(*grouping_keys, *columns),
366-
aggregations=[ex.NullaryAggregation(agg_ops.size_op)],
366+
aggregations=[expression_types.NullaryAggregation(agg_ops.size_op)],
367367
dropna=drop_na and not grouping_keys,
368368
)
369369
count_id = agg_ids[0]
@@ -647,15 +647,15 @@ def skew(
647647
# counts, moment3 for each column
648648
aggregations = []
649649
for i, col in enumerate(original_columns):
650-
count_agg = ex.UnaryAggregation(
650+
count_agg = expression_types.UnaryAggregation(
651651
agg_ops.count_op,
652652
ex.deref(col),
653653
)
654-
moment3_agg = ex.UnaryAggregation(
654+
moment3_agg = expression_types.UnaryAggregation(
655655
agg_ops.mean_op,
656656
ex.deref(delta3_ids[i]),
657657
)
658-
variance_agg = ex.UnaryAggregation(
658+
variance_agg = expression_types.UnaryAggregation(
659659
agg_ops.PopVarOp(),
660660
ex.deref(col),
661661
)
@@ -698,9 +698,13 @@ def kurt(
698698
# counts, moment4 for each column
699699
aggregations = []
700700
for i, col in enumerate(original_columns):
701-
count_agg = ex.UnaryAggregation(agg_ops.count_op, ex.deref(col))
702-
moment4_agg = ex.UnaryAggregation(agg_ops.mean_op, ex.deref(delta4_ids[i]))
703-
variance_agg = ex.UnaryAggregation(agg_ops.PopVarOp(), ex.deref(col))
701+
count_agg = expression_types.UnaryAggregation(agg_ops.count_op, ex.deref(col))
702+
moment4_agg = expression_types.UnaryAggregation(
703+
agg_ops.mean_op, ex.deref(delta4_ids[i])
704+
)
705+
variance_agg = expression_types.UnaryAggregation(
706+
agg_ops.PopVarOp(), ex.deref(col)
707+
)
704708
aggregations.extend([count_agg, moment4_agg, variance_agg])
705709

706710
block, agg_ids = block.aggregate(

bigframes/core/blocks.py

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -51,11 +51,12 @@
5151
from bigframes import session
5252
from bigframes._config import sampling_options
5353
import bigframes.constants
54-
from bigframes.core import local_data
54+
from bigframes.core import expression_types, local_data
5555
import bigframes.core as core
5656
import bigframes.core.compile.googlesql as googlesql
5757
import bigframes.core.expression as ex
5858
import bigframes.core.expression as scalars
59+
import bigframes.core.expression_types as ex_types
5960
import bigframes.core.guid as guid
6061
import bigframes.core.identifiers
6162
import bigframes.core.join_def as join_defs
@@ -1143,7 +1144,7 @@ def apply_window_op(
11431144
skip_reproject_unsafe: bool = False,
11441145
never_skip_nulls: bool = False,
11451146
) -> typing.Tuple[Block, str]:
1146-
agg_expr = ex.UnaryAggregation(op, ex.deref(column))
1147+
agg_expr = expression_types.UnaryAggregation(op, ex.deref(column))
11471148
return self.apply_analytic(
11481149
agg_expr,
11491150
window_spec,
@@ -1155,7 +1156,7 @@ def apply_window_op(
11551156

11561157
def apply_analytic(
11571158
self,
1158-
agg_expr: ex.Aggregation,
1159+
agg_expr: expression_types.Aggregation,
11591160
window: windows.WindowSpec,
11601161
result_label: Label,
11611162
*,
@@ -1248,9 +1249,9 @@ def aggregate_all_and_stack(
12481249
if axis_n == 0:
12491250
aggregations = [
12501251
(
1251-
ex.UnaryAggregation(operation, ex.deref(col_id))
1252+
expression_types.UnaryAggregation(operation, ex.deref(col_id))
12521253
if isinstance(operation, agg_ops.UnaryAggregateOp)
1253-
else ex.NullaryAggregation(operation),
1254+
else expression_types.NullaryAggregation(operation),
12541255
col_id,
12551256
)
12561257
for col_id in self.value_columns
@@ -1279,7 +1280,10 @@ def aggregate_size(
12791280
):
12801281
"""Returns a block object to compute the size(s) of groups."""
12811282
agg_specs = [
1282-
(ex.NullaryAggregation(agg_ops.SizeOp()), guid.generate_guid()),
1283+
(
1284+
expression_types.NullaryAggregation(agg_ops.SizeOp()),
1285+
guid.generate_guid(),
1286+
),
12831287
]
12841288
output_col_ids = [agg_spec[1] for agg_spec in agg_specs]
12851289
result_expr = self.expr.aggregate(agg_specs, by_column_ids, dropna=dropna)
@@ -1350,7 +1354,7 @@ def remap_f(x):
13501354
def aggregate(
13511355
self,
13521356
by_column_ids: typing.Sequence[str] = (),
1353-
aggregations: typing.Sequence[ex.Aggregation] = (),
1357+
aggregations: typing.Sequence[expression_types.Aggregation] = (),
13541358
column_labels: Optional[pd.Index] = None,
13551359
*,
13561360
dropna: bool = True,
@@ -1419,9 +1423,9 @@ def get_stat(
14191423

14201424
aggregations = [
14211425
(
1422-
ex.UnaryAggregation(stat, ex.deref(column_id))
1426+
expression_types.UnaryAggregation(stat, ex.deref(column_id))
14231427
if isinstance(stat, agg_ops.UnaryAggregateOp)
1424-
else ex.NullaryAggregation(stat),
1428+
else expression_types.NullaryAggregation(stat),
14251429
stat.name,
14261430
)
14271431
for stat in stats_to_fetch
@@ -1447,7 +1451,7 @@ def get_binary_stat(
14471451
# TODO(kemppeterson): Add a cache here.
14481452
aggregations = [
14491453
(
1450-
ex.BinaryAggregation(
1454+
expression_types.BinaryAggregation(
14511455
stat, ex.deref(column_id_left), ex.deref(column_id_right)
14521456
),
14531457
f"{stat.name}_{column_id_left}{column_id_right}",
@@ -1474,9 +1478,9 @@ def summarize(
14741478
labels = pd.Index([stat.name for stat in stats])
14751479
aggregations = [
14761480
(
1477-
ex.UnaryAggregation(stat, ex.deref(col_id))
1481+
expression_types.UnaryAggregation(stat, ex.deref(col_id))
14781482
if isinstance(stat, agg_ops.UnaryAggregateOp)
1479-
else ex.NullaryAggregation(stat),
1483+
else expression_types.NullaryAggregation(stat),
14801484
f"{col_id}-{stat.name}",
14811485
)
14821486
for stat in stats
@@ -1750,7 +1754,7 @@ def pivot(
17501754

17511755
block = block.select_columns(column_ids)
17521756
aggregations = [
1753-
ex.UnaryAggregation(agg_ops.AnyValueOp(), ex.deref(col_id))
1757+
expression_types.UnaryAggregation(agg_ops.AnyValueOp(), ex.deref(col_id))
17541758
for col_id in column_ids
17551759
]
17561760
result_block, _ = block.aggregate(
@@ -2018,7 +2022,7 @@ def _generate_resample_label(
20182022

20192023
agg_specs = [
20202024
(
2021-
ex.UnaryAggregation(agg_ops.min_op, ex.deref(col_id)),
2025+
expression_types.UnaryAggregation(agg_ops.min_op, ex.deref(col_id)),
20222026
guid.generate_guid(),
20232027
),
20242028
]
@@ -2047,13 +2051,13 @@ def _generate_resample_label(
20472051
# Generate integer label sequence.
20482052
min_agg_specs = [
20492053
(
2050-
ex.UnaryAggregation(agg_ops.min_op, ex.deref(label_col_id)),
2054+
ex_types.UnaryAggregation(agg_ops.min_op, ex.deref(label_col_id)),
20512055
guid.generate_guid(),
20522056
),
20532057
]
20542058
max_agg_specs = [
20552059
(
2056-
ex.UnaryAggregation(agg_ops.max_op, ex.deref(label_col_id)),
2060+
ex_types.UnaryAggregation(agg_ops.max_op, ex.deref(label_col_id)),
20572061
guid.generate_guid(),
20582062
),
20592063
]

bigframes/core/compile/compiled.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
import bigframes.core.compile.ibis_compiler.scalar_op_compiler as op_compilers
3636
import bigframes.core.compile.ibis_types
3737
import bigframes.core.expression as ex
38+
import bigframes.core.expression_types as ex_types
3839
from bigframes.core.ordering import OrderingExpression
3940
import bigframes.core.sql
4041
from bigframes.core.window_spec import RangeWindowBounds, RowsWindowBounds, WindowSpec
@@ -215,7 +216,7 @@ def filter(self, predicate: ex.Expression) -> UnorderedIR:
215216

216217
def aggregate(
217218
self,
218-
aggregations: typing.Sequence[tuple[ex.Aggregation, str]],
219+
aggregations: typing.Sequence[tuple[ex_types.Aggregation, str]],
219220
by_column_ids: typing.Sequence[ex.DerefOp] = (),
220221
order_by: typing.Sequence[OrderingExpression] = (),
221222
) -> UnorderedIR:
@@ -401,7 +402,7 @@ def isin_join(
401402

402403
def project_window_op(
403404
self,
404-
expression: ex.Aggregation,
405+
expression: ex_types.Aggregation,
405406
window_spec: WindowSpec,
406407
output_name: str,
407408
*,
@@ -467,7 +468,9 @@ def project_window_op(
467468
lambda x, y: x & y, per_col_does_count
468469
).cast(int)
469470
observation_count = agg_compiler.compile_analytic(
470-
ex.UnaryAggregation(agg_ops.sum_op, ex.deref("_observation_count")),
471+
ex_types.UnaryAggregation(
472+
agg_ops.sum_op, ex.deref("_observation_count")
473+
),
471474
window,
472475
bindings={"_observation_count": is_observation},
473476
)
@@ -476,7 +479,7 @@ def project_window_op(
476479
# notnull is just used to convert null values to non-null (FALSE) values to be counted
477480
is_observation = inputs[0].notnull()
478481
observation_count = agg_compiler.compile_analytic(
479-
ex.UnaryAggregation(
482+
ex_types.UnaryAggregation(
480483
agg_ops.count_op, ex.deref("_observation_count")
481484
),
482485
window,

bigframes/core/compile/ibis_compiler/aggregate_compiler.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,10 @@
2626
import bigframes_vendored.ibis.expr.types as ibis_types
2727
import pandas as pd
2828

29+
from bigframes.core import expression_types
2930
from bigframes.core.compile import constants as compiler_constants
3031
import bigframes.core.compile.ibis_compiler.scalar_op_compiler as scalar_compilers
3132
import bigframes.core.compile.ibis_types as compile_ibis_types
32-
import bigframes.core.expression as ex
3333
import bigframes.core.window_spec as window_spec
3434
import bigframes.operations.aggregations as agg_ops
3535

@@ -48,19 +48,19 @@ def approx_quantiles(expression: float, number) -> List[float]:
4848

4949

5050
def compile_aggregate(
51-
aggregate: ex.Aggregation,
51+
aggregate: expression_types.Aggregation,
5252
bindings: typing.Dict[str, ibis_types.Value],
5353
order_by: typing.Sequence[ibis_types.Value] = [],
5454
) -> ibis_types.Value:
55-
if isinstance(aggregate, ex.NullaryAggregation):
55+
if isinstance(aggregate, expression_types.NullaryAggregation):
5656
return compile_nullary_agg(aggregate.op)
57-
if isinstance(aggregate, ex.UnaryAggregation):
57+
if isinstance(aggregate, expression_types.UnaryAggregation):
5858
input = scalar_compiler.compile_expression(aggregate.arg, bindings=bindings)
5959
if not aggregate.op.order_independent:
6060
return compile_ordered_unary_agg(aggregate.op, input, order_by=order_by) # type: ignore
6161
else:
6262
return compile_unary_agg(aggregate.op, input) # type: ignore
63-
elif isinstance(aggregate, ex.BinaryAggregation):
63+
elif isinstance(aggregate, expression_types.BinaryAggregation):
6464
left = scalar_compiler.compile_expression(aggregate.left, bindings=bindings)
6565
right = scalar_compiler.compile_expression(aggregate.right, bindings=bindings)
6666
return compile_binary_agg(aggregate.op, left, right) # type: ignore
@@ -69,16 +69,16 @@ def compile_aggregate(
6969

7070

7171
def compile_analytic(
72-
aggregate: ex.Aggregation,
72+
aggregate: expression_types.Aggregation,
7373
window: window_spec.WindowSpec,
7474
bindings: typing.Dict[str, ibis_types.Value],
7575
) -> ibis_types.Value:
76-
if isinstance(aggregate, ex.NullaryAggregation):
76+
if isinstance(aggregate, expression_types.NullaryAggregation):
7777
return compile_nullary_agg(aggregate.op, window)
78-
elif isinstance(aggregate, ex.UnaryAggregation):
78+
elif isinstance(aggregate, expression_types.UnaryAggregation):
7979
input = scalar_compiler.compile_expression(aggregate.arg, bindings=bindings)
8080
return compile_unary_agg(aggregate.op, input, window) # type: ignore
81-
elif isinstance(aggregate, ex.BinaryAggregation):
81+
elif isinstance(aggregate, expression_types.BinaryAggregation):
8282
raise NotImplementedError("binary analytic operations not yet supported")
8383
else:
8484
raise ValueError(f"Unexpected analytic operation: {aggregate}")

0 commit comments

Comments
 (0)