Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
267 changes: 264 additions & 3 deletions datafusion/physical-expr/src/projection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,15 @@ use std::ops::Deref;
use std::sync::Arc;

use crate::PhysicalExpr;
use crate::expressions::Column;
use crate::expressions::{Column, Literal};
use crate::utils::collect_columns;

use arrow::array::{RecordBatch, RecordBatchOptions};
use arrow::datatypes::{Field, Schema, SchemaRef};
use datafusion_common::stats::ColumnStatistics;
use datafusion_common::stats::{ColumnStatistics, Precision};
use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
use datafusion_common::{
Result, assert_or_internal_err, internal_datafusion_err, plan_err,
Result, ScalarValue, assert_or_internal_err, internal_datafusion_err, plan_err,
};

use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr};
Expand Down Expand Up @@ -587,6 +587,54 @@ impl ProjectionExprs {
let expr = &proj_expr.expr;
let col_stats = if let Some(col) = expr.as_any().downcast_ref::<Column>() {
std::mem::take(&mut stats.column_statistics[col.index()])
} else if let Some(literal) = expr.as_any().downcast_ref::<Literal>() {
// Handle literal expressions (constants) by calculating proper statistics
let data_type = expr.data_type(output_schema)?;

if literal.value().is_null() {
let null_count = match stats.num_rows {
Precision::Exact(num_rows) => Precision::Exact(num_rows),
_ => Precision::Absent,
};

ColumnStatistics {
min_value: Precision::Absent,
max_value: Precision::Absent,
distinct_count: Precision::Exact(1),
null_count,
sum_value: Precision::Absent,
byte_size: Precision::Absent,
}
} else {
let value = literal.value();
let distinct_count = Precision::Exact(1);
let null_count = Precision::Exact(0);

let byte_size = if let Some(byte_width) = data_type.primitive_width()
{
stats.num_rows.multiply(&Precision::Exact(byte_width))
} else {
// Complex types depend on array encoding, so set to Absent
Precision::Absent
};

let sum_value = Precision::<ScalarValue>::from(stats.num_rows)
.cast_to(&value.data_type())
.ok()
.map(|row_count| {
Precision::Exact(value.clone()).multiply(&row_count)
})
.unwrap_or(Precision::Absent);

ColumnStatistics {
min_value: Precision::Exact(value.clone()),
max_value: Precision::Exact(value.clone()),
distinct_count,
null_count,
sum_value,
byte_size,
}
}
} else {
// TODO stats: estimate more statistics from expressions
// (expressions should compute their statistics themselves)
Expand Down Expand Up @@ -2593,4 +2641,217 @@ pub(crate) mod tests {

Ok(())
}

// Test statistics calculation for non-null literal (numeric constant)
#[test]
fn test_project_statistics_with_literal() -> Result<()> {
let input_stats = get_stats();
let input_schema = get_schema();

// Projection with literal: SELECT 42 AS constant, col0 AS num
let projection = ProjectionExprs::new(vec![
ProjectionExpr {
expr: Arc::new(Literal::new(ScalarValue::Int64(Some(42)))),
alias: "constant".to_string(),
},
ProjectionExpr {
expr: Arc::new(Column::new("col0", 0)),
alias: "num".to_string(),
},
]);

let output_stats = projection.project_statistics(
input_stats,
&projection.project_schema(&input_schema)?,
)?;

// Row count should be preserved
assert_eq!(output_stats.num_rows, Precision::Exact(5));

// Should have 2 column statistics
assert_eq!(output_stats.column_statistics.len(), 2);

// First column (literal 42) should have proper constant statistics
assert_eq!(
output_stats.column_statistics[0].min_value,
Precision::Exact(ScalarValue::Int64(Some(42)))
);
assert_eq!(
output_stats.column_statistics[0].max_value,
Precision::Exact(ScalarValue::Int64(Some(42)))
);
assert_eq!(
output_stats.column_statistics[0].distinct_count,
Precision::Exact(1)
);
assert_eq!(
output_stats.column_statistics[0].null_count,
Precision::Exact(0)
);
// Int64 is 8 bytes, 5 rows = 40 bytes
assert_eq!(
output_stats.column_statistics[0].byte_size,
Precision::Exact(40)
);
// For a constant column, sum_value = value * num_rows = 42 * 5 = 210
assert_eq!(
output_stats.column_statistics[0].sum_value,
Precision::Exact(ScalarValue::Int64(Some(210)))
);

// Second column (col0) should preserve statistics
assert_eq!(
output_stats.column_statistics[1].distinct_count,
Precision::Exact(5)
);
assert_eq!(
output_stats.column_statistics[1].max_value,
Precision::Exact(ScalarValue::Int64(Some(21)))
);

Ok(())
}

// Test statistics calculation for NULL literal (constant NULL column)
#[test]
fn test_project_statistics_with_null_literal() -> Result<()> {
let input_stats = get_stats();
let input_schema = get_schema();

// Projection with NULL literal: SELECT NULL AS null_col, col0 AS num
let projection = ProjectionExprs::new(vec![
ProjectionExpr {
expr: Arc::new(Literal::new(ScalarValue::Int64(None))),
alias: "null_col".to_string(),
},
ProjectionExpr {
expr: Arc::new(Column::new("col0", 0)),
alias: "num".to_string(),
},
]);

let output_stats = projection.project_statistics(
input_stats,
&projection.project_schema(&input_schema)?,
)?;

// Row count should be preserved
assert_eq!(output_stats.num_rows, Precision::Exact(5));

// Should have 2 column statistics
assert_eq!(output_stats.column_statistics.len(), 2);

// First column (NULL literal) should have proper constant NULL statistics
assert_eq!(
output_stats.column_statistics[0].min_value,
Precision::Absent // NULLs don't have min/max
);
assert_eq!(
output_stats.column_statistics[0].max_value,
Precision::Absent // NULLs don't have min/max
);
assert_eq!(
output_stats.column_statistics[0].distinct_count,
Precision::Exact(1) // All NULLs are considered the same
);
assert_eq!(
output_stats.column_statistics[0].null_count,
Precision::Exact(5) // All rows are NULL
);
assert_eq!(
output_stats.column_statistics[0].byte_size,
Precision::Absent // NULLs don't take space
);
assert_eq!(
output_stats.column_statistics[0].sum_value,
Precision::Absent // Sum doesn't make sense for NULLs
);

// Second column (col0) should preserve statistics
assert_eq!(
output_stats.column_statistics[1].distinct_count,
Precision::Exact(5)
);
assert_eq!(
output_stats.column_statistics[1].max_value,
Precision::Exact(ScalarValue::Int64(Some(21)))
);

Ok(())
}

// Test statistics calculation for complex type literal (e.g., Utf8 string)
#[test]
fn test_project_statistics_with_complex_type_literal() -> Result<()> {
let input_stats = get_stats();
let input_schema = get_schema();

// Projection with Utf8 literal (complex type): SELECT 'hello' AS text, col0 AS num
let projection = ProjectionExprs::new(vec![
ProjectionExpr {
expr: Arc::new(Literal::new(ScalarValue::Utf8(Some(
"hello".to_string(),
)))),
alias: "text".to_string(),
},
ProjectionExpr {
expr: Arc::new(Column::new("col0", 0)),
alias: "num".to_string(),
},
]);

let output_stats = projection.project_statistics(
input_stats,
&projection.project_schema(&input_schema)?,
)?;

// Row count should be preserved
assert_eq!(output_stats.num_rows, Precision::Exact(5));

// Should have 2 column statistics
assert_eq!(output_stats.column_statistics.len(), 2);

// First column (Utf8 literal 'hello') should have proper constant statistics
// but byte_size should be Absent for complex types
assert_eq!(
output_stats.column_statistics[0].min_value,
Precision::Exact(ScalarValue::Utf8(Some("hello".to_string())))
);
assert_eq!(
output_stats.column_statistics[0].max_value,
Precision::Exact(ScalarValue::Utf8(Some("hello".to_string())))
);
assert_eq!(
output_stats.column_statistics[0].distinct_count,
Precision::Exact(1)
);
assert_eq!(
output_stats.column_statistics[0].null_count,
Precision::Exact(0)
);
// Complex types (Utf8, List, etc.) should have byte_size = Absent
// because we can't calculate exact size without knowing the actual data
assert_eq!(
output_stats.column_statistics[0].byte_size,
Precision::Absent
);
// Non-numeric types (Utf8) should have sum_value = Absent
// because sum is only meaningful for numeric types
assert_eq!(
output_stats.column_statistics[0].sum_value,
Precision::Absent
);

// Second column (col0) should preserve statistics
assert_eq!(
output_stats.column_statistics[1].distinct_count,
Precision::Exact(5)
);
assert_eq!(
output_stats.column_statistics[1].max_value,
Precision::Exact(ScalarValue::Int64(Some(21)))
);

Ok(())
}
}