Skip to content

Commit a4279f3

Browse files
committed
fix: format decimal to string when casting to integral type
1 parent 5ec12d4 commit a4279f3

File tree

2 files changed

+47
-11
lines changed

2 files changed

+47
-11
lines changed

native/spark-expr/src/conversion_funcs/cast.rs

Lines changed: 35 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -692,11 +692,13 @@ macro_rules! cast_decimal_to_int16_down {
692692
.map(|value| match value {
693693
Some(value) => {
694694
let divisor = 10_i128.pow($scale as u32);
695-
let (truncated, decimal) = (value / divisor, (value % divisor).abs());
695+
let truncated = value / divisor;
696696
let is_overflow = truncated.abs() > i32::MAX.into();
697+
let fmt_str =
698+
format_decimal_str(&value.to_string(), $precision as usize, $scale);
697699
if is_overflow {
698700
return Err(cast_overflow(
699-
&format!("{}.{}BD", truncated, decimal),
701+
&format!("{}BD", fmt_str),
700702
&format!("DECIMAL({},{})", $precision, $scale),
701703
$dest_type_str,
702704
));
@@ -705,7 +707,7 @@ macro_rules! cast_decimal_to_int16_down {
705707
<$rust_dest_type>::try_from(i32_value)
706708
.map_err(|_| {
707709
cast_overflow(
708-
&format!("{}.{}BD", truncated, decimal),
710+
&format!("{}BD", fmt_str),
709711
&format!("DECIMAL({},{})", $precision, $scale),
710712
$dest_type_str,
711713
)
@@ -755,11 +757,13 @@ macro_rules! cast_decimal_to_int32_up {
755757
.map(|value| match value {
756758
Some(value) => {
757759
let divisor = 10_i128.pow($scale as u32);
758-
let (truncated, decimal) = (value / divisor, (value % divisor).abs());
760+
let truncated = value / divisor;
759761
let is_overflow = truncated.abs() > $max_dest_val.into();
762+
let fmt_str =
763+
format_decimal_str(&value.to_string(), $precision as usize, $scale);
760764
if is_overflow {
761765
return Err(cast_overflow(
762-
&format!("{}.{}BD", truncated, decimal),
766+
&format!("{}BD", fmt_str),
763767
&format!("DECIMAL({},{})", $precision, $scale),
764768
$dest_type_str,
765769
));
@@ -787,6 +791,30 @@ macro_rules! cast_decimal_to_int32_up {
787791
}};
788792
}
789793

794+
// copied from arrow::dataTypes::Decimal128Type since Decimal128Type::format_decimal can't be called directly
795+
fn format_decimal_str(value_str: &str, precision: usize, scale: i8) -> String {
796+
let (sign, rest) = match value_str.strip_prefix('-') {
797+
Some(stripped) => ("-", stripped),
798+
None => ("", value_str),
799+
};
800+
let bound = precision.min(rest.len()) + sign.len();
801+
let value_str = &value_str[0..bound];
802+
803+
if scale == 0 {
804+
value_str.to_string()
805+
} else if scale < 0 {
806+
let padding = value_str.len() + scale.unsigned_abs() as usize;
807+
format!("{value_str:0<padding$}")
808+
} else if rest.len() > scale as usize {
809+
// Decimal separator is in the middle of the string
810+
let (whole, decimal) = value_str.split_at(value_str.len() - scale as usize);
811+
format!("{whole}.{decimal}")
812+
} else {
813+
// String has to be padded
814+
format!("{}0.{:0>width$}", sign, rest, width = scale as usize)
815+
}
816+
}
817+
790818
impl Cast {
791819
pub fn new(
792820
child: Arc<dyn PhysicalExpr>,
@@ -1794,12 +1822,12 @@ fn spark_cast_nonintegral_numeric_to_integral(
17941822
),
17951823
(DataType::Decimal128(precision, scale), DataType::Int8) => {
17961824
cast_decimal_to_int16_down!(
1797-
array, eval_mode, Int8Array, i8, "TINYINT", precision, *scale
1825+
array, eval_mode, Int8Array, i8, "TINYINT", *precision, *scale
17981826
)
17991827
}
18001828
(DataType::Decimal128(precision, scale), DataType::Int16) => {
18011829
cast_decimal_to_int16_down!(
1802-
array, eval_mode, Int16Array, i16, "SMALLINT", precision, *scale
1830+
array, eval_mode, Int16Array, i16, "SMALLINT", *precision, *scale
18031831
)
18041832
}
18051833
(DataType::Decimal128(precision, scale), DataType::Int32) => {

spark/src/test/scala/org/apache/comet/CometCastSuite.scala

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1030,7 +1030,9 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
10301030

10311031
test("cast between decimals with higher precision than source") {
10321032
// cast between Decimal(10, 2) to Decimal(10,4)
1033-
castTest(generateDecimalsPrecision10Scale2(), DataTypes.createDecimalType(10, 4))
1033+
castTest(
1034+
generateDecimalsPrecision10Scale2(testDecimalToDecimal = true),
1035+
DataTypes.createDecimalType(10, 4))
10341036
}
10351037

10361038
test("cast between decimals with negative precision") {
@@ -1044,7 +1046,9 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
10441046

10451047
test("cast between decimals with zero precision") {
10461048
// cast between Decimal(10, 2) to Decimal(10,0)
1047-
castTest(generateDecimalsPrecision10Scale2(), DataTypes.createDecimalType(10, 0))
1049+
castTest(
1050+
generateDecimalsPrecision10Scale2(testDecimalToDecimal = true),
1051+
DataTypes.createDecimalType(10, 0))
10481052
}
10491053

10501054
test("cast ArrayType to StringType") {
@@ -1116,8 +1120,9 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
11161120
}
11171121
}
11181122

1119-
private def generateDecimalsPrecision10Scale2(): DataFrame = {
1120-
val values = Seq(
1123+
private def generateDecimalsPrecision10Scale2(
1124+
testDecimalToDecimal: Boolean = false): DataFrame = {
1125+
var values = Seq(
11211126
BigDecimal("-99999999.999"),
11221127
BigDecimal("-123456.789"),
11231128
BigDecimal("-32768.678"),
@@ -1135,6 +1140,9 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
11351140
BigDecimal("32768.678"),
11361141
BigDecimal("123456.789"),
11371142
BigDecimal("99999999.999"))
1143+
if (!testDecimalToDecimal) {
1144+
values ++= Seq(BigDecimal("-96833550.07"), BigDecimal("-96833550.7"))
1145+
}
11381146
withNulls(values).toDF("b").withColumn("a", col("b").cast(DecimalType(10, 2))).drop("b")
11391147
}
11401148

0 commit comments

Comments
 (0)