Skip to content

Commit c3ef622

Browse files
improve expression factoring to create multi-col window nodes
1 parent 204be23 commit c3ef622

File tree

2 files changed

+69
-41
lines changed

2 files changed

+69
-41
lines changed

bigframes/core/array_value.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -268,8 +268,7 @@ def compute_values(self, assignments: Sequence[ex.Expression]):
268268

269269
def compute_general_expression(self, assignments: Sequence[ex.Expression]):
270270
named_exprs = [
271-
expression_factoring.NamedExpression(expr, ids.ColumnId.unique())
272-
for expr in assignments
271+
nodes.ColumnDef(expr, ids.ColumnId.unique()) for expr in assignments
273272
]
274273
# TODO: Push this to rewrite later to go from block expression to planning form
275274
# TODO: Jointly fragmentize expressions to more efficiently reuse common sub-expressions
@@ -279,7 +278,7 @@ def compute_general_expression(self, assignments: Sequence[ex.Expression]):
279278
for expr in named_exprs
280279
)
281280
)
282-
target_ids = tuple(named_expr.name for named_expr in named_exprs)
281+
target_ids = tuple(named_expr.id for named_expr in named_exprs)
283282
new_root = expression_factoring.push_into_tree(self.node, fragments, target_ids)
284283
return (ArrayValue(new_root), target_ids)
285284

bigframes/core/expression_factoring.py

Lines changed: 67 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,26 @@
11
import collections
22
import dataclasses
33
import functools
4-
import itertools
5-
from typing import Generic, Hashable, Iterable, Optional, Sequence, Tuple, TypeVar
4+
from typing import cast, Generic, Hashable, Iterable, Optional, Sequence, Tuple, TypeVar
65

7-
from bigframes.core import agg_expressions, expression, identifiers, nodes
6+
from bigframes.core import agg_expressions, expression, identifiers, nodes, window_spec
87

98
_MAX_INLINE_COMPLEXITY = 10
109

1110

12-
@dataclasses.dataclass(frozen=True, eq=False)
13-
class NamedExpression:
14-
expr: expression.Expression
15-
name: identifiers.ColumnId
16-
17-
1811
@dataclasses.dataclass(frozen=True, eq=False)
1912
class FactoredExpression:
2013
root_expr: expression.Expression
21-
sub_exprs: Tuple[NamedExpression, ...]
14+
sub_exprs: Tuple[nodes.ColumnDef, ...]
2215

2316

24-
def fragmentize_expression(root: NamedExpression) -> Sequence[NamedExpression]:
17+
def fragmentize_expression(root: nodes.ColumnDef) -> Sequence[nodes.ColumnDef]:
2518
"""
2619
The goal of this functions is to factor out an expression into multiple sub-expressions.
2720
"""
2821

29-
factored_expr = root.expr.reduce_up(gather_fragments)
30-
root_expr = NamedExpression(factored_expr.root_expr, root.name)
22+
factored_expr = root.expression.reduce_up(gather_fragments)
23+
root_expr = nodes.ColumnDef(factored_expr.root_expr, root.id)
3124
return (root_expr, *factored_expr.sub_exprs)
3225

3326

