Skip to content

Commit 20431f7

Browse files
fix: Fix row count local execution bug
1 parent c3c292c commit 20431f7

File tree

2 files changed

+33
-11
lines changed

2 files changed

+33
-11
lines changed

bigframes/core/rewrite/pruning.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414
import dataclasses
1515
import functools
16+
import itertools
1617
import typing
1718

1819
from bigframes.core import identifiers, nodes
@@ -51,24 +52,22 @@ def prune_columns(node: nodes.BigFrameNode):
5152
if isinstance(node, nodes.SelectionNode):
5253
result = prune_selection_child(node)
5354
elif isinstance(node, nodes.ResultNode):
54-
result = node.replace_child(
55-
prune_node(
56-
node.child, node.consumed_ids or frozenset(list(node.child.ids)[0:1])
57-
)
58-
)
55+
result = node.replace_child(prune_node(node.child, node.consumed_ids))
5956
elif isinstance(node, nodes.AggregateNode):
60-
result = node.replace_child(
61-
prune_node(
62-
node.child, node.consumed_ids or frozenset(list(node.child.ids)[0:1])
63-
)
64-
)
57+
result = node.replace_child(prune_node(node.child, node.consumed_ids))
6558
elif isinstance(node, nodes.InNode):
6659
result = dataclasses.replace(
6760
node,
6861
right_child=prune_node(node.right_child, frozenset([node.right_col.id])),
6962
)
7063
else:
7164
result = node
65+
66+
if len(set(result.ids)) == 0:
67+
raise ValueError()
68+
for child in result.child_nodes:
69+
if len(set(child.ids)) == 0:
70+
raise ValueError()
7271
return result
7372

7473

@@ -149,9 +148,13 @@ def prune_node(
149148
if not (set(node.ids) - ids):
150149
return node
151150
else:
151+
# If no child ids are needed, probably a size op or numbering op above, keep a single column always
152+
ids_to_keep = tuple(id for id in node.ids if id in ids) or tuple(
153+
itertools.islice(node.ids, 0, 1)
154+
)
152155
return nodes.SelectionNode(
153156
node,
154-
tuple(nodes.AliasedRef.identity(id) for id in node.ids if id in ids),
157+
tuple(nodes.AliasedRef.identity(id) for id in ids_to_keep),
155158
)
156159

157160

tests/system/small/engines/test_aggregation.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,25 @@ def apply_agg_to_all_valid(
4848
return new_arr
4949

5050

51+
@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True)
52+
def test_engines_aggregate_post_filter_size(
53+
scalars_array_value: array_value.ArrayValue,
54+
engine,
55+
):
56+
w_offsets, offsets_id = (
57+
scalars_array_value.select_columns(("bool_col", "string_col"))
58+
.filter(expression.deref("bool_col"))
59+
.promote_offsets()
60+
)
61+
plan = (
62+
w_offsets.select_columns((offsets_id, "bool_col", "string_col"))
63+
.row_count()
64+
.node
65+
)
66+
67+
assert_equivalence_execution(plan, REFERENCE_ENGINE, engine)
68+
69+
5170
@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True)
5271
def test_engines_aggregate_size(
5372
scalars_array_value: array_value.ArrayValue,

0 commit comments

Comments
 (0)