diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index 269ded1e48..95e7c33c6a 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -903,7 +903,6 @@ impl PhysicalPlanner { OpStruct::HashAgg(agg) => { assert_eq!(children.len(), 1); let (scans, child) = self.create_plan(&children[0], inputs, partition_count)?; - let group_exprs: PhyExprResult = agg .grouping_exprs .iter() @@ -916,12 +915,16 @@ impl PhysicalPlanner { let group_by = PhysicalGroupBy::new_single(group_exprs?); let schema = child.schema(); - let mode = if agg.mode == 0 { - DFAggregateMode::Partial - } else { + // dbg!(agg); + + let mode = if agg.mode.contains(&2) { DFAggregateMode::Final + } else { + DFAggregateMode::Partial }; + dbg!(&schema); + let agg_exprs: PhyAggResult = agg .agg_exprs .iter() diff --git a/native/proto/src/proto/operator.proto b/native/proto/src/proto/operator.proto index 9d4435751d..8240699c58 100644 --- a/native/proto/src/proto/operator.proto +++ b/native/proto/src/proto/operator.proto @@ -214,7 +214,8 @@ message HashAggregate { repeated spark.spark_expression.Expr grouping_exprs = 1; repeated spark.spark_expression.AggExpr agg_exprs = 2; repeated spark.spark_expression.Expr result_exprs = 3; - AggregateMode mode = 5; + // Spark can have both partial and partialMerge together + repeated AggregateMode mode = 5; } message Limit { @@ -253,7 +254,10 @@ message ParquetWriter { enum AggregateMode { Partial = 0; - Final = 1; + PartialMerge = 1; + Final = 2; + // Spark supports the COMPLETE but it a stub for now + Complete = 3; } message Expand { diff --git a/spark/src/main/scala/org/apache/comet/serde/namedExpressions.scala b/spark/src/main/scala/org/apache/comet/serde/namedExpressions.scala index aba52d3624..590bdacab5 100644 --- a/spark/src/main/scala/org/apache/comet/serde/namedExpressions.scala +++ b/spark/src/main/scala/org/apache/comet/serde/namedExpressions.scala @@ -47,8 +47,7 @@ object CometAttributeReference extends CometExpressionSerde[AttributeReference] if (dataType.isDefined) { if (binding) { // Spark may produce unresolvable attributes in some cases, - // for example https://github.com/apache/datafusion-comet/issues/925. - // So, we allow the binding to fail. + // for example partial aggregation or https://github.com/apache/datafusion-comet/issues/925. val boundRef: Any = BindReferences .bindReference(attr, inputs, allowFailures = true) diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala index 0a435e5b7a..0b8e1f3ec1 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala @@ -31,7 +31,7 @@ import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, AttributeSet, Expression, ExpressionSet, Generator, NamedExpression, SortOrder} -import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateMode, Final, Partial, PartialMerge} +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateMode, Complete, Final, Partial, PartialMerge} import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide} import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical._ @@ -1066,12 +1066,11 @@ trait CometBaseAggregate { val modes = aggregate.aggregateExpressions.map(_.mode).distinct // In distinct aggregates there can be a combination of modes - val multiMode = modes.size > 1 // For a final mode HashAggregate, we only need to transform the HashAggregate // if there is Comet partial aggregation. val sparkFinalMode = modes.contains(Final) && findCometPartialAgg(aggregate.child).isEmpty - if (multiMode || sparkFinalMode) { + if (sparkFinalMode) { return None } @@ -1144,33 +1143,34 @@ trait CometBaseAggregate { hashAggBuilder.addAllResultExprs(resultExprs.map(_.get).asJava) Some(builder.setHashAgg(hashAggBuilder).build()) } else { - val modes = aggregateExpressions.map(_.mode).distinct - - if (modes.size != 1) { - // This shouldn't happen as all aggregation expressions should share the same mode. - // Fallback to Spark nevertheless here. - withInfo(aggregate, "All aggregate expressions do not have the same mode") - return None - } - - val mode = modes.head match { - case Partial => CometAggregateMode.Partial - case Final => CometAggregateMode.Final - case _ => - withInfo(aggregate, s"Unsupported aggregation mode ${modes.head}") - return None - } + // `output` is only used when `binding` is true (i.e., non-Final) + val output = child.output // In final mode, the aggregate expressions are bound to the output of the // child and partial aggregate expressions buffer attributes produced by partial // aggregation. This is done in Spark `HashAggregateExec` internally. In Comet, // we don't have to do this because we don't use the merging expression. - val binding = mode != CometAggregateMode.Final - // `output` is only used when `binding` is true (i.e., non-Final) - val output = child.output - - val aggExprs = - aggregateExpressions.map(aggExprToProto(_, output, binding, aggregate.conf)) + // + // It is possible to have multiple modes for queries with DISTINCT agg expression + // So Spark can Partial and PartialMerge at the same time + val (aggExprs, aggModes) = + aggregateExpressions + .map(a => + ( + aggExprToProto( + a, + output, + a.mode != PartialMerge && a.mode != Final, + aggregate.conf), + a.mode match { + case Partial => CometAggregateMode.Partial + case PartialMerge => CometAggregateMode.PartialMerge + case Final => CometAggregateMode.Final + case mode => + withInfo(aggregate, s"Unsupported Aggregation Mode $mode") + return None + })) + .unzip if (aggExprs.exists(_.isEmpty)) { withInfo( @@ -1185,7 +1185,9 @@ trait CometBaseAggregate { val hashAggBuilder = OperatorOuterClass.HashAggregate.newBuilder() hashAggBuilder.addAllGroupingExprs(groupingExprs.map(_.get).asJava) hashAggBuilder.addAllAggExprs(aggExprs.map(_.get).asJava) - if (mode == CometAggregateMode.Final) { + // Spark sending Final separately only, + // so if any entry is Final means everything else is also Final + if (modes.contains(Final)) { val attributes = groupingExpressions.map(_.toAttribute) ++ aggregateAttributes val resultExprs = resultExpressions.map(exprToProto(_, attributes)) if (resultExprs.exists(_.isEmpty)) { @@ -1197,7 +1199,7 @@ trait CometBaseAggregate { } hashAggBuilder.addAllResultExprs(resultExprs.map(_.get).asJava) } - hashAggBuilder.setModeValue(mode.getNumber) + hashAggBuilder.addAllMode(aggModes.asJava) Some(builder.setHashAgg(hashAggBuilder).build()) } else { val allChildren: Seq[Expression] = @@ -1323,8 +1325,6 @@ case class CometHashAggregateExec( // modes is empty too. If aggExprs is not empty, we need to verify all the // aggregates have the same mode. val modes: Seq[AggregateMode] = aggregateExpressions.map(_.mode).distinct - assert(modes.length == 1 || modes.isEmpty) - val mode = modes.headOption override def producedAttributes: AttributeSet = outputSet ++ AttributeSet(resultExpressions) @@ -1341,7 +1341,7 @@ case class CometHashAggregateExec( } override def stringArgs: Iterator[Any] = - Iterator(input, mode, groupingExpressions, aggregateExpressions, child) + Iterator(input, modes, groupingExpressions, aggregateExpressions, child) override def equals(obj: Any): Boolean = { obj match { @@ -1350,7 +1350,7 @@ case class CometHashAggregateExec( this.groupingExpressions == other.groupingExpressions && this.aggregateExpressions == other.aggregateExpressions && this.input == other.input && - this.mode == other.mode && + this.modes == other.modes && this.child == other.child && this.serializedPlanOpt == other.serializedPlanOpt case _ => @@ -1359,7 +1359,7 @@ case class CometHashAggregateExec( } override def hashCode(): Int = - Objects.hashCode(output, groupingExpressions, aggregateExpressions, input, mode, child) + Objects.hashCode(output, groupingExpressions, aggregateExpressions, input, modes, child) override protected def outputExpressions: Seq[NamedExpression] = resultExpressions } diff --git a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala index 9f9df73a91..d6f0d4501b 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala @@ -481,9 +481,12 @@ class CometExecSuite extends CometTestBase { case s: CometHashAggregateExec => s }.get - assert(agg.mode.isDefined && agg.mode.get.isInstanceOf[AggregateMode]) + assert( + agg.modes.headOption.isDefined && agg.modes.headOption.get.isInstanceOf[AggregateMode]) val newAgg = agg.cleanBlock().asInstanceOf[CometHashAggregateExec] - assert(newAgg.mode.isDefined && newAgg.mode.get.isInstanceOf[AggregateMode]) + assert( + newAgg.modes.headOption.isDefined && newAgg.modes.headOption.get + .isInstanceOf[AggregateMode]) } }