Skip to content

Commit 4678d30

Browse files
committed
fix: format decimal to string when casting to integral type
1 parent 5ec12d4 commit 4678d30

File tree

2 files changed

+35
-9
lines changed

2 files changed

+35
-9
lines changed

native/spark-expr/src/conversion_funcs/cast.rs

Lines changed: 33 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -692,11 +692,12 @@ macro_rules! cast_decimal_to_int16_down {
692692
.map(|value| match value {
693693
Some(value) => {
694694
let divisor = 10_i128.pow($scale as u32);
695-
let (truncated, decimal) = (value / divisor, (value % divisor).abs());
695+
let truncated = value / divisor;
696696
let is_overflow = truncated.abs() > i32::MAX.into();
697+
let fmt_str = format_decimal_str(&value.to_string(), $precision as usize, $scale);
697698
if is_overflow {
698699
return Err(cast_overflow(
699-
&format!("{}.{}BD", truncated, decimal),
700+
&format!("{}BD", fmt_str),
700701
&format!("DECIMAL({},{})", $precision, $scale),
701702
$dest_type_str,
702703
));
@@ -705,7 +706,7 @@ macro_rules! cast_decimal_to_int16_down {
705706
<$rust_dest_type>::try_from(i32_value)
706707
.map_err(|_| {
707708
cast_overflow(
708-
&format!("{}.{}BD", truncated, decimal),
709+
&format!("{}BD", fmt_str),
709710
&format!("DECIMAL({},{})", $precision, $scale),
710711
$dest_type_str,
711712
)
@@ -755,11 +756,12 @@ macro_rules! cast_decimal_to_int32_up {
755756
.map(|value| match value {
756757
Some(value) => {
757758
let divisor = 10_i128.pow($scale as u32);
758-
let (truncated, decimal) = (value / divisor, (value % divisor).abs());
759+
let truncated = value / divisor;
759760
let is_overflow = truncated.abs() > $max_dest_val.into();
761+
let fmt_str = format_decimal_str(&value.to_string(), $precision as usize, $scale);
760762
if is_overflow {
761763
return Err(cast_overflow(
762-
&format!("{}.{}BD", truncated, decimal),
764+
&format!("{}BD", fmt_str),
763765
&format!("DECIMAL({},{})", $precision, $scale),
764766
$dest_type_str,
765767
));
@@ -787,6 +789,30 @@ macro_rules! cast_decimal_to_int32_up {
787789
}};
788790
}
789791

792+
// copied from arrow::dataTypes::Decimal128Type since Decimal128Type::format_decimal can't be called directly
793+
fn format_decimal_str(value_str: &str, precision: usize, scale: i8) -> String {
794+
let (sign, rest) = match value_str.strip_prefix('-') {
795+
Some(stripped) => ("-", stripped),
796+
None => ("", value_str),
797+
};
798+
let bound = precision.min(rest.len()) + sign.len();
799+
let value_str = &value_str[0..bound];
800+
801+
if scale == 0 {
802+
value_str.to_string()
803+
} else if scale < 0 {
804+
let padding = value_str.len() + scale.unsigned_abs() as usize;
805+
format!("{value_str:0<padding$}")
806+
} else if rest.len() > scale as usize {
807+
// Decimal separator is in the middle of the string
808+
let (whole, decimal) = value_str.split_at(value_str.len() - scale as usize);
809+
format!("{whole}.{decimal}")
810+
} else {
811+
// String has to be padded
812+
format!("{}0.{:0>width$}", sign, rest, width = scale as usize)
813+
}
814+
}
815+
790816
impl Cast {
791817
pub fn new(
792818
child: Arc<dyn PhysicalExpr>,
@@ -1794,12 +1820,12 @@ fn spark_cast_nonintegral_numeric_to_integral(
17941820
),
17951821
(DataType::Decimal128(precision, scale), DataType::Int8) => {
17961822
cast_decimal_to_int16_down!(
1797-
array, eval_mode, Int8Array, i8, "TINYINT", precision, *scale
1823+
array, eval_mode, Int8Array, i8, "TINYINT", *precision, *scale
17981824
)
17991825
}
18001826
(DataType::Decimal128(precision, scale), DataType::Int16) => {
18011827
cast_decimal_to_int16_down!(
1802-
array, eval_mode, Int16Array, i16, "SMALLINT", precision, *scale
1828+
array, eval_mode, Int16Array, i16, "SMALLINT", *precision, *scale
18031829
)
18041830
}
18051831
(DataType::Decimal128(precision, scale), DataType::Int32) => {

spark/src/test/scala/org/apache/comet/CometCastSuite.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1118,8 +1118,8 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
11181118

11191119
private def generateDecimalsPrecision10Scale2(): DataFrame = {
11201120
val values = Seq(
1121-
BigDecimal("-99999999.999"),
1122-
BigDecimal("-123456.789"),
1121+
BigDecimal("-99999999.09"),
1122+
BigDecimal("-123456.009"),
11231123
BigDecimal("-32768.678"),
11241124
// Short Min
11251125
BigDecimal("-32767.123"),

0 commit comments

Comments
 (0)