diff --git a/spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala b/spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala index 76e741e3bf..1b0177d7ef 100644 --- a/spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala +++ b/spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala @@ -483,15 +483,16 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] { val serde = handler.asInstanceOf[CometOperatorSerde[SparkPlan]] if (isOperatorEnabled(serde, op)) { // For operators that require native children (like writes), check if all data-producing - // children are CometNativeExec. This prevents runtime failures when the native operator - // expects Arrow arrays but receives non-Arrow data (e.g., OnHeapColumnVector). + // children are CometExec (which includes CometNativeExec and sink operators like + // CometUnionExec, CometCoalesceExec, etc.). This prevents runtime failures when the + // native operator expects Arrow arrays but receives non-Arrow data. if (serde.requiresNativeChildren && op.children.nonEmpty) { // Get the actual data-producing children (unwrap WriteFilesExec if present) val dataProducingChildren = op.children.flatMap { case writeFiles: WriteFilesExec => Seq(writeFiles.child) case other => Seq(other) } - if (!dataProducingChildren.forall(_.isInstanceOf[CometNativeExec])) { + if (!dataProducingChildren.forall(_.isInstanceOf[CometExec])) { withInfo(op, "Cannot perform native operation because input is not in Arrow format") return None } diff --git a/spark/src/test/scala/org/apache/comet/parquet/CometParquetWriterSuite.scala b/spark/src/test/scala/org/apache/comet/parquet/CometParquetWriterSuite.scala index e4c405c003..40e9ffdff9 100644 --- a/spark/src/test/scala/org/apache/comet/parquet/CometParquetWriterSuite.scala +++ b/spark/src/test/scala/org/apache/comet/parquet/CometParquetWriterSuite.scala @@ -140,6 +140,165 @@ class CometParquetWriterSuite extends CometTestBase { } } + // Test for issue #3429: CTAS with UNION fails in Spark 4.x with native writer + test("parquet write with union - CTAS style") { + withTempPath { dir => + val outputPath = new File(dir, "output.parquet").getAbsolutePath + + withSQLConf( + CometConf.COMET_NATIVE_PARQUET_WRITE_ENABLED.key -> "true", + CometConf.getOperatorAllowIncompatConfigKey(classOf[DataWritingCommandExec]) -> "true", + CometConf.COMET_EXEC_ENABLED.key -> "true") { + + // Create a DataFrame using UNION - simulating CTAS with UNION pattern + val df1 = spark.range(1, 5).toDF("id") + val df2 = spark.range(10, 15).toDF("id") + val unionDf = df1.union(df2) + + // Write using parquet - this is similar to CTAS + val plan = captureWritePlan(path => unionDf.write.parquet(path), outputPath) + + // Verify the write completed and data is correct + val result = spark.read.parquet(outputPath) + assert(result.count() == 9, "Expected 9 rows (4 + 5)") + + // Verify native write was used + assertHasCometNativeWriteExec(plan) + } + } + } + + // Corner case: UNION with multiple (3+) DataFrames + test("parquet write with multiple unions") { + withTempPath { dir => + val outputPath = new File(dir, "output.parquet").getAbsolutePath + + withSQLConf( + CometConf.COMET_NATIVE_PARQUET_WRITE_ENABLED.key -> "true", + CometConf.getOperatorAllowIncompatConfigKey(classOf[DataWritingCommandExec]) -> "true", + CometConf.COMET_EXEC_ENABLED.key -> "true") { + + val df1 = spark.range(1, 4).toDF("id") + val df2 = spark.range(10, 13).toDF("id") + val df3 = spark.range(20, 23).toDF("id") + val df4 = spark.range(30, 33).toDF("id") + val unionDf = df1.union(df2).union(df3).union(df4) + + val plan = captureWritePlan(path => unionDf.write.parquet(path), outputPath) + + val result = spark.read.parquet(outputPath) + assert(result.count() == 12, "Expected 12 rows (3 + 3 + 3 + 3)") + + assertHasCometNativeWriteExec(plan) + } + } + } + + // Corner case: UNION followed by coalesce + test("parquet write with union and coalesce") { + withTempPath { dir => + val outputPath = new File(dir, "output.parquet").getAbsolutePath + + withSQLConf( + CometConf.COMET_NATIVE_PARQUET_WRITE_ENABLED.key -> "true", + CometConf.getOperatorAllowIncompatConfigKey(classOf[DataWritingCommandExec]) -> "true", + CometConf.COMET_EXEC_ENABLED.key -> "true") { + + val df1 = spark.range(1, 50).toDF("id") + val df2 = spark.range(100, 150).toDF("id") + val unionDf = df1.union(df2).coalesce(2) + + val plan = captureWritePlan(path => unionDf.write.parquet(path), outputPath) + + val result = spark.read.parquet(outputPath) + assert(result.count() == 98, "Expected 98 rows (49 + 49)") + + assertHasCometNativeWriteExec(plan) + } + } + } + + // Corner case: UNION with filter + test("parquet write with union and filter") { + withTempPath { dir => + val outputPath = new File(dir, "output.parquet").getAbsolutePath + + withSQLConf( + CometConf.COMET_NATIVE_PARQUET_WRITE_ENABLED.key -> "true", + CometConf.getOperatorAllowIncompatConfigKey(classOf[DataWritingCommandExec]) -> "true", + CometConf.COMET_EXEC_ENABLED.key -> "true") { + + val df1 = spark.range(1, 10).toDF("id") + val df2 = spark.range(20, 30).toDF("id") + val unionDf = df1.union(df2).filter("id % 2 = 0") + + val plan = captureWritePlan(path => unionDf.write.parquet(path), outputPath) + + val result = spark.read.parquet(outputPath) + // Even numbers: 2,4,6,8 from df1, 20,22,24,26,28 from df2 = 9 rows + assert(result.count() == 9, "Expected 9 even rows") + + assertHasCometNativeWriteExec(plan) + } + } + } + + // Corner case: UNION with complex types (struct) + test("parquet write with union of structs") { + withTempPath { dir => + val outputPath = new File(dir, "output.parquet").getAbsolutePath + + withSQLConf( + CometConf.COMET_NATIVE_PARQUET_WRITE_ENABLED.key -> "true", + CometConf.getOperatorAllowIncompatConfigKey(classOf[DataWritingCommandExec]) -> "true", + CometConf.COMET_EXEC_ENABLED.key -> "true") { + + val df = spark.sql(""" + SELECT 1 as id, named_struct('name', 'Alice', 'age', 30) as person + UNION ALL + SELECT 2 as id, named_struct('name', 'Bob', 'age', 25) as person + """) + + val plan = captureWritePlan(path => df.write.parquet(path), outputPath) + + val result = spark.read.parquet(outputPath) + assert(result.count() == 2) + + assertHasCometNativeWriteExec(plan) + } + } + } + + // Corner case: Nested UNION (UNION inside subquery) + test("parquet write with nested union in SQL") { + withTempPath { dir => + val outputPath = new File(dir, "output.parquet").getAbsolutePath + + withSQLConf( + CometConf.COMET_NATIVE_PARQUET_WRITE_ENABLED.key -> "true", + CometConf.getOperatorAllowIncompatConfigKey(classOf[DataWritingCommandExec]) -> "true", + CometConf.COMET_EXEC_ENABLED.key -> "true") { + + val df = spark.sql(""" + SELECT * FROM ( + SELECT 1 as id UNION ALL SELECT 2 as id + ) + UNION ALL + SELECT * FROM ( + SELECT 3 as id UNION ALL SELECT 4 as id + ) + """) + + val plan = captureWritePlan(path => df.write.parquet(path), outputPath) + + val result = spark.read.parquet(outputPath) + assert(result.count() == 4) + + assertHasCometNativeWriteExec(plan) + } + } + } + test("parquet write with map type") { withTempPath { dir => val outputPath = new File(dir, "output.parquet").getAbsolutePath