11import collections
22import dataclasses
33import functools
4+ import itertools
45from typing import Generic , Hashable , Iterable , Optional , Sequence , Tuple , TypeVar
56
67from bigframes .core import agg_expressions , expression , identifiers , nodes
@@ -51,9 +52,7 @@ def gather_fragments(
5152 do_inline = is_leaf | is_window_agg
5253 if not do_inline :
5354 id = identifiers .ColumnId .unique ()
54- replacements .append (
55- expression .DerefOp (id )
56- ) # TODO: Determinism, maybe hash-based?
55+ replacements .append (expression .DerefOp (id ))
5756 named_exprs .append (NamedExpression (child_result .root_expr , id ))
5857 named_exprs .extend (child_result .sub_exprs )
5958 else :
@@ -75,32 +74,31 @@ def replace_children(
7574
7675class DiGraph (Generic [T ]):
7776 def __init__ (self , edges : Iterable [Tuple [T , T ]]):
78- self ._nodes = set ()
7977 self ._parents = collections .defaultdict (set )
8078 self ._children = collections .defaultdict (set ) # specifically, unpushed ones
81- # dict repr of graph
82- self ._sinks = set ()
79+ # use dict for stable ordering, which grants determinism
80+ self ._sinks : dict [ T , None ] = dict ()
8381 for src , dst in edges :
8482 self ._children [src ].add (dst )
8583 self ._parents [dst ].add (src )
86- self ._nodes .add (src )
87- self ._nodes .add (dst )
8884 # sinks have no children
8985 if not self ._children [dst ]:
90- self ._sinks .add (dst )
91- self ._sinks .discard (src )
86+ self ._sinks [dst ] = None
87+ if src in self ._sinks :
88+ del self ._sinks [src ]
9289
9390 @property
9491 def nodes (self ):
95- return self ._nodes
92+ # should be the same set of ids as self._parents
93+ return self ._children .keys ()
9694
9795 @property
98- def sinks (self ) -> set [T ]:
99- return self ._sinks
96+ def sinks (self ) -> Iterable [T ]:
97+ return self ._sinks . keys ()
10098
10199 @property
102100 def empty (self ):
103- return len (self ._nodes ) == 0
101+ return len (self .nodes ) == 0
104102
105103 def parents (self , node : T ) -> set [T ]:
106104 return self ._parents [node ]
@@ -114,11 +112,11 @@ def remove_node(self, node: T) -> None:
114112 for parent in self ._parents [node ]:
115113 self ._children [parent ].remove (node )
116114 if len (self ._children [parent ]) == 0 :
117- self ._sinks . add ( parent )
115+ self ._sinks [ parent ] = None
118116 del self ._children [node ]
119117 del self ._parents [node ]
120- self ._nodes . remove ( node )
121- self ._sinks . discard ( node )
118+ if node in self ._sinks :
119+ del self ._sinks [ node ]
122120
123121
124122def push_into_tree (
@@ -145,11 +143,11 @@ def graph_extract_scalar_exprs() -> Sequence[NamedExpression]:
145143 while (
146144 True
147145 ): # Will converge as each loop either reduces graph size, or fails to find any candidate and breaks
148- candidate_ids = graph .sinks .intersection (scalar_ids )
149- bad_inline = set (
146+ candidate_ids = list (
150147 id
151- for id in candidate_ids
152- if any (
148+ for id in graph .sinks
149+ if (id in scalar_ids )
150+ and not any (
153151 (
154152 child in multi_parent_ids
155153 and id in results .keys ()
@@ -158,7 +156,6 @@ def graph_extract_scalar_exprs() -> Sequence[NamedExpression]:
158156 for child in graph .children (id )
159157 )
160158 )
161- candidate_ids = candidate_ids .difference (bad_inline )
162159 if len (candidate_ids ) == 0 :
163160 break
164161 for id in candidate_ids :
@@ -173,17 +170,20 @@ def graph_extract_scalar_exprs() -> Sequence[NamedExpression]:
173170 def graph_extract_window_expr () -> Optional [
174171 Tuple [identifiers .ColumnId , agg_expressions .WindowExpression ]
175172 ]:
176- candidate_ids = graph .sinks .difference (scalar_ids )
177- if not candidate_ids :
173+ candidate = list (
174+ itertools .islice ((id for id in graph .sinks if id not in scalar_ids ), 1 )
175+ )
176+ if not candidate :
178177 return None
179178 else :
180- id = next (iter (candidate_ids ))
179+ id = next (iter (candidate ))
181180 graph .remove_node (id )
182181 result_expr = by_id [id ].expr
183182 assert isinstance (result_expr , agg_expressions .WindowExpression )
184183 return (id , result_expr )
185184
186185 while not graph .empty :
186+ pre_size = len (graph .nodes )
187187 scalar_exprs = graph_extract_scalar_exprs ()
188188 if scalar_exprs :
189189 curr_root = nodes .ProjectionNode (
@@ -194,6 +194,8 @@ def graph_extract_window_expr() -> Optional[
194194 curr_root = nodes .WindowOpNode (
195195 curr_root , window_expr .analytic_expr , window_expr .window , output_name = id
196196 )
197+ if len (graph .nodes ) >= pre_size :
198+ raise ValueError ("graph didn't shrink" )
197199 # TODO: Try to get the ordering right earlier, so can avoid this extra node.
198200 post_ids = (* root .ids , * target_ids )
199201 if tuple (curr_root .ids ) != post_ids :
0 commit comments