diff --git a/spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala b/spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala index ed48e36f07..bb4ce879d7 100644 --- a/spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala +++ b/spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala @@ -33,6 +33,7 @@ import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, AQEShuffleReadExec, BroadcastQueryStageExec, ShuffleQueryStageExec} import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec} import org.apache.spark.sql.execution.command.{DataWritingCommandExec, ExecutedCommandExec} +import org.apache.spark.sql.execution.datasources.WriteFilesExec import org.apache.spark.sql.execution.datasources.csv.CSVFileFormat import org.apache.spark.sql.execution.datasources.json.JsonFileFormat import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat @@ -197,6 +198,14 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] { case op if shouldApplySparkToColumnar(conf, op) => convertToComet(op, CometSparkToColumnarExec).getOrElse(op) + // AQE reoptimization looks for `DataWritingCommandExec` or `WriteFilesExec` + // if there is none it would reinsert write nodes, and since Comet remap those nodes + // to Comet counterparties the write nodes are twice to the plan. + // Checking if AQE inserted another write Command on top of existing write command + case _ @DataWritingCommandExec(_, w: WriteFilesExec) + if w.child.isInstanceOf[CometNativeWriteExec] => + w.child + case op: DataWritingCommandExec => convertToComet(op, CometDataWritingCommand).getOrElse(op) diff --git a/spark/src/test/scala/org/apache/comet/parquet/CometParquetWriterSuite.scala b/spark/src/test/scala/org/apache/comet/parquet/CometParquetWriterSuite.scala index 2ea697fd4d..3ae7f949ab 100644 --- a/spark/src/test/scala/org/apache/comet/parquet/CometParquetWriterSuite.scala +++ b/spark/src/test/scala/org/apache/comet/parquet/CometParquetWriterSuite.scala @@ -54,7 +54,8 @@ class CometParquetWriterSuite extends CometTestBase { private def writeWithCometNativeWriteExec( inputPath: String, - outputPath: String): Option[QueryExecution] = { + outputPath: String, + num_partitions: Option[Int] = None): Option[QueryExecution] = { val df = spark.read.parquet(inputPath) // Use a listener to capture the execution plan during write @@ -77,8 +78,8 @@ class CometParquetWriterSuite extends CometTestBase { spark.listenerManager.register(listener) try { - // Perform native write - df.write.parquet(outputPath) + // Perform native write with optional partitioning + num_partitions.fold(df)(n => df.repartition(n)).write.parquet(outputPath) // Wait for listener to be called with timeout val maxWaitTimeMs = 15000 @@ -97,20 +98,25 @@ class CometParquetWriterSuite extends CometTestBase { s"Listener was not called within ${maxWaitTimeMs}ms - no execution plan captured") capturedPlan.foreach { qe => - val executedPlan = qe.executedPlan - val hasNativeWrite = executedPlan.exists { - case _: CometNativeWriteExec => true + val executedPlan = stripAQEPlan(qe.executedPlan) + + // Count CometNativeWriteExec instances in the plan + var nativeWriteCount = 0 + executedPlan.foreach { + case _: CometNativeWriteExec => + nativeWriteCount += 1 case d: DataWritingCommandExec => - d.child.exists { - case _: CometNativeWriteExec => true - case _ => false + d.child.foreach { + case _: CometNativeWriteExec => + nativeWriteCount += 1 + case _ => } - case _ => false + case _ => } assert( - hasNativeWrite, - s"Expected CometNativeWriteExec in the plan, but got:\n${executedPlan.treeString}") + nativeWriteCount == 1, + s"Expected exactly one CometNativeWriteExec in the plan, but found $nativeWriteCount:\n${executedPlan.treeString}") } } finally { spark.listenerManager.unregister(listener) @@ -197,4 +203,29 @@ class CometParquetWriterSuite extends CometTestBase { } } } + + test("basic parquet write with repartition") { + withTempPath { dir => + // Create test data and write it to a temp parquet file first + withTempPath { inputDir => + val inputPath = createTestData(inputDir) + Seq(true, false).foreach(adaptive => { + // Create a new output path for each AQE value + val outputPath = new File(dir, s"output_aqe_$adaptive.parquet").getAbsolutePath + + withSQLConf( + CometConf.COMET_NATIVE_PARQUET_WRITE_ENABLED.key -> "true", + "spark.sql.adaptive.enabled" -> adaptive.toString, + SQLConf.SESSION_LOCAL_TIMEZONE.key -> "America/Halifax", + CometConf.getOperatorAllowIncompatConfigKey( + classOf[DataWritingCommandExec]) -> "true", + CometConf.COMET_EXEC_ENABLED.key -> "true") { + + writeWithCometNativeWriteExec(inputPath, outputPath, Some(10)) + verifyWrittenFile(outputPath) + } + }) + } + } + } }