Skip to content

Commit 9eddf47

Browse files
refactor: Reuse expression factoring algo for implicit joins
1 parent 798af4a commit 9eddf47

File tree

3 files changed

+262
-103
lines changed

3 files changed

+262
-103
lines changed

bigframes/core/expression_factoring.py

Lines changed: 153 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,13 @@
2525
Hashable,
2626
Iterable,
2727
Iterator,
28+
Literal,
2829
Mapping,
2930
Optional,
3031
Sequence,
3132
Tuple,
3233
TypeVar,
34+
Union,
3335
)
3436

3537
from bigframes.core import (
@@ -38,12 +40,158 @@
3840
graphs,
3941
identifiers,
4042
nodes,
43+
subquery_expression,
4144
window_spec,
4245
)
46+
import bigframes.core.ordered_sets as sets
4347

4448
_MAX_INLINE_COMPLEXITY = 10
4549

4650
T = TypeVar("T")
51+
ExprDomain = Union[window_spec.WindowSpec, Literal["Scalar", "Other"]]
52+
53+
54+
class ExpressionGraph(graphs.DiGraph[nodes.ColumnDef]):
55+
def __init__(self, column_defs: Sequence[nodes.ColumnDef]):
56+
# Assumption: All column defs have unique ids
57+
expr_ids = set(cdef.id for cdef in column_defs)
58+
self._graph = graphs.DiGraph(
59+
(expr.id for expr in column_defs),
60+
(
61+
(expr.id, child_id)
62+
for expr in column_defs
63+
for child_id in expr.expression.column_references
64+
if child_id in expr_ids
65+
),
66+
)
67+
self._id_to_cdef = {cdef.id: cdef for cdef in column_defs}
68+
69+
# TODO: Also prevent inlining expensive or non-deterministic
70+
# We avoid inlining multi-parent ids, as they would be inlined multiple places, potentially increasing work and/or compiled text size
71+
self._multi_parent_ids = set(
72+
id
73+
for id in self._graph.graph_nodes
74+
if len(list(self._graph.parents(id))) > 2
75+
)
76+
self._free_ids_by_domain: dict[
77+
ExprDomain, sets.InsertionOrderedSet[identifiers.ColumnId]
78+
] = collections.defaultdict(sets.InsertionOrderedSet)
79+
80+
for id in self._graph.graph_nodes:
81+
if len(list(self._graph.children(id))) == 0:
82+
self._mark_free(id)
83+
84+
@property
85+
def graph_nodes(self) -> Iterable[nodes.ColumnDef]:
86+
# should be the same set of ids as self._parents
87+
return map(self._id_to_cdef.__getitem__, self._graph.graph_nodes)
88+
89+
@property
90+
def empty(self):
91+
return self._graph.empty
92+
93+
def __len__(self):
94+
return len(self._graph)
95+
96+
def parents(self, node: nodes.ColumnDef) -> Iterator[nodes.ColumnDef]:
97+
yield from map(self._id_to_cdef.__getitem__, self._graph.parents(node.id))
98+
99+
def children(self, node: nodes.ColumnDef) -> Iterator[nodes.ColumnDef]:
100+
yield from map(self._id_to_cdef.__getitem__, self._graph.children(node.id))
101+
102+
def _expr_domain(self, expr: expression.Expression) -> ExprDomain:
103+
if expr.is_scalar_expr:
104+
return "Scalar"
105+
elif isinstance(expr, agg_expressions.WindowExpression):
106+
return expr.window
107+
elif isinstance(expr, subquery_expression.SubqueryExpression):
108+
return "Other"
109+
else:
110+
raise ValueError(f"unrecognized expression {expr}")
111+
112+
def _mark_free(self, id: identifiers.ColumnId):
113+
cdef = self._id_to_cdef[id]
114+
expr = cdef.expression
115+
# If this expands further, probably generalize a compatibility key
116+
self._free_ids_by_domain[self._expr_domain(expr)].add(id)
117+
118+
def _remove_free_mark(self, id: identifiers.ColumnId):
119+
cdef = self._id_to_cdef[id]
120+
expr = cdef.expression
121+
# If this expands further, probably generalize a compatibility key
122+
if id in self._free_ids_by_domain[self._expr_domain(expr)]:
123+
self._free_ids_by_domain[self._expr_domain(expr)].remove(id)
124+
125+
def remove_node(self, node: nodes.ColumnDef) -> None:
126+
for child in self._children[node]:
127+
self._parents[child].remove(node)
128+
for parent in self._parents[node]:
129+
self._children[parent].remove(node)
130+
if len(self._children[parent]) == 0:
131+
self._mark_free(parent.id)
132+
del self._children[node]
133+
del self._parents[node]
134+
self._remove_free_mark(node.id)
135+
136+
def extract_scalar_exprs(self) -> Sequence[nodes.ColumnDef]:
137+
results: dict[identifiers.ColumnId, expression.Expression] = dict()
138+
while (
139+
True
140+
): # Will converge as each loop either reduces graph size, or fails to find any candidate and breaks
141+
candidate_ids = list(
142+
id
143+
for id in self._free_ids_by_domain["Scalar"]
144+
if not any(
145+
(
146+
child in self._multi_parent_ids
147+
and id in results.keys()
148+
and not is_simple(results[id])
149+
)
150+
for child in self._graph.children(id)
151+
)
152+
)
153+
if len(candidate_ids) == 0:
154+
break
155+
for id in candidate_ids:
156+
self._graph.remove_node(id)
157+
new_exprs = {
158+
id: self._id_to_cdef[id].expression.bind_refs(
159+
results, allow_partial_bindings=True
160+
)
161+
}
162+
results.update(new_exprs)
163+
# TODO: We can prune expressions that won't be reused here,
164+
return tuple(nodes.ColumnDef(expr, id) for id, expr in results.items())
165+
166+
def extract_window_expr(
167+
self,
168+
) -> Optional[Tuple[Sequence[nodes.ColumnDef], window_spec.WindowSpec]]:
169+
window = next(
170+
(
171+
domain
172+
for domain in self._free_ids_by_domain
173+
if domain not in ["Scalar", "Other"]
174+
),
175+
None,
176+
)
177+
assert not isinstance(window, str)
178+
if window:
179+
window_expr_ids = self._free_ids_by_domain[window]
180+
window_exprs = (self._id_to_cdef[id] for id in window_expr_ids)
181+
agg_exprs = tuple(
182+
nodes.ColumnDef(
183+
cast(
184+
agg_expressions.WindowExpression, cdef.expression
185+
).analytic_expr,
186+
cdef.id,
187+
)
188+
for cdef in window_exprs
189+
)
190+
for cdef in window_exprs:
191+
self.remove_node(cdef)
192+
return (agg_exprs, window)
193+
194+
return None
47195

