diff --git a/docs/source/user-guide/latest/compatibility.md b/docs/source/user-guide/latest/compatibility.md index 58dd8d6ab0..acc123e355 100644 --- a/docs/source/user-guide/latest/compatibility.md +++ b/docs/source/user-guide/latest/compatibility.md @@ -159,6 +159,8 @@ The following cast operations are generally compatible with Spark except for the | string | short | | | string | integer | | | string | long | | +| string | float | | +| string | double | | | string | binary | | | string | date | Only supports years between 262143 BC and 262142 AD | | binary | string | | @@ -181,8 +183,6 @@ The following cast operations are not compatible with Spark for all inputs and a |-|-|-| | float | decimal | There can be rounding differences | | double | decimal | There can be rounding differences | -| string | float | Does not support inputs ending with 'd' or 'f'. Does not support 'inf'. Does not support ANSI mode. | -| string | double | Does not support inputs ending with 'd' or 'f'. Does not support 'inf'. Does not support ANSI mode. | | string | decimal | Does not support fullwidth unicode digits (e.g \\uFF10) or strings containing null bytes (e.g \\u0000) | | string | timestamp | Not all valid formats are supported | diff --git a/native/spark-expr/src/conversion_funcs/cast.rs b/native/spark-expr/src/conversion_funcs/cast.rs index 6b69c72882..5011917082 100644 --- a/native/spark-expr/src/conversion_funcs/cast.rs +++ b/native/spark-expr/src/conversion_funcs/cast.rs @@ -45,6 +45,7 @@ use arrow::{ record_batch::RecordBatch, util::display::FormatOptions, }; +use base64::prelude::*; use chrono::{DateTime, NaiveDate, TimeZone, Timelike}; use datafusion::common::{ cast::as_generic_string_array, internal_err, DataFusionError, Result as DataFusionResult, @@ -66,8 +67,6 @@ use std::{ sync::Arc, }; -use base64::prelude::*; - static TIMESTAMP_FORMAT: Option<&str> = Some("%Y-%m-%d %H:%M:%S%.f"); const MICROS_PER_SECOND: i64 = 1000000; @@ -217,12 +216,7 @@ fn can_cast_from_string(to_type: &DataType, options: &SparkCastOptions) -> bool use DataType::*; match to_type { Boolean | Int8 | Int16 | Int32 | Int64 | Binary => true, - Float32 | Float64 => { - // https://github.com/apache/datafusion-comet/issues/326 - // Does not support inputs ending with 'd' or 'f'. Does not support 'inf'. - // Does not support ANSI mode. - options.allow_incompat - } + Float32 | Float64 => true, Decimal128(_, _) => { // https://github.com/apache/datafusion-comet/issues/325 // Does not support fullwidth digits and null byte handling. @@ -975,6 +969,7 @@ fn cast_array( cast_string_to_timestamp(&array, to_type, eval_mode, &cast_options.timezone) } (Utf8, Date32) => cast_string_to_date(&array, to_type, eval_mode), + (Utf8, Float32 | Float64) => cast_string_to_float(&array, to_type, eval_mode), (Utf8 | LargeUtf8, Decimal128(precision, scale)) => { cast_string_to_decimal(&array, to_type, precision, scale, eval_mode) } @@ -1046,7 +1041,7 @@ fn cast_array( } (Binary, Utf8) => Ok(cast_binary_to_string::(&array, cast_options)?), _ if cast_options.is_adapting_schema - || is_datafusion_spark_compatible(from_type, to_type, cast_options.allow_incompat) => + || is_datafusion_spark_compatible(from_type, to_type) => { // use DataFusion cast only when we know that it is compatible with Spark Ok(cast_with_options(&array, to_type, &native_cast_options)?) @@ -1063,6 +1058,86 @@ fn cast_array( Ok(spark_cast_postprocess(cast_result?, from_type, to_type)) } +fn cast_string_to_float( + array: &ArrayRef, + to_type: &DataType, + eval_mode: EvalMode, +) -> SparkResult { + match to_type { + DataType::Float32 => cast_string_to_float_impl::(array, eval_mode, "FLOAT"), + DataType::Float64 => cast_string_to_float_impl::(array, eval_mode, "DOUBLE"), + _ => Err(SparkError::Internal(format!( + "Unsupported cast to float type: {:?}", + to_type + ))), + } +} + +fn cast_string_to_float_impl( + array: &ArrayRef, + eval_mode: EvalMode, + type_name: &str, +) -> SparkResult +where + T::Native: FromStr + num::Float, +{ + let arr = array + .as_any() + .downcast_ref::() + .ok_or_else(|| SparkError::Internal("Expected string array".to_string()))?; + + let mut builder = PrimitiveBuilder::::with_capacity(arr.len()); + + for i in 0..arr.len() { + if arr.is_null(i) { + builder.append_null(); + } else { + let str_value = arr.value(i).trim(); + match parse_string_to_float(str_value) { + Some(v) => builder.append_value(v), + None => { + if eval_mode == EvalMode::Ansi { + return Err(invalid_value(arr.value(i), "STRING", type_name)); + } + builder.append_null(); + } + } + } + } + + Ok(Arc::new(builder.finish())) +} + +/// helper to parse floats from string inputs +fn parse_string_to_float(s: &str) -> Option +where + F: FromStr + num::Float, +{ + // Handle +inf / -inf + if s.eq_ignore_ascii_case("inf") + || s.eq_ignore_ascii_case("+inf") + || s.eq_ignore_ascii_case("infinity") + || s.eq_ignore_ascii_case("+infinity") + { + return Some(F::infinity()); + } + if s.eq_ignore_ascii_case("-inf") || s.eq_ignore_ascii_case("-infinity") { + return Some(F::neg_infinity()); + } + if s.eq_ignore_ascii_case("nan") { + return Some(F::nan()); + } + // Remove D/F suffix if present + let pruned_float_str = + if s.ends_with("d") || s.ends_with("D") || s.ends_with('f') || s.ends_with('F') { + &s[..s.len() - 1] + } else { + s + }; + // Rust's parse logic already handles scientific notations so we just rely on it + pruned_float_str.parse::().ok() +} + fn cast_binary_to_string( array: &dyn Array, spark_cast_options: &SparkCastOptions, @@ -1133,11 +1208,7 @@ fn cast_binary_formatter(value: &[u8]) -> String { /// Determines if DataFusion supports the given cast in a way that is /// compatible with Spark -fn is_datafusion_spark_compatible( - from_type: &DataType, - to_type: &DataType, - allow_incompat: bool, -) -> bool { +fn is_datafusion_spark_compatible(from_type: &DataType, to_type: &DataType) -> bool { if from_type == to_type { return true; } @@ -1190,10 +1261,6 @@ fn is_datafusion_spark_compatible( | DataType::Decimal256(_, _) | DataType::Utf8 // note that there can be formatting differences ), - DataType::Utf8 if allow_incompat => matches!( - to_type, - DataType::Binary | DataType::Float32 | DataType::Float64 - ), DataType::Utf8 => matches!(to_type, DataType::Binary), DataType::Date32 => matches!(to_type, DataType::Utf8), DataType::Timestamp(_, _) => { diff --git a/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala b/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala index 14db7c2780..9fc4b3afdf 100644 --- a/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala +++ b/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala @@ -185,11 +185,7 @@ object CometCast extends CometExpressionSerde[Cast] with CometExprShim { case DataTypes.BinaryType => Compatible() case DataTypes.FloatType | DataTypes.DoubleType => - // https://github.com/apache/datafusion-comet/issues/326 - Incompatible( - Some( - "Does not support inputs ending with 'd' or 'f'. Does not support 'inf'. " + - "Does not support ANSI mode.")) + Compatible() case _: DecimalType => // https://github.com/apache/datafusion-comet/issues/325 Incompatible(Some("""Does not support fullwidth unicode digits (e.g \\uFF10) diff --git a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala index 6f58811f80..1892749bec 100644 --- a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala @@ -642,34 +642,50 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { castTest(gen.generateStrings(dataSize, numericPattern, 8).toDF("a"), DataTypes.LongType) } - ignore("cast StringType to FloatType") { + test("cast StringType to DoubleType") { // https://github.com/apache/datafusion-comet/issues/326 + castTest(gen.generateStrings(dataSize, numericPattern, 8).toDF("a"), DataTypes.DoubleType) + } + + test("cast StringType to FloatType") { castTest(gen.generateStrings(dataSize, numericPattern, 8).toDF("a"), DataTypes.FloatType) } - test("cast StringType to FloatType (partial support)") { - withSQLConf( - CometConf.getExprAllowIncompatConfigKey(classOf[Cast]) -> "true", - SQLConf.ANSI_ENABLED.key -> "false") { - castTest( - gen.generateStrings(dataSize, "0123456789.", 8).toDF("a"), - DataTypes.FloatType, - testAnsi = false) + val specialValues: Seq[String] = Seq( + "1.5f", + "1.5F", + "2.0d", + "2.0D", + "3.14159265358979d", + "inf", + "Inf", + "INF", + "+inf", + "+Infinity", + "-inf", + "-Infinity", + "NaN", + "nan", + "NAN", + "1.23e4", + "1.23E4", + "-1.23e-4", + " 123.456789 ", + "0.0", + "-0.0", + "", + "xyz", + null) + + test("cast StringType to FloatType special values") { + Seq(true, false).foreach { ansiMode => + castTest(specialValues.toDF("a"), DataTypes.FloatType, testAnsi = ansiMode) } } - ignore("cast StringType to DoubleType") { - // https://github.com/apache/datafusion-comet/issues/326 - castTest(gen.generateStrings(dataSize, numericPattern, 8).toDF("a"), DataTypes.DoubleType) - } - test("cast StringType to DoubleType (partial support)") { - withSQLConf( - CometConf.getExprAllowIncompatConfigKey(classOf[Cast]) -> "true", - SQLConf.ANSI_ENABLED.key -> "false") { - castTest( - gen.generateStrings(dataSize, "0123456789.", 8).toDF("a"), - DataTypes.DoubleType, - testAnsi = false) + test("cast StringType to DoubleType special values") { + Seq(true, false).foreach { ansiMode => + castTest(specialValues.toDF("a"), DataTypes.DoubleType, testAnsi = ansiMode) } }