@@ -48,7 +41,7 @@ def gather_fragments(
4841
if not do_inline:
4942
id = identifiers.ColumnId.unique()
5043
replacements.append(expression.DerefOp(id))
51-
named_exprs.append(NamedExpression(child_result.root_expr, id))
44+
named_exprs.append(nodes.ColumnDef(child_result.root_expr, id))
5245
named_exprs.extend(child_result.sub_exprs)
5346
else:
5447
replacements.append(child_result.root_expr)
@@ -116,24 +109,34 @@ def remove_node(self, node: T) -> None:
116109

117110
def push_into_tree(
118111
root: nodes.BigFrameNode,
119-
exprs: Sequence[NamedExpression],
112+
exprs: Sequence[nodes.ColumnDef],
120113
target_ids: Sequence[identifiers.ColumnId],
121114
) -> nodes.BigFrameNode:
122115
curr_root = root
123-
by_id = {expr.name: expr for expr in exprs}
116+
by_id = {expr.id: expr for expr in exprs}
124117
# id -> id
125118
graph = DiGraph(
126-
(expr.name, child_id)
119+
(expr.id, child_id)
127120
for expr in exprs
128-
for child_id in expr.expr.column_references
121+
for child_id in expr.expression.column_references
129122
if child_id in by_id.keys()
130123
)
131124
# TODO: Also prevent inlining expensive or non-deterministic
132125
# We avoid inlining multi-parent ids, as they would be inlined multiple places, potentially increasing work and/or compiled text size
133126
multi_parent_ids = set(id for id in graph.nodes if len(graph.parents(id)) > 2)
134-
scalar_ids = set(expr.name for expr in exprs if expr.expr.is_scalar_expr)
127+
scalar_ids = set(expr.id for expr in exprs if expr.expression.is_scalar_expr)
135128

136-
def graph_extract_scalar_exprs() -> Sequence[NamedExpression]:
129+
analytic_defs = filter(
130+
lambda x: isinstance(x.expression, agg_expressions.WindowExpression), exprs
131+
)
132+
analytic_by_window = grouped(
133+
map(
134+
lambda x: (cast(agg_expressions.WindowExpression, x.expression).window, x),
135+
analytic_defs,
136+
)
137+
)
138+
139+
def graph_extract_scalar_exprs() -> Sequence[nodes.ColumnDef]:
137140
results: dict[identifiers.ColumnId, expression.Expression] = dict()
138141
while (
139142
True
@@ -156,40 +159,55 @@ def graph_extract_scalar_exprs() -> Sequence[NamedExpression]:
156159
for id in candidate_ids:
157160
graph.remove_node(id)
158161
new_exprs = {
159-
id: by_id[id].expr.bind_refs(results, allow_partial_bindings=True)
162+
id: by_id[id].expression.bind_refs(
163+
results, allow_partial_bindings=True
164+
)
160165
}
161166
results.update(new_exprs)
162167
# TODO: We can prune expressions that won't be reused here,
163-
return tuple(NamedExpression(expr, id) for id, expr in results.items())
168+
return tuple(nodes.ColumnDef(expr, id) for id, expr in results.items())
164169

165170
def graph_extract_window_expr() -> Optional[
166-
Tuple[identifiers.ColumnId, agg_expressions.WindowExpression]
171+
Tuple[Sequence[nodes.ColumnDef], window_spec.WindowSpec]
167172
]:
168-
candidate = list(
169-
itertools.islice((id for id in graph.sinks if id not in scalar_ids), 1)
170-
)
171-
if not candidate:
172-
return None
173-
else:
174-
id = next(iter(candidate))
175-
graph.remove_node(id)
176-
result_expr = by_id[id].expr
177-
assert isinstance(result_expr, agg_expressions.WindowExpression)
178-
return (id, result_expr)
173+
for id in graph.sinks:
174+
next_def = by_id[id]
175+
if isinstance(next_def.expression, agg_expressions.WindowExpression):
176+
window = next_def.expression.window
177+
window_exprs = [
178+
cdef
179+
for cdef in analytic_by_window[window]
180+
if cdef.id in graph.sinks
181+
]
182+
agg_exprs = tuple(
183+
nodes.ColumnDef(
184+
cast(
185+
agg_expressions.WindowExpression, cdef.expression
186+
).analytic_expr,
187+
cdef.id,
188+
)
189+
for cdef in window_exprs
190+
)
191+
for cdef in window_exprs:
192+
graph.remove_node(cdef.id)
193+
return (agg_exprs, window)
194+
195+
return None
179196

180197
while not graph.empty:
181198
pre_size = len(graph.nodes)
182199
scalar_exprs = graph_extract_scalar_exprs()
183200
if scalar_exprs:
184201
curr_root = nodes.ProjectionNode(
185-
curr_root, tuple((x.expr, x.name) for x in scalar_exprs)
202+
curr_root, tuple((x.expression, x.id) for x in scalar_exprs)
186203
)
187204
while result := graph_extract_window_expr():
188-
id, window_expr = result
205+
defs, window = result
206+
assert len(defs) > 0
189207
curr_root = nodes.WindowOpNode(
190208
curr_root,
191-
(nodes.ColumnDef(window_expr.analytic_expr, id),),
192-
window_expr.window,
209+
tuple(defs),
210+
window,
193211
)
194212
if len(graph.nodes) >= pre_size:
195213
raise ValueError("graph didn't shrink")
@@ -210,3 +228,14 @@ def is_simple(expr: expression.Expression) -> bool:
210228
if count > _MAX_INLINE_COMPLEXITY:
211229
return False
212230
return True
231+
232+
233+
K = TypeVar("K", bound=Hashable)
234+
V = TypeVar("V")
235+
236+
237+
def grouped(values: Iterable[tuple[K, V]]) -> dict[K, list[V]]:
238+
result = collections.defaultdict(list)
239+
for k, v in values:
240+
result[k].append(v)
241+
return result

0 commit comments

Comments
 (0)