diff --git a/datafusion/catalog-listing/src/helpers.rs b/datafusion/catalog-listing/src/helpers.rs index ea016015cebd3..031b2ebfb8109 100644 --- a/datafusion/catalog-listing/src/helpers.rs +++ b/datafusion/catalog-listing/src/helpers.rs @@ -83,6 +83,7 @@ pub fn expr_applicable_for_cols(col_names: &[&str], expr: &Expr) -> bool { | Expr::Exists(_) | Expr::InSubquery(_) | Expr::ScalarSubquery(_) + | Expr::SetComparison(_) | Expr::GroupingSet(_) | Expr::Case(_) => Ok(TreeNodeRecursion::Continue), diff --git a/datafusion/core/tests/set_comparison.rs b/datafusion/core/tests/set_comparison.rs new file mode 100644 index 0000000000000..464d6c937b328 --- /dev/null +++ b/datafusion/core/tests/set_comparison.rs @@ -0,0 +1,193 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use arrow::array::{Int32Array, StringArray}; +use arrow::datatypes::{DataType, Field, Schema}; +use arrow::record_batch::RecordBatch; +use datafusion::prelude::SessionContext; +use datafusion_common::{Result, assert_batches_eq, assert_contains}; + +fn build_table(values: &[i32]) -> Result { + let schema = Arc::new(Schema::new(vec![Field::new("v", DataType::Int32, true)])); + let array = + Arc::new(Int32Array::from(values.to_vec())) as Arc; + RecordBatch::try_new(schema, vec![array]).map_err(Into::into) +} + +#[tokio::test] +async fn set_comparison_any() -> Result<()> { + let ctx = SessionContext::new(); + + ctx.register_batch("t", build_table(&[1, 6, 10])?)?; + // Include a NULL in the subquery input to ensure we propagate UNKNOWN correctly. + ctx.register_batch("s", { + let schema = Arc::new(Schema::new(vec![Field::new("v", DataType::Int32, true)])); + let array = Arc::new(Int32Array::from(vec![Some(5), None])) + as Arc; + RecordBatch::try_new(schema, vec![array])? + })?; + + let df = ctx + .sql("select v from t where v > any(select v from s)") + .await?; + let results = df.collect().await?; + + assert_batches_eq!( + &["+----+", "| v |", "+----+", "| 6 |", "| 10 |", "+----+",], + &results + ); + Ok(()) +} + +#[tokio::test] +async fn set_comparison_any_aggregate_subquery() -> Result<()> { + let ctx = SessionContext::new(); + + ctx.register_batch("t", build_table(&[1, 7])?)?; + ctx.register_batch("s", build_table(&[1, 2, 3])?)?; + + let df = ctx + .sql( + "select v from t where v > any(select sum(v) from s group by v % 2) order by v", + ) + .await?; + let results = df.collect().await?; + + assert_batches_eq!(&["+---+", "| v |", "+---+", "| 7 |", "+---+",], &results); + Ok(()) +} + +#[tokio::test] +async fn set_comparison_all_empty() -> Result<()> { + let ctx = SessionContext::new(); + + ctx.register_batch("t", build_table(&[1, 6, 10])?)?; + ctx.register_batch( + "e", + RecordBatch::new_empty(Arc::new(Schema::new(vec![Field::new( + "v", + DataType::Int32, + true, + )]))), + )?; + + let df = ctx + .sql("select v from t where v < all(select v from e)") + .await?; + let results = df.collect().await?; + + assert_batches_eq!( + &[ + "+----+", "| v |", "+----+", "| 1 |", "| 6 |", "| 10 |", "+----+", + ], + &results + ); + Ok(()) +} + +#[tokio::test] +async fn set_comparison_type_mismatch() -> Result<()> { + let ctx = SessionContext::new(); + + ctx.register_batch("t", build_table(&[1])?)?; + ctx.register_batch("strings", { + let schema = Arc::new(Schema::new(vec![Field::new("s", DataType::Utf8, true)])); + let array = Arc::new(StringArray::from(vec![Some("a"), Some("b")])) + as Arc; + RecordBatch::try_new(schema, vec![array])? + })?; + + let df = ctx + .sql("select v from t where v > any(select s from strings)") + .await?; + let err = df.collect().await.unwrap_err(); + assert_contains!( + err.to_string(), + "expr type Int32 can't cast to Utf8 in SetComparison" + ); + Ok(()) +} + +#[tokio::test] +async fn set_comparison_multiple_operators() -> Result<()> { + let ctx = SessionContext::new(); + + ctx.register_batch("t", build_table(&[1, 2, 3, 4])?)?; + ctx.register_batch("s", build_table(&[2, 3])?)?; + + let df = ctx + .sql("select v from t where v = any(select v from s) order by v") + .await?; + let results = df.collect().await?; + assert_batches_eq!( + &["+---+", "| v |", "+---+", "| 2 |", "| 3 |", "+---+",], + &results + ); + + let df = ctx + .sql("select v from t where v != all(select v from s) order by v") + .await?; + let results = df.collect().await?; + assert_batches_eq!( + &["+---+", "| v |", "+---+", "| 1 |", "| 4 |", "+---+",], + &results + ); + + let df = ctx + .sql("select v from t where v >= all(select v from s) order by v") + .await?; + let results = df.collect().await?; + assert_batches_eq!( + &["+---+", "| v |", "+---+", "| 3 |", "| 4 |", "+---+",], + &results + ); + + let df = ctx + .sql("select v from t where v <= any(select v from s) order by v") + .await?; + let results = df.collect().await?; + assert_batches_eq!( + &[ + "+---+", "| v |", "+---+", "| 1 |", "| 2 |", "| 3 |", "+---+", + ], + &results + ); + Ok(()) +} + +#[tokio::test] +async fn set_comparison_null_semantics_all() -> Result<()> { + let ctx = SessionContext::new(); + + ctx.register_batch("t", build_table(&[5])?)?; + ctx.register_batch("s", { + let schema = Arc::new(Schema::new(vec![Field::new("v", DataType::Int32, true)])); + let array = Arc::new(Int32Array::from(vec![Some(1), None])) + as Arc; + RecordBatch::try_new(schema, vec![array])? + })?; + + let df = ctx + .sql("select v from t where v != all(select v from s)") + .await?; + let results = df.collect().await?; + let row_count: usize = results.iter().map(|batch| batch.num_rows()).sum(); + assert_eq!(0, row_count); + Ok(()) +} diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index c7d825ce1d52f..c8bf59038fefb 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -372,6 +372,8 @@ pub enum Expr { Exists(Exists), /// IN subquery InSubquery(InSubquery), + /// Set comparison subquery (e.g. `= ANY`, `> ALL`) + SetComparison(SetComparison), /// Scalar subquery ScalarSubquery(Subquery), /// Represents a reference to all available fields in a specific schema, @@ -1101,6 +1103,54 @@ impl Exists { } } +/// Whether the set comparison uses `ANY`/`SOME` or `ALL` +#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Hash, Debug)] +pub enum SetQuantifier { + /// `ANY` (or `SOME`) + Any, + /// `ALL` + All, +} + +impl Display for SetQuantifier { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + match self { + SetQuantifier::Any => write!(f, "ANY"), + SetQuantifier::All => write!(f, "ALL"), + } + } +} + +/// Set comparison subquery (e.g. `= ANY`, `> ALL`) +#[derive(Clone, PartialEq, Eq, PartialOrd, Hash, Debug)] +pub struct SetComparison { + /// The expression to compare + pub expr: Box, + /// Subquery that will produce a single column of data to compare against + pub subquery: Subquery, + /// Comparison operator (e.g. `=`, `>`, `<`) + pub op: Operator, + /// Quantifier (`ANY`/`ALL`) + pub quantifier: SetQuantifier, +} + +impl SetComparison { + /// Create a new set comparison expression + pub fn new( + expr: Box, + subquery: Subquery, + op: Operator, + quantifier: SetQuantifier, + ) -> Self { + Self { + expr, + subquery, + op, + quantifier, + } + } +} + /// InList expression #[derive(Clone, PartialEq, Eq, PartialOrd, Hash, Debug)] pub struct InList { @@ -1503,6 +1553,7 @@ impl Expr { Expr::GroupingSet(..) => "GroupingSet", Expr::InList { .. } => "InList", Expr::InSubquery(..) => "InSubquery", + Expr::SetComparison(..) => "SetComparison", Expr::IsNotNull(..) => "IsNotNull", Expr::IsNull(..) => "IsNull", Expr::Like { .. } => "Like", @@ -2058,6 +2109,7 @@ impl Expr { | Expr::GroupingSet(..) | Expr::InList(..) | Expr::InSubquery(..) + | Expr::SetComparison(..) | Expr::IsFalse(..) | Expr::IsNotFalse(..) | Expr::IsNotNull(..) @@ -2651,6 +2703,16 @@ impl HashNode for Expr { subquery.hash(state); negated.hash(state); } + Expr::SetComparison(SetComparison { + expr: _, + subquery, + op, + quantifier, + }) => { + subquery.hash(state); + op.hash(state); + quantifier.hash(state); + } Expr::ScalarSubquery(subquery) => { subquery.hash(state); } @@ -2841,6 +2903,12 @@ impl Display for SchemaDisplay<'_> { write!(f, "NOT IN") } Expr::InSubquery(InSubquery { negated: false, .. }) => write!(f, "IN"), + Expr::SetComparison(SetComparison { + expr, + op, + quantifier, + .. + }) => write!(f, "{} {op} {quantifier}", SchemaDisplay(expr.as_ref())), Expr::IsTrue(expr) => write!(f, "{} IS TRUE", SchemaDisplay(expr)), Expr::IsFalse(expr) => write!(f, "{} IS FALSE", SchemaDisplay(expr)), Expr::IsNotTrue(expr) => { @@ -3316,6 +3384,12 @@ impl Display for Expr { subquery, negated: false, }) => write!(f, "{expr} IN ({subquery:?})"), + Expr::SetComparison(SetComparison { + expr, + subquery, + op, + quantifier, + }) => write!(f, "{expr} {op} {quantifier} ({subquery:?})"), Expr::ScalarSubquery(subquery) => write!(f, "({subquery:?})"), Expr::BinaryExpr(expr) => write!(f, "{expr}"), Expr::ScalarFunction(fun) => { @@ -3799,6 +3873,7 @@ mod test { } use super::*; + use crate::logical_plan::{EmptyRelation, LogicalPlan}; #[test] fn test_display_wildcard() { @@ -3889,6 +3964,28 @@ mod test { ) } + #[test] + fn test_display_set_comparison() { + let subquery = Subquery { + subquery: Arc::new(LogicalPlan::EmptyRelation(EmptyRelation { + produce_one_row: false, + schema: Arc::new(DFSchema::empty()), + })), + outer_ref_columns: vec![], + spans: Spans::new(), + }; + + let expr = Expr::SetComparison(SetComparison::new( + Box::new(Expr::Column(Column::from_name("a"))), + subquery, + Operator::Gt, + SetQuantifier::Any, + )); + + assert_eq!(format!("{expr}"), "a > ANY ()"); + assert_eq!(format!("{}", expr.human_display()), "a > ANY ()"); + } + #[test] fn test_schema_display_alias_with_relation() { assert_eq!( diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index 691a8c508f801..8b0e2b4b9cec8 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -196,6 +196,7 @@ impl ExprSchemable for Expr { | Expr::IsNull(_) | Expr::Exists { .. } | Expr::InSubquery(_) + | Expr::SetComparison(_) | Expr::Between { .. } | Expr::InList { .. } | Expr::IsNotNull(_) @@ -376,6 +377,7 @@ impl ExprSchemable for Expr { | Expr::IsNotFalse(_) | Expr::IsNotUnknown(_) | Expr::Exists { .. } => Ok(false), + Expr::SetComparison(_) => Ok(true), Expr::InSubquery(InSubquery { expr, .. }) => expr.nullable(input_schema), Expr::ScalarSubquery(subquery) => { Ok(subquery.subquery.schema().field(0).is_nullable()) @@ -639,6 +641,7 @@ impl ExprSchemable for Expr { | Expr::TryCast(_) | Expr::InList(_) | Expr::InSubquery(_) + | Expr::SetComparison(_) | Expr::Wildcard { .. } | Expr::GroupingSet(_) | Expr::Placeholder(_) diff --git a/datafusion/expr/src/logical_plan/invariants.rs b/datafusion/expr/src/logical_plan/invariants.rs index 762491a255cbc..b39b23e30f4e8 100644 --- a/datafusion/expr/src/logical_plan/invariants.rs +++ b/datafusion/expr/src/logical_plan/invariants.rs @@ -22,7 +22,7 @@ use datafusion_common::{ use crate::{ Aggregate, Expr, Filter, Join, JoinType, LogicalPlan, Window, - expr::{Exists, InSubquery}, + expr::{Exists, InSubquery, SetComparison}, expr_rewriter::strip_outer_reference, utils::{collect_subquery_cols, split_conjunction}, }; @@ -81,6 +81,7 @@ fn assert_valid_extension_nodes(plan: &LogicalPlan, check: InvariantLevel) -> Re match expr { Expr::Exists(Exists { subquery, .. }) | Expr::InSubquery(InSubquery { subquery, .. }) + | Expr::SetComparison(SetComparison { subquery, .. }) | Expr::ScalarSubquery(subquery) => { assert_valid_extension_nodes(&subquery.subquery, check)?; } @@ -133,6 +134,7 @@ fn assert_subqueries_are_valid(plan: &LogicalPlan) -> Result<()> { match expr { Expr::Exists(Exists { subquery, .. }) | Expr::InSubquery(InSubquery { subquery, .. }) + | Expr::SetComparison(SetComparison { subquery, .. }) | Expr::ScalarSubquery(subquery) => { check_subquery_expr(plan, &subquery.subquery, expr)?; } @@ -229,6 +231,20 @@ pub fn check_subquery_expr( ); } } + if let Expr::SetComparison(set_comparison) = expr + && set_comparison.subquery.subquery.schema().fields().len() > 1 + { + return plan_err!( + "Set comparison subquery should only return one column, but found {}: {}", + set_comparison.subquery.subquery.schema().fields().len(), + set_comparison + .subquery + .subquery + .schema() + .field_names() + .join(", ") + ); + } match outer_plan { LogicalPlan::Projection(_) | LogicalPlan::Filter(_) @@ -237,7 +253,7 @@ pub fn check_subquery_expr( | LogicalPlan::Aggregate(_) | LogicalPlan::Join(_) => Ok(()), _ => plan_err!( - "In/Exist subquery can only be used in \ + "In/Exist/SetComparison subquery can only be used in \ Projection, Filter, TableScan, Window functions, Aggregate and Join plan nodes, \ but was used in [{}]", outer_plan.display() diff --git a/datafusion/expr/src/logical_plan/tree_node.rs b/datafusion/expr/src/logical_plan/tree_node.rs index 62a27b0a025ad..b233fe2e0fff0 100644 --- a/datafusion/expr/src/logical_plan/tree_node.rs +++ b/datafusion/expr/src/logical_plan/tree_node.rs @@ -46,7 +46,7 @@ use crate::{ }; use datafusion_common::tree_node::TreeNodeRefContainer; -use crate::expr::{Exists, InSubquery}; +use crate::expr::{Exists, InSubquery, SetComparison}; use datafusion_common::tree_node::{ Transformed, TreeNode, TreeNodeContainer, TreeNodeIterator, TreeNodeRecursion, TreeNodeRewriter, TreeNodeVisitor, @@ -815,6 +815,7 @@ impl LogicalPlan { expr.apply(|expr| match expr { Expr::Exists(Exists { subquery, .. }) | Expr::InSubquery(InSubquery { subquery, .. }) + | Expr::SetComparison(SetComparison { subquery, .. }) | Expr::ScalarSubquery(subquery) => { // use a synthetic plan so the collector sees a // LogicalPlan::Subquery (even though it is @@ -856,6 +857,22 @@ impl LogicalPlan { })), _ => internal_err!("Transformation should return Subquery"), }), + Expr::SetComparison(SetComparison { + expr, + subquery, + op, + quantifier, + }) => f(LogicalPlan::Subquery(subquery))?.map_data(|s| match s { + LogicalPlan::Subquery(subquery) => { + Ok(Expr::SetComparison(SetComparison { + expr, + subquery, + op, + quantifier, + })) + } + _ => internal_err!("Transformation should return Subquery"), + }), Expr::ScalarSubquery(subquery) => f(LogicalPlan::Subquery(subquery))? .map_data(|s| match s { LogicalPlan::Subquery(subquery) => { diff --git a/datafusion/expr/src/planner.rs b/datafusion/expr/src/planner.rs index 954f511651ced..942736c24b757 100644 --- a/datafusion/expr/src/planner.rs +++ b/datafusion/expr/src/planner.rs @@ -249,13 +249,6 @@ pub trait ExprPlanner: Debug + Send + Sync { ) } - /// Plans `ANY` expression, such as `expr = ANY(array_expr)` - /// - /// Returns origin binary expression if not possible - fn plan_any(&self, expr: RawBinaryExpr) -> Result> { - Ok(PlannerResult::Original(expr)) - } - /// Plans aggregate functions, such as `COUNT()` /// /// Returns original expression arguments if not possible diff --git a/datafusion/expr/src/tree_node.rs b/datafusion/expr/src/tree_node.rs index 742bae5b2320b..226c512a974d8 100644 --- a/datafusion/expr/src/tree_node.rs +++ b/datafusion/expr/src/tree_node.rs @@ -20,8 +20,8 @@ use crate::Expr; use crate::expr::{ AggregateFunction, AggregateFunctionParams, Alias, Between, BinaryExpr, Case, Cast, - GroupingSet, InList, InSubquery, Like, Placeholder, ScalarFunction, TryCast, Unnest, - WindowFunction, WindowFunctionParams, + GroupingSet, InList, InSubquery, Like, Placeholder, ScalarFunction, SetComparison, + TryCast, Unnest, WindowFunction, WindowFunctionParams, }; use datafusion_common::Result; @@ -58,7 +58,8 @@ impl TreeNode for Expr { | Expr::Negative(expr) | Expr::Cast(Cast { expr, .. }) | Expr::TryCast(TryCast { expr, .. }) - | Expr::InSubquery(InSubquery { expr, .. }) => expr.apply_elements(f), + | Expr::InSubquery(InSubquery { expr, .. }) + | Expr::SetComparison(SetComparison { expr, .. }) => expr.apply_elements(f), Expr::GroupingSet(GroupingSet::Rollup(exprs)) | Expr::GroupingSet(GroupingSet::Cube(exprs)) => exprs.apply_elements(f), Expr::ScalarFunction(ScalarFunction { args, .. }) => { @@ -128,6 +129,19 @@ impl TreeNode for Expr { | Expr::ScalarSubquery(_) | Expr::ScalarVariable(_, _) | Expr::Literal(_, _) => Transformed::no(self), + Expr::SetComparison(SetComparison { + expr, + subquery, + op, + quantifier, + }) => expr.map_elements(f)?.update_data(|expr| { + Expr::SetComparison(SetComparison { + expr, + subquery, + op, + quantifier, + }) + }), Expr::Unnest(Unnest { expr, .. }) => expr .map_elements(f)? .update_data(|expr| Expr::Unnest(Unnest { expr })), diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index de4ebf5fa96e9..7d13c9c1d24d1 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -312,6 +312,7 @@ pub fn expr_to_columns(expr: &Expr, accum: &mut HashSet) -> Result<()> { | Expr::InList { .. } | Expr::Exists { .. } | Expr::InSubquery(_) + | Expr::SetComparison(_) | Expr::ScalarSubquery(_) | Expr::Wildcard { .. } | Expr::Placeholder(_) diff --git a/datafusion/functions-nested/src/planner.rs b/datafusion/functions-nested/src/planner.rs index afb18a44f48ab..e96fdb7d4baca 100644 --- a/datafusion/functions-nested/src/planner.rs +++ b/datafusion/functions-nested/src/planner.rs @@ -37,7 +37,7 @@ use std::sync::Arc; use crate::map::map_udf; use crate::{ - array_has::{array_has_all, array_has_udf}, + array_has::array_has_all, expr_fn::{array_append, array_concat, array_prepend}, extract::{array_element, array_slice}, make_array::make_array, @@ -120,20 +120,6 @@ impl ExprPlanner for NestedFunctionPlanner { ScalarFunction::new_udf(map_udf(), vec![keys, values]), ))) } - - fn plan_any(&self, expr: RawBinaryExpr) -> Result> { - if expr.op == BinaryOperator::Eq { - Ok(PlannerResult::Planned(Expr::ScalarFunction( - ScalarFunction::new_udf( - array_has_udf(), - // left and right are reversed here so `needle=any(haystack)` -> `array_has(haystack, needle)` - vec![expr.right, expr.left], - ), - ))) - } else { - plan_err!("Unsupported AnyOp: '{}', only '=' is supported", expr.op) - } - } } #[derive(Debug)] diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index bc317e9c201ce..1885dcab2c4dc 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -35,7 +35,7 @@ use datafusion_common::{ }; use datafusion_expr::expr::{ self, AggregateFunctionParams, Alias, Between, BinaryExpr, Case, Exists, InList, - InSubquery, Like, ScalarFunction, Sort, WindowFunction, + InSubquery, Like, ScalarFunction, SetComparison, Sort, WindowFunction, }; use datafusion_expr::expr_rewriter::coerce_plan_expr_for_schema; use datafusion_expr::expr_schema::cast_subquery; @@ -368,6 +368,45 @@ impl TreeNodeRewriter for TypeCoercionRewriter<'_> { negated, )))) } + Expr::SetComparison(SetComparison { + expr, + subquery, + op, + quantifier, + }) => { + let new_plan = analyze_internal( + self.schema, + Arc::unwrap_or_clone(subquery.subquery), + )? + .data; + let expr_type = expr.get_type(self.schema)?; + let subquery_type = new_plan.schema().field(0).data_type(); + if (expr_type.is_numeric() + && is_utf8_or_utf8view_or_large_utf8(subquery_type)) + || (subquery_type.is_numeric() + && is_utf8_or_utf8view_or_large_utf8(&expr_type)) + { + return plan_err!( + "expr type {expr_type} can't cast to {subquery_type} in SetComparison" + ); + } + let common_type = comparison_coercion(&expr_type, subquery_type).ok_or( + plan_datafusion_err!( + "expr type {expr_type} can't cast to {subquery_type} in SetComparison" + ), + )?; + let new_subquery = Subquery { + subquery: Arc::new(new_plan), + outer_ref_columns: subquery.outer_ref_columns, + spans: subquery.spans, + }; + Ok(Transformed::yes(Expr::SetComparison(SetComparison::new( + Box::new(expr.cast_to(&common_type, self.schema)?), + cast_subquery(new_subquery, &common_type)?, + op, + quantifier, + )))) + } Expr::Not(expr) => Ok(Transformed::yes(not(get_casted_expr_for_bool_op( *expr, self.schema, diff --git a/datafusion/optimizer/src/lib.rs b/datafusion/optimizer/src/lib.rs index a1a59cb348876..f8ab453591e91 100644 --- a/datafusion/optimizer/src/lib.rs +++ b/datafusion/optimizer/src/lib.rs @@ -66,6 +66,7 @@ pub mod propagate_empty_relation; pub mod push_down_filter; pub mod push_down_limit; pub mod replace_distinct_aggregate; +pub mod rewrite_set_comparison; pub mod scalar_subquery_to_join; pub mod simplify_expressions; pub mod single_distinct_to_groupby; diff --git a/datafusion/optimizer/src/optimizer.rs b/datafusion/optimizer/src/optimizer.rs index ededcec0a47c9..f553a5e66b217 100644 --- a/datafusion/optimizer/src/optimizer.rs +++ b/datafusion/optimizer/src/optimizer.rs @@ -51,6 +51,7 @@ use crate::propagate_empty_relation::PropagateEmptyRelation; use crate::push_down_filter::PushDownFilter; use crate::push_down_limit::PushDownLimit; use crate::replace_distinct_aggregate::ReplaceDistinctWithAggregate; +use crate::rewrite_set_comparison::RewriteSetComparison; use crate::scalar_subquery_to_join::ScalarSubqueryToJoin; use crate::simplify_expressions::SimplifyExpressions; use crate::single_distinct_to_groupby::SingleDistinctToGroupBy; @@ -227,6 +228,7 @@ impl Optimizer { /// Create a new optimizer using the recommended list of rules pub fn new() -> Self { let rules: Vec> = vec![ + Arc::new(RewriteSetComparison::new()), Arc::new(OptimizeUnions::new()), Arc::new(SimplifyExpressions::new()), Arc::new(ReplaceDistinctWithAggregate::new()), diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index 755ffdbafc869..1ff8bdfeff4c0 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -263,6 +263,7 @@ fn can_evaluate_as_join_condition(predicate: &Expr) -> Result { | Expr::ScalarVariable(_, _) => Ok(TreeNodeRecursion::Jump), Expr::Exists { .. } | Expr::InSubquery(_) + | Expr::SetComparison(_) | Expr::ScalarSubquery(_) | Expr::OuterReferenceColumn(_, _) | Expr::Unnest(_) => { diff --git a/datafusion/optimizer/src/rewrite_set_comparison.rs b/datafusion/optimizer/src/rewrite_set_comparison.rs new file mode 100644 index 0000000000000..0e642606f5746 --- /dev/null +++ b/datafusion/optimizer/src/rewrite_set_comparison.rs @@ -0,0 +1,171 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Optimizer rule rewriting `SetComparison` subqueries (e.g. `= ANY`, +//! `> ALL`) into boolean expressions built from `EXISTS` subqueries +//! that capture SQL three-valued logic. + +use crate::{OptimizerConfig, OptimizerRule}; +use datafusion_common::tree_node::{Transformed, TreeNode}; +use datafusion_common::{Column, DFSchema, ExprSchema, Result, ScalarValue, plan_err}; +use datafusion_expr::expr::{self, Exists, SetComparison, SetQuantifier}; +use datafusion_expr::logical_plan::Subquery; +use datafusion_expr::logical_plan::builder::LogicalPlanBuilder; +use datafusion_expr::{Expr, LogicalPlan, lit}; +use std::sync::Arc; + +use datafusion_expr::utils::merge_schema; + +/// Rewrite `SetComparison` expressions to scalar subqueries that return the +/// correct boolean value (including SQL NULL semantics). After this rule +/// runs, later rules such as `ScalarSubqueryToJoin` can decorrelate and +/// remove the remaining subquery. +#[derive(Debug, Default)] +pub struct RewriteSetComparison; + +impl RewriteSetComparison { + #[allow(missing_docs)] + pub fn new() -> Self { + Self + } + + fn rewrite_plan(&self, plan: LogicalPlan) -> Result> { + let schema = merge_schema(&plan.inputs()); + plan.map_expressions(|expr| { + expr.transform_up(|expr| rewrite_set_comparison(expr, &schema)) + }) + } +} + +impl OptimizerRule for RewriteSetComparison { + fn name(&self) -> &str { + "rewrite_set_comparison" + } + + fn rewrite( + &self, + plan: LogicalPlan, + _config: &dyn OptimizerConfig, + ) -> Result> { + plan.transform_up_with_subqueries(|plan| self.rewrite_plan(plan)) + } +} + +fn rewrite_set_comparison( + expr: Expr, + outer_schema: &DFSchema, +) -> Result> { + match expr { + Expr::SetComparison(set_comparison) => { + let rewritten = build_set_comparison_subquery(set_comparison, outer_schema)?; + Ok(Transformed::yes(rewritten)) + } + _ => Ok(Transformed::no(expr)), + } +} + +fn build_set_comparison_subquery( + set_comparison: SetComparison, + outer_schema: &DFSchema, +) -> Result { + let SetComparison { + expr, + subquery, + op, + quantifier, + } = set_comparison; + + let left_expr = to_outer_reference(*expr, outer_schema)?; + let subquery_schema = subquery.subquery.schema(); + if subquery_schema.fields().is_empty() { + return plan_err!("single expression required."); + } + // avoid `head_output_expr` for aggr/window plan, it will gives group-by expr if exists + let right_expr = Expr::Column(Column::from(subquery_schema.qualified_field(0))); + + let comparison = Expr::BinaryExpr(expr::BinaryExpr::new( + Box::new(left_expr), + op, + Box::new(right_expr), + )); + + let true_exists = + exists_subquery(&subquery, Expr::IsTrue(Box::new(comparison.clone())))?; + let null_exists = + exists_subquery(&subquery, Expr::IsNull(Box::new(comparison.clone())))?; + + let result_expr = match quantifier { + SetQuantifier::Any => Expr::Case(expr::Case { + expr: None, + when_then_expr: vec![ + (Box::new(true_exists), Box::new(lit(true))), + ( + Box::new(null_exists), + Box::new(Expr::Literal(ScalarValue::Boolean(None), None)), + ), + ], + else_expr: Some(Box::new(lit(false))), + }), + SetQuantifier::All => { + let false_exists = + exists_subquery(&subquery, Expr::IsFalse(Box::new(comparison.clone())))?; + Expr::Case(expr::Case { + expr: None, + when_then_expr: vec![ + (Box::new(false_exists), Box::new(lit(false))), + ( + Box::new(null_exists), + Box::new(Expr::Literal(ScalarValue::Boolean(None), None)), + ), + ], + else_expr: Some(Box::new(lit(true))), + }) + } + }; + + Ok(result_expr) +} + +fn exists_subquery(subquery: &Subquery, filter: Expr) -> Result { + let plan = LogicalPlanBuilder::from(subquery.subquery.as_ref().clone()) + .filter(filter)? + .build()?; + let outer_ref_columns = plan.all_out_ref_exprs(); + Ok(Expr::Exists(Exists { + subquery: Subquery { + subquery: Arc::new(plan), + outer_ref_columns, + spans: subquery.spans.clone(), + }, + negated: false, + })) +} + +fn to_outer_reference(expr: Expr, outer_schema: &DFSchema) -> Result { + expr.transform_up(|expr| match expr { + Expr::Column(col) => { + let field = outer_schema.field_from_column(&col)?; + Ok(Transformed::yes(Expr::OuterReferenceColumn( + Arc::clone(field), + col, + ))) + } + Expr::OuterReferenceColumn(_, _) => Ok(Transformed::no(expr)), + _ => Ok(Transformed::no(expr)), + }) + .map(|t| t.data) +} diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index 01de44cee1f60..863831216e712 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -646,6 +646,7 @@ impl<'a> ConstEvaluator<'a> { | Expr::OuterReferenceColumn(_, _) | Expr::Exists { .. } | Expr::InSubquery(_) + | Expr::SetComparison(_) | Expr::ScalarSubquery(_) | Expr::WindowFunction { .. } | Expr::GroupingSet(_) diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 6e4e5d0b6eea4..9c326a8f6e435 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -578,6 +578,7 @@ pub fn serialize_expr( Expr::ScalarSubquery(_) | Expr::InSubquery(_) | Expr::Exists { .. } + | Expr::SetComparison(_) | Expr::OuterReferenceColumn { .. } => { // we would need to add logical plan operators to datafusion.proto to support this // see discussion in https://github.com/apache/datafusion/issues/2565 diff --git a/datafusion/sql/Cargo.toml b/datafusion/sql/Cargo.toml index a814292a3d71d..b7338cb764d77 100644 --- a/datafusion/sql/Cargo.toml +++ b/datafusion/sql/Cargo.toml @@ -56,6 +56,7 @@ bigdecimal = { workspace = true } chrono = { workspace = true } datafusion-common = { workspace = true, features = ["sql"] } datafusion-expr = { workspace = true, features = ["sql"] } +datafusion-functions-nested = { workspace = true, features = ["sql"] } indexmap = { workspace = true } log = { workspace = true } recursive = { workspace = true, optional = true } diff --git a/datafusion/sql/src/expr/mod.rs b/datafusion/sql/src/expr/mod.rs index fcd7d6376d21c..dbf2ce67732ec 100644 --- a/datafusion/sql/src/expr/mod.rs +++ b/datafusion/sql/src/expr/mod.rs @@ -32,6 +32,7 @@ use datafusion_common::{ }; use datafusion_expr::expr::ScalarFunction; +use datafusion_expr::expr::SetQuantifier; use datafusion_expr::expr::{InList, WildcardOptions}; use datafusion_expr::{ Between, BinaryExpr, Cast, Expr, ExprSchemable, GetFieldAccess, Like, Literal, @@ -39,6 +40,7 @@ use datafusion_expr::{ }; use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; +use datafusion_functions_nested::expr_fn::array_has; mod binary_op; mod function; @@ -594,32 +596,44 @@ impl SqlToRel<'_, S> { // ANY/SOME are equivalent, this field specifies which the user // specified but it doesn't affect the plan so ignore the field is_some: _, - } => { - let mut binary_expr = RawBinaryExpr { - op: compare_op, - left: self.sql_expr_to_logical_expr( - *left, - schema, - planner_context, - )?, - right: self.sql_expr_to_logical_expr( - *right, - schema, - planner_context, - )?, - }; - for planner in self.context_provider.get_expr_planners() { - match planner.plan_any(binary_expr)? { - PlannerResult::Planned(expr) => { - return Ok(expr); - } - PlannerResult::Original(expr) => { - binary_expr = expr; - } + } => match *right { + SQLExpr::Subquery(subquery) => self.parse_set_comparison_subquery( + *left, + *subquery, + &compare_op, + SetQuantifier::Any, + schema, + planner_context, + ), + _ => { + if compare_op != BinaryOperator::Eq { + plan_err!( + "Unsupported AnyOp: '{compare_op}', only '=' is supported" + ) + } else { + let left_expr = + self.sql_to_expr(*left, schema, planner_context)?; + let right_expr = + self.sql_to_expr(*right, schema, planner_context)?; + Ok(array_has(right_expr, left_expr)) } } - not_impl_err!("AnyOp not supported by ExprPlanner: {binary_expr:?}") - } + }, + SQLExpr::AllOp { + left, + compare_op, + right, + } => match *right { + SQLExpr::Subquery(subquery) => self.parse_set_comparison_subquery( + *left, + *subquery, + &compare_op, + SetQuantifier::All, + schema, + planner_context, + ), + _ => not_impl_err!("ALL only supports subquery comparison currently"), + }, #[expect(deprecated)] SQLExpr::Wildcard(_token) => Ok(Expr::Wildcard { qualifier: None, diff --git a/datafusion/sql/src/expr/subquery.rs b/datafusion/sql/src/expr/subquery.rs index ec34ff3d53426..d710a1e1f08ea 100644 --- a/datafusion/sql/src/expr/subquery.rs +++ b/datafusion/sql/src/expr/subquery.rs @@ -17,10 +17,10 @@ use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; use datafusion_common::{DFSchema, Diagnostic, Result, Span, Spans, plan_err}; -use datafusion_expr::expr::{Exists, InSubquery}; +use datafusion_expr::expr::{Exists, InSubquery, SetComparison, SetQuantifier}; use datafusion_expr::{Expr, LogicalPlan, Subquery}; use sqlparser::ast::Expr as SQLExpr; -use sqlparser::ast::{Query, SelectItem, SetExpr}; +use sqlparser::ast::{BinaryOperator, Query, SelectItem, SetExpr}; use std::sync::Arc; impl SqlToRel<'_, S> { @@ -162,4 +162,51 @@ impl SqlToRel<'_, S> { diagnostic.add_help(help_message, None); diagnostic } + + pub(super) fn parse_set_comparison_subquery( + &self, + left_expr: SQLExpr, + subquery: Query, + compare_op: &BinaryOperator, + quantifier: SetQuantifier, + input_schema: &DFSchema, + planner_context: &mut PlannerContext, + ) -> Result { + let old_outer_query_schema = + planner_context.set_outer_query_schema(Some(input_schema.clone().into())); + + let mut spans = Spans::new(); + if let SetExpr::Select(select) = subquery.body.as_ref() { + for item in &select.projection { + if let SelectItem::ExprWithAlias { alias, .. } = item { + if let Some(span) = Span::try_from_sqlparser_span(alias.span) { + spans.add_span(span); + } + } + } + } + + let sub_plan = self.query_to_plan(subquery, planner_context)?; + let outer_ref_columns = sub_plan.all_out_ref_exprs(); + planner_context.set_outer_query_schema(old_outer_query_schema); + + self.validate_single_column( + &sub_plan, + &spans, + "Too many columns! The subquery should only return one column", + "Select only one column in the subquery", + )?; + + let expr_obj = self.sql_to_expr(left_expr, input_schema, planner_context)?; + Ok(Expr::SetComparison(SetComparison::new( + Box::new(expr_obj), + Subquery { + subquery: Arc::new(sub_plan), + outer_ref_columns, + spans, + }, + self.parse_sql_binary_op(compare_op)?, + quantifier, + ))) + } } diff --git a/datafusion/sql/src/unparser/expr.rs b/datafusion/sql/src/unparser/expr.rs index 5746a568e712b..ac7b467920364 100644 --- a/datafusion/sql/src/unparser/expr.rs +++ b/datafusion/sql/src/unparser/expr.rs @@ -45,7 +45,7 @@ use datafusion_common::{ }; use datafusion_expr::{ Between, BinaryExpr, Case, Cast, Expr, GroupingSet, Like, Operator, TryCast, - expr::{Alias, Exists, InList, ScalarFunction, Sort, WindowFunction}, + expr::{Alias, Exists, InList, ScalarFunction, SetQuantifier, Sort, WindowFunction}, }; use sqlparser::ast::helpers::attached_token::AttachedToken; use sqlparser::tokenizer::Span; @@ -393,6 +393,33 @@ impl Unparser<'_> { negated: insubq.negated, }) } + Expr::SetComparison(set_cmp) => { + let left = Box::new(self.expr_to_sql_inner(set_cmp.expr.as_ref())?); + let sub_statement = + self.plan_to_sql(set_cmp.subquery.subquery.as_ref())?; + let sub_query = if let ast::Statement::Query(inner_query) = sub_statement + { + inner_query + } else { + return plan_err!( + "Subquery must be a Query, but found {sub_statement:?}" + ); + }; + let compare_op = self.op_to_sql(&set_cmp.op)?; + match set_cmp.quantifier { + SetQuantifier::Any => Ok(ast::Expr::AnyOp { + left, + compare_op, + right: Box::new(ast::Expr::Subquery(sub_query)), + is_some: false, + }), + SetQuantifier::All => Ok(ast::Expr::AllOp { + left, + compare_op, + right: Box::new(ast::Expr::Subquery(sub_query)), + }), + } + } Expr::Exists(Exists { subquery, negated }) => { let sub_statement = self.plan_to_sql(subquery.subquery.as_ref())?; let sub_query = if let ast::Statement::Query(inner_query) = sub_statement diff --git a/datafusion/sqllogictest/test_files/explain.slt b/datafusion/sqllogictest/test_files/explain.slt index 9087aee56d978..8ce0d2d7046f3 100644 --- a/datafusion/sqllogictest/test_files/explain.slt +++ b/datafusion/sqllogictest/test_files/explain.slt @@ -176,6 +176,7 @@ initial_logical_plan logical_plan after resolve_grouping_function SAME TEXT AS ABOVE logical_plan after type_coercion SAME TEXT AS ABOVE analyzed_logical_plan SAME TEXT AS ABOVE +logical_plan after rewrite_set_comparison SAME TEXT AS ABOVE logical_plan after optimize_unions SAME TEXT AS ABOVE logical_plan after simplify_expressions SAME TEXT AS ABOVE logical_plan after replace_distinct_aggregate SAME TEXT AS ABOVE @@ -197,6 +198,7 @@ logical_plan after single_distinct_aggregation_to_group_by SAME TEXT AS ABOVE logical_plan after eliminate_group_by_constant SAME TEXT AS ABOVE logical_plan after common_sub_expression_eliminate SAME TEXT AS ABOVE logical_plan after optimize_projections TableScan: simple_explain_test projection=[a, b, c] +logical_plan after rewrite_set_comparison SAME TEXT AS ABOVE logical_plan after optimize_unions SAME TEXT AS ABOVE logical_plan after simplify_expressions SAME TEXT AS ABOVE logical_plan after replace_distinct_aggregate SAME TEXT AS ABOVE @@ -538,6 +540,7 @@ initial_logical_plan logical_plan after resolve_grouping_function SAME TEXT AS ABOVE logical_plan after type_coercion SAME TEXT AS ABOVE analyzed_logical_plan SAME TEXT AS ABOVE +logical_plan after rewrite_set_comparison SAME TEXT AS ABOVE logical_plan after optimize_unions SAME TEXT AS ABOVE logical_plan after simplify_expressions SAME TEXT AS ABOVE logical_plan after replace_distinct_aggregate SAME TEXT AS ABOVE @@ -559,6 +562,7 @@ logical_plan after single_distinct_aggregation_to_group_by SAME TEXT AS ABOVE logical_plan after eliminate_group_by_constant SAME TEXT AS ABOVE logical_plan after common_sub_expression_eliminate SAME TEXT AS ABOVE logical_plan after optimize_projections TableScan: simple_explain_test projection=[a, b, c] +logical_plan after rewrite_set_comparison SAME TEXT AS ABOVE logical_plan after optimize_unions SAME TEXT AS ABOVE logical_plan after simplify_expressions SAME TEXT AS ABOVE logical_plan after replace_distinct_aggregate SAME TEXT AS ABOVE diff --git a/datafusion/sqllogictest/test_files/subquery.slt b/datafusion/sqllogictest/test_files/subquery.slt index da0bfc89d5848..e73f4ec3e32da 100644 --- a/datafusion/sqllogictest/test_files/subquery.slt +++ b/datafusion/sqllogictest/test_files/subquery.slt @@ -430,7 +430,7 @@ SELECT t1_id, t1_name, t1_int, (select t2_id, t2_name FROM t2 WHERE t2.t2_id = t #subquery_not_allowed #In/Exist Subquery is not allowed in ORDER BY clause. -statement error DataFusion error: Invalid \(non-executable\) plan after Analyzer\ncaused by\nError during planning: In/Exist subquery can only be used in Projection, Filter, TableScan, Window functions, Aggregate and Join plan nodes, but was used in \[Sort: t1.t1_int IN \(\) ASC NULLS LAST\] +statement error DataFusion error: Invalid \(non-executable\) plan after Analyzer\ncaused by\nError during planning: In/Exist/SetComparison subquery can only be used in Projection, Filter, TableScan, Window functions, Aggregate and Join plan nodes, but was used in \[Sort: t1.t1_int IN \(\) ASC NULLS LAST\] SELECT t1_id, t1_name, t1_int FROM t1 order by t1_int in (SELECT t2_int FROM t2 WHERE t1.t1_id > t1.t1_int) #non_aggregated_correlated_scalar_subquery @@ -1469,3 +1469,62 @@ logical_plan statement count 0 drop table person; + +# Set comparison subqueries (ANY/ALL) +statement ok +create table set_cmp_t(v int) as values (1), (6), (10); + +statement ok +create table set_cmp_s(v int) as values (5), (null); + +statement ok +create table set_cmp_empty(v int); + +query I rowsort +select v from set_cmp_t where v > any(select v from set_cmp_s); +---- +10 +6 + +query I rowsort +select v from set_cmp_t where v < all(select v from set_cmp_empty); +---- +1 +10 +6 + +statement count 0 +drop table set_cmp_t; + +statement count 0 +drop table set_cmp_s; + +statement count 0 +drop table set_cmp_empty; + +query TT +explain select v from (values (1), (6), (10)) set_cmp_t(v) where v > any(select v from (values (5), (null)) set_cmp_s(v)); +---- +logical_plan +01)Projection: set_cmp_t.v +02)--Filter: __correlated_sq_1.mark OR __correlated_sq_2.mark AND NOT __correlated_sq_3.mark AND Boolean(NULL) +03)----LeftMark Join: Filter: set_cmp_t.v > __correlated_sq_3.v IS TRUE +04)------Filter: __correlated_sq_1.mark OR __correlated_sq_2.mark AND Boolean(NULL) +05)--------LeftMark Join: Filter: set_cmp_t.v > __correlated_sq_2.v IS NULL +06)----------Filter: __correlated_sq_1.mark OR Boolean(NULL) +07)------------LeftMark Join: Filter: set_cmp_t.v > __correlated_sq_1.v IS TRUE +08)--------------SubqueryAlias: set_cmp_t +09)----------------Projection: column1 AS v +10)------------------Values: (Int64(1)), (Int64(6)), (Int64(10)) +11)--------------SubqueryAlias: __correlated_sq_1 +12)----------------SubqueryAlias: set_cmp_s +13)------------------Projection: column1 AS v +14)--------------------Values: (Int64(5)), (Int64(NULL)) +15)----------SubqueryAlias: __correlated_sq_2 +16)------------SubqueryAlias: set_cmp_s +17)--------------Projection: column1 AS v +18)----------------Values: (Int64(5)), (Int64(NULL)) +19)------SubqueryAlias: __correlated_sq_3 +20)--------SubqueryAlias: set_cmp_s +21)----------Projection: column1 AS v +22)------------Values: (Int64(5)), (Int64(NULL)) diff --git a/datafusion/substrait/src/logical_plan/consumer/expr/subquery.rs b/datafusion/substrait/src/logical_plan/consumer/expr/subquery.rs index 15fe7947a1e1f..61a381e9eb407 100644 --- a/datafusion/substrait/src/logical_plan/consumer/expr/subquery.rs +++ b/datafusion/substrait/src/logical_plan/consumer/expr/subquery.rs @@ -16,12 +16,13 @@ // under the License. use crate::logical_plan::consumer::SubstraitConsumer; -use datafusion::common::{DFSchema, Spans, substrait_err}; -use datafusion::logical_expr::expr::{Exists, InSubquery}; -use datafusion::logical_expr::{Expr, Subquery}; +use datafusion::common::{DFSchema, Spans, substrait_datafusion_err, substrait_err}; +use datafusion::logical_expr::expr::{Exists, InSubquery, SetComparison, SetQuantifier}; +use datafusion::logical_expr::{Expr, Operator, Subquery}; use std::sync::Arc; use substrait::proto::expression as substrait_expression; use substrait::proto::expression::subquery::SubqueryType; +use substrait::proto::expression::subquery::set_comparison::{ComparisonOp, ReductionOp}; use substrait::proto::expression::subquery::set_predicate::PredicateOp; pub async fn from_subquery( @@ -96,8 +97,53 @@ pub async fn from_subquery( ), } } - other_type => { - substrait_err!("Subquery type {other_type:?} not implemented") + SubqueryType::SetComparison(comparison) => { + let left = comparison.left.as_ref().ok_or_else(|| { + substrait_datafusion_err!("SetComparison requires a left expression") + })?; + let right = comparison.right.as_ref().ok_or_else(|| { + substrait_datafusion_err!("SetComparison requires a right relation") + })?; + let reduction_op = match ReductionOp::try_from(comparison.reduction_op) { + Ok(ReductionOp::Any) => SetQuantifier::Any, + Ok(ReductionOp::All) => SetQuantifier::All, + _ => { + return substrait_err!( + "Unsupported reduction op for SetComparison: {}", + comparison.reduction_op + ); + } + }; + let comparison_op = match ComparisonOp::try_from(comparison.comparison_op) + { + Ok(ComparisonOp::Eq) => Operator::Eq, + Ok(ComparisonOp::Ne) => Operator::NotEq, + Ok(ComparisonOp::Lt) => Operator::Lt, + Ok(ComparisonOp::Gt) => Operator::Gt, + Ok(ComparisonOp::Le) => Operator::LtEq, + Ok(ComparisonOp::Ge) => Operator::GtEq, + _ => { + return substrait_err!( + "Unsupported comparison op for SetComparison: {}", + comparison.comparison_op + ); + } + }; + + let left_expr = consumer.consume_expression(left, input_schema).await?; + let plan = consumer.consume_rel(right).await?; + let outer_ref_columns = plan.all_out_ref_exprs(); + + Ok(Expr::SetComparison(SetComparison::new( + Box::new(left_expr), + Subquery { + subquery: Arc::new(plan), + outer_ref_columns, + spans: Spans::new(), + }, + comparison_op, + reduction_op, + ))) } }, None => { diff --git a/datafusion/substrait/src/logical_plan/producer/expr/mod.rs b/datafusion/substrait/src/logical_plan/producer/expr/mod.rs index 5057564d370cf..74b1a65215376 100644 --- a/datafusion/substrait/src/logical_plan/producer/expr/mod.rs +++ b/datafusion/substrait/src/logical_plan/producer/expr/mod.rs @@ -141,6 +141,7 @@ pub fn to_substrait_rex( Expr::InList(expr) => producer.handle_in_list(expr, schema), Expr::Exists(expr) => not_impl_err!("Cannot convert {expr:?} to Substrait"), Expr::InSubquery(expr) => producer.handle_in_subquery(expr, schema), + Expr::SetComparison(expr) => producer.handle_set_comparison(expr, schema), Expr::ScalarSubquery(expr) => { not_impl_err!("Cannot convert {expr:?} to Substrait") } diff --git a/datafusion/substrait/src/logical_plan/producer/expr/subquery.rs b/datafusion/substrait/src/logical_plan/producer/expr/subquery.rs index f2e6ff551223c..e5b9241c10104 100644 --- a/datafusion/substrait/src/logical_plan/producer/expr/subquery.rs +++ b/datafusion/substrait/src/logical_plan/producer/expr/subquery.rs @@ -16,9 +16,11 @@ // under the License. use crate::logical_plan::producer::SubstraitProducer; -use datafusion::common::DFSchemaRef; -use datafusion::logical_expr::expr::InSubquery; +use datafusion::common::{DFSchemaRef, substrait_err}; +use datafusion::logical_expr::Operator; +use datafusion::logical_expr::expr::{InSubquery, SetComparison, SetQuantifier}; use substrait::proto::expression::subquery::InPredicate; +use substrait::proto::expression::subquery::set_comparison::{ComparisonOp, ReductionOp}; use substrait::proto::expression::{RexType, ScalarFunction}; use substrait::proto::function_argument::ArgType; use substrait::proto::{Expression, FunctionArgument}; @@ -70,3 +72,53 @@ pub fn from_in_subquery( Ok(substrait_subquery) } } + +fn comparison_op_to_proto(op: &Operator) -> datafusion::common::Result { + match op { + Operator::Eq => Ok(ComparisonOp::Eq), + Operator::NotEq => Ok(ComparisonOp::Ne), + Operator::Lt => Ok(ComparisonOp::Lt), + Operator::Gt => Ok(ComparisonOp::Gt), + Operator::LtEq => Ok(ComparisonOp::Le), + Operator::GtEq => Ok(ComparisonOp::Ge), + _ => substrait_err!("Unsupported operator {op:?} for SetComparison subquery"), + } +} + +fn reduction_op_to_proto( + quantifier: &SetQuantifier, +) -> datafusion::common::Result { + match quantifier { + SetQuantifier::Any => Ok(ReductionOp::Any), + SetQuantifier::All => Ok(ReductionOp::All), + } +} + +pub fn from_set_comparison( + producer: &mut impl SubstraitProducer, + set_comparison: &SetComparison, + schema: &DFSchemaRef, +) -> datafusion::common::Result { + let comparison_op = comparison_op_to_proto(&set_comparison.op)? as i32; + let reduction_op = reduction_op_to_proto(&set_comparison.quantifier)? as i32; + let left = producer.handle_expr(set_comparison.expr.as_ref(), schema)?; + let subquery_plan = + producer.handle_plan(set_comparison.subquery.subquery.as_ref())?; + + Ok(Expression { + rex_type: Some(RexType::Subquery(Box::new( + substrait::proto::expression::Subquery { + subquery_type: Some( + substrait::proto::expression::subquery::SubqueryType::SetComparison( + Box::new(substrait::proto::expression::subquery::SetComparison { + reduction_op, + comparison_op, + left: Some(Box::new(left)), + right: Some(subquery_plan), + }), + ), + ), + }, + ))), + }) +} diff --git a/datafusion/substrait/src/logical_plan/producer/substrait_producer.rs b/datafusion/substrait/src/logical_plan/producer/substrait_producer.rs index ffc920ffe609e..c7518bd04e4a1 100644 --- a/datafusion/substrait/src/logical_plan/producer/substrait_producer.rs +++ b/datafusion/substrait/src/logical_plan/producer/substrait_producer.rs @@ -20,14 +20,17 @@ use crate::logical_plan::producer::{ from_aggregate, from_aggregate_function, from_alias, from_between, from_binary_expr, from_case, from_cast, from_column, from_distinct, from_empty_relation, from_filter, from_in_list, from_in_subquery, from_join, from_like, from_limit, from_literal, - from_projection, from_repartition, from_scalar_function, from_sort, - from_subquery_alias, from_table_scan, from_try_cast, from_unary_expr, from_union, - from_values, from_window, from_window_function, to_substrait_rel, to_substrait_rex, + from_projection, from_repartition, from_scalar_function, from_set_comparison, + from_sort, from_subquery_alias, from_table_scan, from_try_cast, from_unary_expr, + from_union, from_values, from_window, from_window_function, to_substrait_rel, + to_substrait_rex, }; use datafusion::common::{Column, DFSchemaRef, ScalarValue, substrait_err}; use datafusion::execution::SessionState; use datafusion::execution::registry::SerializerRegistry; -use datafusion::logical_expr::expr::{Alias, InList, InSubquery, WindowFunction}; +use datafusion::logical_expr::expr::{ + Alias, InList, InSubquery, SetComparison, WindowFunction, +}; use datafusion::logical_expr::{ Aggregate, Between, BinaryExpr, Case, Cast, Distinct, EmptyRelation, Expr, Extension, Filter, Join, Like, Limit, LogicalPlan, Projection, Repartition, Sort, SubqueryAlias, @@ -361,6 +364,14 @@ pub trait SubstraitProducer: Send + Sync + Sized { ) -> datafusion::common::Result { from_in_subquery(self, in_subquery, schema) } + + fn handle_set_comparison( + &mut self, + set_comparison: &SetComparison, + schema: &DFSchemaRef, + ) -> datafusion::common::Result { + from_set_comparison(self, set_comparison, schema) + } } pub struct DefaultSubstraitProducer<'a> { diff --git a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs index 98b35bf082ec4..f78b255526dc9 100644 --- a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs @@ -29,14 +29,15 @@ use std::mem::size_of_val; use datafusion::arrow::datatypes::{DataType, Field, IntervalUnit, Schema, TimeUnit}; use datafusion::common::tree_node::Transformed; -use datafusion::common::{DFSchema, DFSchemaRef, not_impl_err, plan_err}; +use datafusion::common::{DFSchema, DFSchemaRef, Spans, not_impl_err, plan_err}; use datafusion::error::Result; use datafusion::execution::registry::SerializerRegistry; use datafusion::execution::runtime_env::RuntimeEnv; use datafusion::execution::session_state::SessionStateBuilder; +use datafusion::logical_expr::expr::{SetComparison, SetQuantifier}; use datafusion::logical_expr::{ - EmptyRelation, Extension, InvariantLevel, LogicalPlan, PartitionEvaluator, - Repartition, UserDefinedLogicalNode, Values, Volatility, + EmptyRelation, Extension, InvariantLevel, LogicalPlan, Operator, PartitionEvaluator, + Repartition, Subquery, UserDefinedLogicalNode, Values, Volatility, }; use datafusion::optimizer::simplify_expressions::expr_simplifier::THRESHOLD_INLINE_INLIST; use datafusion::prelude::*; @@ -689,6 +690,29 @@ async fn roundtrip_exists_filter() -> Result<()> { Ok(()) } +// assemble logical plan manually to ensure SetComparison expr is present (not rewrite away) +#[tokio::test] +async fn roundtrip_set_comparison_any_substrait() -> Result<()> { + let ctx = create_context().await?; + let plan = build_set_comparison_plan(&ctx, SetQuantifier::Any, Operator::Gt).await?; + let proto = to_substrait_plan(&plan, &ctx.state())?; + let roundtrip_plan = from_substrait_plan(&ctx.state(), &proto).await?; + assert_set_comparison_predicate(&roundtrip_plan, Operator::Gt, SetQuantifier::Any); + Ok(()) +} + +// assemble logical plan manually to ensure SetComparison expr is present (not rewrite away) +#[tokio::test] +async fn roundtrip_set_comparison_all_substrait() -> Result<()> { + let ctx = create_context().await?; + let plan = + build_set_comparison_plan(&ctx, SetQuantifier::All, Operator::NotEq).await?; + let proto = to_substrait_plan(&plan, &ctx.state())?; + let roundtrip_plan = from_substrait_plan(&ctx.state(), &proto).await?; + assert_set_comparison_predicate(&roundtrip_plan, Operator::NotEq, SetQuantifier::All); + Ok(()) +} + #[tokio::test] async fn roundtrip_not_exists_filter_left_anti_join() -> Result<()> { let plan = generate_plan_from_sql( @@ -1865,6 +1889,56 @@ async fn assert_substrait_sql(substrait_plan: Plan, sql: &str) -> Result<()> { Ok(()) } +async fn build_set_comparison_plan( + ctx: &SessionContext, + quantifier: SetQuantifier, + op: Operator, +) -> Result { + let base_scan = ctx.table("data").await?.into_unoptimized_plan(); + let subquery_scan = ctx.table("data2").await?.into_unoptimized_plan(); + let subquery_plan = LogicalPlanBuilder::from(subquery_scan) + .project(vec![col("data2.a")])? + .build()?; + let predicate = Expr::SetComparison(SetComparison::new( + Box::new(col("data.a")), + Subquery { + subquery: Arc::new(subquery_plan), + outer_ref_columns: vec![], + spans: Spans::new(), + }, + op, + quantifier, + )); + + LogicalPlanBuilder::from(base_scan) + .filter(predicate)? + .project(vec![col("data.a")])? + .build() +} + +fn assert_set_comparison_predicate( + plan: &LogicalPlan, + expected_op: Operator, + expected_quantifier: SetQuantifier, +) { + let predicate = match plan { + LogicalPlan::Projection(p) => match p.input.as_ref() { + LogicalPlan::Filter(filter) => &filter.predicate, + other => panic!("expected Filter inside Projection, got {other:?}"), + }, + LogicalPlan::Filter(filter) => &filter.predicate, + other => panic!("expected Filter plan, got {other:?}"), + }; + + match predicate { + Expr::SetComparison(set_comparison) => { + assert_eq!(set_comparison.op, expected_op); + assert_eq!(set_comparison.quantifier, expected_quantifier); + } + other => panic!("expected SetComparison predicate, got {other:?}"), + } +} + async fn roundtrip_fill_na(sql: &str) -> Result<()> { let ctx = create_context().await?; let df = ctx.sql(sql).await?;