|
25 | 25 | Hashable, |
26 | 26 | Iterable, |
27 | 27 | Iterator, |
| 28 | + Literal, |
28 | 29 | Mapping, |
29 | 30 | Optional, |
30 | 31 | Sequence, |
31 | 32 | Tuple, |
32 | 33 | TypeVar, |
| 34 | + Union, |
33 | 35 | ) |
34 | 36 |
|
35 | 37 | from bigframes.core import ( |
|
38 | 40 | graphs, |
39 | 41 | identifiers, |
40 | 42 | nodes, |
| 43 | + subquery_expression, |
41 | 44 | window_spec, |
42 | 45 | ) |
| 46 | +import bigframes.core.ordered_sets as sets |
43 | 47 |
|
44 | 48 | _MAX_INLINE_COMPLEXITY = 10 |
45 | 49 |
|
46 | 50 | T = TypeVar("T") |
| 51 | +ExprDomain = Union[window_spec.WindowSpec, Literal["Scalar", "Other"]] |
| 52 | + |
| 53 | + |
| 54 | +class ExpressionGraph(graphs.DiGraph[nodes.ColumnDef]): |
| 55 | + def __init__(self, column_defs: Sequence[nodes.ColumnDef]): |
| 56 | + # Assumption: All column defs have unique ids |
| 57 | + expr_ids = set(cdef.id for cdef in column_defs) |
| 58 | + self._graph = graphs.DiGraph( |
| 59 | + (expr.id for expr in column_defs), |
| 60 | + ( |
| 61 | + (expr.id, child_id) |
| 62 | + for expr in column_defs |
| 63 | + for child_id in expr.expression.column_references |
| 64 | + if child_id in expr_ids |
| 65 | + ), |
| 66 | + ) |
| 67 | + self._id_to_cdef = {cdef.id: cdef for cdef in column_defs} |
| 68 | + |
| 69 | + # TODO: Also prevent inlining expensive or non-deterministic |
| 70 | + # We avoid inlining multi-parent ids, as they would be inlined multiple places, potentially increasing work and/or compiled text size |
| 71 | + self._multi_parent_ids = set( |
| 72 | + id |
| 73 | + for id in self._graph.graph_nodes |
| 74 | + if len(list(self._graph.parents(id))) > 2 |
| 75 | + ) |
| 76 | + self._free_ids_by_domain: dict[ |
| 77 | + ExprDomain, sets.InsertionOrderedSet[identifiers.ColumnId] |
| 78 | + ] = collections.defaultdict(sets.InsertionOrderedSet) |
| 79 | + |
| 80 | + for id in self._graph.graph_nodes: |
| 81 | + if len(list(self._graph.children(id))) == 0: |
| 82 | + self._mark_free(id) |
| 83 | + |
| 84 | + @property |
| 85 | + def graph_nodes(self) -> Iterable[nodes.ColumnDef]: |
| 86 | + # should be the same set of ids as self._parents |
| 87 | + return map(self._id_to_cdef.__getitem__, self._graph.graph_nodes) |
| 88 | + |
| 89 | + @property |
| 90 | + def empty(self): |
| 91 | + return self._graph.empty |
| 92 | + |
| 93 | + def __len__(self): |
| 94 | + return len(self._graph) |
| 95 | + |
| 96 | + def parents(self, node: nodes.ColumnDef) -> Iterator[nodes.ColumnDef]: |
| 97 | + yield from map(self._id_to_cdef.__getitem__, self._graph.parents(node.id)) |
| 98 | + |
| 99 | + def children(self, node: nodes.ColumnDef) -> Iterator[nodes.ColumnDef]: |
| 100 | + yield from map(self._id_to_cdef.__getitem__, self._graph.children(node.id)) |
| 101 | + |
| 102 | + def _expr_domain(self, expr: expression.Expression) -> ExprDomain: |
| 103 | + if expr.is_scalar_expr: |
| 104 | + return "Scalar" |
| 105 | + elif isinstance(expr, agg_expressions.WindowExpression): |
| 106 | + return expr.window |
| 107 | + elif isinstance(expr, subquery_expression.SubqueryExpression): |
| 108 | + return "Other" |
| 109 | + else: |
| 110 | + raise ValueError(f"unrecognized expression {expr}") |
| 111 | + |
| 112 | + def _mark_free(self, id: identifiers.ColumnId): |
| 113 | + cdef = self._id_to_cdef[id] |
| 114 | + expr = cdef.expression |
| 115 | + # If this expands further, probably generalize a compatibility key |
| 116 | + self._free_ids_by_domain[self._expr_domain(expr)].add(id) |
| 117 | + |
| 118 | + def _remove_free_mark(self, id: identifiers.ColumnId): |
| 119 | + cdef = self._id_to_cdef[id] |
| 120 | + expr = cdef.expression |
| 121 | + # If this expands further, probably generalize a compatibility key |
| 122 | + if id in self._free_ids_by_domain[self._expr_domain(expr)]: |
| 123 | + self._free_ids_by_domain[self._expr_domain(expr)].remove(id) |
| 124 | + |
| 125 | + def remove_node(self, node: nodes.ColumnDef) -> None: |
| 126 | + for child in self._children[node]: |
| 127 | + self._parents[child].remove(node) |
| 128 | + for parent in self._parents[node]: |
| 129 | + self._children[parent].remove(node) |
| 130 | + if len(self._children[parent]) == 0: |
| 131 | + self._mark_free(parent.id) |
| 132 | + del self._children[node] |
| 133 | + del self._parents[node] |
| 134 | + self._remove_free_mark(node.id) |
| 135 | + |
| 136 | + def extract_scalar_exprs(self) -> Sequence[nodes.ColumnDef]: |
| 137 | + results: dict[identifiers.ColumnId, expression.Expression] = dict() |
| 138 | + while ( |
| 139 | + True |
| 140 | + ): # Will converge as each loop either reduces graph size, or fails to find any candidate and breaks |
| 141 | + candidate_ids = list( |
| 142 | + id |
| 143 | + for id in self._free_ids_by_domain["Scalar"] |
| 144 | + if not any( |
| 145 | + ( |
| 146 | + child in self._multi_parent_ids |
| 147 | + and id in results.keys() |
| 148 | + and not is_simple(results[id]) |
| 149 | + ) |
| 150 | + for child in self._graph.children(id) |
| 151 | + ) |
| 152 | + ) |
| 153 | + if len(candidate_ids) == 0: |
| 154 | + break |
| 155 | + for id in candidate_ids: |
| 156 | + self._graph.remove_node(id) |
| 157 | + new_exprs = { |
| 158 | + id: self._id_to_cdef[id].expression.bind_refs( |
| 159 | + results, allow_partial_bindings=True |
| 160 | + ) |
| 161 | + } |
| 162 | + results.update(new_exprs) |
| 163 | + # TODO: We can prune expressions that won't be reused here, |
| 164 | + return tuple(nodes.ColumnDef(expr, id) for id, expr in results.items()) |
| 165 | + |
| 166 | + def extract_window_expr( |
| 167 | + self, |
| 168 | + ) -> Optional[Tuple[Sequence[nodes.ColumnDef], window_spec.WindowSpec]]: |
| 169 | + window = next( |
| 170 | + ( |
| 171 | + domain |
| 172 | + for domain in self._free_ids_by_domain |
| 173 | + if domain not in ["Scalar", "Other"] |
| 174 | + ), |
| 175 | + None, |
| 176 | + ) |
| 177 | + assert not isinstance(window, str) |
| 178 | + if window: |
| 179 | + window_expr_ids = self._free_ids_by_domain[window] |
| 180 | + window_exprs = (self._id_to_cdef[id] for id in window_expr_ids) |
| 181 | + agg_exprs = tuple( |
| 182 | + nodes.ColumnDef( |
| 183 | + cast( |
| 184 | + agg_expressions.WindowExpression, cdef.expression |
| 185 | + ).analytic_expr, |
| 186 | + cdef.id, |
| 187 | + ) |
| 188 | + for cdef in window_exprs |
| 189 | + ) |
| 190 | + for cdef in window_exprs: |
| 191 | + self.remove_node(cdef) |
| 192 | + return (agg_exprs, window) |
| 193 | + |
| 194 | + return None |
47 | 195 |
|
48 | 196 |
|
49 | 197 | def unique_nodes( |
@@ -324,106 +472,25 @@ def push_into_tree( |
324 | 472 | target_ids: Sequence[identifiers.ColumnId], |
325 | 473 | ) -> nodes.BigFrameNode: |
326 | 474 | curr_root = root |
327 | | - by_id = {expr.id: expr for expr in exprs} |
328 | 475 | # id -> id |
329 | | - graph = graphs.DiGraph( |
330 | | - (expr.id for expr in exprs), |
331 | | - ( |
332 | | - (expr.id, child_id) |
333 | | - for expr in exprs |
334 | | - for child_id in expr.expression.column_references |
335 | | - if child_id in by_id.keys() |
336 | | - ), |
337 | | - ) |
338 | | - # TODO: Also prevent inlining expensive or non-deterministic |
339 | | - # We avoid inlining multi-parent ids, as they would be inlined multiple places, potentially increasing work and/or compiled text size |
340 | | - multi_parent_ids = set(id for id in graph.nodes if len(list(graph.parents(id))) > 2) |
341 | | - scalar_ids = set(expr.id for expr in exprs if expr.expression.is_scalar_expr) |
342 | | - |
343 | | - analytic_defs = filter( |
344 | | - lambda x: isinstance(x.expression, agg_expressions.WindowExpression), exprs |
345 | | - ) |
346 | | - analytic_by_window = grouped( |
347 | | - map( |
348 | | - lambda x: (cast(agg_expressions.WindowExpression, x.expression).window, x), |
349 | | - analytic_defs, |
350 | | - ) |
351 | | - ) |
352 | | - |
353 | | - def graph_extract_scalar_exprs() -> Sequence[nodes.ColumnDef]: |
354 | | - results: dict[identifiers.ColumnId, expression.Expression] = dict() |
355 | | - while ( |
356 | | - True |
357 | | - ): # Will converge as each loop either reduces graph size, or fails to find any candidate and breaks |
358 | | - candidate_ids = list( |
359 | | - id |
360 | | - for id in graph.sinks |
361 | | - if (id in scalar_ids) |
362 | | - and not any( |
363 | | - ( |
364 | | - child in multi_parent_ids |
365 | | - and id in results.keys() |
366 | | - and not is_simple(results[id]) |
367 | | - ) |
368 | | - for child in graph.children(id) |
369 | | - ) |
370 | | - ) |
371 | | - if len(candidate_ids) == 0: |
372 | | - break |
373 | | - for id in candidate_ids: |
374 | | - graph.remove_node(id) |
375 | | - new_exprs = { |
376 | | - id: by_id[id].expression.bind_refs( |
377 | | - results, allow_partial_bindings=True |
378 | | - ) |
379 | | - } |
380 | | - results.update(new_exprs) |
381 | | - # TODO: We can prune expressions that won't be reused here, |
382 | | - return tuple(nodes.ColumnDef(expr, id) for id, expr in results.items()) |
383 | | - |
384 | | - def graph_extract_window_expr() -> Optional[ |
385 | | - Tuple[Sequence[nodes.ColumnDef], window_spec.WindowSpec] |
386 | | - ]: |
387 | | - for id in graph.sinks: |
388 | | - next_def = by_id[id] |
389 | | - if isinstance(next_def.expression, agg_expressions.WindowExpression): |
390 | | - window = next_def.expression.window |
391 | | - window_exprs = [ |
392 | | - cdef |
393 | | - for cdef in analytic_by_window[window] |
394 | | - if cdef.id in graph.sinks |
395 | | - ] |
396 | | - agg_exprs = tuple( |
397 | | - nodes.ColumnDef( |
398 | | - cast( |
399 | | - agg_expressions.WindowExpression, cdef.expression |
400 | | - ).analytic_expr, |
401 | | - cdef.id, |
402 | | - ) |
403 | | - for cdef in window_exprs |
404 | | - ) |
405 | | - for cdef in window_exprs: |
406 | | - graph.remove_node(cdef.id) |
407 | | - return (agg_exprs, window) |
408 | | - |
409 | | - return None |
| 476 | + graph = ExpressionGraph(exprs) |
410 | 477 |
|
411 | 478 | while not graph.empty: |
412 | | - pre_size = len(graph.nodes) |
413 | | - scalar_exprs = graph_extract_scalar_exprs() |
| 479 | + pre_size = len(graph) |
| 480 | + scalar_exprs = graph.extract_scalar_exprs() |
414 | 481 | if scalar_exprs: |
415 | 482 | curr_root = nodes.ProjectionNode( |
416 | 483 | curr_root, tuple((x.expression, x.id) for x in scalar_exprs) |
417 | 484 | ) |
418 | | - while result := graph_extract_window_expr(): |
| 485 | + while result := graph.extract_window_expr(): |
419 | 486 | defs, window = result |
420 | 487 | assert len(defs) > 0 |
421 | 488 | curr_root = nodes.WindowOpNode( |
422 | 489 | curr_root, |
423 | 490 | tuple(defs), |
424 | 491 | window, |
425 | 492 | ) |
426 | | - if len(graph.nodes) >= pre_size: |
| 493 | + if len(graph) >= pre_size: |
427 | 494 | raise ValueError("graph didn't shrink") |
428 | 495 | # TODO: Try to get the ordering right earlier, so can avoid this extra node. |
429 | 496 | post_ids = (*root.ids, *target_ids) |
|
0 commit comments