|
13 | 13 | # limitations under the License. |
14 | 14 | import dataclasses |
15 | 15 | import functools |
| 16 | +import itertools |
16 | 17 | import typing |
17 | 18 |
|
18 | 19 | from bigframes.core import identifiers, nodes |
@@ -51,24 +52,22 @@ def prune_columns(node: nodes.BigFrameNode): |
51 | 52 | if isinstance(node, nodes.SelectionNode): |
52 | 53 | result = prune_selection_child(node) |
53 | 54 | elif isinstance(node, nodes.ResultNode): |
54 | | - result = node.replace_child( |
55 | | - prune_node( |
56 | | - node.child, node.consumed_ids or frozenset(list(node.child.ids)[0:1]) |
57 | | - ) |
58 | | - ) |
| 55 | + result = node.replace_child(prune_node(node.child, node.consumed_ids)) |
59 | 56 | elif isinstance(node, nodes.AggregateNode): |
60 | | - result = node.replace_child( |
61 | | - prune_node( |
62 | | - node.child, node.consumed_ids or frozenset(list(node.child.ids)[0:1]) |
63 | | - ) |
64 | | - ) |
| 57 | + result = node.replace_child(prune_node(node.child, node.consumed_ids)) |
65 | 58 | elif isinstance(node, nodes.InNode): |
66 | 59 | result = dataclasses.replace( |
67 | 60 | node, |
68 | 61 | right_child=prune_node(node.right_child, frozenset([node.right_col.id])), |
69 | 62 | ) |
70 | 63 | else: |
71 | 64 | result = node |
| 65 | + |
| 66 | + if len(set(result.ids)) == 0: |
| 67 | + raise ValueError() |
| 68 | + for child in result.child_nodes: |
| 69 | + if len(set(child.ids)) == 0: |
| 70 | + raise ValueError() |
72 | 71 | return result |
73 | 72 |
|
74 | 73 |
|
@@ -149,9 +148,13 @@ def prune_node( |
149 | 148 | if not (set(node.ids) - ids): |
150 | 149 | return node |
151 | 150 | else: |
| 151 | + # If no child ids are needed, probably a size op or numbering op above, keep a single column always |
| 152 | + ids_to_keep = tuple(id for id in node.ids if id in ids) or tuple( |
| 153 | + itertools.islice(node.ids, 0, 1) |
| 154 | + ) |
152 | 155 | return nodes.SelectionNode( |
153 | 156 | node, |
154 | | - tuple(nodes.AliasedRef.identity(id) for id in node.ids if id in ids), |
| 157 | + tuple(nodes.AliasedRef.identity(id) for id in ids_to_keep), |
155 | 158 | ) |
156 | 159 |
|
157 | 160 |
|
|
0 commit comments