diff --git a/docs/source/user-guide/latest/compatibility.md b/docs/source/user-guide/latest/compatibility.md index 60e2234f59..58dd8d6ab0 100644 --- a/docs/source/user-guide/latest/compatibility.md +++ b/docs/source/user-guide/latest/compatibility.md @@ -183,7 +183,8 @@ The following cast operations are not compatible with Spark for all inputs and a | double | decimal | There can be rounding differences | | string | float | Does not support inputs ending with 'd' or 'f'. Does not support 'inf'. Does not support ANSI mode. | | string | double | Does not support inputs ending with 'd' or 'f'. Does not support 'inf'. Does not support ANSI mode. | -| string | decimal | Does not support inputs ending with 'd' or 'f'. Does not support 'inf'. Does not support ANSI mode. Returns 0.0 instead of null if input contains no digits | +| string | decimal | Does not support fullwidth unicode digits (e.g \\uFF10) +or strings containing null bytes (e.g \\u0000) | | string | timestamp | Not all valid formats are supported | diff --git a/native/spark-expr/src/conversion_funcs/cast.rs b/native/spark-expr/src/conversion_funcs/cast.rs index 12a147c6e1..6b69c72882 100644 --- a/native/spark-expr/src/conversion_funcs/cast.rs +++ b/native/spark-expr/src/conversion_funcs/cast.rs @@ -20,12 +20,13 @@ use crate::{timezone, BinaryOutputStyle}; use crate::{EvalMode, SparkError, SparkResult}; use arrow::array::builder::StringBuilder; use arrow::array::{ - BooleanBuilder, Decimal128Builder, DictionaryArray, GenericByteArray, ListArray, StringArray, - StructArray, + BooleanBuilder, Decimal128Builder, DictionaryArray, GenericByteArray, ListArray, + PrimitiveBuilder, StringArray, StructArray, }; use arrow::compute::can_cast_types; use arrow::datatypes::{ - ArrowDictionaryKeyType, ArrowNativeType, DataType, GenericBinaryType, Schema, + i256, ArrowDictionaryKeyType, ArrowNativeType, DataType, Decimal256Type, GenericBinaryType, + Schema, }; use arrow::{ array::{ @@ -224,9 +225,7 @@ fn can_cast_from_string(to_type: &DataType, options: &SparkCastOptions) -> bool } Decimal128(_, _) => { // https://github.com/apache/datafusion-comet/issues/325 - // Does not support inputs ending with 'd' or 'f'. Does not support 'inf'. - // Does not support ANSI mode. Returns 0.0 instead of null if input contains no digits - + // Does not support fullwidth digits and null byte handling. options.allow_incompat } Date32 | Date64 => { @@ -976,6 +975,12 @@ fn cast_array( cast_string_to_timestamp(&array, to_type, eval_mode, &cast_options.timezone) } (Utf8, Date32) => cast_string_to_date(&array, to_type, eval_mode), + (Utf8 | LargeUtf8, Decimal128(precision, scale)) => { + cast_string_to_decimal(&array, to_type, precision, scale, eval_mode) + } + (Utf8 | LargeUtf8, Decimal256(precision, scale)) => { + cast_string_to_decimal(&array, to_type, precision, scale, eval_mode) + } (Int64, Int32) | (Int64, Int16) | (Int64, Int8) @@ -1187,7 +1192,7 @@ fn is_datafusion_spark_compatible( ), DataType::Utf8 if allow_incompat => matches!( to_type, - DataType::Binary | DataType::Float32 | DataType::Float64 | DataType::Decimal128(_, _) + DataType::Binary | DataType::Float32 | DataType::Float64 ), DataType::Utf8 => matches!(to_type, DataType::Binary), DataType::Date32 => matches!(to_type, DataType::Utf8), @@ -1976,6 +1981,306 @@ fn do_cast_string_to_int< Ok(Some(result)) } +fn cast_string_to_decimal( + array: &ArrayRef, + to_type: &DataType, + precision: &u8, + scale: &i8, + eval_mode: EvalMode, +) -> SparkResult { + match to_type { + DataType::Decimal128(_, _) => { + cast_string_to_decimal128_impl(array, eval_mode, *precision, *scale) + } + DataType::Decimal256(_, _) => { + cast_string_to_decimal256_impl(array, eval_mode, *precision, *scale) + } + _ => Err(SparkError::Internal(format!( + "Unexpected type in cast_string_to_decimal: {:?}", + to_type + ))), + } +} + +fn cast_string_to_decimal128_impl( + array: &ArrayRef, + eval_mode: EvalMode, + precision: u8, + scale: i8, +) -> SparkResult { + let string_array = array + .as_any() + .downcast_ref::() + .ok_or_else(|| SparkError::Internal("Expected string array".to_string()))?; + + let mut decimal_builder = Decimal128Builder::with_capacity(string_array.len()); + + for i in 0..string_array.len() { + if string_array.is_null(i) { + decimal_builder.append_null(); + } else { + let str_value = string_array.value(i); + match parse_string_to_decimal(str_value, precision, scale) { + Ok(Some(decimal_value)) => { + decimal_builder.append_value(decimal_value); + } + Ok(None) => { + if eval_mode == EvalMode::Ansi { + return Err(invalid_value( + string_array.value(i), + "STRING", + &format!("DECIMAL({},{})", precision, scale), + )); + } + decimal_builder.append_null(); + } + Err(e) => { + if eval_mode == EvalMode::Ansi { + return Err(e); + } + decimal_builder.append_null(); + } + } + } + } + + Ok(Arc::new( + decimal_builder + .with_precision_and_scale(precision, scale)? + .finish(), + )) +} + +fn cast_string_to_decimal256_impl( + array: &ArrayRef, + eval_mode: EvalMode, + precision: u8, + scale: i8, +) -> SparkResult { + let string_array = array + .as_any() + .downcast_ref::() + .ok_or_else(|| SparkError::Internal("Expected string array".to_string()))?; + + let mut decimal_builder = PrimitiveBuilder::::with_capacity(string_array.len()); + + for i in 0..string_array.len() { + if string_array.is_null(i) { + decimal_builder.append_null(); + } else { + let str_value = string_array.value(i); + match parse_string_to_decimal(str_value, precision, scale) { + Ok(Some(decimal_value)) => { + // Convert i128 to i256 + let i256_value = i256::from_i128(decimal_value); + decimal_builder.append_value(i256_value); + } + Ok(None) => { + if eval_mode == EvalMode::Ansi { + return Err(invalid_value( + str_value, + "STRING", + &format!("DECIMAL({},{})", precision, scale), + )); + } + decimal_builder.append_null(); + } + Err(e) => { + if eval_mode == EvalMode::Ansi { + return Err(e); + } + decimal_builder.append_null(); + } + } + } + } + + Ok(Arc::new( + decimal_builder + .with_precision_and_scale(precision, scale)? + .finish(), + )) +} + +/// Parse a string to decimal following Spark's behavior +fn parse_string_to_decimal(s: &str, precision: u8, scale: i8) -> SparkResult> { + let string_bytes = s.as_bytes(); + let mut start = 0; + let mut end = string_bytes.len(); + + // trim whitespaces + while start < end && string_bytes[start].is_ascii_whitespace() { + start += 1; + } + while end > start && string_bytes[end - 1].is_ascii_whitespace() { + end -= 1; + } + + let trimmed = &s[start..end]; + + if trimmed.is_empty() { + return Ok(None); + } + // Handle special values (inf, nan, etc.) + if trimmed.eq_ignore_ascii_case("inf") + || trimmed.eq_ignore_ascii_case("+inf") + || trimmed.eq_ignore_ascii_case("infinity") + || trimmed.eq_ignore_ascii_case("+infinity") + || trimmed.eq_ignore_ascii_case("-inf") + || trimmed.eq_ignore_ascii_case("-infinity") + || trimmed.eq_ignore_ascii_case("nan") + { + return Ok(None); + } + + // validate and parse mantissa and exponent + match parse_decimal_str(trimmed) { + Ok((mantissa, exponent)) => { + // Convert to target scale + let target_scale = scale as i32; + let scale_adjustment = target_scale - exponent; + + let scaled_value = if scale_adjustment >= 0 { + // Need to multiply (increase scale) but return None if scale is too high to fit i128 + if scale_adjustment > 38 { + return Ok(None); + } + mantissa.checked_mul(10_i128.pow(scale_adjustment as u32)) + } else { + // Need to multiply (increase scale) but return None if scale is too high to fit i128 + let abs_scale_adjustment = (-scale_adjustment) as u32; + if abs_scale_adjustment > 38 { + return Ok(Some(0)); + } + + let divisor = 10_i128.pow(abs_scale_adjustment); + let quotient_opt = mantissa.checked_div(divisor); + // Check if divisor is 0 + if quotient_opt.is_none() { + return Ok(None); + } + let quotient = quotient_opt.unwrap(); + let remainder = mantissa % divisor; + + // Round half up: if abs(remainder) >= divisor/2, round away from zero + let half_divisor = divisor / 2; + let rounded = if remainder.abs() >= half_divisor { + if mantissa >= 0 { + quotient + 1 + } else { + quotient - 1 + } + } else { + quotient + }; + Some(rounded) + }; + + match scaled_value { + Some(value) => { + // Check if it fits target precision + if is_validate_decimal_precision(value, precision) { + Ok(Some(value)) + } else { + Ok(None) + } + } + None => { + // Overflow while scaling + Ok(None) + } + } + } + Err(_) => Ok(None), + } +} + +/// Parse a decimal string into mantissa and scale +/// e.g., "123.45" -> (12345, 2), "-0.001" -> (-1, 3) +fn parse_decimal_str(s: &str) -> Result<(i128, i32), String> { + if s.is_empty() { + return Err("Empty string".to_string()); + } + + let (mantissa_str, exponent) = if let Some(e_pos) = s.find(|c| ['e', 'E'].contains(&c)) { + let mantissa_part = &s[..e_pos]; + let exponent_part = &s[e_pos + 1..]; + // Parse exponent + let exp: i32 = exponent_part + .parse() + .map_err(|e| format!("Invalid exponent: {}", e))?; + + (mantissa_part, exp) + } else { + (s, 0) + }; + + let negative = mantissa_str.starts_with('-'); + let mantissa_str = if negative || mantissa_str.starts_with('+') { + &mantissa_str[1..] + } else { + mantissa_str + }; + + if mantissa_str.starts_with('+') || mantissa_str.starts_with('-') { + return Err("Invalid sign format".to_string()); + } + + let (integral_part, fractional_part) = match mantissa_str.find('.') { + Some(dot_pos) => { + if mantissa_str[dot_pos + 1..].contains('.') { + return Err("Multiple decimal points".to_string()); + } + (&mantissa_str[..dot_pos], &mantissa_str[dot_pos + 1..]) + } + None => (mantissa_str, ""), + }; + + if integral_part.is_empty() && fractional_part.is_empty() { + return Err("No digits found".to_string()); + } + + if !integral_part.is_empty() && !integral_part.bytes().all(|b| b.is_ascii_digit()) { + return Err("Invalid integral part".to_string()); + } + + if !fractional_part.is_empty() && !fractional_part.bytes().all(|b| b.is_ascii_digit()) { + return Err("Invalid fractional part".to_string()); + } + + // Parse integral part + let integral_value: i128 = if integral_part.is_empty() { + // Empty integral part is valid (e.g., ".5" or "-.7e9") + 0 + } else { + integral_part + .parse() + .map_err(|_| "Invalid integral part".to_string())? + }; + + // Parse fractional part + let fractional_scale = fractional_part.len() as i32; + let fractional_value: i128 = if fractional_part.is_empty() { + 0 + } else { + fractional_part + .parse() + .map_err(|_| "Invalid fractional part".to_string())? + }; + + // Combine: value = integral * 10^fractional_scale + fractional + let mantissa = integral_value + .checked_mul(10_i128.pow(fractional_scale as u32)) + .and_then(|v| v.checked_add(fractional_value)) + .ok_or("Overflow in mantissa calculation")?; + + let final_mantissa = if negative { -mantissa } else { mantissa }; + // final scale = fractional_scale - exponent + // For example : "1.23E-5" has fractional_scale=2, exponent=-5, so scale = 2 - (-5) = 7 + let final_scale = fractional_scale - exponent; + Ok((final_mantissa, final_scale)) +} + /// Either return Ok(None) or Err(SparkError::CastInvalidValue) depending on the evaluation mode #[inline] fn none_or_err(eval_mode: EvalMode, type_name: &str, str: &str) -> SparkResult> { diff --git a/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala b/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala index 98ce8ac44d..14db7c2780 100644 --- a/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala +++ b/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala @@ -192,9 +192,8 @@ object CometCast extends CometExpressionSerde[Cast] with CometExprShim { "Does not support ANSI mode.")) case _: DecimalType => // https://github.com/apache/datafusion-comet/issues/325 - Incompatible( - Some("Does not support inputs ending with 'd' or 'f'. Does not support 'inf'. " + - "Does not support ANSI mode. Returns 0.0 instead of null if input contains no digits")) + Incompatible(Some("""Does not support fullwidth unicode digits (e.g \\uFF10) + |or strings containing null bytes (e.g \\u0000)""".stripMargin)) case DataTypes.DateType => // https://github.com/apache/datafusion-comet/issues/327 Compatible(Some("Only supports years between 262143 BC and 262142 AD")) diff --git a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala index 1912e982b9..872a9ca9fa 100644 --- a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala @@ -672,7 +672,6 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { // https://github.com/apache/datafusion-comet/issues/326 castTest(gen.generateStrings(dataSize, numericPattern, 8).toDF("a"), DataTypes.DoubleType) } - test("cast StringType to DoubleType (partial support)") { withSQLConf( CometConf.getExprAllowIncompatConfigKey(classOf[Cast]) -> "true", @@ -684,21 +683,88 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } +// This is to pass the first `all cast combinations are covered` ignore("cast StringType to DecimalType(10,2)") { - // https://github.com/apache/datafusion-comet/issues/325 - val values = gen.generateStrings(dataSize, numericPattern, 8).toDF("a") - castTest(values, DataTypes.createDecimalType(10, 2)) + val values = gen.generateStrings(dataSize, numericPattern, 12).toDF("a") + castTest(values, DataTypes.createDecimalType(10, 2), testAnsi = false) } - test("cast StringType to DecimalType(10,2) (partial support)") { - withSQLConf( - CometConf.getExprAllowIncompatConfigKey(classOf[Cast]) -> "true", - SQLConf.ANSI_ENABLED.key -> "false") { - val values = gen - .generateStrings(dataSize, "0123456789.", 8) - .filter(_.exists(_.isDigit)) - .toDF("a") - castTest(values, DataTypes.createDecimalType(10, 2), testAnsi = false) + test("cast StringType to DecimalType(10,2) (does not support fullwidth unicode digits)") { + withSQLConf(CometConf.getExprAllowIncompatConfigKey(classOf[Cast]) -> "true") { + // TODO fix for Spark 4.0.0 + assume(!isSpark40Plus) + val values = gen.generateStrings(dataSize, numericPattern, 12).toDF("a") + Seq(true, false).foreach(ansiEnabled => + castTest(values, DataTypes.createDecimalType(10, 2), testAnsi = ansiEnabled)) + } + } + + test("cast StringType to DecimalType(2,2)") { + withSQLConf(CometConf.getExprAllowIncompatConfigKey(classOf[Cast]) -> "true") { + // TODO fix for Spark 4.0.0 + assume(!isSpark40Plus) + val values = gen.generateStrings(dataSize, numericPattern, 12).toDF("a") + Seq(true, false).foreach(ansiEnabled => + castTest(values, DataTypes.createDecimalType(2, 2), testAnsi = ansiEnabled)) + } + } + + test("cast StringType to DecimalType(38,10) high precision") { + withSQLConf(CometConf.getExprAllowIncompatConfigKey(classOf[Cast]) -> "true") { + // TODO fix for Spark 4.0.0 + assume(!isSpark40Plus) + val values = gen.generateStrings(dataSize, numericPattern, 38).toDF("a") + Seq(true, false).foreach(ansiEnabled => + castTest(values, DataTypes.createDecimalType(38, 10), testAnsi = ansiEnabled)) + } + } + + test("cast StringType to DecimalType(10,2) basic values") { + withSQLConf(CometConf.getExprAllowIncompatConfigKey(classOf[Cast]) -> "true") { + // TODO fix for Spark 4.0.0 + assume(!isSpark40Plus) + val values = Seq( + "123.45", + "-67.89", + "-67.89", + "-67.895", + "67.895", + "0.001", + "999.99", + "123.456", + "123.45D", + ".5", + "5.", + "+123.45", + " 123.45 ", + "inf", + "", + "abc", + null).toDF("a") + Seq(true, false).foreach(ansiEnabled => + castTest(values, DataTypes.createDecimalType(10, 2), testAnsi = ansiEnabled)) + } + } + + test("cast StringType to Decimal type scientific notation") { + withSQLConf(CometConf.getExprAllowIncompatConfigKey(classOf[Cast]) -> "true") { + // TODO fix for Spark 4.0.0 + assume(!isSpark40Plus) + val values = Seq( + "1.23E-5", + "1.23e10", + "1.23E+10", + "-1.23e-5", + "1e5", + "1E-2", + "-1.5e3", + "1.23E0", + "0e0", + "1.23e", + "e5", + null).toDF("a") + Seq(true, false).foreach(ansiEnabled => + castTest(values, DataTypes.createDecimalType(23, 8), testAnsi = ansiEnabled)) } }