From 87a4b517a86557e3f084df4c995d13376417b111 Mon Sep 17 00:00:00 2001 From: B Vadlamani Date: Tue, 16 Dec 2025 16:12:02 -0800 Subject: [PATCH 1/7] init_string_decimal_sep_pr --- .../spark-expr/src/conversion_funcs/cast.rs | 372 +++++++++++++++++- .../apache/comet/expressions/CometCast.scala | 5 - 2 files changed, 364 insertions(+), 13 deletions(-) diff --git a/native/spark-expr/src/conversion_funcs/cast.rs b/native/spark-expr/src/conversion_funcs/cast.rs index 12a147c6e1..034169bb57 100644 --- a/native/spark-expr/src/conversion_funcs/cast.rs +++ b/native/spark-expr/src/conversion_funcs/cast.rs @@ -222,13 +222,6 @@ fn can_cast_from_string(to_type: &DataType, options: &SparkCastOptions) -> bool // Does not support ANSI mode. options.allow_incompat } - Decimal128(_, _) => { - // https://github.com/apache/datafusion-comet/issues/325 - // Does not support inputs ending with 'd' or 'f'. Does not support 'inf'. - // Does not support ANSI mode. Returns 0.0 instead of null if input contains no digits - - options.allow_incompat - } Date32 | Date64 => { // https://github.com/apache/datafusion-comet/issues/327 // Only supports years between 262143 BC and 262142 AD @@ -976,6 +969,12 @@ fn cast_array( cast_string_to_timestamp(&array, to_type, eval_mode, &cast_options.timezone) } (Utf8, Date32) => cast_string_to_date(&array, to_type, eval_mode), + (Utf8 | LargeUtf8, Decimal128(precision, scale)) => { + cast_string_to_decimal(&array, to_type, precision, scale, eval_mode) + } + (Utf8 | LargeUtf8, Decimal256(precision, scale)) => { + cast_string_to_decimal(&array, to_type, precision, scale, eval_mode) + } (Int64, Int32) | (Int64, Int16) | (Int64, Int8) @@ -1187,7 +1186,7 @@ fn is_datafusion_spark_compatible( ), DataType::Utf8 if allow_incompat => matches!( to_type, - DataType::Binary | DataType::Float32 | DataType::Float64 | DataType::Decimal128(_, _) + DataType::Binary | DataType::Float32 | DataType::Float64 ), DataType::Utf8 => matches!(to_type, DataType::Binary), DataType::Date32 => matches!(to_type, DataType::Utf8), @@ -1976,6 +1975,363 @@ fn do_cast_string_to_int< Ok(Some(result)) } +fn cast_string_to_decimal( + array: &ArrayRef, + to_type: &DataType, + precision: &u8, + scale: &i8, + eval_mode: EvalMode, +) -> SparkResult { + match to_type { + DataType::Decimal128(_, _) => { + cast_string_to_decimal128_impl(array, eval_mode, *precision, *scale) + } + DataType::Decimal256(_, _) => { + cast_string_to_decimal256_impl(array, eval_mode, *precision, *scale) + } + _ => Err(SparkError::Internal(format!( + "Unexpected type in cast_string_to_decimal: {:?}", + to_type + ))), + } +} + +fn cast_string_to_decimal128_impl( + array: &ArrayRef, + eval_mode: EvalMode, + precision: u8, + scale: i8, +) -> SparkResult { + let string_array = array + .as_any() + .downcast_ref::() + .ok_or_else(|| SparkError::Internal("Expected string array".to_string()))?; + + let mut decimal_builder = Decimal128Builder::with_capacity(string_array.len()); + + for i in 0..string_array.len() { + if string_array.is_null(i) { + decimal_builder.append_null(); + } else { + let str_value = string_array.value(i).trim(); + match parse_string_to_decimal(str_value, precision, scale) { + Ok(Some(decimal_value)) => { + decimal_builder.append_value(decimal_value); + } + Ok(None) => { + if eval_mode == EvalMode::Ansi { + return Err(invalid_value( + string_array.value(i), + "STRING", + &format!("DECIMAL({},{})", precision, scale), + )); + } + decimal_builder.append_null(); + } + Err(e) => { + if eval_mode == EvalMode::Ansi { + return Err(e); + } + decimal_builder.append_null(); + } + } + } + } + + Ok(Arc::new( + decimal_builder + .with_precision_and_scale(precision, scale)? + .finish(), + )) +} + +fn cast_string_to_decimal256_impl( + array: &ArrayRef, + eval_mode: EvalMode, + precision: u8, + scale: i8, +) -> SparkResult { + let string_array = array + .as_any() + .downcast_ref::() + .ok_or_else(|| SparkError::Internal("Expected string array".to_string()))?; + + let mut decimal_builder = PrimitiveBuilder::::with_capacity(string_array.len()); + + for i in 0..string_array.len() { + if string_array.is_null(i) { + decimal_builder.append_null(); + } else { + let str_value = string_array.value(i).trim(); + match parse_string_to_decimal(str_value, precision, scale) { + Ok(Some(decimal_value)) => { + // Convert i128 to i256 + let i256_value = i256::from_i128(decimal_value); + decimal_builder.append_value(i256_value); + } + Ok(None) => { + if eval_mode == EvalMode::Ansi { + return Err(invalid_value( + str_value, + "STRING", + &format!("DECIMAL({},{})", precision, scale), + )); + } + decimal_builder.append_null(); + } + Err(e) => { + if eval_mode == EvalMode::Ansi { + return Err(e); + } + decimal_builder.append_null(); + } + } + } + } + + Ok(Arc::new( + decimal_builder + .with_precision_and_scale(precision, scale)? + .finish(), + )) +} + +/// Validates if a string is a valid decimal similar to BigDecimal +fn is_valid_decimal_format(s: &str) -> bool { + if s.is_empty() { + return false; + } + + let bytes = s.as_bytes(); + let mut idx = 0; + let len = bytes.len(); + + // Skip leading +/- signs + if bytes[idx] == b'+' || bytes[idx] == b'-' { + idx += 1; + if idx >= len { + // Sign only. Fail early + return false; + } + } + + // Check invalid cases like "++", "+-" + if bytes[idx] == b'+' || bytes[idx] == b'-' { + return false; + } + + // Now we need at least one digit either before or after a decimal point + let mut has_digit = false; + let mut is_decimal_point_seen = false; + + while idx < len { + let ch = bytes[idx]; + + if ch.is_ascii_digit() { + has_digit = true; + idx += 1; + } else if ch == b'.' { + if is_decimal_point_seen { + // Multiple decimal points or decimal after exponent + return false; + } + is_decimal_point_seen = true; + idx += 1; + } else if ch.eq_ignore_ascii_case(&b'e') { + if !has_digit { + // Exponent without any digits before it + return false; + } + idx += 1; + // Exponent part must have optional sign followed by atleast a digit + if idx >= len { + return false; + } + + if bytes[idx] == b'+' || bytes[idx] == b'-' { + idx += 1; + if idx >= len { + return false; + } + } + + // Must have at least one digit in exponent + if !bytes[idx].is_ascii_digit() { + return false; + } + + // Rest all should only be digits + while idx < len { + if !bytes[idx].is_ascii_digit() { + return false; + } + idx += 1; + } + break; + } else { + // Invalid character found. Fail fast + return false; + } + } + has_digit +} + +/// Parse a string to decimal following Spark's behavior +fn parse_string_to_decimal(s: &str, precision: u8, scale: i8) -> SparkResult> { + if s.is_empty() { + return Ok(None); + } + // Handle special values (inf, nan, etc.) + if s.eq_ignore_ascii_case("inf") + || s.eq_ignore_ascii_case("+inf") + || s.eq_ignore_ascii_case("infinity") + || s.eq_ignore_ascii_case("+infinity") + || s.eq_ignore_ascii_case("-inf") + || s.eq_ignore_ascii_case("-infinity") + || s.eq_ignore_ascii_case("nan") + { + return Ok(None); + } + + if !is_valid_decimal_format(s) { + return Ok(None); + } + + match parse_decimal_str(s) { + Ok((mantissa, exponent)) => { + // Convert to target scale + let target_scale = scale as i32; + let scale_adjustment = target_scale - exponent; + + let scaled_value = if scale_adjustment >= 0 { + // Need to multiply (increase scale) but return None if scale is too high to fit i128 + if scale_adjustment > 38 { + return Ok(None); + } + mantissa.checked_mul(10_i128.pow(scale_adjustment as u32)) + } else { + // Need to multiply (increase scale) but return None if scale is too high to fit i128 + let abs_scale_adjustment = (-scale_adjustment) as u32; + if abs_scale_adjustment > 38 { + return Ok(Some(0)); + } + + let divisor = 10_i128.pow(abs_scale_adjustment); + let quotient_opt = mantissa.checked_div(divisor); + // Check if divisor is 0 + if quotient_opt.is_none() { + return Ok(None); + } + let quotient = quotient_opt.unwrap(); + let remainder = mantissa % divisor; + + // Round half up: if abs(remainder) >= divisor/2, round away from zero + let half_divisor = divisor / 2; + let rounded = if remainder.abs() >= half_divisor { + if mantissa >= 0 { + quotient + 1 + } else { + quotient - 1 + } + } else { + quotient + }; + Some(rounded) + }; + + match scaled_value { + Some(value) => { + // Check if it fits target precision + if is_validate_decimal_precision(value, precision) { + Ok(Some(value)) + } else { + Ok(None) + } + } + None => { + // Overflow while scaling + Ok(None) + } + } + } + Err(_) => Ok(None), + } +} + +/// Parse a decimal string into mantissa and scale +/// e.g., "123.45" -> (12345, 2), "-0.001" -> (-1, 3) +fn parse_decimal_str(s: &str) -> Result<(i128, i32), String> { + let s = s.trim(); + if s.is_empty() { + return Err("Empty string".to_string()); + } + + let (mantissa_str, exponent) = if let Some(e_pos) = s.find(|c| ['e', 'E'].contains(&c)) { + let mantissa_part = &s[..e_pos]; + let exponent_part = &s[e_pos + 1..]; + // Parse exponent + let exp: i32 = exponent_part + .parse() + .map_err(|e| format!("Invalid exponent: {}", e))?; + + (mantissa_part, exp) + } else { + (s, 0) + }; + + let negative = mantissa_str.starts_with('-'); + let mantissa_str = if negative || mantissa_str.starts_with('+') { + &mantissa_str[1..] + } else { + mantissa_str + }; + + let split_by_dot: Vec<&str> = mantissa_str.split('.').collect(); + + if split_by_dot.len() > 2 { + return Err("Multiple decimal points".to_string()); + } + + let integral_part = split_by_dot[0]; + let fractional_part = if split_by_dot.len() == 2 { + split_by_dot[1] + } else { + "" + }; + + // Parse integral part + let integral_value: i128 = if integral_part.is_empty() { + // Empty integral part is valid (e.g., ".5" or "-.7e9") + 0 + } else { + integral_part + .parse() + .map_err(|_| "Invalid integral part".to_string())? + }; + + // Parse fractional part + let fractional_scale = fractional_part.len() as i32; + let fractional_value: i128 = if fractional_part.is_empty() { + 0 + } else { + fractional_part + .parse() + .map_err(|_| "Invalid fractional part".to_string())? + }; + + // Combine: value = integral * 10^fractional_scale + fractional + let mantissa = integral_value + .checked_mul(10_i128.pow(fractional_scale as u32)) + .and_then(|v| v.checked_add(fractional_value)) + .ok_or("Overflow in mantissa calculation")?; + + let final_mantissa = if negative { -mantissa } else { mantissa }; + // final scale = fractional_scale - exponent + // For example : "1.23E-5" has fractional_scale=2, exponent=-5, so scale = 2 - (-5) = 7 + let final_scale = fractional_scale - exponent; + Ok((final_mantissa, final_scale)) +} + /// Either return Ok(None) or Err(SparkError::CastInvalidValue) depending on the evaluation mode #[inline] fn none_or_err(eval_mode: EvalMode, type_name: &str, str: &str) -> SparkResult> { diff --git a/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala b/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala index 98ce8ac44d..3be7700969 100644 --- a/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala +++ b/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala @@ -190,11 +190,6 @@ object CometCast extends CometExpressionSerde[Cast] with CometExprShim { Some( "Does not support inputs ending with 'd' or 'f'. Does not support 'inf'. " + "Does not support ANSI mode.")) - case _: DecimalType => - // https://github.com/apache/datafusion-comet/issues/325 - Incompatible( - Some("Does not support inputs ending with 'd' or 'f'. Does not support 'inf'. " + - "Does not support ANSI mode. Returns 0.0 instead of null if input contains no digits")) case DataTypes.DateType => // https://github.com/apache/datafusion-comet/issues/327 Compatible(Some("Only supports years between 262143 BC and 262142 AD")) From 49053223837e9e276891c19d8acd3868150fea30 Mon Sep 17 00:00:00 2001 From: B Vadlamani Date: Tue, 16 Dec 2025 16:20:38 -0800 Subject: [PATCH 2/7] init_string_decimal_sep_pr --- .../org/apache/comet/CometCastSuite.scala | 79 +++++++++++++++---- 1 file changed, 65 insertions(+), 14 deletions(-) diff --git a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala index 1912e982b9..1b536fec68 100644 --- a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala @@ -684,22 +684,73 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } - ignore("cast StringType to DecimalType(10,2)") { - // https://github.com/apache/datafusion-comet/issues/325 - val values = gen.generateStrings(dataSize, numericPattern, 8).toDF("a") - castTest(values, DataTypes.createDecimalType(10, 2)) + test("cast StringType to DecimalType(10,2)") { + // TODO fix for Spark 4.0.0 + assume(!isSpark40Plus) + val values = gen.generateStrings(dataSize, numericPattern, 12).toDF("a") + Seq(true, false).foreach(k => + castTest(values, DataTypes.createDecimalType(10, 2), testAnsi = k)) } - test("cast StringType to DecimalType(10,2) (partial support)") { - withSQLConf( - CometConf.getExprAllowIncompatConfigKey(classOf[Cast]) -> "true", - SQLConf.ANSI_ENABLED.key -> "false") { - val values = gen - .generateStrings(dataSize, "0123456789.", 8) - .filter(_.exists(_.isDigit)) - .toDF("a") - castTest(values, DataTypes.createDecimalType(10, 2), testAnsi = false) - } + test("cast StringType to DecimalType(2,2)") { + // TODO fix for Spark 4.0.0 + assume(!isSpark40Plus) + val values = gen.generateStrings(dataSize, numericPattern, 12).toDF("a") + Seq(true, false).foreach(k => + castTest(values, DataTypes.createDecimalType(2, 2), testAnsi = k)) + } + + test("cast StringType to DecimalType(38,10) high precision") { + // TODO fix for Spark 4.0.0 + assume(!isSpark40Plus) + val values = gen.generateStrings(dataSize, numericPattern, 38).toDF("a") + Seq(true, false).foreach(k => + castTest(values, DataTypes.createDecimalType(38, 10), testAnsi = k)) + } + + test("cast StringType to DecimalType(10,2) basic values") { + // TODO fix for Spark 4.0.0 + assume(!isSpark40Plus) + val values = Seq( + "123.45", + "-67.89", + "-67.89", + "-67.895", + "67.895", + "0.001", + "999.99", + "123.456", + "123.45D", + ".5", + "5.", + "+123.45", + " 123.45 ", + "inf", + "", + "abc", + null).toDF("a") + Seq(true, false).foreach(k => + castTest(values, DataTypes.createDecimalType(10, 2), testAnsi = k)) + } + + test("cast StringType to Decimal type scientific notation") { + // TODO fix for Spark 4.0.0 + assume(!isSpark40Plus) + val values = Seq( + "1.23E-5", + "1.23e10", + "1.23E+10", + "-1.23e-5", + "1e5", + "1E-2", + "-1.5e3", + "1.23E0", + "0e0", + "1.23e", + "e5", + null).toDF("a") + Seq(true, false).foreach(k => + castTest(values, DataTypes.createDecimalType(23, 8), testAnsi = k)) } test("cast StringType to BinaryType") { From 6896625769967aa8439f4fc1148d8bae89908f8f Mon Sep 17 00:00:00 2001 From: B Vadlamani Date: Wed, 17 Dec 2025 09:32:05 -0800 Subject: [PATCH 3/7] init_string_decimal_sep_pr --- native/spark-expr/src/conversion_funcs/cast.rs | 12 +++++++++--- .../org/apache/comet/expressions/CometCast.scala | 4 ++++ 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/native/spark-expr/src/conversion_funcs/cast.rs b/native/spark-expr/src/conversion_funcs/cast.rs index 034169bb57..637077eb56 100644 --- a/native/spark-expr/src/conversion_funcs/cast.rs +++ b/native/spark-expr/src/conversion_funcs/cast.rs @@ -20,12 +20,13 @@ use crate::{timezone, BinaryOutputStyle}; use crate::{EvalMode, SparkError, SparkResult}; use arrow::array::builder::StringBuilder; use arrow::array::{ - BooleanBuilder, Decimal128Builder, DictionaryArray, GenericByteArray, ListArray, StringArray, - StructArray, + BooleanBuilder, Decimal128Builder, DictionaryArray, GenericByteArray, ListArray, + PrimitiveBuilder, StringArray, StructArray, }; use arrow::compute::can_cast_types; use arrow::datatypes::{ - ArrowDictionaryKeyType, ArrowNativeType, DataType, GenericBinaryType, Schema, + i256, ArrowDictionaryKeyType, ArrowNativeType, DataType, Decimal256Type, GenericBinaryType, + Schema, }; use arrow::{ array::{ @@ -222,6 +223,11 @@ fn can_cast_from_string(to_type: &DataType, options: &SparkCastOptions) -> bool // Does not support ANSI mode. options.allow_incompat } + Decimal128(_, _) => { + // https://github.com/apache/datafusion-comet/issues/325 + // Does not support fullwidth digits and null byte handling. + options.allow_incompat + } Date32 | Date64 => { // https://github.com/apache/datafusion-comet/issues/327 // Only supports years between 262143 BC and 262142 AD diff --git a/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala b/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala index 3be7700969..9cccfeac2f 100644 --- a/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala +++ b/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala @@ -190,6 +190,10 @@ object CometCast extends CometExpressionSerde[Cast] with CometExprShim { Some( "Does not support inputs ending with 'd' or 'f'. Does not support 'inf'. " + "Does not support ANSI mode.")) + case _: DecimalType => + // https://github.com/apache/datafusion-comet/issues/325 + Incompatible(Some( + """Does not support fullwidth unicode digits (e.g \\uFF10) or strings containing null bytes (e.g \\u0000)""".stripMargin)) case DataTypes.DateType => // https://github.com/apache/datafusion-comet/issues/327 Compatible(Some("Only supports years between 262143 BC and 262142 AD")) From 408d436202d277226b7dc9f81e75ffe400e70f51 Mon Sep 17 00:00:00 2001 From: B Vadlamani Date: Thu, 18 Dec 2025 23:26:39 -0800 Subject: [PATCH 4/7] address_review_comments --- .../source/user-guide/latest/compatibility.md | 3 +- .../spark-expr/src/conversion_funcs/cast.rs | 68 +++++----- .../apache/comet/expressions/CometCast.scala | 4 +- .../org/apache/comet/CometCastSuite.scala | 125 ++++++++++-------- 4 files changed, 110 insertions(+), 90 deletions(-) diff --git a/docs/source/user-guide/latest/compatibility.md b/docs/source/user-guide/latest/compatibility.md index 60e2234f59..f529ba1e7e 100644 --- a/docs/source/user-guide/latest/compatibility.md +++ b/docs/source/user-guide/latest/compatibility.md @@ -159,6 +159,8 @@ The following cast operations are generally compatible with Spark except for the | string | short | | | string | integer | | | string | long | | +| string | decimal | Does not support fullwidth unicode digits (e.g \\uFF10) +or strings containing null bytes (e.g \\u0000) | | string | binary | | | string | date | Only supports years between 262143 BC and 262142 AD | | binary | string | | @@ -183,7 +185,6 @@ The following cast operations are not compatible with Spark for all inputs and a | double | decimal | There can be rounding differences | | string | float | Does not support inputs ending with 'd' or 'f'. Does not support 'inf'. Does not support ANSI mode. | | string | double | Does not support inputs ending with 'd' or 'f'. Does not support 'inf'. Does not support ANSI mode. | -| string | decimal | Does not support inputs ending with 'd' or 'f'. Does not support 'inf'. Does not support ANSI mode. Returns 0.0 instead of null if input contains no digits | | string | timestamp | Not all valid formats are supported | diff --git a/native/spark-expr/src/conversion_funcs/cast.rs b/native/spark-expr/src/conversion_funcs/cast.rs index 637077eb56..da48752261 100644 --- a/native/spark-expr/src/conversion_funcs/cast.rs +++ b/native/spark-expr/src/conversion_funcs/cast.rs @@ -2019,7 +2019,7 @@ fn cast_string_to_decimal128_impl( if string_array.is_null(i) { decimal_builder.append_null(); } else { - let str_value = string_array.value(i).trim(); + let str_value = string_array.value(i); match parse_string_to_decimal(str_value, precision, scale) { Ok(Some(decimal_value)) => { decimal_builder.append_value(decimal_value); @@ -2068,7 +2068,7 @@ fn cast_string_to_decimal256_impl( if string_array.is_null(i) { decimal_builder.append_null(); } else { - let str_value = string_array.value(i).trim(); + let str_value = string_array.value(i); match parse_string_to_decimal(str_value, precision, scale) { Ok(Some(decimal_value)) => { // Convert i128 to i256 @@ -2108,12 +2108,12 @@ fn is_valid_decimal_format(s: &str) -> bool { return false; } - let bytes = s.as_bytes(); + let string_bytes = s.as_bytes(); let mut idx = 0; - let len = bytes.len(); + let len = string_bytes.len(); // Skip leading +/- signs - if bytes[idx] == b'+' || bytes[idx] == b'-' { + if string_bytes[idx] == b'+' || string_bytes[idx] == b'-' { idx += 1; if idx >= len { // Sign only. Fail early @@ -2122,7 +2122,7 @@ fn is_valid_decimal_format(s: &str) -> bool { } // Check invalid cases like "++", "+-" - if bytes[idx] == b'+' || bytes[idx] == b'-' { + if string_bytes[idx] == b'+' || string_bytes[idx] == b'-' { return false; } @@ -2131,7 +2131,7 @@ fn is_valid_decimal_format(s: &str) -> bool { let mut is_decimal_point_seen = false; while idx < len { - let ch = bytes[idx]; + let ch = string_bytes[idx]; if ch.is_ascii_digit() { has_digit = true; @@ -2154,7 +2154,7 @@ fn is_valid_decimal_format(s: &str) -> bool { return false; } - if bytes[idx] == b'+' || bytes[idx] == b'-' { + if string_bytes[idx] == b'+' || string_bytes[idx] == b'-' { idx += 1; if idx >= len { return false; @@ -2162,13 +2162,13 @@ fn is_valid_decimal_format(s: &str) -> bool { } // Must have at least one digit in exponent - if !bytes[idx].is_ascii_digit() { + if !string_bytes[idx].is_ascii_digit() { return false; } // Rest all should only be digits while idx < len { - if !bytes[idx].is_ascii_digit() { + if !string_bytes[idx].is_ascii_digit() { return false; } idx += 1; @@ -2184,26 +2184,39 @@ fn is_valid_decimal_format(s: &str) -> bool { /// Parse a string to decimal following Spark's behavior fn parse_string_to_decimal(s: &str, precision: u8, scale: i8) -> SparkResult> { - if s.is_empty() { + let string_bytes = s.as_bytes(); + let mut start = 0; + let mut end = string_bytes.len(); + + while start < end && string_bytes[start].is_ascii_whitespace() { + start += 1; + } + while end > start && string_bytes[end - 1].is_ascii_whitespace() { + end -= 1; + } + + let trimmed = &s[start..end]; + + if trimmed.is_empty() { return Ok(None); } // Handle special values (inf, nan, etc.) - if s.eq_ignore_ascii_case("inf") - || s.eq_ignore_ascii_case("+inf") - || s.eq_ignore_ascii_case("infinity") - || s.eq_ignore_ascii_case("+infinity") - || s.eq_ignore_ascii_case("-inf") - || s.eq_ignore_ascii_case("-infinity") - || s.eq_ignore_ascii_case("nan") + if trimmed.eq_ignore_ascii_case("inf") + || trimmed.eq_ignore_ascii_case("+inf") + || trimmed.eq_ignore_ascii_case("infinity") + || trimmed.eq_ignore_ascii_case("+infinity") + || trimmed.eq_ignore_ascii_case("-inf") + || trimmed.eq_ignore_ascii_case("-infinity") + || trimmed.eq_ignore_ascii_case("nan") { return Ok(None); } - if !is_valid_decimal_format(s) { + if !is_valid_decimal_format(trimmed) { return Ok(None); } - match parse_decimal_str(s) { + match parse_decimal_str(trimmed) { Ok((mantissa, exponent)) => { // Convert to target scale let target_scale = scale as i32; @@ -2267,7 +2280,6 @@ fn parse_string_to_decimal(s: &str, precision: u8, scale: i8) -> SparkResult (12345, 2), "-0.001" -> (-1, 3) fn parse_decimal_str(s: &str) -> Result<(i128, i32), String> { - let s = s.trim(); if s.is_empty() { return Err("Empty string".to_string()); } @@ -2292,17 +2304,9 @@ fn parse_decimal_str(s: &str) -> Result<(i128, i32), String> { mantissa_str }; - let split_by_dot: Vec<&str> = mantissa_str.split('.').collect(); - - if split_by_dot.len() > 2 { - return Err("Multiple decimal points".to_string()); - } - - let integral_part = split_by_dot[0]; - let fractional_part = if split_by_dot.len() == 2 { - split_by_dot[1] - } else { - "" + let (integral_part, fractional_part) = match mantissa_str.find('.') { + Some(dot_pos) => (&mantissa_str[..dot_pos], &mantissa_str[dot_pos + 1..]), + None => (mantissa_str, ""), }; // Parse integral part diff --git a/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala b/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala index 9cccfeac2f..14db7c2780 100644 --- a/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala +++ b/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala @@ -192,8 +192,8 @@ object CometCast extends CometExpressionSerde[Cast] with CometExprShim { "Does not support ANSI mode.")) case _: DecimalType => // https://github.com/apache/datafusion-comet/issues/325 - Incompatible(Some( - """Does not support fullwidth unicode digits (e.g \\uFF10) or strings containing null bytes (e.g \\u0000)""".stripMargin)) + Incompatible(Some("""Does not support fullwidth unicode digits (e.g \\uFF10) + |or strings containing null bytes (e.g \\u0000)""".stripMargin)) case DataTypes.DateType => // https://github.com/apache/datafusion-comet/issues/327 Compatible(Some("Only supports years between 262143 BC and 262142 AD")) diff --git a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala index 1b536fec68..872a9ca9fa 100644 --- a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala @@ -672,7 +672,6 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { // https://github.com/apache/datafusion-comet/issues/326 castTest(gen.generateStrings(dataSize, numericPattern, 8).toDF("a"), DataTypes.DoubleType) } - test("cast StringType to DoubleType (partial support)") { withSQLConf( CometConf.getExprAllowIncompatConfigKey(classOf[Cast]) -> "true", @@ -684,73 +683,89 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } - test("cast StringType to DecimalType(10,2)") { - // TODO fix for Spark 4.0.0 - assume(!isSpark40Plus) +// This is to pass the first `all cast combinations are covered` + ignore("cast StringType to DecimalType(10,2)") { val values = gen.generateStrings(dataSize, numericPattern, 12).toDF("a") - Seq(true, false).foreach(k => - castTest(values, DataTypes.createDecimalType(10, 2), testAnsi = k)) + castTest(values, DataTypes.createDecimalType(10, 2), testAnsi = false) + } + + test("cast StringType to DecimalType(10,2) (does not support fullwidth unicode digits)") { + withSQLConf(CometConf.getExprAllowIncompatConfigKey(classOf[Cast]) -> "true") { + // TODO fix for Spark 4.0.0 + assume(!isSpark40Plus) + val values = gen.generateStrings(dataSize, numericPattern, 12).toDF("a") + Seq(true, false).foreach(ansiEnabled => + castTest(values, DataTypes.createDecimalType(10, 2), testAnsi = ansiEnabled)) + } } test("cast StringType to DecimalType(2,2)") { - // TODO fix for Spark 4.0.0 - assume(!isSpark40Plus) - val values = gen.generateStrings(dataSize, numericPattern, 12).toDF("a") - Seq(true, false).foreach(k => - castTest(values, DataTypes.createDecimalType(2, 2), testAnsi = k)) + withSQLConf(CometConf.getExprAllowIncompatConfigKey(classOf[Cast]) -> "true") { + // TODO fix for Spark 4.0.0 + assume(!isSpark40Plus) + val values = gen.generateStrings(dataSize, numericPattern, 12).toDF("a") + Seq(true, false).foreach(ansiEnabled => + castTest(values, DataTypes.createDecimalType(2, 2), testAnsi = ansiEnabled)) + } } test("cast StringType to DecimalType(38,10) high precision") { - // TODO fix for Spark 4.0.0 - assume(!isSpark40Plus) - val values = gen.generateStrings(dataSize, numericPattern, 38).toDF("a") - Seq(true, false).foreach(k => - castTest(values, DataTypes.createDecimalType(38, 10), testAnsi = k)) + withSQLConf(CometConf.getExprAllowIncompatConfigKey(classOf[Cast]) -> "true") { + // TODO fix for Spark 4.0.0 + assume(!isSpark40Plus) + val values = gen.generateStrings(dataSize, numericPattern, 38).toDF("a") + Seq(true, false).foreach(ansiEnabled => + castTest(values, DataTypes.createDecimalType(38, 10), testAnsi = ansiEnabled)) + } } test("cast StringType to DecimalType(10,2) basic values") { - // TODO fix for Spark 4.0.0 - assume(!isSpark40Plus) - val values = Seq( - "123.45", - "-67.89", - "-67.89", - "-67.895", - "67.895", - "0.001", - "999.99", - "123.456", - "123.45D", - ".5", - "5.", - "+123.45", - " 123.45 ", - "inf", - "", - "abc", - null).toDF("a") - Seq(true, false).foreach(k => - castTest(values, DataTypes.createDecimalType(10, 2), testAnsi = k)) + withSQLConf(CometConf.getExprAllowIncompatConfigKey(classOf[Cast]) -> "true") { + // TODO fix for Spark 4.0.0 + assume(!isSpark40Plus) + val values = Seq( + "123.45", + "-67.89", + "-67.89", + "-67.895", + "67.895", + "0.001", + "999.99", + "123.456", + "123.45D", + ".5", + "5.", + "+123.45", + " 123.45 ", + "inf", + "", + "abc", + null).toDF("a") + Seq(true, false).foreach(ansiEnabled => + castTest(values, DataTypes.createDecimalType(10, 2), testAnsi = ansiEnabled)) + } } test("cast StringType to Decimal type scientific notation") { - // TODO fix for Spark 4.0.0 - assume(!isSpark40Plus) - val values = Seq( - "1.23E-5", - "1.23e10", - "1.23E+10", - "-1.23e-5", - "1e5", - "1E-2", - "-1.5e3", - "1.23E0", - "0e0", - "1.23e", - "e5", - null).toDF("a") - Seq(true, false).foreach(k => - castTest(values, DataTypes.createDecimalType(23, 8), testAnsi = k)) + withSQLConf(CometConf.getExprAllowIncompatConfigKey(classOf[Cast]) -> "true") { + // TODO fix for Spark 4.0.0 + assume(!isSpark40Plus) + val values = Seq( + "1.23E-5", + "1.23e10", + "1.23E+10", + "-1.23e-5", + "1e5", + "1E-2", + "-1.5e3", + "1.23E0", + "0e0", + "1.23e", + "e5", + null).toDF("a") + Seq(true, false).foreach(ansiEnabled => + castTest(values, DataTypes.createDecimalType(23, 8), testAnsi = ansiEnabled)) + } } test("cast StringType to BinaryType") { From bd5f3894060bfedc2ee91e27f519a2b14210f966 Mon Sep 17 00:00:00 2001 From: B Vadlamani Date: Thu, 18 Dec 2025 23:38:13 -0800 Subject: [PATCH 5/7] address_review_comments --- docs/source/user-guide/latest/compatibility.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/user-guide/latest/compatibility.md b/docs/source/user-guide/latest/compatibility.md index f529ba1e7e..58dd8d6ab0 100644 --- a/docs/source/user-guide/latest/compatibility.md +++ b/docs/source/user-guide/latest/compatibility.md @@ -159,8 +159,6 @@ The following cast operations are generally compatible with Spark except for the | string | short | | | string | integer | | | string | long | | -| string | decimal | Does not support fullwidth unicode digits (e.g \\uFF10) -or strings containing null bytes (e.g \\u0000) | | string | binary | | | string | date | Only supports years between 262143 BC and 262142 AD | | binary | string | | @@ -185,6 +183,8 @@ The following cast operations are not compatible with Spark for all inputs and a | double | decimal | There can be rounding differences | | string | float | Does not support inputs ending with 'd' or 'f'. Does not support 'inf'. Does not support ANSI mode. | | string | double | Does not support inputs ending with 'd' or 'f'. Does not support 'inf'. Does not support ANSI mode. | +| string | decimal | Does not support fullwidth unicode digits (e.g \\uFF10) +or strings containing null bytes (e.g \\u0000) | | string | timestamp | Not all valid formats are supported | From fe26709b57c1dd3d04a27af540b11daa9b0f3b90 Mon Sep 17 00:00:00 2001 From: B Vadlamani Date: Sun, 21 Dec 2025 18:57:42 -0800 Subject: [PATCH 6/7] remove_double_interation --- .../spark-expr/src/conversion_funcs/cast.rs | 27 +++++++++++++++---- 1 file changed, 22 insertions(+), 5 deletions(-) diff --git a/native/spark-expr/src/conversion_funcs/cast.rs b/native/spark-expr/src/conversion_funcs/cast.rs index da48752261..c5ec0659e2 100644 --- a/native/spark-expr/src/conversion_funcs/cast.rs +++ b/native/spark-expr/src/conversion_funcs/cast.rs @@ -2212,10 +2212,6 @@ fn parse_string_to_decimal(s: &str, precision: u8, scale: i8) -> SparkResult { // Convert to target scale @@ -2304,11 +2300,32 @@ fn parse_decimal_str(s: &str) -> Result<(i128, i32), String> { mantissa_str }; + if mantissa_str.starts_with('+') || mantissa_str.starts_with('-') { + return Err("Invalid sign format".to_string()); + } + let (integral_part, fractional_part) = match mantissa_str.find('.') { - Some(dot_pos) => (&mantissa_str[..dot_pos], &mantissa_str[dot_pos + 1..]), + Some(dot_pos) => { + if mantissa_str[dot_pos + 1..].contains('.') { + return Err("Multiple decimal points".to_string()); + } + (&mantissa_str[..dot_pos], &mantissa_str[dot_pos + 1..]) + } None => (mantissa_str, ""), }; + if integral_part.is_empty() && fractional_part.is_empty() { + return Err("No digits found".to_string()); + } + + if !integral_part.is_empty() && !integral_part.bytes().all(|b| b.is_ascii_digit()) { + return Err("Invalid integral part".to_string()); + } + + if !fractional_part.is_empty() && !fractional_part.bytes().all(|b| b.is_ascii_digit()) { + return Err("Invalid fractional part".to_string()); + } + // Parse integral part let integral_value: i128 = if integral_part.is_empty() { // Empty integral part is valid (e.g., ".5" or "-.7e9") From 6a6c80d0c66e93c8f08ba4fd7d72096049c04ae0 Mon Sep 17 00:00:00 2001 From: B Vadlamani Date: Sun, 21 Dec 2025 19:02:01 -0800 Subject: [PATCH 7/7] remove_double_interation --- .../spark-expr/src/conversion_funcs/cast.rs | 82 +------------------ 1 file changed, 2 insertions(+), 80 deletions(-) diff --git a/native/spark-expr/src/conversion_funcs/cast.rs b/native/spark-expr/src/conversion_funcs/cast.rs index c5ec0659e2..6b69c72882 100644 --- a/native/spark-expr/src/conversion_funcs/cast.rs +++ b/native/spark-expr/src/conversion_funcs/cast.rs @@ -2102,92 +2102,13 @@ fn cast_string_to_decimal256_impl( )) } -/// Validates if a string is a valid decimal similar to BigDecimal -fn is_valid_decimal_format(s: &str) -> bool { - if s.is_empty() { - return false; - } - - let string_bytes = s.as_bytes(); - let mut idx = 0; - let len = string_bytes.len(); - - // Skip leading +/- signs - if string_bytes[idx] == b'+' || string_bytes[idx] == b'-' { - idx += 1; - if idx >= len { - // Sign only. Fail early - return false; - } - } - - // Check invalid cases like "++", "+-" - if string_bytes[idx] == b'+' || string_bytes[idx] == b'-' { - return false; - } - - // Now we need at least one digit either before or after a decimal point - let mut has_digit = false; - let mut is_decimal_point_seen = false; - - while idx < len { - let ch = string_bytes[idx]; - - if ch.is_ascii_digit() { - has_digit = true; - idx += 1; - } else if ch == b'.' { - if is_decimal_point_seen { - // Multiple decimal points or decimal after exponent - return false; - } - is_decimal_point_seen = true; - idx += 1; - } else if ch.eq_ignore_ascii_case(&b'e') { - if !has_digit { - // Exponent without any digits before it - return false; - } - idx += 1; - // Exponent part must have optional sign followed by atleast a digit - if idx >= len { - return false; - } - - if string_bytes[idx] == b'+' || string_bytes[idx] == b'-' { - idx += 1; - if idx >= len { - return false; - } - } - - // Must have at least one digit in exponent - if !string_bytes[idx].is_ascii_digit() { - return false; - } - - // Rest all should only be digits - while idx < len { - if !string_bytes[idx].is_ascii_digit() { - return false; - } - idx += 1; - } - break; - } else { - // Invalid character found. Fail fast - return false; - } - } - has_digit -} - /// Parse a string to decimal following Spark's behavior fn parse_string_to_decimal(s: &str, precision: u8, scale: i8) -> SparkResult> { let string_bytes = s.as_bytes(); let mut start = 0; let mut end = string_bytes.len(); + // trim whitespaces while start < end && string_bytes[start].is_ascii_whitespace() { start += 1; } @@ -2212,6 +2133,7 @@ fn parse_string_to_decimal(s: &str, precision: u8, scale: i8) -> SparkResult { // Convert to target scale