diff --git a/datafusion/physical-expr/src/utils/mod.rs b/datafusion/physical-expr/src/utils/mod.rs index cd476ee3b31a..0a6d2cf85ca3 100644 --- a/datafusion/physical-expr/src/utils/mod.rs +++ b/datafusion/physical-expr/src/utils/mod.rs @@ -23,6 +23,7 @@ use std::sync::Arc; use crate::PhysicalExpr; use crate::PhysicalSortExpr; +use crate::expressions::UnKnownColumn; use crate::expressions::{BinaryExpr, Column}; use crate::tree_node::ExprContext; @@ -238,6 +239,21 @@ pub fn collect_columns(expr: &Arc) -> HashSet { columns } +pub fn have_unknown_columns(expr: &Arc) -> bool { + let mut found = false; + expr.apply(|e| { + if e.as_any().downcast_ref::().is_some() { + found = true; + Ok(TreeNodeRecursion::Stop) + } else { + Ok(TreeNodeRecursion::Continue) + } + }) + .expect("no way to return error during recursion"); + + found +} + /// Re-assign indices of [`Column`]s within the given [`PhysicalExpr`] according to /// the provided [`Schema`]. /// diff --git a/datafusion/physical-plan/src/filter_pushdown.rs b/datafusion/physical-plan/src/filter_pushdown.rs index 1274e954eaeb..fe1403bc3a35 100644 --- a/datafusion/physical-plan/src/filter_pushdown.rs +++ b/datafusion/physical-plan/src/filter_pushdown.rs @@ -38,7 +38,9 @@ use std::collections::HashSet; use std::sync::Arc; use datafusion_common::Result; -use datafusion_physical_expr::utils::{collect_columns, reassign_expr_columns}; +use datafusion_physical_expr::utils::{ + collect_columns, have_unknown_columns, reassign_expr_columns, +}; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; use itertools::Itertools; @@ -339,7 +341,9 @@ impl ChildFilterDescription { .iter() .all(|col| child_column_names.contains(col.name())); - if all_columns_exist { + let have_unknown_columns = have_unknown_columns(filter); + + if all_columns_exist && !have_unknown_columns { // All columns exist in child - we can push down // Need to reassign column indices to match child schema let reassigned_filter = diff --git a/datafusion/physical-plan/src/lib.rs b/datafusion/physical-plan/src/lib.rs index ec8e154caec9..79c0e9ef5e37 100644 --- a/datafusion/physical-plan/src/lib.rs +++ b/datafusion/physical-plan/src/lib.rs @@ -92,6 +92,7 @@ pub mod streaming; pub mod tree_node; pub mod union; pub mod unnest; +pub mod util; pub mod windows; pub mod work_table; pub mod udaf { diff --git a/datafusion/physical-plan/src/projection.rs b/datafusion/physical-plan/src/projection.rs index a56e9272f119..435071bf1daa 100644 --- a/datafusion/physical-plan/src/projection.rs +++ b/datafusion/physical-plan/src/projection.rs @@ -32,6 +32,7 @@ use crate::filter_pushdown::{ FilterPushdownPropagation, }; use crate::joins::utils::{ColumnIndex, JoinFilter, JoinOn, JoinOnRef}; +use crate::util::PhysicalColumnRewriter; use crate::{DisplayFormatType, ExecutionPlan, PhysicalExpr}; use std::any::Any; use std::collections::HashMap; @@ -45,7 +46,7 @@ use datafusion_common::config::ConfigOptions; use datafusion_common::tree_node::{ Transformed, TransformedResult, TreeNode, TreeNodeRecursion, }; -use datafusion_common::{JoinSide, Result, internal_err}; +use datafusion_common::{DataFusionError, JoinSide, Result, internal_err}; use datafusion_execution::TaskContext; use datafusion_physical_expr::equivalence::ProjectionMapping; use datafusion_physical_expr::projection::Projector; @@ -190,6 +191,29 @@ impl ProjectionExec { input.boundedness(), )) } + + /// Collect reverse alias mapping from projection expressions. + /// The result hash map is a map from aliased Column in parent to original expr. + fn collect_reverse_alias( + &self, + ) -> Result>> { + let mut alias_map = datafusion_common::HashMap::new(); + for projection in self.projection_expr().iter() { + let (aliased_index, _output_field) = self + .projector + .output_schema() + .column_with_name(&projection.alias) + .ok_or_else(|| { + DataFusionError::Internal(format!( + "Expr {} with alias {} not found in output schema", + projection.expr, projection.alias + )) + })?; + let aliased_col = Column::new(&projection.alias, aliased_index); + alias_map.insert(aliased_col, Arc::clone(&projection.expr)); + } + Ok(alias_map) + } } impl DisplayAs for ProjectionExec { @@ -343,10 +367,15 @@ impl ExecutionPlan for ProjectionExec { parent_filters: Vec>, _config: &ConfigOptions, ) -> Result { - // TODO: In future, we can try to handle inverting aliases here. - // For the time being, we pass through untransformed filters, so filters on aliases are not handled. - // https://github.com/apache/datafusion/issues/17246 - FilterDescription::from_children(parent_filters, &self.children()) + // expand alias column to original expr in parent filters + let invert_alias_map = self.collect_reverse_alias()?; + + let mut rewriter = PhysicalColumnRewriter::new(invert_alias_map); + let rewritten_filters = parent_filters + .into_iter() + .map(|filter| filter.rewrite(&mut rewriter).map(|t| t.data)) + .collect::>>()?; + FilterDescription::from_children(rewritten_filters, &self.children()) } fn handle_child_pushdown_result( @@ -1002,6 +1031,7 @@ mod tests { use std::sync::Arc; use crate::common::collect; + use crate::filter_pushdown::PushedDown; use crate::test; use crate::test::exec::StatisticsExec; @@ -1010,7 +1040,9 @@ mod tests { use datafusion_common::stats::{ColumnStatistics, Precision, Statistics}; use datafusion_expr::Operator; - use datafusion_physical_expr::expressions::{BinaryExpr, Column, Literal, col}; + use datafusion_physical_expr::expressions::{ + BinaryExpr, Column, DynamicFilterPhysicalExpr, Literal, binary, col, lit, + }; #[test] fn test_collect_column_indices() -> Result<()> { @@ -1199,4 +1231,357 @@ mod tests { ); assert!(stats.total_byte_size.is_exact().unwrap_or(false)); } + + #[test] + fn test_filter_pushdown_with_alias() -> Result<()> { + let input_schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); + let input = Arc::new(StatisticsExec::new( + Statistics::new_unknown(&input_schema), + input_schema.clone(), + )); + + // project "a" as "b" + let projection = ProjectionExec::try_new( + vec![ProjectionExpr { + expr: Arc::new(Column::new("a", 0)), + alias: "b".to_string(), + }], + input, + )?; + + // filter "b > 5" + let filter = Arc::new(BinaryExpr::new( + Arc::new(Column::new("b", 0)), + Operator::Gt, + Arc::new(Literal::new(ScalarValue::Int32(Some(5)))), + )) as Arc; + + let description = projection.gather_filters_for_pushdown( + FilterPushdownPhase::Post, + vec![filter], + &ConfigOptions::default(), + )?; + + // Should be converted to "a > 5" + // "a" is index 0 in input + let expected_filter = Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 0)), + Operator::Gt, + Arc::new(Literal::new(ScalarValue::Int32(Some(5)))), + )) as Arc; + + assert_eq!(description.self_filters(), vec![vec![]]); + let pushed_filters = &description.parent_filters()[0]; + assert_eq!( + format!("{}", pushed_filters[0].predicate), + format!("{}", expected_filter) + ); + + Ok(()) + } + + #[test] + fn test_filter_pushdown_with_multiple_aliases() -> Result<()> { + let input_schema = Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Int32, false), + ]); + let input = Arc::new(StatisticsExec::new( + Statistics { + column_statistics: vec![Default::default(); input_schema.fields().len()], + ..Default::default() + }, + input_schema.clone(), + )); + + // project "a" as "x", "b" as "y" + let projection = ProjectionExec::try_new( + vec![ + ProjectionExpr { + expr: Arc::new(Column::new("a", 0)), + alias: "x".to_string(), + }, + ProjectionExpr { + expr: Arc::new(Column::new("b", 1)), + alias: "y".to_string(), + }, + ], + input, + )?; + + // filter "x > 5" + let filter1 = Arc::new(BinaryExpr::new( + Arc::new(Column::new("x", 0)), + Operator::Gt, + Arc::new(Literal::new(ScalarValue::Int32(Some(5)))), + )) as Arc; + + // filter "y < 10" + let filter2 = Arc::new(BinaryExpr::new( + Arc::new(Column::new("y", 1)), + Operator::Lt, + Arc::new(Literal::new(ScalarValue::Int32(Some(10)))), + )) as Arc; + + let description = projection.gather_filters_for_pushdown( + FilterPushdownPhase::Post, + vec![filter1, filter2], + &ConfigOptions::default(), + )?; + + // Should be converted to "a > 5" and "b < 10" + let expected_filter1 = Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 0)), + Operator::Gt, + Arc::new(Literal::new(ScalarValue::Int32(Some(5)))), + )) as Arc; + + let expected_filter2 = Arc::new(BinaryExpr::new( + Arc::new(Column::new("b", 1)), + Operator::Lt, + Arc::new(Literal::new(ScalarValue::Int32(Some(10)))), + )) as Arc; + + let pushed_filters = &description.parent_filters()[0]; + assert_eq!(pushed_filters.len(), 2); + // Note: The order of filters is preserved + assert_eq!( + format!("{}", pushed_filters[0].predicate), + format!("{}", expected_filter1) + ); + assert_eq!( + format!("{}", pushed_filters[1].predicate), + format!("{}", expected_filter2) + ); + + Ok(()) + } + + #[test] + fn test_filter_pushdown_with_mixed_columns() -> Result<()> { + let input_schema = Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Int32, false), + ]); + let input = Arc::new(StatisticsExec::new( + Statistics { + column_statistics: vec![Default::default(); input_schema.fields().len()], + ..Default::default() + }, + input_schema.clone(), + )); + + // project "a" as "x", "b" as "b" (pass through) + let projection = ProjectionExec::try_new( + vec![ + ProjectionExpr { + expr: Arc::new(Column::new("a", 0)), + alias: "x".to_string(), + }, + ProjectionExpr { + expr: Arc::new(Column::new("b", 1)), + alias: "b".to_string(), + }, + ], + input, + )?; + + // filter "x > 5" + let filter1 = Arc::new(BinaryExpr::new( + Arc::new(Column::new("x", 0)), + Operator::Gt, + Arc::new(Literal::new(ScalarValue::Int32(Some(5)))), + )) as Arc; + + // filter "b < 10" (using output index 1 which corresponds to 'b') + let filter2 = Arc::new(BinaryExpr::new( + Arc::new(Column::new("b", 1)), + Operator::Lt, + Arc::new(Literal::new(ScalarValue::Int32(Some(10)))), + )) as Arc; + + let description = projection.gather_filters_for_pushdown( + FilterPushdownPhase::Post, + vec![filter1, filter2], + &ConfigOptions::default(), + )?; + + let pushed_filters = &description.parent_filters()[0]; + assert_eq!(pushed_filters.len(), 2); + // "x" -> "a" (index 0) + let expected_filter1 = "a@0 > 5"; + // "b" -> "b" (index 1) + let expected_filter2 = "b@1 < 10"; + + assert_eq!(format!("{}", pushed_filters[0].predicate), expected_filter1); + assert_eq!(format!("{}", pushed_filters[1].predicate), expected_filter2); + + Ok(()) + } + + #[test] + fn test_filter_pushdown_with_complex_expression() -> Result<()> { + let input_schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); + let input = Arc::new(StatisticsExec::new( + Statistics { + column_statistics: vec![Default::default(); input_schema.fields().len()], + ..Default::default() + }, + input_schema.clone(), + )); + + // project "a + 1" as "z" + let projection = ProjectionExec::try_new( + vec![ProjectionExpr { + expr: Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", 0)), + Operator::Plus, + Arc::new(Literal::new(ScalarValue::Int32(Some(1)))), + )), + alias: "z".to_string(), + }], + input, + )?; + + // filter "z > 10" + let filter = Arc::new(BinaryExpr::new( + Arc::new(Column::new("z", 0)), + Operator::Gt, + Arc::new(Literal::new(ScalarValue::Int32(Some(10)))), + )) as Arc; + + let description = projection.gather_filters_for_pushdown( + FilterPushdownPhase::Post, + vec![filter], + &ConfigOptions::default(), + )?; + + // expand to `a + 1 > 10` + let pushed_filters = &description.parent_filters()[0]; + assert!(matches!(pushed_filters[0].discriminant, PushedDown::Yes)); + assert_eq!(format!("{}", pushed_filters[0].predicate), "a@0 + 1 > 10"); + + Ok(()) + } + + #[test] + fn test_filter_pushdown_with_unknown_column() -> Result<()> { + let input_schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); + let input = Arc::new(StatisticsExec::new( + Statistics { + column_statistics: vec![Default::default(); input_schema.fields().len()], + ..Default::default() + }, + input_schema.clone(), + )); + + // project "a" as "a" + let projection = ProjectionExec::try_new( + vec![ProjectionExpr { + expr: Arc::new(Column::new("a", 0)), + alias: "a".to_string(), + }], + input, + )?; + + // filter "unknown_col > 5" - using a column name that doesn't exist in projection output + // Column constructor: name, index. Index 1 doesn't exist. + let filter = Arc::new(BinaryExpr::new( + Arc::new(Column::new("unknown_col", 1)), + Operator::Gt, + Arc::new(Literal::new(ScalarValue::Int32(Some(5)))), + )) as Arc; + + let description = projection.gather_filters_for_pushdown( + FilterPushdownPhase::Post, + vec![filter], + &ConfigOptions::default(), + )?; + + let pushed_filters = &description.parent_filters()[0]; + assert!(matches!(pushed_filters[0].discriminant, PushedDown::No)); + // The column shouldn't be found in the alias map, so it should become UnKnownColumn + assert_eq!( + format!("{}", pushed_filters[0].predicate), + "unknown_col > 5" + ); + + Ok(()) + } + + /// Test that `DynamicFilterPhysicalExpr` can correctly update its child expression + /// i.e. starting with lit(true) and after update it becomes `a > 5` + /// with projection [b - 1 as a], the pushed down filter should be `b - 1 > 5` + #[test] + fn test_dyn_filter_projection_pushdown_update_child() -> Result<()> { + let input_schema = + Arc::new(Schema::new(vec![Field::new("b", DataType::Int32, false)])); + + let input = Arc::new(StatisticsExec::new( + Statistics { + column_statistics: vec![Default::default(); input_schema.fields().len()], + ..Default::default() + }, + input_schema.as_ref().clone(), + )); + + // project "b" - 1 as "a" + let projection = ProjectionExec::try_new( + vec![ProjectionExpr { + expr: binary( + Arc::new(Column::new("b", 0)), + Operator::Minus, + lit(1), + &input_schema, + ) + .unwrap(), + alias: "a".to_string(), + }], + input, + )?; + + // simulate projection's parent create a dynamic filter on "a" + let projected_schema = projection.schema(); + let col_a = col("a", &projected_schema)?; + let dynamic_filter = Arc::new(DynamicFilterPhysicalExpr::new( + vec![Arc::clone(&col_a)], + lit(true), + )); + // Initial state should be lit(true) + let current = dynamic_filter.current()?; + assert_eq!(format!("{current}"), "true"); + + let dyn_phy_expr: Arc = Arc::clone(&dynamic_filter) as _; + + let description = projection.gather_filters_for_pushdown( + FilterPushdownPhase::Post, + vec![dyn_phy_expr], + &ConfigOptions::default(), + )?; + + let pushed_filters = &description.parent_filters()[0][0]; + + // Check currently pushed_filters is lit(true) + assert_eq!( + format!("{}", pushed_filters.predicate), + "DynamicFilter [ empty ]" + ); + + // Update to a > 5 (after projection, b is now called a) + let new_expr = + Arc::new(BinaryExpr::new(Arc::clone(&col_a), Operator::Gt, lit(5i32))); + dynamic_filter.update(new_expr)?; + + // Now it should be a > 5 + let current = dynamic_filter.current()?; + assert_eq!(format!("{current}"), "a@0 > 5"); + + // Check currently pushed_filters is b - 1 > 5 (because b - 1 is projected as a) + assert_eq!( + format!("{}", pushed_filters.predicate), + "DynamicFilter [ b@0 - 1 > 5 ]" + ); + + Ok(()) + } } diff --git a/datafusion/physical-plan/src/util.rs b/datafusion/physical-plan/src/util.rs new file mode 100644 index 000000000000..df95e7c6d321 --- /dev/null +++ b/datafusion/physical-plan/src/util.rs @@ -0,0 +1,312 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use datafusion_common::{ + HashMap, + tree_node::{Transformed, TreeNodeRecursion, TreeNodeRewriter}, +}; +use datafusion_physical_expr::{ + PhysicalExpr, + expressions::{Column, UnKnownColumn}, +}; + +/// Rewrite column references in a physical expr according to a mapping. +pub struct PhysicalColumnRewriter { + /// Mapping from original column to new column. + pub column_map: HashMap>, +} + +impl PhysicalColumnRewriter { + /// Create a new PhysicalColumnRewriter with the given column mapping. + pub fn new(column_map: HashMap>) -> Self { + Self { column_map } + } +} + +impl TreeNodeRewriter for PhysicalColumnRewriter { + type Node = Arc; + + fn f_down( + &mut self, + node: Self::Node, + ) -> datafusion_common::Result> { + if let Some(column) = node.as_any().downcast_ref::() { + if let Some(new_column) = self.column_map.get(column) { + // jump to prevent rewriting the new sub-expression again + return Ok(Transformed::new( + Arc::clone(new_column), + true, + TreeNodeRecursion::Jump, + )); + } else { + return Ok(Transformed::yes(Arc::new(UnKnownColumn::new( + column.name(), + )))); + } + } + Ok(Transformed::no(node)) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::datatypes::{DataType, Field, Schema}; + use datafusion_common::{Result, tree_node::TreeNode}; + use datafusion_physical_expr::{ + PhysicalExpr, + expressions::{Column, binary, col, lit}, + }; + use std::sync::Arc; + + /// Helper function to create a test schema + fn create_test_schema() -> Arc { + Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + Field::new("c", DataType::Int32, true), + Field::new("d", DataType::Int32, true), + Field::new("e", DataType::Int32, true), + Field::new("new_col", DataType::Int32, true), + Field::new("inner_col", DataType::Int32, true), + Field::new("another_col", DataType::Int32, true), + ])) + } + + /// Helper function to create a complex nested expression with multiple columns + /// Create: (col_a + col_b) * (col_c - col_d) + col_e + fn create_complex_expression(schema: &Schema) -> Arc { + let col_a = col("a", schema).unwrap(); + let col_b = col("b", schema).unwrap(); + let col_c = col("c", schema).unwrap(); + let col_d = col("d", schema).unwrap(); + let col_e = col("e", schema).unwrap(); + + let add_expr = + binary(col_a, datafusion_expr::Operator::Plus, col_b, schema).unwrap(); + let sub_expr = + binary(col_c, datafusion_expr::Operator::Minus, col_d, schema).unwrap(); + let mul_expr = binary( + add_expr, + datafusion_expr::Operator::Multiply, + sub_expr, + schema, + ) + .unwrap(); + binary(mul_expr, datafusion_expr::Operator::Plus, col_e, schema).unwrap() + } + + /// Helper function to create a deeply nested expression + /// Create: col_a + (col_b + (col_c + (col_d + col_e))) + fn create_deeply_nested_expression(schema: &Schema) -> Arc { + let col_a = col("a", schema).unwrap(); + let col_b = col("b", schema).unwrap(); + let col_c = col("c", schema).unwrap(); + let col_d = col("d", schema).unwrap(); + let col_e = col("e", schema).unwrap(); + + let inner1 = + binary(col_d, datafusion_expr::Operator::Plus, col_e, schema).unwrap(); + let inner2 = + binary(col_c, datafusion_expr::Operator::Plus, inner1, schema).unwrap(); + let inner3 = + binary(col_b, datafusion_expr::Operator::Plus, inner2, schema).unwrap(); + binary(col_a, datafusion_expr::Operator::Plus, inner3, schema).unwrap() + } + + #[test] + fn test_simple_column_replacement_with_jump() -> Result<()> { + let schema = create_test_schema(); + + // Test that Jump prevents re-processing of replaced columns + let mut column_map = HashMap::new(); + column_map.insert(Column::new_with_schema("a", &schema).unwrap(), lit(42i32)); + column_map.insert( + Column::new_with_schema("b", &schema).unwrap(), + lit("replaced_b"), + ); + + let mut rewriter = PhysicalColumnRewriter::new(column_map); + let expr = create_complex_expression(&schema); + + let result = expr.rewrite(&mut rewriter)?; + + // Verify the transformation occurred + assert!(result.transformed); + + assert_eq!( + format!("{}", result.data), + "(42 + replaced_b) * (c - d) + e" + ); + + Ok(()) + } + + #[test] + fn test_nested_column_replacement_with_jump() -> Result<()> { + let schema = create_test_schema(); + // Test Jump behavior with deeply nested expressions + let mut column_map = HashMap::new(); + // Replace col_c with a complex expression containing new columns + let replacement_expr = binary( + lit(100i32), + datafusion_expr::Operator::Plus, + col("new_col", &schema).unwrap(), + &schema, + ) + .unwrap(); + column_map.insert( + Column::new_with_schema("c", &schema).unwrap(), + replacement_expr, + ); + + let mut rewriter = PhysicalColumnRewriter::new(column_map); + let expr = create_deeply_nested_expression(&schema); + + let result = expr.rewrite(&mut rewriter)?; + + // Verify transformation occurred + assert!(result.transformed); + + assert_eq!( + format!("{}", result.data), + "a + b + 100 + new_col@5 + d + e" + ); + + Ok(()) + } + + #[test] + fn test_circular_reference_prevention() -> Result<()> { + let schema = create_test_schema(); + // Test that Jump prevents infinite recursion with circular references + let mut column_map = HashMap::new(); + + // Create a circular reference: col_a -> col_b -> col_a (but Jump should prevent the second visit) + column_map.insert( + Column::new_with_schema("a", &schema).unwrap(), + col("b", &schema).unwrap(), + ); + column_map.insert( + Column::new_with_schema("b", &schema).unwrap(), + col("a", &schema).unwrap(), + ); + + let mut rewriter = PhysicalColumnRewriter::new(column_map); + + // Start with an expression containing col_a + let expr = binary( + col("a", &schema).unwrap(), + datafusion_expr::Operator::Plus, + col("b", &schema).unwrap(), + &schema, + ) + .unwrap(); + + let result = expr.rewrite(&mut rewriter)?; + + // Verify transformation occurred + assert!(result.transformed); + + assert_eq!(format!("{}", result.data), "b@1 + a@0"); + + Ok(()) + } + + #[test] + fn test_multiple_replacements_in_same_expression() -> Result<()> { + let schema = create_test_schema(); + // Test multiple column replacements in the same complex expression + let mut column_map = HashMap::new(); + + // Replace multiple columns with literals + column_map.insert(Column::new_with_schema("a", &schema).unwrap(), lit(10i32)); + column_map.insert(Column::new_with_schema("c", &schema).unwrap(), lit(20i32)); + column_map.insert(Column::new_with_schema("e", &schema).unwrap(), lit(30i32)); + + let mut rewriter = PhysicalColumnRewriter::new(column_map); + let expr = create_complex_expression(&schema); // (col_a + col_b) * (col_c - col_d) + col_e + + let result = expr.rewrite(&mut rewriter)?; + + // Verify transformation occurred + assert!(result.transformed); + assert_eq!(format!("{}", result.data), "(10 + b) * (20 - d) + 30"); + + Ok(()) + } + + #[test] + fn test_jump_with_complex_replacement_expression() -> Result<()> { + let schema = create_test_schema(); + // Test Jump behavior when replacing with very complex expressions + let mut column_map = HashMap::new(); + + // Replace col_a with a complex nested expression + let inner_expr = binary( + lit(5i32), + datafusion_expr::Operator::Multiply, + col("a", &schema).unwrap(), + &schema, + ) + .unwrap(); + let middle_expr = binary( + inner_expr, + datafusion_expr::Operator::Plus, + lit(3i32), + &schema, + ) + .unwrap(); + let complex_replacement = binary( + middle_expr, + datafusion_expr::Operator::Minus, + col("another_col", &schema).unwrap(), + &schema, + ) + .unwrap(); + + column_map.insert( + Column::new_with_schema("a", &schema).unwrap(), + complex_replacement, + ); + + let mut rewriter = PhysicalColumnRewriter::new(column_map); + + // Create expression: col_a + col_b + let expr = binary( + col("a", &schema).unwrap(), + datafusion_expr::Operator::Plus, + col("b", &schema).unwrap(), + &schema, + ) + .unwrap(); + + let result = expr.rewrite(&mut rewriter)?; + + assert_eq!( + format!("{}", result.data), + "5 * a@0 + 3 - another_col@7 + b" + ); + + // Verify transformation occurred + assert!(result.transformed); + + Ok(()) + } +} diff --git a/datafusion/sqllogictest/test_files/topk.slt b/datafusion/sqllogictest/test_files/topk.slt index aba468d21fd0..8a1fef072229 100644 --- a/datafusion/sqllogictest/test_files/topk.slt +++ b/datafusion/sqllogictest/test_files/topk.slt @@ -383,7 +383,7 @@ physical_plan 03)----ProjectionExec: expr=[__common_expr_1@0 as number_plus, number@1 as number, __common_expr_1@0 as other_number_plus, age@2 as age] 04)------ProjectionExec: expr=[CAST(number@0 AS Int64) + 1 as __common_expr_1, number@0 as number, age@1 as age] 05)--------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1, maintains_sort_order=true -06)----------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/topk/partial_sorted/1.parquet]]}, projection=[number, age], output_ordering=[number@0 DESC], file_type=parquet +06)----------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/topk/partial_sorted/1.parquet]]}, projection=[number, age], output_ordering=[number@0 DESC], file_type=parquet, predicate=DynamicFilter [ empty ] # Cleanup statement ok