Skip to content

Commit a3edeab

Browse files
refactor: Aggregation is now an expression subclass (#2048)
1 parent f7196d1 commit a3edeab

File tree

24 files changed

+360
-245
lines changed

24 files changed

+360
-245
lines changed

bigframes/core/agg_expressions.py

Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
# Copyright 2023 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import annotations
16+
17+
import abc
18+
import dataclasses
19+
import functools
20+
import itertools
21+
import typing
22+
from typing import Callable, Mapping, TypeVar
23+
24+
from bigframes import dtypes
25+
from bigframes.core import expression
26+
import bigframes.core.identifiers as ids
27+
import bigframes.operations.aggregations as agg_ops
28+
29+
TExpression = TypeVar("TExpression", bound="Aggregation")
30+
31+
32+
@dataclasses.dataclass(frozen=True)
33+
class Aggregation(expression.Expression):
34+
"""Represents windowing or aggregation over a column."""
35+
36+
op: agg_ops.WindowOp = dataclasses.field()
37+
38+
@property
39+
def column_references(self) -> typing.Tuple[ids.ColumnId, ...]:
40+
return tuple(
41+
itertools.chain.from_iterable(
42+
map(lambda x: x.column_references, self.inputs)
43+
)
44+
)
45+
46+
@functools.cached_property
47+
def is_resolved(self) -> bool:
48+
return all(input.is_resolved for input in self.inputs)
49+
50+
@functools.cached_property
51+
def output_type(self) -> dtypes.ExpressionType:
52+
if not self.is_resolved:
53+
raise ValueError(f"Type of expression {self.op} has not been fixed.")
54+
55+
input_types = [input.output_type for input in self.inputs]
56+
57+
return self.op.output_type(*input_types)
58+
59+
@property
60+
@abc.abstractmethod
61+
def inputs(
62+
self,
63+
) -> typing.Tuple[expression.Expression, ...]:
64+
...
65+
66+
@property
67+
def free_variables(self) -> typing.Tuple[str, ...]:
68+
return tuple(
69+
itertools.chain.from_iterable(map(lambda x: x.free_variables, self.inputs))
70+
)
71+
72+
@property
73+
def is_const(self) -> bool:
74+
return all(child.is_const for child in self.inputs)
75+
76+
@abc.abstractmethod
77+
def replace_args(self: TExpression, *arg) -> TExpression:
78+
...
79+
80+
def transform_children(
81+
self: TExpression, t: Callable[[expression.Expression], expression.Expression]
82+
) -> TExpression:
83+
return self.replace_args(*(t(arg) for arg in self.inputs))
84+
85+
def bind_variables(
86+
self: TExpression,
87+
bindings: Mapping[str, expression.Expression],
88+
allow_partial_bindings: bool = False,
89+
) -> TExpression:
90+
return self.transform_children(
91+
lambda x: x.bind_variables(bindings, allow_partial_bindings)
92+
)
93+
94+
def bind_refs(
95+
self: TExpression,
96+
bindings: Mapping[ids.ColumnId, expression.Expression],
97+
allow_partial_bindings: bool = False,
98+
) -> TExpression:
99+
return self.transform_children(
100+
lambda x: x.bind_refs(bindings, allow_partial_bindings)
101+
)
102+
103+
104+
@dataclasses.dataclass(frozen=True)
105+
class NullaryAggregation(Aggregation):
106+
op: agg_ops.NullaryWindowOp = dataclasses.field()
107+
108+
@property
109+
def inputs(
110+
self,
111+
) -> typing.Tuple[expression.Expression, ...]:
112+
return ()
113+
114+
def replace_args(self, *arg) -> NullaryAggregation:
115+
return self
116+
117+
118+
@dataclasses.dataclass(frozen=True)
119+
class UnaryAggregation(Aggregation):
120+
op: agg_ops.UnaryWindowOp
121+
arg: expression.Expression
122+
123+
@property
124+
def inputs(
125+
self,
126+
) -> typing.Tuple[expression.Expression, ...]:
127+
return (self.arg,)
128+
129+
def replace_args(self, arg: expression.Expression) -> UnaryAggregation:
130+
return UnaryAggregation(
131+
self.op,
132+
arg,
133+
)
134+
135+
136+
@dataclasses.dataclass(frozen=True)
137+
class BinaryAggregation(Aggregation):
138+
op: agg_ops.BinaryAggregateOp = dataclasses.field()
139+
left: expression.Expression = dataclasses.field()
140+
right: expression.Expression = dataclasses.field()
141+
142+
@property
143+
def inputs(
144+
self,
145+
) -> typing.Tuple[expression.Expression, ...]:
146+
return (self.left, self.right)
147+
148+
def replace_args(
149+
self, larg: expression.Expression, rarg: expression.Expression
150+
) -> BinaryAggregation:
151+
return BinaryAggregation(self.op, larg, rarg)

bigframes/core/array_value.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import pandas
2525
import pyarrow as pa
2626

27+
from bigframes.core import agg_expressions
2728
import bigframes.core.expression as ex
2829
import bigframes.core.guid
2930
import bigframes.core.identifiers as ids
@@ -190,7 +191,7 @@ def row_count(self) -> ArrayValue:
190191
child=self.node,
191192
aggregations=(
192193
(
193-
ex.NullaryAggregation(agg_ops.size_op),
194+
agg_expressions.NullaryAggregation(agg_ops.size_op),
194195
ids.ColumnId(bigframes.core.guid.generate_guid()),
195196
),
196197
),
@@ -379,7 +380,7 @@ def drop_columns(self, columns: Iterable[str]) -> ArrayValue:
379380

380381
def aggregate(
381382
self,
382-
aggregations: typing.Sequence[typing.Tuple[ex.Aggregation, str]],
383+
aggregations: typing.Sequence[typing.Tuple[agg_expressions.Aggregation, str]],
383384
by_column_ids: typing.Sequence[str] = (),
384385
dropna: bool = True,
385386
) -> ArrayValue:
@@ -420,15 +421,15 @@ def project_window_op(
420421
"""
421422

422423
return self.project_window_expr(
423-
ex.UnaryAggregation(op, ex.deref(column_name)),
424+
agg_expressions.UnaryAggregation(op, ex.deref(column_name)),
424425
window_spec,
425426
never_skip_nulls,
426427
skip_reproject_unsafe,
427428
)
428429

429430
def project_window_expr(
430431
self,
431-
expression: ex.Aggregation,
432+
expression: agg_expressions.Aggregation,
432433
window: WindowSpec,
433434
never_skip_nulls=False,
434435
skip_reproject_unsafe: bool = False,

bigframes/core/bigframe_node.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,12 @@
2020
import functools
2121
import itertools
2222
import typing
23-
from typing import Callable, Dict, Generator, Iterable, Mapping, Sequence, Tuple, Union
23+
from typing import Callable, Dict, Generator, Iterable, Mapping, Sequence, Tuple
2424

2525
from bigframes.core import expression, field, identifiers
2626
import bigframes.core.schema as schemata
2727
import bigframes.dtypes
2828

29-
if typing.TYPE_CHECKING:
30-
import bigframes.session
31-
3229
COLUMN_SET = frozenset[identifiers.ColumnId]
3330

3431
T = typing.TypeVar("T")
@@ -281,8 +278,8 @@ def field_by_id(self) -> Mapping[identifiers.ColumnId, field.Field]:
281278
@property
282279
def _node_expressions(
283280
self,
284-
) -> Sequence[Union[expression.Expression, expression.Aggregation]]:
285-
"""List of scalar expressions. Intended for checking engine compatibility with used ops."""
281+
) -> Sequence[expression.Expression]:
282+
"""List of expressions. Intended for checking engine compatibility with used ops."""
286283
return ()
287284

288285
# Plan algorithms

bigframes/core/block_transforms.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,12 @@
2121
import pandas as pd
2222

2323
import bigframes.constants
24+
from bigframes.core import agg_expressions
2425
import bigframes.core as core
2526
import bigframes.core.blocks as blocks
2627
import bigframes.core.expression as ex
2728
import bigframes.core.ordering as ordering
2829
import bigframes.core.window_spec as windows
29-
import bigframes.dtypes
3030
import bigframes.dtypes as dtypes
3131
import bigframes.operations as ops
3232
import bigframes.operations.aggregations as agg_ops
@@ -133,7 +133,7 @@ def quantile(
133133
block, _ = block.aggregate(
134134
grouping_column_ids,
135135
tuple(
136-
ex.UnaryAggregation(agg_ops.AnyValueOp(), ex.deref(col))
136+
agg_expressions.UnaryAggregation(agg_ops.AnyValueOp(), ex.deref(col))
137137
for col in quantile_cols
138138
),
139139
column_labels=pd.Index(labels),
@@ -363,7 +363,7 @@ def value_counts(
363363
block = dropna(block, columns, how="any")
364364
block, agg_ids = block.aggregate(
365365
by_column_ids=(*grouping_keys, *columns),
366-
aggregations=[ex.NullaryAggregation(agg_ops.size_op)],
366+
aggregations=[agg_expressions.NullaryAggregation(agg_ops.size_op)],
367367
dropna=drop_na and not grouping_keys,
368368
)
369369
count_id = agg_ids[0]
@@ -647,15 +647,15 @@ def skew(
647647
# counts, moment3 for each column
648648
aggregations = []
649649
for i, col in enumerate(original_columns):
650-
count_agg = ex.UnaryAggregation(
650+
count_agg = agg_expressions.UnaryAggregation(
651651
agg_ops.count_op,
652652
ex.deref(col),
653653
)
654-
moment3_agg = ex.UnaryAggregation(
654+
moment3_agg = agg_expressions.UnaryAggregation(
655655
agg_ops.mean_op,
656656
ex.deref(delta3_ids[i]),
657657
)
658-
variance_agg = ex.UnaryAggregation(
658+
variance_agg = agg_expressions.UnaryAggregation(
659659
agg_ops.PopVarOp(),
660660
ex.deref(col),
661661
)
@@ -698,9 +698,13 @@ def kurt(
698698
# counts, moment4 for each column
699699
aggregations = []
700700
for i, col in enumerate(original_columns):
701-
count_agg = ex.UnaryAggregation(agg_ops.count_op, ex.deref(col))
702-
moment4_agg = ex.UnaryAggregation(agg_ops.mean_op, ex.deref(delta4_ids[i]))
703-
variance_agg = ex.UnaryAggregation(agg_ops.PopVarOp(), ex.deref(col))
701+
count_agg = agg_expressions.UnaryAggregation(agg_ops.count_op, ex.deref(col))
702+
moment4_agg = agg_expressions.UnaryAggregation(
703+
agg_ops.mean_op, ex.deref(delta4_ids[i])
704+
)
705+
variance_agg = agg_expressions.UnaryAggregation(
706+
agg_ops.PopVarOp(), ex.deref(col)
707+
)
704708
aggregations.extend([count_agg, moment4_agg, variance_agg])
705709

706710
block, agg_ids = block.aggregate(

0 commit comments

Comments
 (0)