diff --git a/datafusion-examples/examples/relation_planner/table_sample.rs b/datafusion-examples/examples/relation_planner/table_sample.rs index 207fffe1327a3..362d35dcf4cac 100644 --- a/datafusion-examples/examples/relation_planner/table_sample.rs +++ b/datafusion-examples/examples/relation_planner/table_sample.rs @@ -83,13 +83,12 @@ use std::{ any::Any, fmt::{self, Debug, Formatter}, hash::{Hash, Hasher}, - ops::{Add, Div, Mul, Sub}, pin::Pin, - str::FromStr, sync::Arc, task::{Context, Poll}, }; +use arrow::datatypes::{Float64Type, Int64Type}; use arrow::{ array::{ArrayRef, Int32Array, RecordBatch, StringArray, UInt32Array}, compute, @@ -102,6 +101,7 @@ use futures::{ use rand::{Rng, SeedableRng, rngs::StdRng}; use tonic::async_trait; +use datafusion::optimizer::simplify_expressions::simplify_literal::parse_literal; use datafusion::{ execution::{ RecordBatchStream, SendableRecordBatchStream, SessionState, SessionStateBuilder, @@ -410,11 +410,12 @@ impl RelationPlanner for TableSamplePlanner { "TABLESAMPLE requires a quantity (percentage, fraction, or row count)" ); }; + let quantity_value_expr = context.sql_to_expr(quantity.value, input.schema())?; match quantity.unit { // TABLESAMPLE (N ROWS) - exact row limit Some(TableSampleUnit::Rows) => { - let rows = parse_quantity::(&quantity.value)?; + let rows: i64 = parse_literal::(&quantity_value_expr)?; if rows < 0 { return plan_err!("row count must be non-negative, got {}", rows); } @@ -426,7 +427,7 @@ impl RelationPlanner for TableSamplePlanner { // TABLESAMPLE (N PERCENT) - percentage sampling Some(TableSampleUnit::Percent) => { - let percent = parse_quantity::(&quantity.value)?; + let percent: f64 = parse_literal::(&quantity_value_expr)?; let fraction = percent / 100.0; let plan = TableSamplePlanNode::new(input, fraction, seed).into_plan(); Ok(RelationPlanning::Planned(PlannedRelation::new(plan, alias))) @@ -434,7 +435,7 @@ impl RelationPlanner for TableSamplePlanner { // TABLESAMPLE (N) - fraction if <1.0, row limit if >=1.0 None => { - let value = parse_quantity::(&quantity.value)?; + let value = parse_literal::(&quantity_value_expr)?; if value < 0.0 { return plan_err!("sample value must be non-negative, got {}", value); } @@ -453,40 +454,6 @@ impl RelationPlanner for TableSamplePlanner { } } -/// Parse a SQL expression as a numeric value (supports basic arithmetic). -fn parse_quantity(expr: &ast::Expr) -> Result -where - T: FromStr + Add + Sub + Mul + Div, -{ - eval_numeric_expr(expr) - .ok_or_else(|| plan_datafusion_err!("invalid numeric expression: {:?}", expr)) -} - -/// Recursively evaluate numeric SQL expressions. -fn eval_numeric_expr(expr: &ast::Expr) -> Option -where - T: FromStr + Add + Sub + Mul + Div, -{ - match expr { - ast::Expr::Value(v) => match &v.value { - ast::Value::Number(n, _) => n.to_string().parse().ok(), - _ => None, - }, - ast::Expr::BinaryOp { left, op, right } => { - let l = eval_numeric_expr::(left)?; - let r = eval_numeric_expr::(right)?; - match op { - ast::BinaryOperator::Plus => Some(l + r), - ast::BinaryOperator::Minus => Some(l - r), - ast::BinaryOperator::Multiply => Some(l * r), - ast::BinaryOperator::Divide => Some(l / r), - _ => None, - } - } - _ => None, - } -} - /// Custom logical plan node representing a TABLESAMPLE operation. /// /// Stores sampling parameters (bounds, seed) and wraps the input plan. diff --git a/datafusion/optimizer/src/simplify_expressions/mod.rs b/datafusion/optimizer/src/simplify_expressions/mod.rs index e238fca32689d..58a4eadb5c078 100644 --- a/datafusion/optimizer/src/simplify_expressions/mod.rs +++ b/datafusion/optimizer/src/simplify_expressions/mod.rs @@ -22,6 +22,7 @@ pub mod expr_simplifier; mod inlist_simplifier; mod regex; pub mod simplify_exprs; +pub mod simplify_literal; mod simplify_predicates; mod unwrap_cast; mod utils; diff --git a/datafusion/optimizer/src/simplify_expressions/simplify_literal.rs b/datafusion/optimizer/src/simplify_expressions/simplify_literal.rs new file mode 100644 index 0000000000000..168a6ebb461f0 --- /dev/null +++ b/datafusion/optimizer/src/simplify_expressions/simplify_literal.rs @@ -0,0 +1,151 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Parses and simplifies an expression to a literal of a given type. +//! +//! This module provides functionality to parse and simplify static expressions +//! used in SQL constructs like `FROM TABLE SAMPLE (10 + 50 * 2)`. If they are required +//! in a planning (not an execution) phase, they need to be reduced to literals of a given type. + +use crate::simplify_expressions::ExprSimplifier; +use arrow::datatypes::ArrowPrimitiveType; +use datafusion_common::{ + DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue, plan_datafusion_err, + plan_err, +}; +use datafusion_expr::Expr; +use datafusion_expr::execution_props::ExecutionProps; +use datafusion_expr::simplify::SimplifyContext; +use std::sync::Arc; + +/// Parse and simplifies an expression to a numeric literal, +/// corresponding to an arrow primitive type `T` (for example, Float64Type). +/// +/// This function simplifies and coerces the expression, then extracts the underlying +/// native type using `TryFrom`. +/// +/// # Example +/// ```ignore +/// let value: f64 = parse_literal::(expr)?; +/// ``` +pub fn parse_literal(expr: &Expr) -> Result +where + T: ArrowPrimitiveType, + T::Native: TryFrom, +{ + // Empty schema is sufficient because it parses only literal expressions + let schema = DFSchemaRef::new(DFSchema::empty()); + + log::debug!("Parsing expr {:?} to type {}", expr, T::DATA_TYPE); + + let execution_props = ExecutionProps::new(); + let simplifier = ExprSimplifier::new( + SimplifyContext::new(&execution_props).with_schema(Arc::clone(&schema)), + ); + + // Simplify and coerce expression in case of constant arithmetic operations (e.g., 10 + 5) + let simplified_expr: Expr = simplifier + .simplify(expr.clone()) + .map_err(|err| plan_datafusion_err!("Cannot simplify {expr:?}: {err}"))?; + let coerced_expr: Expr = simplifier.coerce(simplified_expr, schema.as_ref())?; + log::debug!("Coerced expression: {:?}", &coerced_expr); + + match coerced_expr { + Expr::Literal(scalar_value, _) => { + // It is a literal - proceed to the underlying value + // Cast to the target type if needed + let casted_scalar = scalar_value.cast_to(&T::DATA_TYPE)?; + + // Extract the native type + T::Native::try_from(casted_scalar).map_err(|err| { + plan_datafusion_err!( + "Cannot extract {} from scalar value: {err}", + std::any::type_name::() + ) + }) + } + actual => { + plan_err!( + "Cannot extract literal from coerced {actual:?} expression given {expr:?} expression" + ) + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::datatypes::{Float64Type, Int64Type}; + use datafusion_expr::{BinaryExpr, lit}; + use datafusion_expr_common::operator::Operator; + + #[test] + fn test_parse_sql_float_literal() { + let test_cases = vec![ + (Expr::Literal(ScalarValue::Float64(Some(0.0)), None), 0.0), + (Expr::Literal(ScalarValue::Float64(Some(1.0)), None), 1.0), + ( + Expr::BinaryExpr(BinaryExpr::new( + Box::new(lit(50.0)), + Operator::Minus, + Box::new(lit(10.0)), + )), + 40.0, + ), + ( + Expr::Literal(ScalarValue::Utf8(Some("1e2".into())), None), + 100.0, + ), + ( + Expr::Literal(ScalarValue::Utf8(Some("2.5e-1".into())), None), + 0.25, + ), + ]; + + for (expr, expected) in test_cases { + let result: Result = parse_literal::(&expr); + + match result { + Ok(value) => { + assert!( + (value - expected).abs() < 1e-10, + "For expression '{expr}': expected {expected}, got {value}", + ); + } + Err(e) => panic!("Failed to parse expression '{expr}': {e}"), + } + } + } + + #[test] + fn test_parse_sql_integer_literal() { + let expr = Expr::BinaryExpr(BinaryExpr::new( + Box::new(lit(2)), + Operator::Plus, + Box::new(lit(4)), + )); + + let result: Result = parse_literal::(&expr); + + match result { + Ok(value) => { + assert_eq!(6, value); + } + Err(e) => panic!("Failed to parse expression: {e}"), + } + } +} diff --git a/datafusion/sql/src/lib.rs b/datafusion/sql/src/lib.rs index b21eb52920ab5..a81960aa4ed5c 100644 --- a/datafusion/sql/src/lib.rs +++ b/datafusion/sql/src/lib.rs @@ -48,7 +48,7 @@ mod expr; pub mod parser; pub mod planner; mod query; -mod relation; +pub mod relation; pub mod resolve; mod select; mod set_expr; diff --git a/datafusion/sql/src/relation/mod.rs b/datafusion/sql/src/relation/mod.rs index 3115d8dfffbd2..5f96779d1a1f6 100644 --- a/datafusion/sql/src/relation/mod.rs +++ b/datafusion/sql/src/relation/mod.rs @@ -33,11 +33,23 @@ use sqlparser::ast::{FunctionArg, FunctionArgExpr, Spanned, TableFactor}; mod join; -struct SqlToRelRelationContext<'a, 'b, S: ContextProvider> { +pub struct SqlToRelRelationContext<'a, 'b, S: ContextProvider> { planner: &'a SqlToRel<'b, S>, planner_context: &'a mut PlannerContext, } +impl<'a, 'b, S: ContextProvider> SqlToRelRelationContext<'a, 'b, S> { + pub fn new( + planner: &'a SqlToRel<'b, S>, + planner_context: &'a mut PlannerContext, + ) -> Self { + Self { + planner, + planner_context, + } + } +} + // Implement RelationPlannerContext impl<'a, 'b, S: ContextProvider> RelationPlannerContext for SqlToRelRelationContext<'a, 'b, S> @@ -117,11 +129,7 @@ impl SqlToRel<'_, S> { let mut current_relation = relation; for planner in planners.iter() { - let mut context = SqlToRelRelationContext { - planner: self, - planner_context, - }; - + let mut context = SqlToRelRelationContext::new(self, planner_context); match planner.plan_relation(current_relation, &mut context)? { RelationPlanning::Planned(planned) => { return Ok(RelationPlanning::Planned(planned));