Skip to content

Commit 2f8d42f

Browse files
committed
fix: array to array cast
1 parent 1b3354b commit 2f8d42f

File tree

3 files changed

+41
-6
lines changed

3 files changed

+41
-6
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,3 +18,4 @@ apache-rat-*.jar
1818
venv
1919
dev/release/comet-rm/workdir
2020
spark/benchmarks
21+
/comet-event-trace.json

native/spark-expr/src/conversion_funcs/cast.rs

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1031,8 +1031,18 @@ fn cast_array(
10311031
cast_options,
10321032
)?),
10331033
(List(_), Utf8) => Ok(cast_array_to_string(array.as_list(), cast_options)?),
1034-
(List(_), List(_)) if can_cast_types(from_type, to_type) => {
1035-
Ok(cast_with_options(&array, to_type, &CAST_OPTIONS)?)
1034+
(List(from), List(to))
1035+
if can_cast_types(from_type, to_type)
1036+
|| (matches!(from.data_type(), Decimal128(_, _))
1037+
&& matches!(to.data_type(), Boolean)) =>
1038+
{
1039+
let list_array = array.as_list::<i32>();
1040+
Ok(Arc::new(ListArray::new(
1041+
Arc::clone(to),
1042+
list_array.offsets().clone(),
1043+
cast_array(Arc::clone(list_array.values()), to.data_type(), cast_options)?,
1044+
list_array.nulls().cloned(),
1045+
)) as ArrayRef)
10361046
}
10371047
(UInt8 | UInt16 | UInt32 | UInt64, Int8 | Int16 | Int32 | Int64)
10381048
if cast_options.allow_cast_unsigned_ints =>

spark/src/test/scala/org/apache/comet/CometCastSuite.scala

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ package org.apache.comet
2222
import java.io.File
2323

2424
import scala.collection.mutable.ListBuffer
25-
import scala.util.Random
25+
import scala.util.{Failure, Random, Success, Try}
2626
import scala.util.matching.Regex
2727

2828
import org.apache.hadoop.fs.Path
@@ -33,7 +33,7 @@ import org.apache.spark.sql.functions.col
3333
import org.apache.spark.sql.internal.SQLConf
3434
import org.apache.spark.sql.types.{ArrayType, BooleanType, ByteType, DataType, DataTypes, DecimalType, IntegerType, LongType, ShortType, StringType, StructField, StructType}
3535

36-
import org.apache.comet.CometSparkSessionExtensions.isSpark40Plus
36+
import org.apache.comet.CometSparkSessionExtensions.{hasExplainInfo, isSpark40Plus}
3737
import org.apache.comet.expressions.{CometCast, CometEvalMode}
3838
import org.apache.comet.rules.CometScanTypeChecker
3939
import org.apache.comet.serde.Compatible
@@ -1072,6 +1072,30 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
10721072
}
10731073
}
10741074

1075+
test("cast ArrayType to ArrayType") {
1076+
val types = Seq(
1077+
BooleanType,
1078+
StringType,
1079+
ByteType,
1080+
IntegerType,
1081+
LongType,
1082+
ShortType,
1083+
DecimalType(10, 2),
1084+
DecimalType(38, 18))
1085+
for (fromType <- types) {
1086+
for (toType <- types) {
1087+
if (fromType != toType &&
1088+
!tags
1089+
.get(s"cast $fromType to $toType")
1090+
.exists(s => s.contains("org.scalatest.Ignore")) &&
1091+
Cast.canCast(fromType, toType) &&
1092+
CometCast.isSupported(fromType, toType, None, CometEvalMode.LEGACY) == Compatible()) {
1093+
castTest(generateArrays(100, fromType), ArrayType(toType))
1094+
}
1095+
}
1096+
}
1097+
}
1098+
10751099
private def generateFloats(): DataFrame = {
10761100
withNulls(gen.generateFloats(dataSize)).toDF("a")
10771101
}
@@ -1100,10 +1124,10 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
11001124
withNulls(gen.generateLongs(dataSize)).toDF("a")
11011125
}
11021126

1103-
private def generateArrays(rowSize: Int, elementType: DataType): DataFrame = {
1127+
private def generateArrays(rowNum: Int, elementType: DataType): DataFrame = {
11041128
import scala.collection.JavaConverters._
11051129
val schema = StructType(Seq(StructField("a", ArrayType(elementType), true)))
1106-
spark.createDataFrame(gen.generateRows(rowSize, schema).asJava, schema)
1130+
spark.createDataFrame(gen.generateRows(rowNum, schema).asJava, schema)
11071131
}
11081132

11091133
// https://github.com/apache/datafusion-comet/issues/2038

0 commit comments

Comments
 (0)