diff --git a/native/spark-expr/src/conversion_funcs/cast.rs b/native/spark-expr/src/conversion_funcs/cast.rs index 6b69c72882..7bde159d6b 100644 --- a/native/spark-expr/src/conversion_funcs/cast.rs +++ b/native/spark-expr/src/conversion_funcs/cast.rs @@ -691,11 +691,13 @@ macro_rules! cast_decimal_to_int16_down { .map(|value| match value { Some(value) => { let divisor = 10_i128.pow($scale as u32); - let (truncated, decimal) = (value / divisor, (value % divisor).abs()); + let truncated = value / divisor; let is_overflow = truncated.abs() > i32::MAX.into(); + let fmt_str = + format_decimal_str(&value.to_string(), $precision as usize, $scale); if is_overflow { return Err(cast_overflow( - &format!("{}.{}BD", truncated, decimal), + &format!("{}BD", fmt_str), &format!("DECIMAL({},{})", $precision, $scale), $dest_type_str, )); @@ -704,7 +706,7 @@ macro_rules! cast_decimal_to_int16_down { <$rust_dest_type>::try_from(i32_value) .map_err(|_| { cast_overflow( - &format!("{}.{}BD", truncated, decimal), + &format!("{}BD", fmt_str), &format!("DECIMAL({},{})", $precision, $scale), $dest_type_str, ) @@ -786,6 +788,30 @@ macro_rules! cast_decimal_to_int32_up { }}; } +// copied from arrow::dataTypes::Decimal128Type since Decimal128Type::format_decimal can't be called directly +fn format_decimal_str(value_str: &str, precision: usize, scale: i8) -> String { + let (sign, rest) = match value_str.strip_prefix('-') { + Some(stripped) => ("-", stripped), + None => ("", value_str), + }; + let bound = precision.min(rest.len()) + sign.len(); + let value_str = &value_str[0..bound]; + + if scale == 0 { + value_str.to_string() + } else if scale < 0 { + let padding = value_str.len() + scale.unsigned_abs() as usize; + format!("{value_str:0 scale as usize { + // Decimal separator is in the middle of the string + let (whole, decimal) = value_str.split_at(value_str.len() - scale as usize); + format!("{whole}.{decimal}") + } else { + // String has to be padded + format!("{}0.{:0>width$}", sign, rest, width = scale as usize) + } +} + impl Cast { pub fn new( child: Arc, @@ -1799,12 +1825,12 @@ fn spark_cast_nonintegral_numeric_to_integral( ), (DataType::Decimal128(precision, scale), DataType::Int8) => { cast_decimal_to_int16_down!( - array, eval_mode, Int8Array, i8, "TINYINT", precision, *scale + array, eval_mode, Int8Array, i8, "TINYINT", *precision, *scale ) } (DataType::Decimal128(precision, scale), DataType::Int16) => { cast_decimal_to_int16_down!( - array, eval_mode, Int16Array, i16, "SMALLINT", precision, *scale + array, eval_mode, Int16Array, i16, "SMALLINT", *precision, *scale ) } (DataType::Decimal128(precision, scale), DataType::Int32) => { diff --git a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala index 6f58811f80..b2eba9a58b 100644 --- a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala @@ -529,6 +529,9 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { test("cast DecimalType(10,2) to ShortType") { castTest(generateDecimalsPrecision10Scale2(), DataTypes.ShortType) + castTest( + generateDecimalsPrecision10Scale2(Seq(BigDecimal("-96833550.07"))), + DataTypes.ShortType) } test("cast DecimalType(10,2) to IntegerType") { @@ -1189,6 +1192,10 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { BigDecimal("32768.678"), BigDecimal("123456.789"), BigDecimal("99999999.999")) + generateDecimalsPrecision10Scale2(values) + } + + private def generateDecimalsPrecision10Scale2(values: Seq[BigDecimal]): DataFrame = { withNulls(values).toDF("b").withColumn("a", col("b").cast(DecimalType(10, 2))).drop("b") }