Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
333 changes: 278 additions & 55 deletions datafusion/functions/src/math/floor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,10 @@ use std::any::Any;
use std::sync::Arc;

use arrow::array::{ArrayRef, AsArray};
use arrow::compute::{DecimalCast, rescale_decimal};
use arrow::datatypes::{
DataType, Decimal32Type, Decimal64Type, Decimal128Type, Decimal256Type, Float32Type,
Float64Type,
ArrowNativeTypeOp, DataType, Decimal32Type, Decimal64Type, Decimal128Type,
Decimal256Type, DecimalType, Float32Type, Float64Type,
};
use datafusion_common::{Result, ScalarValue, exec_err};
use datafusion_expr::interval_arithmetic::Interval;
Expand Down Expand Up @@ -77,6 +78,42 @@ impl FloorFunc {
}
}

// ============ Macro for preimage bounds ============
/// Generates the code to call the appropriate bounds function and wrap results.
macro_rules! preimage_bounds {
// Float types: call float_preimage_bounds and wrap in ScalarValue
(float: $variant:ident, $value:expr) => {
float_preimage_bounds($value).map(|(lo, hi)| {
(
ScalarValue::$variant(Some(lo)),
ScalarValue::$variant(Some(hi)),
)
})
};

// Integer types: call int_preimage_bounds and wrap in ScalarValue
(int: $variant:ident, $value:expr) => {
int_preimage_bounds($value).map(|(lo, hi)| {
(
ScalarValue::$variant(Some(lo)),
ScalarValue::$variant(Some(hi)),
)
})
};

// Decimal types: call decimal_preimage_bounds with precision/scale and wrap in ScalarValue
(decimal: $variant:ident, $decimal_type:ty, $value:expr, $precision:expr, $scale:expr) => {
decimal_preimage_bounds::<$decimal_type>($value, $precision, $scale).map(
|(lo, hi)| {
(
ScalarValue::$variant(Some(lo), $precision, $scale),
ScalarValue::$variant(Some(hi), $precision, $scale),
)
},
)
};
}

