@@ -81,6 +81,10 @@ def pull_up_select_unary(node: nodes.UnaryNode) -> nodes.BigFrameNode:
8181 if not isinstance (child , nodes .SelectionNode ):
8282 return node
8383
84+ # case where selection must be kept in place to prevent ambiguity
85+ if set (child .child .ids ) & set (node .defined_variables ):
86+ return node
87+
8488 # Schema-preserving nodes
8589 if isinstance (
8690 node ,
@@ -157,9 +161,78 @@ def pull_up_select_unary(node: nodes.UnaryNode) -> nodes.BigFrameNode:
157161 return node
158162
159163
160- def pull_up_selects_under_join (node : nodes .JoinNode ) -> nodes .JoinNode :
161- # Can in theory pull up selects here, but it is a bit dangerous, in particular or self-joins, when there are more transforms to do.
162- # TODO: Safely pull up selects above join
164+ def pull_up_selects_under_join (node : nodes .JoinNode ) -> nodes .BigFrameNode :
165+ if isinstance (node .left_child , nodes .SelectionNode ) and isinstance (
166+ node .right_child , nodes .SelectionNode
167+ ):
168+ conflicts = set (node .left_child .child .ids ) & set (node .right_child .child .ids )
169+ if not conflicts :
170+ lmap = {id : ref .id for ref , id in node .left_child .input_output_pairs }
171+ rmap = {id : ref .id for ref , id in node .right_child .input_output_pairs }
172+ new_join = nodes .JoinNode (
173+ node .left_child .child ,
174+ node .right_child .child ,
175+ conditions = tuple (
176+ (lref .remap_column_refs (lmap ), rref .remap_column_refs (rmap ))
177+ for lref , rref in node .conditions
178+ ),
179+ type = node .type ,
180+ propogate_order = node .propogate_order ,
181+ )
182+ new_select = nodes .SelectionNode (
183+ new_join ,
184+ (
185+ * node .left_child .input_output_pairs ,
186+ * node .right_child .input_output_pairs ,
187+ ),
188+ )
189+ return new_select
190+ elif isinstance (node .left_child , nodes .SelectionNode ):
191+ conflicts = set (node .left_child .child .ids ) & set (node .right_child .ids )
192+ if not conflicts :
193+ lmap = {id : ref .id for ref , id in node .left_child .input_output_pairs }
194+ new_join = nodes .JoinNode (
195+ node .left_child .child ,
196+ node .right_child ,
197+ conditions = tuple (
198+ (lref .remap_column_refs (lmap ), rref )
199+ for lref , rref in node .conditions
200+ ),
201+ type = node .type ,
202+ propogate_order = node .propogate_order ,
203+ )
204+ new_select = nodes .SelectionNode (
205+ new_join ,
206+ (
207+ * node .left_child .input_output_pairs ,
208+ * (nodes .AliasedRef .identity (id ) for id in node .right_child .ids ),
209+ ),
210+ )
211+ return new_select
212+
213+ elif isinstance (node .right_child , nodes .SelectionNode ):
214+ conflicts = set (node .right_child .child .ids ) & set (node .left_child .ids )
215+ if not conflicts :
216+ rmap = {id : ref .id for ref , id in node .right_child .input_output_pairs }
217+ new_join = nodes .JoinNode (
218+ node .left_child ,
219+ node .right_child .child ,
220+ conditions = tuple (
221+ (lref , rref .remap_column_refs (rmap ))
222+ for lref , rref in node .conditions
223+ ),
224+ type = node .type ,
225+ propogate_order = node .propogate_order ,
226+ )
227+ new_select = nodes .SelectionNode (
228+ new_join ,
229+ (
230+ * (nodes .AliasedRef .identity (id ) for id in node .left_child .ids ),
231+ * node .right_child .input_output_pairs ,
232+ ),
233+ )
234+ return new_select
235+
163236 return node
164237
165238
0 commit comments