Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, AQEShuffleReadExec, BroadcastQueryStageExec, ShuffleQueryStageExec}
import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec}
import org.apache.spark.sql.execution.command.{DataWritingCommandExec, ExecutedCommandExec}
import org.apache.spark.sql.execution.datasources.WriteFilesExec
import org.apache.spark.sql.execution.datasources.csv.CSVFileFormat
import org.apache.spark.sql.execution.datasources.json.JsonFileFormat
import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat
Expand Down Expand Up @@ -197,6 +198,14 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] {
case op if shouldApplySparkToColumnar(conf, op) =>
convertToComet(op, CometSparkToColumnarExec).getOrElse(op)

// AQE reoptimization looks for `DataWritingCommandExec` or `WriteFilesExec`
// if there is none it would reinsert write nodes, and since Comet remap those nodes
// to Comet counterparties the write nodes are twice to the plan.
// Checking if AQE inserted another write Command on top of existing write command
case _ @DataWritingCommandExec(_, w: WriteFilesExec)
if w.child.isInstanceOf[CometNativeWriteExec] =>
w.child

case op: DataWritingCommandExec =>
convertToComet(op, CometDataWritingCommand).getOrElse(op)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,8 @@ class CometParquetWriterSuite extends CometTestBase {

private def writeWithCometNativeWriteExec(
inputPath: String,
outputPath: String): Option[QueryExecution] = {
outputPath: String,
num_partitions: Option[Int] = None): Option[QueryExecution] = {
val df = spark.read.parquet(inputPath)

// Use a listener to capture the execution plan during write
Expand All @@ -77,8 +78,8 @@ class CometParquetWriterSuite extends CometTestBase {
spark.listenerManager.register(listener)

try {
// Perform native write
df.write.parquet(outputPath)
// Perform native write with optional partitioning
num_partitions.fold(df)(n => df.repartition(n)).write.parquet(outputPath)

// Wait for listener to be called with timeout
val maxWaitTimeMs = 15000
Expand All @@ -97,20 +98,25 @@ class CometParquetWriterSuite extends CometTestBase {
s"Listener was not called within ${maxWaitTimeMs}ms - no execution plan captured")

capturedPlan.foreach { qe =>
val executedPlan = qe.executedPlan
val hasNativeWrite = executedPlan.exists {
case _: CometNativeWriteExec => true
val executedPlan = stripAQEPlan(qe.executedPlan)

// Count CometNativeWriteExec instances in the plan
var nativeWriteCount = 0
executedPlan.foreach {
case _: CometNativeWriteExec =>
nativeWriteCount += 1
case d: DataWritingCommandExec =>
d.child.exists {
case _: CometNativeWriteExec => true
case _ => false
d.child.foreach {
case _: CometNativeWriteExec =>
nativeWriteCount += 1
case _ =>
}
case _ => false
case _ =>
}

assert(
hasNativeWrite,
s"Expected CometNativeWriteExec in the plan, but got:\n${executedPlan.treeString}")
nativeWriteCount == 1,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

check the count is just 1

s"Expected exactly one CometNativeWriteExec in the plan, but found $nativeWriteCount:\n${executedPlan.treeString}")
}
} finally {
spark.listenerManager.unregister(listener)
Expand Down Expand Up @@ -197,4 +203,29 @@ class CometParquetWriterSuite extends CometTestBase {
}
}
}

test("basic parquet write with repartition") {
withTempPath { dir =>
// Create test data and write it to a temp parquet file first
withTempPath { inputDir =>
val inputPath = createTestData(inputDir)
Seq(true, false).foreach(adaptive => {
// Create a new output path for each AQE value
val outputPath = new File(dir, s"output_aqe_$adaptive.parquet").getAbsolutePath

withSQLConf(
CometConf.COMET_NATIVE_PARQUET_WRITE_ENABLED.key -> "true",
"spark.sql.adaptive.enabled" -> adaptive.toString,
SQLConf.SESSION_LOCAL_TIMEZONE.key -> "America/Halifax",
CometConf.getOperatorAllowIncompatConfigKey(
classOf[DataWritingCommandExec]) -> "true",
CometConf.COMET_EXEC_ENABLED.key -> "true") {

writeWithCometNativeWriteExec(inputPath, outputPath, Some(10))
verifyWrittenFile(outputPath)
}
})
}
}
}
}
Loading