impl ScalarUDFImpl for FloorFunc {
fn as_any(&self) -> &dyn Any {
self
Expand Down Expand Up @@ -216,10 +253,8 @@ impl ScalarUDFImpl for FloorFunc {
lit_expr: &Expr,
_info: &SimplifyContext,
) -> Result<PreimageResult> {
// floor takes exactly one argument
if args.len() != 1 {
return Ok(PreimageResult::None);
}
// floor takes exactly one argument and we do not expect to reach here with multiple arguments.
debug_assert!(args.len() == 1, "floor() takes exactly one argument");

let arg = args[0].clone();

Expand All @@ -230,35 +265,34 @@ impl ScalarUDFImpl for FloorFunc {

// Compute lower bound (N) and upper bound (N + 1) using helper functions
let Some((lower, upper)) = (match lit_value {
// Decimal types should be supported and tracked in
// https://github.com/apache/datafusion/issues/20080
// Floating-point types
ScalarValue::Float64(Some(n)) => float_preimage_bounds(*n).map(|(lo, hi)| {
(
ScalarValue::Float64(Some(lo)),
ScalarValue::Float64(Some(hi)),
)
}),
ScalarValue::Float32(Some(n)) => float_preimage_bounds(*n).map(|(lo, hi)| {
(
ScalarValue::Float32(Some(lo)),
ScalarValue::Float32(Some(hi)),
)
}),

// Integer types
ScalarValue::Int8(Some(n)) => int_preimage_bounds(*n).map(|(lo, hi)| {
(ScalarValue::Int8(Some(lo)), ScalarValue::Int8(Some(hi)))
}),
ScalarValue::Int16(Some(n)) => int_preimage_bounds(*n).map(|(lo, hi)| {
(ScalarValue::Int16(Some(lo)), ScalarValue::Int16(Some(hi)))
}),
ScalarValue::Int32(Some(n)) => int_preimage_bounds(*n).map(|(lo, hi)| {
(ScalarValue::Int32(Some(lo)), ScalarValue::Int32(Some(hi)))
}),
ScalarValue::Int64(Some(n)) => int_preimage_bounds(*n).map(|(lo, hi)| {
(ScalarValue::Int64(Some(lo)), ScalarValue::Int64(Some(hi)))
}),
ScalarValue::Float64(Some(n)) => preimage_bounds!(float: Float64, *n),
ScalarValue::Float32(Some(n)) => preimage_bounds!(float: Float32, *n),

// Integer types (not reachable from SQL/SLT: floor() only accepts Float64/Float32/Decimal,
// so the RHS literal is always coerced to one of those before preimage runs; kept for
// programmatic use and unit tests)
ScalarValue::Int8(Some(n)) => preimage_bounds!(int: Int8, *n),
ScalarValue::Int16(Some(n)) => preimage_bounds!(int: Int16, *n),
ScalarValue::Int32(Some(n)) => preimage_bounds!(int: Int32, *n),
ScalarValue::Int64(Some(n)) => preimage_bounds!(int: Int64, *n),

// Decimal types
// DECIMAL(precision, scale) where precision ≤ 38 -> Decimal128(precision, scale)
// DECIMAL(precision, scale) where precision > 38 -> Decimal256(precision, scale)
// Decimal32 and Decimal64 are unreachable from SQL/SLT.
ScalarValue::Decimal32(Some(n), precision, scale) => {
preimage_bounds!(decimal: Decimal32, Decimal32Type, *n, *precision, *scale)
}
ScalarValue::Decimal64(Some(n), precision, scale) => {
preimage_bounds!(decimal: Decimal64, Decimal64Type, *n, *precision, *scale)
}
ScalarValue::Decimal128(Some(n), precision, scale) => {
preimage_bounds!(decimal: Decimal128, Decimal128Type, *n, *precision, *scale)
}
ScalarValue::Decimal256(Some(n), precision, scale) => {
preimage_bounds!(decimal: Decimal256, Decimal256Type, *n, *precision, *scale)
}

// Unsupported types
_ => None,
Expand Down Expand Up @@ -310,9 +344,49 @@ fn int_preimage_bounds<I: CheckedAdd + One + Copy>(n: I) -> Option<(I, I)> {
Some((n, upper))
}

/// Compute preimage bounds for floor function on decimal types.
/// For floor(x) = n, the preimage is [n, n+1).
/// Returns None if:
/// - The value has a fractional part (floor always returns integers)
/// - Adding 1 would overflow
fn decimal_preimage_bounds<D: DecimalType>(
value: D::Native,
precision: u8,
scale: i8,
) -> Option<(D::Native, D::Native)>
where
D::Native: DecimalCast + ArrowNativeTypeOp + std::ops::Rem<Output = D::Native>,
{
// Use rescale_decimal to compute "1" at target scale (avoids manual pow)
// Convert integer 1 (scale=0) to the target scale
let one_scaled: D::Native = rescale_decimal::<D, D>(
D::Native::ONE, // value = 1
1, // input_precision = 1
0, // input_scale = 0 (integer)
precision, // output_precision
scale, // output_scale
)?;

// floor always returns an integer, so if value has a fractional part, there's no solution
// Check: value % one_scaled != 0 means fractional part exists
if scale > 0 && value % one_scaled != D::Native::ZERO {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about negative scales ?
Arrow supports them.

Copy link
Contributor Author

@devanshu0987 devanshu0987 Feb 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, the logic here comes from the idea that floor(x) = 1.3 has no pre-image solution. We do not want to optimize it via this.

floor(x) will always return a floor data type but it will actually be an integer.

    // floor always returns an integer, so if n has a fractional part, there's no solution
    if n.fract() != F::zero() {
        return None;
    }

Same logic is carried over here for Decimal.

scale <= 0 is effectively an integer and hence only possibility is to check for overflow which we check in the next line.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code review comment in the last PR: #20059 (comment)

return None;
}

// Compute upper bound using checked addition
// Before preimage stage, the internal i128/i256(value) is validated based on the precision and scale.
// MAX_DECIMAL128_FOR_EACH_PRECISION and MAX_DECIMAL256_FOR_EACH_PRECISION are used to validate the internal i128/i256.
// Any invalid i128/i256 will not reach here.
// Therefore, the add_checked will always succeed if tested via SQL/SLT path.
let upper = value.add_checked(one_scaled).ok()?;

Some((value, upper))
}

#[cfg(test)]
mod tests {
use super::*;
use arrow_buffer::i256;
use datafusion_expr::col;

/// Helper to test valid preimage cases that should return a Range
Expand Down Expand Up @@ -434,33 +508,182 @@ mod tests {
assert_preimage_none(ScalarValue::Int64(None));
}

// ============ Decimal32 Tests (mirrors float/int tests) ============

#[test]
fn test_floor_preimage_invalid_inputs() {
let floor_func = FloorFunc::new();
let info = SimplifyContext::default();
fn test_floor_preimage_decimal_valid_cases() {
// ===== Decimal32 =====
// Positive integer decimal: 100.00 (scale=2, so raw=10000)
// floor(x) = 100.00 -> x in [100.00, 101.00)
assert_preimage_range(
ScalarValue::Decimal32(Some(10000), 9, 2),
ScalarValue::Decimal32(Some(10000), 9, 2), // 100.00
ScalarValue::Decimal32(Some(10100), 9, 2), // 101.00
);

// Non-literal comparison value
let result = floor_func.preimage(&[col("x")], &col("y"), &info).unwrap();
assert!(
matches!(result, PreimageResult::None),
"Expected None for non-literal"
// Smaller positive: 50.00
assert_preimage_range(
ScalarValue::Decimal32(Some(5000), 9, 2),
ScalarValue::Decimal32(Some(5000), 9, 2), // 50.00
ScalarValue::Decimal32(Some(5100), 9, 2), // 51.00
);

// Wrong argument count (too many)
let lit = Expr::Literal(ScalarValue::Float64(Some(100.0)), None);
let result = floor_func
.preimage(&[col("x"), col("y")], &lit, &info)
.unwrap();
assert!(
matches!(result, PreimageResult::None),
"Expected None for wrong arg count"
// Negative integer decimal: -5.00
assert_preimage_range(
ScalarValue::Decimal32(Some(-500), 9, 2),
ScalarValue::Decimal32(Some(-500), 9, 2), // -5.00
ScalarValue::Decimal32(Some(-400), 9, 2), // -4.00
);

// Wrong argument count (zero)
let result = floor_func.preimage(&[], &lit, &info).unwrap();
assert!(
matches!(result, PreimageResult::None),
"Expected None for zero args"
// Zero: 0.00
assert_preimage_range(
ScalarValue::Decimal32(Some(0), 9, 2),
ScalarValue::Decimal32(Some(0), 9, 2), // 0.00
ScalarValue::Decimal32(Some(100), 9, 2), // 1.00
);

// Scale 0 (pure integer): 42
assert_preimage_range(
ScalarValue::Decimal32(Some(42), 9, 0),
ScalarValue::Decimal32(Some(42), 9, 0),
ScalarValue::Decimal32(Some(43), 9, 0),
);

// ===== Decimal64 =====
assert_preimage_range(
ScalarValue::Decimal64(Some(10000), 18, 2),
ScalarValue::Decimal64(Some(10000), 18, 2), // 100.00
ScalarValue::Decimal64(Some(10100), 18, 2), // 101.00
);

// Negative
assert_preimage_range(
ScalarValue::Decimal64(Some(-500), 18, 2),
ScalarValue::Decimal64(Some(-500), 18, 2), // -5.00
ScalarValue::Decimal64(Some(-400), 18, 2), // -4.00
);

// Zero
assert_preimage_range(
ScalarValue::Decimal64(Some(0), 18, 2),
ScalarValue::Decimal64(Some(0), 18, 2),
ScalarValue::Decimal64(Some(100), 18, 2),
);

// ===== Decimal128 =====
assert_preimage_range(
ScalarValue::Decimal128(Some(10000), 38, 2),
ScalarValue::Decimal128(Some(10000), 38, 2), // 100.00
ScalarValue::Decimal128(Some(10100), 38, 2), // 101.00
);

// Negative
assert_preimage_range(
ScalarValue::Decimal128(Some(-500), 38, 2),
ScalarValue::Decimal128(Some(-500), 38, 2), // -5.00
ScalarValue::Decimal128(Some(-400), 38, 2), // -4.00
);

// Zero
assert_preimage_range(
ScalarValue::Decimal128(Some(0), 38, 2),
ScalarValue::Decimal128(Some(0), 38, 2),
ScalarValue::Decimal128(Some(100), 38, 2),
);

// ===== Decimal256 =====
assert_preimage_range(
ScalarValue::Decimal256(Some(i256::from(10000)), 76, 2),
ScalarValue::Decimal256(Some(i256::from(10000)), 76, 2), // 100.00
ScalarValue::Decimal256(Some(i256::from(10100)), 76, 2), // 101.00
);

// Negative
assert_preimage_range(
ScalarValue::Decimal256(Some(i256::from(-500)), 76, 2),
ScalarValue::Decimal256(Some(i256::from(-500)), 76, 2), // -5.00
ScalarValue::Decimal256(Some(i256::from(-400)), 76, 2), // -4.00
);

// Zero
assert_preimage_range(
ScalarValue::Decimal256(Some(i256::ZERO), 76, 2),
ScalarValue::Decimal256(Some(i256::ZERO), 76, 2),
ScalarValue::Decimal256(Some(i256::from(100)), 76, 2),
);
}

#[test]
fn test_floor_preimage_decimal_non_integer() {
// floor(x) = 1.30 has NO SOLUTION because floor always returns an integer
// Therefore preimage should return None for non-integer decimals

// Decimal32
assert_preimage_none(ScalarValue::Decimal32(Some(130), 9, 2)); // 1.30
assert_preimage_none(ScalarValue::Decimal32(Some(-250), 9, 2)); // -2.50
assert_preimage_none(ScalarValue::Decimal32(Some(370), 9, 2)); // 3.70
assert_preimage_none(ScalarValue::Decimal32(Some(1), 9, 2)); // 0.01

// Decimal64
assert_preimage_none(ScalarValue::Decimal64(Some(130), 18, 2)); // 1.30
assert_preimage_none(ScalarValue::Decimal64(Some(-250), 18, 2)); // -2.50

// Decimal128
assert_preimage_none(ScalarValue::Decimal128(Some(130), 38, 2)); // 1.30
assert_preimage_none(ScalarValue::Decimal128(Some(-250), 38, 2)); // -2.50

// Decimal256
assert_preimage_none(ScalarValue::Decimal256(Some(i256::from(130)), 76, 2)); // 1.30
assert_preimage_none(ScalarValue::Decimal256(Some(i256::from(-250)), 76, 2)); // -2.50

// Decimal32: i32::MAX - 50
// This return None because the value is not an integer, not because it is out of range.
assert_preimage_none(ScalarValue::Decimal32(Some(i32::MAX - 50), 10, 2));

// Decimal64: i64::MAX - 50
// This return None because the value is not an integer, not because it is out of range.
assert_preimage_none(ScalarValue::Decimal64(Some(i64::MAX - 50), 19, 2));
}

#[test]
fn test_floor_preimage_decimal_overflow() {
// Test near MAX where adding scale_factor would overflow

// Decimal32: i32::MAX
assert_preimage_none(ScalarValue::Decimal32(Some(i32::MAX), 10, 0));

// Decimal64: i64::MAX
assert_preimage_none(ScalarValue::Decimal64(Some(i64::MAX), 19, 0));
}

#[test]
fn test_floor_preimage_decimal_edge_cases() {
// ===== Decimal32 =====
// Large value that doesn't overflow
// Decimal(9,2) max value is 9,999,999.99 (stored as 999,999,999)
// Use a large value that fits Decimal(9,2) and is divisible by 100
let safe_max_aligned_32 = 999_999_900; // 9,999,999.00
assert_preimage_range(
ScalarValue::Decimal32(Some(safe_max_aligned_32), 9, 2),
ScalarValue::Decimal32(Some(safe_max_aligned_32), 9, 2),
ScalarValue::Decimal32(Some(safe_max_aligned_32 + 100), 9, 2),
);

// Negative edge: use a large negative value that fits Decimal(9,2)
// Decimal(9,2) min value is -9,999,999.99 (stored as -999,999,999)
let min_aligned_32 = -999_999_900; // -9,999,999.00
assert_preimage_range(
ScalarValue::Decimal32(Some(min_aligned_32), 9, 2),
ScalarValue::Decimal32(Some(min_aligned_32), 9, 2),
ScalarValue::Decimal32(Some(min_aligned_32 + 100), 9, 2),
);
}

#[test]
fn test_floor_preimage_decimal_null() {
assert_preimage_none(ScalarValue::Decimal32(None, 9, 2));
assert_preimage_none(ScalarValue::Decimal64(None, 18, 2));
assert_preimage_none(ScalarValue::Decimal128(None, 38, 2));
assert_preimage_none(ScalarValue::Decimal256(None, 76, 2));
}
}
Loading