diff --git a/.github/workflows/spark_sql_test.yml b/.github/workflows/spark_sql_test.yml
index 3d7aa2e2f9..cdfb233e97 100644
--- a/.github/workflows/spark_sql_test.yml
+++ b/.github/workflows/spark_sql_test.yml
@@ -50,7 +50,7 @@ jobs:
strategy:
matrix:
os: [ubuntu-24.04]
- spark-version: [{short: '3.4', full: '3.4.3', java: 11}, {short: '3.5', full: '3.5.7', java: 11}, {short: '4.0', full: '4.0.1', java: 17}]
+ spark-version: [{short: '3.4', full: '3.4.3', java: 11}, {short: '3.5', full: '3.5.7', java: 11}, {short: '4.1', full: '4.1.0', java: 17}]
module:
- {name: "catalyst", args1: "catalyst/test", args2: ""}
- {name: "sql_core-1", args1: "", args2: sql/testOnly * -- -l org.apache.spark.tags.ExtendedSQLTest -l org.apache.spark.tags.SlowSQLTest}
@@ -59,9 +59,9 @@ jobs:
- {name: "sql_hive-1", args1: "", args2: "hive/testOnly * -- -l org.apache.spark.tags.ExtendedHiveTest -l org.apache.spark.tags.SlowHiveTest"}
- {name: "sql_hive-2", args1: "", args2: "hive/testOnly * -- -n org.apache.spark.tags.ExtendedHiveTest"}
- {name: "sql_hive-3", args1: "", args2: "hive/testOnly * -- -n org.apache.spark.tags.SlowHiveTest"}
- # Skip sql_hive-1 for Spark 4.0 due to https://github.com/apache/datafusion-comet/issues/2946
+ # Skip sql_hive-1 for Spark 4.1 due to https://github.com/apache/datafusion-comet/issues/2946
exclude:
- - spark-version: {short: '4.0', full: '4.0.1', java: 17}
+ - spark-version: {short: '4.1', full: '4.1.0', java: 17}
module: {name: "sql_hive-1", args1: "", args2: "hive/testOnly * -- -l org.apache.spark.tags.ExtendedHiveTest -l org.apache.spark.tags.SlowHiveTest"}
fail-fast: false
name: spark-sql-${{ matrix.module.name }}/${{ matrix.os }}/spark-${{ matrix.spark-version.full }}/java-${{ matrix.spark-version.java }}
diff --git a/dev/diffs/4.0.1.diff b/dev/diffs/4.1.0.diff
similarity index 94%
rename from dev/diffs/4.0.1.diff
rename to dev/diffs/4.1.0.diff
index a9315db005..241c864683 100644
--- a/dev/diffs/4.0.1.diff
+++ b/dev/diffs/4.1.0.diff
@@ -1,8 +1,8 @@
diff --git a/pom.xml b/pom.xml
-index 22922143fc3..477d4ec4194 100644
+index 1824a28614b..0d0ce6d27ce 100644
--- a/pom.xml
+++ b/pom.xml
-@@ -148,6 +148,8 @@
+@@ -152,6 +152,8 @@
4.0.3
2.5.3
2.0.8
@@ -11,7 +11,7 @@ index 22922143fc3..477d4ec4194 100644
- spark-4.0
+ spark-4.1
2.13.16
2.13
- 4.0.1
- 4.0
+ 4.1.0
+ 4.1
1.15.2
4.13.6
2.0.16
- spark-4.0
+ spark-4.1
not-needed-yet
17
diff --git a/spark/pom.xml b/spark/pom.xml
index 3b832e37a2..13f7772fe6 100644
--- a/spark/pom.xml
+++ b/spark/pom.xml
@@ -233,7 +233,7 @@ under the License.
- spark-4.0
+ spark-4.1
org.apache.iceberg
@@ -241,7 +241,7 @@ under the License.
1.10.0
test
-
+
org.eclipse.jetty
jetty-server
diff --git a/spark/src/main/java/org/apache/spark/sql/comet/execution/shuffle/CometBypassMergeSortShuffleWriter.java b/spark/src/main/java/org/apache/spark/sql/comet/execution/shuffle/CometBypassMergeSortShuffleWriter.java
index a58ec7851b..6dc3fd6d7c 100644
--- a/spark/src/main/java/org/apache/spark/sql/comet/execution/shuffle/CometBypassMergeSortShuffleWriter.java
+++ b/spark/src/main/java/org/apache/spark/sql/comet/execution/shuffle/CometBypassMergeSortShuffleWriter.java
@@ -41,7 +41,6 @@
import org.apache.spark.memory.TaskMemoryManager;
import org.apache.spark.network.shuffle.checksum.ShuffleChecksumHelper;
import org.apache.spark.scheduler.MapStatus;
-import org.apache.spark.scheduler.MapStatus$;
import org.apache.spark.serializer.SerializerInstance;
import org.apache.spark.shuffle.ShuffleWriteMetricsReporter;
import org.apache.spark.shuffle.ShuffleWriter;
@@ -54,6 +53,7 @@
import org.apache.spark.shuffle.comet.CometShuffleMemoryAllocatorTrait;
import org.apache.spark.shuffle.sort.CometShuffleExternalSorter;
import org.apache.spark.sql.catalyst.expressions.UnsafeRow;
+import org.apache.spark.sql.comet.shims.ShimMapStatus$;
import org.apache.spark.sql.types.StructType;
import org.apache.spark.storage.BlockManager;
import org.apache.spark.storage.FileSegment;
@@ -172,7 +172,7 @@ public void write(Iterator> records) throws IOException {
.commitAllPartitions(ShuffleChecksumHelper.EMPTY_CHECKSUM_VALUE)
.getPartitionLengths();
mapStatus =
- MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths, mapId);
+ ShimMapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths, mapId);
return;
}
final long openStartTime = System.nanoTime();
@@ -261,7 +261,8 @@ public void write(Iterator> records) throws IOException {
// TODO: We probably can move checksum generation here when concatenating partition files
partitionLengths = writePartitionedData(mapOutputWriter);
- mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths, mapId);
+ mapStatus =
+ ShimMapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths, mapId);
} catch (Exception e) {
try {
mapOutputWriter.abort(e);
diff --git a/spark/src/main/java/org/apache/spark/sql/comet/execution/shuffle/CometUnsafeShuffleWriter.java b/spark/src/main/java/org/apache/spark/sql/comet/execution/shuffle/CometUnsafeShuffleWriter.java
index a845e743d4..d736512a86 100644
--- a/spark/src/main/java/org/apache/spark/sql/comet/execution/shuffle/CometUnsafeShuffleWriter.java
+++ b/spark/src/main/java/org/apache/spark/sql/comet/execution/shuffle/CometUnsafeShuffleWriter.java
@@ -50,7 +50,6 @@
import org.apache.spark.network.shuffle.checksum.ShuffleChecksumHelper;
import org.apache.spark.network.util.LimitedInputStream;
import org.apache.spark.scheduler.MapStatus;
-import org.apache.spark.scheduler.MapStatus$;
import org.apache.spark.serializer.SerializationStream;
import org.apache.spark.serializer.SerializerInstance;
import org.apache.spark.shuffle.BaseShuffleHandle;
@@ -67,6 +66,7 @@
import org.apache.spark.shuffle.sort.SortShuffleManager;
import org.apache.spark.shuffle.sort.UnsafeShuffleWriter;
import org.apache.spark.sql.catalyst.expressions.UnsafeRow;
+import org.apache.spark.sql.comet.shims.ShimMapStatus$;
import org.apache.spark.sql.types.StructType;
import org.apache.spark.storage.BlockManager;
import org.apache.spark.storage.TimeTrackingOutputStream;
@@ -288,7 +288,8 @@ void closeAndWriteOutput() throws IOException {
}
}
}
- mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths, mapId);
+ mapStatus =
+ ShimMapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths, mapId);
}
@VisibleForTesting
diff --git a/spark/src/main/scala/org/apache/comet/serde/aggregates.scala b/spark/src/main/scala/org/apache/comet/serde/aggregates.scala
index 8ab568dc83..7e3e978d8f 100644
--- a/spark/src/main/scala/org/apache/comet/serde/aggregates.scala
+++ b/spark/src/main/scala/org/apache/comet/serde/aggregates.scala
@@ -30,7 +30,7 @@ import org.apache.comet.CometConf
import org.apache.comet.CometConf.COMET_EXEC_STRICT_FLOATING_POINT
import org.apache.comet.CometSparkSessionExtensions.withInfo
import org.apache.comet.serde.QueryPlanSerde.{evalModeToProto, exprToProto, serializeDataType}
-import org.apache.comet.shims.CometEvalModeUtil
+import org.apache.comet.shims.{CometAggShim, CometEvalModeUtil}
object CometMin extends CometAggregateExpressionSerde[Min] {
@@ -214,7 +214,7 @@ object CometAverage extends CometAggregateExpressionSerde[Average] {
object CometSum extends CometAggregateExpressionSerde[Sum] {
override def getSupportLevel(sum: Sum): SupportLevel = {
- sum.evalMode match {
+ CometAggShim.getSumEvalMode(sum) match {
case EvalMode.ANSI if !sum.dataType.isInstanceOf[DecimalType] =>
Incompatible(Some("ANSI mode for non decimal inputs is not supported"))
case EvalMode.TRY if !sum.dataType.isInstanceOf[DecimalType] =>
@@ -243,7 +243,8 @@ object CometSum extends CometAggregateExpressionSerde[Sum] {
val builder = ExprOuterClass.Sum.newBuilder()
builder.setChild(childExpr.get)
builder.setDatatype(dataType.get)
- builder.setEvalMode(evalModeToProto(CometEvalModeUtil.fromSparkEvalMode(sum.evalMode)))
+ val evalMode = CometEvalModeUtil.fromSparkEvalMode(CometAggShim.getSumEvalMode(sum))
+ builder.setEvalMode(evalModeToProto(evalMode))
Some(
ExprOuterClass.AggExpr
diff --git a/spark/src/main/spark-3.x/org/apache/comet/shims/CometAggShim.scala b/spark/src/main/spark-3.x/org/apache/comet/shims/CometAggShim.scala
new file mode 100644
index 0000000000..2c311a6373
--- /dev/null
+++ b/spark/src/main/spark-3.x/org/apache/comet/shims/CometAggShim.scala
@@ -0,0 +1,27 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.comet.shims
+
+import org.apache.spark.sql.catalyst.expressions.EvalMode
+import org.apache.spark.sql.catalyst.expressions.aggregate.Sum
+
+object CometAggShim {
+ def getSumEvalMode(sum: Sum): EvalMode.Value = sum.evalMode
+}
diff --git a/spark/src/main/spark-3.x/org/apache/spark/sql/comet/shims/ShimMapStatus.scala b/spark/src/main/spark-3.x/org/apache/spark/sql/comet/shims/ShimMapStatus.scala
new file mode 100644
index 0000000000..280a086602
--- /dev/null
+++ b/spark/src/main/spark-3.x/org/apache/spark/sql/comet/shims/ShimMapStatus.scala
@@ -0,0 +1,29 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.spark.sql.comet.shims
+
+import org.apache.spark.scheduler.MapStatus
+import org.apache.spark.storage.BlockManagerId
+
+object ShimMapStatus {
+ def apply(loc: BlockManagerId, uncompressedSizes: Array[Long], mapTaskId: Long): MapStatus = {
+ MapStatus(loc, uncompressedSizes, mapTaskId)
+ }
+}
diff --git a/spark/src/main/spark-4.1/org/apache/comet/shims/CometAggShim.scala b/spark/src/main/spark-4.1/org/apache/comet/shims/CometAggShim.scala
new file mode 100644
index 0000000000..d34915389c
--- /dev/null
+++ b/spark/src/main/spark-4.1/org/apache/comet/shims/CometAggShim.scala
@@ -0,0 +1,27 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.comet.shims
+
+import org.apache.spark.sql.catalyst.expressions.EvalMode
+import org.apache.spark.sql.catalyst.expressions.aggregate.Sum
+
+object CometAggShim {
+ def getSumEvalMode(sum: Sum): EvalMode.Value = sum.evalContext.evalMode
+}
diff --git a/spark/src/main/spark-4.0/org/apache/comet/shims/CometExprShim.scala b/spark/src/main/spark-4.1/org/apache/comet/shims/CometExprShim.scala
similarity index 97%
rename from spark/src/main/spark-4.0/org/apache/comet/shims/CometExprShim.scala
rename to spark/src/main/spark-4.1/org/apache/comet/shims/CometExprShim.scala
index fc3db183b3..21d18f1e1e 100644
--- a/spark/src/main/spark-4.0/org/apache/comet/shims/CometExprShim.scala
+++ b/spark/src/main/spark-4.1/org/apache/comet/shims/CometExprShim.scala
@@ -39,9 +39,7 @@ trait CometExprShim extends CommonStringExprs {
CometEvalModeUtil.fromSparkEvalMode(c.evalMode)
protected def binaryOutputStyle: BinaryOutputStyle = {
- SQLConf.get
- .getConf(SQLConf.BINARY_OUTPUT_STYLE)
- .map(SQLConf.BinaryOutputStyle.withName) match {
+ SQLConf.get.getConf(SQLConf.BINARY_OUTPUT_STYLE) match {
case Some(SQLConf.BinaryOutputStyle.UTF8) => BinaryOutputStyle.UTF8
case Some(SQLConf.BinaryOutputStyle.BASIC) => BinaryOutputStyle.BASIC
case Some(SQLConf.BinaryOutputStyle.BASE64) => BinaryOutputStyle.BASE64
diff --git a/spark/src/main/spark-4.0/org/apache/comet/shims/ShimCometBroadcastExchangeExec.scala b/spark/src/main/spark-4.1/org/apache/comet/shims/ShimCometBroadcastExchangeExec.scala
similarity index 100%
rename from spark/src/main/spark-4.0/org/apache/comet/shims/ShimCometBroadcastExchangeExec.scala
rename to spark/src/main/spark-4.1/org/apache/comet/shims/ShimCometBroadcastExchangeExec.scala
diff --git a/spark/src/main/spark-4.0/org/apache/comet/shims/ShimCometShuffleExchangeExec.scala b/spark/src/main/spark-4.1/org/apache/comet/shims/ShimCometShuffleExchangeExec.scala
similarity index 100%
rename from spark/src/main/spark-4.0/org/apache/comet/shims/ShimCometShuffleExchangeExec.scala
rename to spark/src/main/spark-4.1/org/apache/comet/shims/ShimCometShuffleExchangeExec.scala
diff --git a/spark/src/main/spark-4.0/org/apache/comet/shims/ShimCometSparkSessionExtensions.scala b/spark/src/main/spark-4.1/org/apache/comet/shims/ShimCometSparkSessionExtensions.scala
similarity index 100%
rename from spark/src/main/spark-4.0/org/apache/comet/shims/ShimCometSparkSessionExtensions.scala
rename to spark/src/main/spark-4.1/org/apache/comet/shims/ShimCometSparkSessionExtensions.scala
diff --git a/spark/src/main/spark-4.0/org/apache/comet/shims/ShimSQLConf.scala b/spark/src/main/spark-4.1/org/apache/comet/shims/ShimSQLConf.scala
similarity index 100%
rename from spark/src/main/spark-4.0/org/apache/comet/shims/ShimSQLConf.scala
rename to spark/src/main/spark-4.1/org/apache/comet/shims/ShimSQLConf.scala
diff --git a/spark/src/main/spark-4.0/org/apache/spark/comet/shims/ShimCometDriverPlugin.scala b/spark/src/main/spark-4.1/org/apache/spark/comet/shims/ShimCometDriverPlugin.scala
similarity index 100%
rename from spark/src/main/spark-4.0/org/apache/spark/comet/shims/ShimCometDriverPlugin.scala
rename to spark/src/main/spark-4.1/org/apache/spark/comet/shims/ShimCometDriverPlugin.scala
diff --git a/spark/src/main/spark-4.0/org/apache/spark/sql/comet/shims/ShimCometScanExec.scala b/spark/src/main/spark-4.1/org/apache/spark/sql/comet/shims/ShimCometScanExec.scala
similarity index 100%
rename from spark/src/main/spark-4.0/org/apache/spark/sql/comet/shims/ShimCometScanExec.scala
rename to spark/src/main/spark-4.1/org/apache/spark/sql/comet/shims/ShimCometScanExec.scala
diff --git a/spark/src/main/spark-4.0/org/apache/spark/sql/comet/shims/ShimCometShuffleWriteProcessor.scala b/spark/src/main/spark-4.1/org/apache/spark/sql/comet/shims/ShimCometShuffleWriteProcessor.scala
similarity index 100%
rename from spark/src/main/spark-4.0/org/apache/spark/sql/comet/shims/ShimCometShuffleWriteProcessor.scala
rename to spark/src/main/spark-4.1/org/apache/spark/sql/comet/shims/ShimCometShuffleWriteProcessor.scala
diff --git a/spark/src/main/spark-4.1/org/apache/spark/sql/comet/shims/ShimMapStatus.scala b/spark/src/main/spark-4.1/org/apache/spark/sql/comet/shims/ShimMapStatus.scala
new file mode 100644
index 0000000000..f805a9b542
--- /dev/null
+++ b/spark/src/main/spark-4.1/org/apache/spark/sql/comet/shims/ShimMapStatus.scala
@@ -0,0 +1,29 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.spark.sql.comet.shims
+
+import org.apache.spark.scheduler.MapStatus
+import org.apache.spark.storage.BlockManagerId
+
+object ShimMapStatus {
+ def apply(loc: BlockManagerId, uncompressedSizes: Array[Long], mapTaskId: Long): MapStatus = {
+ MapStatus(loc, uncompressedSizes, mapTaskId, 0L)
+ }
+}
diff --git a/spark/src/main/spark-4.0/org/apache/spark/sql/comet/shims/ShimStreamSourceAwareSparkPlan.scala b/spark/src/main/spark-4.1/org/apache/spark/sql/comet/shims/ShimStreamSourceAwareSparkPlan.scala
similarity index 100%
rename from spark/src/main/spark-4.0/org/apache/spark/sql/comet/shims/ShimStreamSourceAwareSparkPlan.scala
rename to spark/src/main/spark-4.1/org/apache/spark/sql/comet/shims/ShimStreamSourceAwareSparkPlan.scala
diff --git a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala
index a7bd6febf8..c20b4ef824 100644
--- a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala
@@ -31,9 +31,9 @@ import org.apache.spark.sql.catalyst.expressions.Cast
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.internal.SQLConf
-import org.apache.spark.sql.types.{ArrayType, BooleanType, ByteType, DataType, DataTypes, DecimalType, IntegerType, LongType, ShortType, StringType, StructField, StructType}
+import org.apache.spark.sql.types.{ArrayType, BooleanType, ByteType, DataType, DataTypes, DecimalType, IntegerType, LongType, ShortType, StringType, StructField, StructType}import org.apache.comet.expressions.{CometCast, CometEvalMode}
-import org.apache.comet.expressions.{CometCast, CometEvalMode}
+import org.apache.comet.CometSparkSessionExtensions.isSpark40Plus
import org.apache.comet.rules.CometScanTypeChecker
import org.apache.comet.serde.Compatible