From d9050480b7ea8827f3e14a2a483a745fbdafddc3 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Wed, 10 Dec 2025 08:44:32 -0700 Subject: [PATCH 01/13] basic fix --- .../apache/comet/rules/CometExecRule.scala | 69 ++++++++++++++++++- .../comet/rules/CometExecRuleSuite.scala | 4 +- 2 files changed, 69 insertions(+), 4 deletions(-) diff --git a/spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala b/spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala index f490202537..a58a8d67a5 100644 --- a/spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala +++ b/spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala @@ -21,6 +21,7 @@ package org.apache.comet.rules import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.expressions.{Divide, DoubleLiteral, EqualNullSafe, EqualTo, Expression, FloatLiteral, GreaterThan, GreaterThanOrEqual, KnownFloatingPointNormalized, LessThan, LessThanOrEqual, NamedExpression, Remainder} +import org.apache.spark.sql.catalyst.expressions.aggregate.{Final, Partial, PartialMerge} import org.apache.spark.sql.catalyst.optimizer.NormalizeNaNAndZero import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.util.sideBySide @@ -103,6 +104,64 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] { private def isCometNative(op: SparkPlan): Boolean = op.isInstanceOf[CometNativeExec] + /** + * Pre-processes the plan to ensure coordination between partial and final hash aggregates. + * + * This method walks the plan top-down to identify final hash aggregates that cannot be + * converted to Comet. For such cases, it finds and tags any corresponding partial aggregates + * with fallback reasons to prevent mixed Comet partial + Spark final aggregation. + * + * @param plan + * The input plan to pre-process + * @return + * The plan with appropriate fallback tags added + */ + private def tagUnsupportedPartialAggregates(plan: SparkPlan): SparkPlan = { + plan.transformDown { + case finalAgg: HashAggregateExec if hasFinalMode(finalAgg) => + // Check if this final aggregate can be converted to Comet + val handler = allExecs + .get(finalAgg.getClass) + .map(_.asInstanceOf[CometOperatorSerde[SparkPlan]]) + + handler match { + case Some(serde) if !isOperatorEnabled(serde, finalAgg) => + // Final aggregate cannot be converted, so tag corresponding partial aggregates + val reason = "Cannot convert final hash aggregate to Comet, " + + "so partial aggregates must also use Spark to avoid mixed execution" + tagRelatedPartialAggregates(finalAgg, reason) + case _ => + finalAgg + } + case other => other + } + } + + /** + * Helper method to check if a hash aggregate has Final mode expressions. + */ + private def hasFinalMode(agg: HashAggregateExec): Boolean = { + agg.aggregateExpressions.exists(_.mode == Final) + } + + /** + * Tags related partial aggregates in the subtree with fallback reasons. + */ + private def tagRelatedPartialAggregates(plan: SparkPlan, reason: String): SparkPlan = { + plan.transformDown { + case partialAgg: HashAggregateExec if hasPartialMode(partialAgg) => + withInfo(partialAgg, reason) + case other => other + } + } + + /** + * Helper method to check if a hash aggregate has Partial or PartialMerge mode expressions. + */ + private def hasPartialMode(agg: HashAggregateExec): Boolean = { + agg.aggregateExpressions.exists(expr => expr.mode == Partial || expr.mode == PartialMerge) + } + // spotless:off /** @@ -239,6 +298,11 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] { convertToComet(s, CometShuffleExchangeExec).getOrElse(s) case op => + // Check if this operator has already been tagged with fallback reasons + if (hasExplainInfo(op)) { + return op + } + // if all children are native (or if this is a leaf node) then see if there is a // registered handler for creating a fully native plan if (op.children.forall(_.isInstanceOf[CometNativeExec])) { @@ -365,7 +429,10 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] { normalizedPlan } - var newPlan = transform(planWithJoinRewritten) + // Pre-process the plan to ensure coordination between partial and final hash aggregates + val planWithAggregateCoordination = tagUnsupportedPartialAggregates(planWithJoinRewritten) + + var newPlan = transform(planWithAggregateCoordination) // if the plan cannot be run fully natively then explain why (when appropriate // config is enabled) diff --git a/spark/src/test/scala/org/apache/comet/rules/CometExecRuleSuite.scala b/spark/src/test/scala/org/apache/comet/rules/CometExecRuleSuite.scala index cf6f8918f4..84b3783d28 100644 --- a/spark/src/test/scala/org/apache/comet/rules/CometExecRuleSuite.scala +++ b/spark/src/test/scala/org/apache/comet/rules/CometExecRuleSuite.scala @@ -131,9 +131,7 @@ class CometExecRuleSuite extends CometTestBase { } } - // TODO this test exposes the bug described in - // https://github.com/apache/datafusion-comet/issues/1389 - ignore("CometExecRule should not allow Comet partial and Spark final hash aggregate") { + test("CometExecRule should not allow Comet partial and Spark final hash aggregate") { withTempView("test_data") { createTestDataFrame.createOrReplaceTempView("test_data") From ddcfaf5a6b4c3d9e90ce62f93a213208525151b6 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Wed, 10 Dec 2025 08:49:19 -0700 Subject: [PATCH 02/13] refine --- .../apache/comet/rules/CometExecRule.scala | 24 +++++++++++++++---- 1 file changed, 19 insertions(+), 5 deletions(-) diff --git a/spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala b/spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala index a58a8d67a5..1c152c6252 100644 --- a/spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala +++ b/spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala @@ -145,14 +145,28 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] { } /** - * Tags related partial aggregates in the subtree with fallback reasons. + * Tags the first related partial aggregate in the subtree with fallback reasons. Stops + * transforming after finding and tagging the first partial aggregate to avoid affecting + * unrelated aggregates elsewhere in the tree. */ private def tagRelatedPartialAggregates(plan: SparkPlan, reason: String): SparkPlan = { - plan.transformDown { - case partialAgg: HashAggregateExec if hasPartialMode(partialAgg) => - withInfo(partialAgg, reason) - case other => other + var found = false + + def transformOnce(node: SparkPlan): SparkPlan = { + if (found) { + node + } else { + node match { + case partialAgg: HashAggregateExec if hasPartialMode(partialAgg) => + found = true + withInfo(partialAgg, reason) + case other => + other.withNewChildren(other.children.map(transformOnce)) + } + } } + + transformOnce(plan) } /** From 59d5ead100b427759bf3232a625afdc7b5091182 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Wed, 10 Dec 2025 09:08:55 -0700 Subject: [PATCH 03/13] object hash agg test passes --- .../apache/comet/rules/CometExecRule.scala | 16 ++-- .../apache/spark/sql/comet/operators.scala | 14 +++ .../comet/rules/CometExecRuleSuite.scala | 95 ++++++++++++++++++- 3 files changed, 116 insertions(+), 9 deletions(-) diff --git a/spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala b/spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala index 1c152c6252..58b58e8eb9 100644 --- a/spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala +++ b/spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.comet._ import org.apache.spark.sql.comet.execution.shuffle.{CometColumnarShuffle, CometNativeShuffle, CometShuffleExchangeExec} import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, AQEShuffleReadExec, BroadcastQueryStageExec, ShuffleQueryStageExec} -import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec} +import org.apache.spark.sql.execution.aggregate.{BaseAggregateExec, HashAggregateExec, ObjectHashAggregateExec} import org.apache.spark.sql.execution.command.{DataWritingCommandExec, ExecutedCommandExec} import org.apache.spark.sql.execution.datasources.v2.V2CommandExec import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec} @@ -118,7 +118,7 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] { */ private def tagUnsupportedPartialAggregates(plan: SparkPlan): SparkPlan = { plan.transformDown { - case finalAgg: HashAggregateExec if hasFinalMode(finalAgg) => + case finalAgg: BaseAggregateExec if hasFinalMode(finalAgg) => // Check if this final aggregate can be converted to Comet val handler = allExecs .get(finalAgg.getClass) @@ -127,7 +127,7 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] { handler match { case Some(serde) if !isOperatorEnabled(serde, finalAgg) => // Final aggregate cannot be converted, so tag corresponding partial aggregates - val reason = "Cannot convert final hash aggregate to Comet, " + + val reason = "Cannot convert final aggregate to Comet, " + "so partial aggregates must also use Spark to avoid mixed execution" tagRelatedPartialAggregates(finalAgg, reason) case _ => @@ -138,9 +138,9 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] { } /** - * Helper method to check if a hash aggregate has Final mode expressions. + * Helper method to check if an aggregate has Final mode expressions. */ - private def hasFinalMode(agg: HashAggregateExec): Boolean = { + private def hasFinalMode(agg: BaseAggregateExec): Boolean = { agg.aggregateExpressions.exists(_.mode == Final) } @@ -157,7 +157,7 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] { node } else { node match { - case partialAgg: HashAggregateExec if hasPartialMode(partialAgg) => + case partialAgg: BaseAggregateExec if hasPartialMode(partialAgg) => found = true withInfo(partialAgg, reason) case other => @@ -170,9 +170,9 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] { } /** - * Helper method to check if a hash aggregate has Partial or PartialMerge mode expressions. + * Helper method to check if an aggregate has Partial or PartialMerge mode expressions. */ - private def hasPartialMode(agg: HashAggregateExec): Boolean = { + private def hasPartialMode(agg: BaseAggregateExec): Boolean = { agg.aggregateExpressions.exists(expr => expr.mode == Partial || expr.mode == PartialMerge) } diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala index 0a435e5b7a..a10cf1a7b9 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala @@ -1276,6 +1276,20 @@ object CometObjectHashAggregateExec override def enabledConfig: Option[ConfigEntry[Boolean]] = Some( CometConf.COMET_EXEC_AGGREGATE_ENABLED) + override def getSupportLevel(op: ObjectHashAggregateExec): SupportLevel = { + // some unit tests need to disable partial or final hash aggregate support to test that + // CometExecRule does not allow mixed Spark/Comet aggregates + if (!CometConf.COMET_ENABLE_PARTIAL_HASH_AGGREGATE.get(op.conf) && + op.aggregateExpressions.exists(expr => expr.mode == Partial || expr.mode == PartialMerge)) { + return Unsupported(Some("Partial aggregates disabled via test config")) + } + if (!CometConf.COMET_ENABLE_FINAL_HASH_AGGREGATE.get(op.conf) && + op.aggregateExpressions.exists(_.mode == Final)) { + return Unsupported(Some("Final aggregates disabled via test config")) + } + Compatible() + } + override def convert( aggregate: ObjectHashAggregateExec, builder: Operator.Builder, diff --git a/spark/src/test/scala/org/apache/comet/rules/CometExecRuleSuite.scala b/spark/src/test/scala/org/apache/comet/rules/CometExecRuleSuite.scala index 84b3783d28..460ee5dca8 100644 --- a/spark/src/test/scala/org/apache/comet/rules/CometExecRuleSuite.scala +++ b/spark/src/test/scala/org/apache/comet/rules/CometExecRuleSuite.scala @@ -22,11 +22,15 @@ package org.apache.comet.rules import scala.util.Random import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.expressions.ExpressionInfo +import org.apache.spark.sql.catalyst.expressions.aggregate.BloomFilterAggregate import org.apache.spark.sql.comet._ import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.adaptive.QueryStageExec -import org.apache.spark.sql.execution.aggregate.HashAggregateExec +import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec} import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ShuffleExchangeExec} import org.apache.spark.sql.types.{DataTypes, StructField, StructType} @@ -179,6 +183,95 @@ class CometExecRuleSuite extends CometTestBase { } } + test("CometExecRule should not allow Comet partial and Spark final object hash aggregate") { + val funcId_bloom_filter_agg = new FunctionIdentifier("bloom_filter_agg") + spark.sessionState.functionRegistry.registerFunction( + funcId_bloom_filter_agg, + new ExpressionInfo(classOf[BloomFilterAggregate].getName, "bloom_filter_agg"), + (children: Seq[Expression]) => + children.size match { + case 1 => new BloomFilterAggregate(children.head) + case 2 => new BloomFilterAggregate(children.head, children(1)) + case 3 => new BloomFilterAggregate(children.head, children(1), children(2)) + }) + + try { + withTempView("test_data") { + createTestDataFrame.createOrReplaceTempView("test_data") + + val sparkPlan = + createSparkPlan( + spark, + "SELECT bloom_filter_agg(cast(id as long)) FROM test_data GROUP BY (id % 3)") + + // Count original Spark operators - bloom filter should generate ObjectHashAggregateExec + val originalObjectHashAggCount = + countOperators(sparkPlan, classOf[ObjectHashAggregateExec]) + assert(originalObjectHashAggCount == 2) + + withSQLConf( + CometConf.COMET_ENABLE_FINAL_HASH_AGGREGATE.key -> "false", + CometConf.COMET_EXEC_LOCAL_TABLE_SCAN_ENABLED.key -> "true") { + val transformedPlan = applyCometExecRule(sparkPlan) + + // if the final object aggregate cannot be converted to Comet, then neither should be + assert( + countOperators( + transformedPlan, + classOf[ObjectHashAggregateExec]) == originalObjectHashAggCount) + assert(countOperators(transformedPlan, classOf[CometHashAggregateExec]) == 0) + } + } + } finally { + spark.sessionState.functionRegistry.dropFunction(funcId_bloom_filter_agg) + } + } + + test("CometExecRule should not allow Spark partial and Comet final object hash aggregate") { + val funcId_bloom_filter_agg = new FunctionIdentifier("bloom_filter_agg") + spark.sessionState.functionRegistry.registerFunction( + funcId_bloom_filter_agg, + new ExpressionInfo(classOf[BloomFilterAggregate].getName, "bloom_filter_agg"), + (children: Seq[Expression]) => + children.size match { + case 1 => new BloomFilterAggregate(children.head) + case 2 => new BloomFilterAggregate(children.head, children(1)) + case 3 => new BloomFilterAggregate(children.head, children(1), children(2)) + }) + + try { + withTempView("test_data") { + createTestDataFrame.createOrReplaceTempView("test_data") + + val sparkPlan = + createSparkPlan( + spark, + "SELECT bloom_filter_agg(cast(id as long)) FROM test_data GROUP BY (id % 3)") + + // Count original Spark operators - bloom filter should generate ObjectHashAggregateExec + val originalObjectHashAggCount = + countOperators(sparkPlan, classOf[ObjectHashAggregateExec]) + assert(originalObjectHashAggCount == 2) + + withSQLConf( + CometConf.COMET_ENABLE_PARTIAL_HASH_AGGREGATE.key -> "false", + CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true", // ObjectHashAggregateExec requires shuffle + CometConf.COMET_EXEC_LOCAL_TABLE_SCAN_ENABLED.key -> "true") { + val transformedPlan = applyCometExecRule(sparkPlan) + + // if the partial object aggregate cannot be converted to Comet, then neither should be + assert( + countOperators( + transformedPlan, + classOf[ObjectHashAggregateExec]) == originalObjectHashAggCount) + assert(countOperators(transformedPlan, classOf[CometHashAggregateExec]) == 0) + } + } + } finally { + spark.sessionState.functionRegistry.dropFunction(funcId_bloom_filter_agg) + } + } + test("CometExecRule should apply broadcast exchange transformations") { withTempView("test_data") { createTestDataFrame.createOrReplaceTempView("test_data") From 2204ec51dbb6ca320e4a14805a4ad0622ec22e57 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Wed, 10 Dec 2025 09:09:52 -0700 Subject: [PATCH 04/13] format --- .../test/scala/org/apache/comet/rules/CometExecRuleSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spark/src/test/scala/org/apache/comet/rules/CometExecRuleSuite.scala b/spark/src/test/scala/org/apache/comet/rules/CometExecRuleSuite.scala index 460ee5dca8..07f5d73568 100644 --- a/spark/src/test/scala/org/apache/comet/rules/CometExecRuleSuite.scala +++ b/spark/src/test/scala/org/apache/comet/rules/CometExecRuleSuite.scala @@ -22,7 +22,7 @@ package org.apache.comet.rules import scala.util.Random import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} +import org.apache.spark.sql.catalyst.FunctionIdentifier import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.expressions.ExpressionInfo import org.apache.spark.sql.catalyst.expressions.aggregate.BloomFilterAggregate From 0de73bfd26c847d4b36e288a299777294fddb878 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Wed, 10 Dec 2025 09:20:25 -0700 Subject: [PATCH 05/13] tests now check for fallback message --- .../comet/rules/CometExecRuleSuite.scala | 31 +++++++++++++++++-- 1 file changed, 28 insertions(+), 3 deletions(-) diff --git a/spark/src/test/scala/org/apache/comet/rules/CometExecRuleSuite.scala b/spark/src/test/scala/org/apache/comet/rules/CometExecRuleSuite.scala index 07f5d73568..2026f5334b 100644 --- a/spark/src/test/scala/org/apache/comet/rules/CometExecRuleSuite.scala +++ b/spark/src/test/scala/org/apache/comet/rules/CometExecRuleSuite.scala @@ -25,16 +25,16 @@ import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.FunctionIdentifier import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.expressions.ExpressionInfo -import org.apache.spark.sql.catalyst.expressions.aggregate.BloomFilterAggregate +import org.apache.spark.sql.catalyst.expressions.aggregate.{BloomFilterAggregate, Partial, PartialMerge} import org.apache.spark.sql.comet._ import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.adaptive.QueryStageExec -import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec} +import org.apache.spark.sql.execution.aggregate.{BaseAggregateExec, HashAggregateExec, ObjectHashAggregateExec} import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ShuffleExchangeExec} import org.apache.spark.sql.types.{DataTypes, StructField, StructType} -import org.apache.comet.CometConf +import org.apache.comet.{CometConf, CometExplainInfo} import org.apache.comet.testing.{DataGenOptions, FuzzDataGenerator} /** @@ -77,6 +77,22 @@ class CometExecRuleSuite extends CometTestBase { }.sum } + /** Helper method to find all partial aggregates in a plan */ + private def findPartialAggregates(plan: SparkPlan): Seq[BaseAggregateExec] = { + plan.collect { + case agg: BaseAggregateExec + if agg.aggregateExpressions.exists(expr => + expr.mode == Partial || expr.mode == PartialMerge) => + agg + } + } + + /** Helper method to check if an operator has a specific fallback message */ + private def hasFallbackMessage(op: SparkPlan, expectedMessage: String): Boolean = { + op.getTagValue(CometExplainInfo.EXTENSION_INFO) + .exists(_.contains(expectedMessage)) + } + test( "CometExecRule should apply basic operator transformations, but only when Comet is enabled") { withTempView("test_data") { @@ -220,6 +236,15 @@ class CometExecRuleSuite extends CometTestBase { transformedPlan, classOf[ObjectHashAggregateExec]) == originalObjectHashAggCount) assert(countOperators(transformedPlan, classOf[CometHashAggregateExec]) == 0) + + // Verify that the partial aggregate has the expected fallback message + val partialAggregates = findPartialAggregates(transformedPlan) + assert(partialAggregates.nonEmpty, "Should have found at least one partial aggregate") + val expectedMessage = "Cannot convert final aggregate to Comet, " + + "so partial aggregates must also use Spark to avoid mixed execution" + assert( + partialAggregates.exists(hasFallbackMessage(_, expectedMessage)), + s"Partial aggregate should have fallback message: $expectedMessage") } } } finally { From da690817acb363c7108ad435d28d206d7a6ff9e8 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Wed, 10 Dec 2025 09:26:23 -0700 Subject: [PATCH 06/13] improve fallback reporting --- .../apache/comet/rules/CometExecRule.scala | 23 +++++++++++++++---- .../comet/rules/CometExecRuleSuite.scala | 5 ++-- 2 files changed, 21 insertions(+), 7 deletions(-) diff --git a/spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala b/spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala index 58b58e8eb9..bd0d3be068 100644 --- a/spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala +++ b/spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala @@ -125,11 +125,24 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] { .map(_.asInstanceOf[CometOperatorSerde[SparkPlan]]) handler match { - case Some(serde) if !isOperatorEnabled(serde, finalAgg) => - // Final aggregate cannot be converted, so tag corresponding partial aggregates - val reason = "Cannot convert final aggregate to Comet, " + - "so partial aggregates must also use Spark to avoid mixed execution" - tagRelatedPartialAggregates(finalAgg, reason) + case Some(serde) => + // Get the actual support level and reason for the final aggregate + serde.getSupportLevel(finalAgg) match { + case Unsupported(reasonOpt) => + // Final aggregate cannot be converted, extract the actual reason + val actualReason = reasonOpt.getOrElse("Final aggregate not supported by Comet") + val reason = s"Cannot convert final aggregate to Comet ($actualReason), " + + "so partial aggregates must also use Spark to avoid mixed execution" + tagRelatedPartialAggregates(finalAgg, reason) + case Incompatible(reasonOpt) => + // Final aggregate cannot be converted, extract the actual reason + val actualReason = reasonOpt.getOrElse("Final aggregate incompatible with Comet") + val reason = s"Cannot convert final aggregate to Comet ($actualReason), " + + "so partial aggregates must also use Spark to avoid mixed execution" + tagRelatedPartialAggregates(finalAgg, reason) + case Compatible(_) => + finalAgg + } case _ => finalAgg } diff --git a/spark/src/test/scala/org/apache/comet/rules/CometExecRuleSuite.scala b/spark/src/test/scala/org/apache/comet/rules/CometExecRuleSuite.scala index 2026f5334b..e2f87e99cf 100644 --- a/spark/src/test/scala/org/apache/comet/rules/CometExecRuleSuite.scala +++ b/spark/src/test/scala/org/apache/comet/rules/CometExecRuleSuite.scala @@ -240,8 +240,9 @@ class CometExecRuleSuite extends CometTestBase { // Verify that the partial aggregate has the expected fallback message val partialAggregates = findPartialAggregates(transformedPlan) assert(partialAggregates.nonEmpty, "Should have found at least one partial aggregate") - val expectedMessage = "Cannot convert final aggregate to Comet, " + - "so partial aggregates must also use Spark to avoid mixed execution" + val expectedMessage = + "Cannot convert final aggregate to Comet (Final aggregates disabled via test config), " + + "so partial aggregates must also use Spark to avoid mixed execution" assert( partialAggregates.exists(hasFallbackMessage(_, expectedMessage)), s"Partial aggregate should have fallback message: $expectedMessage") From 8802d4f58d63db53512bf90664685aec8e5d0547 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Wed, 10 Dec 2025 10:36:38 -0700 Subject: [PATCH 07/13] save --- .../comet/rules/CometExecRuleSuite.scala | 136 ++++++++++++++++++ 1 file changed, 136 insertions(+) diff --git a/spark/src/test/scala/org/apache/comet/rules/CometExecRuleSuite.scala b/spark/src/test/scala/org/apache/comet/rules/CometExecRuleSuite.scala index e2f87e99cf..8888a636e3 100644 --- a/spark/src/test/scala/org/apache/comet/rules/CometExecRuleSuite.scala +++ b/spark/src/test/scala/org/apache/comet/rules/CometExecRuleSuite.scala @@ -298,6 +298,142 @@ class CometExecRuleSuite extends CometTestBase { } } + test("CometExecRule should coordinate across AQE stages for ObjectHashAggregateExec") { + val funcId_bloom_filter_agg = new FunctionIdentifier("bloom_filter_agg") + spark.sessionState.functionRegistry.registerFunction( + funcId_bloom_filter_agg, + new ExpressionInfo(classOf[BloomFilterAggregate].getName, "bloom_filter_agg"), + (children: Seq[Expression]) => + children.size match { + case 1 => new BloomFilterAggregate(children.head) + case 2 => new BloomFilterAggregate(children.head, children(1)) + case 3 => new BloomFilterAggregate(children.head, children(1), children(2)) + }) + + try { + withTempView("test_data1", "test_data2") { + // Create smaller datasets to avoid memory issues + val testSchema = new StructType( + Array( + StructField("id", DataTypes.IntegerType, nullable = true), + StructField("group_key", DataTypes.IntegerType, nullable = true), + StructField("value", DataTypes.StringType, nullable = true))) + + // Create two smaller tables for join to encourage AQE stage creation + val data1 = FuzzDataGenerator.generateDataFrame( + new Random(42), + spark, + testSchema, + 200, + DataGenOptions()) + val data2 = FuzzDataGenerator.generateDataFrame( + new Random(43), + spark, + testSchema, + 200, + DataGenOptions()) + + data1.createOrReplaceTempView("test_data1") + data2.createOrReplaceTempView("test_data2") + + // Use moderate AQE settings to avoid memory issues while encouraging stage creation + withSQLConf( + "spark.sql.adaptive.enabled" -> "true", + "spark.sql.adaptive.coalescePartitions.enabled" -> "true", + "spark.sql.adaptive.advisoryPartitionSizeInBytes" -> "64KB", // Less aggressive + "spark.sql.adaptive.skewJoin.enabled" -> "true", + "spark.default.parallelism" -> "4", // Fewer partitions + "spark.sql.shuffle.partitions" -> "4", + CometConf.COMET_ENABLE_FINAL_HASH_AGGREGATE.key -> "false", + CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true", + CometConf.COMET_EXEC_LOCAL_TABLE_SCAN_ENABLED.key -> "true") { + + // Use a join + aggregation to encourage AQE stage boundaries + // Join often creates natural stage boundaries in AQE + val df = spark.sql(""" + |SELECT + | t1.group_key, + | bloom_filter_agg(cast(t1.id as long)) as bloom_result, + | count(*) as cnt + |FROM test_data1 t1 + |JOIN test_data2 t2 ON t1.group_key = t2.group_key + |WHERE t1.id % 2 = 0 + |GROUP BY t1.group_key + |""".stripMargin) + + // Execute the plan to trigger AQE stage creation + val result = df.collect() + logInfo(s"Query executed successfully, returned ${result.length} rows") + + // Get the executed plan which might have AQE stages + val executedPlan = df.queryExecution.executedPlan + + // Check if we have QueryStageExec nodes (indicating AQE created stages) + val queryStages = executedPlan.collect { case qs: QueryStageExec => qs } + + if (queryStages.nonEmpty) { + // We have AQE stages - test the cross-stage coordination + logInfo( + s"AQE created ${queryStages.length} stages - testing cross-stage coordination") + + // Verify that we have ObjectHashAggregateExec operators in the plan + val objectHashAggs = stripAQEPlan(executedPlan).collect { + case agg: ObjectHashAggregateExec => agg + } + assert(objectHashAggs.nonEmpty, "Should have ObjectHashAggregateExec operators") + + // Verify coordination worked - no mixed Comet/Spark aggregation + val cometHashAggs = stripAQEPlan(executedPlan).collect { + case agg: CometHashAggregateExec => agg + } + assert( + cometHashAggs.isEmpty, + "Should have no CometHashAggregateExec - coordination should prevent mixed execution") + + // Verify that partial aggregates have the expected fallback message + val partialAggregates = queryStages.flatMap { stage => + findPartialAggregates(stage.plan) + } + if (partialAggregates.nonEmpty) { + val expectedMessage = + "Cannot convert final aggregate to Comet (Final aggregates disabled via test config), " + + "so partial aggregates must also use Spark to avoid mixed execution" + assert( + partialAggregates.exists(hasFallbackMessage(_, expectedMessage)), + s"Partial aggregate should have fallback message: $expectedMessage") + } + + logInfo(s"AQE cross-stage coordination test passed with ${queryStages.length} stages") + } else { + // AQE didn't create stages - fall back to testing single-stage coordination + logInfo( + "AQE did not create separate stages - testing single-stage coordination instead") + + // scalastyle:off + println(executedPlan) + // Verify that we have ObjectHashAggregateExec operators + val objectHashAggs = stripAQEPlan(executedPlan).collect { case agg: ObjectHashAggregateExec => + agg + } + assert(objectHashAggs.nonEmpty, "Should have ObjectHashAggregateExec operators") + + // Verify coordination worked - no mixed Comet/Spark aggregation + val cometHashAggs = executedPlan.collect { case agg: CometHashAggregateExec => + agg + } + assert( + cometHashAggs.isEmpty, + "Should have no CometHashAggregateExec - coordination should prevent mixed execution") + + logInfo("Single-stage coordination test passed") + } + } + } + } finally { + spark.sessionState.functionRegistry.dropFunction(funcId_bloom_filter_agg) + } + } + test("CometExecRule should apply broadcast exchange transformations") { withTempView("test_data") { createTestDataFrame.createOrReplaceTempView("test_data") From 818cdd0e4d52e4b1a35b0505cd23e54cbbd2a6c3 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Wed, 10 Dec 2025 10:46:53 -0700 Subject: [PATCH 08/13] Save --- .../comet/rules/CometExecRuleSuite.scala | 76 +++++++++++++------ 1 file changed, 53 insertions(+), 23 deletions(-) diff --git a/spark/src/test/scala/org/apache/comet/rules/CometExecRuleSuite.scala b/spark/src/test/scala/org/apache/comet/rules/CometExecRuleSuite.scala index 8888a636e3..89f38659fb 100644 --- a/spark/src/test/scala/org/apache/comet/rules/CometExecRuleSuite.scala +++ b/spark/src/test/scala/org/apache/comet/rules/CometExecRuleSuite.scala @@ -311,61 +311,78 @@ class CometExecRuleSuite extends CometTestBase { }) try { - withTempView("test_data1", "test_data2") { - // Create smaller datasets to avoid memory issues + withTempView("test_data1", "test_data2", "test_data3") { + // Create datasets large enough to force AQE stage creation val testSchema = new StructType( Array( StructField("id", DataTypes.IntegerType, nullable = true), StructField("group_key", DataTypes.IntegerType, nullable = true), StructField("value", DataTypes.StringType, nullable = true))) - // Create two smaller tables for join to encourage AQE stage creation + // Create multiple tables with larger datasets to force shuffle stages val data1 = FuzzDataGenerator.generateDataFrame( new Random(42), spark, testSchema, - 200, + 1000, DataGenOptions()) val data2 = FuzzDataGenerator.generateDataFrame( new Random(43), spark, testSchema, - 200, + 1000, + DataGenOptions()) + val data3 = FuzzDataGenerator.generateDataFrame( + new Random(44), + spark, + testSchema, + 1000, DataGenOptions()) data1.createOrReplaceTempView("test_data1") data2.createOrReplaceTempView("test_data2") + data3.createOrReplaceTempView("test_data3") - // Use moderate AQE settings to avoid memory issues while encouraging stage creation + // More aggressive AQE settings to force stage creation withSQLConf( "spark.sql.adaptive.enabled" -> "true", "spark.sql.adaptive.coalescePartitions.enabled" -> "true", - "spark.sql.adaptive.advisoryPartitionSizeInBytes" -> "64KB", // Less aggressive + "spark.sql.adaptive.advisoryPartitionSizeInBytes" -> "4KB", // Very small to force stages "spark.sql.adaptive.skewJoin.enabled" -> "true", - "spark.default.parallelism" -> "4", // Fewer partitions - "spark.sql.shuffle.partitions" -> "4", + "spark.default.parallelism" -> "8", + "spark.sql.shuffle.partitions" -> "8", CometConf.COMET_ENABLE_FINAL_HASH_AGGREGATE.key -> "false", CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true", CometConf.COMET_EXEC_LOCAL_TABLE_SCAN_ENABLED.key -> "true") { - // Use a join + aggregation to encourage AQE stage boundaries - // Join often creates natural stage boundaries in AQE + // Use a complex query with multiple joins and subqueries to force stage boundaries + // This pattern is more likely to create distinct partial/final aggregate stages val df = spark.sql(""" + |WITH combined_data AS ( + | SELECT t1.id, t1.group_key, t1.value + | FROM test_data1 t1 + | JOIN test_data2 t2 ON t1.group_key = t2.group_key + | WHERE t1.id % 3 = 0 + | UNION ALL + | SELECT t3.id, t3.group_key, t3.value + | FROM test_data3 t3 + | WHERE t3.id % 5 = 0 + |) |SELECT - | t1.group_key, - | bloom_filter_agg(cast(t1.id as long)) as bloom_result, + | group_key, + | bloom_filter_agg(cast(id as long)) as bloom_result, | count(*) as cnt - |FROM test_data1 t1 - |JOIN test_data2 t2 ON t1.group_key = t2.group_key - |WHERE t1.id % 2 = 0 - |GROUP BY t1.group_key + |FROM combined_data + |GROUP BY group_key + |HAVING count(*) > 1 + |ORDER BY group_key |""".stripMargin) // Execute the plan to trigger AQE stage creation val result = df.collect() logInfo(s"Query executed successfully, returned ${result.length} rows") - // Get the executed plan which might have AQE stages + // Get the executed plan which should have AQE stages val executedPlan = df.queryExecution.executedPlan // Check if we have QueryStageExec nodes (indicating AQE created stages) @@ -405,15 +422,16 @@ class CometExecRuleSuite extends CometTestBase { logInfo(s"AQE cross-stage coordination test passed with ${queryStages.length} stages") } else { - // AQE didn't create stages - fall back to testing single-stage coordination + // AQE didn't create stages - test that single-stage coordination still works logInfo( - "AQE did not create separate stages - testing single-stage coordination instead") + "AQE did not create separate stages - verifying single-stage coordination works") // scalastyle:off println(executedPlan) // Verify that we have ObjectHashAggregateExec operators - val objectHashAggs = stripAQEPlan(executedPlan).collect { case agg: ObjectHashAggregateExec => - agg + val objectHashAggs = stripAQEPlan(executedPlan).collect { + case agg: ObjectHashAggregateExec => + agg } assert(objectHashAggs.nonEmpty, "Should have ObjectHashAggregateExec operators") @@ -425,7 +443,19 @@ class CometExecRuleSuite extends CometTestBase { cometHashAggs.isEmpty, "Should have no CometHashAggregateExec - coordination should prevent mixed execution") - logInfo("Single-stage coordination test passed") + // Verify partial aggregates have fallback messages even in single-stage case + val partialAggregates = findPartialAggregates(stripAQEPlan(executedPlan)) + if (partialAggregates.nonEmpty) { + val expectedMessage = + "Cannot convert final aggregate to Comet (Final aggregates disabled via test config), " + + "so partial aggregates must also use Spark to avoid mixed execution" + assert( + partialAggregates.exists(hasFallbackMessage(_, expectedMessage)), + s"Partial aggregate should have fallback message even in single-stage: $expectedMessage") + } + + logInfo( + "Single-stage coordination test passed - coordination works within single stage") } } } From fab662495cd10b29c398654ca1da3efea7404454 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Wed, 10 Dec 2025 10:57:09 -0700 Subject: [PATCH 09/13] fix test --- .../comet/rules/CometExecRuleSuite.scala | 134 +++++++++--------- 1 file changed, 70 insertions(+), 64 deletions(-) diff --git a/spark/src/test/scala/org/apache/comet/rules/CometExecRuleSuite.scala b/spark/src/test/scala/org/apache/comet/rules/CometExecRuleSuite.scala index 89f38659fb..0d1e230c8e 100644 --- a/spark/src/test/scala/org/apache/comet/rules/CometExecRuleSuite.scala +++ b/spark/src/test/scala/org/apache/comet/rules/CometExecRuleSuite.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.{BloomFilterAggregate import org.apache.spark.sql.comet._ import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec import org.apache.spark.sql.execution._ -import org.apache.spark.sql.execution.adaptive.QueryStageExec +import org.apache.spark.sql.execution.adaptive.{BroadcastQueryStageExec, QueryStageExec, ShuffleQueryStageExec} import org.apache.spark.sql.execution.aggregate.{BaseAggregateExec, HashAggregateExec, ObjectHashAggregateExec} import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ShuffleExchangeExec} import org.apache.spark.sql.types.{DataTypes, StructField, StructType} @@ -93,6 +93,11 @@ class CometExecRuleSuite extends CometTestBase { .exists(_.contains(expectedMessage)) } + /** Helper method to check if an aggregate has Partial or PartialMerge mode expressions */ + private def hasPartialMode(agg: BaseAggregateExec): Boolean = { + agg.aggregateExpressions.exists(expr => expr.mode == Partial || expr.mode == PartialMerge) + } + test( "CometExecRule should apply basic operator transformations, but only when Comet is enabled") { withTempView("test_data") { @@ -386,77 +391,78 @@ class CometExecRuleSuite extends CometTestBase { val executedPlan = df.queryExecution.executedPlan // Check if we have QueryStageExec nodes (indicating AQE created stages) - val queryStages = executedPlan.collect { case qs: QueryStageExec => qs } - - if (queryStages.nonEmpty) { - // We have AQE stages - test the cross-stage coordination - logInfo( - s"AQE created ${queryStages.length} stages - testing cross-stage coordination") - - // Verify that we have ObjectHashAggregateExec operators in the plan - val objectHashAggs = stripAQEPlan(executedPlan).collect { - case agg: ObjectHashAggregateExec => agg + val queryStages = stripAQEPlan(executedPlan).collect { case qs: QueryStageExec => qs } + + assert (queryStages.nonEmpty) + // We have AQE stages - test the cross-stage coordination + logInfo( + s"AQE created ${queryStages.length} stages - testing cross-stage coordination") + + // Verify that we have ObjectHashAggregateExec operators in the plan + // Need to recursively search through AQE stages + def findObjectHashAggs(plan: SparkPlan): Seq[ObjectHashAggregateExec] = { + val buffer = scala.collection.mutable.ListBuffer[ObjectHashAggregateExec]() + def collect(p: SparkPlan): Unit = { + p match { + case agg: ObjectHashAggregateExec => buffer += agg + case stage: ShuffleQueryStageExec => collect(stage.plan) + case stage: BroadcastQueryStageExec => collect(stage.plan) + case _ => p.children.foreach(collect) + } } - assert(objectHashAggs.nonEmpty, "Should have ObjectHashAggregateExec operators") - - // Verify coordination worked - no mixed Comet/Spark aggregation - val cometHashAggs = stripAQEPlan(executedPlan).collect { - case agg: CometHashAggregateExec => agg - } - assert( - cometHashAggs.isEmpty, - "Should have no CometHashAggregateExec - coordination should prevent mixed execution") + collect(plan) + buffer.toSeq + } - // Verify that partial aggregates have the expected fallback message - val partialAggregates = queryStages.flatMap { stage => - findPartialAggregates(stage.plan) - } - if (partialAggregates.nonEmpty) { - val expectedMessage = - "Cannot convert final aggregate to Comet (Final aggregates disabled via test config), " + - "so partial aggregates must also use Spark to avoid mixed execution" - assert( - partialAggregates.exists(hasFallbackMessage(_, expectedMessage)), - s"Partial aggregate should have fallback message: $expectedMessage") + val objectHashAggs = findObjectHashAggs(stripAQEPlan(executedPlan)) + assert(objectHashAggs.nonEmpty, "Should have ObjectHashAggregateExec operators") + + // Verify coordination worked - no mixed Comet/Spark aggregation + def findCometHashAggs(plan: SparkPlan): Seq[CometHashAggregateExec] = { + val buffer = scala.collection.mutable.ListBuffer[CometHashAggregateExec]() + def collect(p: SparkPlan): Unit = { + p match { + case agg: CometHashAggregateExec => buffer += agg + case stage: ShuffleQueryStageExec => collect(stage.plan) + case stage: BroadcastQueryStageExec => collect(stage.plan) + case _ => p.children.foreach(collect) + } } + collect(plan) + buffer.toSeq + } - logInfo(s"AQE cross-stage coordination test passed with ${queryStages.length} stages") - } else { - // AQE didn't create stages - test that single-stage coordination still works - logInfo( - "AQE did not create separate stages - verifying single-stage coordination works") - - // scalastyle:off - println(executedPlan) - // Verify that we have ObjectHashAggregateExec operators - val objectHashAggs = stripAQEPlan(executedPlan).collect { - case agg: ObjectHashAggregateExec => - agg + val cometHashAggs = findCometHashAggs(executedPlan) + assert( + cometHashAggs.isEmpty, + "Should have no CometHashAggregateExec - coordination should prevent mixed execution") + + // Verify that partial aggregates have the expected fallback message + def findPartialAggsInAQE(plan: SparkPlan): Seq[BaseAggregateExec] = { + val buffer = scala.collection.mutable.ListBuffer[BaseAggregateExec]() + def collect(p: SparkPlan): Unit = { + p match { + case agg: BaseAggregateExec if hasPartialMode(agg) => buffer += agg + case stage: ShuffleQueryStageExec => collect(stage.plan) + case stage: BroadcastQueryStageExec => collect(stage.plan) + case _ => p.children.foreach(collect) + } } - assert(objectHashAggs.nonEmpty, "Should have ObjectHashAggregateExec operators") + collect(plan) + buffer.toSeq + } - // Verify coordination worked - no mixed Comet/Spark aggregation - val cometHashAggs = executedPlan.collect { case agg: CometHashAggregateExec => - agg - } + val partialAggregates = findPartialAggsInAQE(executedPlan) + if (partialAggregates.nonEmpty) { + val expectedMessage = + "Cannot convert final aggregate to Comet (Final aggregates disabled via test config), " + + "so partial aggregates must also use Spark to avoid mixed execution" assert( - cometHashAggs.isEmpty, - "Should have no CometHashAggregateExec - coordination should prevent mixed execution") - - // Verify partial aggregates have fallback messages even in single-stage case - val partialAggregates = findPartialAggregates(stripAQEPlan(executedPlan)) - if (partialAggregates.nonEmpty) { - val expectedMessage = - "Cannot convert final aggregate to Comet (Final aggregates disabled via test config), " + - "so partial aggregates must also use Spark to avoid mixed execution" - assert( - partialAggregates.exists(hasFallbackMessage(_, expectedMessage)), - s"Partial aggregate should have fallback message even in single-stage: $expectedMessage") - } - - logInfo( - "Single-stage coordination test passed - coordination works within single stage") + partialAggregates.exists(hasFallbackMessage(_, expectedMessage)), + s"Partial aggregate should have fallback message: $expectedMessage") } + + logInfo(s"AQE cross-stage coordination test passed with ${queryStages.length} stages") } } } finally { From 8b62ba6f0718e11982b87c5e6c0013671e7225cd Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Wed, 10 Dec 2025 10:58:03 -0700 Subject: [PATCH 10/13] format --- .../scala/org/apache/comet/rules/CometExecRuleSuite.scala | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/spark/src/test/scala/org/apache/comet/rules/CometExecRuleSuite.scala b/spark/src/test/scala/org/apache/comet/rules/CometExecRuleSuite.scala index 0d1e230c8e..802435709c 100644 --- a/spark/src/test/scala/org/apache/comet/rules/CometExecRuleSuite.scala +++ b/spark/src/test/scala/org/apache/comet/rules/CometExecRuleSuite.scala @@ -393,10 +393,9 @@ class CometExecRuleSuite extends CometTestBase { // Check if we have QueryStageExec nodes (indicating AQE created stages) val queryStages = stripAQEPlan(executedPlan).collect { case qs: QueryStageExec => qs } - assert (queryStages.nonEmpty) + assert(queryStages.nonEmpty) // We have AQE stages - test the cross-stage coordination - logInfo( - s"AQE created ${queryStages.length} stages - testing cross-stage coordination") + logInfo(s"AQE created ${queryStages.length} stages - testing cross-stage coordination") // Verify that we have ObjectHashAggregateExec operators in the plan // Need to recursively search through AQE stages From 2686ea88b0d2e8c2e213a0aa58f8be3df0e82290 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Wed, 10 Dec 2025 10:58:31 -0700 Subject: [PATCH 11/13] remove debug logging --- .../scala/org/apache/comet/rules/CometExecRuleSuite.scala | 5 ----- 1 file changed, 5 deletions(-) diff --git a/spark/src/test/scala/org/apache/comet/rules/CometExecRuleSuite.scala b/spark/src/test/scala/org/apache/comet/rules/CometExecRuleSuite.scala index 802435709c..53fa90a3a3 100644 --- a/spark/src/test/scala/org/apache/comet/rules/CometExecRuleSuite.scala +++ b/spark/src/test/scala/org/apache/comet/rules/CometExecRuleSuite.scala @@ -385,7 +385,6 @@ class CometExecRuleSuite extends CometTestBase { // Execute the plan to trigger AQE stage creation val result = df.collect() - logInfo(s"Query executed successfully, returned ${result.length} rows") // Get the executed plan which should have AQE stages val executedPlan = df.queryExecution.executedPlan @@ -394,8 +393,6 @@ class CometExecRuleSuite extends CometTestBase { val queryStages = stripAQEPlan(executedPlan).collect { case qs: QueryStageExec => qs } assert(queryStages.nonEmpty) - // We have AQE stages - test the cross-stage coordination - logInfo(s"AQE created ${queryStages.length} stages - testing cross-stage coordination") // Verify that we have ObjectHashAggregateExec operators in the plan // Need to recursively search through AQE stages @@ -460,8 +457,6 @@ class CometExecRuleSuite extends CometTestBase { partialAggregates.exists(hasFallbackMessage(_, expectedMessage)), s"Partial aggregate should have fallback message: $expectedMessage") } - - logInfo(s"AQE cross-stage coordination test passed with ${queryStages.length} stages") } } } finally { From efcaa39f4cabc3563c95f9c4c4a590b2dadec125 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Wed, 10 Dec 2025 11:10:11 -0700 Subject: [PATCH 12/13] improve test to check for cross-stage boundary --- .../comet/rules/CometExecRuleSuite.scala | 63 ++++++++++++++++++- 1 file changed, 60 insertions(+), 3 deletions(-) diff --git a/spark/src/test/scala/org/apache/comet/rules/CometExecRuleSuite.scala b/spark/src/test/scala/org/apache/comet/rules/CometExecRuleSuite.scala index 53fa90a3a3..a7a5331f36 100644 --- a/spark/src/test/scala/org/apache/comet/rules/CometExecRuleSuite.scala +++ b/spark/src/test/scala/org/apache/comet/rules/CometExecRuleSuite.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.FunctionIdentifier import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.expressions.ExpressionInfo -import org.apache.spark.sql.catalyst.expressions.aggregate.{BloomFilterAggregate, Partial, PartialMerge} +import org.apache.spark.sql.catalyst.expressions.aggregate.{BloomFilterAggregate, Final, Partial, PartialMerge} import org.apache.spark.sql.comet._ import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec import org.apache.spark.sql.execution._ @@ -448,14 +448,71 @@ class CometExecRuleSuite extends CometTestBase { buffer.toSeq } + // Create a mapping from aggregate operators to their containing QueryStageExec + def buildStageMapping(plan: SparkPlan): Map[BaseAggregateExec, QueryStageExec] = { + val mapping = scala.collection.mutable.Map[BaseAggregateExec, QueryStageExec]() + def collect(p: SparkPlan, currentStage: Option[QueryStageExec]): Unit = { + p match { + case stage: QueryStageExec => + collect(stage.plan, Some(stage)) + case agg: BaseAggregateExec if currentStage.isDefined => + mapping += (agg -> currentStage.get) + p.children.foreach(collect(_, currentStage)) + case _ => + p.children.foreach(collect(_, currentStage)) + } + } + collect(plan, None) + mapping.toMap + } + val partialAggregates = findPartialAggsInAQE(executedPlan) if (partialAggregates.nonEmpty) { val expectedMessage = "Cannot convert final aggregate to Comet (Final aggregates disabled via test config), " + "so partial aggregates must also use Spark to avoid mixed execution" + + val partialAggsWithFallback = + partialAggregates.filter(hasFallbackMessage(_, expectedMessage)) assert( - partialAggregates.exists(hasFallbackMessage(_, expectedMessage)), - s"Partial aggregate should have fallback message: $expectedMessage") + partialAggsWithFallback.nonEmpty, + s"Should have partial aggregates with fallback message: $expectedMessage") + + // Find final aggregates to verify cross-stage coordination + def findFinalAggsInAQE(plan: SparkPlan): Seq[BaseAggregateExec] = { + val buffer = scala.collection.mutable.ListBuffer[BaseAggregateExec]() + def collect(p: SparkPlan): Unit = { + p match { + case agg: BaseAggregateExec + if agg.aggregateExpressions.exists(_.mode == Final) => + buffer += agg + case stage: ShuffleQueryStageExec => collect(stage.plan) + case stage: BroadcastQueryStageExec => collect(stage.plan) + case _ => p.children.foreach(collect) + } + } + collect(plan) + buffer.toSeq + } + + val finalAggregates = findFinalAggsInAQE(executedPlan) + val stageMapping = buildStageMapping(stripAQEPlan(executedPlan)) + + if (finalAggregates.nonEmpty && partialAggsWithFallback.nonEmpty) { + // Verify that partial and final aggregates are in different stages + val partialStages = partialAggsWithFallback.flatMap(stageMapping.get).distinct + val finalStages = finalAggregates.flatMap(stageMapping.get).distinct + + assert( + partialStages.nonEmpty && finalStages.nonEmpty, + "Should find both partial and final aggregates within QueryStageExec nodes") + + assert( + partialStages.intersect(finalStages).isEmpty, + s"Partial aggregates (stages: ${partialStages.map(_.id)}) and " + + s"final aggregates (stages: ${finalStages.map(_.id)}) should be in different AQE stages " + + "to prove cross-stage coordination is working") + } } } } From d8ed6ee9a009e33d5012ad309555e3a2c11f741c Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Wed, 10 Dec 2025 12:01:53 -0700 Subject: [PATCH 13/13] scalastyle --- .../test/scala/org/apache/comet/rules/CometExecRuleSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spark/src/test/scala/org/apache/comet/rules/CometExecRuleSuite.scala b/spark/src/test/scala/org/apache/comet/rules/CometExecRuleSuite.scala index a7a5331f36..67cbe982bd 100644 --- a/spark/src/test/scala/org/apache/comet/rules/CometExecRuleSuite.scala +++ b/spark/src/test/scala/org/apache/comet/rules/CometExecRuleSuite.scala @@ -384,7 +384,7 @@ class CometExecRuleSuite extends CometTestBase { |""".stripMargin) // Execute the plan to trigger AQE stage creation - val result = df.collect() + df.collect() // Get the executed plan which should have AQE stages val executedPlan = df.queryExecution.executedPlan