1+ # Copyright 2025 Google LLC
2+ #
3+ # Licensed under the Apache License, Version 2.0 (the "License");
4+ # you may not use this file except in compliance with the License.
5+ # You may obtain a copy of the License at
6+ #
7+ # http://www.apache.org/licenses/LICENSE-2.0
8+ #
9+ # Unless required by applicable law or agreed to in writing, software
10+ # distributed under the License is distributed on an "AS IS" BASIS,
11+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+ # See the License for the specific language governing permissions and
13+ # limitations under the License.
14+
15+
116import collections
217import dataclasses
318import functools
419import itertools
520from typing import (
621 cast ,
7- Dict ,
8- Generic ,
922 Hashable ,
1023 Iterable ,
1124 Iterator ,
1225 Mapping ,
1326 Optional ,
1427 Sequence ,
15- Set ,
1628 Tuple ,
1729 TypeVar ,
1830)
1931
20- from bigframes .core import agg_expressions , expression , identifiers , nodes , window_spec
32+ from bigframes .core import (
33+ agg_expressions ,
34+ expression ,
35+ graphs ,
36+ identifiers ,
37+ nodes ,
38+ window_spec ,
39+ )
2140
2241_MAX_INLINE_COMPLEXITY = 10
2342
@@ -45,7 +64,6 @@ def plan_general_aggregation(
4564 all_inputs = list (
4665 itertools .chain (* (factored_agg .agg_inputs for factored_agg in factored_aggs ))
4766 )
48- # TODO: Windowize
4967 window_def = window_spec .WindowSpec (grouping_keys = tuple (grouping_keys ))
5068 windowized_inputs = [
5169 nodes .ColumnDef (windowize (cdef .expression , window_def ), cdef .id )
@@ -123,8 +141,10 @@ def factor_aggregation(root: nodes.ColumnDef) -> FactoredAggregation:
123141 2. The set of underlying primitive aggregations
124142 3. A final post-aggregate scalar expression
125143 """
126- final_aggs = set (find_final_aggregations (root .expression ))
127- agg_inputs = set (itertools .chain .from_iterable (map (find_agg_inputs , final_aggs )))
144+ final_aggs = list (dedupe (find_final_aggregations (root .expression )))
145+ agg_inputs = list (
146+ dedupe (itertools .chain .from_iterable (map (find_agg_inputs , final_aggs )))
147+ )
128148
129149 agg_input_defs = tuple (
130150 nodes .ColumnDef (expr , identifiers .ColumnId .unique ()) for expr in agg_inputs
@@ -219,64 +239,6 @@ def replace_children(
219239 return root .transform_children (lambda x : mapping .get (x , x ))
220240
221241
222- T = TypeVar ("T" , bound = Hashable )
223-
224-
225- class DiGraph (Generic [T ]):
226- def __init__ (self , nodes : Iterable [T ], edges : Iterable [Tuple [T , T ]]):
227- self ._parents : Dict [T , Set [T ]] = collections .defaultdict (set )
228- self ._children : Dict [T , Set [T ]] = collections .defaultdict (
229- set
230- ) # specifically, unpushed ones
231- # use dict for stable ordering, which grants determinism
232- self ._sinks : dict [T , None ] = dict ()
233- for node in nodes :
234- self ._children [node ]
235- self ._parents [node ]
236- self ._sinks [node ] = None
237- for src , dst in edges :
238- assert src in self .nodes
239- assert dst in self .nodes
240- self ._children [src ].add (dst )
241- self ._parents [dst ].add (src )
242- # sinks have no children
243- if src in self ._sinks :
244- del self ._sinks [src ]
245-
246- @property
247- def nodes (self ):
248- # should be the same set of ids as self._parents
249- return self ._children .keys ()
250-
251- @property
252- def sinks (self ) -> Iterable [T ]:
253- return self ._sinks .keys ()
254-
255- @property
256- def empty (self ):
257- return len (self .nodes ) == 0
258-
259- def parents (self , node : T ) -> set [T ]:
260- assert node in self ._parents
261- return self ._parents [node ]
262-
263- def children (self , node : T ) -> set [T ]:
264- assert node in self ._children
265- return self ._children [node ]
266-
267- def remove_node (self , node : T ) -> None :
268- for child in self ._children [node ]:
269- self ._parents [child ].remove (node )
270- for parent in self ._parents [node ]:
271- self ._children [parent ].remove (node )
272- if len (self ._children [parent ]) == 0 :
273- self ._sinks [parent ] = None
274- del self ._children [node ]
275- del self ._parents [node ]
276- if node in self ._sinks :
277- del self ._sinks [node ]
278-
279-
280242def push_into_tree (
281243 root : nodes .BigFrameNode ,
282244 exprs : Sequence [nodes .ColumnDef ],
@@ -285,7 +247,7 @@ def push_into_tree(
285247 curr_root = root
286248 by_id = {expr .id : expr for expr in exprs }
287249 # id -> id
288- graph = DiGraph (
250+ graph = graphs . DiGraph (
289251 (expr .id for expr in exprs ),
290252 (
291253 (expr .id , child_id )
@@ -296,7 +258,7 @@ def push_into_tree(
296258 )
297259 # TODO: Also prevent inlining expensive or non-deterministic
298260 # We avoid inlining multi-parent ids, as they would be inlined multiple places, potentially increasing work and/or compiled text size
299- multi_parent_ids = set (id for id in graph .nodes if len (graph .parents (id )) > 2 )
261+ multi_parent_ids = set (id for id in graph .nodes if len (list ( graph .parents (id ) )) > 2 )
300262 scalar_ids = set (expr .id for expr in exprs if expr .expression .is_scalar_expr )
301263
302264 analytic_defs = filter (
@@ -367,21 +329,13 @@ def graph_extract_window_expr() -> Optional[
367329
368330 return None
369331
370- must_be_pushed = set (target_ids ) - set (graph .nodes )
371- if not must_be_pushed .issubset (curr_root .ids ):
372- missing = must_be_pushed - set (curr_root .ids )
373- raise ValueError (f"hmmm, missing { missing } " )
374-
375332 while not graph .empty :
376333 pre_size = len (graph .nodes )
377334 scalar_exprs = graph_extract_scalar_exprs ()
378335 if scalar_exprs :
379336 curr_root = nodes .ProjectionNode (
380337 curr_root , tuple ((x .expression , x .id ) for x in scalar_exprs )
381338 )
382- must_be_pushed = set (target_ids ) - set (graph .nodes )
383- if not must_be_pushed .issubset (curr_root .ids ):
384- raise ValueError ("hmmm" )
385339 while result := graph_extract_window_expr ():
386340 defs , window = result
387341 assert len (defs ) > 0
@@ -390,10 +344,6 @@ def graph_extract_window_expr() -> Optional[
390344 tuple (defs ),
391345 window ,
392346 )
393- must_be_pushed = set (target_ids ) - set (graph .nodes )
394- if not must_be_pushed .issubset (curr_root .ids ):
395- missing = must_be_pushed - set (curr_root .ids )
396- raise ValueError (f"hmmm, missing { missing } " )
397347 if len (graph .nodes ) >= pre_size :
398348 raise ValueError ("graph didn't shrink" )
399349 # TODO: Try to get the ordering right earlier, so can avoid this extra node.
@@ -424,3 +374,11 @@ def grouped(values: Iterable[tuple[K, V]]) -> dict[K, list[V]]:
424374 for k , v in values :
425375 result [k ].append (v )
426376 return result
377+
378+
379+ def dedupe (values : Iterable [K ]) -> Iterator [K ]:
380+ seen = set ()
381+ for k in values :
382+ if k not in seen :
383+ seen .add (k )
384+ yield k
0 commit comments