48196

49197
def unique_nodes(
@@ -324,106 +472,25 @@ def push_into_tree(
324472
target_ids: Sequence[identifiers.ColumnId],
325473
) -> nodes.BigFrameNode:
326474
curr_root = root
327-
by_id = {expr.id: expr for expr in exprs}
328475
# id -> id
329-
graph = graphs.DiGraph(
330-
(expr.id for expr in exprs),
331-
(
332-
(expr.id, child_id)
333-
for expr in exprs
334-
for child_id in expr.expression.column_references
335-
if child_id in by_id.keys()
336-
),
337-
)
338-
# TODO: Also prevent inlining expensive or non-deterministic
339-
# We avoid inlining multi-parent ids, as they would be inlined multiple places, potentially increasing work and/or compiled text size
340-
multi_parent_ids = set(id for id in graph.nodes if len(list(graph.parents(id))) > 2)
341-
scalar_ids = set(expr.id for expr in exprs if expr.expression.is_scalar_expr)
342-
343-
analytic_defs = filter(
344-
lambda x: isinstance(x.expression, agg_expressions.WindowExpression), exprs
345-
)
346-
analytic_by_window = grouped(
347-
map(
348-
lambda x: (cast(agg_expressions.WindowExpression, x.expression).window, x),
349-
analytic_defs,
350-
)
351-
)
352-
353-
def graph_extract_scalar_exprs() -> Sequence[nodes.ColumnDef]:
354-
results: dict[identifiers.ColumnId, expression.Expression] = dict()
355-
while (
356-
True
357-
): # Will converge as each loop either reduces graph size, or fails to find any candidate and breaks
358-
candidate_ids = list(
359-
id
360-
for id in graph.sinks
361-
if (id in scalar_ids)
362-
and not any(
363-
(
364-
child in multi_parent_ids
365-
and id in results.keys()
366-
and not is_simple(results[id])
367-
)
368-
for child in graph.children(id)
369-
)
370-
)
371-
if len(candidate_ids) == 0:
372-
break
373-
for id in candidate_ids:
374-
graph.remove_node(id)
375-
new_exprs = {
376-
id: by_id[id].expression.bind_refs(
377-
results, allow_partial_bindings=True
378-
)
379-
}
380-
results.update(new_exprs)
381-
# TODO: We can prune expressions that won't be reused here,
382-
return tuple(nodes.ColumnDef(expr, id) for id, expr in results.items())
383-
384-
def graph_extract_window_expr() -> Optional[
385-
Tuple[Sequence[nodes.ColumnDef], window_spec.WindowSpec]
386-
]:
387-
for id in graph.sinks:
388-
next_def = by_id[id]
389-
if isinstance(next_def.expression, agg_expressions.WindowExpression):
390-
window = next_def.expression.window
391-
window_exprs = [
392-
cdef
393-
for cdef in analytic_by_window[window]
394-
if cdef.id in graph.sinks
395-
]
396-
agg_exprs = tuple(
397-
nodes.ColumnDef(
398-
cast(
399-
agg_expressions.WindowExpression, cdef.expression
400-
).analytic_expr,
401-
cdef.id,
402-
)
403-
for cdef in window_exprs
404-
)
405-
for cdef in window_exprs:
406-
graph.remove_node(cdef.id)
407-
return (agg_exprs, window)
408-
409-
return None
476+
graph = ExpressionGraph(exprs)
410477

411478
while not graph.empty:
412-
pre_size = len(graph.nodes)
413-
scalar_exprs = graph_extract_scalar_exprs()
479+
pre_size = len(graph)
480+
scalar_exprs = graph.extract_scalar_exprs()
414481
if scalar_exprs:
415482
curr_root = nodes.ProjectionNode(
416483
curr_root, tuple((x.expression, x.id) for x in scalar_exprs)
417484
)
418-
while result := graph_extract_window_expr():
485+
while result := graph.extract_window_expr():
419486
defs, window = result
420487
assert len(defs) > 0
421488
curr_root = nodes.WindowOpNode(
422489
curr_root,
423490
tuple(defs),
424491
window,
425492
)
426-
if len(graph.nodes) >= pre_size:
493+
if len(graph) >= pre_size:
427494
raise ValueError("graph didn't shrink")
428495
# TODO: Try to get the ordering right earlier, so can avoid this extra node.
429496
post_ids = (*root.ids, *target_ids)

bigframes/core/graphs.py

Lines changed: 8 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
import collections
16+
import collections.abc
1617
from typing import Dict, Generic, Hashable, Iterable, Iterator, Tuple, TypeVar
1718

1819
import bigframes.core.ordered_sets as sets
@@ -28,32 +29,26 @@ def __init__(self, nodes: Iterable[T], edges: Iterable[Tuple[T, T]]):
2829
self._children: Dict[T, sets.InsertionOrderedSet[T]] = collections.defaultdict(
2930
sets.InsertionOrderedSet
3031
)
31-
self._sinks: sets.InsertionOrderedSet[T] = sets.InsertionOrderedSet()
3232
for node in nodes:
3333
self._children[node]
3434
self._parents[node]
35-
self._sinks.add(node)
3635
for src, dst in edges:
37-
assert src in self.nodes
38-
assert dst in self.nodes
36+
assert src in self.graph_nodes
37+
assert dst in self.graph_nodes
3938
self._children[src].add(dst)
4039
self._parents[dst].add(src)
41-
# sinks have no children
42-
if src in self._sinks:
43-
self._sinks.remove(src)
40+
41+
def __len__(self):
42+
return len(self._children.keys())
4443

4544
@property
46-
def nodes(self):
45+
def graph_nodes(self) -> Iterable[T]:
4746
# should be the same set of ids as self._parents
4847
return self._children.keys()
4948

50-
@property
51-
def sinks(self) -> Iterable[T]:
52-
return self._sinks
53-
5449
@property
5550
def empty(self):
56-
return len(self.nodes) == 0
51+
return len(self) == 0
5752

5853
def parents(self, node: T) -> Iterator[T]:
5954
assert node in self._parents
@@ -68,9 +63,5 @@ def remove_node(self, node: T) -> None:
6863
self._parents[child].remove(node)
6964
for parent in self._parents[node]:
7065
self._children[parent].remove(node)
71-
if len(self._children[parent]) == 0:
72-
self._sinks.add(parent)
7366
del self._children[node]
7467
del self._parents[node]
75-
if node in self._sinks:
76-
self._sinks.remove(node)

0 commit comments

Comments
 (0)