Skip to content

Commit b2e4efa

Browse files
factor out data structures to classes
1 parent c9a9921 commit b2e4efa

File tree

5 files changed

+252
-93
lines changed

5 files changed

+252
-93
lines changed

bigframes/core/array_value.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -295,7 +295,6 @@ def compute_general_reduction(
295295
new_root = expression_factoring.plan_general_aggregation(
296296
plan, named_exprs, grouping_keys=[ex.deref(by) for by in by_column_ids]
297297
)
298-
new_root.validate_tree()
299298
target_ids = tuple(named_expr.id for named_expr in named_exprs)
300299
return (ArrayValue(new_root), target_ids)
301300

bigframes/core/blocks.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1176,19 +1176,6 @@ def reduce_general(
11761176
*,
11771177
dropna: bool = True,
11781178
) -> typing.Tuple[Block, typing.Sequence[str]]:
1179-
"""
1180-
Apply aggregations to the block.
1181-
1182-
Arguments:
1183-
by_column_id: column id of the aggregation key, this is preserved through the transform and used as index.
1184-
aggregations: input_column_id, operation tuples
1185-
dropna: whether null keys should be dropped
1186-
1187-
Returns:
1188-
Tuple[Block, Sequence[str]]:
1189-
The first element is the grouped block. The second is the
1190-
column IDs corresponding to each applied aggregation.
1191-
"""
11921179
if column_labels is None:
11931180
column_labels = pd.Index(range(len(aggregations)))
11941181

bigframes/core/expression_factoring.py

Lines changed: 37 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,42 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
116
import collections
217
import dataclasses
318
import functools
419
import itertools
520
from typing import (
621
cast,
7-
Dict,
8-
Generic,
922
Hashable,
1023
Iterable,
1124
Iterator,
1225
Mapping,
1326
Optional,
1427
Sequence,
15-
Set,
1628
Tuple,
1729
TypeVar,
1830
)
1931

20-
from bigframes.core import agg_expressions, expression, identifiers, nodes, window_spec
32+
from bigframes.core import (
33+
agg_expressions,
34+
expression,
35+
graphs,
36+
identifiers,
37+
nodes,
38+
window_spec,
39+
)
2140

2241
_MAX_INLINE_COMPLEXITY = 10
2342

@@ -45,7 +64,6 @@ def plan_general_aggregation(
4564
all_inputs = list(
4665
itertools.chain(*(factored_agg.agg_inputs for factored_agg in factored_aggs))
4766
)
48-
# TODO: Windowize
4967
window_def = window_spec.WindowSpec(grouping_keys=tuple(grouping_keys))
5068
windowized_inputs = [
5169
nodes.ColumnDef(windowize(cdef.expression, window_def), cdef.id)
@@ -123,8 +141,10 @@ def factor_aggregation(root: nodes.ColumnDef) -> FactoredAggregation:
123141
2. The set of underlying primitive aggregations
124142
3. A final post-aggregate scalar expression
125143
"""
126-
final_aggs = set(find_final_aggregations(root.expression))
127-
agg_inputs = set(itertools.chain.from_iterable(map(find_agg_inputs, final_aggs)))
144+
final_aggs = list(dedupe(find_final_aggregations(root.expression)))
145+
agg_inputs = list(
146+
dedupe(itertools.chain.from_iterable(map(find_agg_inputs, final_aggs)))
147+
)
128148

129149
agg_input_defs = tuple(
130150
nodes.ColumnDef(expr, identifiers.ColumnId.unique()) for expr in agg_inputs
@@ -219,64 +239,6 @@ def replace_children(
219239
return root.transform_children(lambda x: mapping.get(x, x))
220240

221241

222-
T = TypeVar("T", bound=Hashable)
223-
224-
225-
class DiGraph(Generic[T]):
226-
def __init__(self, nodes: Iterable[T], edges: Iterable[Tuple[T, T]]):
227-
self._parents: Dict[T, Set[T]] = collections.defaultdict(set)
228-
self._children: Dict[T, Set[T]] = collections.defaultdict(
229-
set
230-
) # specifically, unpushed ones
231-
# use dict for stable ordering, which grants determinism
232-
self._sinks: dict[T, None] = dict()
233-
for node in nodes:
234-
self._children[node]
235-
self._parents[node]
236-
self._sinks[node] = None
237-
for src, dst in edges:
238-
assert src in self.nodes
239-
assert dst in self.nodes
240-
self._children[src].add(dst)
241-
self._parents[dst].add(src)
242-
# sinks have no children
243-
if src in self._sinks:
244-
del self._sinks[src]
245-
246-
@property
247-
def nodes(self):
248-
# should be the same set of ids as self._parents
249-
return self._children.keys()
250-
251-
@property
252-
def sinks(self) -> Iterable[T]:
253-
return self._sinks.keys()
254-
255-
@property
256-
def empty(self):
257-
return len(self.nodes) == 0
258-
259-
def parents(self, node: T) -> set[T]:
260-
assert node in self._parents
261-
return self._parents[node]
262-
263-
def children(self, node: T) -> set[T]:
264-
assert node in self._children
265-
return self._children[node]
266-
267-
def remove_node(self, node: T) -> None:
268-
for child in self._children[node]:
269-
self._parents[child].remove(node)
270-
for parent in self._parents[node]:
271-
self._children[parent].remove(node)
272-
if len(self._children[parent]) == 0:
273-
self._sinks[parent] = None
274-
del self._children[node]
275-
del self._parents[node]
276-
if node in self._sinks:
277-
del self._sinks[node]
278-
279-
280242
def push_into_tree(
281243
root: nodes.BigFrameNode,
282244
exprs: Sequence[nodes.ColumnDef],
@@ -285,7 +247,7 @@ def push_into_tree(
285247
curr_root = root
286248
by_id = {expr.id: expr for expr in exprs}
287249
# id -> id
288-
graph = DiGraph(
250+
graph = graphs.DiGraph(
289251
(expr.id for expr in exprs),
290252
(
291253
(expr.id, child_id)
@@ -296,7 +258,7 @@ def push_into_tree(
296258
)
297259
# TODO: Also prevent inlining expensive or non-deterministic
298260
# We avoid inlining multi-parent ids, as they would be inlined multiple places, potentially increasing work and/or compiled text size
299-
multi_parent_ids = set(id for id in graph.nodes if len(graph.parents(id)) > 2)
261+
multi_parent_ids = set(id for id in graph.nodes if len(list(graph.parents(id))) > 2)
300262
scalar_ids = set(expr.id for expr in exprs if expr.expression.is_scalar_expr)
301263

302264
analytic_defs = filter(
@@ -367,21 +329,13 @@ def graph_extract_window_expr() -> Optional[
367329

368330
return None
369331

370-
must_be_pushed = set(target_ids) - set(graph.nodes)
371-
if not must_be_pushed.issubset(curr_root.ids):
372-
missing = must_be_pushed - set(curr_root.ids)
373-
raise ValueError(f"hmmm, missing {missing}")
374-
375332
while not graph.empty:
376333
pre_size = len(graph.nodes)
377334
scalar_exprs = graph_extract_scalar_exprs()
378335
if scalar_exprs:
379336
curr_root = nodes.ProjectionNode(
380337
curr_root, tuple((x.expression, x.id) for x in scalar_exprs)
381338
)
382-
must_be_pushed = set(target_ids) - set(graph.nodes)
383-
if not must_be_pushed.issubset(curr_root.ids):
384-
raise ValueError("hmmm")
385339
while result := graph_extract_window_expr():
386340
defs, window = result
387341
assert len(defs) > 0
@@ -390,10 +344,6 @@ def graph_extract_window_expr() -> Optional[
390344
tuple(defs),
391345
window,
392346
)
393-
must_be_pushed = set(target_ids) - set(graph.nodes)
394-
if not must_be_pushed.issubset(curr_root.ids):
395-
missing = must_be_pushed - set(curr_root.ids)
396-
raise ValueError(f"hmmm, missing {missing}")
397347
if len(graph.nodes) >= pre_size:
398348
raise ValueError("graph didn't shrink")
399349
# TODO: Try to get the ordering right earlier, so can avoid this extra node.
@@ -424,3 +374,11 @@ def grouped(values: Iterable[tuple[K, V]]) -> dict[K, list[V]]:
424374
for k, v in values:
425375
result[k].append(v)
426376
return result
377+
378+
379+
def dedupe(values: Iterable[K]) -> Iterator[K]:
380+
seen = set()
381+
for k in values:
382+
if k not in seen:
383+
seen.add(k)
384+
yield k

bigframes/core/graphs.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import collections
16+
from typing import Dict, Generic, Hashable, Iterable, Iterator, Tuple, TypeVar
17+
18+
import bigframes.core.ordered_sets as sets
19+
20+
T = TypeVar("T", bound=Hashable)
21+
22+
23+
class DiGraph(Generic[T]):
24+
def __init__(self, nodes: Iterable[T], edges: Iterable[Tuple[T, T]]):
25+
self._parents: Dict[T, sets.InsertionOrderedSet[T]] = collections.defaultdict(
26+
sets.InsertionOrderedSet
27+
)
28+
self._children: Dict[T, sets.InsertionOrderedSet[T]] = collections.defaultdict(
29+
sets.InsertionOrderedSet
30+
)
31+
self._sinks: sets.InsertionOrderedSet[T] = sets.InsertionOrderedSet()
32+
for node in nodes:
33+
self._children[node]
34+
self._parents[node]
35+
self._sinks.add(node)
36+
for src, dst in edges:
37+
assert src in self.nodes
38+
assert dst in self.nodes
39+
self._children[src].add(dst)
40+
self._parents[dst].add(src)
41+
# sinks have no children
42+
if src in self._sinks:
43+
self._sinks.remove(src)
44+
45+
@property
46+
def nodes(self):
47+
# should be the same set of ids as self._parents
48+
return self._children.keys()
49+
50+
@property
51+
def sinks(self) -> Iterable[T]:
52+
return self._sinks
53+
54+
@property
55+
def empty(self):
56+
return len(self.nodes) == 0
57+
58+
def parents(self, node: T) -> Iterator[T]:
59+
assert node in self._parents
60+
yield from self._parents[node]
61+
62+
def children(self, node: T) -> Iterator[T]:
63+
assert node in self._children
64+
yield from self._children[node]
65+
66+
def remove_node(self, node: T) -> None:
67+
for child in self._children[node]:
68+
self._parents[child].remove(node)
69+
for parent in self._parents[node]:
70+
self._children[parent].remove(node)
71+
if len(self._children[parent]) == 0:
72+
self._sinks.add(parent)
73+
del self._children[node]
74+
del self._parents[node]
75+
if node in self._sinks:
76+
self._sinks.remove(node)

0 commit comments

Comments
 (0)