From 51d78a6a8e589aa258e7791d5aee6947c5881986 Mon Sep 17 00:00:00 2001 From: Ruihang Xia Date: Fri, 5 Dec 2025 14:57:10 +0800 Subject: [PATCH 1/9] initial impl Signed-off-by: Ruihang Xia --- datafusion/catalog-listing/src/helpers.rs | 1 + datafusion/core/tests/set_comparison.rs | 82 +++++++++ .../execution/src/cache/list_files_cache.rs | 4 +- datafusion/expr/src/expr.rs | 74 ++++++++ datafusion/expr/src/expr_schema.rs | 3 + .../expr/src/logical_plan/invariants.rs | 20 ++- datafusion/expr/src/logical_plan/tree_node.rs | 19 +- datafusion/expr/src/tree_node.rs | 20 ++- datafusion/expr/src/utils.rs | 1 + datafusion/optimizer/src/analyzer/mod.rs | 3 + .../optimizer/src/analyzer/set_comparison.rs | 167 ++++++++++++++++++ .../optimizer/src/analyzer/type_coercion.rs | 34 +++- datafusion/optimizer/src/push_down_filter.rs | 1 + .../simplify_expressions/expr_simplifier.rs | 1 + datafusion/sql/src/expr/mod.rs | 53 +++--- datafusion/sql/src/expr/subquery.rs | 51 +++++- datafusion/sql/src/unparser/expr.rs | 29 ++- 17 files changed, 525 insertions(+), 38 deletions(-) create mode 100644 datafusion/core/tests/set_comparison.rs create mode 100644 datafusion/optimizer/src/analyzer/set_comparison.rs diff --git a/datafusion/catalog-listing/src/helpers.rs b/datafusion/catalog-listing/src/helpers.rs index 34073338fbd7e..39db6773768b1 100644 --- a/datafusion/catalog-listing/src/helpers.rs +++ b/datafusion/catalog-listing/src/helpers.rs @@ -83,6 +83,7 @@ pub fn expr_applicable_for_cols(col_names: &[&str], expr: &Expr) -> bool { | Expr::Exists(_) | Expr::InSubquery(_) | Expr::ScalarSubquery(_) + | Expr::SetComparison(_) | Expr::GroupingSet(_) | Expr::Case(_) => Ok(TreeNodeRecursion::Continue), diff --git a/datafusion/core/tests/set_comparison.rs b/datafusion/core/tests/set_comparison.rs new file mode 100644 index 0000000000000..f0402460ce9e9 --- /dev/null +++ b/datafusion/core/tests/set_comparison.rs @@ -0,0 +1,82 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use arrow::array::Int32Array; +use arrow::datatypes::{DataType, Field, Schema}; +use arrow::record_batch::RecordBatch; +use datafusion::prelude::SessionContext; +use datafusion_common::{assert_batches_eq, Result}; + +fn build_table(values: &[i32]) -> Result { + let schema = Arc::new(Schema::new(vec![Field::new("v", DataType::Int32, true)])); + let array = + Arc::new(Int32Array::from(values.to_vec())) as Arc; + RecordBatch::try_new(schema, vec![array]).map_err(Into::into) +} + +#[tokio::test] +async fn set_comparison_any() -> Result<()> { + let ctx = SessionContext::new(); + + ctx.register_batch("t", build_table(&[1, 6, 10])?)?; + // Include a NULL in the subquery input to ensure we propagate UNKNOWN correctly. + ctx.register_batch("s", { + let schema = Arc::new(Schema::new(vec![Field::new("v", DataType::Int32, true)])); + let array = Arc::new(Int32Array::from(vec![Some(5), None])) + as Arc; + RecordBatch::try_new(schema, vec![array])? + })?; + + let df = ctx + .sql("select v from t where v > any(select v from s)") + .await?; + let results = df.collect().await?; + + assert_batches_eq!( + &["+----+", "| v |", "+----+", "| 6 |", "| 10 |", "+----+",], + &results + ); + Ok(()) +} + +#[tokio::test] +async fn set_comparison_all_empty() -> Result<()> { + let ctx = SessionContext::new(); + + ctx.register_batch("t", build_table(&[1, 6, 10])?)?; + ctx.register_batch( + "e", + RecordBatch::new_empty(Arc::new(Schema::new(vec![Field::new( + "v", + DataType::Int32, + true, + )]))), + )?; + + let df = ctx + .sql("select v from t where v < all(select v from e)") + .await?; + let results = df.collect().await?; + + assert_batches_eq!( + &["+----+", "| v |", "+----+", "| 1 |", "| 6 |", "| 10 |", "+----+",], + &results + ); + Ok(()) +} diff --git a/datafusion/execution/src/cache/list_files_cache.rs b/datafusion/execution/src/cache/list_files_cache.rs index 2bec2a1e70fc7..944cfca8ea2f8 100644 --- a/datafusion/execution/src/cache/list_files_cache.rs +++ b/datafusion/execution/src/cache/list_files_cache.rs @@ -21,9 +21,9 @@ use std::{ }; use datafusion_common::instant::Instant; -use object_store::{path::Path, ObjectMeta}; +use object_store::{ObjectMeta, path::Path}; -use crate::cache::{cache_manager::ListFilesCache, lru_queue::LruQueue, CacheAccessor}; +use crate::cache::{CacheAccessor, cache_manager::ListFilesCache, lru_queue::LruQueue}; /// Default implementation of [`ListFilesCache`] /// diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index a595b59355739..7115c8fb9987c 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -372,6 +372,8 @@ pub enum Expr { Exists(Exists), /// IN subquery InSubquery(InSubquery), + /// Set comparison subquery (e.g. `= ANY`, `> ALL`) + SetComparison(SetComparison), /// Scalar subquery ScalarSubquery(Subquery), /// Represents a reference to all available fields in a specific schema, @@ -1101,6 +1103,54 @@ impl Exists { } } +/// Whether the set comparison uses `ANY`/`SOME` or `ALL` +#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Hash, Debug)] +pub enum SetQuantifier { + /// `ANY` (or `SOME`) + Any, + /// `ALL` + All, +} + +impl Display for SetQuantifier { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + match self { + SetQuantifier::Any => write!(f, "ANY"), + SetQuantifier::All => write!(f, "ALL"), + } + } +} + +/// Set comparison subquery (e.g. `= ANY`, `> ALL`) +#[derive(Clone, PartialEq, Eq, PartialOrd, Hash, Debug)] +pub struct SetComparison { + /// The expression to compare + pub expr: Box, + /// Subquery that will produce a single column of data to compare against + pub subquery: Subquery, + /// Comparison operator (e.g. `=`, `>`, `<`) + pub op: Operator, + /// Quantifier (`ANY`/`ALL`) + pub quantifier: SetQuantifier, +} + +impl SetComparison { + /// Create a new set comparison expression + pub fn new( + expr: Box, + subquery: Subquery, + op: Operator, + quantifier: SetQuantifier, + ) -> Self { + Self { + expr, + subquery, + op, + quantifier, + } + } +} + /// InList expression #[derive(Clone, PartialEq, Eq, PartialOrd, Hash, Debug)] pub struct InList { @@ -1503,6 +1553,7 @@ impl Expr { Expr::GroupingSet(..) => "GroupingSet", Expr::InList { .. } => "InList", Expr::InSubquery(..) => "InSubquery", + Expr::SetComparison(..) => "SetComparison", Expr::IsNotNull(..) => "IsNotNull", Expr::IsNull(..) => "IsNull", Expr::Like { .. } => "Like", @@ -2058,6 +2109,7 @@ impl Expr { | Expr::GroupingSet(..) | Expr::InList(..) | Expr::InSubquery(..) + | Expr::SetComparison(..) | Expr::IsFalse(..) | Expr::IsNotFalse(..) | Expr::IsNotNull(..) @@ -2651,6 +2703,16 @@ impl HashNode for Expr { subquery.hash(state); negated.hash(state); } + Expr::SetComparison(SetComparison { + expr: _, + subquery, + op, + quantifier, + }) => { + subquery.hash(state); + op.hash(state); + quantifier.hash(state); + } Expr::ScalarSubquery(subquery) => { subquery.hash(state); } @@ -2841,6 +2903,12 @@ impl Display for SchemaDisplay<'_> { write!(f, "NOT IN") } Expr::InSubquery(InSubquery { negated: false, .. }) => write!(f, "IN"), + Expr::SetComparison(SetComparison { + expr, + op, + quantifier, + .. + }) => write!(f, "{} {op} {quantifier}", SchemaDisplay(expr.as_ref())), Expr::IsTrue(expr) => write!(f, "{} IS TRUE", SchemaDisplay(expr)), Expr::IsFalse(expr) => write!(f, "{} IS FALSE", SchemaDisplay(expr)), Expr::IsNotTrue(expr) => { @@ -3316,6 +3384,12 @@ impl Display for Expr { subquery, negated: false, }) => write!(f, "{expr} IN ({subquery:?})"), + Expr::SetComparison(SetComparison { + expr, + subquery, + op, + quantifier, + }) => write!(f, "{expr} {op} {quantifier} ({subquery:?})"), Expr::ScalarSubquery(subquery) => write!(f, "({subquery:?})"), Expr::BinaryExpr(expr) => write!(f, "{expr}"), Expr::ScalarFunction(fun) => { diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index 0d895310655ca..52cca4b958e04 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -196,6 +196,7 @@ impl ExprSchemable for Expr { | Expr::IsNull(_) | Expr::Exists { .. } | Expr::InSubquery(_) + | Expr::SetComparison(_) | Expr::Between { .. } | Expr::InList { .. } | Expr::IsNotNull(_) @@ -380,6 +381,7 @@ impl ExprSchemable for Expr { | Expr::IsNotFalse(_) | Expr::IsNotUnknown(_) | Expr::Exists { .. } => Ok(false), + Expr::SetComparison(_) => Ok(true), Expr::InSubquery(InSubquery { expr, .. }) => expr.nullable(input_schema), Expr::ScalarSubquery(subquery) => { Ok(subquery.subquery.schema().field(0).is_nullable()) @@ -645,6 +647,7 @@ impl ExprSchemable for Expr { | Expr::TryCast(_) | Expr::InList(_) | Expr::InSubquery(_) + | Expr::SetComparison(_) | Expr::Wildcard { .. } | Expr::GroupingSet(_) | Expr::Placeholder(_) diff --git a/datafusion/expr/src/logical_plan/invariants.rs b/datafusion/expr/src/logical_plan/invariants.rs index 762491a255cbc..dcd720b2cdee0 100644 --- a/datafusion/expr/src/logical_plan/invariants.rs +++ b/datafusion/expr/src/logical_plan/invariants.rs @@ -22,7 +22,7 @@ use datafusion_common::{ use crate::{ Aggregate, Expr, Filter, Join, JoinType, LogicalPlan, Window, - expr::{Exists, InSubquery}, + expr::{Exists, InSubquery, SetComparison}, expr_rewriter::strip_outer_reference, utils::{collect_subquery_cols, split_conjunction}, }; @@ -81,6 +81,7 @@ fn assert_valid_extension_nodes(plan: &LogicalPlan, check: InvariantLevel) -> Re match expr { Expr::Exists(Exists { subquery, .. }) | Expr::InSubquery(InSubquery { subquery, .. }) + | Expr::SetComparison(SetComparison { subquery, .. }) | Expr::ScalarSubquery(subquery) => { assert_valid_extension_nodes(&subquery.subquery, check)?; } @@ -133,6 +134,7 @@ fn assert_subqueries_are_valid(plan: &LogicalPlan) -> Result<()> { match expr { Expr::Exists(Exists { subquery, .. }) | Expr::InSubquery(InSubquery { subquery, .. }) + | Expr::SetComparison(SetComparison { subquery, .. }) | Expr::ScalarSubquery(subquery) => { check_subquery_expr(plan, &subquery.subquery, expr)?; } @@ -229,6 +231,20 @@ pub fn check_subquery_expr( ); } } + if let Expr::SetComparison(set_comparison) = expr { + if set_comparison.subquery.subquery.schema().fields().len() > 1 { + return plan_err!( + "Set comparison subquery should only return one column, but found {}: {}", + set_comparison.subquery.subquery.schema().fields().len(), + set_comparison + .subquery + .subquery + .schema() + .field_names() + .join(", ") + ); + } + } match outer_plan { LogicalPlan::Projection(_) | LogicalPlan::Filter(_) @@ -237,7 +253,7 @@ pub fn check_subquery_expr( | LogicalPlan::Aggregate(_) | LogicalPlan::Join(_) => Ok(()), _ => plan_err!( - "In/Exist subquery can only be used in \ + "In/Exist/SetComparison subquery can only be used in \ Projection, Filter, TableScan, Window functions, Aggregate and Join plan nodes, \ but was used in [{}]", outer_plan.display() diff --git a/datafusion/expr/src/logical_plan/tree_node.rs b/datafusion/expr/src/logical_plan/tree_node.rs index 62a27b0a025ad..b233fe2e0fff0 100644 --- a/datafusion/expr/src/logical_plan/tree_node.rs +++ b/datafusion/expr/src/logical_plan/tree_node.rs @@ -46,7 +46,7 @@ use crate::{ }; use datafusion_common::tree_node::TreeNodeRefContainer; -use crate::expr::{Exists, InSubquery}; +use crate::expr::{Exists, InSubquery, SetComparison}; use datafusion_common::tree_node::{ Transformed, TreeNode, TreeNodeContainer, TreeNodeIterator, TreeNodeRecursion, TreeNodeRewriter, TreeNodeVisitor, @@ -815,6 +815,7 @@ impl LogicalPlan { expr.apply(|expr| match expr { Expr::Exists(Exists { subquery, .. }) | Expr::InSubquery(InSubquery { subquery, .. }) + | Expr::SetComparison(SetComparison { subquery, .. }) | Expr::ScalarSubquery(subquery) => { // use a synthetic plan so the collector sees a // LogicalPlan::Subquery (even though it is @@ -856,6 +857,22 @@ impl LogicalPlan { })), _ => internal_err!("Transformation should return Subquery"), }), + Expr::SetComparison(SetComparison { + expr, + subquery, + op, + quantifier, + }) => f(LogicalPlan::Subquery(subquery))?.map_data(|s| match s { + LogicalPlan::Subquery(subquery) => { + Ok(Expr::SetComparison(SetComparison { + expr, + subquery, + op, + quantifier, + })) + } + _ => internal_err!("Transformation should return Subquery"), + }), Expr::ScalarSubquery(subquery) => f(LogicalPlan::Subquery(subquery))? .map_data(|s| match s { LogicalPlan::Subquery(subquery) => { diff --git a/datafusion/expr/src/tree_node.rs b/datafusion/expr/src/tree_node.rs index 742bae5b2320b..226c512a974d8 100644 --- a/datafusion/expr/src/tree_node.rs +++ b/datafusion/expr/src/tree_node.rs @@ -20,8 +20,8 @@ use crate::Expr; use crate::expr::{ AggregateFunction, AggregateFunctionParams, Alias, Between, BinaryExpr, Case, Cast, - GroupingSet, InList, InSubquery, Like, Placeholder, ScalarFunction, TryCast, Unnest, - WindowFunction, WindowFunctionParams, + GroupingSet, InList, InSubquery, Like, Placeholder, ScalarFunction, SetComparison, + TryCast, Unnest, WindowFunction, WindowFunctionParams, }; use datafusion_common::Result; @@ -58,7 +58,8 @@ impl TreeNode for Expr { | Expr::Negative(expr) | Expr::Cast(Cast { expr, .. }) | Expr::TryCast(TryCast { expr, .. }) - | Expr::InSubquery(InSubquery { expr, .. }) => expr.apply_elements(f), + | Expr::InSubquery(InSubquery { expr, .. }) + | Expr::SetComparison(SetComparison { expr, .. }) => expr.apply_elements(f), Expr::GroupingSet(GroupingSet::Rollup(exprs)) | Expr::GroupingSet(GroupingSet::Cube(exprs)) => exprs.apply_elements(f), Expr::ScalarFunction(ScalarFunction { args, .. }) => { @@ -128,6 +129,19 @@ impl TreeNode for Expr { | Expr::ScalarSubquery(_) | Expr::ScalarVariable(_, _) | Expr::Literal(_, _) => Transformed::no(self), + Expr::SetComparison(SetComparison { + expr, + subquery, + op, + quantifier, + }) => expr.map_elements(f)?.update_data(|expr| { + Expr::SetComparison(SetComparison { + expr, + subquery, + op, + quantifier, + }) + }), Expr::Unnest(Unnest { expr, .. }) => expr .map_elements(f)? .update_data(|expr| Expr::Unnest(Unnest { expr })), diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index de4ebf5fa96e9..7d13c9c1d24d1 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -312,6 +312,7 @@ pub fn expr_to_columns(expr: &Expr, accum: &mut HashSet) -> Result<()> { | Expr::InList { .. } | Expr::Exists { .. } | Expr::InSubquery(_) + | Expr::SetComparison(_) | Expr::ScalarSubquery(_) | Expr::Wildcard { .. } | Expr::Placeholder(_) diff --git a/datafusion/optimizer/src/analyzer/mod.rs b/datafusion/optimizer/src/analyzer/mod.rs index 272692f983683..742f0e1d7b98e 100644 --- a/datafusion/optimizer/src/analyzer/mod.rs +++ b/datafusion/optimizer/src/analyzer/mod.rs @@ -29,6 +29,7 @@ use datafusion_expr::expr_rewriter::FunctionRewrite; use datafusion_expr::{InvariantLevel, LogicalPlan}; use crate::analyzer::resolve_grouping_function::ResolveGroupingFunction; +use crate::analyzer::set_comparison::RewriteSetComparison; use crate::analyzer::type_coercion::TypeCoercion; use crate::utils::log_plan; @@ -36,6 +37,7 @@ use self::function_rewrite::ApplyFunctionRewrites; pub mod function_rewrite; pub mod resolve_grouping_function; +pub mod set_comparison; pub mod type_coercion; /// [`AnalyzerRule`]s transform [`LogicalPlan`]s in some way to make @@ -87,6 +89,7 @@ impl Analyzer { pub fn new() -> Self { let rules: Vec> = vec![ Arc::new(ResolveGroupingFunction::new()), + Arc::new(RewriteSetComparison::new()), Arc::new(TypeCoercion::new()), ]; Self::with_rules(rules) diff --git a/datafusion/optimizer/src/analyzer/set_comparison.rs b/datafusion/optimizer/src/analyzer/set_comparison.rs new file mode 100644 index 0000000000000..440c43611c8ef --- /dev/null +++ b/datafusion/optimizer/src/analyzer/set_comparison.rs @@ -0,0 +1,167 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Rewrite `SetComparison` subqueries (e.g. `= ANY`, `> ALL`) into +//! boolean expressions built from `EXISTS` subqueries that capture SQL +//! three-valued logic. + +use super::AnalyzerRule; +use datafusion_common::config::ConfigOptions; +use datafusion_common::tree_node::{Transformed, TreeNode}; +use datafusion_common::{plan_datafusion_err, DFSchema, ExprSchema, Result, ScalarValue}; +use datafusion_expr::expr::{self, Exists, SetComparison, SetQuantifier}; +use datafusion_expr::logical_plan::builder::LogicalPlanBuilder; +use datafusion_expr::logical_plan::Subquery; +use datafusion_expr::{lit, Expr, LogicalPlan}; +use std::sync::Arc; + +use datafusion_expr::utils::merge_schema; + +/// Rewrite `SetComparison` expressions to scalar subqueries that return the +/// correct boolean value (including SQL NULL semantics). After this rule +/// runs, later rules such as `ScalarSubqueryToJoin` can decorrelate and +/// remove the remaining subquery. +#[derive(Debug, Default)] +pub struct RewriteSetComparison; + +impl RewriteSetComparison { + #[allow(missing_docs)] + pub fn new() -> Self { + Self::default() + } + + fn rewrite_plan(&self, plan: LogicalPlan) -> Result> { + let schema = merge_schema(&plan.inputs()); + plan.map_expressions(|expr| { + expr.transform_up(|expr| rewrite_set_comparison(expr, &schema)) + }) + } +} + +impl AnalyzerRule for RewriteSetComparison { + fn name(&self) -> &str { + "rewrite_set_comparison" + } + + fn analyze(&self, plan: LogicalPlan, _config: &ConfigOptions) -> Result { + plan.transform_up_with_subqueries(|plan| self.rewrite_plan(plan)) + .map(|t| t.data) + } +} + +fn rewrite_set_comparison( + expr: Expr, + outer_schema: &DFSchema, +) -> Result> { + match expr { + Expr::SetComparison(set_comparison) => { + let rewritten = build_set_comparison_subquery(set_comparison, outer_schema)?; + Ok(Transformed::yes(rewritten)) + } + _ => Ok(Transformed::no(expr)), + } +} + +fn build_set_comparison_subquery( + set_comparison: SetComparison, + outer_schema: &DFSchema, +) -> Result { + let SetComparison { + expr, + subquery, + op, + quantifier, + } = set_comparison; + + let left_expr = to_outer_reference(*expr, outer_schema)?; + let right_expr = subquery + .subquery + .head_output_expr()? + .ok_or_else(|| plan_datafusion_err!("single expression required."))?; + + let comparison = Expr::BinaryExpr(expr::BinaryExpr::new( + Box::new(left_expr), + op, + Box::new(right_expr), + )); + + let true_exists = + exists_subquery(&subquery, Expr::IsTrue(Box::new(comparison.clone())))?; + let null_exists = + exists_subquery(&subquery, Expr::IsNull(Box::new(comparison.clone())))?; + + let result_expr = match quantifier { + SetQuantifier::Any => Expr::Case(expr::Case { + expr: None, + when_then_expr: vec![ + (Box::new(true_exists), Box::new(lit(true))), + ( + Box::new(null_exists), + Box::new(Expr::Literal(ScalarValue::Boolean(None), None)), + ), + ], + else_expr: Some(Box::new(lit(false))), + }), + SetQuantifier::All => { + let false_exists = + exists_subquery(&subquery, Expr::IsFalse(Box::new(comparison.clone())))?; + Expr::Case(expr::Case { + expr: None, + when_then_expr: vec![ + (Box::new(false_exists), Box::new(lit(false))), + ( + Box::new(null_exists), + Box::new(Expr::Literal(ScalarValue::Boolean(None), None)), + ), + ], + else_expr: Some(Box::new(lit(true))), + }) + } + }; + + Ok(result_expr) +} + +fn exists_subquery(subquery: &Subquery, filter: Expr) -> Result { + let plan = LogicalPlanBuilder::from(subquery.subquery.as_ref().clone()) + .filter(filter)? + .build()?; + let outer_ref_columns = plan.all_out_ref_exprs(); + Ok(Expr::Exists(Exists { + subquery: Subquery { + subquery: Arc::new(plan), + outer_ref_columns, + spans: subquery.spans.clone(), + }, + negated: false, + })) +} + +fn to_outer_reference(expr: Expr, outer_schema: &DFSchema) -> Result { + expr.transform_up(|expr| match expr { + Expr::Column(col) => { + let field = outer_schema.field_from_column(&col)?; + Ok(Transformed::yes(Expr::OuterReferenceColumn( + field.clone(), + col, + ))) + } + Expr::OuterReferenceColumn(_, _) => Ok(Transformed::no(expr)), + _ => Ok(Transformed::no(expr)), + }) + .map(|t| t.data) +} diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index a557d3356dba0..35ecf59958120 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -35,7 +35,7 @@ use datafusion_common::{ }; use datafusion_expr::expr::{ self, AggregateFunctionParams, Alias, Between, BinaryExpr, Case, Exists, InList, - InSubquery, Like, ScalarFunction, Sort, WindowFunction, + InSubquery, Like, ScalarFunction, SetComparison, Sort, WindowFunction, }; use datafusion_expr::expr_rewriter::coerce_plan_expr_for_schema; use datafusion_expr::expr_schema::cast_subquery; @@ -368,6 +368,36 @@ impl TreeNodeRewriter for TypeCoercionRewriter<'_> { negated, )))) } + Expr::SetComparison(SetComparison { + expr, + subquery, + op, + quantifier, + }) => { + let new_plan = analyze_internal( + self.schema, + Arc::unwrap_or_clone(subquery.subquery), + )? + .data; + let expr_type = expr.get_type(self.schema)?; + let subquery_type = new_plan.schema().field(0).data_type(); + let common_type = comparison_coercion(&expr_type, subquery_type).ok_or( + plan_datafusion_err!( + "expr type {expr_type} can't cast to {subquery_type} in SetComparison" + ), + )?; + let new_subquery = Subquery { + subquery: Arc::new(new_plan), + outer_ref_columns: subquery.outer_ref_columns, + spans: subquery.spans, + }; + Ok(Transformed::yes(Expr::SetComparison(SetComparison::new( + Box::new(expr.cast_to(&common_type, self.schema)?), + cast_subquery(new_subquery, &common_type)?, + op, + quantifier, + )))) + } Expr::Not(expr) => Ok(Transformed::yes(not(get_casted_expr_for_bool_op( *expr, self.schema, @@ -1127,7 +1157,7 @@ mod test { use datafusion_common::config::ConfigOptions; use datafusion_common::tree_node::{TransformedResult, TreeNode}; use datafusion_common::{DFSchema, DFSchemaRef, Result, ScalarValue, Spans}; - use datafusion_expr::expr::{self, InSubquery, Like, ScalarFunction}; + use datafusion_expr::expr::{self, InSubquery, Like, ScalarFunction, SetComparison}; use datafusion_expr::logical_plan::{EmptyRelation, Projection, Sort}; use datafusion_expr::test::function_stub::avg_udaf; use datafusion_expr::{ diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index ea0980ad4e1c7..9f2bec79f3312 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -263,6 +263,7 @@ fn can_evaluate_as_join_condition(predicate: &Expr) -> Result { | Expr::ScalarVariable(_, _) => Ok(TreeNodeRecursion::Jump), Expr::Exists { .. } | Expr::InSubquery(_) + | Expr::SetComparison(_) | Expr::ScalarSubquery(_) | Expr::OuterReferenceColumn(_, _) | Expr::Unnest(_) => { diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index 366c99ce8f28b..11f4bc802f109 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -649,6 +649,7 @@ impl<'a> ConstEvaluator<'a> { | Expr::OuterReferenceColumn(_, _) | Expr::Exists { .. } | Expr::InSubquery(_) + | Expr::SetComparison(_) | Expr::ScalarSubquery(_) | Expr::WindowFunction { .. } | Expr::GroupingSet(_) diff --git a/datafusion/sql/src/expr/mod.rs b/datafusion/sql/src/expr/mod.rs index 9725025d599fe..63891fd12a4ad 100644 --- a/datafusion/sql/src/expr/mod.rs +++ b/datafusion/sql/src/expr/mod.rs @@ -32,6 +32,7 @@ use datafusion_common::{ }; use datafusion_expr::expr::ScalarFunction; +use datafusion_expr::expr::SetQuantifier; use datafusion_expr::expr::{InList, WildcardOptions}; use datafusion_expr::{ lit, Between, BinaryExpr, Cast, Expr, ExprSchemable, GetFieldAccess, Like, Literal, @@ -594,32 +595,34 @@ impl SqlToRel<'_, S> { // ANY/SOME are equivalent, this field specifies which the user // specified but it doesn't affect the plan so ignore the field is_some: _, - } => { - let mut binary_expr = RawBinaryExpr { - op: compare_op, - left: self.sql_expr_to_logical_expr( - *left, - schema, - planner_context, - )?, - right: self.sql_expr_to_logical_expr( - *right, - schema, - planner_context, - )?, - }; - for planner in self.context_provider.get_expr_planners() { - match planner.plan_any(binary_expr)? { - PlannerResult::Planned(expr) => { - return Ok(expr); - } - PlannerResult::Original(expr) => { - binary_expr = expr; - } - } + } => match *right { + SQLExpr::Subquery(subquery) => self.parse_set_comparison_subquery( + *left, + *subquery, + compare_op, + SetQuantifier::Any, + schema, + planner_context, + ), + _ => { + not_impl_err!("ANY/SOME only supports subquery comparison currently") } - not_impl_err!("AnyOp not supported by ExprPlanner: {binary_expr:?}") - } + }, + SQLExpr::AllOp { + left, + compare_op, + right, + } => match *right { + SQLExpr::Subquery(subquery) => self.parse_set_comparison_subquery( + *left, + *subquery, + compare_op, + SetQuantifier::All, + schema, + planner_context, + ), + _ => not_impl_err!("ALL only supports subquery comparison currently"), + }, #[expect(deprecated)] SQLExpr::Wildcard(_token) => Ok(Expr::Wildcard { qualifier: None, diff --git a/datafusion/sql/src/expr/subquery.rs b/datafusion/sql/src/expr/subquery.rs index 4bca6f7e49ba0..bad2a7e70d4f7 100644 --- a/datafusion/sql/src/expr/subquery.rs +++ b/datafusion/sql/src/expr/subquery.rs @@ -17,10 +17,10 @@ use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; use datafusion_common::{plan_err, DFSchema, Diagnostic, Result, Span, Spans}; -use datafusion_expr::expr::{Exists, InSubquery}; +use datafusion_expr::expr::{Exists, InSubquery, SetComparison, SetQuantifier}; use datafusion_expr::{Expr, LogicalPlan, Subquery}; use sqlparser::ast::Expr as SQLExpr; -use sqlparser::ast::{Query, SelectItem, SetExpr}; +use sqlparser::ast::{BinaryOperator, Query, SelectItem, SetExpr}; use std::sync::Arc; impl SqlToRel<'_, S> { @@ -162,4 +162,51 @@ impl SqlToRel<'_, S> { diagnostic.add_help(help_message, None); diagnostic } + + pub(super) fn parse_set_comparison_subquery( + &self, + left_expr: SQLExpr, + subquery: Query, + compare_op: BinaryOperator, + quantifier: SetQuantifier, + input_schema: &DFSchema, + planner_context: &mut PlannerContext, + ) -> Result { + let old_outer_query_schema = + planner_context.set_outer_query_schema(Some(input_schema.clone().into())); + + let mut spans = Spans::new(); + if let SetExpr::Select(select) = subquery.body.as_ref() { + for item in &select.projection { + if let SelectItem::ExprWithAlias { alias, .. } = item { + if let Some(span) = Span::try_from_sqlparser_span(alias.span) { + spans.add_span(span); + } + } + } + } + + let sub_plan = self.query_to_plan(subquery, planner_context)?; + let outer_ref_columns = sub_plan.all_out_ref_exprs(); + planner_context.set_outer_query_schema(old_outer_query_schema); + + self.validate_single_column( + &sub_plan, + &spans, + "Too many columns! The subquery should only return one column", + "Select only one column in the subquery", + )?; + + let expr_obj = self.sql_to_expr(left_expr, input_schema, planner_context)?; + Ok(Expr::SetComparison(SetComparison::new( + Box::new(expr_obj), + Subquery { + subquery: Arc::new(sub_plan), + outer_ref_columns, + spans, + }, + self.parse_sql_binary_op(&compare_op)?, + quantifier, + ))) + } } diff --git a/datafusion/sql/src/unparser/expr.rs b/datafusion/sql/src/unparser/expr.rs index 575cfd27ee354..78d3ed171ffb9 100644 --- a/datafusion/sql/src/unparser/expr.rs +++ b/datafusion/sql/src/unparser/expr.rs @@ -44,7 +44,7 @@ use datafusion_common::{ internal_err, not_impl_err, plan_err, Column, Result, ScalarValue, }; use datafusion_expr::{ - expr::{Alias, Exists, InList, ScalarFunction, Sort, WindowFunction}, + expr::{Alias, Exists, InList, ScalarFunction, SetQuantifier, Sort, WindowFunction}, Between, BinaryExpr, Case, Cast, Expr, GroupingSet, Like, Operator, TryCast, }; use sqlparser::ast::helpers::attached_token::AttachedToken; @@ -393,6 +393,33 @@ impl Unparser<'_> { negated: insubq.negated, }) } + Expr::SetComparison(set_cmp) => { + let left = Box::new(self.expr_to_sql_inner(set_cmp.expr.as_ref())?); + let sub_statement = + self.plan_to_sql(set_cmp.subquery.subquery.as_ref())?; + let sub_query = if let ast::Statement::Query(inner_query) = sub_statement + { + inner_query + } else { + return plan_err!( + "Subquery must be a Query, but found {sub_statement:?}" + ); + }; + let compare_op = self.op_to_sql(&set_cmp.op)?; + match set_cmp.quantifier { + SetQuantifier::Any => Ok(ast::Expr::AnyOp { + left, + compare_op, + right: Box::new(ast::Expr::Subquery(sub_query)), + is_some: false, + }), + SetQuantifier::All => Ok(ast::Expr::AllOp { + left, + compare_op, + right: Box::new(ast::Expr::Subquery(sub_query)), + }), + } + } Expr::Exists(Exists { subquery, negated }) => { let sub_statement = self.plan_to_sql(subquery.subquery.as_ref())?; let sub_query = if let ast::Statement::Query(inner_query) = sub_statement From a555416fe5bbbebe1c5e022a04711e0d247c7d74 Mon Sep 17 00:00:00 2001 From: Ruihang Xia Date: Fri, 5 Dec 2025 14:58:36 +0800 Subject: [PATCH 2/9] fix clippy Signed-off-by: Ruihang Xia --- .../expr/src/logical_plan/invariants.rs | 26 +++++++++---------- .../optimizer/src/analyzer/set_comparison.rs | 4 +-- .../optimizer/src/analyzer/type_coercion.rs | 2 +- datafusion/proto/src/logical_plan/to_proto.rs | 1 + datafusion/sql/src/expr/mod.rs | 4 +-- datafusion/sql/src/expr/subquery.rs | 4 +-- .../src/logical_plan/producer/expr/mod.rs | 3 +++ 7 files changed, 24 insertions(+), 20 deletions(-) diff --git a/datafusion/expr/src/logical_plan/invariants.rs b/datafusion/expr/src/logical_plan/invariants.rs index dcd720b2cdee0..b39b23e30f4e8 100644 --- a/datafusion/expr/src/logical_plan/invariants.rs +++ b/datafusion/expr/src/logical_plan/invariants.rs @@ -231,19 +231,19 @@ pub fn check_subquery_expr( ); } } - if let Expr::SetComparison(set_comparison) = expr { - if set_comparison.subquery.subquery.schema().fields().len() > 1 { - return plan_err!( - "Set comparison subquery should only return one column, but found {}: {}", - set_comparison.subquery.subquery.schema().fields().len(), - set_comparison - .subquery - .subquery - .schema() - .field_names() - .join(", ") - ); - } + if let Expr::SetComparison(set_comparison) = expr + && set_comparison.subquery.subquery.schema().fields().len() > 1 + { + return plan_err!( + "Set comparison subquery should only return one column, but found {}: {}", + set_comparison.subquery.subquery.schema().fields().len(), + set_comparison + .subquery + .subquery + .schema() + .field_names() + .join(", ") + ); } match outer_plan { LogicalPlan::Projection(_) diff --git a/datafusion/optimizer/src/analyzer/set_comparison.rs b/datafusion/optimizer/src/analyzer/set_comparison.rs index 440c43611c8ef..484c6dc7d754a 100644 --- a/datafusion/optimizer/src/analyzer/set_comparison.rs +++ b/datafusion/optimizer/src/analyzer/set_comparison.rs @@ -41,7 +41,7 @@ pub struct RewriteSetComparison; impl RewriteSetComparison { #[allow(missing_docs)] pub fn new() -> Self { - Self::default() + Self } fn rewrite_plan(&self, plan: LogicalPlan) -> Result> { @@ -156,7 +156,7 @@ fn to_outer_reference(expr: Expr, outer_schema: &DFSchema) -> Result { Expr::Column(col) => { let field = outer_schema.field_from_column(&col)?; Ok(Transformed::yes(Expr::OuterReferenceColumn( - field.clone(), + Arc::clone(field), col, ))) } diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index 35ecf59958120..417f04a7d4771 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -1157,7 +1157,7 @@ mod test { use datafusion_common::config::ConfigOptions; use datafusion_common::tree_node::{TransformedResult, TreeNode}; use datafusion_common::{DFSchema, DFSchemaRef, Result, ScalarValue, Spans}; - use datafusion_expr::expr::{self, InSubquery, Like, ScalarFunction, SetComparison}; + use datafusion_expr::expr::{self, InSubquery, Like, ScalarFunction}; use datafusion_expr::logical_plan::{EmptyRelation, Projection, Sort}; use datafusion_expr::test::function_stub::avg_udaf; use datafusion_expr::{ diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 2774b5b6ba7c3..eb373f896b7e5 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -579,6 +579,7 @@ pub fn serialize_expr( Expr::ScalarSubquery(_) | Expr::InSubquery(_) | Expr::Exists { .. } + | Expr::SetComparison(_) | Expr::OuterReferenceColumn { .. } => { // we would need to add logical plan operators to datafusion.proto to support this // see discussion in https://github.com/apache/datafusion/issues/2565 diff --git a/datafusion/sql/src/expr/mod.rs b/datafusion/sql/src/expr/mod.rs index 63891fd12a4ad..654cde51ec9de 100644 --- a/datafusion/sql/src/expr/mod.rs +++ b/datafusion/sql/src/expr/mod.rs @@ -599,7 +599,7 @@ impl SqlToRel<'_, S> { SQLExpr::Subquery(subquery) => self.parse_set_comparison_subquery( *left, *subquery, - compare_op, + &compare_op, SetQuantifier::Any, schema, planner_context, @@ -616,7 +616,7 @@ impl SqlToRel<'_, S> { SQLExpr::Subquery(subquery) => self.parse_set_comparison_subquery( *left, *subquery, - compare_op, + &compare_op, SetQuantifier::All, schema, planner_context, diff --git a/datafusion/sql/src/expr/subquery.rs b/datafusion/sql/src/expr/subquery.rs index bad2a7e70d4f7..24c1c07026b43 100644 --- a/datafusion/sql/src/expr/subquery.rs +++ b/datafusion/sql/src/expr/subquery.rs @@ -167,7 +167,7 @@ impl SqlToRel<'_, S> { &self, left_expr: SQLExpr, subquery: Query, - compare_op: BinaryOperator, + compare_op: &BinaryOperator, quantifier: SetQuantifier, input_schema: &DFSchema, planner_context: &mut PlannerContext, @@ -205,7 +205,7 @@ impl SqlToRel<'_, S> { outer_ref_columns, spans, }, - self.parse_sql_binary_op(&compare_op)?, + self.parse_sql_binary_op(compare_op)?, quantifier, ))) } diff --git a/datafusion/substrait/src/logical_plan/producer/expr/mod.rs b/datafusion/substrait/src/logical_plan/producer/expr/mod.rs index f4e43fd586773..953ecc6d2f20f 100644 --- a/datafusion/substrait/src/logical_plan/producer/expr/mod.rs +++ b/datafusion/substrait/src/logical_plan/producer/expr/mod.rs @@ -141,6 +141,9 @@ pub fn to_substrait_rex( Expr::InList(expr) => producer.handle_in_list(expr, schema), Expr::Exists(expr) => not_impl_err!("Cannot convert {expr:?} to Substrait"), Expr::InSubquery(expr) => producer.handle_in_subquery(expr, schema), + Expr::SetComparison(expr) => { + not_impl_err!("Cannot convert {expr:?} to Substrait") + } Expr::ScalarSubquery(expr) => { not_impl_err!("Cannot convert {expr:?} to Substrait") } From 8f40a43da526ddc67ad549968baab5d3e3692933 Mon Sep 17 00:00:00 2001 From: Ruihang Xia Date: Fri, 5 Dec 2025 15:45:11 +0800 Subject: [PATCH 3/9] more tests Signed-off-by: Ruihang Xia --- datafusion/expr/src/expr.rs | 23 +++++++ .../sqllogictest/test_files/subquery.slt | 61 ++++++++++++++++++- 2 files changed, 83 insertions(+), 1 deletion(-) diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 7115c8fb9987c..fc0cd77df1070 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -3873,6 +3873,7 @@ mod test { } use super::*; + use crate::logical_plan::{EmptyRelation, LogicalPlan}; #[test] fn test_display_wildcard() { @@ -3963,6 +3964,28 @@ mod test { ) } + #[test] + fn test_display_set_comparison() { + let subquery = Subquery { + subquery: Arc::new(LogicalPlan::EmptyRelation(EmptyRelation { + produce_one_row: false, + schema: Arc::new(DFSchema::empty()), + })), + outer_ref_columns: vec![], + spans: Spans::new(), + }; + + let expr = Expr::SetComparison(SetComparison::new( + Box::new(Expr::Column(Column::from_name("a"))), + subquery, + Operator::Gt, + SetQuantifier::Any, + )); + + assert_eq!(format!("{expr}"), "a > ANY ()"); + assert_eq!(format!("{}", expr.human_display()), "a > ANY ()"); + } + #[test] fn test_schema_display_alias_with_relation() { assert_eq!( diff --git a/datafusion/sqllogictest/test_files/subquery.slt b/datafusion/sqllogictest/test_files/subquery.slt index 27325d4e5e84d..e322cb4f11b39 100644 --- a/datafusion/sqllogictest/test_files/subquery.slt +++ b/datafusion/sqllogictest/test_files/subquery.slt @@ -438,7 +438,7 @@ SELECT t1_id, t1_name, t1_int, (select t2_id, t2_name FROM t2 WHERE t2.t2_id = t #subquery_not_allowed #In/Exist Subquery is not allowed in ORDER BY clause. -statement error DataFusion error: Invalid \(non-executable\) plan after Analyzer\ncaused by\nError during planning: In/Exist subquery can only be used in Projection, Filter, TableScan, Window functions, Aggregate and Join plan nodes, but was used in \[Sort: t1.t1_int IN \(\) ASC NULLS LAST\] +statement error DataFusion error: Invalid \(non-executable\) plan after Analyzer\ncaused by\nError during planning: In/Exist/SetComparison subquery can only be used in Projection, Filter, TableScan, Window functions, Aggregate and Join plan nodes, but was used in \[Sort: t1.t1_int IN \(\) ASC NULLS LAST\] SELECT t1_id, t1_name, t1_int FROM t1 order by t1_int in (SELECT t2_int FROM t2 WHERE t1.t1_id > t1.t1_int) #non_aggregated_correlated_scalar_subquery @@ -1478,3 +1478,62 @@ logical_plan statement count 0 drop table person; + +# Set comparison subqueries (ANY/ALL) +statement ok +create table set_cmp_t(v int) as values (1), (6), (10); + +statement ok +create table set_cmp_s(v int) as values (5), (null); + +statement ok +create table set_cmp_empty(v int); + +query I rowsort +select v from set_cmp_t where v > any(select v from set_cmp_s); +---- +10 +6 + +query I rowsort +select v from set_cmp_t where v < all(select v from set_cmp_empty); +---- +1 +10 +6 + +statement count 0 +drop table set_cmp_t; + +statement count 0 +drop table set_cmp_s; + +statement count 0 +drop table set_cmp_empty; + +query TT +explain select v from (values (1), (6), (10)) set_cmp_t(v) where v > any(select v from (values (5), (null)) set_cmp_s(v)); +---- +logical_plan +01)Projection: set_cmp_t.v +02)--Filter: __correlated_sq_1.mark OR __correlated_sq_2.mark AND NOT __correlated_sq_3.mark AND Boolean(NULL) +03)----LeftMark Join: Filter: set_cmp_t.v > __correlated_sq_3.v IS TRUE +04)------Filter: __correlated_sq_1.mark OR __correlated_sq_2.mark AND Boolean(NULL) +05)--------LeftMark Join: Filter: set_cmp_t.v > __correlated_sq_2.v IS NULL +06)----------Filter: __correlated_sq_1.mark OR Boolean(NULL) +07)------------LeftMark Join: Filter: set_cmp_t.v > __correlated_sq_1.v IS TRUE +08)--------------SubqueryAlias: set_cmp_t +09)----------------Projection: column1 AS v +10)------------------Values: (Int64(1)), (Int64(6)), (Int64(10)) +11)--------------SubqueryAlias: __correlated_sq_1 +12)----------------SubqueryAlias: set_cmp_s +13)------------------Projection: column1 AS v +14)--------------------Values: (Int64(5)), (Int64(NULL)) +15)----------SubqueryAlias: __correlated_sq_2 +16)------------SubqueryAlias: set_cmp_s +17)--------------Projection: column1 AS v +18)----------------Values: (Int64(5)), (Int64(NULL)) +19)------SubqueryAlias: __correlated_sq_3 +20)--------SubqueryAlias: set_cmp_s +21)----------Projection: column1 AS v +22)------------Values: (Int64(5)), (Int64(NULL)) From c1faa8850127ef62534bd772746e8f713d820116 Mon Sep 17 00:00:00 2001 From: Ruihang Xia Date: Fri, 5 Dec 2025 16:01:59 +0800 Subject: [PATCH 4/9] remove plan_any Signed-off-by: Ruihang Xia --- datafusion/expr/src/planner.rs | 7 ------- datafusion/functions-nested/src/planner.rs | 16 +--------------- 2 files changed, 1 insertion(+), 22 deletions(-) diff --git a/datafusion/expr/src/planner.rs b/datafusion/expr/src/planner.rs index 794c394d11d49..78034cc7b0051 100644 --- a/datafusion/expr/src/planner.rs +++ b/datafusion/expr/src/planner.rs @@ -227,13 +227,6 @@ pub trait ExprPlanner: Debug + Send + Sync { ) } - /// Plans `ANY` expression, such as `expr = ANY(array_expr)` - /// - /// Returns origin binary expression if not possible - fn plan_any(&self, expr: RawBinaryExpr) -> Result> { - Ok(PlannerResult::Original(expr)) - } - /// Plans aggregate functions, such as `COUNT()` /// /// Returns original expression arguments if not possible diff --git a/datafusion/functions-nested/src/planner.rs b/datafusion/functions-nested/src/planner.rs index 4fec5e38065b5..107f22a6c32cd 100644 --- a/datafusion/functions-nested/src/planner.rs +++ b/datafusion/functions-nested/src/planner.rs @@ -37,7 +37,7 @@ use std::sync::Arc; use crate::map::map_udf; use crate::{ - array_has::{array_has_all, array_has_udf}, + array_has::array_has_all, expr_fn::{array_append, array_concat, array_prepend}, extract::{array_element, array_slice}, make_array::make_array, @@ -120,20 +120,6 @@ impl ExprPlanner for NestedFunctionPlanner { ScalarFunction::new_udf(map_udf(), vec![keys, values]), ))) } - - fn plan_any(&self, expr: RawBinaryExpr) -> Result> { - if expr.op == BinaryOperator::Eq { - Ok(PlannerResult::Planned(Expr::ScalarFunction( - ScalarFunction::new_udf( - array_has_udf(), - // left and right are reversed here so `needle=any(haystack)` -> `array_has(haystack, needle)` - vec![expr.right, expr.left], - ), - ))) - } else { - plan_err!("Unsupported AnyOp: '{}', only '=' is supported", expr.op) - } - } } #[derive(Debug)] From 5d12582eff6f0729d1448a1c7e24f69ce23cdeac Mon Sep 17 00:00:00 2001 From: Ruihang Xia Date: Fri, 5 Dec 2025 17:05:58 +0800 Subject: [PATCH 5/9] substrait support Signed-off-by: Ruihang Xia --- .../logical_plan/consumer/expr/subquery.rs | 56 +++++++++++-- .../src/logical_plan/producer/expr/mod.rs | 4 +- .../logical_plan/producer/expr/subquery.rs | 56 ++++++++++++- .../producer/substrait_producer.rs | 19 ++++- .../tests/cases/roundtrip_logical_plan.rs | 80 ++++++++++++++++++- 5 files changed, 198 insertions(+), 17 deletions(-) diff --git a/datafusion/substrait/src/logical_plan/consumer/expr/subquery.rs b/datafusion/substrait/src/logical_plan/consumer/expr/subquery.rs index 917bcc007716b..22b5f706c0564 100644 --- a/datafusion/substrait/src/logical_plan/consumer/expr/subquery.rs +++ b/datafusion/substrait/src/logical_plan/consumer/expr/subquery.rs @@ -16,11 +16,12 @@ // under the License. use crate::logical_plan::consumer::SubstraitConsumer; -use datafusion::common::{substrait_err, DFSchema, Spans}; -use datafusion::logical_expr::expr::{Exists, InSubquery}; -use datafusion::logical_expr::{Expr, Subquery}; +use datafusion::common::{substrait_datafusion_err, substrait_err, DFSchema, Spans}; +use datafusion::logical_expr::expr::{Exists, InSubquery, SetComparison, SetQuantifier}; +use datafusion::logical_expr::{Expr, Operator, Subquery}; use std::sync::Arc; use substrait::proto::expression as substrait_expression; +use substrait::proto::expression::subquery::set_comparison::{ComparisonOp, ReductionOp}; use substrait::proto::expression::subquery::set_predicate::PredicateOp; use substrait::proto::expression::subquery::SubqueryType; @@ -94,8 +95,53 @@ pub async fn from_subquery( ), } } - other_type => { - substrait_err!("Subquery type {other_type:?} not implemented") + SubqueryType::SetComparison(comparison) => { + let left = comparison.left.as_ref().ok_or_else(|| { + substrait_datafusion_err!("SetComparison requires a left expression") + })?; + let right = comparison.right.as_ref().ok_or_else(|| { + substrait_datafusion_err!("SetComparison requires a right relation") + })?; + let reduction_op = match ReductionOp::try_from(comparison.reduction_op) { + Ok(ReductionOp::Any) => SetQuantifier::Any, + Ok(ReductionOp::All) => SetQuantifier::All, + _ => { + return substrait_err!( + "Unsupported reduction op for SetComparison: {}", + comparison.reduction_op + ) + } + }; + let comparison_op = match ComparisonOp::try_from(comparison.comparison_op) + { + Ok(ComparisonOp::Eq) => Operator::Eq, + Ok(ComparisonOp::Ne) => Operator::NotEq, + Ok(ComparisonOp::Lt) => Operator::Lt, + Ok(ComparisonOp::Gt) => Operator::Gt, + Ok(ComparisonOp::Le) => Operator::LtEq, + Ok(ComparisonOp::Ge) => Operator::GtEq, + _ => { + return substrait_err!( + "Unsupported comparison op for SetComparison: {}", + comparison.comparison_op + ) + } + }; + + let left_expr = consumer.consume_expression(left, input_schema).await?; + let plan = consumer.consume_rel(right).await?; + let outer_ref_columns = plan.all_out_ref_exprs(); + + Ok(Expr::SetComparison(SetComparison::new( + Box::new(left_expr), + Subquery { + subquery: Arc::new(plan), + outer_ref_columns, + spans: Spans::new(), + }, + comparison_op, + reduction_op, + ))) } }, None => { diff --git a/datafusion/substrait/src/logical_plan/producer/expr/mod.rs b/datafusion/substrait/src/logical_plan/producer/expr/mod.rs index 953ecc6d2f20f..637cd6f50b5e0 100644 --- a/datafusion/substrait/src/logical_plan/producer/expr/mod.rs +++ b/datafusion/substrait/src/logical_plan/producer/expr/mod.rs @@ -141,9 +141,7 @@ pub fn to_substrait_rex( Expr::InList(expr) => producer.handle_in_list(expr, schema), Expr::Exists(expr) => not_impl_err!("Cannot convert {expr:?} to Substrait"), Expr::InSubquery(expr) => producer.handle_in_subquery(expr, schema), - Expr::SetComparison(expr) => { - not_impl_err!("Cannot convert {expr:?} to Substrait") - } + Expr::SetComparison(expr) => producer.handle_set_comparison(expr, schema), Expr::ScalarSubquery(expr) => { not_impl_err!("Cannot convert {expr:?} to Substrait") } diff --git a/datafusion/substrait/src/logical_plan/producer/expr/subquery.rs b/datafusion/substrait/src/logical_plan/producer/expr/subquery.rs index c1ee78c68c258..0c18530119eb5 100644 --- a/datafusion/substrait/src/logical_plan/producer/expr/subquery.rs +++ b/datafusion/substrait/src/logical_plan/producer/expr/subquery.rs @@ -16,8 +16,10 @@ // under the License. use crate::logical_plan::producer::SubstraitProducer; -use datafusion::common::DFSchemaRef; -use datafusion::logical_expr::expr::InSubquery; +use datafusion::common::{substrait_err, DFSchemaRef}; +use datafusion::logical_expr::expr::{InSubquery, SetComparison, SetQuantifier}; +use datafusion::logical_expr::Operator; +use substrait::proto::expression::subquery::set_comparison::{ComparisonOp, ReductionOp}; use substrait::proto::expression::subquery::InPredicate; use substrait::proto::expression::{RexType, ScalarFunction}; use substrait::proto::function_argument::ArgType; @@ -70,3 +72,53 @@ pub fn from_in_subquery( Ok(substrait_subquery) } } + +fn comparison_op_to_proto(op: &Operator) -> datafusion::common::Result { + match op { + Operator::Eq => Ok(ComparisonOp::Eq), + Operator::NotEq => Ok(ComparisonOp::Ne), + Operator::Lt => Ok(ComparisonOp::Lt), + Operator::Gt => Ok(ComparisonOp::Gt), + Operator::LtEq => Ok(ComparisonOp::Le), + Operator::GtEq => Ok(ComparisonOp::Ge), + _ => substrait_err!("Unsupported operator {op:?} for SetComparison subquery"), + } +} + +fn reduction_op_to_proto( + quantifier: &SetQuantifier, +) -> datafusion::common::Result { + match quantifier { + SetQuantifier::Any => Ok(ReductionOp::Any), + SetQuantifier::All => Ok(ReductionOp::All), + } +} + +pub fn from_set_comparison( + producer: &mut impl SubstraitProducer, + set_comparison: &SetComparison, + schema: &DFSchemaRef, +) -> datafusion::common::Result { + let comparison_op = comparison_op_to_proto(&set_comparison.op)? as i32; + let reduction_op = reduction_op_to_proto(&set_comparison.quantifier)? as i32; + let left = producer.handle_expr(set_comparison.expr.as_ref(), schema)?; + let subquery_plan = + producer.handle_plan(set_comparison.subquery.subquery.as_ref())?; + + Ok(Expression { + rex_type: Some(RexType::Subquery(Box::new( + substrait::proto::expression::Subquery { + subquery_type: Some( + substrait::proto::expression::subquery::SubqueryType::SetComparison( + Box::new(substrait::proto::expression::subquery::SetComparison { + reduction_op, + comparison_op, + left: Some(Box::new(left)), + right: Some(subquery_plan), + }), + ), + ), + }, + ))), + }) +} diff --git a/datafusion/substrait/src/logical_plan/producer/substrait_producer.rs b/datafusion/substrait/src/logical_plan/producer/substrait_producer.rs index 54fa9ea5daa4b..ba6103eaa7fd0 100644 --- a/datafusion/substrait/src/logical_plan/producer/substrait_producer.rs +++ b/datafusion/substrait/src/logical_plan/producer/substrait_producer.rs @@ -20,14 +20,17 @@ use crate::logical_plan::producer::{ from_aggregate, from_aggregate_function, from_alias, from_between, from_binary_expr, from_case, from_cast, from_column, from_distinct, from_empty_relation, from_filter, from_in_list, from_in_subquery, from_join, from_like, from_limit, from_literal, - from_projection, from_repartition, from_scalar_function, from_sort, - from_subquery_alias, from_table_scan, from_try_cast, from_unary_expr, from_union, - from_values, from_window, from_window_function, to_substrait_rel, to_substrait_rex, + from_projection, from_repartition, from_scalar_function, from_set_comparison, + from_sort, from_subquery_alias, from_table_scan, from_try_cast, from_unary_expr, + from_union, from_values, from_window, from_window_function, to_substrait_rel, + to_substrait_rex, }; use datafusion::common::{substrait_err, Column, DFSchemaRef, ScalarValue}; use datafusion::execution::registry::SerializerRegistry; use datafusion::execution::SessionState; -use datafusion::logical_expr::expr::{Alias, InList, InSubquery, WindowFunction}; +use datafusion::logical_expr::expr::{ + Alias, InList, InSubquery, SetComparison, WindowFunction, +}; use datafusion::logical_expr::{ expr, Aggregate, Between, BinaryExpr, Case, Cast, Distinct, EmptyRelation, Expr, Extension, Filter, Join, Like, Limit, LogicalPlan, Projection, Repartition, Sort, @@ -359,6 +362,14 @@ pub trait SubstraitProducer: Send + Sync + Sized { ) -> datafusion::common::Result { from_in_subquery(self, in_subquery, schema) } + + fn handle_set_comparison( + &mut self, + set_comparison: &SetComparison, + schema: &DFSchemaRef, + ) -> datafusion::common::Result { + from_set_comparison(self, set_comparison, schema) + } } pub struct DefaultSubstraitProducer<'a> { diff --git a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs index 34cb05fbf7ff8..0cca0fba33e87 100644 --- a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs @@ -29,14 +29,15 @@ use std::mem::size_of_val; use datafusion::arrow::datatypes::{DataType, Field, IntervalUnit, Schema, TimeUnit}; use datafusion::common::tree_node::Transformed; -use datafusion::common::{not_impl_err, plan_err, DFSchema, DFSchemaRef}; +use datafusion::common::{not_impl_err, plan_err, DFSchema, DFSchemaRef, Spans}; use datafusion::error::Result; use datafusion::execution::registry::SerializerRegistry; use datafusion::execution::runtime_env::RuntimeEnv; use datafusion::execution::session_state::SessionStateBuilder; +use datafusion::logical_expr::expr::{SetComparison, SetQuantifier}; use datafusion::logical_expr::{ - EmptyRelation, Extension, InvariantLevel, LogicalPlan, PartitionEvaluator, - Repartition, UserDefinedLogicalNode, Values, Volatility, + EmptyRelation, Extension, InvariantLevel, LogicalPlan, Operator, PartitionEvaluator, + Repartition, Subquery, UserDefinedLogicalNode, Values, Volatility, }; use datafusion::optimizer::simplify_expressions::expr_simplifier::THRESHOLD_INLINE_INLIST; use datafusion::prelude::*; @@ -689,6 +690,29 @@ async fn roundtrip_exists_filter() -> Result<()> { Ok(()) } +// assemble logical plan manually to ensure SetComparison expr is present (not rewrite away) +#[tokio::test] +async fn roundtrip_set_comparison_any_substrait() -> Result<()> { + let ctx = create_context().await?; + let plan = build_set_comparison_plan(&ctx, SetQuantifier::Any, Operator::Gt).await?; + let proto = to_substrait_plan(&plan, &ctx.state())?; + let roundtrip_plan = from_substrait_plan(&ctx.state(), &proto).await?; + assert_set_comparison_predicate(&roundtrip_plan, Operator::Gt, SetQuantifier::Any); + Ok(()) +} + +// assemble logical plan manually to ensure SetComparison expr is present (not rewrite away) +#[tokio::test] +async fn roundtrip_set_comparison_all_substrait() -> Result<()> { + let ctx = create_context().await?; + let plan = + build_set_comparison_plan(&ctx, SetQuantifier::All, Operator::NotEq).await?; + let proto = to_substrait_plan(&plan, &ctx.state())?; + let roundtrip_plan = from_substrait_plan(&ctx.state(), &proto).await?; + assert_set_comparison_predicate(&roundtrip_plan, Operator::NotEq, SetQuantifier::All); + Ok(()) +} + #[tokio::test] async fn roundtrip_not_exists_filter_left_anti_join() -> Result<()> { let plan = generate_plan_from_sql( @@ -1865,6 +1889,56 @@ async fn assert_substrait_sql(substrait_plan: Plan, sql: &str) -> Result<()> { Ok(()) } +async fn build_set_comparison_plan( + ctx: &SessionContext, + quantifier: SetQuantifier, + op: Operator, +) -> Result { + let base_scan = ctx.table("data").await?.into_unoptimized_plan(); + let subquery_scan = ctx.table("data2").await?.into_unoptimized_plan(); + let subquery_plan = LogicalPlanBuilder::from(subquery_scan) + .project(vec![col("data2.a")])? + .build()?; + let predicate = Expr::SetComparison(SetComparison::new( + Box::new(col("data.a")), + Subquery { + subquery: Arc::new(subquery_plan), + outer_ref_columns: vec![], + spans: Spans::new(), + }, + op, + quantifier, + )); + + LogicalPlanBuilder::from(base_scan) + .filter(predicate)? + .project(vec![col("data.a")])? + .build() +} + +fn assert_set_comparison_predicate( + plan: &LogicalPlan, + expected_op: Operator, + expected_quantifier: SetQuantifier, +) { + let predicate = match plan { + LogicalPlan::Projection(p) => match p.input.as_ref() { + LogicalPlan::Filter(filter) => &filter.predicate, + other => panic!("expected Filter inside Projection, got {other:?}"), + }, + LogicalPlan::Filter(filter) => &filter.predicate, + other => panic!("expected Filter plan, got {other:?}"), + }; + + match predicate { + Expr::SetComparison(set_comparison) => { + assert_eq!(set_comparison.op, expected_op); + assert_eq!(set_comparison.quantifier, expected_quantifier); + } + other => panic!("expected SetComparison predicate, got {other:?}"), + } +} + async fn roundtrip_fill_na(sql: &str) -> Result<()> { let ctx = create_context().await?; let df = ctx.sql(sql).await?; From 02f19c2f321e6d2adc1ba3177c00c3f6180b6012 Mon Sep 17 00:00:00 2001 From: Ruihang Xia Date: Fri, 5 Dec 2025 21:13:33 +0800 Subject: [PATCH 6/9] use optimizer rule instead Signed-off-by: Ruihang Xia --- datafusion/optimizer/src/analyzer/mod.rs | 3 --- datafusion/optimizer/src/lib.rs | 1 + datafusion/optimizer/src/optimizer.rs | 2 ++ ...omparison.rs => rewrite_set_comparison.rs} | 21 +++++++++++-------- datafusion/sql/Cargo.toml | 1 + datafusion/sql/src/expr/mod.rs | 13 +++++++++++- 6 files changed, 28 insertions(+), 13 deletions(-) rename datafusion/optimizer/src/{analyzer/set_comparison.rs => rewrite_set_comparison.rs} (90%) diff --git a/datafusion/optimizer/src/analyzer/mod.rs b/datafusion/optimizer/src/analyzer/mod.rs index 742f0e1d7b98e..272692f983683 100644 --- a/datafusion/optimizer/src/analyzer/mod.rs +++ b/datafusion/optimizer/src/analyzer/mod.rs @@ -29,7 +29,6 @@ use datafusion_expr::expr_rewriter::FunctionRewrite; use datafusion_expr::{InvariantLevel, LogicalPlan}; use crate::analyzer::resolve_grouping_function::ResolveGroupingFunction; -use crate::analyzer::set_comparison::RewriteSetComparison; use crate::analyzer::type_coercion::TypeCoercion; use crate::utils::log_plan; @@ -37,7 +36,6 @@ use self::function_rewrite::ApplyFunctionRewrites; pub mod function_rewrite; pub mod resolve_grouping_function; -pub mod set_comparison; pub mod type_coercion; /// [`AnalyzerRule`]s transform [`LogicalPlan`]s in some way to make @@ -89,7 +87,6 @@ impl Analyzer { pub fn new() -> Self { let rules: Vec> = vec![ Arc::new(ResolveGroupingFunction::new()), - Arc::new(RewriteSetComparison::new()), Arc::new(TypeCoercion::new()), ]; Self::with_rules(rules) diff --git a/datafusion/optimizer/src/lib.rs b/datafusion/optimizer/src/lib.rs index c4ee23517b4df..e6b24dec87fd8 100644 --- a/datafusion/optimizer/src/lib.rs +++ b/datafusion/optimizer/src/lib.rs @@ -65,6 +65,7 @@ pub mod propagate_empty_relation; pub mod push_down_filter; pub mod push_down_limit; pub mod replace_distinct_aggregate; +pub mod rewrite_set_comparison; pub mod scalar_subquery_to_join; pub mod simplify_expressions; pub mod single_distinct_to_groupby; diff --git a/datafusion/optimizer/src/optimizer.rs b/datafusion/optimizer/src/optimizer.rs index 421563d5e7e88..810c5add32cc6 100644 --- a/datafusion/optimizer/src/optimizer.rs +++ b/datafusion/optimizer/src/optimizer.rs @@ -51,6 +51,7 @@ use crate::propagate_empty_relation::PropagateEmptyRelation; use crate::push_down_filter::PushDownFilter; use crate::push_down_limit::PushDownLimit; use crate::replace_distinct_aggregate::ReplaceDistinctWithAggregate; +use crate::rewrite_set_comparison::RewriteSetComparison; use crate::scalar_subquery_to_join::ScalarSubqueryToJoin; use crate::simplify_expressions::SimplifyExpressions; use crate::single_distinct_to_groupby::SingleDistinctToGroupBy; @@ -227,6 +228,7 @@ impl Optimizer { /// Create a new optimizer using the recommended list of rules pub fn new() -> Self { let rules: Vec> = vec![ + Arc::new(RewriteSetComparison::new()), Arc::new(OptimizeUnions::new()), Arc::new(SimplifyExpressions::new()), Arc::new(ReplaceDistinctWithAggregate::new()), diff --git a/datafusion/optimizer/src/analyzer/set_comparison.rs b/datafusion/optimizer/src/rewrite_set_comparison.rs similarity index 90% rename from datafusion/optimizer/src/analyzer/set_comparison.rs rename to datafusion/optimizer/src/rewrite_set_comparison.rs index 484c6dc7d754a..0f40b8ed24bb4 100644 --- a/datafusion/optimizer/src/analyzer/set_comparison.rs +++ b/datafusion/optimizer/src/rewrite_set_comparison.rs @@ -15,14 +15,14 @@ // specific language governing permissions and limitations // under the License. -//! Rewrite `SetComparison` subqueries (e.g. `= ANY`, `> ALL`) into -//! boolean expressions built from `EXISTS` subqueries that capture SQL -//! three-valued logic. +//! Optimizer rule rewriting `SetComparison` subqueries (e.g. `= ANY`, +//! `> ALL`) into boolean expressions built from `EXISTS` subqueries +//! that capture SQL three-valued logic. -use super::AnalyzerRule; -use datafusion_common::config::ConfigOptions; +use crate::{OptimizerConfig, OptimizerRule}; use datafusion_common::tree_node::{Transformed, TreeNode}; -use datafusion_common::{plan_datafusion_err, DFSchema, ExprSchema, Result, ScalarValue}; +use datafusion_common::ExprSchema; +use datafusion_common::{plan_datafusion_err, DFSchema, Result, ScalarValue}; use datafusion_expr::expr::{self, Exists, SetComparison, SetQuantifier}; use datafusion_expr::logical_plan::builder::LogicalPlanBuilder; use datafusion_expr::logical_plan::Subquery; @@ -52,14 +52,17 @@ impl RewriteSetComparison { } } -impl AnalyzerRule for RewriteSetComparison { +impl OptimizerRule for RewriteSetComparison { fn name(&self) -> &str { "rewrite_set_comparison" } - fn analyze(&self, plan: LogicalPlan, _config: &ConfigOptions) -> Result { + fn rewrite( + &self, + plan: LogicalPlan, + _config: &dyn OptimizerConfig, + ) -> Result> { plan.transform_up_with_subqueries(|plan| self.rewrite_plan(plan)) - .map(|t| t.data) } } diff --git a/datafusion/sql/Cargo.toml b/datafusion/sql/Cargo.toml index a814292a3d71d..5932d098daeac 100644 --- a/datafusion/sql/Cargo.toml +++ b/datafusion/sql/Cargo.toml @@ -61,6 +61,7 @@ log = { workspace = true } recursive = { workspace = true, optional = true } regex = { workspace = true } sqlparser = { workspace = true } +datafusion-functions-nested = { workspace = true, features = ["sql"] } [dev-dependencies] ctor = { workspace = true } diff --git a/datafusion/sql/src/expr/mod.rs b/datafusion/sql/src/expr/mod.rs index 654cde51ec9de..32403bd882a31 100644 --- a/datafusion/sql/src/expr/mod.rs +++ b/datafusion/sql/src/expr/mod.rs @@ -40,6 +40,7 @@ use datafusion_expr::{ }; use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; +use datafusion_functions_nested::expr_fn::array_has; mod binary_op; mod function; @@ -605,7 +606,17 @@ impl SqlToRel<'_, S> { planner_context, ), _ => { - not_impl_err!("ANY/SOME only supports subquery comparison currently") + if compare_op != BinaryOperator::Eq { + plan_err!( + "Unsupported AnyOp: '{compare_op}', only '=' is supported" + ) + } else { + let left_expr = + self.sql_to_expr(*left, schema, planner_context)?; + let right_expr = + self.sql_to_expr(*right, schema, planner_context)?; + Ok(array_has(right_expr, left_expr)) + } } }, SQLExpr::AllOp { From 38f02485951dfbb1712506e57e03a66d4fcefa7b Mon Sep 17 00:00:00 2001 From: Ruihang Xia Date: Fri, 5 Dec 2025 21:45:52 +0800 Subject: [PATCH 7/9] format toml and update slt result Signed-off-by: Ruihang Xia --- datafusion/sql/Cargo.toml | 2 +- datafusion/sqllogictest/test_files/explain.slt | 4 ++++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/datafusion/sql/Cargo.toml b/datafusion/sql/Cargo.toml index 5932d098daeac..b7338cb764d77 100644 --- a/datafusion/sql/Cargo.toml +++ b/datafusion/sql/Cargo.toml @@ -56,12 +56,12 @@ bigdecimal = { workspace = true } chrono = { workspace = true } datafusion-common = { workspace = true, features = ["sql"] } datafusion-expr = { workspace = true, features = ["sql"] } +datafusion-functions-nested = { workspace = true, features = ["sql"] } indexmap = { workspace = true } log = { workspace = true } recursive = { workspace = true, optional = true } regex = { workspace = true } sqlparser = { workspace = true } -datafusion-functions-nested = { workspace = true, features = ["sql"] } [dev-dependencies] ctor = { workspace = true } diff --git a/datafusion/sqllogictest/test_files/explain.slt b/datafusion/sqllogictest/test_files/explain.slt index 918c01b5613af..922cad827fa67 100644 --- a/datafusion/sqllogictest/test_files/explain.slt +++ b/datafusion/sqllogictest/test_files/explain.slt @@ -176,6 +176,7 @@ initial_logical_plan logical_plan after resolve_grouping_function SAME TEXT AS ABOVE logical_plan after type_coercion SAME TEXT AS ABOVE analyzed_logical_plan SAME TEXT AS ABOVE +logical_plan after rewrite_set_comparison SAME TEXT AS ABOVE logical_plan after optimize_unions SAME TEXT AS ABOVE logical_plan after simplify_expressions SAME TEXT AS ABOVE logical_plan after replace_distinct_aggregate SAME TEXT AS ABOVE @@ -197,6 +198,7 @@ logical_plan after single_distinct_aggregation_to_group_by SAME TEXT AS ABOVE logical_plan after eliminate_group_by_constant SAME TEXT AS ABOVE logical_plan after common_sub_expression_eliminate SAME TEXT AS ABOVE logical_plan after optimize_projections TableScan: simple_explain_test projection=[a, b, c] +logical_plan after rewrite_set_comparison SAME TEXT AS ABOVE logical_plan after optimize_unions SAME TEXT AS ABOVE logical_plan after simplify_expressions SAME TEXT AS ABOVE logical_plan after replace_distinct_aggregate SAME TEXT AS ABOVE @@ -535,6 +537,7 @@ initial_logical_plan logical_plan after resolve_grouping_function SAME TEXT AS ABOVE logical_plan after type_coercion SAME TEXT AS ABOVE analyzed_logical_plan SAME TEXT AS ABOVE +logical_plan after rewrite_set_comparison SAME TEXT AS ABOVE logical_plan after optimize_unions SAME TEXT AS ABOVE logical_plan after simplify_expressions SAME TEXT AS ABOVE logical_plan after replace_distinct_aggregate SAME TEXT AS ABOVE @@ -556,6 +559,7 @@ logical_plan after single_distinct_aggregation_to_group_by SAME TEXT AS ABOVE logical_plan after eliminate_group_by_constant SAME TEXT AS ABOVE logical_plan after common_sub_expression_eliminate SAME TEXT AS ABOVE logical_plan after optimize_projections TableScan: simple_explain_test projection=[a, b, c] +logical_plan after rewrite_set_comparison SAME TEXT AS ABOVE logical_plan after optimize_unions SAME TEXT AS ABOVE logical_plan after simplify_expressions SAME TEXT AS ABOVE logical_plan after replace_distinct_aggregate SAME TEXT AS ABOVE From 07e23bd0303dac19e7b26f6138993183511041de Mon Sep 17 00:00:00 2001 From: Ruihang Xia Date: Thu, 25 Dec 2025 16:41:56 +0800 Subject: [PATCH 8/9] more tests including type mismatch, operators and NULL Signed-off-by: Ruihang Xia --- datafusion/core/tests/set_comparison.rs | 95 ++++++++++++++++++- .../optimizer/src/analyzer/type_coercion.rs | 9 ++ datafusion/sql/src/unparser/expr.rs | 22 ++--- .../logical_plan/consumer/expr/subquery.rs | 4 +- .../producer/substrait_producer.rs | 10 +- .../tests/cases/roundtrip_logical_plan.rs | 4 +- 6 files changed, 122 insertions(+), 22 deletions(-) diff --git a/datafusion/core/tests/set_comparison.rs b/datafusion/core/tests/set_comparison.rs index 4cad70bfb3c95..1ee1b940ba790 100644 --- a/datafusion/core/tests/set_comparison.rs +++ b/datafusion/core/tests/set_comparison.rs @@ -17,11 +17,11 @@ use std::sync::Arc; -use arrow::array::Int32Array; +use arrow::array::{Int32Array, StringArray}; use arrow::datatypes::{DataType, Field, Schema}; use arrow::record_batch::RecordBatch; use datafusion::prelude::SessionContext; -use datafusion_common::{Result, assert_batches_eq}; +use datafusion_common::{Result, assert_batches_eq, assert_contains}; fn build_table(values: &[i32]) -> Result { let schema = Arc::new(Schema::new(vec![Field::new("v", DataType::Int32, true)])); @@ -82,3 +82,94 @@ async fn set_comparison_all_empty() -> Result<()> { ); Ok(()) } + +#[tokio::test] +async fn set_comparison_type_mismatch() -> Result<()> { + let ctx = SessionContext::new(); + + ctx.register_batch("t", build_table(&[1])?)?; + ctx.register_batch("strings", { + let schema = Arc::new(Schema::new(vec![Field::new("s", DataType::Utf8, true)])); + let array = Arc::new(StringArray::from(vec![Some("a"), Some("b")])) + as Arc; + RecordBatch::try_new(schema, vec![array])? + })?; + + let df = ctx + .sql("select v from t where v > any(select s from strings)") + .await?; + let err = df.collect().await.unwrap_err(); + assert_contains!( + err.to_string(), + "expr type Int32 can't cast to Utf8 in SetComparison" + ); + Ok(()) +} + +#[tokio::test] +async fn set_comparison_multiple_operators() -> Result<()> { + let ctx = SessionContext::new(); + + ctx.register_batch("t", build_table(&[1, 2, 3, 4])?)?; + ctx.register_batch("s", build_table(&[2, 3])?)?; + + let df = ctx + .sql("select v from t where v = any(select v from s) order by v") + .await?; + let results = df.collect().await?; + assert_batches_eq!( + &["+---+", "| v |", "+---+", "| 2 |", "| 3 |", "+---+",], + &results + ); + + let df = ctx + .sql("select v from t where v != all(select v from s) order by v") + .await?; + let results = df.collect().await?; + assert_batches_eq!( + &["+---+", "| v |", "+---+", "| 1 |", "| 4 |", "+---+",], + &results + ); + + let df = ctx + .sql("select v from t where v >= all(select v from s) order by v") + .await?; + let results = df.collect().await?; + assert_batches_eq!( + &["+---+", "| v |", "+---+", "| 3 |", "| 4 |", "+---+",], + &results + ); + + let df = ctx + .sql("select v from t where v <= any(select v from s) order by v") + .await?; + let results = df.collect().await?; + assert_batches_eq!( + &[ + "+---+", "| v |", "+---+", "| 1 |", "| 2 |", "| 3 |", "+---+", + ], + &results + ); + Ok(()) +} + +#[tokio::test] +async fn set_comparison_null_semantics_all() -> Result<()> { + let ctx = SessionContext::new(); + + ctx.register_batch("t", build_table(&[5])?)?; + ctx.register_batch("s", { + let schema = Arc::new(Schema::new(vec![Field::new("v", DataType::Int32, true)])); + let array = Arc::new(Int32Array::from(vec![Some(1), None])) + as Arc; + RecordBatch::try_new(schema, vec![array])? + })?; + + let df = ctx + .sql("select v from t where v != all(select v from s)") + .await?; + let results = df.collect().await?; + let row_count: usize = results.iter().map(|batch| batch.num_rows()).sum(); + assert_eq!(0, row_count); + Ok(()) +} diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index 332ff430254b9..1885dcab2c4dc 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -381,6 +381,15 @@ impl TreeNodeRewriter for TypeCoercionRewriter<'_> { .data; let expr_type = expr.get_type(self.schema)?; let subquery_type = new_plan.schema().field(0).data_type(); + if (expr_type.is_numeric() + && is_utf8_or_utf8view_or_large_utf8(subquery_type)) + || (subquery_type.is_numeric() + && is_utf8_or_utf8view_or_large_utf8(&expr_type)) + { + return plan_err!( + "expr type {expr_type} can't cast to {subquery_type} in SetComparison" + ); + } let common_type = comparison_coercion(&expr_type, subquery_type).ok_or( plan_datafusion_err!( "expr type {expr_type} can't cast to {subquery_type} in SetComparison" diff --git a/datafusion/sql/src/unparser/expr.rs b/datafusion/sql/src/unparser/expr.rs index 3ee092430f4c5..ac7b467920364 100644 --- a/datafusion/sql/src/unparser/expr.rs +++ b/datafusion/sql/src/unparser/expr.rs @@ -25,27 +25,27 @@ use sqlparser::ast::{ use std::sync::Arc; use std::vec; -use super::dialect::IntervalStyle; use super::Unparser; +use super::dialect::IntervalStyle; use arrow::array::{ + ArrayRef, Date32Array, Date64Array, PrimitiveArray, types::{ ArrowTemporalType, Time32MillisecondType, Time32SecondType, Time64MicrosecondType, Time64NanosecondType, TimestampMicrosecondType, TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, }, - ArrayRef, Date32Array, Date64Array, PrimitiveArray, }; use arrow::datatypes::{ - DataType, Decimal128Type, Decimal256Type, Decimal32Type, Decimal64Type, DecimalType, + DataType, Decimal32Type, Decimal64Type, Decimal128Type, Decimal256Type, DecimalType, }; use arrow::util::display::array_value_to_string; use datafusion_common::{ - assert_eq_or_internal_err, assert_or_internal_err, internal_datafusion_err, - internal_err, not_impl_err, plan_err, Column, Result, ScalarValue, + Column, Result, ScalarValue, assert_eq_or_internal_err, assert_or_internal_err, + internal_datafusion_err, internal_err, not_impl_err, plan_err, }; use datafusion_expr::{ - expr::{Alias, Exists, InList, ScalarFunction, SetQuantifier, Sort, WindowFunction}, Between, BinaryExpr, Case, Cast, Expr, GroupingSet, Like, Operator, TryCast, + expr::{Alias, Exists, InList, ScalarFunction, SetQuantifier, Sort, WindowFunction}, }; use sqlparser::ast::helpers::attached_token::AttachedToken; use sqlparser::tokenizer::Span; @@ -1831,12 +1831,12 @@ mod tests { use datafusion_common::{Spans, TableReference}; use datafusion_expr::expr::WildcardOptions; use datafusion_expr::{ - case, cast, col, cube, exists, grouping_set, interval_datetime_lit, - interval_year_month_lit, lit, not, not_exists, out_ref_col, placeholder, rollup, - table_scan, try_cast, when, ColumnarValue, ScalarFunctionArgs, ScalarUDF, - ScalarUDFImpl, Signature, Volatility, WindowFrame, WindowFunctionDefinition, + ColumnarValue, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature, + Volatility, WindowFrame, WindowFunctionDefinition, case, cast, col, cube, exists, + grouping_set, interval_datetime_lit, interval_year_month_lit, lit, not, + not_exists, out_ref_col, placeholder, rollup, table_scan, try_cast, when, }; - use datafusion_expr::{interval_month_day_nano_lit, ExprFunctionExt}; + use datafusion_expr::{ExprFunctionExt, interval_month_day_nano_lit}; use datafusion_functions::datetime::from_unixtime::FromUnixtimeFunc; use datafusion_functions::expr_fn::{get_field, named_struct}; use datafusion_functions_aggregate::count::count_udaf; diff --git a/datafusion/substrait/src/logical_plan/consumer/expr/subquery.rs b/datafusion/substrait/src/logical_plan/consumer/expr/subquery.rs index dab26aab73fad..61a381e9eb407 100644 --- a/datafusion/substrait/src/logical_plan/consumer/expr/subquery.rs +++ b/datafusion/substrait/src/logical_plan/consumer/expr/subquery.rs @@ -16,14 +16,14 @@ // under the License. use crate::logical_plan::consumer::SubstraitConsumer; -use datafusion::common::{substrait_datafusion_err, substrait_err, DFSchema, Spans}; +use datafusion::common::{DFSchema, Spans, substrait_datafusion_err, substrait_err}; use datafusion::logical_expr::expr::{Exists, InSubquery, SetComparison, SetQuantifier}; use datafusion::logical_expr::{Expr, Operator, Subquery}; use std::sync::Arc; use substrait::proto::expression as substrait_expression; +use substrait::proto::expression::subquery::SubqueryType; use substrait::proto::expression::subquery::set_comparison::{ComparisonOp, ReductionOp}; use substrait::proto::expression::subquery::set_predicate::PredicateOp; -use substrait::proto::expression::subquery::SubqueryType; pub async fn from_subquery( consumer: &impl SubstraitConsumer, diff --git a/datafusion/substrait/src/logical_plan/producer/substrait_producer.rs b/datafusion/substrait/src/logical_plan/producer/substrait_producer.rs index ee272d74f5c72..c7518bd04e4a1 100644 --- a/datafusion/substrait/src/logical_plan/producer/substrait_producer.rs +++ b/datafusion/substrait/src/logical_plan/producer/substrait_producer.rs @@ -25,16 +25,16 @@ use crate::logical_plan::producer::{ from_union, from_values, from_window, from_window_function, to_substrait_rel, to_substrait_rex, }; -use datafusion::common::{substrait_err, Column, DFSchemaRef, ScalarValue}; -use datafusion::execution::registry::SerializerRegistry; +use datafusion::common::{Column, DFSchemaRef, ScalarValue, substrait_err}; use datafusion::execution::SessionState; +use datafusion::execution::registry::SerializerRegistry; use datafusion::logical_expr::expr::{ Alias, InList, InSubquery, SetComparison, WindowFunction, }; use datafusion::logical_expr::{ - expr, Aggregate, Between, BinaryExpr, Case, Cast, Distinct, EmptyRelation, Expr, - Extension, Filter, Join, Like, Limit, LogicalPlan, Projection, Repartition, Sort, - SubqueryAlias, TableScan, TryCast, Union, Values, Window, + Aggregate, Between, BinaryExpr, Case, Cast, Distinct, EmptyRelation, Expr, Extension, + Filter, Join, Like, Limit, LogicalPlan, Projection, Repartition, Sort, SubqueryAlias, + TableScan, TryCast, Union, Values, Window, expr, }; use pbjson_types::Any as ProtoAny; use substrait::proto::aggregate_rel::Measure; diff --git a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs index 3ae74c2d69dc4..f78b255526dc9 100644 --- a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs @@ -29,7 +29,7 @@ use std::mem::size_of_val; use datafusion::arrow::datatypes::{DataType, Field, IntervalUnit, Schema, TimeUnit}; use datafusion::common::tree_node::Transformed; -use datafusion::common::{not_impl_err, plan_err, DFSchema, DFSchemaRef, Spans}; +use datafusion::common::{DFSchema, DFSchemaRef, Spans, not_impl_err, plan_err}; use datafusion::error::Result; use datafusion::execution::registry::SerializerRegistry; use datafusion::execution::runtime_env::RuntimeEnv; @@ -46,7 +46,7 @@ use std::hash::Hash; use std::sync::Arc; use substrait::proto::extensions::simple_extension_declaration::MappingType; use substrait::proto::rel::RelType; -use substrait::proto::{plan_rel, Plan, Rel}; +use substrait::proto::{Plan, Rel, plan_rel}; #[derive(Debug)] struct MockSerializerRegistry; From c53451cd4e3a1d50a5aede123854404192a2ad60 Mon Sep 17 00:00:00 2001 From: Ruihang Xia Date: Thu, 25 Dec 2025 19:46:11 +0800 Subject: [PATCH 9/9] fix aggr/window plan Signed-off-by: Ruihang Xia --- datafusion/core/tests/set_comparison.rs | 18 ++++++++++++++++++ .../optimizer/src/rewrite_set_comparison.rs | 13 +++++++------ 2 files changed, 25 insertions(+), 6 deletions(-) diff --git a/datafusion/core/tests/set_comparison.rs b/datafusion/core/tests/set_comparison.rs index 1ee1b940ba790..464d6c937b328 100644 --- a/datafusion/core/tests/set_comparison.rs +++ b/datafusion/core/tests/set_comparison.rs @@ -55,6 +55,24 @@ async fn set_comparison_any() -> Result<()> { Ok(()) } +#[tokio::test] +async fn set_comparison_any_aggregate_subquery() -> Result<()> { + let ctx = SessionContext::new(); + + ctx.register_batch("t", build_table(&[1, 7])?)?; + ctx.register_batch("s", build_table(&[1, 2, 3])?)?; + + let df = ctx + .sql( + "select v from t where v > any(select sum(v) from s group by v % 2) order by v", + ) + .await?; + let results = df.collect().await?; + + assert_batches_eq!(&["+---+", "| v |", "+---+", "| 7 |", "+---+",], &results); + Ok(()) +} + #[tokio::test] async fn set_comparison_all_empty() -> Result<()> { let ctx = SessionContext::new(); diff --git a/datafusion/optimizer/src/rewrite_set_comparison.rs b/datafusion/optimizer/src/rewrite_set_comparison.rs index d34a75fd475ba..0e642606f5746 100644 --- a/datafusion/optimizer/src/rewrite_set_comparison.rs +++ b/datafusion/optimizer/src/rewrite_set_comparison.rs @@ -20,9 +20,8 @@ //! that capture SQL three-valued logic. use crate::{OptimizerConfig, OptimizerRule}; -use datafusion_common::ExprSchema; use datafusion_common::tree_node::{Transformed, TreeNode}; -use datafusion_common::{DFSchema, Result, ScalarValue, plan_datafusion_err}; +use datafusion_common::{Column, DFSchema, ExprSchema, Result, ScalarValue, plan_err}; use datafusion_expr::expr::{self, Exists, SetComparison, SetQuantifier}; use datafusion_expr::logical_plan::Subquery; use datafusion_expr::logical_plan::builder::LogicalPlanBuilder; @@ -91,10 +90,12 @@ fn build_set_comparison_subquery( } = set_comparison; let left_expr = to_outer_reference(*expr, outer_schema)?; - let right_expr = subquery - .subquery - .head_output_expr()? - .ok_or_else(|| plan_datafusion_err!("single expression required."))?; + let subquery_schema = subquery.subquery.schema(); + if subquery_schema.fields().is_empty() { + return plan_err!("single expression required."); + } + // avoid `head_output_expr` for aggr/window plan, it will gives group-by expr if exists + let right_expr = Expr::Column(Column::from(subquery_schema.qualified_field(0))); let comparison = Expr::BinaryExpr(expr::BinaryExpr::new( Box::new(left_expr),