From c60d2db38331745a7b978f80497c66b3ccf16d76 Mon Sep 17 00:00:00 2001 From: B Vadlamani Date: Thu, 20 Nov 2025 11:29:02 -0800 Subject: [PATCH 01/16] support_ansi_avg_wip --- native/core/src/execution/planner.rs | 8 +- native/proto/src/proto/expr.proto | 2 +- native/spark-expr/src/agg_funcs/avg_int.rs | 143 ++++++++++++++++++ native/spark-expr/src/agg_funcs/mod.rs | 2 + .../org/apache/comet/serde/aggregates.scala | 16 +- 5 files changed, 156 insertions(+), 15 deletions(-) create mode 100644 native/spark-expr/src/agg_funcs/avg_int.rs diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index 0fe04a5a41..9f6815ffa4 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -118,7 +118,7 @@ use datafusion_comet_spark_expr::{ ArrayInsert, Avg, AvgDecimal, Cast, CheckOverflow, Correlation, Covariance, CreateNamedStruct, GetArrayStructFields, GetStructField, IfExpr, ListExtract, NormalizeNaNAndZero, RLike, RandExpr, RandnExpr, SparkCastOptions, Stddev, SubstringExpr, SumDecimal, TimestampTruncExpr, - ToJson, UnboundColumn, Variance, + ToJson, UnboundColumn, Variance, AvgInt }; use itertools::Itertools; use jni::objects::GlobalRef; @@ -1894,6 +1894,12 @@ impl PhysicalPlanner { let datatype = to_arrow_datatype(expr.datatype.as_ref().unwrap()); let input_datatype = to_arrow_datatype(expr.sum_datatype.as_ref().unwrap()); let builder = match datatype { + + DataType::Int8 | DataType::UInt8 | DataType::Int16 | DataType::UInt16 | DataType::Int32 => { + let func = + AggregateUDF::new_from_impl(AvgInt::new(datatype, input_datatype)); + AggregateExprBuilder::new(Arc::new(func), vec![child]) + } DataType::Decimal128(_, _) => { let func = AggregateUDF::new_from_impl(AvgDecimal::new(datatype, input_datatype)); diff --git a/native/proto/src/proto/expr.proto b/native/proto/src/proto/expr.proto index c9037dcd69..7ec4a9aebe 100644 --- a/native/proto/src/proto/expr.proto +++ b/native/proto/src/proto/expr.proto @@ -137,7 +137,7 @@ message Avg { Expr child = 1; DataType datatype = 2; DataType sum_datatype = 3; - bool fail_on_error = 4; // currently unused (useful for deciding Ansi vs Legacy mode) + EvalMode eval_mode = 4; } message First { diff --git a/native/spark-expr/src/agg_funcs/avg_int.rs b/native/spark-expr/src/agg_funcs/avg_int.rs new file mode 100644 index 0000000000..c48e84ff78 --- /dev/null +++ b/native/spark-expr/src/agg_funcs/avg_int.rs @@ -0,0 +1,143 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use arrow::array::{ArrayRef, BooleanArray}; +use arrow::datatypes::{DataType, FieldRef}; +use datafusion::common::{DataFusionError, Result as DFResult, ScalarValue}; +use datafusion::logical_expr::{Accumulator, AggregateUDFImpl, EmitTo, GroupsAccumulator, ReversedUDAF, Signature}; +use datafusion::logical_expr::function::{AccumulatorArgs, StateFieldsArgs}; +use datafusion::logical_expr::type_coercion::aggregates::avg_return_type; +use datafusion::logical_expr::Volatility::Immutable; +use crate::{AvgDecimal, EvalMode}; + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct AvgInt { + signature: Signature, + eval_mode: EvalMode, +} + +impl AvgInt { + pub fn try_new(data_type: DataType, eval_mode: EvalMode) -> DFResult { + match data_type { + DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => { + Ok(Self { + signature: Signature::user_defined(Immutable), + eval_mode + }) + }, + _ => {Err(DataFusionError::Internal("inalid data type for AvgInt".to_string()))} + } + } +} + +impl AggregateUDFImpl for AvgInt { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "avg" + } + + fn reverse_expr(&self) -> ReversedUDAF { + ReversedUDAF::Identical + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> datafusion::common::Result { + avg_return_type(self.name(), &arg_types[0]) + } + + fn is_nullable(&self) -> bool { + true + } + + fn accumulator(&self, acc_args: AccumulatorArgs) -> datafusion::common::Result> { + todo!() + } + + fn state_fields(&self, args: StateFieldsArgs) -> datafusion::common::Result> { + todo!() + } + + fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool { + false + } + + fn create_groups_accumulator(&self, _args: AccumulatorArgs) -> datafusion::common::Result> { + Ok(Box::new(AvgIntGroupsAccumulator::new(self.eval_mode))) + } + + fn default_value(&self, data_type: &DataType) -> datafusion::common::Result { + todo!() + } +} + +struct AvgIntegerAccumulator{ + sum: Option, + count: u64, + eval_mode: EvalMode, +} + +impl AvgIntegerAccumulator { + fn new(eval_mode: EvalMode) -> Self { + Self{ + sum : Some(0), + count: 0, + eval_mode + } + } +} + +impl Accumulator for AvgIntegerAccumulator { + +} + +struct AvgIntGroupsAccumulator { + +} + +impl AvgIntGroupsAccumulator { + +} + + +impl GroupsAccumulator for AvgIntGroupsAccumulator { + fn update_batch(&mut self, values: &[ArrayRef], group_indices: &[usize], opt_filter: Option<&BooleanArray>, total_num_groups: usize) -> datafusion::common::Result<()> { + todo!() + } + + fn evaluate(&mut self, emit_to: EmitTo) -> datafusion::common::Result { + todo!() + } + + fn state(&mut self, emit_to: EmitTo) -> datafusion::common::Result> { + todo!() + } + + fn merge_batch(&mut self, values: &[ArrayRef], group_indices: &[usize], opt_filter: Option<&BooleanArray>, total_num_groups: usize) -> datafusion::common::Result<()> { + todo!() + } + + fn size(&self) -> usize { + todo!() + } +} \ No newline at end of file diff --git a/native/spark-expr/src/agg_funcs/mod.rs b/native/spark-expr/src/agg_funcs/mod.rs index 252da78890..19398cd1e2 100644 --- a/native/spark-expr/src/agg_funcs/mod.rs +++ b/native/spark-expr/src/agg_funcs/mod.rs @@ -22,6 +22,7 @@ mod covariance; mod stddev; mod sum_decimal; mod variance; +mod avg_int; pub use avg::Avg; pub use avg_decimal::AvgDecimal; @@ -30,3 +31,4 @@ pub use covariance::Covariance; pub use stddev::Stddev; pub use sum_decimal::SumDecimal; pub use variance::Variance; +pub use avg_int::AvgInt; diff --git a/spark/src/main/scala/org/apache/comet/serde/aggregates.scala b/spark/src/main/scala/org/apache/comet/serde/aggregates.scala index d00bbf4dfa..c1760a8b67 100644 --- a/spark/src/main/scala/org/apache/comet/serde/aggregates.scala +++ b/spark/src/main/scala/org/apache/comet/serde/aggregates.scala @@ -29,7 +29,8 @@ import org.apache.spark.sql.types.{ByteType, DataTypes, DecimalType, IntegerType import org.apache.comet.CometConf import org.apache.comet.CometConf.COMET_EXEC_STRICT_FLOATING_POINT import org.apache.comet.CometSparkSessionExtensions.withInfo -import org.apache.comet.serde.QueryPlanSerde.{exprToProto, serializeDataType} +import org.apache.comet.serde.QueryPlanSerde.{evalModeToProto, exprToProto, serializeDataType} +import org.apache.comet.shims.CometEvalModeUtil object CometMin extends CometAggregateExpressionSerde[Min] { @@ -150,17 +151,6 @@ object CometCount extends CometAggregateExpressionSerde[Count] { object CometAverage extends CometAggregateExpressionSerde[Average] { - override def getSupportLevel(avg: Average): SupportLevel = { - avg.evalMode match { - case EvalMode.ANSI => - Incompatible(Some("ANSI mode is not supported")) - case EvalMode.TRY => - Incompatible(Some("TRY mode is not supported")) - case _ => - Compatible() - } - } - override def convert( aggExpr: AggregateExpression, avg: Average, @@ -192,7 +182,7 @@ object CometAverage extends CometAggregateExpressionSerde[Average] { val builder = ExprOuterClass.Avg.newBuilder() builder.setChild(childExpr.get) builder.setDatatype(dataType.get) - builder.setFailOnError(avg.evalMode == EvalMode.ANSI) + builder.setEvalMode(evalModeToProto(CometEvalModeUtil.fromSparkEvalMode(avg.evalMode))) builder.setSumDatatype(sumDataType.get) Some( From 07c6acda11c3be6c5a38fe407481fcb7edc3a7a1 Mon Sep 17 00:00:00 2001 From: B Vadlamani Date: Thu, 20 Nov 2025 11:29:27 -0800 Subject: [PATCH 02/16] support_ansi_avg_wip --- native/core/src/execution/planner.rs | 15 +++-- native/spark-expr/src/agg_funcs/avg_int.rs | 71 +++++++++++++--------- native/spark-expr/src/agg_funcs/mod.rs | 4 +- 3 files changed, 53 insertions(+), 37 deletions(-) diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index 9f6815ffa4..317455d68b 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -115,10 +115,10 @@ use datafusion_comet_proto::{ }; use datafusion_comet_spark_expr::monotonically_increasing_id::MonotonicallyIncreasingId; use datafusion_comet_spark_expr::{ - ArrayInsert, Avg, AvgDecimal, Cast, CheckOverflow, Correlation, Covariance, CreateNamedStruct, - GetArrayStructFields, GetStructField, IfExpr, ListExtract, NormalizeNaNAndZero, RLike, - RandExpr, RandnExpr, SparkCastOptions, Stddev, SubstringExpr, SumDecimal, TimestampTruncExpr, - ToJson, UnboundColumn, Variance, AvgInt + ArrayInsert, Avg, AvgDecimal, AvgInt, Cast, CheckOverflow, Correlation, Covariance, + CreateNamedStruct, GetArrayStructFields, GetStructField, IfExpr, ListExtract, + NormalizeNaNAndZero, RLike, RandExpr, RandnExpr, SparkCastOptions, Stddev, SubstringExpr, + SumDecimal, TimestampTruncExpr, ToJson, UnboundColumn, Variance, }; use itertools::Itertools; use jni::objects::GlobalRef; @@ -1894,8 +1894,11 @@ impl PhysicalPlanner { let datatype = to_arrow_datatype(expr.datatype.as_ref().unwrap()); let input_datatype = to_arrow_datatype(expr.sum_datatype.as_ref().unwrap()); let builder = match datatype { - - DataType::Int8 | DataType::UInt8 | DataType::Int16 | DataType::UInt16 | DataType::Int32 => { + DataType::Int8 + | DataType::UInt8 + | DataType::Int16 + | DataType::UInt16 + | DataType::Int32 => { let func = AggregateUDF::new_from_impl(AvgInt::new(datatype, input_datatype)); AggregateExprBuilder::new(Arc::new(func), vec![child]) diff --git a/native/spark-expr/src/agg_funcs/avg_int.rs b/native/spark-expr/src/agg_funcs/avg_int.rs index c48e84ff78..103c2ac19e 100644 --- a/native/spark-expr/src/agg_funcs/avg_int.rs +++ b/native/spark-expr/src/agg_funcs/avg_int.rs @@ -15,15 +15,17 @@ // specific language governing permissions and limitations // under the License. -use std::any::Any; +use crate::{AvgDecimal, EvalMode}; use arrow::array::{ArrayRef, BooleanArray}; use arrow::datatypes::{DataType, FieldRef}; use datafusion::common::{DataFusionError, Result as DFResult, ScalarValue}; -use datafusion::logical_expr::{Accumulator, AggregateUDFImpl, EmitTo, GroupsAccumulator, ReversedUDAF, Signature}; use datafusion::logical_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion::logical_expr::type_coercion::aggregates::avg_return_type; use datafusion::logical_expr::Volatility::Immutable; -use crate::{AvgDecimal, EvalMode}; +use datafusion::logical_expr::{ + Accumulator, AggregateUDFImpl, EmitTo, GroupsAccumulator, ReversedUDAF, Signature, +}; +use std::any::Any; #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct AvgInt { @@ -34,13 +36,13 @@ pub struct AvgInt { impl AvgInt { pub fn try_new(data_type: DataType, eval_mode: EvalMode) -> DFResult { match data_type { - DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => { - Ok(Self { + DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => Ok(Self { signature: Signature::user_defined(Immutable), - eval_mode - }) - }, - _ => {Err(DataFusionError::Internal("inalid data type for AvgInt".to_string()))} + eval_mode, + }), + _ => Err(DataFusionError::Internal( + "inalid data type for AvgInt".to_string(), + )), } } } @@ -58,7 +60,7 @@ impl AggregateUDFImpl for AvgInt { ReversedUDAF::Identical } - fn signature(&self) -> &Signature { + fn signature(&self) -> &Signature { &self.signature } @@ -70,7 +72,10 @@ impl AggregateUDFImpl for AvgInt { true } - fn accumulator(&self, acc_args: AccumulatorArgs) -> datafusion::common::Result> { + fn accumulator( + &self, + acc_args: AccumulatorArgs, + ) -> datafusion::common::Result> { todo!() } @@ -82,7 +87,10 @@ impl AggregateUDFImpl for AvgInt { false } - fn create_groups_accumulator(&self, _args: AccumulatorArgs) -> datafusion::common::Result> { + fn create_groups_accumulator( + &self, + _args: AccumulatorArgs, + ) -> datafusion::common::Result> { Ok(Box::new(AvgIntGroupsAccumulator::new(self.eval_mode))) } @@ -91,7 +99,7 @@ impl AggregateUDFImpl for AvgInt { } } -struct AvgIntegerAccumulator{ +struct AvgIntegerAccumulator { sum: Option, count: u64, eval_mode: EvalMode, @@ -99,29 +107,28 @@ struct AvgIntegerAccumulator{ impl AvgIntegerAccumulator { fn new(eval_mode: EvalMode) -> Self { - Self{ - sum : Some(0), + Self { + sum: Some(0), count: 0, - eval_mode + eval_mode, } } } -impl Accumulator for AvgIntegerAccumulator { - -} +impl Accumulator for AvgIntegerAccumulator {} -struct AvgIntGroupsAccumulator { - -} - -impl AvgIntGroupsAccumulator { - -} +struct AvgIntGroupsAccumulator {} +impl AvgIntGroupsAccumulator {} impl GroupsAccumulator for AvgIntGroupsAccumulator { - fn update_batch(&mut self, values: &[ArrayRef], group_indices: &[usize], opt_filter: Option<&BooleanArray>, total_num_groups: usize) -> datafusion::common::Result<()> { + fn update_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) -> datafusion::common::Result<()> { todo!() } @@ -133,11 +140,17 @@ impl GroupsAccumulator for AvgIntGroupsAccumulator { todo!() } - fn merge_batch(&mut self, values: &[ArrayRef], group_indices: &[usize], opt_filter: Option<&BooleanArray>, total_num_groups: usize) -> datafusion::common::Result<()> { + fn merge_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) -> datafusion::common::Result<()> { todo!() } fn size(&self) -> usize { todo!() } -} \ No newline at end of file +} diff --git a/native/spark-expr/src/agg_funcs/mod.rs b/native/spark-expr/src/agg_funcs/mod.rs index 19398cd1e2..8025fc7a08 100644 --- a/native/spark-expr/src/agg_funcs/mod.rs +++ b/native/spark-expr/src/agg_funcs/mod.rs @@ -17,18 +17,18 @@ mod avg; mod avg_decimal; +mod avg_int; mod correlation; mod covariance; mod stddev; mod sum_decimal; mod variance; -mod avg_int; pub use avg::Avg; pub use avg_decimal::AvgDecimal; +pub use avg_int::AvgInt; pub use correlation::Correlation; pub use covariance::Covariance; pub use stddev::Stddev; pub use sum_decimal::SumDecimal; pub use variance::Variance; -pub use avg_int::AvgInt; From d1b474c225d6a16e645d9b96d43f61f549173c63 Mon Sep 17 00:00:00 2001 From: B Vadlamani Date: Sun, 23 Nov 2025 16:11:46 -0800 Subject: [PATCH 03/16] support_ansi_avg --- .../source/user-guide/latest/compatibility.md | 4 - docs/source/user-guide/latest/configs.md | 29 ---- native/core/src/execution/planner.rs | 32 ++-- native/spark-expr/src/agg_funcs/avg.rs | 52 +++--- native/spark-expr/src/agg_funcs/avg_int.rs | 156 ------------------ native/spark-expr/src/agg_funcs/mod.rs | 2 - .../comet/exec/CometAggregateSuite.scala | 36 ++++ 7 files changed, 82 insertions(+), 229 deletions(-) delete mode 100644 native/spark-expr/src/agg_funcs/avg_int.rs diff --git a/docs/source/user-guide/latest/compatibility.md b/docs/source/user-guide/latest/compatibility.md index 60e2234f59..cb5366bd73 100644 --- a/docs/source/user-guide/latest/compatibility.md +++ b/docs/source/user-guide/latest/compatibility.md @@ -89,7 +89,6 @@ The following cast operations are generally compatible with Spark except for the - | From Type | To Type | Notes | |-|-|-| | boolean | byte | | @@ -166,7 +165,6 @@ The following cast operations are generally compatible with Spark except for the | timestamp | long | | | timestamp | string | | | timestamp | date | | - ### Incompatible Casts @@ -176,7 +174,6 @@ The following cast operations are not compatible with Spark for all inputs and a - | From Type | To Type | Notes | |-|-|-| | float | decimal | There can be rounding differences | @@ -185,7 +182,6 @@ The following cast operations are not compatible with Spark for all inputs and a | string | double | Does not support inputs ending with 'd' or 'f'. Does not support 'inf'. Does not support ANSI mode. | | string | decimal | Does not support inputs ending with 'd' or 'f'. Does not support 'inf'. Does not support ANSI mode. Returns 0.0 instead of null if input contains no digits | | string | timestamp | Not all valid formats are supported | - ### Unsupported Casts diff --git a/docs/source/user-guide/latest/configs.md b/docs/source/user-guide/latest/configs.md index 1e77032f7d..df6ff0e313 100644 --- a/docs/source/user-guide/latest/configs.md +++ b/docs/source/user-guide/latest/configs.md @@ -25,23 +25,19 @@ Comet provides the following configuration settings. - | Config | Description | Default Value | |--------|-------------|---------------| | `spark.comet.scan.allowIncompatible` | Some Comet scan implementations are not currently fully compatible with Spark for all datatypes. Set this config to true to allow them anyway. For more information, refer to the [Comet Compatibility Guide](https://datafusion.apache.org/comet/user-guide/compatibility.html). | false | | `spark.comet.scan.enabled` | Whether to enable native scans. When this is turned on, Spark will use Comet to read supported data sources (currently only Parquet is supported natively). Note that to enable native vectorized execution, both this config and `spark.comet.exec.enabled` need to be enabled. | true | -| `spark.comet.scan.icebergNative.enabled` | Whether to enable native Iceberg table scan using iceberg-rust. When enabled, Iceberg tables are read directly through native execution, bypassing Spark's DataSource V2 API for better performance. | false | | `spark.comet.scan.preFetch.enabled` | Whether to enable pre-fetching feature of CometScan. | false | | `spark.comet.scan.preFetch.threadNum` | The number of threads running pre-fetching for CometScan. Effective if spark.comet.scan.preFetch.enabled is enabled. Note that more pre-fetching threads means more memory requirement to store pre-fetched row groups. | 2 | | `spark.hadoop.fs.comet.libhdfs.schemes` | Defines filesystem schemes (e.g., hdfs, webhdfs) that the native side accesses via libhdfs, separated by commas. Valid only when built with hdfs feature enabled. | | - ## Parquet Reader Configuration Settings - | Config | Description | Default Value | |--------|-------------|---------------| | `spark.comet.parquet.enable.directBuffer` | Whether to use Java direct byte buffer when reading Parquet. | false | @@ -51,14 +47,12 @@ Comet provides the following configuration settings. | `spark.comet.parquet.read.parallel.io.enabled` | Whether to enable Comet's parallel reader for Parquet files. The parallel reader reads ranges of consecutive data in a file in parallel. It is faster for large files and row groups but uses more resources. | true | | `spark.comet.parquet.read.parallel.io.thread-pool.size` | The maximum number of parallel threads the parallel reader will use in a single executor. For executors configured with a smaller number of cores, use a smaller number. | 16 | | `spark.comet.parquet.respectFilterPushdown` | Whether to respect Spark's PARQUET_FILTER_PUSHDOWN_ENABLED config. This needs to be respected when running the Spark SQL test suite but the default setting results in poor performance in Comet when using the new native scans, disabled by default | false | - ## Query Execution Settings - | Config | Description | Default Value | |--------|-------------|---------------| | `spark.comet.caseConversion.enabled` | Java uses locale-specific rules when converting strings to upper or lower case and Rust does not, so we disable upper and lower by default. | false | @@ -73,7 +67,6 @@ Comet provides the following configuration settings. | `spark.comet.metrics.updateInterval` | The interval in milliseconds to update metrics. If interval is negative, metrics will be updated upon task completion. | 3000 | | `spark.comet.nativeLoadRequired` | Whether to require Comet native library to load successfully when Comet is enabled. If not, Comet will silently fallback to Spark when it fails to load the native lib. Otherwise, an error will be thrown and the Spark job will be aborted. | false | | `spark.comet.regexp.allowIncompatible` | Comet is not currently fully compatible with Spark for all regular expressions. Set this config to true to allow them anyway. For more information, refer to the [Comet Compatibility Guide](https://datafusion.apache.org/comet/user-guide/compatibility.html). | false | - ## Viewing Explain Plan & Fallback Reasons @@ -82,7 +75,6 @@ These settings can be used to determine which parts of the plan are accelerated - | Config | Description | Default Value | |--------|-------------|---------------| | `spark.comet.explain.format` | Choose extended explain output. The default format of 'verbose' will provide the full query plan annotated with fallback reasons as well as a summary of how much of the plan was accelerated by Comet. The format 'fallback' provides a list of fallback reasons instead. | verbose | @@ -90,14 +82,12 @@ These settings can be used to determine which parts of the plan are accelerated | `spark.comet.explain.rules` | When this setting is enabled, Comet will log all plan transformations performed in physical optimizer rules. Default: false | false | | `spark.comet.explainFallback.enabled` | When this setting is enabled, Comet will provide logging explaining the reason(s) why a query stage cannot be executed natively. Set this to false to reduce the amount of logging. | false | | `spark.comet.logFallbackReasons.enabled` | When this setting is enabled, Comet will log warnings for all fallback reasons. It can be overridden by the environment variable `ENABLE_COMET_LOG_FALLBACK_REASONS`. | false | - ## Shuffle Configuration Settings - | Config | Description | Default Value | |--------|-------------|---------------| | `spark.comet.columnar.shuffle.async.enabled` | Whether to enable asynchronous shuffle for Arrow-based shuffle. | false | @@ -111,28 +101,24 @@ These settings can be used to determine which parts of the plan are accelerated | `spark.comet.native.shuffle.partitioning.range.enabled` | Whether to enable range partitioning for Comet native shuffle. | true | | `spark.comet.shuffle.preferDictionary.ratio` | The ratio of total values to distinct values in a string column to decide whether to prefer dictionary encoding when shuffling the column. If the ratio is higher than this config, dictionary encoding will be used on shuffling string column. This config is effective if it is higher than 1.0. Note that this config is only used when `spark.comet.exec.shuffle.mode` is `jvm`. | 10.0 | | `spark.comet.shuffle.sizeInBytesMultiplier` | Comet reports smaller sizes for shuffle due to using Arrow's columnar memory format and this can result in Spark choosing a different join strategy due to the estimated size of the exchange being smaller. Comet will multiple sizeInBytes by this amount to avoid regressions in join strategy. | 1.0 | - ## Memory & Tuning Configuration Settings - | Config | Description | Default Value | |--------|-------------|---------------| | `spark.comet.batchSize` | The columnar batch size, i.e., the maximum number of rows that a batch can contain. | 8192 | | `spark.comet.exec.memoryPool` | The type of memory pool to be used for Comet native execution when running Spark in off-heap mode. Available pool types are `greedy_unified` and `fair_unified`. For more information, refer to the [Comet Tuning Guide](https://datafusion.apache.org/comet/user-guide/tuning.html). | fair_unified | | `spark.comet.exec.memoryPool.fraction` | Fraction of off-heap memory pool that is available to Comet. Only applies to off-heap mode. For more information, refer to the [Comet Tuning Guide](https://datafusion.apache.org/comet/user-guide/tuning.html). | 1.0 | | `spark.comet.tracing.enabled` | Enable fine-grained tracing of events and memory usage. For more information, refer to the [Comet Tracing Guide](https://datafusion.apache.org/comet/user-guide/tracing.html). | false | - ## Development & Testing Settings - | Config | Description | Default Value | |--------|-------------|---------------| | `spark.comet.columnar.shuffle.memory.factor` | Fraction of Comet memory to be allocated per executor process for columnar shuffle when running in on-heap mode. For more information, refer to the [Comet Tuning Guide](https://datafusion.apache.org/comet/user-guide/tuning.html). | 1.0 | @@ -145,14 +131,12 @@ These settings can be used to determine which parts of the plan are accelerated | `spark.comet.sparkToColumnar.enabled` | Whether to enable Spark to Arrow columnar conversion. When this is turned on, Comet will convert operators in `spark.comet.sparkToColumnar.supportedOperatorList` into Arrow columnar format before processing. This is an experimental feature and has known issues with non-UTC timezones. | false | | `spark.comet.sparkToColumnar.supportedOperatorList` | A comma-separated list of operators that will be converted to Arrow columnar format when `spark.comet.sparkToColumnar.enabled` is true. | Range,InMemoryTableScan,RDDScan | | `spark.comet.testing.strict` | Experimental option to enable strict testing, which will fail tests that could be more comprehensive, such as checking for a specific fallback reason. It can be overridden by the environment variable `ENABLE_COMET_STRICT_TESTING`. | false | - ## Enabling or Disabling Individual Operators - | Config | Description | Default Value | |--------|-------------|---------------| | `spark.comet.exec.aggregate.enabled` | Whether to enable aggregate by default. | true | @@ -173,14 +157,12 @@ These settings can be used to determine which parts of the plan are accelerated | `spark.comet.exec.takeOrderedAndProject.enabled` | Whether to enable takeOrderedAndProject by default. | true | | `spark.comet.exec.union.enabled` | Whether to enable union by default. | true | | `spark.comet.exec.window.enabled` | Whether to enable window by default. | true | - ## Enabling or Disabling Individual Scalar Expressions - | Config | Description | Default Value | |--------|-------------|---------------| | `spark.comet.expression.Abs.enabled` | Enable Comet acceleration for `Abs` | true | @@ -215,7 +197,6 @@ These settings can be used to determine which parts of the plan are accelerated | `spark.comet.expression.BitwiseNot.enabled` | Enable Comet acceleration for `BitwiseNot` | true | | `spark.comet.expression.BitwiseOr.enabled` | Enable Comet acceleration for `BitwiseOr` | true | | `spark.comet.expression.BitwiseXor.enabled` | Enable Comet acceleration for `BitwiseXor` | true | -| `spark.comet.expression.BloomFilterMightContain.enabled` | Enable Comet acceleration for `BloomFilterMightContain` | true | | `spark.comet.expression.CaseWhen.enabled` | Enable Comet acceleration for `CaseWhen` | true | | `spark.comet.expression.Cast.enabled` | Enable Comet acceleration for `Cast` | true | | `spark.comet.expression.Ceil.enabled` | Enable Comet acceleration for `Ceil` | true | @@ -226,7 +207,6 @@ These settings can be used to determine which parts of the plan are accelerated | `spark.comet.expression.ConcatWs.enabled` | Enable Comet acceleration for `ConcatWs` | true | | `spark.comet.expression.Contains.enabled` | Enable Comet acceleration for `Contains` | true | | `spark.comet.expression.Cos.enabled` | Enable Comet acceleration for `Cos` | true | -| `spark.comet.expression.Cosh.enabled` | Enable Comet acceleration for `Cosh` | true | | `spark.comet.expression.Cot.enabled` | Enable Comet acceleration for `Cot` | true | | `spark.comet.expression.CreateArray.enabled` | Enable Comet acceleration for `CreateArray` | true | | `spark.comet.expression.CreateNamedStruct.enabled` | Enable Comet acceleration for `CreateNamedStruct` | true | @@ -261,7 +241,6 @@ These settings can be used to determine which parts of the plan are accelerated | `spark.comet.expression.IsNaN.enabled` | Enable Comet acceleration for `IsNaN` | true | | `spark.comet.expression.IsNotNull.enabled` | Enable Comet acceleration for `IsNotNull` | true | | `spark.comet.expression.IsNull.enabled` | Enable Comet acceleration for `IsNull` | true | -| `spark.comet.expression.KnownFloatingPointNormalized.enabled` | Enable Comet acceleration for `KnownFloatingPointNormalized` | true | | `spark.comet.expression.Length.enabled` | Enable Comet acceleration for `Length` | true | | `spark.comet.expression.LessThan.enabled` | Enable Comet acceleration for `LessThan` | true | | `spark.comet.expression.LessThanOrEqual.enabled` | Enable Comet acceleration for `LessThanOrEqual` | true | @@ -271,7 +250,6 @@ These settings can be used to determine which parts of the plan are accelerated | `spark.comet.expression.Log10.enabled` | Enable Comet acceleration for `Log10` | true | | `spark.comet.expression.Log2.enabled` | Enable Comet acceleration for `Log2` | true | | `spark.comet.expression.Lower.enabled` | Enable Comet acceleration for `Lower` | true | -| `spark.comet.expression.MakeDecimal.enabled` | Enable Comet acceleration for `MakeDecimal` | true | | `spark.comet.expression.MapEntries.enabled` | Enable Comet acceleration for `MapEntries` | true | | `spark.comet.expression.MapFromArrays.enabled` | Enable Comet acceleration for `MapFromArrays` | true | | `spark.comet.expression.MapKeys.enabled` | Enable Comet acceleration for `MapKeys` | true | @@ -294,7 +272,6 @@ These settings can be used to determine which parts of the plan are accelerated | `spark.comet.expression.Remainder.enabled` | Enable Comet acceleration for `Remainder` | true | | `spark.comet.expression.Reverse.enabled` | Enable Comet acceleration for `Reverse` | true | | `spark.comet.expression.Round.enabled` | Enable Comet acceleration for `Round` | true | -| `spark.comet.expression.ScalarSubquery.enabled` | Enable Comet acceleration for `ScalarSubquery` | true | | `spark.comet.expression.Second.enabled` | Enable Comet acceleration for `Second` | true | | `spark.comet.expression.Sha1.enabled` | Enable Comet acceleration for `Sha1` | true | | `spark.comet.expression.Sha2.enabled` | Enable Comet acceleration for `Sha2` | true | @@ -302,7 +279,6 @@ These settings can be used to determine which parts of the plan are accelerated | `spark.comet.expression.ShiftRight.enabled` | Enable Comet acceleration for `ShiftRight` | true | | `spark.comet.expression.Signum.enabled` | Enable Comet acceleration for `Signum` | true | | `spark.comet.expression.Sin.enabled` | Enable Comet acceleration for `Sin` | true | -| `spark.comet.expression.Sinh.enabled` | Enable Comet acceleration for `Sinh` | true | | `spark.comet.expression.SortOrder.enabled` | Enable Comet acceleration for `SortOrder` | true | | `spark.comet.expression.SparkPartitionID.enabled` | Enable Comet acceleration for `SparkPartitionID` | true | | `spark.comet.expression.Sqrt.enabled` | Enable Comet acceleration for `Sqrt` | true | @@ -323,25 +299,21 @@ These settings can be used to determine which parts of the plan are accelerated | `spark.comet.expression.Substring.enabled` | Enable Comet acceleration for `Substring` | true | | `spark.comet.expression.Subtract.enabled` | Enable Comet acceleration for `Subtract` | true | | `spark.comet.expression.Tan.enabled` | Enable Comet acceleration for `Tan` | true | -| `spark.comet.expression.Tanh.enabled` | Enable Comet acceleration for `Tanh` | true | | `spark.comet.expression.TruncDate.enabled` | Enable Comet acceleration for `TruncDate` | true | | `spark.comet.expression.TruncTimestamp.enabled` | Enable Comet acceleration for `TruncTimestamp` | true | | `spark.comet.expression.UnaryMinus.enabled` | Enable Comet acceleration for `UnaryMinus` | true | | `spark.comet.expression.Unhex.enabled` | Enable Comet acceleration for `Unhex` | true | -| `spark.comet.expression.UnscaledValue.enabled` | Enable Comet acceleration for `UnscaledValue` | true | | `spark.comet.expression.Upper.enabled` | Enable Comet acceleration for `Upper` | true | | `spark.comet.expression.WeekDay.enabled` | Enable Comet acceleration for `WeekDay` | true | | `spark.comet.expression.WeekOfYear.enabled` | Enable Comet acceleration for `WeekOfYear` | true | | `spark.comet.expression.XxHash64.enabled` | Enable Comet acceleration for `XxHash64` | true | | `spark.comet.expression.Year.enabled` | Enable Comet acceleration for `Year` | true | - ## Enabling or Disabling Individual Aggregate Expressions - | Config | Description | Default Value | |--------|-------------|---------------| | `spark.comet.expression.Average.enabled` | Enable Comet acceleration for `Average` | true | @@ -362,5 +334,4 @@ These settings can be used to determine which parts of the plan are accelerated | `spark.comet.expression.Sum.enabled` | Enable Comet acceleration for `Sum` | true | | `spark.comet.expression.VariancePop.enabled` | Enable Comet acceleration for `VariancePop` | true | | `spark.comet.expression.VarianceSamp.enabled` | Enable Comet acceleration for `VarianceSamp` | true | - diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index 317455d68b..bdb9ff611c 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -115,10 +115,10 @@ use datafusion_comet_proto::{ }; use datafusion_comet_spark_expr::monotonically_increasing_id::MonotonicallyIncreasingId; use datafusion_comet_spark_expr::{ - ArrayInsert, Avg, AvgDecimal, AvgInt, Cast, CheckOverflow, Correlation, Covariance, - CreateNamedStruct, GetArrayStructFields, GetStructField, IfExpr, ListExtract, - NormalizeNaNAndZero, RLike, RandExpr, RandnExpr, SparkCastOptions, Stddev, SubstringExpr, - SumDecimal, TimestampTruncExpr, ToJson, UnboundColumn, Variance, + ArrayInsert, Avg, AvgDecimal, Cast, CheckOverflow, Correlation, Covariance, CreateNamedStruct, + GetArrayStructFields, GetStructField, IfExpr, ListExtract, NormalizeNaNAndZero, RLike, + RandExpr, RandnExpr, SparkCastOptions, Stddev, SubstringExpr, SumDecimal, TimestampTruncExpr, + ToJson, UnboundColumn, Variance, }; use itertools::Itertools; use jni::objects::GlobalRef; @@ -1893,28 +1893,24 @@ impl PhysicalPlanner { let child = self.create_expr(expr.child.as_ref().unwrap(), Arc::clone(&schema))?; let datatype = to_arrow_datatype(expr.datatype.as_ref().unwrap()); let input_datatype = to_arrow_datatype(expr.sum_datatype.as_ref().unwrap()); + let eval_mode = from_protobuf_eval_mode(expr.eval_mode)?; + let builder = match datatype { - DataType::Int8 - | DataType::UInt8 - | DataType::Int16 - | DataType::UInt16 - | DataType::Int32 => { - let func = - AggregateUDF::new_from_impl(AvgInt::new(datatype, input_datatype)); - AggregateExprBuilder::new(Arc::new(func), vec![child]) - } DataType::Decimal128(_, _) => { let func = AggregateUDF::new_from_impl(AvgDecimal::new(datatype, input_datatype)); AggregateExprBuilder::new(Arc::new(func), vec![child]) } _ => { - // cast to the result data type of AVG if the result data type is different - // from the input type, e.g. AVG(Int32). We should not expect a cast - // failure since it should have already been checked at Spark side. + // For all other numeric types (Int8/16/32/64, Float32/64): + // Cast to Float64 for accumulation let child: Arc = - Arc::new(CastExpr::new(Arc::clone(&child), datatype.clone(), None)); - let func = AggregateUDF::new_from_impl(Avg::new("avg", datatype)); + Arc::new(CastExpr::new(Arc::clone(&child), DataType::Float64, None)); + let func = AggregateUDF::new_from_impl(Avg::new( + "avg", + DataType::Float64, + eval_mode, + )); AggregateExprBuilder::new(Arc::new(func), vec![child]) } }; diff --git a/native/spark-expr/src/agg_funcs/avg.rs b/native/spark-expr/src/agg_funcs/avg.rs index e8b90b4f46..9850c02605 100644 --- a/native/spark-expr/src/agg_funcs/avg.rs +++ b/native/spark-expr/src/agg_funcs/avg.rs @@ -15,11 +15,12 @@ // specific language governing permissions and limitations // under the License. +use crate::EvalMode; use arrow::array::{ builder::PrimitiveBuilder, cast::AsArray, types::{Float64Type, Int64Type}, - Array, ArrayRef, ArrowNumericType, Int64Array, PrimitiveArray, + Array, ArrayRef, ArrowNativeTypeOp, ArrowNumericType, Int64Array, PrimitiveArray, }; use arrow::compute::sum; use arrow::datatypes::{DataType, Field, FieldRef}; @@ -31,24 +32,22 @@ use datafusion::logical_expr::{ use datafusion::physical_expr::expressions::format_state_name; use std::{any::Any, sync::Arc}; -use arrow::array::ArrowNativeTypeOp; use datafusion::logical_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion::logical_expr::Volatility::Immutable; use DataType::*; -/// AVG aggregate expression #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct Avg { name: String, signature: Signature, - // expr: Arc, input_data_type: DataType, result_data_type: DataType, + eval_mode: EvalMode, } impl Avg { /// Create a new AVG aggregate function - pub fn new(name: impl Into, data_type: DataType) -> Self { + pub fn new(name: impl Into, data_type: DataType, eval_mode: EvalMode) -> Self { let result_data_type = avg_return_type("avg", &data_type).unwrap(); Self { @@ -56,20 +55,20 @@ impl Avg { signature: Signature::user_defined(Immutable), input_data_type: data_type, result_data_type, + eval_mode, } } } impl AggregateUDFImpl for Avg { - /// Return a reference to Any that can be used for downcasting fn as_any(&self) -> &dyn Any { self } fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result> { - // instantiate specialized accumulator based for the type + // All numeric types use Float64 accumulation after casting match (&self.input_data_type, &self.result_data_type) { - (Float64, Float64) => Ok(Box::::default()), + (Float64, Float64) => Ok(Box::new(AvgAccumulator::new(self.eval_mode))), _ => not_impl_err!( "AvgAccumulator for ({} --> {})", self.input_data_type, @@ -109,10 +108,10 @@ impl AggregateUDFImpl for Avg { &self, _args: AccumulatorArgs, ) -> Result> { - // instantiate specialized accumulator based for the type match (&self.input_data_type, &self.result_data_type) { (Float64, Float64) => Ok(Box::new(AvgGroupsAccumulator::::new( &self.input_data_type, + self.eval_mode, |sum: f64, count: i64| Ok(sum / count as f64), ))), @@ -137,11 +136,22 @@ impl AggregateUDFImpl for Avg { } } -/// An accumulator to compute the average -#[derive(Debug, Default)] +#[derive(Debug)] pub struct AvgAccumulator { sum: Option, count: i64, + #[allow(dead_code)] + eval_mode: EvalMode, +} + +impl AvgAccumulator { + pub fn new(eval_mode: EvalMode) -> Self { + Self { + sum: None, + count: 0, + eval_mode, + } + } } impl Accumulator for AvgAccumulator { @@ -166,7 +176,7 @@ impl Accumulator for AvgAccumulator { // counts are summed self.count += sum(states[1].as_primitive::()).unwrap_or_default(); - // sums are summed + // sums are summed - no overflow checking if let Some(x) = sum(states[0].as_primitive::()) { let v = self.sum.get_or_insert(0.); *v += x; @@ -176,8 +186,6 @@ impl Accumulator for AvgAccumulator { fn evaluate(&mut self) -> Result { if self.count == 0 { - // If all input are nulls, count will be 0 and we will get null after the division. - // This is consistent with Spark Average implementation. Ok(ScalarValue::Float64(None)) } else { Ok(ScalarValue::Float64( @@ -192,7 +200,7 @@ impl Accumulator for AvgAccumulator { } /// An accumulator to compute the average of `[PrimitiveArray]`. -/// Stores values as native types, and does overflow checking +/// Stores values as native types. /// /// F: Function that calculates the average value from a sum of /// T::Native and a total count @@ -211,6 +219,10 @@ where /// Sums per group, stored as the native type sums: Vec, + /// Evaluation mode (stored but not used for Float64) + #[allow(dead_code)] + eval_mode: EvalMode, + /// Function that computes the final average (value / count) avg_fn: F, } @@ -220,11 +232,12 @@ where T: ArrowNumericType + Send, F: Fn(T::Native, i64) -> Result + Send, { - pub fn new(return_data_type: &DataType, avg_fn: F) -> Self { + pub fn new(return_data_type: &DataType, eval_mode: EvalMode, avg_fn: F) -> Self { Self { return_data_type: return_data_type.clone(), counts: vec![], sums: vec![], + eval_mode, avg_fn, } } @@ -254,6 +267,7 @@ where if values.null_count() == 0 { for (&group_index, &value) in iter { let sum = &mut self.sums[group_index]; + // No overflow checking - INFINITY is a valid result *sum = (*sum).add_wrapping(value); self.counts[group_index] += 1; } @@ -264,7 +278,6 @@ where } let sum = &mut self.sums[group_index]; *sum = (*sum).add_wrapping(value); - self.counts[group_index] += 1; } } @@ -280,9 +293,9 @@ where total_num_groups: usize, ) -> Result<()> { assert_eq!(values.len(), 2, "two arguments to merge_batch"); - // first batch is partial sums, second is counts let partial_sums = values[0].as_primitive::(); let partial_counts = values[1].as_primitive::(); + // update counts with partial counts self.counts.resize(total_num_groups, 0); let iter1 = group_indices.iter().zip(partial_counts.values().iter()); @@ -290,7 +303,7 @@ where self.counts[group_index] += partial_count; } - // update sums + // update sums - no overflow checking self.sums.resize(total_num_groups, T::default_value()); let iter2 = group_indices.iter().zip(partial_sums.values().iter()); for (&group_index, &new_value) in iter2 { @@ -319,7 +332,6 @@ where Ok(Arc::new(array)) } - // return arrays for sums and counts fn state(&mut self, emit_to: EmitTo) -> Result> { let counts = emit_to.take_needed(&mut self.counts); let counts = Int64Array::new(counts.into(), None); diff --git a/native/spark-expr/src/agg_funcs/avg_int.rs b/native/spark-expr/src/agg_funcs/avg_int.rs deleted file mode 100644 index 103c2ac19e..0000000000 --- a/native/spark-expr/src/agg_funcs/avg_int.rs +++ /dev/null @@ -1,156 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use crate::{AvgDecimal, EvalMode}; -use arrow::array::{ArrayRef, BooleanArray}; -use arrow::datatypes::{DataType, FieldRef}; -use datafusion::common::{DataFusionError, Result as DFResult, ScalarValue}; -use datafusion::logical_expr::function::{AccumulatorArgs, StateFieldsArgs}; -use datafusion::logical_expr::type_coercion::aggregates::avg_return_type; -use datafusion::logical_expr::Volatility::Immutable; -use datafusion::logical_expr::{ - Accumulator, AggregateUDFImpl, EmitTo, GroupsAccumulator, ReversedUDAF, Signature, -}; -use std::any::Any; - -#[derive(Debug, Clone, PartialEq, Eq, Hash)] -pub struct AvgInt { - signature: Signature, - eval_mode: EvalMode, -} - -impl AvgInt { - pub fn try_new(data_type: DataType, eval_mode: EvalMode) -> DFResult { - match data_type { - DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => Ok(Self { - signature: Signature::user_defined(Immutable), - eval_mode, - }), - _ => Err(DataFusionError::Internal( - "inalid data type for AvgInt".to_string(), - )), - } - } -} - -impl AggregateUDFImpl for AvgInt { - fn as_any(&self) -> &dyn Any { - self - } - - fn name(&self) -> &str { - "avg" - } - - fn reverse_expr(&self) -> ReversedUDAF { - ReversedUDAF::Identical - } - - fn signature(&self) -> &Signature { - &self.signature - } - - fn return_type(&self, arg_types: &[DataType]) -> datafusion::common::Result { - avg_return_type(self.name(), &arg_types[0]) - } - - fn is_nullable(&self) -> bool { - true - } - - fn accumulator( - &self, - acc_args: AccumulatorArgs, - ) -> datafusion::common::Result> { - todo!() - } - - fn state_fields(&self, args: StateFieldsArgs) -> datafusion::common::Result> { - todo!() - } - - fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool { - false - } - - fn create_groups_accumulator( - &self, - _args: AccumulatorArgs, - ) -> datafusion::common::Result> { - Ok(Box::new(AvgIntGroupsAccumulator::new(self.eval_mode))) - } - - fn default_value(&self, data_type: &DataType) -> datafusion::common::Result { - todo!() - } -} - -struct AvgIntegerAccumulator { - sum: Option, - count: u64, - eval_mode: EvalMode, -} - -impl AvgIntegerAccumulator { - fn new(eval_mode: EvalMode) -> Self { - Self { - sum: Some(0), - count: 0, - eval_mode, - } - } -} - -impl Accumulator for AvgIntegerAccumulator {} - -struct AvgIntGroupsAccumulator {} - -impl AvgIntGroupsAccumulator {} - -impl GroupsAccumulator for AvgIntGroupsAccumulator { - fn update_batch( - &mut self, - values: &[ArrayRef], - group_indices: &[usize], - opt_filter: Option<&BooleanArray>, - total_num_groups: usize, - ) -> datafusion::common::Result<()> { - todo!() - } - - fn evaluate(&mut self, emit_to: EmitTo) -> datafusion::common::Result { - todo!() - } - - fn state(&mut self, emit_to: EmitTo) -> datafusion::common::Result> { - todo!() - } - - fn merge_batch( - &mut self, - values: &[ArrayRef], - group_indices: &[usize], - opt_filter: Option<&BooleanArray>, - total_num_groups: usize, - ) -> datafusion::common::Result<()> { - todo!() - } - - fn size(&self) -> usize { - todo!() - } -} diff --git a/native/spark-expr/src/agg_funcs/mod.rs b/native/spark-expr/src/agg_funcs/mod.rs index 8025fc7a08..252da78890 100644 --- a/native/spark-expr/src/agg_funcs/mod.rs +++ b/native/spark-expr/src/agg_funcs/mod.rs @@ -17,7 +17,6 @@ mod avg; mod avg_decimal; -mod avg_int; mod correlation; mod covariance; mod stddev; @@ -26,7 +25,6 @@ mod variance; pub use avg::Avg; pub use avg_decimal::AvgDecimal; -pub use avg_int::AvgInt; pub use correlation::Correlation; pub use covariance::Covariance; pub use stddev::Stddev; diff --git a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala index 7e577c5fda..a9af3bc4f1 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala @@ -1471,6 +1471,42 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } + test("AVG and try_avg - basic functionality") { + withParquetTable( + Seq( + (10L, 1), + (20L, 1), + (null.asInstanceOf[Long], 1), + (100L, 2), + (200L, 2), + (null.asInstanceOf[Long], 3)), + "tbl") { + + Seq(true, false).foreach({ k => + // without GROUP BY + withSQLConf(SQLConf.ANSI_ENABLED.key -> k.toString) { + val res = sql("SELECT avg(_1) FROM tbl") + checkSparkAnswerAndOperator(res) + } + + // with GROUP BY + withSQLConf(SQLConf.ANSI_ENABLED.key -> k.toString) { + val res = sql("SELECT _2, avg(_1) FROM tbl GROUP BY _2") + checkSparkAnswerAndOperator(res) + } + + }) + + // try_avg without GROUP BY + val resTry = sql("SELECT try_avg(_1) FROM tbl") + checkSparkAnswerAndOperator(resTry) + + // try_avg with GROUP BY + val resTryGroup = sql("SELECT _2, try_avg(_1) FROM tbl GROUP BY _2") + checkSparkAnswerAndOperator(resTryGroup) + } + } + protected def checkSparkAnswerAndNumOfAggregates(query: String, numAggregates: Int): Unit = { val df = sql(query) checkSparkAnswer(df) From 33b0b2fb2f24520e1000f96463d47186cd5b317e Mon Sep 17 00:00:00 2001 From: B Vadlamani Date: Mon, 24 Nov 2025 12:34:39 -0800 Subject: [PATCH 04/16] support_ansi_avg --- .../org/apache/spark/sql/comet/CometPlanStabilitySuite.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/spark/src/test/scala/org/apache/spark/sql/comet/CometPlanStabilitySuite.scala b/spark/src/test/scala/org/apache/spark/sql/comet/CometPlanStabilitySuite.scala index 8f260e2ca8..3b494cc6d1 100644 --- a/spark/src/test/scala/org/apache/spark/sql/comet/CometPlanStabilitySuite.scala +++ b/spark/src/test/scala/org/apache/spark/sql/comet/CometPlanStabilitySuite.scala @@ -225,7 +225,6 @@ trait CometPlanStabilitySuite extends DisableAdaptiveExecutionSuite with TPCDSBa CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true", CometConf.COMET_EXEC_SORT_MERGE_JOIN_WITH_JOIN_FILTER_ENABLED.key -> "true", // Allow Incompatible is needed for Sum + Average for Spark 4.0.0 / ANSI support - CometConf.getExprAllowIncompatConfigKey(classOf[Average]) -> "true", CometConf.getExprAllowIncompatConfigKey(classOf[Sum]) -> "true", // as well as for v1.4/q9, v1.4/q44, v2.7.0/q6, v2.7.0/q64 CometConf.getExprAllowIncompatConfigKey(classOf[Cast]) -> "true", From 444ce0e2e9eff9a4aa172a7a80c0cb44ecff6375 Mon Sep 17 00:00:00 2001 From: B Vadlamani Date: Mon, 24 Nov 2025 13:04:34 -0800 Subject: [PATCH 05/16] support_ansi_avg --- .../source/user-guide/latest/compatibility.md | 4 +++ docs/source/user-guide/latest/configs.md | 29 +++++++++++++++++++ .../sql/comet/CometPlanStabilitySuite.scala | 2 +- 3 files changed, 34 insertions(+), 1 deletion(-) diff --git a/docs/source/user-guide/latest/compatibility.md b/docs/source/user-guide/latest/compatibility.md index cb5366bd73..60e2234f59 100644 --- a/docs/source/user-guide/latest/compatibility.md +++ b/docs/source/user-guide/latest/compatibility.md @@ -89,6 +89,7 @@ The following cast operations are generally compatible with Spark except for the + | From Type | To Type | Notes | |-|-|-| | boolean | byte | | @@ -165,6 +166,7 @@ The following cast operations are generally compatible with Spark except for the | timestamp | long | | | timestamp | string | | | timestamp | date | | + ### Incompatible Casts @@ -174,6 +176,7 @@ The following cast operations are not compatible with Spark for all inputs and a + | From Type | To Type | Notes | |-|-|-| | float | decimal | There can be rounding differences | @@ -182,6 +185,7 @@ The following cast operations are not compatible with Spark for all inputs and a | string | double | Does not support inputs ending with 'd' or 'f'. Does not support 'inf'. Does not support ANSI mode. | | string | decimal | Does not support inputs ending with 'd' or 'f'. Does not support 'inf'. Does not support ANSI mode. Returns 0.0 instead of null if input contains no digits | | string | timestamp | Not all valid formats are supported | + ### Unsupported Casts diff --git a/docs/source/user-guide/latest/configs.md b/docs/source/user-guide/latest/configs.md index df6ff0e313..1e77032f7d 100644 --- a/docs/source/user-guide/latest/configs.md +++ b/docs/source/user-guide/latest/configs.md @@ -25,19 +25,23 @@ Comet provides the following configuration settings. + | Config | Description | Default Value | |--------|-------------|---------------| | `spark.comet.scan.allowIncompatible` | Some Comet scan implementations are not currently fully compatible with Spark for all datatypes. Set this config to true to allow them anyway. For more information, refer to the [Comet Compatibility Guide](https://datafusion.apache.org/comet/user-guide/compatibility.html). | false | | `spark.comet.scan.enabled` | Whether to enable native scans. When this is turned on, Spark will use Comet to read supported data sources (currently only Parquet is supported natively). Note that to enable native vectorized execution, both this config and `spark.comet.exec.enabled` need to be enabled. | true | +| `spark.comet.scan.icebergNative.enabled` | Whether to enable native Iceberg table scan using iceberg-rust. When enabled, Iceberg tables are read directly through native execution, bypassing Spark's DataSource V2 API for better performance. | false | | `spark.comet.scan.preFetch.enabled` | Whether to enable pre-fetching feature of CometScan. | false | | `spark.comet.scan.preFetch.threadNum` | The number of threads running pre-fetching for CometScan. Effective if spark.comet.scan.preFetch.enabled is enabled. Note that more pre-fetching threads means more memory requirement to store pre-fetched row groups. | 2 | | `spark.hadoop.fs.comet.libhdfs.schemes` | Defines filesystem schemes (e.g., hdfs, webhdfs) that the native side accesses via libhdfs, separated by commas. Valid only when built with hdfs feature enabled. | | + ## Parquet Reader Configuration Settings + | Config | Description | Default Value | |--------|-------------|---------------| | `spark.comet.parquet.enable.directBuffer` | Whether to use Java direct byte buffer when reading Parquet. | false | @@ -47,12 +51,14 @@ Comet provides the following configuration settings. | `spark.comet.parquet.read.parallel.io.enabled` | Whether to enable Comet's parallel reader for Parquet files. The parallel reader reads ranges of consecutive data in a file in parallel. It is faster for large files and row groups but uses more resources. | true | | `spark.comet.parquet.read.parallel.io.thread-pool.size` | The maximum number of parallel threads the parallel reader will use in a single executor. For executors configured with a smaller number of cores, use a smaller number. | 16 | | `spark.comet.parquet.respectFilterPushdown` | Whether to respect Spark's PARQUET_FILTER_PUSHDOWN_ENABLED config. This needs to be respected when running the Spark SQL test suite but the default setting results in poor performance in Comet when using the new native scans, disabled by default | false | + ## Query Execution Settings + | Config | Description | Default Value | |--------|-------------|---------------| | `spark.comet.caseConversion.enabled` | Java uses locale-specific rules when converting strings to upper or lower case and Rust does not, so we disable upper and lower by default. | false | @@ -67,6 +73,7 @@ Comet provides the following configuration settings. | `spark.comet.metrics.updateInterval` | The interval in milliseconds to update metrics. If interval is negative, metrics will be updated upon task completion. | 3000 | | `spark.comet.nativeLoadRequired` | Whether to require Comet native library to load successfully when Comet is enabled. If not, Comet will silently fallback to Spark when it fails to load the native lib. Otherwise, an error will be thrown and the Spark job will be aborted. | false | | `spark.comet.regexp.allowIncompatible` | Comet is not currently fully compatible with Spark for all regular expressions. Set this config to true to allow them anyway. For more information, refer to the [Comet Compatibility Guide](https://datafusion.apache.org/comet/user-guide/compatibility.html). | false | + ## Viewing Explain Plan & Fallback Reasons @@ -75,6 +82,7 @@ These settings can be used to determine which parts of the plan are accelerated + | Config | Description | Default Value | |--------|-------------|---------------| | `spark.comet.explain.format` | Choose extended explain output. The default format of 'verbose' will provide the full query plan annotated with fallback reasons as well as a summary of how much of the plan was accelerated by Comet. The format 'fallback' provides a list of fallback reasons instead. | verbose | @@ -82,12 +90,14 @@ These settings can be used to determine which parts of the plan are accelerated | `spark.comet.explain.rules` | When this setting is enabled, Comet will log all plan transformations performed in physical optimizer rules. Default: false | false | | `spark.comet.explainFallback.enabled` | When this setting is enabled, Comet will provide logging explaining the reason(s) why a query stage cannot be executed natively. Set this to false to reduce the amount of logging. | false | | `spark.comet.logFallbackReasons.enabled` | When this setting is enabled, Comet will log warnings for all fallback reasons. It can be overridden by the environment variable `ENABLE_COMET_LOG_FALLBACK_REASONS`. | false | + ## Shuffle Configuration Settings + | Config | Description | Default Value | |--------|-------------|---------------| | `spark.comet.columnar.shuffle.async.enabled` | Whether to enable asynchronous shuffle for Arrow-based shuffle. | false | @@ -101,24 +111,28 @@ These settings can be used to determine which parts of the plan are accelerated | `spark.comet.native.shuffle.partitioning.range.enabled` | Whether to enable range partitioning for Comet native shuffle. | true | | `spark.comet.shuffle.preferDictionary.ratio` | The ratio of total values to distinct values in a string column to decide whether to prefer dictionary encoding when shuffling the column. If the ratio is higher than this config, dictionary encoding will be used on shuffling string column. This config is effective if it is higher than 1.0. Note that this config is only used when `spark.comet.exec.shuffle.mode` is `jvm`. | 10.0 | | `spark.comet.shuffle.sizeInBytesMultiplier` | Comet reports smaller sizes for shuffle due to using Arrow's columnar memory format and this can result in Spark choosing a different join strategy due to the estimated size of the exchange being smaller. Comet will multiple sizeInBytes by this amount to avoid regressions in join strategy. | 1.0 | + ## Memory & Tuning Configuration Settings + | Config | Description | Default Value | |--------|-------------|---------------| | `spark.comet.batchSize` | The columnar batch size, i.e., the maximum number of rows that a batch can contain. | 8192 | | `spark.comet.exec.memoryPool` | The type of memory pool to be used for Comet native execution when running Spark in off-heap mode. Available pool types are `greedy_unified` and `fair_unified`. For more information, refer to the [Comet Tuning Guide](https://datafusion.apache.org/comet/user-guide/tuning.html). | fair_unified | | `spark.comet.exec.memoryPool.fraction` | Fraction of off-heap memory pool that is available to Comet. Only applies to off-heap mode. For more information, refer to the [Comet Tuning Guide](https://datafusion.apache.org/comet/user-guide/tuning.html). | 1.0 | | `spark.comet.tracing.enabled` | Enable fine-grained tracing of events and memory usage. For more information, refer to the [Comet Tracing Guide](https://datafusion.apache.org/comet/user-guide/tracing.html). | false | + ## Development & Testing Settings + | Config | Description | Default Value | |--------|-------------|---------------| | `spark.comet.columnar.shuffle.memory.factor` | Fraction of Comet memory to be allocated per executor process for columnar shuffle when running in on-heap mode. For more information, refer to the [Comet Tuning Guide](https://datafusion.apache.org/comet/user-guide/tuning.html). | 1.0 | @@ -131,12 +145,14 @@ These settings can be used to determine which parts of the plan are accelerated | `spark.comet.sparkToColumnar.enabled` | Whether to enable Spark to Arrow columnar conversion. When this is turned on, Comet will convert operators in `spark.comet.sparkToColumnar.supportedOperatorList` into Arrow columnar format before processing. This is an experimental feature and has known issues with non-UTC timezones. | false | | `spark.comet.sparkToColumnar.supportedOperatorList` | A comma-separated list of operators that will be converted to Arrow columnar format when `spark.comet.sparkToColumnar.enabled` is true. | Range,InMemoryTableScan,RDDScan | | `spark.comet.testing.strict` | Experimental option to enable strict testing, which will fail tests that could be more comprehensive, such as checking for a specific fallback reason. It can be overridden by the environment variable `ENABLE_COMET_STRICT_TESTING`. | false | + ## Enabling or Disabling Individual Operators + | Config | Description | Default Value | |--------|-------------|---------------| | `spark.comet.exec.aggregate.enabled` | Whether to enable aggregate by default. | true | @@ -157,12 +173,14 @@ These settings can be used to determine which parts of the plan are accelerated | `spark.comet.exec.takeOrderedAndProject.enabled` | Whether to enable takeOrderedAndProject by default. | true | | `spark.comet.exec.union.enabled` | Whether to enable union by default. | true | | `spark.comet.exec.window.enabled` | Whether to enable window by default. | true | + ## Enabling or Disabling Individual Scalar Expressions + | Config | Description | Default Value | |--------|-------------|---------------| | `spark.comet.expression.Abs.enabled` | Enable Comet acceleration for `Abs` | true | @@ -197,6 +215,7 @@ These settings can be used to determine which parts of the plan are accelerated | `spark.comet.expression.BitwiseNot.enabled` | Enable Comet acceleration for `BitwiseNot` | true | | `spark.comet.expression.BitwiseOr.enabled` | Enable Comet acceleration for `BitwiseOr` | true | | `spark.comet.expression.BitwiseXor.enabled` | Enable Comet acceleration for `BitwiseXor` | true | +| `spark.comet.expression.BloomFilterMightContain.enabled` | Enable Comet acceleration for `BloomFilterMightContain` | true | | `spark.comet.expression.CaseWhen.enabled` | Enable Comet acceleration for `CaseWhen` | true | | `spark.comet.expression.Cast.enabled` | Enable Comet acceleration for `Cast` | true | | `spark.comet.expression.Ceil.enabled` | Enable Comet acceleration for `Ceil` | true | @@ -207,6 +226,7 @@ These settings can be used to determine which parts of the plan are accelerated | `spark.comet.expression.ConcatWs.enabled` | Enable Comet acceleration for `ConcatWs` | true | | `spark.comet.expression.Contains.enabled` | Enable Comet acceleration for `Contains` | true | | `spark.comet.expression.Cos.enabled` | Enable Comet acceleration for `Cos` | true | +| `spark.comet.expression.Cosh.enabled` | Enable Comet acceleration for `Cosh` | true | | `spark.comet.expression.Cot.enabled` | Enable Comet acceleration for `Cot` | true | | `spark.comet.expression.CreateArray.enabled` | Enable Comet acceleration for `CreateArray` | true | | `spark.comet.expression.CreateNamedStruct.enabled` | Enable Comet acceleration for `CreateNamedStruct` | true | @@ -241,6 +261,7 @@ These settings can be used to determine which parts of the plan are accelerated | `spark.comet.expression.IsNaN.enabled` | Enable Comet acceleration for `IsNaN` | true | | `spark.comet.expression.IsNotNull.enabled` | Enable Comet acceleration for `IsNotNull` | true | | `spark.comet.expression.IsNull.enabled` | Enable Comet acceleration for `IsNull` | true | +| `spark.comet.expression.KnownFloatingPointNormalized.enabled` | Enable Comet acceleration for `KnownFloatingPointNormalized` | true | | `spark.comet.expression.Length.enabled` | Enable Comet acceleration for `Length` | true | | `spark.comet.expression.LessThan.enabled` | Enable Comet acceleration for `LessThan` | true | | `spark.comet.expression.LessThanOrEqual.enabled` | Enable Comet acceleration for `LessThanOrEqual` | true | @@ -250,6 +271,7 @@ These settings can be used to determine which parts of the plan are accelerated | `spark.comet.expression.Log10.enabled` | Enable Comet acceleration for `Log10` | true | | `spark.comet.expression.Log2.enabled` | Enable Comet acceleration for `Log2` | true | | `spark.comet.expression.Lower.enabled` | Enable Comet acceleration for `Lower` | true | +| `spark.comet.expression.MakeDecimal.enabled` | Enable Comet acceleration for `MakeDecimal` | true | | `spark.comet.expression.MapEntries.enabled` | Enable Comet acceleration for `MapEntries` | true | | `spark.comet.expression.MapFromArrays.enabled` | Enable Comet acceleration for `MapFromArrays` | true | | `spark.comet.expression.MapKeys.enabled` | Enable Comet acceleration for `MapKeys` | true | @@ -272,6 +294,7 @@ These settings can be used to determine which parts of the plan are accelerated | `spark.comet.expression.Remainder.enabled` | Enable Comet acceleration for `Remainder` | true | | `spark.comet.expression.Reverse.enabled` | Enable Comet acceleration for `Reverse` | true | | `spark.comet.expression.Round.enabled` | Enable Comet acceleration for `Round` | true | +| `spark.comet.expression.ScalarSubquery.enabled` | Enable Comet acceleration for `ScalarSubquery` | true | | `spark.comet.expression.Second.enabled` | Enable Comet acceleration for `Second` | true | | `spark.comet.expression.Sha1.enabled` | Enable Comet acceleration for `Sha1` | true | | `spark.comet.expression.Sha2.enabled` | Enable Comet acceleration for `Sha2` | true | @@ -279,6 +302,7 @@ These settings can be used to determine which parts of the plan are accelerated | `spark.comet.expression.ShiftRight.enabled` | Enable Comet acceleration for `ShiftRight` | true | | `spark.comet.expression.Signum.enabled` | Enable Comet acceleration for `Signum` | true | | `spark.comet.expression.Sin.enabled` | Enable Comet acceleration for `Sin` | true | +| `spark.comet.expression.Sinh.enabled` | Enable Comet acceleration for `Sinh` | true | | `spark.comet.expression.SortOrder.enabled` | Enable Comet acceleration for `SortOrder` | true | | `spark.comet.expression.SparkPartitionID.enabled` | Enable Comet acceleration for `SparkPartitionID` | true | | `spark.comet.expression.Sqrt.enabled` | Enable Comet acceleration for `Sqrt` | true | @@ -299,21 +323,25 @@ These settings can be used to determine which parts of the plan are accelerated | `spark.comet.expression.Substring.enabled` | Enable Comet acceleration for `Substring` | true | | `spark.comet.expression.Subtract.enabled` | Enable Comet acceleration for `Subtract` | true | | `spark.comet.expression.Tan.enabled` | Enable Comet acceleration for `Tan` | true | +| `spark.comet.expression.Tanh.enabled` | Enable Comet acceleration for `Tanh` | true | | `spark.comet.expression.TruncDate.enabled` | Enable Comet acceleration for `TruncDate` | true | | `spark.comet.expression.TruncTimestamp.enabled` | Enable Comet acceleration for `TruncTimestamp` | true | | `spark.comet.expression.UnaryMinus.enabled` | Enable Comet acceleration for `UnaryMinus` | true | | `spark.comet.expression.Unhex.enabled` | Enable Comet acceleration for `Unhex` | true | +| `spark.comet.expression.UnscaledValue.enabled` | Enable Comet acceleration for `UnscaledValue` | true | | `spark.comet.expression.Upper.enabled` | Enable Comet acceleration for `Upper` | true | | `spark.comet.expression.WeekDay.enabled` | Enable Comet acceleration for `WeekDay` | true | | `spark.comet.expression.WeekOfYear.enabled` | Enable Comet acceleration for `WeekOfYear` | true | | `spark.comet.expression.XxHash64.enabled` | Enable Comet acceleration for `XxHash64` | true | | `spark.comet.expression.Year.enabled` | Enable Comet acceleration for `Year` | true | + ## Enabling or Disabling Individual Aggregate Expressions + | Config | Description | Default Value | |--------|-------------|---------------| | `spark.comet.expression.Average.enabled` | Enable Comet acceleration for `Average` | true | @@ -334,4 +362,5 @@ These settings can be used to determine which parts of the plan are accelerated | `spark.comet.expression.Sum.enabled` | Enable Comet acceleration for `Sum` | true | | `spark.comet.expression.VariancePop.enabled` | Enable Comet acceleration for `VariancePop` | true | | `spark.comet.expression.VarianceSamp.enabled` | Enable Comet acceleration for `VarianceSamp` | true | + diff --git a/spark/src/test/scala/org/apache/spark/sql/comet/CometPlanStabilitySuite.scala b/spark/src/test/scala/org/apache/spark/sql/comet/CometPlanStabilitySuite.scala index 3b494cc6d1..1c0c8f8966 100644 --- a/spark/src/test/scala/org/apache/spark/sql/comet/CometPlanStabilitySuite.scala +++ b/spark/src/test/scala/org/apache/spark/sql/comet/CometPlanStabilitySuite.scala @@ -29,7 +29,7 @@ import org.apache.spark.SparkContext import org.apache.spark.internal.config.{MEMORY_OFFHEAP_ENABLED, MEMORY_OFFHEAP_SIZE} import org.apache.spark.sql.TPCDSBase import org.apache.spark.sql.catalyst.expressions.{AttributeSet, Cast} -import org.apache.spark.sql.catalyst.expressions.aggregate.{Average, Sum} +import org.apache.spark.sql.catalyst.expressions.aggregate.Sum import org.apache.spark.sql.catalyst.util.resourceToString import org.apache.spark.sql.execution.{FormattedMode, ReusedSubqueryExec, SparkPlan, SubqueryBroadcastExec, SubqueryExec} import org.apache.spark.sql.execution.adaptive.DisableAdaptiveExecutionSuite From 2845d6294b5e79e138d4a74ede8ca36b97ba2dc9 Mon Sep 17 00:00:00 2001 From: B Vadlamani Date: Sun, 14 Dec 2025 23:25:35 -0800 Subject: [PATCH 06/16] rebase_main --- .../scala/org/apache/comet/exec/CometAggregateSuite.scala | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala index 95783098cc..83cb6acdbc 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala @@ -1495,7 +1495,6 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper { val res = sql("SELECT _2, avg(_1) FROM tbl GROUP BY _2") checkSparkAnswerAndOperator(res) } - }) // try_avg without GROUP BY @@ -1505,6 +1504,10 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper { // try_avg with GROUP BY val resTryGroup = sql("SELECT _2, try_avg(_1) FROM tbl GROUP BY _2") checkSparkAnswerAndOperator(resTryGroup) + + } + } + test("ANSI support for decimal sum - null test") { Seq(true, false).foreach { ansiEnabled => withSQLConf( From df4585105e5c05c3f7b38371ab48647006694c89 Mon Sep 17 00:00:00 2001 From: B Vadlamani Date: Tue, 23 Dec 2025 03:13:59 -0800 Subject: [PATCH 07/16] address_review_comments --- native/spark-expr/src/agg_funcs/avg.rs | 22 +++++-------------- .../comet/exec/CometAggregateSuite.scala | 7 +++--- 2 files changed, 9 insertions(+), 20 deletions(-) diff --git a/native/spark-expr/src/agg_funcs/avg.rs b/native/spark-expr/src/agg_funcs/avg.rs index 5fca5b3947..0103a81d8f 100644 --- a/native/spark-expr/src/agg_funcs/avg.rs +++ b/native/spark-expr/src/agg_funcs/avg.rs @@ -48,13 +48,12 @@ pub struct Avg { name: String, signature: Signature, input_data_type: DataType, - result_data_type: DataType, - eval_mode: EvalMode, + result_data_type: DataType } impl Avg { /// Create a new AVG aggregate function - pub fn new(name: impl Into, data_type: DataType, eval_mode: EvalMode) -> Self { + pub fn new(name: impl Into, data_type: DataType) -> Self { let result_data_type = avg_return_type("avg", &data_type).unwrap(); Self { @@ -62,7 +61,6 @@ impl Avg { signature: Signature::user_defined(Immutable), input_data_type: data_type, result_data_type, - eval_mode, } } } @@ -75,7 +73,8 @@ impl AggregateUDFImpl for Avg { fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result> { // All numeric types use Float64 accumulation after casting match (&self.input_data_type, &self.result_data_type) { - (Float64, Float64) => Ok(Box::new(AvgAccumulator::new(self.eval_mode))), + (Float64, Float64) => Ok(Box::new(AvgAccumulator:: + new())), _ => not_impl_err!( "AvgAccumulator for ({} --> {})", self.input_data_type, @@ -118,7 +117,6 @@ impl AggregateUDFImpl for Avg { match (&self.input_data_type, &self.result_data_type) { (Float64, Float64) => Ok(Box::new(AvgGroupsAccumulator::::new( &self.input_data_type, - self.eval_mode, |sum: f64, count: i64| Ok(sum / count as f64), ))), @@ -147,16 +145,13 @@ impl AggregateUDFImpl for Avg { pub struct AvgAccumulator { sum: Option, count: i64, - #[allow(dead_code)] - eval_mode: EvalMode, } impl AvgAccumulator { - pub fn new(eval_mode: EvalMode) -> Self { + pub fn new() -> Self { Self { sum: None, count: 0, - eval_mode, } } } @@ -226,10 +221,6 @@ where /// Sums per group, stored as the native type sums: Vec, - /// Evaluation mode (stored but not used for Float64) - #[allow(dead_code)] - eval_mode: EvalMode, - /// Function that computes the final average (value / count) avg_fn: F, } @@ -239,12 +230,11 @@ where T: ArrowNumericType + Send, F: Fn(T::Native, i64) -> Result + Send, { - pub fn new(return_data_type: &DataType, eval_mode: EvalMode, avg_fn: F) -> Self { + pub fn new(return_data_type: &DataType, avg_fn: F) -> Self { Self { return_data_type: return_data_type.clone(), counts: vec![], sums: vec![], - eval_mode, avg_fn, } } diff --git a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala index 83cb6acdbc..f045f2ac15 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala @@ -1483,15 +1483,15 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper { (null.asInstanceOf[Long], 3)), "tbl") { - Seq(true, false).foreach({ k => + Seq(true, false).foreach({ ansiMode => // without GROUP BY - withSQLConf(SQLConf.ANSI_ENABLED.key -> k.toString) { + withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiMode.toString) { val res = sql("SELECT avg(_1) FROM tbl") checkSparkAnswerAndOperator(res) } // with GROUP BY - withSQLConf(SQLConf.ANSI_ENABLED.key -> k.toString) { + withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiMode.toString) { val res = sql("SELECT _2, avg(_1) FROM tbl GROUP BY _2") checkSparkAnswerAndOperator(res) } @@ -1520,7 +1520,6 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper { "null_tbl") { val res = sql("SELECT sum(_1) FROM null_tbl") checkSparkAnswerAndOperator(res) - assert(res.collect() === Array(Row(null))) } } } From 512cc6b346b953483fdfc1236d0d37c99842a03b Mon Sep 17 00:00:00 2001 From: B Vadlamani Date: Tue, 23 Dec 2025 03:21:39 -0800 Subject: [PATCH 08/16] address_review_comments --- native/spark-expr/src/agg_funcs/avg.rs | 38 ++++++++++++-------------- 1 file changed, 18 insertions(+), 20 deletions(-) diff --git a/native/spark-expr/src/agg_funcs/avg.rs b/native/spark-expr/src/agg_funcs/avg.rs index 0103a81d8f..7a4084d6c8 100644 --- a/native/spark-expr/src/agg_funcs/avg.rs +++ b/native/spark-expr/src/agg_funcs/avg.rs @@ -15,12 +15,11 @@ // specific language governing permissions and limitations // under the License. -use crate::EvalMode; use arrow::array::{ builder::PrimitiveBuilder, cast::AsArray, types::{Float64Type, Int64Type}, - Array, ArrayRef, ArrowNativeTypeOp, ArrowNumericType, Int64Array, PrimitiveArray, + Array, ArrayRef, ArrowNumericType, Int64Array, PrimitiveArray, }; use arrow::compute::sum; use arrow::datatypes::{DataType, Field, FieldRef}; @@ -31,6 +30,7 @@ use datafusion::logical_expr::{ use datafusion::physical_expr::expressions::format_state_name; use std::{any::Any, sync::Arc}; +use arrow::array::ArrowNativeTypeOp; use datafusion::logical_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion::logical_expr::Volatility::Immutable; use DataType::*; @@ -47,8 +47,9 @@ fn avg_return_type(_name: &str, data_type: &DataType) -> Result { pub struct Avg { name: String, signature: Signature, + // expr: Arc, input_data_type: DataType, - result_data_type: DataType + result_data_type: DataType, } impl Avg { @@ -66,6 +67,7 @@ impl Avg { } impl AggregateUDFImpl for Avg { + /// Return a reference to Any that can be used for downcasting fn as_any(&self) -> &dyn Any { self } @@ -73,8 +75,7 @@ impl AggregateUDFImpl for Avg { fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result> { // All numeric types use Float64 accumulation after casting match (&self.input_data_type, &self.result_data_type) { - (Float64, Float64) => Ok(Box::new(AvgAccumulator:: - new())), + (Float64, Float64) => Ok(Box::::default()), _ => not_impl_err!( "AvgAccumulator for ({} --> {})", self.input_data_type, @@ -141,21 +142,13 @@ impl AggregateUDFImpl for Avg { } } -#[derive(Debug)] +/// An accumulator to compute the average +#[derive(Debug, Default)] pub struct AvgAccumulator { sum: Option, count: i64, } -impl AvgAccumulator { - pub fn new() -> Self { - Self { - sum: None, - count: 0, - } - } -} - impl Accumulator for AvgAccumulator { fn state(&mut self) -> Result> { Ok(vec![ @@ -178,7 +171,7 @@ impl Accumulator for AvgAccumulator { // counts are summed self.count += sum(states[1].as_primitive::()).unwrap_or_default(); - // sums are summed - no overflow checking + // sums are summed - no overflow checking in all Eval Modes if let Some(x) = sum(states[0].as_primitive::()) { let v = self.sum.get_or_insert(0.); *v += x; @@ -188,6 +181,8 @@ impl Accumulator for AvgAccumulator { fn evaluate(&mut self) -> Result { if self.count == 0 { + // If all input are nulls, count will be 0, and we will get null after the division. + // This is consistent with Spark Average implementation. Ok(ScalarValue::Float64(None)) } else { Ok(ScalarValue::Float64( @@ -202,7 +197,8 @@ impl Accumulator for AvgAccumulator { } /// An accumulator to compute the average of `[PrimitiveArray]`. -/// Stores values as native types. +/// Stores values as native types ( +/// no overflow check all eval modes since inf is a perfectly valid value per spark impl) /// /// F: Function that calculates the average value from a sum of /// T::Native and a total count @@ -264,7 +260,7 @@ where if values.null_count() == 0 { for (&group_index, &value) in iter { let sum = &mut self.sums[group_index]; - // No overflow checking - INFINITY is a valid result + // No overflow checking - Infinity is a valid result *sum = (*sum).add_wrapping(value); self.counts[group_index] += 1; } @@ -275,6 +271,7 @@ where } let sum = &mut self.sums[group_index]; *sum = (*sum).add_wrapping(value); + self.counts[group_index] += 1; } } @@ -290,9 +287,9 @@ where total_num_groups: usize, ) -> Result<()> { assert_eq!(values.len(), 2, "two arguments to merge_batch"); + // first batch is partial sums, second is counts let partial_sums = values[0].as_primitive::(); let partial_counts = values[1].as_primitive::(); - // update counts with partial counts self.counts.resize(total_num_groups, 0); let iter1 = group_indices.iter().zip(partial_counts.values().iter()); @@ -300,7 +297,7 @@ where self.counts[group_index] += partial_count; } - // update sums - no overflow checking + // update sums - no overflow checking (in all eval modes) self.sums.resize(total_num_groups, T::default_value()); let iter2 = group_indices.iter().zip(partial_sums.values().iter()); for (&group_index, &new_value) in iter2 { @@ -329,6 +326,7 @@ where Ok(Arc::new(array)) } + fn state(&mut self, emit_to: EmitTo) -> Result> { let counts = emit_to.take_needed(&mut self.counts); let counts = Int64Array::new(counts.into(), None); From 25bae5480516a9846923cf0ce4aa59a9bab183cb Mon Sep 17 00:00:00 2001 From: B Vadlamani Date: Tue, 23 Dec 2025 04:00:09 -0800 Subject: [PATCH 09/16] address_review_comments --- .../comet/exec/CometAggregateSuite.scala | 62 +++++++++++++++++++ 1 file changed, 62 insertions(+) diff --git a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala index f045f2ac15..aefbdf22d6 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala @@ -1488,6 +1488,7 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper { withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiMode.toString) { val res = sql("SELECT avg(_1) FROM tbl") checkSparkAnswerAndOperator(res) + assert(res.collect() === Array(Row(null))) } // with GROUP BY @@ -1508,6 +1509,67 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } + test("AVG and try_avg - special numbers") { + + val negativeNumbers = Seq( + (-1L, 1), + (-123L, 1), + (-456L, 1), + (-9223372036854775808L, 1), + (-9223372036854775808L, 1), + (-9223372036854775808L, 2), + (-9223372036854775807L, 2), + (null.asInstanceOf[String], 3) + ) + + val zeroSeq = Seq( + (0, 1), + (-0, 1), + (+0, 2), + (+0, 2), + (null.asInstanceOf[String], 3) + ) + + val highNegNumbers = Seq( + (Long.MaxValue, 1), + (Long.MaxValue, 1), + (Long.MaxValue, 2), + (Long.MaxValue, 2), + (null.asInstanceOf[String], 3) + ) + val inputs = Seq(negativeNumbers, zeroSeq, highNegNumbers) + inputs.foreach( + inputSeq => { + withParquetTable(inputSeq, + "tbl") { + + Seq(true, false).foreach({ ansiMode => + // without GROUP BY + withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiMode.toString) { + val res = sql("SELECT avg(_1) FROM tbl") + checkSparkAnswerAndOperator(res) + assert(res.collect() === Array(Row(null))) + } + + // with GROUP BY + withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiMode.toString) { + val res = sql("SELECT _2, avg(_1) FROM tbl GROUP BY _2") + checkSparkAnswerAndOperator(res) + } + }) + + // try_avg without GROUP BY + val resTry = sql("SELECT try_avg(_1) FROM tbl") + checkSparkAnswerAndOperator(resTry) + + // try_avg with GROUP BY + val resTryGroup = sql("SELECT _2, try_avg(_1) FROM tbl GROUP BY _2") + checkSparkAnswerAndOperator(resTryGroup) + + } + }) + } + test("ANSI support for decimal sum - null test") { Seq(true, false).foreach { ansiEnabled => withSQLConf( From 57e3de831cdc54ab749e2df05c0b5945d954a1dd Mon Sep 17 00:00:00 2001 From: B Vadlamani Date: Tue, 23 Dec 2025 09:47:29 -0800 Subject: [PATCH 10/16] address_review_comments --- native/core/src/execution/planner.rs | 4 +- native/spark-expr/src/agg_funcs/avg.rs | 1 - .../comet/exec/CometAggregateSuite.scala | 66 ++++++++----------- 3 files changed, 29 insertions(+), 42 deletions(-) diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index 45caa0e694..e7ecf8a8d3 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -1863,7 +1863,6 @@ impl PhysicalPlanner { let child = self.create_expr(expr.child.as_ref().unwrap(), Arc::clone(&schema))?; let datatype = to_arrow_datatype(expr.datatype.as_ref().unwrap()); let input_datatype = to_arrow_datatype(expr.sum_datatype.as_ref().unwrap()); - let eval_mode = from_protobuf_eval_mode(expr.eval_mode)?; let builder = match datatype { DataType::Decimal128(_, _) => { @@ -1878,8 +1877,7 @@ impl PhysicalPlanner { Arc::new(CastExpr::new(Arc::clone(&child), DataType::Float64, None)); let func = AggregateUDF::new_from_impl(Avg::new( "avg", - DataType::Float64, - eval_mode, + DataType::Float64 )); AggregateExprBuilder::new(Arc::new(func), vec![child]) } diff --git a/native/spark-expr/src/agg_funcs/avg.rs b/native/spark-expr/src/agg_funcs/avg.rs index 7a4084d6c8..d1d71cca21 100644 --- a/native/spark-expr/src/agg_funcs/avg.rs +++ b/native/spark-expr/src/agg_funcs/avg.rs @@ -326,7 +326,6 @@ where Ok(Arc::new(array)) } - fn state(&mut self, emit_to: EmitTo) -> Result> { let counts = emit_to.take_needed(&mut self.counts); let counts = Int64Array::new(counts.into(), None); diff --git a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala index aefbdf22d6..b34ced485b 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala @@ -1511,7 +1511,7 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper { test("AVG and try_avg - special numbers") { - val negativeNumbers = Seq( + val negativeNumbers = Seq( (-1L, 1), (-123L, 1), (-456L, 1), @@ -1519,55 +1519,45 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper { (-9223372036854775808L, 1), (-9223372036854775808L, 2), (-9223372036854775807L, 2), - (null.asInstanceOf[String], 3) - ) + (null.asInstanceOf[String], 3)) - val zeroSeq = Seq( - (0, 1), - (-0, 1), - (+0, 2), - (+0, 2), - (null.asInstanceOf[String], 3) - ) + val zeroSeq = Seq((0, 1), (-0, 1), (+0, 2), (+0, 2), (null.asInstanceOf[String], 3)) val highNegNumbers = Seq( (Long.MaxValue, 1), (Long.MaxValue, 1), (Long.MaxValue, 2), (Long.MaxValue, 2), - (null.asInstanceOf[String], 3) - ) + (null.asInstanceOf[String], 3)) val inputs = Seq(negativeNumbers, zeroSeq, highNegNumbers) - inputs.foreach( - inputSeq => { - withParquetTable(inputSeq, - "tbl") { + inputs.foreach(inputSeq => { + withParquetTable(inputSeq, "tbl") { - Seq(true, false).foreach({ ansiMode => - // without GROUP BY - withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiMode.toString) { - val res = sql("SELECT avg(_1) FROM tbl") - checkSparkAnswerAndOperator(res) - assert(res.collect() === Array(Row(null))) - } + Seq(true, false).foreach({ ansiMode => + // without GROUP BY + withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiMode.toString) { + val res = sql("SELECT avg(_1) FROM tbl") + checkSparkAnswerAndOperator(res) + assert(res.collect() === Array(Row(null))) + } - // with GROUP BY - withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiMode.toString) { - val res = sql("SELECT _2, avg(_1) FROM tbl GROUP BY _2") - checkSparkAnswerAndOperator(res) - } - }) + // with GROUP BY + withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiMode.toString) { + val res = sql("SELECT _2, avg(_1) FROM tbl GROUP BY _2") + checkSparkAnswerAndOperator(res) + } + }) - // try_avg without GROUP BY - val resTry = sql("SELECT try_avg(_1) FROM tbl") - checkSparkAnswerAndOperator(resTry) + // try_avg without GROUP BY + val resTry = sql("SELECT try_avg(_1) FROM tbl") + checkSparkAnswerAndOperator(resTry) - // try_avg with GROUP BY - val resTryGroup = sql("SELECT _2, try_avg(_1) FROM tbl GROUP BY _2") - checkSparkAnswerAndOperator(resTryGroup) - - } - }) + // try_avg with GROUP BY + val resTryGroup = sql("SELECT _2, try_avg(_1) FROM tbl GROUP BY _2") + checkSparkAnswerAndOperator(resTryGroup) + + } + }) } test("ANSI support for decimal sum - null test") { From 979136860a491d6a722f43b65e261a2b2ac93492 Mon Sep 17 00:00:00 2001 From: B Vadlamani Date: Tue, 23 Dec 2025 09:51:36 -0800 Subject: [PATCH 11/16] address_review_comments --- .../scala/org/apache/comet/exec/CometAggregateSuite.scala | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala index b34ced485b..5c2de43f11 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala @@ -1515,10 +1515,10 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper { (-1L, 1), (-123L, 1), (-456L, 1), - (-9223372036854775808L, 1), - (-9223372036854775808L, 1), - (-9223372036854775808L, 2), - (-9223372036854775807L, 2), + (Long.MinValue, 1), + (Long.MinValue, 1), + (Long.MinValue, 2), + (Long.MinValue, 2), (null.asInstanceOf[String], 3)) val zeroSeq = Seq((0, 1), (-0, 1), (+0, 2), (+0, 2), (null.asInstanceOf[String], 3)) From ef4ecdc1db30a669868152ba793beb5059f54806 Mon Sep 17 00:00:00 2001 From: B Vadlamani Date: Tue, 23 Dec 2025 10:57:34 -0800 Subject: [PATCH 12/16] address_review_comments --- .../apache/comet/exec/CometAggregateSuite.scala | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala index 5c2de43f11..fc2466d2e1 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala @@ -1511,7 +1511,7 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper { test("AVG and try_avg - special numbers") { - val negativeNumbers = Seq( + val negativeNumbers: Seq[(Long, Int)] = Seq( (-1L, 1), (-123L, 1), (-456L, 1), @@ -1519,26 +1519,26 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper { (Long.MinValue, 1), (Long.MinValue, 2), (Long.MinValue, 2), - (null.asInstanceOf[String], 3)) + (null.asInstanceOf[Long], 3)) - val zeroSeq = Seq((0, 1), (-0, 1), (+0, 2), (+0, 2), (null.asInstanceOf[String], 3)) + val zeroSeq: Seq[(Long, Int)] = + Seq((0L, 1), (-0L, 1), (+0L, 2), (+0L, 2), (null.asInstanceOf[Long], 3)) - val highNegNumbers = Seq( + val highNegNumbers: Seq[(Long, Int)] = Seq( (Long.MaxValue, 1), (Long.MaxValue, 1), (Long.MaxValue, 2), (Long.MaxValue, 2), - (null.asInstanceOf[String], 3)) - val inputs = Seq(negativeNumbers, zeroSeq, highNegNumbers) + (null.asInstanceOf[Long], 3)) + + val inputs = Seq(negativeNumbers, highNegNumbers, zeroSeq) inputs.foreach(inputSeq => { withParquetTable(inputSeq, "tbl") { - Seq(true, false).foreach({ ansiMode => // without GROUP BY withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiMode.toString) { val res = sql("SELECT avg(_1) FROM tbl") checkSparkAnswerAndOperator(res) - assert(res.collect() === Array(Row(null))) } // with GROUP BY From afe5f24a9d53bcddab57ec496d4776ec8c3f5d28 Mon Sep 17 00:00:00 2001 From: B Vadlamani Date: Tue, 23 Dec 2025 11:28:11 -0800 Subject: [PATCH 13/16] address_review_comments --- native/core/src/execution/planner.rs | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index e7ecf8a8d3..4a68ba16c7 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -1875,10 +1875,7 @@ impl PhysicalPlanner { // Cast to Float64 for accumulation let child: Arc = Arc::new(CastExpr::new(Arc::clone(&child), DataType::Float64, None)); - let func = AggregateUDF::new_from_impl(Avg::new( - "avg", - DataType::Float64 - )); + let func = AggregateUDF::new_from_impl(Avg::new("avg", DataType::Float64)); AggregateExprBuilder::new(Arc::new(func), vec![child]) } }; From a60574faf5b32b8faf5876feae04201167d8d3ea Mon Sep 17 00:00:00 2001 From: B Vadlamani Date: Tue, 23 Dec 2025 14:46:55 -0800 Subject: [PATCH 14/16] address_review_comments --- .../test/scala/org/apache/comet/exec/CometAggregateSuite.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala index b6bb522d18..97742452a0 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala @@ -1487,7 +1487,6 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper { withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiMode.toString) { val res = sql("SELECT avg(_1) FROM tbl") checkSparkAnswerAndOperator(res) - assert(res.collect() === Array(Row(null))) } // with GROUP BY @@ -1557,6 +1556,8 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper { } }) + } + test("ANSI support for sum - null test") { Seq(true, false).foreach { ansiEnabled => withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString) { From 6781b7c7432bd448b4052ef1faf377366eeed652 Mon Sep 17 00:00:00 2001 From: B Vadlamani Date: Tue, 23 Dec 2025 15:39:13 -0800 Subject: [PATCH 15/16] address_review_comments_fix_clippy --- docs/source/user-guide/latest/compatibility.md | 2 +- .../org/apache/comet/serde/aggregates.scala | 2 +- .../apache/comet/exec/CometAggregateSuite.scala | 16 ++++++---------- 3 files changed, 8 insertions(+), 12 deletions(-) diff --git a/docs/source/user-guide/latest/compatibility.md b/docs/source/user-guide/latest/compatibility.md index 3d2c9a7b55..5f5d94d9e1 100644 --- a/docs/source/user-guide/latest/compatibility.md +++ b/docs/source/user-guide/latest/compatibility.md @@ -36,7 +36,7 @@ Comet will fall back to Spark for the following expressions when ANSI mode is en `spark.comet.expression.EXPRNAME.allowIncompatible=true`, where `EXPRNAME` is the Spark expression class name. See the [Comet Supported Expressions Guide](expressions.md) for more information on this configuration setting. -- Average +- Average (supports all numeric inputs except decimal types) - Cast (in some cases) There is an [epic](https://github.com/apache/datafusion-comet/issues/313) where we are tracking the work to fully implement ANSI support. diff --git a/spark/src/main/scala/org/apache/comet/serde/aggregates.scala b/spark/src/main/scala/org/apache/comet/serde/aggregates.scala index 660665214d..8e58c08740 100644 --- a/spark/src/main/scala/org/apache/comet/serde/aggregates.scala +++ b/spark/src/main/scala/org/apache/comet/serde/aggregates.scala @@ -21,7 +21,7 @@ package org.apache.comet.serde import scala.jdk.CollectionConverters._ -import org.apache.spark.sql.catalyst.expressions.{Attribute, EvalMode} +import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Average, BitAndAgg, BitOrAgg, BitXorAgg, BloomFilterAggregate, CentralMomentAgg, Corr, Count, Covariance, CovPopulation, CovSample, First, Last, Max, Min, StddevPop, StddevSamp, Sum, VariancePop, VarianceSamp} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{ByteType, DataTypes, DecimalType, IntegerType, LongType, ShortType, StringType} diff --git a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala index 97742452a0..cd479c4158 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala @@ -1522,37 +1522,33 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper { val zeroSeq: Seq[(Long, Int)] = Seq((0L, 1), (-0L, 1), (+0L, 2), (+0L, 2), (null.asInstanceOf[Long], 3)) - val highNegNumbers: Seq[(Long, Int)] = Seq( + val highValNumbers: Seq[(Long, Int)] = Seq( (Long.MaxValue, 1), (Long.MaxValue, 1), (Long.MaxValue, 2), (Long.MaxValue, 2), (null.asInstanceOf[Long], 3)) - val inputs = Seq(negativeNumbers, highNegNumbers, zeroSeq) + val inputs = Seq(negativeNumbers, highValNumbers, zeroSeq) inputs.foreach(inputSeq => { withParquetTable(inputSeq, "tbl") { Seq(true, false).foreach({ ansiMode => // without GROUP BY withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiMode.toString) { - val res = sql("SELECT avg(_1) FROM tbl") - checkSparkAnswerAndOperator(res) + checkSparkAnswerAndOperator("SELECT avg(_1) FROM tbl") } // with GROUP BY withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiMode.toString) { - val res = sql("SELECT _2, avg(_1) FROM tbl GROUP BY _2") - checkSparkAnswerAndOperator(res) + checkSparkAnswerAndOperator("SELECT _2, avg(_1) FROM tbl GROUP BY _2") } }) // try_avg without GROUP BY - val resTry = sql("SELECT try_avg(_1) FROM tbl") - checkSparkAnswerAndOperator(resTry) + checkSparkAnswerAndOperator("SELECT try_avg(_1) FROM tbl") // try_avg with GROUP BY - val resTryGroup = sql("SELECT _2, try_avg(_1) FROM tbl GROUP BY _2") - checkSparkAnswerAndOperator(resTryGroup) + checkSparkAnswerAndOperator("SELECT _2, try_avg(_1) FROM tbl GROUP BY _2") } }) From b444cbac191aafaada4bd14a313405e805fd31e7 Mon Sep 17 00:00:00 2001 From: B Vadlamani Date: Tue, 23 Dec 2025 15:46:21 -0800 Subject: [PATCH 16/16] address_review_comments_fix_clippy --- .../test/scala/org/apache/comet/exec/CometAggregateSuite.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala index cd479c4158..14b5dc3092 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala @@ -1577,6 +1577,7 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper { "null_tbl") { val res = sql("SELECT sum(_1) FROM null_tbl") checkSparkAnswerAndOperator(res) + assert(res.collect() === Array(Row(null))) } } }