11import collections
22import dataclasses
33import functools
4- import itertools
5- from typing import Generic , Hashable , Iterable , Optional , Sequence , Tuple , TypeVar
4+ from typing import cast , Generic , Hashable , Iterable , Optional , Sequence , Tuple , TypeVar
65
7- from bigframes .core import agg_expressions , expression , identifiers , nodes
6+ from bigframes .core import agg_expressions , expression , identifiers , nodes , window_spec
87
98_MAX_INLINE_COMPLEXITY = 10
109
1110
12- @dataclasses .dataclass (frozen = True , eq = False )
13- class NamedExpression :
14- expr : expression .Expression
15- name : identifiers .ColumnId
16-
17-
1811@dataclasses .dataclass (frozen = True , eq = False )
1912class FactoredExpression :
2013 root_expr : expression .Expression
21- sub_exprs : Tuple [NamedExpression , ...]
14+ sub_exprs : Tuple [nodes . ColumnDef , ...]
2215
2316
24- def fragmentize_expression (root : NamedExpression ) -> Sequence [NamedExpression ]:
17+ def fragmentize_expression (root : nodes . ColumnDef ) -> Sequence [nodes . ColumnDef ]:
2518 """
2619 The goal of this functions is to factor out an expression into multiple sub-expressions.
2720 """
2821
29- factored_expr = root .expr .reduce_up (gather_fragments )
30- root_expr = NamedExpression (factored_expr .root_expr , root .name )
22+ factored_expr = root .expression .reduce_up (gather_fragments )
23+ root_expr = nodes . ColumnDef (factored_expr .root_expr , root .id )
3124 return (root_expr , * factored_expr .sub_exprs )
3225
3326
@@ -48,7 +41,7 @@ def gather_fragments(
4841 if not do_inline :
4942 id = identifiers .ColumnId .unique ()
5043 replacements .append (expression .DerefOp (id ))
51- named_exprs .append (NamedExpression (child_result .root_expr , id ))
44+ named_exprs .append (nodes . ColumnDef (child_result .root_expr , id ))
5245 named_exprs .extend (child_result .sub_exprs )
5346 else :
5447 replacements .append (child_result .root_expr )
@@ -116,24 +109,34 @@ def remove_node(self, node: T) -> None:
116109
117110def push_into_tree (
118111 root : nodes .BigFrameNode ,
119- exprs : Sequence [NamedExpression ],
112+ exprs : Sequence [nodes . ColumnDef ],
120113 target_ids : Sequence [identifiers .ColumnId ],
121114) -> nodes .BigFrameNode :
122115 curr_root = root
123- by_id = {expr .name : expr for expr in exprs }
116+ by_id = {expr .id : expr for expr in exprs }
124117 # id -> id
125118 graph = DiGraph (
126- (expr .name , child_id )
119+ (expr .id , child_id )
127120 for expr in exprs
128- for child_id in expr .expr .column_references
121+ for child_id in expr .expression .column_references
129122 if child_id in by_id .keys ()
130123 )
131124 # TODO: Also prevent inlining expensive or non-deterministic
132125 # We avoid inlining multi-parent ids, as they would be inlined multiple places, potentially increasing work and/or compiled text size
133126 multi_parent_ids = set (id for id in graph .nodes if len (graph .parents (id )) > 2 )
134- scalar_ids = set (expr .name for expr in exprs if expr .expr .is_scalar_expr )
127+ scalar_ids = set (expr .id for expr in exprs if expr .expression .is_scalar_expr )
135128
136- def graph_extract_scalar_exprs () -> Sequence [NamedExpression ]:
129+ analytic_defs = filter (
130+ lambda x : isinstance (x .expression , agg_expressions .WindowExpression ), exprs
131+ )
132+ analytic_by_window = grouped (
133+ map (
134+ lambda x : (cast (agg_expressions .WindowExpression , x .expression ).window , x ),
135+ analytic_defs ,
136+ )
137+ )
138+
139+ def graph_extract_scalar_exprs () -> Sequence [nodes .ColumnDef ]:
137140 results : dict [identifiers .ColumnId , expression .Expression ] = dict ()
138141 while (
139142 True
@@ -156,40 +159,55 @@ def graph_extract_scalar_exprs() -> Sequence[NamedExpression]:
156159 for id in candidate_ids :
157160 graph .remove_node (id )
158161 new_exprs = {
159- id : by_id [id ].expr .bind_refs (results , allow_partial_bindings = True )
162+ id : by_id [id ].expression .bind_refs (
163+ results , allow_partial_bindings = True
164+ )
160165 }
161166 results .update (new_exprs )
162167 # TODO: We can prune expressions that won't be reused here,
163- return tuple (NamedExpression (expr , id ) for id , expr in results .items ())
168+ return tuple (nodes . ColumnDef (expr , id ) for id , expr in results .items ())
164169
165170 def graph_extract_window_expr () -> Optional [
166- Tuple [identifiers . ColumnId , agg_expressions . WindowExpression ]
171+ Tuple [Sequence [ nodes . ColumnDef ], window_spec . WindowSpec ]
167172 ]:
168- candidate = list (
169- itertools .islice ((id for id in graph .sinks if id not in scalar_ids ), 1 )
170- )
171- if not candidate :
172- return None
173- else :
174- id = next (iter (candidate ))
175- graph .remove_node (id )
176- result_expr = by_id [id ].expr
177- assert isinstance (result_expr , agg_expressions .WindowExpression )
178- return (id , result_expr )
173+ for id in graph .sinks :
174+ next_def = by_id [id ]
175+ if isinstance (next_def .expression , agg_expressions .WindowExpression ):
176+ window = next_def .expression .window
177+ window_exprs = [
178+ cdef
179+ for cdef in analytic_by_window [window ]
180+ if cdef .id in graph .sinks
181+ ]
182+ agg_exprs = tuple (
183+ nodes .ColumnDef (
184+ cast (
185+ agg_expressions .WindowExpression , cdef .expression
186+ ).analytic_expr ,
187+ cdef .id ,
188+ )
189+ for cdef in window_exprs
190+ )
191+ for cdef in window_exprs :
192+ graph .remove_node (cdef .id )
193+ return (agg_exprs , window )
194+
195+ return None
179196
180197 while not graph .empty :
181198 pre_size = len (graph .nodes )
182199 scalar_exprs = graph_extract_scalar_exprs ()
183200 if scalar_exprs :
184201 curr_root = nodes .ProjectionNode (
185- curr_root , tuple ((x .expr , x .name ) for x in scalar_exprs )
202+ curr_root , tuple ((x .expression , x .id ) for x in scalar_exprs )
186203 )
187204 while result := graph_extract_window_expr ():
188- id , window_expr = result
205+ defs , window = result
206+ assert len (defs ) > 0
189207 curr_root = nodes .WindowOpNode (
190208 curr_root ,
191- ( nodes . ColumnDef ( window_expr . analytic_expr , id ), ),
192- window_expr . window ,
209+ tuple ( defs ),
210+ window ,
193211 )
194212 if len (graph .nodes ) >= pre_size :
195213 raise ValueError ("graph didn't shrink" )
@@ -210,3 +228,14 @@ def is_simple(expr: expression.Expression) -> bool:
210228 if count > _MAX_INLINE_COMPLEXITY :
211229 return False
212230 return True
231+
232+
233+ K = TypeVar ("K" , bound = Hashable )
234+ V = TypeVar ("V" )
235+
236+
237+ def grouped (values : Iterable [tuple [K , V ]]) -> dict [K , list [V ]]:
238+ result = collections .defaultdict (list )
239+ for k , v in values :
240+ result [k ].append (v )
241+ return result
0 commit comments