Skip to content

Commit d59297c

Browse files
perf: Simplify plans by deferring aliasing above joins where possible
1 parent 64995d6 commit d59297c

File tree

1 file changed

+76
-3
lines changed

1 file changed

+76
-3
lines changed

bigframes/core/rewrite/select_pullup.py

Lines changed: 76 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,10 @@ def pull_up_select_unary(node: nodes.UnaryNode) -> nodes.BigFrameNode:
8181
if not isinstance(child, nodes.SelectionNode):
8282
return node
8383

84+
# case where selection must be kept in place to prevent ambiguity
85+
if set(child.child.ids) & set(node.defined_variables):
86+
return node
87+
8488
# Schema-preserving nodes
8589
if isinstance(
8690
node,
@@ -157,9 +161,78 @@ def pull_up_select_unary(node: nodes.UnaryNode) -> nodes.BigFrameNode:
157161
return node
158162

159163

160-
def pull_up_selects_under_join(node: nodes.JoinNode) -> nodes.JoinNode:
161-
# Can in theory pull up selects here, but it is a bit dangerous, in particular or self-joins, when there are more transforms to do.
162-
# TODO: Safely pull up selects above join
164+
def pull_up_selects_under_join(node: nodes.JoinNode) -> nodes.BigFrameNode:
165+
if isinstance(node.left_child, nodes.SelectionNode) and isinstance(
166+
node.right_child, nodes.SelectionNode
167+
):
168+
conflicts = set(node.left_child.child.ids) & set(node.right_child.child.ids)
169+
if not conflicts:
170+
lmap = {id: ref.id for ref, id in node.left_child.input_output_pairs}
171+
rmap = {id: ref.id for ref, id in node.right_child.input_output_pairs}
172+
new_join = nodes.JoinNode(
173+
node.left_child.child,
174+
node.right_child.child,
175+
conditions=tuple(
176+
(lref.remap_column_refs(lmap), rref.remap_column_refs(rmap))
177+
for lref, rref in node.conditions
178+
),
179+
type=node.type,
180+
propogate_order=node.propogate_order,
181+
)
182+
new_select = nodes.SelectionNode(
183+
new_join,
184+
(
185+
*node.left_child.input_output_pairs,
186+
*node.right_child.input_output_pairs,
187+
),
188+
)
189+
return new_select
190+
elif isinstance(node.left_child, nodes.SelectionNode):
191+
conflicts = set(node.left_child.child.ids) & set(node.right_child.ids)
192+
if not conflicts:
193+
lmap = {id: ref.id for ref, id in node.left_child.input_output_pairs}
194+
new_join = nodes.JoinNode(
195+
node.left_child.child,
196+
node.right_child,
197+
conditions=tuple(
198+
(lref.remap_column_refs(lmap), rref)
199+
for lref, rref in node.conditions
200+
),
201+
type=node.type,
202+
propogate_order=node.propogate_order,
203+
)
204+
new_select = nodes.SelectionNode(
205+
new_join,
206+
(
207+
*node.left_child.input_output_pairs,
208+
*(nodes.AliasedRef.identity(id) for id in node.right_child.ids),
209+
),
210+
)
211+
return new_select
212+
213+
elif isinstance(node.right_child, nodes.SelectionNode):
214+
conflicts = set(node.right_child.child.ids) & set(node.left_child.ids)
215+
if not conflicts:
216+
rmap = {id: ref.id for ref, id in node.right_child.input_output_pairs}
217+
new_join = nodes.JoinNode(
218+
node.left_child,
219+
node.right_child.child,
220+
conditions=tuple(
221+
(lref, rref.remap_column_refs(rmap))
222+
for lref, rref in node.conditions
223+
),
224+
type=node.type,
225+
propogate_order=node.propogate_order,
226+
)
227+
new_select = nodes.SelectionNode(
228+
new_join,
229+
(
230+
*(nodes.AliasedRef.identity(id) for id in node.left_child.ids),
231+
*node.right_child.input_output_pairs,
232+
),
233+
)
234+
return new_select
235+
163236
return node
164237

165238

0 commit comments

Comments
 (0)