diff --git a/datafusion/physical-expr/src/projection.rs b/datafusion/physical-expr/src/projection.rs index c46df87fd8b3..931790485259 100644 --- a/datafusion/physical-expr/src/projection.rs +++ b/datafusion/physical-expr/src/projection.rs @@ -21,15 +21,15 @@ use std::ops::Deref; use std::sync::Arc; use crate::PhysicalExpr; -use crate::expressions::Column; +use crate::expressions::{Column, Literal}; use crate::utils::collect_columns; use arrow::array::{RecordBatch, RecordBatchOptions}; use arrow::datatypes::{Field, Schema, SchemaRef}; -use datafusion_common::stats::ColumnStatistics; +use datafusion_common::stats::{ColumnStatistics, Precision}; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::{ - Result, assert_or_internal_err, internal_datafusion_err, plan_err, + Result, ScalarValue, assert_or_internal_err, internal_datafusion_err, plan_err, }; use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; @@ -587,6 +587,54 @@ impl ProjectionExprs { let expr = &proj_expr.expr; let col_stats = if let Some(col) = expr.as_any().downcast_ref::() { std::mem::take(&mut stats.column_statistics[col.index()]) + } else if let Some(literal) = expr.as_any().downcast_ref::() { + // Handle literal expressions (constants) by calculating proper statistics + let data_type = expr.data_type(output_schema)?; + + if literal.value().is_null() { + let null_count = match stats.num_rows { + Precision::Exact(num_rows) => Precision::Exact(num_rows), + _ => Precision::Absent, + }; + + ColumnStatistics { + min_value: Precision::Absent, + max_value: Precision::Absent, + distinct_count: Precision::Exact(1), + null_count, + sum_value: Precision::Absent, + byte_size: Precision::Absent, + } + } else { + let value = literal.value(); + let distinct_count = Precision::Exact(1); + let null_count = Precision::Exact(0); + + let byte_size = if let Some(byte_width) = data_type.primitive_width() + { + stats.num_rows.multiply(&Precision::Exact(byte_width)) + } else { + // Complex types depend on array encoding, so set to Absent + Precision::Absent + }; + + let sum_value = Precision::::from(stats.num_rows) + .cast_to(&value.data_type()) + .ok() + .map(|row_count| { + Precision::Exact(value.clone()).multiply(&row_count) + }) + .unwrap_or(Precision::Absent); + + ColumnStatistics { + min_value: Precision::Exact(value.clone()), + max_value: Precision::Exact(value.clone()), + distinct_count, + null_count, + sum_value, + byte_size, + } + } } else { // TODO stats: estimate more statistics from expressions // (expressions should compute their statistics themselves) @@ -2593,4 +2641,217 @@ pub(crate) mod tests { Ok(()) } + + // Test statistics calculation for non-null literal (numeric constant) + #[test] + fn test_project_statistics_with_literal() -> Result<()> { + let input_stats = get_stats(); + let input_schema = get_schema(); + + // Projection with literal: SELECT 42 AS constant, col0 AS num + let projection = ProjectionExprs::new(vec![ + ProjectionExpr { + expr: Arc::new(Literal::new(ScalarValue::Int64(Some(42)))), + alias: "constant".to_string(), + }, + ProjectionExpr { + expr: Arc::new(Column::new("col0", 0)), + alias: "num".to_string(), + }, + ]); + + let output_stats = projection.project_statistics( + input_stats, + &projection.project_schema(&input_schema)?, + )?; + + // Row count should be preserved + assert_eq!(output_stats.num_rows, Precision::Exact(5)); + + // Should have 2 column statistics + assert_eq!(output_stats.column_statistics.len(), 2); + + // First column (literal 42) should have proper constant statistics + assert_eq!( + output_stats.column_statistics[0].min_value, + Precision::Exact(ScalarValue::Int64(Some(42))) + ); + assert_eq!( + output_stats.column_statistics[0].max_value, + Precision::Exact(ScalarValue::Int64(Some(42))) + ); + assert_eq!( + output_stats.column_statistics[0].distinct_count, + Precision::Exact(1) + ); + assert_eq!( + output_stats.column_statistics[0].null_count, + Precision::Exact(0) + ); + // Int64 is 8 bytes, 5 rows = 40 bytes + assert_eq!( + output_stats.column_statistics[0].byte_size, + Precision::Exact(40) + ); + // For a constant column, sum_value = value * num_rows = 42 * 5 = 210 + assert_eq!( + output_stats.column_statistics[0].sum_value, + Precision::Exact(ScalarValue::Int64(Some(210))) + ); + + // Second column (col0) should preserve statistics + assert_eq!( + output_stats.column_statistics[1].distinct_count, + Precision::Exact(5) + ); + assert_eq!( + output_stats.column_statistics[1].max_value, + Precision::Exact(ScalarValue::Int64(Some(21))) + ); + + Ok(()) + } + + // Test statistics calculation for NULL literal (constant NULL column) + #[test] + fn test_project_statistics_with_null_literal() -> Result<()> { + let input_stats = get_stats(); + let input_schema = get_schema(); + + // Projection with NULL literal: SELECT NULL AS null_col, col0 AS num + let projection = ProjectionExprs::new(vec![ + ProjectionExpr { + expr: Arc::new(Literal::new(ScalarValue::Int64(None))), + alias: "null_col".to_string(), + }, + ProjectionExpr { + expr: Arc::new(Column::new("col0", 0)), + alias: "num".to_string(), + }, + ]); + + let output_stats = projection.project_statistics( + input_stats, + &projection.project_schema(&input_schema)?, + )?; + + // Row count should be preserved + assert_eq!(output_stats.num_rows, Precision::Exact(5)); + + // Should have 2 column statistics + assert_eq!(output_stats.column_statistics.len(), 2); + + // First column (NULL literal) should have proper constant NULL statistics + assert_eq!( + output_stats.column_statistics[0].min_value, + Precision::Absent // NULLs don't have min/max + ); + assert_eq!( + output_stats.column_statistics[0].max_value, + Precision::Absent // NULLs don't have min/max + ); + assert_eq!( + output_stats.column_statistics[0].distinct_count, + Precision::Exact(1) // All NULLs are considered the same + ); + assert_eq!( + output_stats.column_statistics[0].null_count, + Precision::Exact(5) // All rows are NULL + ); + assert_eq!( + output_stats.column_statistics[0].byte_size, + Precision::Absent // NULLs don't take space + ); + assert_eq!( + output_stats.column_statistics[0].sum_value, + Precision::Absent // Sum doesn't make sense for NULLs + ); + + // Second column (col0) should preserve statistics + assert_eq!( + output_stats.column_statistics[1].distinct_count, + Precision::Exact(5) + ); + assert_eq!( + output_stats.column_statistics[1].max_value, + Precision::Exact(ScalarValue::Int64(Some(21))) + ); + + Ok(()) + } + + // Test statistics calculation for complex type literal (e.g., Utf8 string) + #[test] + fn test_project_statistics_with_complex_type_literal() -> Result<()> { + let input_stats = get_stats(); + let input_schema = get_schema(); + + // Projection with Utf8 literal (complex type): SELECT 'hello' AS text, col0 AS num + let projection = ProjectionExprs::new(vec![ + ProjectionExpr { + expr: Arc::new(Literal::new(ScalarValue::Utf8(Some( + "hello".to_string(), + )))), + alias: "text".to_string(), + }, + ProjectionExpr { + expr: Arc::new(Column::new("col0", 0)), + alias: "num".to_string(), + }, + ]); + + let output_stats = projection.project_statistics( + input_stats, + &projection.project_schema(&input_schema)?, + )?; + + // Row count should be preserved + assert_eq!(output_stats.num_rows, Precision::Exact(5)); + + // Should have 2 column statistics + assert_eq!(output_stats.column_statistics.len(), 2); + + // First column (Utf8 literal 'hello') should have proper constant statistics + // but byte_size should be Absent for complex types + assert_eq!( + output_stats.column_statistics[0].min_value, + Precision::Exact(ScalarValue::Utf8(Some("hello".to_string()))) + ); + assert_eq!( + output_stats.column_statistics[0].max_value, + Precision::Exact(ScalarValue::Utf8(Some("hello".to_string()))) + ); + assert_eq!( + output_stats.column_statistics[0].distinct_count, + Precision::Exact(1) + ); + assert_eq!( + output_stats.column_statistics[0].null_count, + Precision::Exact(0) + ); + // Complex types (Utf8, List, etc.) should have byte_size = Absent + // because we can't calculate exact size without knowing the actual data + assert_eq!( + output_stats.column_statistics[0].byte_size, + Precision::Absent + ); + // Non-numeric types (Utf8) should have sum_value = Absent + // because sum is only meaningful for numeric types + assert_eq!( + output_stats.column_statistics[0].sum_value, + Precision::Absent + ); + + // Second column (col0) should preserve statistics + assert_eq!( + output_stats.column_statistics[1].distinct_count, + Precision::Exact(5) + ); + assert_eq!( + output_stats.column_statistics[1].max_value, + Precision::Exact(ScalarValue::Int64(Some(21))) + ); + + Ok(()) + } }