diff --git a/docs/source/user-guide/latest/compatibility.md b/docs/source/user-guide/latest/compatibility.md index 3d2c9a7b55..5f5d94d9e1 100644 --- a/docs/source/user-guide/latest/compatibility.md +++ b/docs/source/user-guide/latest/compatibility.md @@ -36,7 +36,7 @@ Comet will fall back to Spark for the following expressions when ANSI mode is en `spark.comet.expression.EXPRNAME.allowIncompatible=true`, where `EXPRNAME` is the Spark expression class name. See the [Comet Supported Expressions Guide](expressions.md) for more information on this configuration setting. -- Average +- Average (supports all numeric inputs except decimal types) - Cast (in some cases) There is an [epic](https://github.com/apache/datafusion-comet/issues/313) where we are tracking the work to fully implement ANSI support. diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index 8e8191dd0e..93fbb59c11 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -1840,6 +1840,7 @@ impl PhysicalPlanner { let child = self.create_expr(expr.child.as_ref().unwrap(), Arc::clone(&schema))?; let datatype = to_arrow_datatype(expr.datatype.as_ref().unwrap()); let input_datatype = to_arrow_datatype(expr.sum_datatype.as_ref().unwrap()); + let builder = match datatype { DataType::Decimal128(_, _) => { let func = @@ -1847,12 +1848,11 @@ impl PhysicalPlanner { AggregateExprBuilder::new(Arc::new(func), vec![child]) } _ => { - // cast to the result data type of AVG if the result data type is different - // from the input type, e.g. AVG(Int32). We should not expect a cast - // failure since it should have already been checked at Spark side. + // For all other numeric types (Int8/16/32/64, Float32/64): + // Cast to Float64 for accumulation let child: Arc = - Arc::new(CastExpr::new(Arc::clone(&child), datatype.clone(), None)); - let func = AggregateUDF::new_from_impl(Avg::new("avg", datatype)); + Arc::new(CastExpr::new(Arc::clone(&child), DataType::Float64, None)); + let func = AggregateUDF::new_from_impl(Avg::new("avg", DataType::Float64)); AggregateExprBuilder::new(Arc::new(func), vec![child]) } }; diff --git a/native/proto/src/proto/expr.proto b/native/proto/src/proto/expr.proto index 1c453b6336..5f258fd677 100644 --- a/native/proto/src/proto/expr.proto +++ b/native/proto/src/proto/expr.proto @@ -138,7 +138,7 @@ message Avg { Expr child = 1; DataType datatype = 2; DataType sum_datatype = 3; - bool fail_on_error = 4; // currently unused (useful for deciding Ansi vs Legacy mode) + EvalMode eval_mode = 4; } message First { diff --git a/native/spark-expr/src/agg_funcs/avg.rs b/native/spark-expr/src/agg_funcs/avg.rs index e746aaf6e5..d1d71cca21 100644 --- a/native/spark-expr/src/agg_funcs/avg.rs +++ b/native/spark-expr/src/agg_funcs/avg.rs @@ -73,7 +73,7 @@ impl AggregateUDFImpl for Avg { } fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result> { - // instantiate specialized accumulator based for the type + // All numeric types use Float64 accumulation after casting match (&self.input_data_type, &self.result_data_type) { (Float64, Float64) => Ok(Box::::default()), _ => not_impl_err!( @@ -115,7 +115,6 @@ impl AggregateUDFImpl for Avg { &self, _args: AccumulatorArgs, ) -> Result> { - // instantiate specialized accumulator based for the type match (&self.input_data_type, &self.result_data_type) { (Float64, Float64) => Ok(Box::new(AvgGroupsAccumulator::::new( &self.input_data_type, @@ -172,7 +171,7 @@ impl Accumulator for AvgAccumulator { // counts are summed self.count += sum(states[1].as_primitive::()).unwrap_or_default(); - // sums are summed + // sums are summed - no overflow checking in all Eval Modes if let Some(x) = sum(states[0].as_primitive::()) { let v = self.sum.get_or_insert(0.); *v += x; @@ -182,7 +181,7 @@ impl Accumulator for AvgAccumulator { fn evaluate(&mut self) -> Result { if self.count == 0 { - // If all input are nulls, count will be 0 and we will get null after the division. + // If all input are nulls, count will be 0, and we will get null after the division. // This is consistent with Spark Average implementation. Ok(ScalarValue::Float64(None)) } else { @@ -198,7 +197,8 @@ impl Accumulator for AvgAccumulator { } /// An accumulator to compute the average of `[PrimitiveArray]`. -/// Stores values as native types, and does overflow checking +/// Stores values as native types ( +/// no overflow check all eval modes since inf is a perfectly valid value per spark impl) /// /// F: Function that calculates the average value from a sum of /// T::Native and a total count @@ -260,6 +260,7 @@ where if values.null_count() == 0 { for (&group_index, &value) in iter { let sum = &mut self.sums[group_index]; + // No overflow checking - Infinity is a valid result *sum = (*sum).add_wrapping(value); self.counts[group_index] += 1; } @@ -296,7 +297,7 @@ where self.counts[group_index] += partial_count; } - // update sums + // update sums - no overflow checking (in all eval modes) self.sums.resize(total_num_groups, T::default_value()); let iter2 = group_indices.iter().zip(partial_sums.values().iter()); for (&group_index, &new_value) in iter2 { @@ -325,7 +326,6 @@ where Ok(Arc::new(array)) } - // return arrays for sums and counts fn state(&mut self, emit_to: EmitTo) -> Result> { let counts = emit_to.take_needed(&mut self.counts); let counts = Int64Array::new(counts.into(), None); diff --git a/spark/src/main/scala/org/apache/comet/serde/aggregates.scala b/spark/src/main/scala/org/apache/comet/serde/aggregates.scala index a05efaebbc..8e58c08740 100644 --- a/spark/src/main/scala/org/apache/comet/serde/aggregates.scala +++ b/spark/src/main/scala/org/apache/comet/serde/aggregates.scala @@ -21,7 +21,7 @@ package org.apache.comet.serde import scala.jdk.CollectionConverters._ -import org.apache.spark.sql.catalyst.expressions.{Attribute, EvalMode} +import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Average, BitAndAgg, BitOrAgg, BitXorAgg, BloomFilterAggregate, CentralMomentAgg, Corr, Count, Covariance, CovPopulation, CovSample, First, Last, Max, Min, StddevPop, StddevSamp, Sum, VariancePop, VarianceSamp} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{ByteType, DataTypes, DecimalType, IntegerType, LongType, ShortType, StringType} @@ -151,17 +151,6 @@ object CometCount extends CometAggregateExpressionSerde[Count] { object CometAverage extends CometAggregateExpressionSerde[Average] { - override def getSupportLevel(avg: Average): SupportLevel = { - avg.evalMode match { - case EvalMode.ANSI => - Incompatible(Some("ANSI mode is not supported")) - case EvalMode.TRY => - Incompatible(Some("TRY mode is not supported")) - case _ => - Compatible() - } - } - override def convert( aggExpr: AggregateExpression, avg: Average, @@ -193,7 +182,7 @@ object CometAverage extends CometAggregateExpressionSerde[Average] { val builder = ExprOuterClass.Avg.newBuilder() builder.setChild(childExpr.get) builder.setDatatype(dataType.get) - builder.setFailOnError(avg.evalMode == EvalMode.ANSI) + builder.setEvalMode(evalModeToProto(CometEvalModeUtil.fromSparkEvalMode(avg.evalMode))) builder.setSumDatatype(sumDataType.get) Some( diff --git a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala index 9b2816c2fd..14b5dc3092 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala @@ -1471,6 +1471,89 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } + test("AVG and try_avg - basic functionality") { + withParquetTable( + Seq( + (10L, 1), + (20L, 1), + (null.asInstanceOf[Long], 1), + (100L, 2), + (200L, 2), + (null.asInstanceOf[Long], 3)), + "tbl") { + + Seq(true, false).foreach({ ansiMode => + // without GROUP BY + withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiMode.toString) { + val res = sql("SELECT avg(_1) FROM tbl") + checkSparkAnswerAndOperator(res) + } + + // with GROUP BY + withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiMode.toString) { + val res = sql("SELECT _2, avg(_1) FROM tbl GROUP BY _2") + checkSparkAnswerAndOperator(res) + } + }) + + // try_avg without GROUP BY + val resTry = sql("SELECT try_avg(_1) FROM tbl") + checkSparkAnswerAndOperator(resTry) + + // try_avg with GROUP BY + val resTryGroup = sql("SELECT _2, try_avg(_1) FROM tbl GROUP BY _2") + checkSparkAnswerAndOperator(resTryGroup) + + } + } + + test("AVG and try_avg - special numbers") { + + val negativeNumbers: Seq[(Long, Int)] = Seq( + (-1L, 1), + (-123L, 1), + (-456L, 1), + (Long.MinValue, 1), + (Long.MinValue, 1), + (Long.MinValue, 2), + (Long.MinValue, 2), + (null.asInstanceOf[Long], 3)) + + val zeroSeq: Seq[(Long, Int)] = + Seq((0L, 1), (-0L, 1), (+0L, 2), (+0L, 2), (null.asInstanceOf[Long], 3)) + + val highValNumbers: Seq[(Long, Int)] = Seq( + (Long.MaxValue, 1), + (Long.MaxValue, 1), + (Long.MaxValue, 2), + (Long.MaxValue, 2), + (null.asInstanceOf[Long], 3)) + + val inputs = Seq(negativeNumbers, highValNumbers, zeroSeq) + inputs.foreach(inputSeq => { + withParquetTable(inputSeq, "tbl") { + Seq(true, false).foreach({ ansiMode => + // without GROUP BY + withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiMode.toString) { + checkSparkAnswerAndOperator("SELECT avg(_1) FROM tbl") + } + + // with GROUP BY + withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiMode.toString) { + checkSparkAnswerAndOperator("SELECT _2, avg(_1) FROM tbl GROUP BY _2") + } + }) + + // try_avg without GROUP BY + checkSparkAnswerAndOperator("SELECT try_avg(_1) FROM tbl") + + // try_avg with GROUP BY + checkSparkAnswerAndOperator("SELECT _2, try_avg(_1) FROM tbl GROUP BY _2") + + } + }) + } + test("ANSI support for sum - null test") { Seq(true, false).foreach { ansiEnabled => withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString) { diff --git a/spark/src/test/scala/org/apache/spark/sql/comet/CometPlanStabilitySuite.scala b/spark/src/test/scala/org/apache/spark/sql/comet/CometPlanStabilitySuite.scala index b1848ff513..adf74ba549 100644 --- a/spark/src/test/scala/org/apache/spark/sql/comet/CometPlanStabilitySuite.scala +++ b/spark/src/test/scala/org/apache/spark/sql/comet/CometPlanStabilitySuite.scala @@ -29,7 +29,6 @@ import org.apache.spark.SparkContext import org.apache.spark.internal.config.{MEMORY_OFFHEAP_ENABLED, MEMORY_OFFHEAP_SIZE} import org.apache.spark.sql.TPCDSBase import org.apache.spark.sql.catalyst.expressions.{AttributeSet, Cast} -import org.apache.spark.sql.catalyst.expressions.aggregate.Average import org.apache.spark.sql.catalyst.util.resourceToString import org.apache.spark.sql.execution.{FormattedMode, ReusedSubqueryExec, SparkPlan, SubqueryBroadcastExec, SubqueryExec} import org.apache.spark.sql.execution.adaptive.DisableAdaptiveExecutionSuite @@ -226,8 +225,6 @@ trait CometPlanStabilitySuite extends DisableAdaptiveExecutionSuite with TPCDSBa CometConf.COMET_DPP_FALLBACK_ENABLED.key -> "false", CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true", CometConf.COMET_EXEC_SORT_MERGE_JOIN_WITH_JOIN_FILTER_ENABLED.key -> "true", - // Allow Incompatible is needed for Sum + Average for Spark 4.0.0 / ANSI support - CometConf.getExprAllowIncompatConfigKey(classOf[Average]) -> "true", // as well as for v1.4/q9, v1.4/q44, v2.7.0/q6, v2.7.0/q64 CometConf.getExprAllowIncompatConfigKey(classOf[Cast]) -> "true", SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "10MB") {