Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/user-guide/latest/compatibility.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ Comet will fall back to Spark for the following expressions when ANSI mode is en
`spark.comet.expression.EXPRNAME.allowIncompatible=true`, where `EXPRNAME` is the Spark expression class name. See
the [Comet Supported Expressions Guide](expressions.md) for more information on this configuration setting.

- Average
- Average (supports all numeric inputs except decimal types)
- Cast (in some cases)

There is an [epic](https://github.com/apache/datafusion-comet/issues/313) where we are tracking the work to fully implement ANSI support.
Expand Down
10 changes: 5 additions & 5 deletions native/core/src/execution/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1840,19 +1840,19 @@ impl PhysicalPlanner {
let child = self.create_expr(expr.child.as_ref().unwrap(), Arc::clone(&schema))?;
let datatype = to_arrow_datatype(expr.datatype.as_ref().unwrap());
let input_datatype = to_arrow_datatype(expr.sum_datatype.as_ref().unwrap());

let builder = match datatype {
DataType::Decimal128(_, _) => {
let func =
AggregateUDF::new_from_impl(AvgDecimal::new(datatype, input_datatype));
AggregateExprBuilder::new(Arc::new(func), vec![child])
}
_ => {
// cast to the result data type of AVG if the result data type is different
// from the input type, e.g. AVG(Int32). We should not expect a cast
// failure since it should have already been checked at Spark side.
// For all other numeric types (Int8/16/32/64, Float32/64):
// Cast to Float64 for accumulation
let child: Arc<dyn PhysicalExpr> =
Arc::new(CastExpr::new(Arc::clone(&child), datatype.clone(), None));
let func = AggregateUDF::new_from_impl(Avg::new("avg", datatype));
Arc::new(CastExpr::new(Arc::clone(&child), DataType::Float64, None));
let func = AggregateUDF::new_from_impl(Avg::new("avg", DataType::Float64));
AggregateExprBuilder::new(Arc::new(func), vec![child])
}
};
Expand Down
2 changes: 1 addition & 1 deletion native/proto/src/proto/expr.proto
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ message Avg {
Expr child = 1;
DataType datatype = 2;
DataType sum_datatype = 3;
bool fail_on_error = 4; // currently unused (useful for deciding Ansi vs Legacy mode)
EvalMode eval_mode = 4;
}

message First {
Expand Down
14 changes: 7 additions & 7 deletions native/spark-expr/src/agg_funcs/avg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ impl AggregateUDFImpl for Avg {
}

fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
// instantiate specialized accumulator based for the type
// All numeric types use Float64 accumulation after casting
match (&self.input_data_type, &self.result_data_type) {
(Float64, Float64) => Ok(Box::<AvgAccumulator>::default()),
_ => not_impl_err!(
Expand Down Expand Up @@ -115,7 +115,6 @@ impl AggregateUDFImpl for Avg {
&self,
_args: AccumulatorArgs,
) -> Result<Box<dyn GroupsAccumulator>> {
// instantiate specialized accumulator based for the type
match (&self.input_data_type, &self.result_data_type) {
(Float64, Float64) => Ok(Box::new(AvgGroupsAccumulator::<Float64Type, _>::new(
&self.input_data_type,
Expand Down Expand Up @@ -172,7 +171,7 @@ impl Accumulator for AvgAccumulator {
// counts are summed
self.count += sum(states[1].as_primitive::<Int64Type>()).unwrap_or_default();

// sums are summed
// sums are summed - no overflow checking in all Eval Modes
if let Some(x) = sum(states[0].as_primitive::<Float64Type>()) {
let v = self.sum.get_or_insert(0.);
*v += x;
Expand All @@ -182,7 +181,7 @@ impl Accumulator for AvgAccumulator {

fn evaluate(&mut self) -> Result<ScalarValue> {
if self.count == 0 {
// If all input are nulls, count will be 0 and we will get null after the division.
// If all input are nulls, count will be 0, and we will get null after the division.
// This is consistent with Spark Average implementation.
Ok(ScalarValue::Float64(None))
} else {
Expand All @@ -198,7 +197,8 @@ impl Accumulator for AvgAccumulator {
}

/// An accumulator to compute the average of `[PrimitiveArray<T>]`.
/// Stores values as native types, and does overflow checking
/// Stores values as native types (
/// no overflow check all eval modes since inf is a perfectly valid value per spark impl)
///
/// F: Function that calculates the average value from a sum of
/// T::Native and a total count
Expand Down Expand Up @@ -260,6 +260,7 @@ where
if values.null_count() == 0 {
for (&group_index, &value) in iter {
let sum = &mut self.sums[group_index];
// No overflow checking - Infinity is a valid result
*sum = (*sum).add_wrapping(value);
self.counts[group_index] += 1;
}
Expand Down Expand Up @@ -296,7 +297,7 @@ where
self.counts[group_index] += partial_count;
}

// update sums
// update sums - no overflow checking (in all eval modes)
self.sums.resize(total_num_groups, T::default_value());
let iter2 = group_indices.iter().zip(partial_sums.values().iter());
for (&group_index, &new_value) in iter2 {
Expand Down Expand Up @@ -325,7 +326,6 @@ where
Ok(Arc::new(array))
}

// return arrays for sums and counts
fn state(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> {
let counts = emit_to.take_needed(&mut self.counts);
let counts = Int64Array::new(counts.into(), None);
Expand Down
15 changes: 2 additions & 13 deletions spark/src/main/scala/org/apache/comet/serde/aggregates.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ package org.apache.comet.serde

import scala.jdk.CollectionConverters._

import org.apache.spark.sql.catalyst.expressions.{Attribute, EvalMode}
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Average, BitAndAgg, BitOrAgg, BitXorAgg, BloomFilterAggregate, CentralMomentAgg, Corr, Count, Covariance, CovPopulation, CovSample, First, Last, Max, Min, StddevPop, StddevSamp, Sum, VariancePop, VarianceSamp}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{ByteType, DataTypes, DecimalType, IntegerType, LongType, ShortType, StringType}
Expand Down Expand Up @@ -151,17 +151,6 @@ object CometCount extends CometAggregateExpressionSerde[Count] {

object CometAverage extends CometAggregateExpressionSerde[Average] {

override def getSupportLevel(avg: Average): SupportLevel = {
avg.evalMode match {
case EvalMode.ANSI =>
Incompatible(Some("ANSI mode is not supported"))
case EvalMode.TRY =>
Incompatible(Some("TRY mode is not supported"))
case _ =>
Compatible()
}
}

override def convert(
aggExpr: AggregateExpression,
avg: Average,
Expand Down Expand Up @@ -193,7 +182,7 @@ object CometAverage extends CometAggregateExpressionSerde[Average] {
val builder = ExprOuterClass.Avg.newBuilder()
builder.setChild(childExpr.get)
builder.setDatatype(dataType.get)
builder.setFailOnError(avg.evalMode == EvalMode.ANSI)
builder.setEvalMode(evalModeToProto(CometEvalModeUtil.fromSparkEvalMode(avg.evalMode)))
builder.setSumDatatype(sumDataType.get)

Some(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1471,6 +1471,89 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper {
}
}

test("AVG and try_avg - basic functionality") {
withParquetTable(
Seq(
(10L, 1),
(20L, 1),
(null.asInstanceOf[Long], 1),
(100L, 2),
(200L, 2),
(null.asInstanceOf[Long], 3)),
"tbl") {

Seq(true, false).foreach({ ansiMode =>
// without GROUP BY
withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiMode.toString) {
val res = sql("SELECT avg(_1) FROM tbl")
checkSparkAnswerAndOperator(res)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The last nit: can we remove res* variables here and refer to the sql text

}

// with GROUP BY
withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiMode.toString) {
val res = sql("SELECT _2, avg(_1) FROM tbl GROUP BY _2")
checkSparkAnswerAndOperator(res)
}
})

// try_avg without GROUP BY
val resTry = sql("SELECT try_avg(_1) FROM tbl")
checkSparkAnswerAndOperator(resTry)

// try_avg with GROUP BY
val resTryGroup = sql("SELECT _2, try_avg(_1) FROM tbl GROUP BY _2")
checkSparkAnswerAndOperator(resTryGroup)

}
}

test("AVG and try_avg - special numbers") {

val negativeNumbers: Seq[(Long, Int)] = Seq(
(-1L, 1),
(-123L, 1),
(-456L, 1),
(Long.MinValue, 1),
(Long.MinValue, 1),
(Long.MinValue, 2),
(Long.MinValue, 2),
(null.asInstanceOf[Long], 3))

val zeroSeq: Seq[(Long, Int)] =
Seq((0L, 1), (-0L, 1), (+0L, 2), (+0L, 2), (null.asInstanceOf[Long], 3))

val highValNumbers: Seq[(Long, Int)] = Seq(
(Long.MaxValue, 1),
(Long.MaxValue, 1),
(Long.MaxValue, 2),
(Long.MaxValue, 2),
(null.asInstanceOf[Long], 3))

val inputs = Seq(negativeNumbers, highValNumbers, zeroSeq)
inputs.foreach(inputSeq => {
withParquetTable(inputSeq, "tbl") {
Seq(true, false).foreach({ ansiMode =>
// without GROUP BY
withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiMode.toString) {
checkSparkAnswerAndOperator("SELECT avg(_1) FROM tbl")
}

// with GROUP BY
withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiMode.toString) {
checkSparkAnswerAndOperator("SELECT _2, avg(_1) FROM tbl GROUP BY _2")
}
})

// try_avg without GROUP BY
checkSparkAnswerAndOperator("SELECT try_avg(_1) FROM tbl")

// try_avg with GROUP BY
checkSparkAnswerAndOperator("SELECT _2, try_avg(_1) FROM tbl GROUP BY _2")

}
})
}

test("ANSI support for sum - null test") {
Seq(true, false).foreach { ansiEnabled =>
withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ import org.apache.spark.SparkContext
import org.apache.spark.internal.config.{MEMORY_OFFHEAP_ENABLED, MEMORY_OFFHEAP_SIZE}
import org.apache.spark.sql.TPCDSBase
import org.apache.spark.sql.catalyst.expressions.{AttributeSet, Cast}
import org.apache.spark.sql.catalyst.expressions.aggregate.Average
import org.apache.spark.sql.catalyst.util.resourceToString
import org.apache.spark.sql.execution.{FormattedMode, ReusedSubqueryExec, SparkPlan, SubqueryBroadcastExec, SubqueryExec}
import org.apache.spark.sql.execution.adaptive.DisableAdaptiveExecutionSuite
Expand Down Expand Up @@ -226,8 +225,6 @@ trait CometPlanStabilitySuite extends DisableAdaptiveExecutionSuite with TPCDSBa
CometConf.COMET_DPP_FALLBACK_ENABLED.key -> "false",
CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true",
CometConf.COMET_EXEC_SORT_MERGE_JOIN_WITH_JOIN_FILTER_ENABLED.key -> "true",
// Allow Incompatible is needed for Sum + Average for Spark 4.0.0 / ANSI support
CometConf.getExprAllowIncompatConfigKey(classOf[Average]) -> "true",
// as well as for v1.4/q9, v1.4/q44, v2.7.0/q6, v2.7.0/q64
CometConf.getExprAllowIncompatConfigKey(classOf[Cast]) -> "true",
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "10MB") {
Expand Down
Loading