Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion docs/source/user-guide/latest/compatibility.md
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,8 @@ The following cast operations are not compatible with Spark for all inputs and a
| double | decimal | There can be rounding differences |
| string | float | Does not support inputs ending with 'd' or 'f'. Does not support 'inf'. Does not support ANSI mode. |
| string | double | Does not support inputs ending with 'd' or 'f'. Does not support 'inf'. Does not support ANSI mode. |
| string | decimal | Does not support inputs ending with 'd' or 'f'. Does not support 'inf'. Does not support ANSI mode. Returns 0.0 instead of null if input contains no digits |
| string | decimal | Does not support fullwidth unicode digits (e.g \\uFF10)
or strings containing null bytes (e.g \\u0000) |
| string | timestamp | Not all valid formats are supported |
<!-- prettier-ignore-end -->
<!--END:INCOMPAT_CAST_TABLE-->
Expand Down
319 changes: 312 additions & 7 deletions native/spark-expr/src/conversion_funcs/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,13 @@ use crate::{timezone, BinaryOutputStyle};
use crate::{EvalMode, SparkError, SparkResult};
use arrow::array::builder::StringBuilder;
use arrow::array::{
BooleanBuilder, Decimal128Builder, DictionaryArray, GenericByteArray, ListArray, StringArray,
StructArray,
BooleanBuilder, Decimal128Builder, DictionaryArray, GenericByteArray, ListArray,
PrimitiveBuilder, StringArray, StructArray,
};
use arrow::compute::can_cast_types;
use arrow::datatypes::{
ArrowDictionaryKeyType, ArrowNativeType, DataType, GenericBinaryType, Schema,
i256, ArrowDictionaryKeyType, ArrowNativeType, DataType, Decimal256Type, GenericBinaryType,
Schema,
};
use arrow::{
array::{
Expand Down Expand Up @@ -224,9 +225,7 @@ fn can_cast_from_string(to_type: &DataType, options: &SparkCastOptions) -> bool
}
Decimal128(_, _) => {
// https://github.com/apache/datafusion-comet/issues/325
// Does not support inputs ending with 'd' or 'f'. Does not support 'inf'.
// Does not support ANSI mode. Returns 0.0 instead of null if input contains no digits

// Does not support fullwidth digits and null byte handling.
options.allow_incompat
}
Date32 | Date64 => {
Expand Down Expand Up @@ -976,6 +975,12 @@ fn cast_array(
cast_string_to_timestamp(&array, to_type, eval_mode, &cast_options.timezone)
}
(Utf8, Date32) => cast_string_to_date(&array, to_type, eval_mode),
(Utf8 | LargeUtf8, Decimal128(precision, scale)) => {
cast_string_to_decimal(&array, to_type, precision, scale, eval_mode)
}
(Utf8 | LargeUtf8, Decimal256(precision, scale)) => {
cast_string_to_decimal(&array, to_type, precision, scale, eval_mode)
}
(Int64, Int32)
| (Int64, Int16)
| (Int64, Int8)
Expand Down Expand Up @@ -1187,7 +1192,7 @@ fn is_datafusion_spark_compatible(
),
DataType::Utf8 if allow_incompat => matches!(
to_type,
DataType::Binary | DataType::Float32 | DataType::Float64 | DataType::Decimal128(_, _)
DataType::Binary | DataType::Float32 | DataType::Float64
),
DataType::Utf8 => matches!(to_type, DataType::Binary),
DataType::Date32 => matches!(to_type, DataType::Utf8),
Expand Down Expand Up @@ -1976,6 +1981,306 @@ fn do_cast_string_to_int<
Ok(Some(result))
}

fn cast_string_to_decimal(
array: &ArrayRef,
to_type: &DataType,
precision: &u8,
scale: &i8,
eval_mode: EvalMode,
) -> SparkResult<ArrayRef> {
match to_type {
DataType::Decimal128(_, _) => {
cast_string_to_decimal128_impl(array, eval_mode, *precision, *scale)
}
DataType::Decimal256(_, _) => {
cast_string_to_decimal256_impl(array, eval_mode, *precision, *scale)
}
_ => Err(SparkError::Internal(format!(
"Unexpected type in cast_string_to_decimal: {:?}",
to_type
))),
}
}

fn cast_string_to_decimal128_impl(
array: &ArrayRef,
eval_mode: EvalMode,
precision: u8,
scale: i8,
) -> SparkResult<ArrayRef> {
let string_array = array
.as_any()
.downcast_ref::<StringArray>()
.ok_or_else(|| SparkError::Internal("Expected string array".to_string()))?;

let mut decimal_builder = Decimal128Builder::with_capacity(string_array.len());

for i in 0..string_array.len() {
if string_array.is_null(i) {
decimal_builder.append_null();
} else {
let str_value = string_array.value(i);
match parse_string_to_decimal(str_value, precision, scale) {
Ok(Some(decimal_value)) => {
decimal_builder.append_value(decimal_value);
}
Ok(None) => {
if eval_mode == EvalMode::Ansi {
return Err(invalid_value(
string_array.value(i),
"STRING",
&format!("DECIMAL({},{})", precision, scale),
));
}
decimal_builder.append_null();
}
Err(e) => {
if eval_mode == EvalMode::Ansi {
return Err(e);
}
decimal_builder.append_null();
}
}
}
}

Ok(Arc::new(
decimal_builder
.with_precision_and_scale(precision, scale)?
.finish(),
))
}

fn cast_string_to_decimal256_impl(
array: &ArrayRef,
eval_mode: EvalMode,
precision: u8,
scale: i8,
) -> SparkResult<ArrayRef> {
let string_array = array
.as_any()
.downcast_ref::<StringArray>()
.ok_or_else(|| SparkError::Internal("Expected string array".to_string()))?;

let mut decimal_builder = PrimitiveBuilder::<Decimal256Type>::with_capacity(string_array.len());

for i in 0..string_array.len() {
if string_array.is_null(i) {
decimal_builder.append_null();
} else {
let str_value = string_array.value(i);
match parse_string_to_decimal(str_value, precision, scale) {
Ok(Some(decimal_value)) => {
// Convert i128 to i256
let i256_value = i256::from_i128(decimal_value);
decimal_builder.append_value(i256_value);
}
Ok(None) => {
if eval_mode == EvalMode::Ansi {
return Err(invalid_value(
str_value,
"STRING",
&format!("DECIMAL({},{})", precision, scale),
));
}
decimal_builder.append_null();
}
Err(e) => {
if eval_mode == EvalMode::Ansi {
return Err(e);
}
decimal_builder.append_null();
}
}
}
}

Ok(Arc::new(
decimal_builder
.with_precision_and_scale(precision, scale)?
.finish(),
))
}

/// Parse a string to decimal following Spark's behavior
fn parse_string_to_decimal(s: &str, precision: u8, scale: i8) -> SparkResult<Option<i128>> {
let string_bytes = s.as_bytes();
let mut start = 0;
let mut end = string_bytes.len();

// trim whitespaces
while start < end && string_bytes[start].is_ascii_whitespace() {
start += 1;
}
while end > start && string_bytes[end - 1].is_ascii_whitespace() {
end -= 1;
}

let trimmed = &s[start..end];

if trimmed.is_empty() {
return Ok(None);
}
// Handle special values (inf, nan, etc.)
if trimmed.eq_ignore_ascii_case("inf")
|| trimmed.eq_ignore_ascii_case("+inf")
|| trimmed.eq_ignore_ascii_case("infinity")
|| trimmed.eq_ignore_ascii_case("+infinity")
|| trimmed.eq_ignore_ascii_case("-inf")
|| trimmed.eq_ignore_ascii_case("-infinity")
|| trimmed.eq_ignore_ascii_case("nan")
{
return Ok(None);
}

// validate and parse mantissa and exponent
match parse_decimal_str(trimmed) {
Ok((mantissa, exponent)) => {
// Convert to target scale
let target_scale = scale as i32;
let scale_adjustment = target_scale - exponent;

let scaled_value = if scale_adjustment >= 0 {
// Need to multiply (increase scale) but return None if scale is too high to fit i128
if scale_adjustment > 38 {
return Ok(None);
}
mantissa.checked_mul(10_i128.pow(scale_adjustment as u32))
} else {
// Need to multiply (increase scale) but return None if scale is too high to fit i128
let abs_scale_adjustment = (-scale_adjustment) as u32;
if abs_scale_adjustment > 38 {
return Ok(Some(0));
}

let divisor = 10_i128.pow(abs_scale_adjustment);
let quotient_opt = mantissa.checked_div(divisor);
// Check if divisor is 0
if quotient_opt.is_none() {
return Ok(None);
}
let quotient = quotient_opt.unwrap();
let remainder = mantissa % divisor;

// Round half up: if abs(remainder) >= divisor/2, round away from zero
let half_divisor = divisor / 2;
let rounded = if remainder.abs() >= half_divisor {
if mantissa >= 0 {
quotient + 1
} else {
quotient - 1
}
} else {
quotient
};
Some(rounded)
};

match scaled_value {
Some(value) => {
// Check if it fits target precision
if is_validate_decimal_precision(value, precision) {
Ok(Some(value))
} else {
Ok(None)
}
}
None => {
// Overflow while scaling
Ok(None)
}
}
}
Err(_) => Ok(None),
}
}

/// Parse a decimal string into mantissa and scale
/// e.g., "123.45" -> (12345, 2), "-0.001" -> (-1, 3)
fn parse_decimal_str(s: &str) -> Result<(i128, i32), String> {
if s.is_empty() {
return Err("Empty string".to_string());
}

let (mantissa_str, exponent) = if let Some(e_pos) = s.find(|c| ['e', 'E'].contains(&c)) {
let mantissa_part = &s[..e_pos];
let exponent_part = &s[e_pos + 1..];
// Parse exponent
let exp: i32 = exponent_part
.parse()
.map_err(|e| format!("Invalid exponent: {}", e))?;

(mantissa_part, exp)
} else {
(s, 0)
};

let negative = mantissa_str.starts_with('-');
let mantissa_str = if negative || mantissa_str.starts_with('+') {
&mantissa_str[1..]
} else {
mantissa_str
};

if mantissa_str.starts_with('+') || mantissa_str.starts_with('-') {
return Err("Invalid sign format".to_string());
}

let (integral_part, fractional_part) = match mantissa_str.find('.') {
Some(dot_pos) => {
if mantissa_str[dot_pos + 1..].contains('.') {
return Err("Multiple decimal points".to_string());
}
(&mantissa_str[..dot_pos], &mantissa_str[dot_pos + 1..])
}
None => (mantissa_str, ""),
};

if integral_part.is_empty() && fractional_part.is_empty() {
return Err("No digits found".to_string());
}

if !integral_part.is_empty() && !integral_part.bytes().all(|b| b.is_ascii_digit()) {
return Err("Invalid integral part".to_string());
}

if !fractional_part.is_empty() && !fractional_part.bytes().all(|b| b.is_ascii_digit()) {
return Err("Invalid fractional part".to_string());
}

// Parse integral part
let integral_value: i128 = if integral_part.is_empty() {
// Empty integral part is valid (e.g., ".5" or "-.7e9")
0
} else {
integral_part
.parse()
.map_err(|_| "Invalid integral part".to_string())?
};

// Parse fractional part
let fractional_scale = fractional_part.len() as i32;
let fractional_value: i128 = if fractional_part.is_empty() {
0
} else {
fractional_part
.parse()
.map_err(|_| "Invalid fractional part".to_string())?
};

// Combine: value = integral * 10^fractional_scale + fractional
let mantissa = integral_value
.checked_mul(10_i128.pow(fractional_scale as u32))
.and_then(|v| v.checked_add(fractional_value))
.ok_or("Overflow in mantissa calculation")?;

let final_mantissa = if negative { -mantissa } else { mantissa };
// final scale = fractional_scale - exponent
// For example : "1.23E-5" has fractional_scale=2, exponent=-5, so scale = 2 - (-5) = 7
let final_scale = fractional_scale - exponent;
Ok((final_mantissa, final_scale))
}

/// Either return Ok(None) or Err(SparkError::CastInvalidValue) depending on the evaluation mode
#[inline]
fn none_or_err<T>(eval_mode: EvalMode, type_name: &str, str: &str) -> SparkResult<Option<T>> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -192,9 +192,8 @@ object CometCast extends CometExpressionSerde[Cast] with CometExprShim {
"Does not support ANSI mode."))
case _: DecimalType =>
// https://github.com/apache/datafusion-comet/issues/325
Incompatible(
Some("Does not support inputs ending with 'd' or 'f'. Does not support 'inf'. " +
"Does not support ANSI mode. Returns 0.0 instead of null if input contains no digits"))
Incompatible(Some("""Does not support fullwidth unicode digits (e.g \\uFF10)
|or strings containing null bytes (e.g \\u0000)""".stripMargin))
case DataTypes.DateType =>
// https://github.com/apache/datafusion-comet/issues/327
Compatible(Some("Only supports years between 262143 BC and 262142 AD"))
Expand Down
Loading
Loading