Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
38 changes: 35 additions & 3 deletions dev/diffs/3.5.7.diff
Original file line number Diff line number Diff line change
Expand Up @@ -239,10 +239,14 @@ index e5494726695..00937f025c2 100644

test("A cached table preserves the partitioning and ordering of its cached SparkPlan") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
index 6f3090d8908..c08a60fb0c2 100644
index 6f3090d8908..d4208f1d642 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
@@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.plans.logical.Expand
@@ -25,10 +25,11 @@ import org.scalatest.matchers.must.Matchers.the

import org.apache.spark.{SparkException, SparkThrowable}
import org.apache.spark.sql.catalyst.plans.logical.Expand
+import org.apache.spark.sql.comet.{CometColumnarToRowExec, CometHashAggregateExec}
import org.apache.spark.sql.execution.WholeStageCodegenExec
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec}
Expand All @@ -251,7 +255,35 @@ index 6f3090d8908..c08a60fb0c2 100644
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
@@ -793,7 +793,7 @@ class DataFrameAggregateSuite extends QueryTest
@@ -726,10 +727,15 @@ class DataFrameAggregateSuite extends QueryTest
if (wholeStage) {
assert(find(hashAggPlan) {
case WholeStageCodegenExec(_: HashAggregateExec) => true
+ case _: CometHashAggregateExec => true
case _ => false
}.isDefined)
} else {
- assert(stripAQEPlan(hashAggPlan).isInstanceOf[HashAggregateExec])
+ assert(stripAQEPlan(hashAggPlan) match {
+ case _: HashAggregateExec => true
+ case CometColumnarToRowExec(_: CometHashAggregateExec) => true
+ case _ => false
+ })
}

// test case for ObjectHashAggregate and SortAggregate
@@ -738,7 +744,9 @@ class DataFrameAggregateSuite extends QueryTest
val objHashAggOrSortAggPlan =
stripAQEPlan(objHashAggOrSortAggDF.queryExecution.executedPlan)
if (useObjectHashAgg) {
- assert(objHashAggOrSortAggPlan.isInstanceOf[ObjectHashAggregateExec])
+ assert(objHashAggOrSortAggPlan.isInstanceOf[ObjectHashAggregateExec] ||
+ objHashAggOrSortAggPlan.isInstanceOf[CometHashAggregateExec]
+ )
} else {
assert(objHashAggOrSortAggPlan.isInstanceOf[SortAggregateExec])
}
@@ -793,7 +801,7 @@ class DataFrameAggregateSuite extends QueryTest
assert(objHashAggPlans.nonEmpty)

val exchangePlans = collect(aggPlan) {
Expand Down
48 changes: 32 additions & 16 deletions spark/src/main/scala/org/apache/spark/sql/comet/operators.scala
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ import org.apache.spark.broadcast.Broadcast
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, AttributeSet, Expression, ExpressionSet, Generator, NamedExpression, SortOrder}
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateMode, Final, Partial, PartialMerge}
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateMode, BloomFilterAggregate, Final, Partial, PartialMerge}
import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide}
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.physical._
Expand Down Expand Up @@ -1064,17 +1064,35 @@ trait CometBaseAggregate {
builder: Operator.Builder,
childOp: OperatorOuterClass.Operator*): Option[OperatorOuterClass.Operator] = {

val modes = aggregate.aggregateExpressions.map(_.mode).distinct
// In distinct aggregates there can be a combination of modes
val multiMode = modes.size > 1
// For a final mode HashAggregate, we only need to transform the HashAggregate
// if there is Comet partial aggregation.
val sparkFinalMode = modes.contains(Final) && findCometPartialAgg(aggregate.child).isEmpty

if (multiMode || sparkFinalMode) {
val modes = aggregate.aggregateExpressions.map(_.mode).distinct
if (modes.size > 1) {
return None
}

if (modes.contains(Final)) {
// in most cases, Comet partial aggregates are compatible with Spark final
// aggregates, but there are some exceptions
findPartialAgg(aggregate.child) match {
case Some(agg: HashAggregateExec) if agg.conf.ansiEnabled =>
withInfo(
aggregate,
"Cannot perform final aggregate in Comet because " +
"incompatible partial aggregate ran in Spark")
return None
case Some(child: ObjectHashAggregateExec) =>
if (child.aggregateExpressions.exists(
_.aggregateFunction.isInstanceOf[BloomFilterAggregate])) {
withInfo(
aggregate,
"Cannot perform final aggregate in Comet because " +
"incompatible partial aggregate ran in Spark")
return None
}
case _ =>
}
}

val groupingExpressions = aggregate.groupingExpressions
val aggregateExpressions = aggregate.aggregateExpressions
val aggregateAttributes = aggregate.aggregateAttributes
Expand Down Expand Up @@ -1210,18 +1228,16 @@ trait CometBaseAggregate {
}

/**
* Find the first Comet partial aggregate in the plan. If it reaches a Spark HashAggregate with
* partial mode, it will return None.
* Find the first partial aggregate in the plan.
*/
private def findCometPartialAgg(plan: SparkPlan): Option[CometHashAggregateExec] = {
private def findPartialAgg(plan: SparkPlan): Option[SparkPlan] = {
plan.collectFirst {
case agg: CometHashAggregateExec if agg.aggregateExpressions.forall(_.mode == Partial) =>
Some(agg)
case agg: HashAggregateExec if agg.aggregateExpressions.forall(_.mode == Partial) => None
case agg: ObjectHashAggregateExec if agg.aggregateExpressions.forall(_.mode == Partial) =>
None
case a: AQEShuffleReadExec => findCometPartialAgg(a.child)
case s: ShuffleQueryStageExec => findCometPartialAgg(s.plan)
case agg: BaseAggregateExec if agg.aggregateExpressions.forall(_.mode == Partial) =>
Some(agg)
case a: AQEShuffleReadExec => findPartialAgg(a.child)
case s: ShuffleQueryStageExec => findPartialAgg(s.plan)
}.flatten
}

Expand Down
Loading
Loading