Skip to content

Commit ee82678

Browse files
committed
fix: format decimal to string when casting to integral type
1 parent 5ec12d4 commit ee82678

File tree

2 files changed

+38
-8
lines changed

2 files changed

+38
-8
lines changed

native/spark-expr/src/conversion_funcs/cast.rs

Lines changed: 35 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -692,11 +692,13 @@ macro_rules! cast_decimal_to_int16_down {
692692
.map(|value| match value {
693693
Some(value) => {
694694
let divisor = 10_i128.pow($scale as u32);
695-
let (truncated, decimal) = (value / divisor, (value % divisor).abs());
695+
let truncated = value / divisor;
696696
let is_overflow = truncated.abs() > i32::MAX.into();
697+
let fmt_str =
698+
format_decimal_str(&value.to_string(), $precision as usize, $scale);
697699
if is_overflow {
698700
return Err(cast_overflow(
699-
&format!("{}.{}BD", truncated, decimal),
701+
&format!("{}BD", fmt_str),
700702
&format!("DECIMAL({},{})", $precision, $scale),
701703
$dest_type_str,
702704
));
@@ -705,7 +707,7 @@ macro_rules! cast_decimal_to_int16_down {
705707
<$rust_dest_type>::try_from(i32_value)
706708
.map_err(|_| {
707709
cast_overflow(
708-
&format!("{}.{}BD", truncated, decimal),
710+
&format!("{}BD", fmt_str),
709711
&format!("DECIMAL({},{})", $precision, $scale),
710712
$dest_type_str,
711713
)
@@ -755,11 +757,13 @@ macro_rules! cast_decimal_to_int32_up {
755757
.map(|value| match value {
756758
Some(value) => {
757759
let divisor = 10_i128.pow($scale as u32);
758-
let (truncated, decimal) = (value / divisor, (value % divisor).abs());
760+
let truncated = value / divisor;
759761
let is_overflow = truncated.abs() > $max_dest_val.into();
762+
let fmt_str =
763+
format_decimal_str(&value.to_string(), $precision as usize, $scale);
760764
if is_overflow {
761765
return Err(cast_overflow(
762-
&format!("{}.{}BD", truncated, decimal),
766+
&format!("{}BD", fmt_str),
763767
&format!("DECIMAL({},{})", $precision, $scale),
764768
$dest_type_str,
765769
));
@@ -787,6 +791,30 @@ macro_rules! cast_decimal_to_int32_up {
787791
}};
788792
}
789793

794+
// copied from arrow::dataTypes::Decimal128Type since Decimal128Type::format_decimal can't be called directly
795+
fn format_decimal_str(value_str: &str, precision: usize, scale: i8) -> String {
796+
let (sign, rest) = match value_str.strip_prefix('-') {
797+
Some(stripped) => ("-", stripped),
798+
None => ("", value_str),
799+
};
800+
let bound = precision.min(rest.len()) + sign.len();
801+
let value_str = &value_str[0..bound];
802+
803+
if scale == 0 {
804+
value_str.to_string()
805+
} else if scale < 0 {
806+
let padding = value_str.len() + scale.unsigned_abs() as usize;
807+
format!("{value_str:0<padding$}")
808+
} else if rest.len() > scale as usize {
809+
// Decimal separator is in the middle of the string
810+
let (whole, decimal) = value_str.split_at(value_str.len() - scale as usize);
811+
format!("{whole}.{decimal}")
812+
} else {
813+
// String has to be padded
814+
format!("{}0.{:0>width$}", sign, rest, width = scale as usize)
815+
}
816+
}
817+
790818
impl Cast {
791819
pub fn new(
792820
child: Arc<dyn PhysicalExpr>,
@@ -1794,12 +1822,12 @@ fn spark_cast_nonintegral_numeric_to_integral(
17941822
),
17951823
(DataType::Decimal128(precision, scale), DataType::Int8) => {
17961824
cast_decimal_to_int16_down!(
1797-
array, eval_mode, Int8Array, i8, "TINYINT", precision, *scale
1825+
array, eval_mode, Int8Array, i8, "TINYINT", *precision, *scale
17981826
)
17991827
}
18001828
(DataType::Decimal128(precision, scale), DataType::Int16) => {
18011829
cast_decimal_to_int16_down!(
1802-
array, eval_mode, Int16Array, i16, "SMALLINT", precision, *scale
1830+
array, eval_mode, Int16Array, i16, "SMALLINT", *precision, *scale
18031831
)
18041832
}
18051833
(DataType::Decimal128(precision, scale), DataType::Int32) => {

spark/src/test/scala/org/apache/comet/CometCastSuite.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1118,7 +1118,9 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
11181118

11191119
private def generateDecimalsPrecision10Scale2(): DataFrame = {
11201120
val values = Seq(
1121-
BigDecimal("-99999999.999"),
1121+
BigDecimal("-96833550.07"),
1122+
BigDecimal("-96833550.7"),
1123+
BigDecimal("-99999999.99"),
11221124
BigDecimal("-123456.789"),
11231125
BigDecimal("-32768.678"),
11241126
// Short Min

0 commit comments

Comments
 (0)