Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 31 additions & 5 deletions native/spark-expr/src/conversion_funcs/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -691,11 +691,13 @@ macro_rules! cast_decimal_to_int16_down {
.map(|value| match value {
Some(value) => {
let divisor = 10_i128.pow($scale as u32);
let (truncated, decimal) = (value / divisor, (value % divisor).abs());
let truncated = value / divisor;
let is_overflow = truncated.abs() > i32::MAX.into();
let fmt_str =
format_decimal_str(&value.to_string(), $precision as usize, $scale);
if is_overflow {
return Err(cast_overflow(
&format!("{}.{}BD", truncated, decimal),
&format!("{}BD", fmt_str),
&format!("DECIMAL({},{})", $precision, $scale),
$dest_type_str,
));
Expand All @@ -704,7 +706,7 @@ macro_rules! cast_decimal_to_int16_down {
<$rust_dest_type>::try_from(i32_value)
.map_err(|_| {
cast_overflow(
&format!("{}.{}BD", truncated, decimal),
&format!("{}BD", fmt_str),
&format!("DECIMAL({},{})", $precision, $scale),
$dest_type_str,
)
Expand Down Expand Up @@ -786,6 +788,30 @@ macro_rules! cast_decimal_to_int32_up {
}};
}

// copied from arrow::dataTypes::Decimal128Type since Decimal128Type::format_decimal can't be called directly
fn format_decimal_str(value_str: &str, precision: usize, scale: i8) -> String {
let (sign, rest) = match value_str.strip_prefix('-') {
Some(stripped) => ("-", stripped),
None => ("", value_str),
};
let bound = precision.min(rest.len()) + sign.len();
let value_str = &value_str[0..bound];

if scale == 0 {
value_str.to_string()
} else if scale < 0 {
let padding = value_str.len() + scale.unsigned_abs() as usize;
format!("{value_str:0<padding$}")
} else if rest.len() > scale as usize {
// Decimal separator is in the middle of the string
let (whole, decimal) = value_str.split_at(value_str.len() - scale as usize);
format!("{whole}.{decimal}")
} else {
// String has to be padded
format!("{}0.{:0>width$}", sign, rest, width = scale as usize)
}
}

impl Cast {
pub fn new(
child: Arc<dyn PhysicalExpr>,
Expand Down Expand Up @@ -1799,12 +1825,12 @@ fn spark_cast_nonintegral_numeric_to_integral(
),
(DataType::Decimal128(precision, scale), DataType::Int8) => {
cast_decimal_to_int16_down!(
array, eval_mode, Int8Array, i8, "TINYINT", precision, *scale
array, eval_mode, Int8Array, i8, "TINYINT", *precision, *scale
)
}
(DataType::Decimal128(precision, scale), DataType::Int16) => {
cast_decimal_to_int16_down!(
array, eval_mode, Int16Array, i16, "SMALLINT", precision, *scale
array, eval_mode, Int16Array, i16, "SMALLINT", *precision, *scale
)
}
(DataType::Decimal128(precision, scale), DataType::Int32) => {
Expand Down
7 changes: 7 additions & 0 deletions spark/src/test/scala/org/apache/comet/CometCastSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -529,6 +529,9 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {

test("cast DecimalType(10,2) to ShortType") {
castTest(generateDecimalsPrecision10Scale2(), DataTypes.ShortType)
castTest(
generateDecimalsPrecision10Scale2(Seq(BigDecimal("-96833550.07"))),
DataTypes.ShortType)
}

test("cast DecimalType(10,2) to IntegerType") {
Expand Down Expand Up @@ -1189,6 +1192,10 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
BigDecimal("32768.678"),
BigDecimal("123456.789"),
BigDecimal("99999999.999"))
generateDecimalsPrecision10Scale2(values)
}

private def generateDecimalsPrecision10Scale2(values: Seq[BigDecimal]): DataFrame = {
withNulls(values).toDF("b").withColumn("a", col("b").cast(DecimalType(10, 2))).drop("b")
}

Expand Down
Loading