@@ -22,7 +22,7 @@ package org.apache.comet
2222import java .io .File
2323
2424import scala .collection .mutable .ListBuffer
25- import scala .util .Random
25+ import scala .util .{ Failure , Random , Success , Try }
2626import scala .util .matching .Regex
2727
2828import org .apache .hadoop .fs .Path
@@ -33,7 +33,7 @@ import org.apache.spark.sql.functions.col
3333import org .apache .spark .sql .internal .SQLConf
3434import org .apache .spark .sql .types .{ArrayType , BooleanType , ByteType , DataType , DataTypes , DecimalType , IntegerType , LongType , ShortType , StringType , StructField , StructType }
3535
36- import org .apache .comet .CometSparkSessionExtensions .isSpark40Plus
36+ import org .apache .comet .CometSparkSessionExtensions .{ hasExplainInfo , isSpark40Plus }
3737import org .apache .comet .expressions .{CometCast , CometEvalMode }
3838import org .apache .comet .rules .CometScanTypeChecker
3939import org .apache .comet .serde .Compatible
@@ -1072,6 +1072,30 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
10721072 }
10731073 }
10741074
1075+ test(" cast ArrayType to ArrayType" ) {
1076+ val types = Seq (
1077+ BooleanType ,
1078+ StringType ,
1079+ ByteType ,
1080+ IntegerType ,
1081+ LongType ,
1082+ ShortType ,
1083+ DecimalType (10 , 2 ),
1084+ DecimalType (38 , 18 ))
1085+ for (fromType <- types) {
1086+ for (toType <- types) {
1087+ if (fromType != toType &&
1088+ ! tags
1089+ .get(s " cast $fromType to $toType" )
1090+ .exists(s => s.contains(" org.scalatest.Ignore" )) &&
1091+ Cast .canCast(fromType, toType) &&
1092+ CometCast .isSupported(fromType, toType, None , CometEvalMode .LEGACY ) == Compatible ()) {
1093+ castTest(generateArrays(100 , fromType), ArrayType (toType))
1094+ }
1095+ }
1096+ }
1097+ }
1098+
10751099 private def generateFloats (): DataFrame = {
10761100 withNulls(gen.generateFloats(dataSize)).toDF(" a" )
10771101 }
@@ -1100,10 +1124,10 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
11001124 withNulls(gen.generateLongs(dataSize)).toDF(" a" )
11011125 }
11021126
1103- private def generateArrays (rowSize : Int , elementType : DataType ): DataFrame = {
1127+ private def generateArrays (rowNum : Int , elementType : DataType ): DataFrame = {
11041128 import scala .collection .JavaConverters ._
11051129 val schema = StructType (Seq (StructField (" a" , ArrayType (elementType), true )))
1106- spark.createDataFrame(gen.generateRows(rowSize , schema).asJava, schema)
1130+ spark.createDataFrame(gen.generateRows(rowNum , schema).asJava, schema)
11071131 }
11081132
11091133 // https://github.com/apache/datafusion-comet/issues/2038
0 commit comments