From 92a78676ef13b06a38ed713f95738fe03802f7cc Mon Sep 17 00:00:00 2001 From: manuzhang Date: Tue, 9 Dec 2025 23:21:18 +0800 Subject: [PATCH] fix: array to array cast --- .gitignore | 1 + .../spark-expr/src/conversion_funcs/cast.rs | 18 ++++++++++-- .../org/apache/comet/CometCastSuite.scala | 28 +++++++++++++++++-- 3 files changed, 43 insertions(+), 4 deletions(-) diff --git a/.gitignore b/.gitignore index 94877ced70..cdb16e9369 100644 --- a/.gitignore +++ b/.gitignore @@ -18,3 +18,4 @@ apache-rat-*.jar venv dev/release/comet-rm/workdir spark/benchmarks +/comet-event-trace.json diff --git a/native/spark-expr/src/conversion_funcs/cast.rs b/native/spark-expr/src/conversion_funcs/cast.rs index 12a147c6e1..88a7e60cd2 100644 --- a/native/spark-expr/src/conversion_funcs/cast.rs +++ b/native/spark-expr/src/conversion_funcs/cast.rs @@ -1031,8 +1031,22 @@ fn cast_array( cast_options, )?), (List(_), Utf8) => Ok(cast_array_to_string(array.as_list(), cast_options)?), - (List(_), List(_)) if can_cast_types(from_type, to_type) => { - Ok(cast_with_options(&array, to_type, &CAST_OPTIONS)?) + (List(from), List(to)) + if can_cast_types(from_type, to_type) + || (matches!(from.data_type(), Decimal128(_, _)) + && matches!(to.data_type(), Boolean)) => + { + let list_array = array.as_list::(); + Ok(Arc::new(ListArray::new( + Arc::clone(to), + list_array.offsets().clone(), + cast_array( + Arc::clone(list_array.values()), + to.data_type(), + cast_options, + )?, + list_array.nulls().cloned(), + )) as ArrayRef) } (UInt8 | UInt16 | UInt32 | UInt64, Int8 | Int16 | Int32 | Int64) if cast_options.allow_cast_unsigned_ints => diff --git a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala index 1912e982b9..c07d86095d 100644 --- a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala @@ -1072,6 +1072,30 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } + test("cast ArrayType to ArrayType") { + val types = Seq( + BooleanType, + StringType, + ByteType, + IntegerType, + LongType, + ShortType, + DecimalType(10, 2), + DecimalType(38, 18)) + for (fromType <- types) { + for (toType <- types) { + if (fromType != toType && + !tags + .get(s"cast $fromType to $toType") + .exists(s => s.contains("org.scalatest.Ignore")) && + Cast.canCast(fromType, toType) && + CometCast.isSupported(fromType, toType, None, CometEvalMode.LEGACY) == Compatible()) { + castTest(generateArrays(100, fromType), ArrayType(toType)) + } + } + } + } + private def generateFloats(): DataFrame = { withNulls(gen.generateFloats(dataSize)).toDF("a") } @@ -1100,10 +1124,10 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { withNulls(gen.generateLongs(dataSize)).toDF("a") } - private def generateArrays(rowSize: Int, elementType: DataType): DataFrame = { + private def generateArrays(rowNum: Int, elementType: DataType): DataFrame = { import scala.collection.JavaConverters._ val schema = StructType(Seq(StructField("a", ArrayType(elementType), true))) - spark.createDataFrame(gen.generateRows(rowSize, schema).asJava, schema) + spark.createDataFrame(gen.generateRows(rowNum, schema).asJava, schema) } // https://github.com/apache/datafusion-comet/issues/2038