Skip to content

Commit 7a1c53a

Browse files
make more deterministic
1 parent 616eccf commit 7a1c53a

File tree

1 file changed

+27
-25
lines changed

1 file changed

+27
-25
lines changed

bigframes/core/expression_factoring.py

Lines changed: 27 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import collections
22
import dataclasses
33
import functools
4+
import itertools
45
from typing import Generic, Hashable, Iterable, Optional, Sequence, Tuple, TypeVar
56

67
from bigframes.core import agg_expressions, expression, identifiers, nodes
@@ -51,9 +52,7 @@ def gather_fragments(
5152
do_inline = is_leaf | is_window_agg
5253
if not do_inline:
5354
id = identifiers.ColumnId.unique()
54-
replacements.append(
55-
expression.DerefOp(id)
56-
) # TODO: Determinism, maybe hash-based?
55+
replacements.append(expression.DerefOp(id))
5756
named_exprs.append(NamedExpression(child_result.root_expr, id))
5857
named_exprs.extend(child_result.sub_exprs)
5958
else:
@@ -75,32 +74,31 @@ def replace_children(
7574

7675
class DiGraph(Generic[T]):
7776
def __init__(self, edges: Iterable[Tuple[T, T]]):
78-
self._nodes = set()
7977
self._parents = collections.defaultdict(set)
8078
self._children = collections.defaultdict(set) # specifically, unpushed ones
81-
# dict repr of graph
82-
self._sinks = set()
79+
# use dict for stable ordering, which grants determinism
80+
self._sinks: dict[T, None] = dict()
8381
for src, dst in edges:
8482
self._children[src].add(dst)
8583
self._parents[dst].add(src)
86-
self._nodes.add(src)
87-
self._nodes.add(dst)
8884
# sinks have no children
8985
if not self._children[dst]:
90-
self._sinks.add(dst)
91-
self._sinks.discard(src)
86+
self._sinks[dst] = None
87+
if src in self._sinks:
88+
del self._sinks[src]
9289

9390
@property
9491
def nodes(self):
95-
return self._nodes
92+
# should be the same set of ids as self._parents
93+
return self._children.keys()
9694

9795
@property
98-
def sinks(self) -> set[T]:
99-
return self._sinks
96+
def sinks(self) -> Iterable[T]:
97+
return self._sinks.keys()
10098

10199
@property
102100
def empty(self):
103-
return len(self._nodes) == 0
101+
return len(self.nodes) == 0
104102

105103
def parents(self, node: T) -> set[T]:
106104
return self._parents[node]
@@ -114,11 +112,11 @@ def remove_node(self, node: T) -> None:
114112
for parent in self._parents[node]:
115113
self._children[parent].remove(node)
116114
if len(self._children[parent]) == 0:
117-
self._sinks.add(parent)
115+
self._sinks[parent] = None
118116
del self._children[node]
119117
del self._parents[node]
120-
self._nodes.remove(node)
121-
self._sinks.discard(node)
118+
if node in self._sinks:
119+
del self._sinks[node]
122120

123121

124122
def push_into_tree(
@@ -145,11 +143,11 @@ def graph_extract_scalar_exprs() -> Sequence[NamedExpression]:
145143
while (
146144
True
147145
): # Will converge as each loop either reduces graph size, or fails to find any candidate and breaks
148-
candidate_ids = graph.sinks.intersection(scalar_ids)
149-
bad_inline = set(
146+
candidate_ids = list(
150147
id
151-
for id in candidate_ids
152-
if any(
148+
for id in graph.sinks
149+
if (id in scalar_ids)
150+
and not any(
153151
(
154152
child in multi_parent_ids
155153
and id in results.keys()
@@ -158,7 +156,6 @@ def graph_extract_scalar_exprs() -> Sequence[NamedExpression]:
158156
for child in graph.children(id)
159157
)
160158
)
161-
candidate_ids = candidate_ids.difference(bad_inline)
162159
if len(candidate_ids) == 0:
163160
break
164161
for id in candidate_ids:
@@ -173,17 +170,20 @@ def graph_extract_scalar_exprs() -> Sequence[NamedExpression]:
173170
def graph_extract_window_expr() -> Optional[
174171
Tuple[identifiers.ColumnId, agg_expressions.WindowExpression]
175172
]:
176-
candidate_ids = graph.sinks.difference(scalar_ids)
177-
if not candidate_ids:
173+
candidate = list(
174+
itertools.islice((id for id in graph.sinks if id not in scalar_ids), 1)
175+
)
176+
if not candidate:
178177
return None
179178
else:
180-
id = next(iter(candidate_ids))
179+
id = next(iter(candidate))
181180
graph.remove_node(id)
182181
result_expr = by_id[id].expr
183182
assert isinstance(result_expr, agg_expressions.WindowExpression)
184183
return (id, result_expr)
185184

186185
while not graph.empty:
186+
pre_size = len(graph.nodes)
187187
scalar_exprs = graph_extract_scalar_exprs()
188188
if scalar_exprs:
189189
curr_root = nodes.ProjectionNode(
@@ -194,6 +194,8 @@ def graph_extract_window_expr() -> Optional[
194194
curr_root = nodes.WindowOpNode(
195195
curr_root, window_expr.analytic_expr, window_expr.window, output_name=id
196196
)
197+
if len(graph.nodes) >= pre_size:
198+
raise ValueError("graph didn't shrink")
197199
# TODO: Try to get the ordering right earlier, so can avoid this extra node.
198200
post_ids = (*root.ids, *target_ids)
199201
if tuple(curr_root.ids) != post_ids:

0 commit comments

Comments
 (0)