Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions native/core/src/execution/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -903,7 +903,6 @@ impl PhysicalPlanner {
OpStruct::HashAgg(agg) => {
assert_eq!(children.len(), 1);
let (scans, child) = self.create_plan(&children[0], inputs, partition_count)?;

let group_exprs: PhyExprResult = agg
.grouping_exprs
.iter()
Expand All @@ -916,12 +915,16 @@ impl PhysicalPlanner {
let group_by = PhysicalGroupBy::new_single(group_exprs?);
let schema = child.schema();

let mode = if agg.mode == 0 {
DFAggregateMode::Partial
} else {
// dbg!(agg);

let mode = if agg.mode.contains(&2) {
DFAggregateMode::Final
} else {
DFAggregateMode::Partial
};

dbg!(&schema);

let agg_exprs: PhyAggResult = agg
.agg_exprs
.iter()
Expand Down
8 changes: 6 additions & 2 deletions native/proto/src/proto/operator.proto
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,8 @@ message HashAggregate {
repeated spark.spark_expression.Expr grouping_exprs = 1;
repeated spark.spark_expression.AggExpr agg_exprs = 2;
repeated spark.spark_expression.Expr result_exprs = 3;
AggregateMode mode = 5;
// Spark can have both partial and partialMerge together
repeated AggregateMode mode = 5;
}

message Limit {
Expand Down Expand Up @@ -253,7 +254,10 @@ message ParquetWriter {

enum AggregateMode {
Partial = 0;
Final = 1;
PartialMerge = 1;
Final = 2;
// Spark supports the COMPLETE but it a stub for now
Complete = 3;
}

message Expand {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,7 @@ object CometAttributeReference extends CometExpressionSerde[AttributeReference]
if (dataType.isDefined) {
if (binding) {
// Spark may produce unresolvable attributes in some cases,
// for example https://github.com/apache/datafusion-comet/issues/925.
// So, we allow the binding to fail.
// for example partial aggregation or https://github.com/apache/datafusion-comet/issues/925.
val boundRef: Any = BindReferences
.bindReference(attr, inputs, allowFailures = true)

Expand Down
64 changes: 32 additions & 32 deletions spark/src/main/scala/org/apache/spark/sql/comet/operators.scala
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ import org.apache.spark.broadcast.Broadcast
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, AttributeSet, Expression, ExpressionSet, Generator, NamedExpression, SortOrder}
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateMode, Final, Partial, PartialMerge}
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateMode, Complete, Final, Partial, PartialMerge}
import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide}
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.physical._
Expand Down Expand Up @@ -1066,12 +1066,11 @@ trait CometBaseAggregate {

val modes = aggregate.aggregateExpressions.map(_.mode).distinct
// In distinct aggregates there can be a combination of modes
val multiMode = modes.size > 1
// For a final mode HashAggregate, we only need to transform the HashAggregate
// if there is Comet partial aggregation.
val sparkFinalMode = modes.contains(Final) && findCometPartialAgg(aggregate.child).isEmpty

if (multiMode || sparkFinalMode) {
if (sparkFinalMode) {
return None
}

Expand Down Expand Up @@ -1144,33 +1143,34 @@ trait CometBaseAggregate {
hashAggBuilder.addAllResultExprs(resultExprs.map(_.get).asJava)
Some(builder.setHashAgg(hashAggBuilder).build())
} else {
val modes = aggregateExpressions.map(_.mode).distinct

if (modes.size != 1) {
// This shouldn't happen as all aggregation expressions should share the same mode.
// Fallback to Spark nevertheless here.
withInfo(aggregate, "All aggregate expressions do not have the same mode")
return None
}

val mode = modes.head match {
case Partial => CometAggregateMode.Partial
case Final => CometAggregateMode.Final
case _ =>
withInfo(aggregate, s"Unsupported aggregation mode ${modes.head}")
return None
}
// `output` is only used when `binding` is true (i.e., non-Final)
val output = child.output

// In final mode, the aggregate expressions are bound to the output of the
// child and partial aggregate expressions buffer attributes produced by partial
// aggregation. This is done in Spark `HashAggregateExec` internally. In Comet,
// we don't have to do this because we don't use the merging expression.
val binding = mode != CometAggregateMode.Final
// `output` is only used when `binding` is true (i.e., non-Final)
val output = child.output

val aggExprs =
aggregateExpressions.map(aggExprToProto(_, output, binding, aggregate.conf))
//
// It is possible to have multiple modes for queries with DISTINCT agg expression
// So Spark can Partial and PartialMerge at the same time
val (aggExprs, aggModes) =
aggregateExpressions
.map(a =>
(
aggExprToProto(
a,
output,
a.mode != PartialMerge && a.mode != Final,
aggregate.conf),
a.mode match {
case Partial => CometAggregateMode.Partial
case PartialMerge => CometAggregateMode.PartialMerge
case Final => CometAggregateMode.Final
case mode =>
withInfo(aggregate, s"Unsupported Aggregation Mode $mode")
return None
}))
.unzip

if (aggExprs.exists(_.isEmpty)) {
withInfo(
Expand All @@ -1185,7 +1185,9 @@ trait CometBaseAggregate {
val hashAggBuilder = OperatorOuterClass.HashAggregate.newBuilder()
hashAggBuilder.addAllGroupingExprs(groupingExprs.map(_.get).asJava)
hashAggBuilder.addAllAggExprs(aggExprs.map(_.get).asJava)
if (mode == CometAggregateMode.Final) {
// Spark sending Final separately only,
// so if any entry is Final means everything else is also Final
if (modes.contains(Final)) {
val attributes = groupingExpressions.map(_.toAttribute) ++ aggregateAttributes
val resultExprs = resultExpressions.map(exprToProto(_, attributes))
if (resultExprs.exists(_.isEmpty)) {
Expand All @@ -1197,7 +1199,7 @@ trait CometBaseAggregate {
}
hashAggBuilder.addAllResultExprs(resultExprs.map(_.get).asJava)
}
hashAggBuilder.setModeValue(mode.getNumber)
hashAggBuilder.addAllMode(aggModes.asJava)
Some(builder.setHashAgg(hashAggBuilder).build())
} else {
val allChildren: Seq[Expression] =
Expand Down Expand Up @@ -1323,8 +1325,6 @@ case class CometHashAggregateExec(
// modes is empty too. If aggExprs is not empty, we need to verify all the
// aggregates have the same mode.
val modes: Seq[AggregateMode] = aggregateExpressions.map(_.mode).distinct
assert(modes.length == 1 || modes.isEmpty)
val mode = modes.headOption

override def producedAttributes: AttributeSet = outputSet ++ AttributeSet(resultExpressions)

Expand All @@ -1341,7 +1341,7 @@ case class CometHashAggregateExec(
}

override def stringArgs: Iterator[Any] =
Iterator(input, mode, groupingExpressions, aggregateExpressions, child)
Iterator(input, modes, groupingExpressions, aggregateExpressions, child)

override def equals(obj: Any): Boolean = {
obj match {
Expand All @@ -1350,7 +1350,7 @@ case class CometHashAggregateExec(
this.groupingExpressions == other.groupingExpressions &&
this.aggregateExpressions == other.aggregateExpressions &&
this.input == other.input &&
this.mode == other.mode &&
this.modes == other.modes &&
this.child == other.child &&
this.serializedPlanOpt == other.serializedPlanOpt
case _ =>
Expand All @@ -1359,7 +1359,7 @@ case class CometHashAggregateExec(
}

override def hashCode(): Int =
Objects.hashCode(output, groupingExpressions, aggregateExpressions, input, mode, child)
Objects.hashCode(output, groupingExpressions, aggregateExpressions, input, modes, child)

override protected def outputExpressions: Seq[NamedExpression] = resultExpressions
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -481,9 +481,12 @@ class CometExecSuite extends CometTestBase {
case s: CometHashAggregateExec => s
}.get

assert(agg.mode.isDefined && agg.mode.get.isInstanceOf[AggregateMode])
assert(
agg.modes.headOption.isDefined && agg.modes.headOption.get.isInstanceOf[AggregateMode])
val newAgg = agg.cleanBlock().asInstanceOf[CometHashAggregateExec]
assert(newAgg.mode.isDefined && newAgg.mode.get.isInstanceOf[AggregateMode])
assert(
newAgg.modes.headOption.isDefined && newAgg.modes.headOption.get
.isInstanceOf[AggregateMode])
}
}

Expand Down
Loading