From cc9b40ed42d9baea11aaa72df3b61b0be6aff930 Mon Sep 17 00:00:00 2001 From: Georgi Krastev Date: Fri, 19 Sep 2025 18:22:52 +0200 Subject: [PATCH 1/2] Reduce cloning in LogicalPlanBuilder - Migrate function arguments from `LogicalPlan` to `impl Into>` - Update usages (mostly in tests) --- datafusion/core/src/physical_planner.rs | 8 +- datafusion/expr/src/expr_rewriter/mod.rs | 45 ++-- datafusion/expr/src/logical_plan/builder.rs | 248 +++++++++--------- .../optimizer/src/analyzer/type_coercion.rs | 7 +- .../optimizer/src/optimize_projections/mod.rs | 11 +- datafusion/optimizer/src/optimize_unions.rs | 11 +- datafusion/optimizer/src/push_down_filter.rs | 4 +- .../optimizer/src/scalar_subquery_to_join.rs | 58 ++-- datafusion/optimizer/src/test/mod.rs | 11 +- 9 files changed, 203 insertions(+), 200 deletions(-) diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index 9eaf1403e5757..0208bc9c6801f 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -1054,17 +1054,15 @@ impl DefaultPhysicalPlanner { let (left, left_col_keys, left_projected) = wrap_projection_for_join_if_necessary( &left_keys, - original_left.as_ref().clone(), + Arc::clone(original_left), )?; let (right, right_col_keys, right_projected) = wrap_projection_for_join_if_necessary( &right_keys, - original_right.as_ref().clone(), + Arc::clone(original_right), )?; - let column_on = (left_col_keys, right_col_keys); - let left = Arc::new(left); - let right = Arc::new(right); + let column_on = (left_col_keys, right_col_keys); let (new_join, requalified) = Join::try_new_with_project_input( node, Arc::clone(&left), diff --git a/datafusion/expr/src/expr_rewriter/mod.rs b/datafusion/expr/src/expr_rewriter/mod.rs index a0faca76e91e4..260c5aa29b348 100644 --- a/datafusion/expr/src/expr_rewriter/mod.rs +++ b/datafusion/expr/src/expr_rewriter/mod.rs @@ -41,7 +41,7 @@ pub use order_by::rewrite_sort_cols_by_aggs; /// Trait for rewriting [`Expr`]s into function calls. /// -/// This trait is used with `FunctionRegistry::register_function_rewrite` to +/// This trait is used with `FunctionRegistry::register_function_rewrite` /// to evaluating `Expr`s using functions that may not be built in to DataFusion /// /// For example, concatenating arrays `a || b` is represented as @@ -49,7 +49,7 @@ pub use order_by::rewrite_sort_cols_by_aggs; /// `array_concat` from the `functions-nested` crate. // This is not used in datafusion internally, but it is still helpful for downstream project so don't remove it. pub trait FunctionRewrite: Debug { - /// Return a human readable name for this rewrite + /// Return a human-readable name for this rewrite fn name(&self) -> &str; /// Potentially rewrite `expr` to some other expression @@ -219,26 +219,29 @@ pub fn strip_outer_reference(expr: Expr) -> Expr { /// Returns plan with expressions coerced to types compatible with /// schema types pub fn coerce_plan_expr_for_schema( - plan: LogicalPlan, + plan: Arc, schema: &DFSchema, -) -> Result { - match plan { +) -> Result> { + if matches!(plan.as_ref(), LogicalPlan::Projection(_)) { // special case Projection to avoid adding multiple projections - LogicalPlan::Projection(Projection { expr, input, .. }) => { - let new_exprs = coerce_exprs_for_schema(expr, input.schema(), schema)?; - let projection = Projection::try_new(new_exprs, input)?; - Ok(LogicalPlan::Projection(projection)) - } - _ => { - let exprs: Vec = plan.schema().iter().map(Expr::from).collect(); - let new_exprs = coerce_exprs_for_schema(exprs, plan.schema(), schema)?; - let add_project = new_exprs.iter().any(|expr| expr.try_as_col().is_none()); - if add_project { - let projection = Projection::try_new(new_exprs, Arc::new(plan))?; - Ok(LogicalPlan::Projection(projection)) - } else { - Ok(plan) - } + let LogicalPlan::Projection(Projection { expr, input, .. }) = + Arc::unwrap_or_clone(plan) + else { + unreachable!() + }; + + let new_exprs = coerce_exprs_for_schema(expr, input.schema(), schema)?; + let projection = Projection::try_new(new_exprs, input)?; + Ok(Arc::new(LogicalPlan::Projection(projection))) + } else { + let exprs: Vec = plan.schema().iter().map(Expr::from).collect(); + let new_exprs = coerce_exprs_for_schema(exprs, plan.schema(), schema)?; + let add_project = new_exprs.iter().any(|expr| expr.try_as_col().is_none()); + if add_project { + let projection = Projection::try_new(new_exprs, plan)?; + Ok(Arc::new(LogicalPlan::Projection(projection))) + } else { + Ok(plan) } } } @@ -427,7 +430,7 @@ mod test { fn normalize_cols() { let expr = col("a") + col("b") + col("c"); - // Schemas with some matching and some non matching cols + // Schemas with some matching and some non-matching cols let schema_a = make_schema_with_empty_metadata( vec![Some("tableA".into()), Some("tableA".into())], vec!["a", "aa"], diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index 6f654428e41a1..9211c11045941 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -62,7 +62,9 @@ use datafusion_common::{ }; use datafusion_expr_common::type_coercion::binary::type_union_resolution; +use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use indexmap::IndexSet; +use itertools::Itertools; /// Default table name for unnamed table pub const UNNAMED_TABLE: &str = "?table?"; @@ -177,9 +179,10 @@ impl LogicalPlanBuilder { pub fn to_recursive_query( self, name: String, - recursive_term: LogicalPlan, + recursive_term: impl Into>, is_distinct: bool, ) -> Result { + let recursive_term = recursive_term.into(); // Ensure that the static term and the recursive term have the same number of fields let static_fields_len = self.plan.schema().fields().len(); let recursive_fields_len = recursive_term.schema().fields().len(); @@ -191,12 +194,12 @@ impl LogicalPlanBuilder { ); } // Ensure that the recursive term has the same field types as the static term - let coerced_recursive_term = + let recursive_term = coerce_plan_expr_for_schema(recursive_term, self.plan.schema())?; Ok(Self::from(LogicalPlan::RecursiveQuery(RecursiveQuery { name, static_term: self.plan, - recursive_term: Arc::new(coerced_recursive_term), + recursive_term, is_distinct, }))) } @@ -207,7 +210,7 @@ impl LogicalPlanBuilder { /// /// so it's usually better to override the default names with a table alias list. /// - /// If the values include params/binders such as $1, $2, $3, etc, then the `param_data_types` should be provided. + /// If the values include params/binders such as $1, $2, $3, etc. then the `param_data_types` should be provided. pub fn values(values: Vec>) -> Result { if values.is_empty() { return plan_err!("Values list cannot be empty"); @@ -239,7 +242,7 @@ impl LogicalPlanBuilder { /// The column names are not specified by the SQL standard and different database systems do it differently, /// so it's usually better to override the default names with a table alias list. /// - /// If the values include params/binders such as $1, $2, $3, etc, then the `param_data_types` should be provided. + /// If the values include params/binders such as $1, $2, $3, etc. then the `param_data_types` should be provided. pub fn values_with_schema( values: Vec>, schema: &DFSchemaRef, @@ -415,14 +418,14 @@ impl LogicalPlanBuilder { /// Create a [CopyTo] for copying the contents of this builder to the specified file(s) pub fn copy_to( - input: LogicalPlan, + input: impl Into>, output_url: String, file_type: Arc, options: HashMap, partition_by: Vec, ) -> Result { Ok(Self::new(LogicalPlan::Copy(CopyTo::new( - Arc::new(input), + input.into(), output_url, partition_by, file_type, @@ -464,7 +467,7 @@ impl LogicalPlanBuilder { /// # } /// ``` pub fn insert_into( - input: LogicalPlan, + input: impl Into>, table_name: impl Into, target: Arc, insert_op: InsertOp, @@ -473,7 +476,7 @@ impl LogicalPlanBuilder { table_name.into(), target, WriteOp::Insert(insert_op), - Arc::new(input), + input.into(), )))) } @@ -543,10 +546,10 @@ impl LogicalPlanBuilder { /// Wrap a plan in a window pub fn window_plan( - input: LogicalPlan, + input: impl Into>, window_exprs: impl IntoIterator, ) -> Result { - let mut plan = input; + let mut plan = input.into(); let mut groups = group_window_expr_by_sort_keys(window_exprs)?; // To align with the behavior of PostgreSQL, we want the sort_keys sorted as same rule as PostgreSQL that first // we compare the sort key themselves and if one window's sort keys are a prefix of another @@ -568,15 +571,17 @@ impl LogicalPlanBuilder { } key_b.len().cmp(&key_a.len()) }); + for (_, exprs) in groups { let window_exprs = exprs.into_iter().collect::>(); // Partition and sorting is done at physical level, see the EnforceDistribution // and EnforceSorting rules. plan = LogicalPlanBuilder::from(plan) .window(window_exprs)? - .build()?; + .build_arc()?; } - Ok(plan) + + Ok(Arc::unwrap_or_clone(plan)) } /// Apply a projection without alias. @@ -584,7 +589,7 @@ impl LogicalPlanBuilder { self, expr: impl IntoIterator>, ) -> Result { - project(Arc::unwrap_or_clone(self.plan), expr).map(Self::new) + project(self.plan, expr).map(Self::new) } /// Apply a projection without alias with optional validation @@ -593,7 +598,7 @@ impl LogicalPlanBuilder { self, expr: Vec<(impl Into, bool)>, ) -> Result { - project_with_validation(Arc::unwrap_or_clone(self.plan), expr).map(Self::new) + project_with_validation(self.plan, expr).map(Self::new) } /// Select the given column indices @@ -661,7 +666,7 @@ impl LogicalPlanBuilder { /// Apply an alias pub fn alias(self, alias: impl Into) -> Result { - subquery_alias(Arc::unwrap_or_clone(self.plan), alias).map(Self::new) + subquery_alias(self.plan, alias).map(Self::new) } /// Add missing sort columns to all downstream projection @@ -696,43 +701,43 @@ impl LogicalPlanBuilder { curr_plan: LogicalPlan, missing_cols: &IndexSet, is_distinct: bool, - ) -> Result { + ) -> Result> { match curr_plan { LogicalPlan::Projection(Projection { input, mut expr, - schema: _, + schema, }) if missing_cols.iter().all(|c| input.schema().has_column(c)) => { - let mut missing_exprs = missing_cols + let missing_exprs: Vec = missing_cols .iter() .map(|c| normalize_col(Expr::Column(c.clone()), &input)) - .collect::>>()?; + // Do not let duplicate columns to be added, some of the + // missing_cols may be already present but without the new + // projected alias. + .filter_ok(|e| !expr.contains(e)) + .try_collect()?; - // Do not let duplicate columns to be added, some of the - // missing_cols may be already present but without the new - // projected alias. - missing_exprs.retain(|e| !expr.contains(e)); if is_distinct { Self::ambiguous_distinct_check(&missing_exprs, missing_cols, &expr)?; } - expr.extend(missing_exprs); - project(Arc::unwrap_or_clone(input), expr) + + if missing_exprs.is_empty() { + Ok(Transformed::no(LogicalPlan::Projection(Projection { + input, + expr, + schema, + }))) + } else { + expr.extend(missing_exprs); + project(input, expr).map(Transformed::yes) + } } _ => { let is_distinct = is_distinct || matches!(curr_plan, LogicalPlan::Distinct(_)); - let new_inputs = curr_plan - .inputs() - .into_iter() - .map(|input_plan| { - Self::add_missing_columns( - (*input_plan).clone(), - missing_cols, - is_distinct, - ) - }) - .collect::>>()?; - curr_plan.with_new_exprs(curr_plan.expressions(), new_inputs) + curr_plan.map_children(|input| { + Self::add_missing_columns(input, missing_cols, is_distinct) + }) } } } @@ -830,13 +835,13 @@ impl LogicalPlanBuilder { // remove pushed down sort columns let new_expr = schema.columns().into_iter().map(Expr::Column).collect(); - let is_distinct = false; let plan = Self::add_missing_columns( Arc::unwrap_or_clone(self.plan), &missing_cols, is_distinct, - )?; + ) + .data()?; let sort_plan = LogicalPlan::Sort(Sort { expr: normalize_sorts(sorts, &plan)?, @@ -850,36 +855,33 @@ impl LogicalPlanBuilder { } /// Apply a union, preserving duplicate rows - pub fn union(self, plan: LogicalPlan) -> Result { - union(Arc::unwrap_or_clone(self.plan), plan).map(Self::new) + pub fn union(self, plan: impl Into>) -> Result { + union(self.plan, plan).map(Self::new) } /// Apply a union by name, preserving duplicate rows - pub fn union_by_name(self, plan: LogicalPlan) -> Result { - union_by_name(Arc::unwrap_or_clone(self.plan), plan).map(Self::new) + pub fn union_by_name(self, plan: impl Into>) -> Result { + union_by_name(self.plan, plan).map(Self::new) } /// Apply a union by name, removing duplicate rows - pub fn union_by_name_distinct(self, plan: LogicalPlan) -> Result { - let left_plan: LogicalPlan = Arc::unwrap_or_clone(self.plan); - let right_plan: LogicalPlan = plan; - + pub fn union_by_name_distinct( + self, + plan: impl Into>, + ) -> Result { Ok(Self::new(LogicalPlan::Distinct(Distinct::All(Arc::new( - union_by_name(left_plan, right_plan)?, + union_by_name(self.plan, plan)?, ))))) } /// Apply a union, removing duplicate rows - pub fn union_distinct(self, plan: LogicalPlan) -> Result { - let left_plan: LogicalPlan = Arc::unwrap_or_clone(self.plan); - let right_plan: LogicalPlan = plan; - + pub fn union_distinct(self, plan: impl Into>) -> Result { Ok(Self::new(LogicalPlan::Distinct(Distinct::All(Arc::new( - union(left_plan, right_plan)?, + union(self.plan, plan)?, ))))) } - /// Apply deduplication: Only distinct (different) values are returned) + /// Apply deduplication: Only distinct (different) values are returned pub fn distinct(self) -> Result { Ok(Self::new(LogicalPlan::Distinct(Distinct::All(self.plan)))) } @@ -912,7 +914,7 @@ impl LogicalPlanBuilder { /// Note that in case of outer join, the `filter` is applied to only matched rows. pub fn join( self, - right: LogicalPlan, + right: impl Into>, join_type: JoinType, join_keys: (Vec>, Vec>), filter: Option, @@ -968,7 +970,7 @@ impl LogicalPlanBuilder { /// ``` pub fn join_on( self, - right: LogicalPlan, + right: impl Into>, join_type: JoinType, on_exprs: impl IntoIterator, ) -> Result { @@ -1006,7 +1008,7 @@ impl LogicalPlanBuilder { /// The `null_equality` dictates how `null` values are joined. pub fn join_detailed( self, - right: LogicalPlan, + right: impl Into>, join_type: JoinType, join_keys: (Vec>, Vec>), filter: Option, @@ -1016,6 +1018,7 @@ impl LogicalPlanBuilder { return plan_err!("left_keys and right_keys were not the same length"); } + let right = right.into(); let filter = if let Some(expr) = filter { let filter = normalize_col_with_schemas_and_ambiguity_check( expr, @@ -1121,7 +1124,7 @@ impl LogicalPlanBuilder { Ok(Self::new(LogicalPlan::Join(Join { left: self.plan, - right: Arc::new(right), + right, on, filter, join_type, @@ -1134,10 +1137,11 @@ impl LogicalPlanBuilder { /// Apply a join with using constraint, which duplicates all join columns in output schema. pub fn join_using( self, - right: LogicalPlan, + right: impl Into>, join_type: JoinType, using_keys: Vec, ) -> Result { + let right = right.into(); let left_keys: Vec = using_keys .clone() .into_iter() @@ -1195,7 +1199,7 @@ impl LogicalPlanBuilder { } else { let join = Join::try_new( self.plan, - Arc::new(right), + right, join_on, filters, join_type, @@ -1208,10 +1212,10 @@ impl LogicalPlanBuilder { } /// Apply a cross join - pub fn cross_join(self, right: LogicalPlan) -> Result { + pub fn cross_join(self, right: impl Into>) -> Result { let join = Join::try_new( self.plan, - Arc::new(right), + right.into(), vec![], None, JoinType::Inner, @@ -1310,8 +1314,8 @@ impl LogicalPlanBuilder { /// Process intersect set operator pub fn intersect( - left_plan: LogicalPlan, - right_plan: LogicalPlan, + left_plan: impl Into>, + right_plan: impl Into>, is_all: bool, ) -> Result { LogicalPlanBuilder::intersect_or_except( @@ -1324,8 +1328,8 @@ impl LogicalPlanBuilder { /// Process except set operator pub fn except( - left_plan: LogicalPlan, - right_plan: LogicalPlan, + left_plan: impl Into>, + right_plan: impl Into>, is_all: bool, ) -> Result { LogicalPlanBuilder::intersect_or_except( @@ -1338,11 +1342,13 @@ impl LogicalPlanBuilder { /// Process intersect or except fn intersect_or_except( - left_plan: LogicalPlan, - right_plan: LogicalPlan, + left_plan: impl Into>, + right_plan: impl Into>, join_type: JoinType, is_all: bool, ) -> Result { + let left_plan = left_plan.into(); + let right_plan = right_plan.into(); let left_len = left_plan.schema().fields().len(); let right_len = right_plan.schema().fields().len(); @@ -1368,8 +1374,8 @@ impl LogicalPlanBuilder { .zip(right_plan.schema().fields().iter()) .map(|(left_field, right_field)| { ( - (Column::from_name(left_field.name())), - (Column::from_name(right_field.name())), + Column::from_name(left_field.name()), + Column::from_name(right_field.name()), ) }) .unzip(); @@ -1402,13 +1408,18 @@ impl LogicalPlanBuilder { Ok(Arc::unwrap_or_clone(self.plan)) } + /// Build the plan into an Arc + pub fn build_arc(self) -> Result> { + Ok(self.plan) + } + /// Apply a join with both explicit equijoin and non equijoin predicates. /// /// Note this is a low level API that requires identifying specific /// predicate types. Most users should use [`join_on`](Self::join_on) that /// automatically identifies predicates appropriately. /// - /// `equi_exprs` defines equijoin predicates, of the form `l = r)` for each + /// `equi_exprs` defines equijoin predicates, of the form `l = r` for each /// `(l, r)` tuple. `l`, the first element of the tuple, must only refer /// to columns from the existing input. `r`, the second element of the tuple, /// must only refer to columns from the right input. @@ -1418,7 +1429,7 @@ impl LogicalPlanBuilder { /// than the filter expressions, so they are preferred. pub fn join_with_expr_keys( self, - right: LogicalPlan, + right: impl Into>, join_type: JoinType, equi_exprs: (Vec>, Vec>), filter: Option, @@ -1427,6 +1438,7 @@ impl LogicalPlanBuilder { return plan_err!("left_keys and right_keys were not the same length"); } + let right = right.into(); let join_key_pairs = equi_exprs .0 .into_iter() @@ -1465,7 +1477,7 @@ impl LogicalPlanBuilder { let join = Join::try_new( self.plan, - Arc::new(right), + right, join_key_pairs, filter, join_type, @@ -1478,7 +1490,7 @@ impl LogicalPlanBuilder { /// Unnest the given column. pub fn unnest_column(self, column: impl Into) -> Result { - unnest(Arc::unwrap_or_clone(self.plan), vec![column.into()]).map(Self::new) + unnest(self.plan, vec![column.into()]).map(Self::new) } /// Unnest the given column given [`UnnestOptions`] @@ -1487,12 +1499,7 @@ impl LogicalPlanBuilder { column: impl Into, options: UnnestOptions, ) -> Result { - unnest_with_options( - Arc::unwrap_or_clone(self.plan), - vec![column.into()], - options, - ) - .map(Self::new) + unnest_with_options(self.plan, vec![column.into()], options).map(Self::new) } /// Unnest the given columns with the given [`UnnestOptions`] @@ -1501,8 +1508,7 @@ impl LogicalPlanBuilder { columns: Vec, options: UnnestOptions, ) -> Result { - unnest_with_options(Arc::unwrap_or_clone(self.plan), columns, options) - .map(Self::new) + unnest_with_options(self.plan, columns, options).map(Self::new) } } @@ -1856,7 +1862,7 @@ pub fn validate_unique_names<'a>( /// Union two [`LogicalPlan`]s. /// -/// Constructs the UNION plan, but does not perform type-coercion. Therefore the +/// Constructs the UNION plan, but does not perform type-coercion. Therefore, the /// subtree expressions will not be properly typed until the optimizer pass. /// /// If a properly typed UNION plan is needed, refer to [`TypeCoercionRewriter::coerce_union`] @@ -1865,22 +1871,25 @@ pub fn validate_unique_names<'a>( /// /// [`TypeCoercionRewriter::coerce_union`]: https://docs.rs/datafusion-optimizer/latest/datafusion_optimizer/analyzer/type_coercion/struct.TypeCoercionRewriter.html#method.coerce_union /// [`coerce_union_schema`]: https://docs.rs/datafusion-optimizer/latest/datafusion_optimizer/analyzer/type_coercion/fn.coerce_union_schema.html -pub fn union(left_plan: LogicalPlan, right_plan: LogicalPlan) -> Result { +pub fn union( + left_plan: impl Into>, + right_plan: impl Into>, +) -> Result { Ok(LogicalPlan::Union(Union::try_new_with_loose_types(vec![ - Arc::new(left_plan), - Arc::new(right_plan), + left_plan.into(), + right_plan.into(), ])?)) } /// Like [`union`], but combine rows from different tables by name, rather than /// by position. pub fn union_by_name( - left_plan: LogicalPlan, - right_plan: LogicalPlan, + left_plan: impl Into>, + right_plan: impl Into>, ) -> Result { Ok(LogicalPlan::Union(Union::try_new_by_name(vec![ - Arc::new(left_plan), - Arc::new(right_plan), + left_plan.into(), + right_plan.into(), ])?)) } @@ -1890,7 +1899,7 @@ pub fn union_by_name( /// * Two or more expressions have the same name /// * An invalid expression is used (e.g. a `sort` expression) pub fn project( - plan: LogicalPlan, + plan: impl Into>, expr: impl IntoIterator>, ) -> Result { project_with_validation(plan, expr.into_iter().map(|e| (e, true))) @@ -1904,9 +1913,10 @@ pub fn project( /// * Two or more expressions have the same name /// * An invalid expression is used (e.g. a `sort` expression) fn project_with_validation( - plan: LogicalPlan, + plan: impl Into>, expr: impl IntoIterator, bool)>, ) -> Result { + let plan = plan.into(); let mut projected_expr = vec![]; for (e, validate) in expr { let e = e.into(); @@ -1961,9 +1971,9 @@ fn project_with_validation( } } } - validate_unique_names("Projections", projected_expr.iter())?; - Projection::try_new(projected_expr, Arc::new(plan)).map(LogicalPlan::Projection) + validate_unique_names("Projections", projected_expr.iter())?; + Projection::try_new(projected_expr, plan).map(LogicalPlan::Projection) } /// If there is a REPLACE statement in the projected expression in the form of @@ -1990,10 +2000,10 @@ fn replace_columns( /// Create a SubqueryAlias to wrap a LogicalPlan. pub fn subquery_alias( - plan: LogicalPlan, + plan: impl Into>, alias: impl Into, ) -> Result { - SubqueryAlias::try_new(Arc::new(plan), alias).map(LogicalPlan::SubqueryAlias) + SubqueryAlias::try_new(plan.into(), alias).map(LogicalPlan::SubqueryAlias) } /// Create a LogicalPlanBuilder representing a scan of a table with the provided name and schema. @@ -2069,8 +2079,9 @@ pub fn table_source_with_constraints( /// Wrap projection for a plan, if the join keys contains normal expression. pub fn wrap_projection_for_join_if_necessary( join_keys: &[Expr], - input: LogicalPlan, -) -> Result<(LogicalPlan, Vec, bool)> { + input: impl Into>, +) -> Result<(Arc, Vec, bool)> { + let input = input.into(); let input_schema = input.schema(); let alias_join_keys: Vec = join_keys .iter() @@ -2109,7 +2120,7 @@ pub fn wrap_projection_for_join_if_necessary( LogicalPlanBuilder::from(input) .project(projection.into_iter().map(SelectExpr::from))? - .build()? + .build_arc()? } else { input }; @@ -2174,7 +2185,10 @@ impl TableSource for LogicalTableSource { } /// Create a [`LogicalPlan::Unnest`] plan -pub fn unnest(input: LogicalPlan, columns: Vec) -> Result { +pub fn unnest( + input: impl Into>, + columns: Vec, +) -> Result { unnest_with_options(input, columns, UnnestOptions::default()) } @@ -2190,8 +2204,8 @@ pub fn get_struct_unnested_columns( /// Create a [`LogicalPlan::Unnest`] plan with options /// This function receive a list of columns to be unnested -/// because multiple unnest can be performed on the same column (e.g unnest with different depth) -/// The new schema will contains post-unnest fields replacing the original field +/// because multiple unnest can be performed on the same column (e.g. unnest with different depth) +/// The new schema will contain post-unnest fields replacing the original field /// /// For example: /// Input schema as @@ -2218,12 +2232,12 @@ pub fn get_struct_unnested_columns( /// +---------+---------+---------------------+---------------------+ /// ``` pub fn unnest_with_options( - input: LogicalPlan, + input: impl Into>, columns_to_unnest: Vec, options: UnnestOptions, ) -> Result { Ok(LogicalPlan::Unnest(Unnest::try_new( - Arc::new(input), + input.into(), columns_to_unnest, options, )?)) @@ -2231,12 +2245,12 @@ pub fn unnest_with_options( #[cfg(test)] mod tests { - use std::vec; - use super::*; use crate::lit_with_metadata; use crate::logical_plan::StringifiedPlan; use crate::{col, expr, expr_fn::exists, in_subquery, lit, scalar_subquery}; + use arrow::datatypes::Schema; + use std::vec; use crate::test::function_stub::sum; use datafusion_common::{ @@ -2598,7 +2612,7 @@ mod tests { "); // Check unnested field is a scalar - let field = plan.schema().field_with_name(None, "strings").unwrap(); + let field = plan.schema().field_with_name(None, "strings")?; assert_eq!(&DataType::Utf8, field.data_type()); // Unnesting the singular struct column result into 2 new columns for each subfield @@ -2615,8 +2629,7 @@ mod tests { // Check unnested struct field is a scalar let field = plan .schema() - .field_with_name(None, &format!("struct_singular.{field_name}")) - .unwrap(); + .field_with_name(None, &format!("struct_singular.{field_name}"))?; assert_eq!(&DataType::UInt32, field.data_type()); } @@ -2635,7 +2648,7 @@ mod tests { "); // Check unnested struct list field should be a struct. - let field = plan.schema().field_with_name(None, "structs").unwrap(); + let field = plan.schema().field_with_name(None, "structs")?; assert!(matches!(field.data_type(), DataType::Struct(_))); // Unnesting multiple fields at the same time, using infer syntax @@ -2681,25 +2694,18 @@ mod tests { "); // Check output columns has correct type - let field = plan - .schema() - .field_with_name(None, "stringss_depth_1") - .unwrap(); + let field = plan.schema().field_with_name(None, "stringss_depth_1")?; assert_eq!( &DataType::new_list(DataType::Utf8, false), field.data_type() ); - let field = plan - .schema() - .field_with_name(None, "stringss_depth_2") - .unwrap(); + let field = plan.schema().field_with_name(None, "stringss_depth_2")?; assert_eq!(&DataType::Utf8, field.data_type()); // unnesting struct is still correct for field_name in &["a", "b"] { let field = plan .schema() - .field_with_name(None, &format!("struct_singular.{field_name}")) - .unwrap(); + .field_with_name(None, &format!("struct_singular.{field_name}"))?; assert_eq!(&DataType::UInt32, field.data_type()); } diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index bc317e9c201ce..23031bfd11b88 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -78,7 +78,7 @@ fn coerce_output(plan: LogicalPlan, config: &ConfigOptions) -> Result TypeCoercionRewriter<'a> { .inputs .into_iter() .map(|p| { - let plan = - coerce_plan_expr_for_schema(Arc::unwrap_or_clone(p), &union_schema)?; - match plan { + match Arc::unwrap_or_clone(coerce_plan_expr_for_schema(p, &union_schema)?) + { LogicalPlan::Projection(Projection { expr, input, .. }) => { Ok(Arc::new(project_with_column_index( expr, diff --git a/datafusion/optimizer/src/optimize_projections/mod.rs b/datafusion/optimizer/src/optimize_projections/mod.rs index 548eadffa242e..288b528b634a9 100644 --- a/datafusion/optimizer/src/optimize_projections/mod.rs +++ b/datafusion/optimizer/src/optimize_projections/mod.rs @@ -946,7 +946,7 @@ mod tests { use crate::optimize_projections::OptimizeProjections; use crate::optimizer::Optimizer; use crate::test::{ - assert_fields_eq, scan_empty, test_table_scan, test_table_scan_fields, + assert_fields_eq, test_table_scan, test_table_scan_fields, test_table_scan_with_name, }; use crate::{OptimizerContext, OptimizerRule}; @@ -1839,7 +1839,8 @@ mod tests { let table_scan = test_table_scan()?; let schema = Schema::new(vec![Field::new("c1", DataType::UInt32, false)]); - let table2_scan = scan_empty(Some("test2"), &schema, None)?.build()?; + let table2_scan = + datafusion_expr::table_scan(Some("test2"), &schema, None)?.build()?; let plan = LogicalPlanBuilder::from(table_scan) .join(table2_scan, JoinType::Left, (vec!["a"], vec!["c1"]), None)? @@ -1891,7 +1892,8 @@ mod tests { let table_scan = test_table_scan()?; let schema = Schema::new(vec![Field::new("c1", DataType::UInt32, false)]); - let table2_scan = scan_empty(Some("test2"), &schema, None)?.build()?; + let table2_scan = + datafusion_expr::table_scan(Some("test2"), &schema, None)?.build()?; let plan = LogicalPlanBuilder::from(table_scan) .join(table2_scan, JoinType::Left, (vec!["a"], vec!["c1"]), None)? @@ -1946,7 +1948,8 @@ mod tests { let table_scan = test_table_scan()?; let schema = Schema::new(vec![Field::new("a", DataType::UInt32, false)]); - let table2_scan = scan_empty(Some("test2"), &schema, None)?.build()?; + let table2_scan = + datafusion_expr::table_scan(Some("test2"), &schema, None)?.build()?; let plan = LogicalPlanBuilder::from(table_scan) .join_using(table2_scan, JoinType::Left, vec!["a".into()])? diff --git a/datafusion/optimizer/src/optimize_unions.rs b/datafusion/optimizer/src/optimize_unions.rs index 900757b9a0607..4848553c4d6b4 100644 --- a/datafusion/optimizer/src/optimize_unions.rs +++ b/datafusion/optimizer/src/optimize_unions.rs @@ -22,7 +22,6 @@ use datafusion_common::Result; use datafusion_common::tree_node::Transformed; use datafusion_expr::expr_rewriter::coerce_plan_expr_for_schema; use datafusion_expr::{Distinct, LogicalPlan, Projection, Union}; -use itertools::Itertools; use std::sync::Arc; #[derive(Default, Debug)] @@ -64,11 +63,11 @@ impl OptimizerRule for OptimizeUnions { let inputs = inputs .into_iter() .flat_map(extract_plans_from_union) - .map(|plan| coerce_plan_expr_for_schema(plan, &schema)) + .map(|plan| coerce_plan_expr_for_schema(Arc::new(plan), &schema)) .collect::>>()?; Ok(Transformed::yes(LogicalPlan::Union(Union { - inputs: inputs.into_iter().map(Arc::new).collect_vec(), + inputs, schema, }))) } @@ -79,12 +78,14 @@ impl OptimizerRule for OptimizeUnions { .into_iter() .map(extract_plan_from_distinct) .flat_map(extract_plans_from_union) - .map(|plan| coerce_plan_expr_for_schema(plan, &schema)) + .map(|plan| { + coerce_plan_expr_for_schema(Arc::new(plan), &schema) + }) .collect::>>()?; Ok(Transformed::yes(LogicalPlan::Distinct(Distinct::All( Arc::new(LogicalPlan::Union(Union { - inputs: inputs.into_iter().map(Arc::new).collect_vec(), + inputs, schema: Arc::clone(&schema), })), )))) diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index 755ffdbafc869..5e0356f6f0d9b 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -2319,7 +2319,7 @@ mod tests { ]); let right = table_scan(Some("test1"), &schema, None)? .project(vec![col("d"), col("e"), col("f")])? - .build()?; + .build_arc()?; let filter = and(col("test.a").eq(lit(1)), col("test1.d").gt(lit(2))); let plan = LogicalPlanBuilder::from(left) .cross_join(right)? @@ -2349,7 +2349,7 @@ mod tests { let right_table_scan = test_table_scan_with_name("test1")?; let right = LogicalPlanBuilder::from(right_table_scan) .project(vec![col("a"), col("b"), col("c")])? - .build()?; + .build_arc()?; let filter = and(col("test.a").eq(lit(1)), col("test1.a").gt(lit(2))); let plan = LogicalPlanBuilder::from(left) .cross_join(right)? diff --git a/datafusion/optimizer/src/scalar_subquery_to_join.rs b/datafusion/optimizer/src/scalar_subquery_to_join.rs index 975c234b38836..09ab1c2a992ba 100644 --- a/datafusion/optimizer/src/scalar_subquery_to_join.rs +++ b/datafusion/optimizer/src/scalar_subquery_to_join.rs @@ -100,10 +100,10 @@ impl OptimizerRule for ScalarSubqueryToJoin { ); // iterate through all subqueries in predicate, turning each into a left join - let mut cur_input = filter.input.as_ref().clone(); + let mut cur_input = Arc::clone(&filter.input); for (subquery, alias) in subqueries { if let Some((optimized_subquery, expr_check_map)) = - build_join(&subquery, &cur_input, &alias)? + build_join(subquery, cur_input, &alias)? { if !expr_check_map.is_empty() { rewrite_expr = rewrite_expr @@ -122,7 +122,7 @@ impl OptimizerRule for ScalarSubqueryToJoin { } cur_input = optimized_subquery; } else { - // if we can't handle all of the subqueries then bail for now + // if we can't handle all the subqueries then bail for now return Ok(Transformed::no(LogicalPlan::Filter(filter))); } } @@ -145,29 +145,30 @@ impl OptimizerRule for ScalarSubqueryToJoin { let mut all_subqueries = vec![]; let mut expr_to_rewrite_expr_map = HashMap::new(); - let mut subquery_to_expr_map = HashMap::new(); for expr in projection.expr.iter() { - let (subqueries, rewrite_exprs) = + let (subqueries, rewrite_expr) = self.extract_subquery_exprs(expr, config.alias_generator())?; - for (subquery, _) in &subqueries { - subquery_to_expr_map.insert(subquery.clone(), expr.clone()); - } - all_subqueries.extend(subqueries); - expr_to_rewrite_expr_map.insert(expr, rewrite_exprs); + all_subqueries.extend( + subqueries + .into_iter() + .map(|(subquery, alias)| (subquery, alias, expr)), + ); + expr_to_rewrite_expr_map.insert(expr, rewrite_expr); } + assert_or_internal_err!( !all_subqueries.is_empty(), "Expected subqueries not found in projection" ); + // iterate through all subqueries in predicate, turning each into a left join - let mut cur_input = projection.input.as_ref().clone(); - for (subquery, alias) in all_subqueries { + let mut cur_input = Arc::clone(&projection.input); + for (subquery, alias, expr) in all_subqueries { if let Some((optimized_subquery, expr_check_map)) = - build_join(&subquery, &cur_input, &alias)? + build_join(subquery, cur_input, &alias)? { cur_input = optimized_subquery; if !expr_check_map.is_empty() - && let Some(expr) = subquery_to_expr_map.get(&subquery) && let Some(rewrite_expr) = expr_to_rewrite_expr_map.get(expr) { let new_expr = rewrite_expr @@ -187,7 +188,7 @@ impl OptimizerRule for ScalarSubqueryToJoin { expr_to_rewrite_expr_map.insert(expr, new_expr); } } else { - // if we can't handle all of the subqueries then bail for now + // if we can't handle all the subqueries then bail for now return Ok(Transformed::no(LogicalPlan::Projection(projection))); } } @@ -241,12 +242,11 @@ impl TreeNodeRewriter for ExtractScalarSubQuery<'_> { match expr { Expr::ScalarSubquery(subquery) => { let subqry_alias = self.alias_gen.next("__scalar_sq"); - self.sub_query_info - .push((subquery.clone(), subqry_alias.clone())); let scalar_expr = subquery .subquery .head_output_expr()? .map_or(plan_err!("single expression required."), Ok)?; + self.sub_query_info.push((subquery, subqry_alias.clone())); Ok(Transformed::new( Expr::Column(create_col_from_scalar_expr( &scalar_expr, @@ -297,14 +297,16 @@ impl TreeNodeRewriter for ExtractScalarSubQuery<'_> { /// * `filter_input` - The non-subquery portion (from customers) /// * `outer_others` - Any additional parts to the `where` expression (and c.x = y) /// * `subquery_alias` - Subquery aliases +#[expect(clippy::type_complexity)] fn build_join( - subquery: &Subquery, - filter_input: &LogicalPlan, + subquery: Subquery, + filter_input: Arc, subquery_alias: &str, -) -> Result)>> { - let subquery_plan = subquery.subquery.as_ref(); +) -> Result, HashMap)>> { let mut pull_up = PullUpCorrelatedExpr::new().with_need_handle_count_bug(true); - let new_plan = subquery_plan.clone().rewrite(&mut pull_up).data()?; + let new_plan = Arc::unwrap_or_clone(subquery.subquery) + .rewrite(&mut pull_up) + .data()?; if !pull_up.can_pull_up { return Ok(None); } @@ -313,7 +315,7 @@ fn build_join( pull_up.collected_count_expr_map.get(&new_plan).cloned(); let sub_query_alias = LogicalPlanBuilder::from(new_plan) .alias(subquery_alias.to_string())? - .build()?; + .build_arc()?; let mut all_correlated_cols = BTreeSet::new(); pull_up @@ -329,27 +331,27 @@ fn build_join( // join our sub query into the main plan let new_plan = if join_filter_opt.is_none() { - match filter_input { + match filter_input.as_ref() { LogicalPlan::EmptyRelation(EmptyRelation { produce_one_row: true, schema: _, }) => sub_query_alias, _ => { // if not correlated, group down to 1 row and left join on that (preserving row count) - LogicalPlanBuilder::from(filter_input.clone()) + LogicalPlanBuilder::from(filter_input) .join_on( sub_query_alias, JoinType::Left, vec![Expr::Literal(ScalarValue::Boolean(Some(true)), None)], )? - .build()? + .build_arc()? } } } else { // left join if correlated, grouping by the join keys so we don't change row count - LogicalPlanBuilder::from(filter_input.clone()) + LogicalPlanBuilder::from(filter_input) .join_on(sub_query_alias, JoinType::Left, join_filter_opt)? - .build()? + .build_arc()? }; let mut computation_project_expr = HashMap::new(); if let Some(expr_map) = collected_count_expr_map { diff --git a/datafusion/optimizer/src/test/mod.rs b/datafusion/optimizer/src/test/mod.rs index a45983950496d..4507b1e59e84b 100644 --- a/datafusion/optimizer/src/test/mod.rs +++ b/datafusion/optimizer/src/test/mod.rs @@ -21,7 +21,7 @@ use crate::{OptimizerContext, OptimizerRule}; use arrow::datatypes::{DataType, Field, Schema}; use datafusion_common::config::ConfigOptions; use datafusion_common::{Result, assert_contains}; -use datafusion_expr::{LogicalPlan, LogicalPlanBuilder, logical_plan::table_scan}; +use datafusion_expr::{LogicalPlan, logical_plan::table_scan}; use std::sync::Arc; pub mod user_defined; @@ -45,15 +45,6 @@ pub fn test_table_scan() -> Result { test_table_scan_with_name("test") } -/// Scan an empty data source, mainly used in tests -pub fn scan_empty( - name: Option<&str>, - table_schema: &Schema, - projection: Option>, -) -> Result { - table_scan(name, table_schema, projection) -} - pub fn assert_fields_eq(plan: &LogicalPlan, expected: Vec<&str>) { let actual: Vec = plan .schema() From ac8da54c312f7de1ee464111cff275f62f159dad Mon Sep 17 00:00:00 2001 From: Georgi Krastev Date: Fri, 26 Sep 2025 17:18:38 +0200 Subject: [PATCH 2/2] Don't clone the schema in logical2physical --- .../src/datasource/physical_plan/parquet.rs | 2 +- datafusion/datasource-parquet/src/opener.rs | 36 +-- .../datasource-parquet/src/row_filter.rs | 34 ++- .../src/row_group_filter.rs | 41 ++-- datafusion/physical-expr/benches/binary_op.rs | 4 +- .../physical-expr/src/expressions/binary.rs | 19 +- datafusion/physical-expr/src/planner.rs | 11 +- .../physical-expr/src/utils/guarantee.rs | 3 +- datafusion/pruning/src/pruning_predicate.rs | 213 ++++++++++-------- 9 files changed, 203 insertions(+), 160 deletions(-) diff --git a/datafusion/core/src/datasource/physical_plan/parquet.rs b/datafusion/core/src/datasource/physical_plan/parquet.rs index 4703b55ecc0de..510ac09d4e100 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet.rs @@ -159,7 +159,7 @@ mod tests { let predicate = self .predicate .as_ref() - .map(|p| logical2physical(p, &table_schema)); + .map(|p| logical2physical(p, Arc::clone(&table_schema))); let mut source = ParquetSource::new(table_schema); if let Some(predicate) = predicate { diff --git a/datafusion/datasource-parquet/src/opener.rs b/datafusion/datasource-parquet/src/opener.rs index bea970f144863..e4f2da9d2f9b1 100644 --- a/datafusion/datasource-parquet/src/opener.rs +++ b/datafusion/datasource-parquet/src/opener.rs @@ -1350,7 +1350,7 @@ mod test { // A filter on "a" should not exclude any rows even if it matches the data let expr = col("a").eq(lit(1)); - let predicate = logical2physical(&expr, &schema); + let predicate = logical2physical(&expr, Arc::clone(&schema)); let opener = make_opener(predicate); let stream = opener.open(file.clone()).unwrap().await.unwrap(); let (num_batches, num_rows) = count_batches_and_rows(stream).await; @@ -1359,7 +1359,7 @@ mod test { // A filter on `b = 5.0` should exclude all rows let expr = col("b").eq(lit(ScalarValue::Float32(Some(5.0)))); - let predicate = logical2physical(&expr, &schema); + let predicate = logical2physical(&expr, Arc::clone(&schema)); let opener = make_opener(predicate); let stream = opener.open(file).unwrap().await.unwrap(); let (num_batches, num_rows) = count_batches_and_rows(stream).await; @@ -1405,7 +1405,8 @@ mod test { let expr = col("part").eq(lit(1)); // Mark the expression as dynamic even if it's not to force partition pruning to happen // Otherwise we assume it already happened at the planning stage and won't re-do the work here - let predicate = make_dynamic_expr(logical2physical(&expr, &table_schema)); + let predicate = + make_dynamic_expr(logical2physical(&expr, Arc::clone(&table_schema))); let opener = make_opener(predicate); let stream = opener.open(file.clone()).unwrap().await.unwrap(); let (num_batches, num_rows) = count_batches_and_rows(stream).await; @@ -1416,7 +1417,7 @@ mod test { let expr = col("part").eq(lit(2)); // Mark the expression as dynamic even if it's not to force partition pruning to happen // Otherwise we assume it already happened at the planning stage and won't re-do the work here - let predicate = make_dynamic_expr(logical2physical(&expr, &table_schema)); + let predicate = make_dynamic_expr(logical2physical(&expr, table_schema)); let opener = make_opener(predicate); let stream = opener.open(file).unwrap().await.unwrap(); let (num_batches, num_rows) = count_batches_and_rows(stream).await; @@ -1472,7 +1473,7 @@ mod test { // Filter should match the partition value and file statistics let expr = col("part").eq(lit(1)).and(col("b").eq(lit(1.0))); - let predicate = logical2physical(&expr, &table_schema); + let predicate = logical2physical(&expr, Arc::clone(&table_schema)); let opener = make_opener(predicate); let stream = opener.open(file.clone()).unwrap().await.unwrap(); let (num_batches, num_rows) = count_batches_and_rows(stream).await; @@ -1481,7 +1482,7 @@ mod test { // Should prune based on partition value but not file statistics let expr = col("part").eq(lit(2)).and(col("b").eq(lit(1.0))); - let predicate = logical2physical(&expr, &table_schema); + let predicate = logical2physical(&expr, Arc::clone(&table_schema)); let opener = make_opener(predicate); let stream = opener.open(file.clone()).unwrap().await.unwrap(); let (num_batches, num_rows) = count_batches_and_rows(stream).await; @@ -1490,7 +1491,7 @@ mod test { // Should prune based on file statistics but not partition value let expr = col("part").eq(lit(1)).and(col("b").eq(lit(7.0))); - let predicate = logical2physical(&expr, &table_schema); + let predicate = logical2physical(&expr, Arc::clone(&table_schema)); let opener = make_opener(predicate); let stream = opener.open(file.clone()).unwrap().await.unwrap(); let (num_batches, num_rows) = count_batches_and_rows(stream).await; @@ -1499,7 +1500,7 @@ mod test { // Should prune based on both partition value and file statistics let expr = col("part").eq(lit(2)).and(col("b").eq(lit(7.0))); - let predicate = logical2physical(&expr, &table_schema); + let predicate = logical2physical(&expr, table_schema); let opener = make_opener(predicate); let stream = opener.open(file).unwrap().await.unwrap(); let (num_batches, num_rows) = count_batches_and_rows(stream).await; @@ -1545,7 +1546,7 @@ mod test { // Filter should match the partition value and data value let expr = col("part").eq(lit(1)).or(col("a").eq(lit(1))); - let predicate = logical2physical(&expr, &table_schema); + let predicate = logical2physical(&expr, Arc::clone(&table_schema)); let opener = make_opener(predicate); let stream = opener.open(file.clone()).unwrap().await.unwrap(); let (num_batches, num_rows) = count_batches_and_rows(stream).await; @@ -1554,7 +1555,7 @@ mod test { // Filter should match the partition value but not the data value let expr = col("part").eq(lit(1)).or(col("a").eq(lit(3))); - let predicate = logical2physical(&expr, &table_schema); + let predicate = logical2physical(&expr, Arc::clone(&table_schema)); let opener = make_opener(predicate); let stream = opener.open(file.clone()).unwrap().await.unwrap(); let (num_batches, num_rows) = count_batches_and_rows(stream).await; @@ -1563,7 +1564,7 @@ mod test { // Filter should not match the partition value but match the data value let expr = col("part").eq(lit(2)).or(col("a").eq(lit(1))); - let predicate = logical2physical(&expr, &table_schema); + let predicate = logical2physical(&expr, Arc::clone(&table_schema)); let opener = make_opener(predicate); let stream = opener.open(file.clone()).unwrap().await.unwrap(); let (num_batches, num_rows) = count_batches_and_rows(stream).await; @@ -1572,7 +1573,7 @@ mod test { // Filter should not match the partition value or the data value let expr = col("part").eq(lit(2)).or(col("a").eq(lit(3))); - let predicate = logical2physical(&expr, &table_schema); + let predicate = logical2physical(&expr, table_schema); let opener = make_opener(predicate); let stream = opener.open(file).unwrap().await.unwrap(); let (num_batches, num_rows) = count_batches_and_rows(stream).await; @@ -1625,7 +1626,7 @@ mod test { // This filter could prune based on statistics, but since it's not dynamic it's not applied for pruning // (the assumption is this happened already at planning time) let expr = col("a").eq(lit(42)); - let predicate = logical2physical(&expr, &table_schema); + let predicate = logical2physical(&expr, Arc::clone(&table_schema)); let opener = make_opener(predicate); let stream = opener.open(file.clone()).unwrap().await.unwrap(); let (num_batches, num_rows) = count_batches_and_rows(stream).await; @@ -1634,7 +1635,8 @@ mod test { // If we make the filter dynamic, it should prune. // This allows dynamic filters to prune partitions/files even if they are populated late into execution. - let predicate = make_dynamic_expr(logical2physical(&expr, &table_schema)); + let predicate = + make_dynamic_expr(logical2physical(&expr, Arc::clone(&table_schema))); let opener = make_opener(predicate); let stream = opener.open(file.clone()).unwrap().await.unwrap(); let (num_batches, num_rows) = count_batches_and_rows(stream).await; @@ -1644,7 +1646,8 @@ mod test { // If we have a filter that touches partition columns only and is dynamic, it should prune even if there are no stats. file.statistics = Some(Arc::new(Statistics::new_unknown(&file_schema))); let expr = col("part").eq(lit(2)); - let predicate = make_dynamic_expr(logical2physical(&expr, &table_schema)); + let predicate = + make_dynamic_expr(logical2physical(&expr, Arc::clone(&table_schema))); let opener = make_opener(predicate); let stream = opener.open(file.clone()).unwrap().await.unwrap(); let (num_batches, num_rows) = count_batches_and_rows(stream).await; @@ -1653,7 +1656,8 @@ mod test { // Similarly a filter that combines partition and data columns should prune even if there are no stats. let expr = col("part").eq(lit(2)).and(col("a").eq(lit(42))); - let predicate = make_dynamic_expr(logical2physical(&expr, &table_schema)); + let predicate = + make_dynamic_expr(logical2physical(&expr, Arc::clone(&table_schema))); let opener = make_opener(predicate); let stream = opener.open(file.clone()).unwrap().await.unwrap(); let (num_batches, num_rows) = count_batches_and_rows(stream).await; diff --git a/datafusion/datasource-parquet/src/row_filter.rs b/datafusion/datasource-parquet/src/row_filter.rs index ba3b29be40d74..529be3e50dcf5 100644 --- a/datafusion/datasource-parquet/src/row_filter.rs +++ b/datafusion/datasource-parquet/src/row_filter.rs @@ -487,15 +487,13 @@ mod test { let metadata = reader.metadata(); - let table_schema = + let table_schema = Arc::new( parquet_to_arrow_schema(metadata.file_metadata().schema_descr(), None) - .expect("parsing schema"); + .expect("parsing schema"), + ); let expr = col("int64_list").is_not_null(); - let expr = logical2physical(&expr, &table_schema); - - let table_schema = Arc::new(table_schema.clone()); - + let expr = logical2physical(&expr, Arc::clone(&table_schema)); let candidate = FilterCandidateBuilder::new(expr, table_schema) .build(metadata) .expect("building candidate"); @@ -516,23 +514,23 @@ mod test { // This is the schema we would like to coerce to, // which is different from the physical schema of the file. - let table_schema = Schema::new(vec![Field::new( + let table_schema = Arc::new(Schema::new(vec![Field::new( "timestamp_col", DataType::Timestamp(Nanosecond, Some(Arc::from("UTC"))), false, - )]); + )])); // Test all should fail let expr = col("timestamp_col").lt(Expr::Literal( ScalarValue::TimestampNanosecond(Some(1), Some(Arc::from("UTC"))), None, )); - let expr = logical2physical(&expr, &table_schema); + let expr = logical2physical(&expr, Arc::clone(&table_schema)); let expr = DefaultPhysicalExprAdapterFactory {} - .create(Arc::new(table_schema.clone()), Arc::clone(&file_schema)) + .create(Arc::clone(&table_schema), Arc::clone(&file_schema)) .rewrite(expr) .expect("rewriting expression"); - let candidate = FilterCandidateBuilder::new(expr, file_schema.clone()) + let candidate = FilterCandidateBuilder::new(expr, Arc::clone(&file_schema)) .build(&metadata) .expect("building candidate") .expect("candidate expected"); @@ -565,10 +563,10 @@ mod test { ScalarValue::TimestampNanosecond(Some(0), Some(Arc::from("UTC"))), None, )); - let expr = logical2physical(&expr, &table_schema); + let expr = logical2physical(&expr, Arc::clone(&table_schema)); // Rewrite the expression to add CastExpr for type coercion let expr = DefaultPhysicalExprAdapterFactory {} - .create(Arc::new(table_schema), Arc::clone(&file_schema)) + .create(table_schema, Arc::clone(&file_schema)) .rewrite(expr) .expect("rewriting expression"); let candidate = FilterCandidateBuilder::new(expr, file_schema) @@ -594,7 +592,7 @@ mod test { let table_schema = Arc::new(get_lists_table_schema()); let expr = col("utf8_list").is_not_null(); - let expr = logical2physical(&expr, &table_schema); + let expr = logical2physical(&expr, Arc::clone(&table_schema)); check_expression_can_evaluate_against_schema(&expr, &table_schema); assert!(!can_expr_be_pushed_down_with_schemas(&expr, &table_schema)); @@ -612,22 +610,22 @@ mod test { #[test] fn basic_expr_doesnt_prevent_pushdown() { - let table_schema = get_basic_table_schema(); + let table_schema = Arc::new(get_basic_table_schema()); let expr = col("string_col").is_null(); - let expr = logical2physical(&expr, &table_schema); + let expr = logical2physical(&expr, Arc::clone(&table_schema)); assert!(can_expr_be_pushed_down_with_schemas(&expr, &table_schema)); } #[test] fn complex_expr_doesnt_prevent_pushdown() { - let table_schema = get_basic_table_schema(); + let table_schema = Arc::new(get_basic_table_schema()); let expr = col("string_col") .is_not_null() .or(col("bigint_col").gt(Expr::Literal(ScalarValue::Int64(Some(5)), None))); - let expr = logical2physical(&expr, &table_schema); + let expr = logical2physical(&expr, Arc::clone(&table_schema)); assert!(can_expr_be_pushed_down_with_schemas(&expr, &table_schema)); } diff --git a/datafusion/datasource-parquet/src/row_group_filter.rs b/datafusion/datasource-parquet/src/row_group_filter.rs index 1264197609f3f..2b664dc170e89 100644 --- a/datafusion/datasource-parquet/src/row_group_filter.rs +++ b/datafusion/datasource-parquet/src/row_group_filter.rs @@ -518,7 +518,7 @@ mod tests { let schema = Arc::new(Schema::new(vec![Field::new("c1", DataType::Int32, false)])); let expr = col("c1").gt(lit(15)); - let expr = logical2physical(&expr, &schema); + let expr = logical2physical(&expr, Arc::clone(&schema)); let pruning_predicate = PruningPredicate::try_new(expr, schema.clone()).unwrap(); let field = PrimitiveTypeField::new("c1", PhysicalType::INT32); @@ -563,7 +563,7 @@ mod tests { let schema = Arc::new(Schema::new(vec![Field::new("c1", DataType::Int32, false)])); let expr = col("c1").gt(lit(15)); - let expr = logical2physical(&expr, &schema); + let expr = logical2physical(&expr, Arc::clone(&schema)); let pruning_predicate = PruningPredicate::try_new(expr, schema.clone()).unwrap(); let field = PrimitiveTypeField::new("c1", PhysicalType::INT32); @@ -606,7 +606,7 @@ mod tests { Field::new("c2", DataType::Int32, false), ])); let expr = col("c1").gt(lit(15)).and(col("c2").rem(lit(2)).eq(lit(0))); - let expr = logical2physical(&expr, &schema); + let expr = logical2physical(&expr, Arc::clone(&schema)); let pruning_predicate = PruningPredicate::try_new(expr, schema.clone()).unwrap(); let schema_descr = get_test_schema_descr(vec![ @@ -645,7 +645,7 @@ mod tests { // if conditions in predicate are joined with OR and an unsupported expression is used // this bypasses the entire predicate expression and no row groups are filtered out let expr = col("c1").gt(lit(15)).or(col("c2").rem(lit(2)).eq(lit(0))); - let expr = logical2physical(&expr, &schema); + let expr = logical2physical(&expr, Arc::clone(&schema)); let pruning_predicate = PruningPredicate::try_new(expr, schema.clone()).unwrap(); // if conditions in predicate are joined with OR and an unsupported expression is used @@ -671,7 +671,7 @@ mod tests { Field::new("c2", DataType::Int32, false), ])); let expr = col("c1").gt(lit(0)); - let expr = logical2physical(&expr, &table_schema); + let expr = logical2physical(&expr, Arc::clone(&table_schema)); let pruning_predicate = PruningPredicate::try_new(expr, table_schema.clone()).unwrap(); @@ -749,7 +749,7 @@ mod tests { ])); let schema_descr = ArrowSchemaConverter::new().convert(&schema).unwrap(); let expr = col("c1").gt(lit(15)).and(col("c2").is_null()); - let expr = logical2physical(&expr, &schema); + let expr = logical2physical(&expr, Arc::clone(&schema)); let pruning_predicate = PruningPredicate::try_new(expr, schema.clone()).unwrap(); let groups = gen_row_group_meta_data_for_pruning_predicate(); @@ -780,7 +780,7 @@ mod tests { let expr = col("c1") .gt(lit(15)) .and(col("c2").eq(lit(ScalarValue::Boolean(None)))); - let expr = logical2physical(&expr, &schema); + let expr = logical2physical(&expr, Arc::clone(&schema)); let pruning_predicate = PruningPredicate::try_new(expr, schema.clone()).unwrap(); let groups = gen_row_group_meta_data_for_pruning_predicate(); @@ -818,7 +818,7 @@ mod tests { .with_precision(9); let schema_descr = get_test_schema_descr(vec![field]); let expr = col("c1").gt(lit(ScalarValue::Decimal128(Some(500), 9, 2))); - let expr = logical2physical(&expr, &schema); + let expr = logical2physical(&expr, Arc::clone(&schema)); let pruning_predicate = PruningPredicate::try_new(expr, schema.clone()).unwrap(); let rgm1 = get_row_group_meta_data( &schema_descr, @@ -889,7 +889,7 @@ mod tests { lit(ScalarValue::Decimal128(Some(500), 5, 2)), Decimal128(11, 2), )); - let expr = logical2physical(&expr, &schema); + let expr = logical2physical(&expr, Arc::clone(&schema)); let pruning_predicate = PruningPredicate::try_new(expr, schema.clone()).unwrap(); let rgm1 = get_row_group_meta_data( &schema_descr, @@ -981,7 +981,7 @@ mod tests { .with_precision(18); let schema_descr = get_test_schema_descr(vec![field]); let expr = col("c1").lt(lit(ScalarValue::Decimal128(Some(500), 18, 2))); - let expr = logical2physical(&expr, &schema); + let expr = logical2physical(&expr, Arc::clone(&schema)); let pruning_predicate = PruningPredicate::try_new(expr, schema.clone()).unwrap(); let rgm1 = get_row_group_meta_data( &schema_descr, @@ -1042,7 +1042,7 @@ mod tests { // cast the type of c1 to decimal(28,3) let left = cast(col("c1"), Decimal128(28, 3)); let expr = left.eq(lit(ScalarValue::Decimal128(Some(100000), 28, 3))); - let expr = logical2physical(&expr, &schema); + let expr = logical2physical(&expr, Arc::clone(&schema)); let pruning_predicate = PruningPredicate::try_new(expr, schema.clone()).unwrap(); // we must use the big-endian when encode the i128 to bytes or vec[u8]. let rgm1 = get_row_group_meta_data( @@ -1120,7 +1120,7 @@ mod tests { // cast the type of c1 to decimal(28,3) let left = cast(col("c1"), Decimal128(28, 3)); let expr = left.eq(lit(ScalarValue::Decimal128(Some(100000), 28, 3))); - let expr = logical2physical(&expr, &schema); + let expr = logical2physical(&expr, Arc::clone(&schema)); let pruning_predicate = PruningPredicate::try_new(expr, schema.clone()).unwrap(); // we must use the big-endian when encode the i128 to bytes or vec[u8]. let rgm1 = get_row_group_meta_data( @@ -1283,7 +1283,11 @@ mod tests { let data = bytes::Bytes::from(std::fs::read(path).unwrap()); // generate pruning predicate - let schema = Schema::new(vec![Field::new("String", DataType::Utf8, false)]); + let schema = Arc::new(Schema::new(vec![Field::new( + "String", + DataType::Utf8, + false, + )])); let expr = col(r#""String""#).in_list( (1..25) @@ -1291,9 +1295,8 @@ mod tests { .collect::>(), false, ); - let expr = logical2physical(&expr, &schema); - let pruning_predicate = - PruningPredicate::try_new(expr, Arc::new(schema)).unwrap(); + let expr = logical2physical(&expr, Arc::clone(&schema)); + let pruning_predicate = PruningPredicate::try_new(expr, schema).unwrap(); let pruned_row_groups = test_row_group_bloom_filter_pruning_predicate( file_name, @@ -1514,9 +1517,9 @@ mod tests { let path = format!("{testdata}/{file_name}"); let data = bytes::Bytes::from(std::fs::read(path).unwrap()); - let expr = logical2physical(&expr, &schema); - let pruning_predicate = - PruningPredicate::try_new(expr, Arc::new(schema)).unwrap(); + let schema = Arc::new(schema); + let expr = logical2physical(&expr, Arc::clone(&schema)); + let pruning_predicate = PruningPredicate::try_new(expr, schema).unwrap(); let pruned_row_groups = test_row_group_bloom_filter_pruning_predicate( &file_name, diff --git a/datafusion/physical-expr/benches/binary_op.rs b/datafusion/physical-expr/benches/binary_op.rs index 99fc40fa1c91b..1d6d8a96ef985 100644 --- a/datafusion/physical-expr/benches/binary_op.rs +++ b/datafusion/physical-expr/benches/binary_op.rs @@ -203,14 +203,14 @@ fn benchmark_binary_op_in_short_circuit(c: &mut Criterion) { let expr_and = BinaryExpr::new( Arc::new(Column::new("a", 0)), Operator::And, - logical2physical(&right_condition_and, &schema), + logical2physical(&right_condition_and, Arc::clone(&schema)), ); // a OR ((b ~ regex) OR (c ~ regex)) let expr_or = BinaryExpr::new( Arc::new(Column::new("a", 0)), Operator::Or, - logical2physical(&right_condition_or, &schema), + logical2physical(&right_condition_or, schema), ); // Each scenario when the test operator is `and` diff --git a/datafusion/physical-expr/src/expressions/binary.rs b/datafusion/physical-expr/src/expressions/binary.rs index 8df09c22bbd8d..176be9b0a618e 100644 --- a/datafusion/physical-expr/src/expressions/binary.rs +++ b/datafusion/physical-expr/src/expressions/binary.rs @@ -4896,7 +4896,8 @@ mod tests { .unwrap(); // op: AND left: all false - let left_expr = logical2physical(&logical_col("a").eq(expr_lit(2)), &schema); + let left_expr = + logical2physical(&logical_col("a").eq(expr_lit(2)), Arc::clone(&schema)); let left_value = left_expr.evaluate(&batch).unwrap(); assert!(matches!( check_short_circuit(&left_value, &Operator::And), @@ -4904,7 +4905,8 @@ mod tests { )); // op: AND left: not all false - let left_expr = logical2physical(&logical_col("a").eq(expr_lit(3)), &schema); + let left_expr = + logical2physical(&logical_col("a").eq(expr_lit(3)), Arc::clone(&schema)); let left_value = left_expr.evaluate(&batch).unwrap(); let ColumnarValue::Array(array) = &left_value else { panic!("Expected ColumnarValue::Array"); @@ -4920,7 +4922,8 @@ mod tests { assert_eq!(expected_boolean_arr, boolean_arr); // op: OR left: all true - let left_expr = logical2physical(&logical_col("a").gt(expr_lit(0)), &schema); + let left_expr = + logical2physical(&logical_col("a").gt(expr_lit(0)), Arc::clone(&schema)); let left_value = left_expr.evaluate(&batch).unwrap(); assert!(matches!( check_short_circuit(&left_value, &Operator::Or), @@ -4929,7 +4932,7 @@ mod tests { // op: OR left: not all true let left_expr: Arc = - logical2physical(&logical_col("a").gt(expr_lit(2)), &schema); + logical2physical(&logical_col("a").gt(expr_lit(2)), Arc::clone(&schema)); let left_value = left_expr.evaluate(&batch).unwrap(); assert!(matches!( check_short_circuit(&left_value, &Operator::Or), @@ -4965,7 +4968,7 @@ mod tests { .unwrap(); // Case: Mixed values with nulls - shouldn't short-circuit for AND - let mixed_nulls = logical2physical(&logical_col("c"), &schema_nullable); + let mixed_nulls = logical2physical(&logical_col("c"), schema_nullable); let mixed_nulls_value = mixed_nulls.evaluate(&batch_nullable).unwrap(); assert!(matches!( check_short_circuit(&mixed_nulls_value, &Operator::And), @@ -4986,10 +4989,10 @@ mod tests { ) .unwrap(); - let null_expr = logical2physical(&logical_col("e"), &null_batch.schema()); + let null_expr = logical2physical(&logical_col("e"), null_batch.schema()); let null_value = null_expr.evaluate(&null_batch).unwrap(); - // All nulls shouldn't short-circuit for AND or OR + // All nulls shouldn't short-circuit for `AND` or `OR` assert!(matches!( check_short_circuit(&null_value, &Operator::And), ShortCircuitStrategy::None @@ -5131,7 +5134,7 @@ mod tests { let batch = RecordBatch::try_new(Arc::clone(&schema), vec![Arc::clone(&c_array)]) .unwrap(); - let expr = logical2physical(&logical_col("c").and(expr_lit(true)), &schema); + let expr = logical2physical(&logical_col("c").and(expr_lit(true)), schema); let result = expr.evaluate(&batch).unwrap(); let ColumnarValue::Array(result_arr) = result else { diff --git a/datafusion/physical-expr/src/planner.rs b/datafusion/physical-expr/src/planner.rs index 84a6aa4309872..b3875886bff71 100644 --- a/datafusion/physical-expr/src/planner.rs +++ b/datafusion/physical-expr/src/planner.rs @@ -23,7 +23,7 @@ use crate::{ expressions::{self, Column, Literal, binary, like, similar_to}, }; -use arrow::datatypes::Schema; +use arrow::datatypes::SchemaRef; use datafusion_common::config::ConfigOptions; use datafusion_common::metadata::FieldMetadata; use datafusion_common::{ @@ -405,10 +405,9 @@ where .collect() } -/// Convert a logical expression to a physical expression (without any simplification, etc) -pub fn logical2physical(expr: &Expr, schema: &Schema) -> Arc { - // TODO this makes a deep copy of the Schema. Should take SchemaRef instead and avoid deep copy - let df_schema = schema.clone().to_dfschema().unwrap(); +/// Convert a logical expression to a physical expression (without any simplification, etc.) +pub fn logical2physical(expr: &Expr, schema: SchemaRef) -> Arc { + let df_schema = schema.to_dfschema().unwrap(); let execution_props = ExecutionProps::new(); create_physical_expr(expr, &df_schema, &execution_props).unwrap() } @@ -416,7 +415,7 @@ pub fn logical2physical(expr: &Expr, schema: &Schema) -> Arc { #[cfg(test)] mod tests { use arrow::array::{ArrayRef, BooleanArray, RecordBatch, StringArray}; - use arrow::datatypes::{DataType, Field}; + use arrow::datatypes::{DataType, Field, Schema}; use datafusion_expr::{Operator, col, lit}; diff --git a/datafusion/physical-expr/src/utils/guarantee.rs b/datafusion/physical-expr/src/utils/guarantee.rs index c4ce74fd3a573..104deac22071f 100644 --- a/datafusion/physical-expr/src/utils/guarantee.rs +++ b/datafusion/physical-expr/src/utils/guarantee.rs @@ -1064,8 +1064,7 @@ mod test { /// Tests that analyzing expr results in the expected guarantees fn test_analyze(expr: Expr, expected: Vec) { println!("Begin analyze of {expr}"); - let schema = schema(); - let physical_expr = logical2physical(&expr, &schema); + let physical_expr = logical2physical(&expr, schema()); let actual = LiteralGuarantee::analyze(&physical_expr) .into_iter() diff --git a/datafusion/pruning/src/pruning_predicate.rs b/datafusion/pruning/src/pruning_predicate.rs index b5b8267d7f93f..a77818b296877 100644 --- a/datafusion/pruning/src/pruning_predicate.rs +++ b/datafusion/pruning/src/pruning_predicate.rs @@ -2371,8 +2371,8 @@ mod tests { Field::new("c2", DataType::Int32, true), ])); let expr = col("c1").eq(lit(100)).and(col("c2").eq(lit(200))); - let expr = logical2physical(&expr, &schema); - let p = PruningPredicate::try_new(expr, Arc::clone(&schema)).unwrap(); + let expr = logical2physical(&expr, Arc::clone(&schema)); + let p = PruningPredicate::try_new(expr, schema).unwrap(); // note pruning expression refers to row_count twice assert_eq!( "c1_null_count@2 != row_count@3 AND c1_min@0 <= 100 AND 100 <= c1_max@1 AND c2_null_count@6 != row_count@3 AND c2_min@4 <= 200 AND 200 <= c2_max@5", @@ -2671,7 +2671,8 @@ mod tests { #[test] fn row_group_predicate_eq() -> Result<()> { - let schema = Schema::new(vec![Field::new("c1", DataType::Int32, false)]); + let schema = + Arc::new(Schema::new(vec![Field::new("c1", DataType::Int32, false)])); let expected_expr = "c1_null_count@2 != row_count@3 AND c1_min@0 <= 1 AND 1 <= c1_max@1"; @@ -2692,7 +2693,8 @@ mod tests { #[test] fn row_group_predicate_not_eq() -> Result<()> { - let schema = Schema::new(vec![Field::new("c1", DataType::Int32, false)]); + let schema = + Arc::new(Schema::new(vec![Field::new("c1", DataType::Int32, false)])); let expected_expr = "c1_null_count@2 != row_count@3 AND (c1_min@0 != 1 OR 1 != c1_max@1)"; @@ -2713,7 +2715,8 @@ mod tests { #[test] fn row_group_predicate_gt() -> Result<()> { - let schema = Schema::new(vec![Field::new("c1", DataType::Int32, false)]); + let schema = + Arc::new(Schema::new(vec![Field::new("c1", DataType::Int32, false)])); let expected_expr = "c1_null_count@1 != row_count@2 AND c1_max@0 > 1"; // test column on the left @@ -2733,7 +2736,8 @@ mod tests { #[test] fn row_group_predicate_gt_eq() -> Result<()> { - let schema = Schema::new(vec![Field::new("c1", DataType::Int32, false)]); + let schema = + Arc::new(Schema::new(vec![Field::new("c1", DataType::Int32, false)])); let expected_expr = "c1_null_count@1 != row_count@2 AND c1_max@0 >= 1"; // test column on the left @@ -2752,7 +2756,8 @@ mod tests { #[test] fn row_group_predicate_lt() -> Result<()> { - let schema = Schema::new(vec![Field::new("c1", DataType::Int32, false)]); + let schema = + Arc::new(Schema::new(vec![Field::new("c1", DataType::Int32, false)])); let expected_expr = "c1_null_count@1 != row_count@2 AND c1_min@0 < 1"; // test column on the left @@ -2772,7 +2777,8 @@ mod tests { #[test] fn row_group_predicate_lt_eq() -> Result<()> { - let schema = Schema::new(vec![Field::new("c1", DataType::Int32, false)]); + let schema = + Arc::new(Schema::new(vec![Field::new("c1", DataType::Int32, false)])); let expected_expr = "c1_null_count@1 != row_count@2 AND c1_min@0 <= 1"; // test column on the left @@ -2791,11 +2797,11 @@ mod tests { #[test] fn row_group_predicate_and() -> Result<()> { - let schema = Schema::new(vec![ + let schema = Arc::new(Schema::new(vec![ Field::new("c1", DataType::Int32, false), Field::new("c2", DataType::Int32, false), Field::new("c3", DataType::Int32, false), - ]); + ])); // test AND operator joining supported c1 < 1 expression and unsupported c2 > c3 expression let expr = col("c1").lt(lit(1)).and(col("c2").lt(col("c3"))); let expected_expr = "c1_null_count@1 != row_count@2 AND c1_min@0 < 1"; @@ -2808,10 +2814,10 @@ mod tests { #[test] fn row_group_predicate_or() -> Result<()> { - let schema = Schema::new(vec![ + let schema = Arc::new(Schema::new(vec![ Field::new("c1", DataType::Int32, false), Field::new("c2", DataType::Int32, false), - ]); + ])); // test OR operator joining supported c1 < 1 expression and unsupported c2 % 2 = 0 expression let expr = col("c1").lt(lit(1)).or(col("c2").rem(lit(2)).eq(lit(0))); let expected_expr = "true"; @@ -2824,7 +2830,8 @@ mod tests { #[test] fn row_group_predicate_not() -> Result<()> { - let schema = Schema::new(vec![Field::new("c1", DataType::Int32, false)]); + let schema = + Arc::new(Schema::new(vec![Field::new("c1", DataType::Int32, false)])); let expected_expr = "true"; let expr = col("c1").not(); @@ -2837,7 +2844,11 @@ mod tests { #[test] fn row_group_predicate_not_bool() -> Result<()> { - let schema = Schema::new(vec![Field::new("c1", DataType::Boolean, false)]); + let schema = Arc::new(Schema::new(vec![Field::new( + "c1", + DataType::Boolean, + false, + )])); let expected_expr = "NOT c1_min@0 AND c1_max@1"; let expr = col("c1").not(); @@ -2850,7 +2861,11 @@ mod tests { #[test] fn row_group_predicate_bool() -> Result<()> { - let schema = Schema::new(vec![Field::new("c1", DataType::Boolean, false)]); + let schema = Arc::new(Schema::new(vec![Field::new( + "c1", + DataType::Boolean, + false, + )])); let expected_expr = "c1_min@0 OR c1_max@1"; let expr = col("c1"); @@ -2981,7 +2996,7 @@ mod tests { // Note that we have no stats, pruning can only happen via partition value pruning from the dynamic filter .with_row_counts("c1", vec![Some(10)]); let dynamic_filter_expr = col("c1").gt(lit(5)).and(col("part").eq(lit("B"))); - let phys_expr = logical2physical(&dynamic_filter_expr, &schema); + let phys_expr = logical2physical(&dynamic_filter_expr, Arc::clone(&schema)); let children = collect_columns(&phys_expr) .iter() .map(|c| Arc::new(c.clone()) as Arc) @@ -3021,7 +3036,11 @@ mod tests { #[test] fn row_group_predicate_lt_bool() -> Result<()> { - let schema = Schema::new(vec![Field::new("c1", DataType::Boolean, false)]); + let schema = Arc::new(Schema::new(vec![Field::new( + "c1", + DataType::Boolean, + false, + )])); let expected_expr = "c1_null_count@1 != row_count@2 AND c1_min@0 < true"; // DF doesn't support arithmetic on boolean columns so @@ -3036,10 +3055,10 @@ mod tests { #[test] fn row_group_predicate_required_columns() -> Result<()> { - let schema = Schema::new(vec![ + let schema = Arc::new(Schema::new(vec![ Field::new("c1", DataType::Int32, false), Field::new("c2", DataType::Int32, false), - ]); + ])); let mut required_columns = RequiredColumns::new(); // c1 < 1 and (c2 = 2 or c2 = 3) let expr = col("c1") @@ -3127,10 +3146,10 @@ mod tests { #[test] fn row_group_predicate_in_list() -> Result<()> { - let schema = Schema::new(vec![ + let schema = Arc::new(Schema::new(vec![ Field::new("c1", DataType::Int32, false), Field::new("c2", DataType::Int32, false), - ]); + ])); // test c1 in(1, 2, 3) let expr = Expr::InList(InList::new( Box::new(col("c1")), @@ -3147,10 +3166,10 @@ mod tests { #[test] fn row_group_predicate_in_list_empty() -> Result<()> { - let schema = Schema::new(vec![ + let schema = Arc::new(Schema::new(vec![ Field::new("c1", DataType::Int32, false), Field::new("c2", DataType::Int32, false), - ]); + ])); // test c1 in() let expr = Expr::InList(InList::new(Box::new(col("c1")), vec![], false)); let expected_expr = "true"; @@ -3163,10 +3182,10 @@ mod tests { #[test] fn row_group_predicate_in_list_negated() -> Result<()> { - let schema = Schema::new(vec![ + let schema = Arc::new(Schema::new(vec![ Field::new("c1", DataType::Int32, false), Field::new("c2", DataType::Int32, false), - ]); + ])); // test c1 not in(1, 2, 3) let expr = Expr::InList(InList::new( Box::new(col("c1")), @@ -3183,10 +3202,10 @@ mod tests { #[test] fn row_group_predicate_between() -> Result<()> { - let schema = Schema::new(vec![ + let schema = Arc::new(Schema::new(vec![ Field::new("c1", DataType::Int32, false), Field::new("c2", DataType::Int32, false), - ]); + ])); // test c1 BETWEEN 1 AND 5 let expr1 = col("c1").between(lit(1), lit(5)); @@ -3206,10 +3225,10 @@ mod tests { #[test] fn row_group_predicate_between_with_in_list() -> Result<()> { - let schema = Schema::new(vec![ + let schema = Arc::new(Schema::new(vec![ Field::new("c1", DataType::Int32, false), Field::new("c2", DataType::Int32, false), - ]); + ])); // test c1 in(1, 2) let expr1 = col("c1").in_list(vec![lit(1), lit(2)], false); @@ -3229,7 +3248,8 @@ mod tests { #[test] fn row_group_predicate_in_list_to_many_values() -> Result<()> { - let schema = Schema::new(vec![Field::new("c1", DataType::Int32, false)]); + let schema = + Arc::new(Schema::new(vec![Field::new("c1", DataType::Int32, false)])); // test c1 in(1..21) // in pruning.rs has MAX_LIST_VALUE_SIZE_REWRITE = 20, more than this value will be rewrite // always true @@ -3245,7 +3265,8 @@ mod tests { #[test] fn row_group_predicate_cast_int_int() -> Result<()> { - let schema = Schema::new(vec![Field::new("c1", DataType::Int32, false)]); + let schema = + Arc::new(Schema::new(vec![Field::new("c1", DataType::Int32, false)])); let expected_expr = "c1_null_count@2 != row_count@3 AND CAST(c1_min@0 AS Int64) <= 1 AND 1 <= CAST(c1_max@1 AS Int64)"; // test cast(c1 as int64) = 1 @@ -3283,7 +3304,11 @@ mod tests { #[test] fn row_group_predicate_cast_string_string() -> Result<()> { - let schema = Schema::new(vec![Field::new("c1", DataType::Utf8View, false)]); + let schema = Arc::new(Schema::new(vec![Field::new( + "c1", + DataType::Utf8View, + false, + )])); let expected_expr = "c1_null_count@2 != row_count@3 AND CAST(c1_min@0 AS Utf8) <= 1 AND 1 <= CAST(c1_max@1 AS Utf8)"; // test column on the left @@ -3305,7 +3330,11 @@ mod tests { #[test] fn row_group_predicate_cast_string_int() -> Result<()> { - let schema = Schema::new(vec![Field::new("c1", DataType::Utf8View, false)]); + let schema = Arc::new(Schema::new(vec![Field::new( + "c1", + DataType::Utf8View, + false, + )])); let expected_expr = "true"; // test column on the left @@ -3325,7 +3354,8 @@ mod tests { #[test] fn row_group_predicate_cast_int_string() -> Result<()> { - let schema = Schema::new(vec![Field::new("c1", DataType::Int32, false)]); + let schema = + Arc::new(Schema::new(vec![Field::new("c1", DataType::Int32, false)])); let expected_expr = "true"; // test column on the left @@ -3347,7 +3377,8 @@ mod tests { #[test] fn row_group_predicate_date_date() -> Result<()> { - let schema = Schema::new(vec![Field::new("c1", DataType::Date32, false)]); + let schema = + Arc::new(Schema::new(vec![Field::new("c1", DataType::Date32, false)])); let expected_expr = "c1_null_count@2 != row_count@3 AND CAST(c1_min@0 AS Date64) <= 1970-01-01 AND 1970-01-01 <= CAST(c1_max@1 AS Date64)"; // test column on the left @@ -3370,7 +3401,8 @@ mod tests { #[test] fn row_group_predicate_dict_string_date() -> Result<()> { // Test with Dictionary for the literal - let schema = Schema::new(vec![Field::new("c1", DataType::Date32, false)]); + let schema = + Arc::new(Schema::new(vec![Field::new("c1", DataType::Date32, false)])); let expected_expr = "true"; // test column on the left @@ -3398,11 +3430,11 @@ mod tests { #[test] fn row_group_predicate_date_dict_string() -> Result<()> { // Test with Dictionary for the column - let schema = Schema::new(vec![Field::new( + let schema = Arc::new(Schema::new(vec![Field::new( "c1", DataType::Dictionary(Box::new(DataType::UInt8), Box::new(DataType::Utf8)), false, - )]); + )])); let expected_expr = "true"; // test column on the left @@ -3425,11 +3457,11 @@ mod tests { #[test] fn row_group_predicate_dict_dict_same_value_type() -> Result<()> { // Test with Dictionary types that have the same value type but different key types - let schema = Schema::new(vec![Field::new( + let schema = Arc::new(Schema::new(vec![Field::new( "c1", DataType::Dictionary(Box::new(DataType::UInt8), Box::new(DataType::Utf8)), false, - )]); + )])); // Direct comparison with no cast let expr = col("c1").eq(lit(ScalarValue::Utf8(Some("test".to_string())))); @@ -3456,11 +3488,11 @@ mod tests { #[test] fn row_group_predicate_dict_dict_different_value_type() -> Result<()> { // Test with Dictionary types that have different value types - let schema = Schema::new(vec![Field::new( + let schema = Arc::new(Schema::new(vec![Field::new( "c1", DataType::Dictionary(Box::new(DataType::UInt8), Box::new(DataType::Int32)), false, - )]); + )])); let expected_expr = "c1_null_count@2 != row_count@3 AND CAST(c1_min@0 AS Int64) <= 123 AND 123 <= CAST(c1_max@1 AS Int64)"; // Test with literal of a different type @@ -3476,7 +3508,7 @@ mod tests { #[test] fn row_group_predicate_nested_dict() -> Result<()> { // Test with nested Dictionary types - let schema = Schema::new(vec![Field::new( + let schema = Arc::new(Schema::new(vec![Field::new( "c1", DataType::Dictionary( Box::new(DataType::UInt8), @@ -3486,7 +3518,7 @@ mod tests { )), ), false, - )]); + )])); let expected_expr = "c1_null_count@2 != row_count@3 AND c1_min@0 <= test AND test <= c1_max@1"; @@ -3502,11 +3534,11 @@ mod tests { #[test] fn row_group_predicate_dict_date_dict_date() -> Result<()> { // Test with dictionary-wrapped date types for both sides - let schema = Schema::new(vec![Field::new( + let schema = Arc::new(Schema::new(vec![Field::new( "c1", DataType::Dictionary(Box::new(DataType::UInt8), Box::new(DataType::Date32)), false, - )]); + )])); let expected_expr = "c1_null_count@2 != row_count@3 AND CAST(c1_min@0 AS Dictionary(UInt16, Date64)) <= 1970-01-01 AND 1970-01-01 <= CAST(c1_max@1 AS Dictionary(UInt16, Date64))"; // Test with a cast to a different date type @@ -3524,7 +3556,7 @@ mod tests { #[test] fn row_group_predicate_date_string() -> Result<()> { - let schema = Schema::new(vec![Field::new("c1", DataType::Utf8, false)]); + let schema = Arc::new(Schema::new(vec![Field::new("c1", DataType::Utf8, false)])); let expected_expr = "true"; // test column on the left @@ -3546,7 +3578,8 @@ mod tests { #[test] fn row_group_predicate_string_date() -> Result<()> { - let schema = Schema::new(vec![Field::new("c1", DataType::Date32, false)]); + let schema = + Arc::new(Schema::new(vec![Field::new("c1", DataType::Date32, false)])); let expected_expr = "true"; // test column on the left @@ -3568,7 +3601,8 @@ mod tests { #[test] fn row_group_predicate_cast_list() -> Result<()> { - let schema = Schema::new(vec![Field::new("c1", DataType::Int32, false)]); + let schema = + Arc::new(Schema::new(vec![Field::new("c1", DataType::Int32, false)])); // test cast(c1 as int64) in int64(1, 2, 3) let expr = Expr::InList(InList::new( Box::new(cast(col("c1"), DataType::Int64)), @@ -4682,7 +4716,7 @@ mod tests { true, // s1 ["AB", "A\u{10ffff}\u{10ffff}\u{10ffff}"] ==> some rows could pass (must keep) true, - // s1 ["A\u{10ffff}\u{10ffff}", "A\u{10ffff}\u{10ffff}"] ==> no row match. (min, max) maybe truncate + // s1 ["A\u{10ffff}\u{10ffff}", "A\u{10ffff}\u{10ffff}"] ==> no row match. (min, max) maybe truncate // original (min, max) maybe ("A\u{10ffff}\u{10ffff}\u{10ffff}", "A\u{10ffff}\u{10ffff}\u{10ffff}\u{10ffff}") true, ]; @@ -4785,14 +4819,14 @@ mod tests { #[test] fn test_rewrite_expr_to_prunable() { - let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]); + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, true)])); let df_schema = DFSchema::try_from(schema.clone()).unwrap(); // column op lit let left_input = col("a"); - let left_input = logical2physical(&left_input, &schema); + let left_input = logical2physical(&left_input, Arc::clone(&schema)); let right_input = lit(ScalarValue::Int32(Some(12))); - let right_input = logical2physical(&right_input, &schema); + let right_input = logical2physical(&right_input, Arc::clone(&schema)); let (result_left, _, result_right) = rewrite_expr_to_prunable( &left_input, Operator::Eq, @@ -4805,9 +4839,9 @@ mod tests { // cast op lit let left_input = cast(col("a"), DataType::Decimal128(20, 3)); - let left_input = logical2physical(&left_input, &schema); + let left_input = logical2physical(&left_input, Arc::clone(&schema)); let right_input = lit(ScalarValue::Decimal128(Some(12), 20, 3)); - let right_input = logical2physical(&right_input, &schema); + let right_input = logical2physical(&right_input, Arc::clone(&schema)); let (result_left, _, result_right) = rewrite_expr_to_prunable( &left_input, Operator::Gt, @@ -4820,9 +4854,9 @@ mod tests { // try_cast op lit let left_input = try_cast(col("a"), DataType::Int64); - let left_input = logical2physical(&left_input, &schema); + let left_input = logical2physical(&left_input, Arc::clone(&schema)); let right_input = lit(ScalarValue::Int64(Some(12))); - let right_input = logical2physical(&right_input, &schema); + let right_input = logical2physical(&right_input, Arc::clone(&schema)); let (result_left, _, result_right) = rewrite_expr_to_prunable(&left_input, Operator::Gt, &right_input, df_schema) .unwrap(); @@ -4845,17 +4879,17 @@ mod tests { } } - let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]); - let schema_with_b = Schema::new(vec![ + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, true)])); + let schema_with_b = Arc::new(Schema::new(vec![ Field::new("a", DataType::Int32, true), Field::new("b", DataType::Int32, true), - ]); + ])); let rewriter = PredicateRewriter::new() .with_unhandled_hook(Arc::new(CustomUnhandledHook {})); let transform_expr = |expr| { - let expr = logical2physical(&expr, &schema_with_b); + let expr = logical2physical(&expr, Arc::clone(&schema_with_b)); rewriter.rewrite_predicate_to_statistics_predicate(&expr, &schema) }; @@ -4863,13 +4897,13 @@ mod tests { let known_expression = col("a").eq(lit(12)); let known_expression_transformed = PredicateRewriter::new() .rewrite_predicate_to_statistics_predicate( - &logical2physical(&known_expression, &schema), + &logical2physical(&known_expression, Arc::clone(&schema)), &schema, ); // an expression referencing an unknown column (that is not in the schema) gets passed to the hook let input = col("b").eq(lit(12)); - let expected = logical2physical(&lit(42), &schema); + let expected = logical2physical(&lit(42), Arc::clone(&schema)); let transformed = transform_expr(input.clone()); assert_eq!(transformed.to_string(), expected.to_string()); @@ -4878,14 +4912,14 @@ mod tests { let expected = phys_expr::BinaryExpr::new( Arc::::clone(&known_expression_transformed), Operator::And, - logical2physical(&lit(42), &schema), + logical2physical(&lit(42), Arc::clone(&schema)), ); let transformed = transform_expr(input.clone()); assert_eq!(transformed.to_string(), expected.to_string()); // an unknown expression gets passed to the hook let input = array_has(make_array(vec![lit(1)]), col("a")); - let expected = logical2physical(&lit(42), &schema); + let expected = logical2physical(&lit(42), Arc::clone(&schema)); let transformed = transform_expr(input.clone()); assert_eq!(transformed.to_string(), expected.to_string()); @@ -4894,7 +4928,7 @@ mod tests { let expected = phys_expr::BinaryExpr::new( Arc::::clone(&known_expression_transformed), Operator::And, - logical2physical(&lit(42), &schema), + logical2physical(&lit(42), Arc::clone(&schema)), ); let transformed = transform_expr(input.clone()); assert_eq!(transformed.to_string(), expected.to_string()); @@ -4904,12 +4938,12 @@ mod tests { fn test_rewrite_expr_to_prunable_error() { // cast string value to numeric value // this cast is not supported - let schema = Schema::new(vec![Field::new("a", DataType::Utf8, true)]); + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Utf8, true)])); let df_schema = DFSchema::try_from(schema.clone()).unwrap(); let left_input = cast(col("a"), DataType::Int64); - let left_input = logical2physical(&left_input, &schema); + let left_input = logical2physical(&left_input, Arc::clone(&schema)); let right_input = lit(ScalarValue::Int64(Some(12))); - let right_input = logical2physical(&right_input, &schema); + let right_input = logical2physical(&right_input, Arc::clone(&schema)); let result = rewrite_expr_to_prunable( &left_input, Operator::Gt, @@ -4920,9 +4954,9 @@ mod tests { // other expr let left_input = is_null(col("a")); - let left_input = logical2physical(&left_input, &schema); + let left_input = logical2physical(&left_input, Arc::clone(&schema)); let right_input = lit(ScalarValue::Int64(Some(12))); - let right_input = logical2physical(&right_input, &schema); + let right_input = logical2physical(&right_input, Arc::clone(&schema)); let result = rewrite_expr_to_prunable(&left_input, Operator::Gt, &right_input, df_schema); assert!(result.is_err()); @@ -5376,8 +5410,8 @@ mod tests { expected: &[bool], ) { println!("Pruning with expr: {expr}"); - let expr = logical2physical(&expr, schema); - let p = PruningPredicate::try_new(expr, Arc::::clone(schema)).unwrap(); + let expr = logical2physical(&expr, Arc::clone(schema)); + let p = PruningPredicate::try_new(expr, Arc::clone(schema)).unwrap(); let result = p.prune(statistics).unwrap(); assert_eq!(result, expected); } @@ -5389,55 +5423,58 @@ mod tests { expected: &[bool], ) { println!("Pruning with expr: {expr}"); - let expr = logical2physical(&expr, schema); + let expr = logical2physical(&expr, Arc::clone(schema)); let simplifier = PhysicalExprSimplifier::new(schema); let expr = simplifier.simplify(expr).unwrap(); - let p = PruningPredicate::try_new(expr, Arc::::clone(schema)).unwrap(); + let p = PruningPredicate::try_new(expr, Arc::clone(schema)).unwrap(); let result = p.prune(statistics).unwrap(); assert_eq!(result, expected); } fn test_build_predicate_expression( expr: &Expr, - schema: &Schema, + schema: &SchemaRef, required_columns: &mut RequiredColumns, ) -> Arc { - let expr = logical2physical(expr, schema); + let expr = logical2physical(expr, Arc::clone(schema)); let unhandled_hook = Arc::new(ConstantUnhandledPredicateHook::default()) as _; - build_predicate_expression( - &expr, - &Arc::new(schema.clone()), - required_columns, - &unhandled_hook, - ) + build_predicate_expression(&expr, schema, required_columns, &unhandled_hook) } #[test] fn test_build_predicate_expression_with_false() { let expr = lit(ScalarValue::Boolean(Some(false))); - let schema = Schema::empty(); + let schema = Arc::new(Schema::empty()); let res = test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); - let expected = logical2physical(&expr, &schema); + let expected = logical2physical(&expr, schema); assert_eq!(&res, &expected); } #[test] fn test_build_predicate_expression_with_and_false() { - let schema = Schema::new(vec![Field::new("c1", DataType::Utf8View, false)]); + let schema = Arc::new(Schema::new(vec![Field::new( + "c1", + DataType::Utf8View, + false, + )])); let expr = and( col("c1").eq(lit("a")), lit(ScalarValue::Boolean(Some(false))), ); let res = test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); - let expected = logical2physical(&lit(ScalarValue::Boolean(Some(false))), &schema); + let expected = logical2physical(&lit(ScalarValue::Boolean(Some(false))), schema); assert_eq!(&res, &expected); } #[test] fn test_build_predicate_expression_with_or_false() { - let schema = Schema::new(vec![Field::new("c1", DataType::Utf8View, false)]); + let schema = Arc::new(Schema::new(vec![Field::new( + "c1", + DataType::Utf8View, + false, + )])); let left_expr = col("c1").eq(lit("a")); let right_expr = lit(ScalarValue::Boolean(Some(false))); let res = test_build_predicate_expression(