From fe8cfdd1e4f83b4fe790f9fe1819089ab65eeb79 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Sun, 14 Dec 2025 11:09:23 -0700 Subject: [PATCH 01/20] Add trait --- .../org/apache/comet/DataTypeSupport.scala | 2 + .../apache/comet/cost/CometCostModel.scala | 57 +++++++++++++++++++ 2 files changed, 59 insertions(+) create mode 100644 spark/src/main/scala/org/apache/comet/cost/CometCostModel.scala diff --git a/spark/src/main/scala/org/apache/comet/DataTypeSupport.scala b/spark/src/main/scala/org/apache/comet/DataTypeSupport.scala index 9adf829580..e5b0b2fadd 100644 --- a/spark/src/main/scala/org/apache/comet/DataTypeSupport.scala +++ b/spark/src/main/scala/org/apache/comet/DataTypeSupport.scala @@ -79,4 +79,6 @@ object DataTypeSupport { case _: StructType | _: ArrayType | _: MapType => true case _ => false } + + def hasComplexTypes(schema: StructType) = schema.fields.exists(f => isComplexType(f.dataType)) } diff --git a/spark/src/main/scala/org/apache/comet/cost/CometCostModel.scala b/spark/src/main/scala/org/apache/comet/cost/CometCostModel.scala new file mode 100644 index 0000000000..c5291342d3 --- /dev/null +++ b/spark/src/main/scala/org/apache/comet/cost/CometCostModel.scala @@ -0,0 +1,57 @@ +package org.apache.comet.cost + +import org.apache.comet.DataTypeSupport +import org.apache.spark.sql.catalyst.expressions.{Add, BinaryArithmetic, Expression} +import org.apache.spark.sql.comet.{CometColumnarToRowExec, CometPlan, CometProjectExec} +import org.apache.spark.sql.comet.execution.shuffle.{CometColumnarShuffle, CometNativeShuffle, CometShuffleExchangeExec} +import org.apache.spark.sql.execution.{ColumnarToRowExec, RowToColumnarExec, SparkPlan} +import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec + +case class CometCostEstimate(acceleration: Double) + +trait CometCostModel { + + /** Estimate the relative cost of one operator */ + def estimateCost(plan: SparkPlan): CometCostEstimate +} + +class DefaultComeCostModel extends CometCostModel { + + // optimistic default of 2x acceleration + private val defaultAcceleration = 2.0 + + override def estimateCost(plan: SparkPlan): CometCostEstimate = { + + plan match { + case op: CometShuffleExchangeExec => + op.shuffleType match { + case CometNativeShuffle => CometCostEstimate(1.5) + case CometColumnarShuffle => + if (DataTypeSupport.hasComplexTypes(op.schema)) { + CometCostEstimate(0.8) + } else { + CometCostEstimate(1.1) + } + } + case _: CometColumnarToRowExec => + CometCostEstimate(1.0) + case op: CometProjectExec => + val total: Double = op.expressions.map(estimateCost).sum + CometCostEstimate(total / op.expressions.length.toDouble) + case _: CometPlan => + CometCostEstimate(defaultAcceleration) + case _ => + // Spark operator + CometCostEstimate(1.0) + } + } + + /** Estimate the cost of an expression */ + def estimateCost(expr: Expression): Double = { + expr match { + case _: BinaryArithmetic => + 2.0 + case _ => defaultAcceleration + } + } +} From f05a6295909f6a750cda4ec931135d8ddfb17005 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Sun, 14 Dec 2025 11:14:16 -0700 Subject: [PATCH 02/20] format --- .../org/apache/comet/DataTypeSupport.scala | 3 +- .../apache/comet/cost/CometCostModel.scala | 39 ++++++++++++++----- 2 files changed, 31 insertions(+), 11 deletions(-) diff --git a/spark/src/main/scala/org/apache/comet/DataTypeSupport.scala b/spark/src/main/scala/org/apache/comet/DataTypeSupport.scala index e5b0b2fadd..694ae8b0e1 100644 --- a/spark/src/main/scala/org/apache/comet/DataTypeSupport.scala +++ b/spark/src/main/scala/org/apache/comet/DataTypeSupport.scala @@ -80,5 +80,6 @@ object DataTypeSupport { case _ => false } - def hasComplexTypes(schema: StructType) = schema.fields.exists(f => isComplexType(f.dataType)) + def hasComplexTypes(schema: StructType): Boolean = + schema.fields.exists(f => isComplexType(f.dataType)) } diff --git a/spark/src/main/scala/org/apache/comet/cost/CometCostModel.scala b/spark/src/main/scala/org/apache/comet/cost/CometCostModel.scala index c5291342d3..cc7ce8ab5a 100644 --- a/spark/src/main/scala/org/apache/comet/cost/CometCostModel.scala +++ b/spark/src/main/scala/org/apache/comet/cost/CometCostModel.scala @@ -1,11 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + package org.apache.comet.cost -import org.apache.comet.DataTypeSupport -import org.apache.spark.sql.catalyst.expressions.{Add, BinaryArithmetic, Expression} +import org.apache.spark.sql.catalyst.expressions.{BinaryArithmetic, Expression} import org.apache.spark.sql.comet.{CometColumnarToRowExec, CometPlan, CometProjectExec} import org.apache.spark.sql.comet.execution.shuffle.{CometColumnarShuffle, CometNativeShuffle, CometShuffleExchangeExec} -import org.apache.spark.sql.execution.{ColumnarToRowExec, RowToColumnarExec, SparkPlan} -import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec +import org.apache.spark.sql.execution.SparkPlan + +import org.apache.comet.DataTypeSupport case class CometCostEstimate(acceleration: Double) @@ -15,7 +34,7 @@ trait CometCostModel { def estimateCost(plan: SparkPlan): CometCostEstimate } -class DefaultComeCostModel extends CometCostModel { +class DefaultCometCostModel extends CometCostModel { // optimistic default of 2x acceleration private val defaultAcceleration = 2.0 @@ -27,11 +46,11 @@ class DefaultComeCostModel extends CometCostModel { op.shuffleType match { case CometNativeShuffle => CometCostEstimate(1.5) case CometColumnarShuffle => - if (DataTypeSupport.hasComplexTypes(op.schema)) { - CometCostEstimate(0.8) - } else { - CometCostEstimate(1.1) - } + if (DataTypeSupport.hasComplexTypes(op.schema)) { + CometCostEstimate(0.8) + } else { + CometCostEstimate(1.1) + } } case _: CometColumnarToRowExec => CometCostEstimate(1.0) From 6662c968e9af168efec200c058ec9d968e308395 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Sun, 14 Dec 2025 11:16:51 -0700 Subject: [PATCH 03/20] configs --- .../main/scala/org/apache/comet/CometConf.scala | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/common/src/main/scala/org/apache/comet/CometConf.scala b/common/src/main/scala/org/apache/comet/CometConf.scala index 7eaae60552..f7b90a2509 100644 --- a/common/src/main/scala/org/apache/comet/CometConf.scala +++ b/common/src/main/scala/org/apache/comet/CometConf.scala @@ -754,6 +754,23 @@ object CometConf extends ShimCometConf { .booleanConf .createWithEnvVarOrDefault("ENABLE_COMET_STRICT_TESTING", false) + val COMET_COST_BASED_OPTIMIZATION_ENABLED: ConfigEntry[Boolean] = + conf("spark.comet.cost.enabled") + .category(CATEGORY_TUNING) + .doc("Whether to enable cost-based optimization for Comet. When enabled, Comet will " + + "use a cost model to estimate acceleration factors for operators and make decisions " + + "about whether to use Comet or Spark operators based on estimated performance.") + .booleanConf + .createWithDefault(false) + + val COMET_COST_MODEL_CLASS: ConfigEntry[String] = + conf("spark.comet.cost.model.class") + .category(CATEGORY_TUNING) + .doc("The fully qualified class name of the cost model implementation to use for " + + "cost-based optimization. The class must implement the CometCostModel trait.") + .stringConf + .createWithDefault("org.apache.comet.cost.DefaultCometCostModel") + /** Create a config to enable a specific operator */ private def createExecEnabledConfig( exec: String, From 7f5996d87ecb1fa32390082a05ed19131caa223e Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Sun, 14 Dec 2025 11:31:10 -0700 Subject: [PATCH 04/20] integrate with Spark AQE --- .../scala/org/apache/comet/CometConf.scala | 7 +- .../comet/cost/CometCostEvaluator.scala | 98 +++++++++++++++++++ .../main/scala/org/apache/spark/Plugins.scala | 11 ++- 3 files changed, 112 insertions(+), 4 deletions(-) create mode 100644 spark/src/main/scala/org/apache/comet/cost/CometCostEvaluator.scala diff --git a/common/src/main/scala/org/apache/comet/CometConf.scala b/common/src/main/scala/org/apache/comet/CometConf.scala index f7b90a2509..a5cce453e2 100644 --- a/common/src/main/scala/org/apache/comet/CometConf.scala +++ b/common/src/main/scala/org/apache/comet/CometConf.scala @@ -757,9 +757,10 @@ object CometConf extends ShimCometConf { val COMET_COST_BASED_OPTIMIZATION_ENABLED: ConfigEntry[Boolean] = conf("spark.comet.cost.enabled") .category(CATEGORY_TUNING) - .doc("Whether to enable cost-based optimization for Comet. When enabled, Comet will " + - "use a cost model to estimate acceleration factors for operators and make decisions " + - "about whether to use Comet or Spark operators based on estimated performance.") + .doc( + "Whether to enable cost-based optimization for Comet. When enabled, Comet will " + + "use a cost model to estimate acceleration factors for operators and make decisions " + + "about whether to use Comet or Spark operators based on estimated performance.") .booleanConf .createWithDefault(false) diff --git a/spark/src/main/scala/org/apache/comet/cost/CometCostEvaluator.scala b/spark/src/main/scala/org/apache/comet/cost/CometCostEvaluator.scala new file mode 100644 index 0000000000..bd706ff21c --- /dev/null +++ b/spark/src/main/scala/org/apache/comet/cost/CometCostEvaluator.scala @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.cost + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.adaptive.{Cost, CostEvaluator} +import org.apache.spark.sql.internal.SQLConf + +import org.apache.comet.CometConf + +/** + * Simple Cost implementation for Comet cost evaluator. + */ +case class CometCost(value: Double) extends Cost { + override def compare(that: Cost): Int = that match { + case CometCost(thatValue) => java.lang.Double.compare(value, thatValue) + case _ => 0 // If we can't compare, assume equal + } + + override def toString: String = s"CometCost($value)" +} + +/** + * Comet implementation of Spark's CostEvaluator for adaptive query execution. + * + * This evaluator uses the configured CometCostModel to estimate costs for query plans, allowing + * Spark's adaptive query execution to make informed decisions about whether to use Comet or Spark + * operators based on estimated performance. + */ +class CometCostEvaluator extends CostEvaluator with Logging { + + @transient private lazy val costModel: CometCostModel = { + val conf = SQLConf.get + val costModelClass = CometConf.COMET_COST_MODEL_CLASS.get(conf) + + try { + // scalastyle:off classforname + val clazz = Class.forName(costModelClass) + // scalastyle:on classforname + val constructor = clazz.getConstructor() + constructor.newInstance().asInstanceOf[CometCostModel] + } catch { + case e: Exception => + logWarning( + s"Failed to instantiate cost model class '$costModelClass', " + + s"falling back to DefaultCometCostModel. Error: ${e.getMessage}") + new DefaultCometCostModel() + } + } + + /** + * Evaluates the cost of executing the given SparkPlan. + * + * This method uses the configured CometCostModel to estimate the acceleration factor for the + * plan, then converts it to a Cost object that Spark's adaptive query execution can use for + * decision making. + * + * @param plan + * The SparkPlan to evaluate + * @return + * A Cost representing the estimated execution cost + */ + override def evaluateCost(plan: SparkPlan): Cost = { + val estimate = costModel.estimateCost(plan) + + // Convert acceleration factor to cost + // Lower cost means better performance, so we use the inverse of acceleration factor + // For example: + // - 2.0x acceleration -> cost = 0.5 (half the cost) + // - 0.8x acceleration -> cost = 1.25 (25% more cost) + val costValue = 1.0 / estimate.acceleration + + logDebug( + s"Cost evaluation for ${plan.getClass.getSimpleName}: " + + s"acceleration=${estimate.acceleration}, cost=$costValue") + + // Create Cost object with the calculated value + CometCost(costValue) + } +} diff --git a/spark/src/main/scala/org/apache/spark/Plugins.scala b/spark/src/main/scala/org/apache/spark/Plugins.scala index 2529f08cfb..9765ab5a6c 100644 --- a/spark/src/main/scala/org/apache/spark/Plugins.scala +++ b/spark/src/main/scala/org/apache/spark/Plugins.scala @@ -28,7 +28,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.internal.config.{EXECUTOR_MEMORY, EXECUTOR_MEMORY_OVERHEAD, EXECUTOR_MEMORY_OVERHEAD_FACTOR} import org.apache.spark.sql.internal.StaticSQLConf -import org.apache.comet.CometConf.COMET_ONHEAP_ENABLED +import org.apache.comet.CometConf.{COMET_COST_BASED_OPTIMIZATION_ENABLED, COMET_ONHEAP_ENABLED} import org.apache.comet.CometSparkSessionExtensions /** @@ -57,6 +57,15 @@ class CometDriverPlugin extends DriverPlugin with Logging with ShimCometDriverPl // register CometSparkSessionExtensions if it isn't already registered CometDriverPlugin.registerCometSessionExtension(sc.conf) + // Enable cost-based optimization if configured + if (sc.getConf.getBoolean(COMET_COST_BASED_OPTIMIZATION_ENABLED.key, false)) { + // Set the custom cost evaluator for Spark's adaptive query execution + sc.conf.set( + "spark.sql.adaptive.customCostEvaluatorClass", + "org.apache.comet.cost.CometCostEvaluator") + logInfo("Enabled Comet cost-based optimization with CometCostEvaluator") + } + if (CometSparkSessionExtensions.shouldOverrideMemoryConf(sc.getConf)) { val execMemOverhead = if (sc.getConf.contains(EXECUTOR_MEMORY_OVERHEAD.key)) { sc.getConf.getSizeAsMb(EXECUTOR_MEMORY_OVERHEAD.key) From fde7f357eb6649059db835e3aca7f9ab52be94fb Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Sun, 14 Dec 2025 11:52:00 -0700 Subject: [PATCH 05/20] walk plan --- docs/source/user-guide/latest/configs.md | 2 ++ .../apache/comet/cost/CometCostModel.scala | 27 +++++++++++++++++++ 2 files changed, 29 insertions(+) diff --git a/docs/source/user-guide/latest/configs.md b/docs/source/user-guide/latest/configs.md index 5b416f927d..74b10e83ee 100644 --- a/docs/source/user-guide/latest/configs.md +++ b/docs/source/user-guide/latest/configs.md @@ -122,6 +122,8 @@ These settings can be used to determine which parts of the plan are accelerated | Config | Description | Default Value | |--------|-------------|---------------| | `spark.comet.batchSize` | The columnar batch size, i.e., the maximum number of rows that a batch can contain. | 8192 | +| `spark.comet.cost.enabled` | Whether to enable cost-based optimization for Comet. When enabled, Comet will use a cost model to estimate acceleration factors for operators and make decisions about whether to use Comet or Spark operators based on estimated performance. | false | +| `spark.comet.cost.model.class` | The fully qualified class name of the cost model implementation to use for cost-based optimization. The class must implement the CometCostModel trait. | org.apache.comet.cost.DefaultCometCostModel | | `spark.comet.exec.memoryPool` | The type of memory pool to be used for Comet native execution when running Spark in off-heap mode. Available pool types are `greedy_unified` and `fair_unified`. For more information, refer to the [Comet Tuning Guide](https://datafusion.apache.org/comet/user-guide/tuning.html). | fair_unified | | `spark.comet.exec.memoryPool.fraction` | Fraction of off-heap memory pool that is available to Comet. Only applies to off-heap mode. For more information, refer to the [Comet Tuning Guide](https://datafusion.apache.org/comet/user-guide/tuning.html). | 1.0 | | `spark.comet.tracing.enabled` | Enable fine-grained tracing of events and memory usage. For more information, refer to the [Comet Tracing Guide](https://datafusion.apache.org/comet/contributor-guide/tracing.html). | false | diff --git a/spark/src/main/scala/org/apache/comet/cost/CometCostModel.scala b/spark/src/main/scala/org/apache/comet/cost/CometCostModel.scala index cc7ce8ab5a..6e86115dc8 100644 --- a/spark/src/main/scala/org/apache/comet/cost/CometCostModel.scala +++ b/spark/src/main/scala/org/apache/comet/cost/CometCostModel.scala @@ -40,7 +40,34 @@ class DefaultCometCostModel extends CometCostModel { private val defaultAcceleration = 2.0 override def estimateCost(plan: SparkPlan): CometCostEstimate = { + // Walk the entire plan tree and accumulate costs + var totalAcceleration = 0.0 + var operatorCount = 0 + def collectOperatorCosts(node: SparkPlan): Unit = { + val operatorCost = estimateOperatorCost(node) + totalAcceleration += operatorCost.acceleration + operatorCount += 1 + + // Recursively process children + node.children.foreach(collectOperatorCosts) + } + + collectOperatorCosts(plan) + + // Calculate average acceleration across all operators + // This is crude but gives us a starting point + val averageAcceleration = if (operatorCount > 0) { + totalAcceleration / operatorCount.toDouble + } else { + 1.0 // No acceleration if no operators + } + + CometCostEstimate(averageAcceleration) + } + + /** Estimate the cost of a single operator */ + private def estimateOperatorCost(plan: SparkPlan): CometCostEstimate = { plan match { case op: CometShuffleExchangeExec => op.shuffleType match { From 3f964c681ce2a528608abf901d66418054eb14b6 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Sun, 14 Dec 2025 11:53:47 -0700 Subject: [PATCH 06/20] Save --- .../src/main/scala/org/apache/comet/cost/CometCostModel.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/spark/src/main/scala/org/apache/comet/cost/CometCostModel.scala b/spark/src/main/scala/org/apache/comet/cost/CometCostModel.scala index 6e86115dc8..a979fda0c1 100644 --- a/spark/src/main/scala/org/apache/comet/cost/CometCostModel.scala +++ b/spark/src/main/scala/org/apache/comet/cost/CometCostModel.scala @@ -82,7 +82,7 @@ class DefaultCometCostModel extends CometCostModel { case _: CometColumnarToRowExec => CometCostEstimate(1.0) case op: CometProjectExec => - val total: Double = op.expressions.map(estimateCost).sum + val total: Double = op.expressions.map(estimateExpressionCost).sum CometCostEstimate(total / op.expressions.length.toDouble) case _: CometPlan => CometCostEstimate(defaultAcceleration) @@ -93,7 +93,7 @@ class DefaultCometCostModel extends CometCostModel { } /** Estimate the cost of an expression */ - def estimateCost(expr: Expression): Double = { + private def estimateExpressionCost(expr: Expression): Double = { expr match { case _: BinaryArithmetic => 2.0 From 95551f18b6715757e6f6a9febb31553c98c25a27 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Sun, 14 Dec 2025 15:23:08 -0700 Subject: [PATCH 07/20] more --- .../apache/comet/cost/CometCostModel.scala | 24 +++++++++++++++---- 1 file changed, 20 insertions(+), 4 deletions(-) diff --git a/spark/src/main/scala/org/apache/comet/cost/CometCostModel.scala b/spark/src/main/scala/org/apache/comet/cost/CometCostModel.scala index a979fda0c1..9fc54318bc 100644 --- a/spark/src/main/scala/org/apache/comet/cost/CometCostModel.scala +++ b/spark/src/main/scala/org/apache/comet/cost/CometCostModel.scala @@ -19,11 +19,11 @@ package org.apache.comet.cost -import org.apache.spark.sql.catalyst.expressions.{BinaryArithmetic, Expression} +import com.univocity.parsers.annotations.{Replace, Trim} +import org.apache.spark.sql.catalyst.expressions.{Ascii, BinaryArithmetic, Chr, ConcatWs, Expression, InitCap, Length, Lower, OctetLength, Reverse, StringSpace, StringTranslate, StringTrim, Substring, Upper} import org.apache.spark.sql.comet.{CometColumnarToRowExec, CometPlan, CometProjectExec} import org.apache.spark.sql.comet.execution.shuffle.{CometColumnarShuffle, CometNativeShuffle, CometShuffleExchangeExec} import org.apache.spark.sql.execution.SparkPlan - import org.apache.comet.DataTypeSupport case class CometCostEstimate(acceleration: Double) @@ -95,8 +95,24 @@ class DefaultCometCostModel extends CometCostModel { /** Estimate the cost of an expression */ private def estimateExpressionCost(expr: Expression): Double = { expr match { - case _: BinaryArithmetic => - 2.0 + // string expression numbers from CometStringExpressionBenchmark + case _: Substring => 6.3 + case _: Ascii => 0.6 + case _: Ascii => 0.6 + case _: OctetLength => 0.6 + case _: Lower => 3.0 + case _: Upper => 3.0 + case _: Chr => 0.6 + case _: InitCap => 0.9 + case _: StringTrim => 0.4 + case _: ConcatWs => 0.5 + case _: Length => 9.1 + // case _: Repeat => 0.4 + case _: Reverse => 6.9 + // case _: Instr => 0.6 + case _: Replace => 1.3 + case _: StringSpace => 0.8 + case _: StringTranslate => 0.8 case _ => defaultAcceleration } } From e3efad578a5481968caa2333484c05614dbf6139 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Sun, 14 Dec 2025 16:25:50 -0700 Subject: [PATCH 08/20] add TODO --- .../scala/org/apache/comet/cost/CometCostModel.scala | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/spark/src/main/scala/org/apache/comet/cost/CometCostModel.scala b/spark/src/main/scala/org/apache/comet/cost/CometCostModel.scala index 9fc54318bc..8b8c8fcdff 100644 --- a/spark/src/main/scala/org/apache/comet/cost/CometCostModel.scala +++ b/spark/src/main/scala/org/apache/comet/cost/CometCostModel.scala @@ -19,11 +19,13 @@ package org.apache.comet.cost -import com.univocity.parsers.annotations.{Replace, Trim} -import org.apache.spark.sql.catalyst.expressions.{Ascii, BinaryArithmetic, Chr, ConcatWs, Expression, InitCap, Length, Lower, OctetLength, Reverse, StringSpace, StringTranslate, StringTrim, Substring, Upper} +import org.apache.spark.sql.catalyst.expressions.{Ascii, Chr, ConcatWs, Expression, InitCap, Length, Lower, OctetLength, Reverse, StringSpace, StringTranslate, StringTrim, Substring, Upper} import org.apache.spark.sql.comet.{CometColumnarToRowExec, CometPlan, CometProjectExec} import org.apache.spark.sql.comet.execution.shuffle.{CometColumnarShuffle, CometNativeShuffle, CometShuffleExchangeExec} import org.apache.spark.sql.execution.SparkPlan + +import com.univocity.parsers.annotations.Replace + import org.apache.comet.DataTypeSupport case class CometCostEstimate(acceleration: Double) @@ -96,6 +98,10 @@ class DefaultCometCostModel extends CometCostModel { private def estimateExpressionCost(expr: Expression): Double = { expr match { // string expression numbers from CometStringExpressionBenchmark + // TODO this is matching on Spark expressions, which isn't correct since + // we need to look at the converted Comet expressions instead, but + // this code demonstrates what the goal is - having specific numbers + // based on current micro benchmarks case _: Substring => 6.3 case _: Ascii => 0.6 case _: Ascii => 0.6 From 6417e54faec3d502652515c8e5503d81db080deb Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Sun, 14 Dec 2025 17:00:42 -0700 Subject: [PATCH 09/20] fix --- .../apache/comet/cost/CometCostModel.scala | 69 +++++++++++-------- 1 file changed, 39 insertions(+), 30 deletions(-) diff --git a/spark/src/main/scala/org/apache/comet/cost/CometCostModel.scala b/spark/src/main/scala/org/apache/comet/cost/CometCostModel.scala index 8b8c8fcdff..08043737d6 100644 --- a/spark/src/main/scala/org/apache/comet/cost/CometCostModel.scala +++ b/spark/src/main/scala/org/apache/comet/cost/CometCostModel.scala @@ -19,14 +19,14 @@ package org.apache.comet.cost -import org.apache.spark.sql.catalyst.expressions.{Ascii, Chr, ConcatWs, Expression, InitCap, Length, Lower, OctetLength, Reverse, StringSpace, StringTranslate, StringTrim, Substring, Upper} +import scala.jdk.CollectionConverters._ + import org.apache.spark.sql.comet.{CometColumnarToRowExec, CometPlan, CometProjectExec} import org.apache.spark.sql.comet.execution.shuffle.{CometColumnarShuffle, CometNativeShuffle, CometShuffleExchangeExec} import org.apache.spark.sql.execution.SparkPlan -import com.univocity.parsers.annotations.Replace - import org.apache.comet.DataTypeSupport +import org.apache.comet.serde.{ExprOuterClass, OperatorOuterClass} case class CometCostEstimate(acceleration: Double) @@ -84,8 +84,13 @@ class DefaultCometCostModel extends CometCostModel { case _: CometColumnarToRowExec => CometCostEstimate(1.0) case op: CometProjectExec => - val total: Double = op.expressions.map(estimateExpressionCost).sum - CometCostEstimate(total / op.expressions.length.toDouble) + // Cast nativeOp to Operator and extract projection expressions + val operator = op.nativeOp.asInstanceOf[OperatorOuterClass.Operator] + val projection = operator.getProjection + val expressions = projection.getProjectListList.asScala + + val total: Double = expressions.map(estimateCometExpressionCost).sum + CometCostEstimate(total / expressions.length.toDouble) case _: CometPlan => CometCostEstimate(defaultAcceleration) case _ => @@ -94,32 +99,36 @@ class DefaultCometCostModel extends CometCostModel { } } - /** Estimate the cost of an expression */ - private def estimateExpressionCost(expr: Expression): Double = { - expr match { - // string expression numbers from CometStringExpressionBenchmark - // TODO this is matching on Spark expressions, which isn't correct since - // we need to look at the converted Comet expressions instead, but - // this code demonstrates what the goal is - having specific numbers - // based on current micro benchmarks - case _: Substring => 6.3 - case _: Ascii => 0.6 - case _: Ascii => 0.6 - case _: OctetLength => 0.6 - case _: Lower => 3.0 - case _: Upper => 3.0 - case _: Chr => 0.6 - case _: InitCap => 0.9 - case _: StringTrim => 0.4 - case _: ConcatWs => 0.5 - case _: Length => 9.1 - // case _: Repeat => 0.4 - case _: Reverse => 6.9 - // case _: Instr => 0.6 - case _: Replace => 1.3 - case _: StringSpace => 0.8 - case _: StringTranslate => 0.8 + /** Estimate the cost of a Comet protobuf expression */ + private def estimateCometExpressionCost(expr: ExprOuterClass.Expr): Double = { + expr.getExprStructCase match { + // Handle specialized expression types + case ExprOuterClass.Expr.ExprStructCase.SUBSTRING => 6.3 + + // Handle generic scalar functions + case ExprOuterClass.Expr.ExprStructCase.SCALARFUNC => + expr.getScalarFunc.getFunc match { + // String expression numbers from CometStringExpressionBenchmark + case "ascii" => 0.6 + case "octet_length" => 0.6 + case "lower" => 3.0 + case "upper" => 3.0 + case "char" => 0.6 + case "initcap" => 0.9 + case "trim" => 0.4 + case "concat_ws" => 0.5 + case "length" => 9.1 + case "repeat" => 0.4 + case "reverse" => 6.9 + case "instr" => 0.6 + case "replace" => 1.3 + case "string_space" => 0.8 + case "translate" => 0.8 + case _ => defaultAcceleration + } + case _ => defaultAcceleration } } + } From 7227f25246df11be3b248b8bb5abcec34f677ed2 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Sun, 14 Dec 2025 17:03:42 -0700 Subject: [PATCH 10/20] format --- .../org/apache/comet/cost/CometCostModel.scala | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/spark/src/main/scala/org/apache/comet/cost/CometCostModel.scala b/spark/src/main/scala/org/apache/comet/cost/CometCostModel.scala index 08043737d6..d2508fdb72 100644 --- a/spark/src/main/scala/org/apache/comet/cost/CometCostModel.scala +++ b/spark/src/main/scala/org/apache/comet/cost/CometCostModel.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.comet.execution.shuffle.{CometColumnarShuffle, Comet import org.apache.spark.sql.execution.SparkPlan import org.apache.comet.DataTypeSupport -import org.apache.comet.serde.{ExprOuterClass, OperatorOuterClass} +import org.apache.comet.serde.ExprOuterClass case class CometCostEstimate(acceleration: Double) @@ -71,6 +71,10 @@ class DefaultCometCostModel extends CometCostModel { /** Estimate the cost of a single operator */ private def estimateOperatorCost(plan: SparkPlan): CometCostEstimate = { plan match { + case op: CometProjectExec => + val expressions = op.nativeOp.getProjection.getProjectListList.asScala + val total: Double = expressions.map(estimateCometExpressionCost).sum + CometCostEstimate(total / expressions.length.toDouble) case op: CometShuffleExchangeExec => op.shuffleType match { case CometNativeShuffle => CometCostEstimate(1.5) @@ -83,14 +87,6 @@ class DefaultCometCostModel extends CometCostModel { } case _: CometColumnarToRowExec => CometCostEstimate(1.0) - case op: CometProjectExec => - // Cast nativeOp to Operator and extract projection expressions - val operator = op.nativeOp.asInstanceOf[OperatorOuterClass.Operator] - val projection = operator.getProjection - val expressions = projection.getProjectListList.asScala - - val total: Double = expressions.map(estimateCometExpressionCost).sum - CometCostEstimate(total / expressions.length.toDouble) case _: CometPlan => CometCostEstimate(defaultAcceleration) case _ => From 927edc51735ab9e3c4964f6a76df55cccbcc9f7a Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Sun, 14 Dec 2025 17:20:55 -0700 Subject: [PATCH 11/20] test --- .../apache/comet/CometCostModelSuite.scala | 187 ++++++++++++++++++ 1 file changed, 187 insertions(+) create mode 100644 spark/src/test/scala/org/apache/comet/CometCostModelSuite.scala diff --git a/spark/src/test/scala/org/apache/comet/CometCostModelSuite.scala b/spark/src/test/scala/org/apache/comet/CometCostModelSuite.scala new file mode 100644 index 0000000000..508a381102 --- /dev/null +++ b/spark/src/test/scala/org/apache/comet/CometCostModelSuite.scala @@ -0,0 +1,187 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet + +import org.apache.spark.sql.{CometTestBase, DataFrame} +import org.apache.spark.sql.comet.CometProjectExec +import org.apache.spark.sql.execution.ProjectExec +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.internal.SQLConf + +class CometCostModelSuite extends CometTestBase { + + // Fast expressions in Comet (high acceleration factor -> low cost -> preferred) + // Based on CometCostModel estimates: + // - length: 9.1x -> cost = 1/9.1 = 0.11 (very low cost, Comet preferred) + // - reverse: 6.9x -> cost = 1/6.9 = 0.145 (low cost, Comet preferred) + + // Slow expressions in Comet (low acceleration factor -> high cost -> Spark preferred) + // - trim: 0.4x -> cost = 1/0.4 = 2.5 (high cost, Spark preferred) + // - ascii: 0.6x -> cost = 1/0.6 = 1.67 (high cost, Spark preferred) + + test("CBO should prefer Comet for fast expressions (length)") { + withSQLConf( + CometConf.COMET_ENABLED.key -> "true", + CometConf.COMET_COST_BASED_OPTIMIZATION_ENABLED.key -> "true", + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", + CometConf.COMET_EXEC_ENABLED.key -> "true", + CometConf.COMET_EXEC_PROJECT_ENABLED.key -> "true") { + + withTempView("test_data") { + // Create test data + import testImplicits._ + val df = Seq( + ("hello world", "test string"), + ("comet rocks", "another test"), + ("fast execution", "performance")).toDF("text1", "text2") + df.createOrReplaceTempView("test_data") + + // Query using length function (fast in Comet: 9.1x acceleration) + val query = "SELECT length(text1) as len1, length(text2) as len2 FROM test_data" + val result = sql(query) + + // Execute the query to materialize the plan + result.collect() + + // Check that Comet is used for the projection (due to low cost) + val executedPlan = stripAQEPlan(result.queryExecution.executedPlan) + val hasProjectExec = findProjectExec(executedPlan) + + // With CBO enabled and fast expressions, we should see CometProjectExec + assert(hasProjectExec.isDefined, "Should have a project operator") + assert( + hasProjectExec.get.isInstanceOf[CometProjectExec], + s"Expected CometProjectExec for fast expression, got ${hasProjectExec.get.getClass.getSimpleName}") + } + } + } + + test("CBO should prefer Spark for slow expressions (trim)") { + withSQLConf( + CometConf.COMET_ENABLED.key -> "true", + CometConf.COMET_COST_BASED_OPTIMIZATION_ENABLED.key -> "true", + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", + CometConf.COMET_EXEC_ENABLED.key -> "true", + CometConf.COMET_EXEC_PROJECT_ENABLED.key -> "true") { + + withTempView("test_data") { + // Create test data + import testImplicits._ + val df = Seq( + (" hello world ", " test string "), + (" comet rocks ", " another test "), + (" slow execution ", " performance ")).toDF("text1", "text2") + df.createOrReplaceTempView("test_data") + + // Query using trim function (slow in Comet: 0.4x acceleration) + val query = "SELECT trim(text1) as trimmed1, trim(text2) as trimmed2 FROM test_data" + val result = sql(query) + + // Execute the query to materialize the plan + result.collect() + + // Check that Spark is used for the projection (due to high cost) + val executedPlan = stripAQEPlan(result.queryExecution.executedPlan) + val hasProjectExec = findProjectExec(executedPlan) + + // With CBO enabled and slow expressions, we should see Spark ProjectExec + assert(hasProjectExec.isDefined, "Should have a project operator") + assert( + hasProjectExec.get.isInstanceOf[ProjectExec], + s"Expected Spark ProjectExec for slow expression, got ${hasProjectExec.get.getClass.getSimpleName}") + } + } + } + + test("Without CBO, Comet should be used regardless of expression cost") { + withSQLConf( + CometConf.COMET_ENABLED.key -> "true", + CometConf.COMET_COST_BASED_OPTIMIZATION_ENABLED.key -> "false", // CBO disabled + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", + CometConf.COMET_EXEC_ENABLED.key -> "true", + CometConf.COMET_EXEC_PROJECT_ENABLED.key -> "true") { + + withTempView("test_data") { + import testImplicits._ + val df = Seq((" hello world ", " test string ")).toDF("text1", "text2") + df.createOrReplaceTempView("test_data") + + // Query using trim function (slow in Comet, but CBO is disabled) + val query = "SELECT trim(text1) as trimmed1 FROM test_data" + val result = sql(query) + + result.collect() + + val executedPlan = stripAQEPlan(result.queryExecution.executedPlan) + val hasProjectExec = findProjectExec(executedPlan) + + // Without CBO, Comet should be used even for slow expressions + assert(hasProjectExec.isDefined, "Should have a project operator") + assert( + hasProjectExec.get.isInstanceOf[CometProjectExec], + s"Expected CometProjectExec when CBO disabled, got ${hasProjectExec.get.getClass.getSimpleName}") + } + } + } + + test("Mixed expressions should use appropriate operators per expression cost") { + withSQLConf( + CometConf.COMET_ENABLED.key -> "true", + CometConf.COMET_COST_BASED_OPTIMIZATION_ENABLED.key -> "true", + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", + CometConf.COMET_EXEC_ENABLED.key -> "true", + CometConf.COMET_EXEC_PROJECT_ENABLED.key -> "true") { + + withTempView("test_data") { + import testImplicits._ + val df = Seq(("hello world", "test")).toDF("text1", "text2") + df.createOrReplaceTempView("test_data") + + // Query mixing fast (length: 9.1x) and slow (ascii: 0.6x) expressions + val query = "SELECT length(text1) as fast_expr, ascii(text1) as slow_expr FROM test_data" + val result = sql(query) + + result.collect() + + // The overall cost should be average: (9.1 + 0.6) / 2 = 4.85 + // Cost = 1/4.85 = 0.206, which should still prefer Comet + val executedPlan = stripAQEPlan(result.queryExecution.executedPlan) + val hasProjectExec = findProjectExec(executedPlan) + + assert(hasProjectExec.isDefined, "Should have a project operator") + // Mixed expressions with overall positive acceleration should use Comet + assert( + hasProjectExec.get.isInstanceOf[CometProjectExec], + s"Expected CometProjectExec for mixed expressions with positive average, got ${hasProjectExec.get.getClass.getSimpleName}") + } + } + } + + /** + * Helper method to find ProjectExec or CometProjectExec in the plan tree + */ + private def findProjectExec(plan: SparkPlan): Option[SparkPlan] = { + if (plan.isInstanceOf[ProjectExec] || plan.isInstanceOf[CometProjectExec]) { + Some(plan) + } else { + plan.children.flatMap(findProjectExec).headOption + } + } +} From bc2ced7f786f57ebab6a6cd57713b7b19151b17d Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Sun, 14 Dec 2025 17:23:13 -0700 Subject: [PATCH 12/20] test --- .../apache/comet/CometCostModelSuite.scala | 203 ++++++++---------- 1 file changed, 89 insertions(+), 114 deletions(-) diff --git a/spark/src/test/scala/org/apache/comet/CometCostModelSuite.scala b/spark/src/test/scala/org/apache/comet/CometCostModelSuite.scala index 508a381102..ca04df2df1 100644 --- a/spark/src/test/scala/org/apache/comet/CometCostModelSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometCostModelSuite.scala @@ -37,146 +37,121 @@ class CometCostModelSuite extends CometTestBase { // - ascii: 0.6x -> cost = 1/0.6 = 1.67 (high cost, Spark preferred) test("CBO should prefer Comet for fast expressions (length)") { - withSQLConf( - CometConf.COMET_ENABLED.key -> "true", - CometConf.COMET_COST_BASED_OPTIMIZATION_ENABLED.key -> "true", - SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", - CometConf.COMET_EXEC_ENABLED.key -> "true", - CometConf.COMET_EXEC_PROJECT_ENABLED.key -> "true") { - - withTempView("test_data") { - // Create test data - import testImplicits._ - val df = Seq( - ("hello world", "test string"), - ("comet rocks", "another test"), - ("fast execution", "performance")).toDF("text1", "text2") - df.createOrReplaceTempView("test_data") - - // Query using length function (fast in Comet: 9.1x acceleration) - val query = "SELECT length(text1) as len1, length(text2) as len2 FROM test_data" - val result = sql(query) - - // Execute the query to materialize the plan - result.collect() - - // Check that Comet is used for the projection (due to low cost) - val executedPlan = stripAQEPlan(result.queryExecution.executedPlan) - val hasProjectExec = findProjectExec(executedPlan) - - // With CBO enabled and fast expressions, we should see CometProjectExec - assert(hasProjectExec.isDefined, "Should have a project operator") - assert( - hasProjectExec.get.isInstanceOf[CometProjectExec], - s"Expected CometProjectExec for fast expression, got ${hasProjectExec.get.getClass.getSimpleName}") - } + withCBOEnabled { + createSimpleTestData() + val query = "SELECT length(text1) as len1, length(text2) as len2 FROM test_data" + + executeAndCheckOperator( + query, + classOf[CometProjectExec], + "Expected CometProjectExec for fast expression") } } test("CBO should prefer Spark for slow expressions (trim)") { - withSQLConf( - CometConf.COMET_ENABLED.key -> "true", - CometConf.COMET_COST_BASED_OPTIMIZATION_ENABLED.key -> "true", - SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", - CometConf.COMET_EXEC_ENABLED.key -> "true", - CometConf.COMET_EXEC_PROJECT_ENABLED.key -> "true") { - - withTempView("test_data") { - // Create test data - import testImplicits._ - val df = Seq( - (" hello world ", " test string "), - (" comet rocks ", " another test "), - (" slow execution ", " performance ")).toDF("text1", "text2") - df.createOrReplaceTempView("test_data") - - // Query using trim function (slow in Comet: 0.4x acceleration) - val query = "SELECT trim(text1) as trimmed1, trim(text2) as trimmed2 FROM test_data" - val result = sql(query) - - // Execute the query to materialize the plan - result.collect() - - // Check that Spark is used for the projection (due to high cost) - val executedPlan = stripAQEPlan(result.queryExecution.executedPlan) - val hasProjectExec = findProjectExec(executedPlan) - - // With CBO enabled and slow expressions, we should see Spark ProjectExec - assert(hasProjectExec.isDefined, "Should have a project operator") - assert( - hasProjectExec.get.isInstanceOf[ProjectExec], - s"Expected Spark ProjectExec for slow expression, got ${hasProjectExec.get.getClass.getSimpleName}") - } + withCBOEnabled { + createPaddedTestData() + val query = "SELECT trim(text1) as trimmed1, trim(text2) as trimmed2 FROM test_data" + + executeAndCheckOperator( + query, + classOf[ProjectExec], + "Expected Spark ProjectExec for slow expression") } } test("Without CBO, Comet should be used regardless of expression cost") { + withCBODisabled { + createPaddedTestData() + val query = "SELECT trim(text1) as trimmed1 FROM test_data" + + executeAndCheckOperator( + query, + classOf[CometProjectExec], + "Expected CometProjectExec when CBO disabled") + } + } + + test("Mixed expressions should use appropriate operators per expression cost") { + withCBOEnabled { + createSimpleTestData() + // Query mixing fast (length: 9.1x) and slow (ascii: 0.6x) expressions + // Average acceleration: (9.1 + 0.6) / 2 = 4.85x -> cost = 0.206 (still prefer Comet) + val query = "SELECT length(text1) as fast_expr, ascii(text1) as slow_expr FROM test_data" + + executeAndCheckOperator( + query, + classOf[CometProjectExec], + "Expected CometProjectExec for mixed expressions with positive average") + } + } + + /** Helper method to run tests with CBO enabled */ + private def withCBOEnabled(f: => Unit): Unit = { withSQLConf( CometConf.COMET_ENABLED.key -> "true", - CometConf.COMET_COST_BASED_OPTIMIZATION_ENABLED.key -> "false", // CBO disabled + CometConf.COMET_COST_BASED_OPTIMIZATION_ENABLED.key -> "true", SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", CometConf.COMET_EXEC_ENABLED.key -> "true", CometConf.COMET_EXEC_PROJECT_ENABLED.key -> "true") { - - withTempView("test_data") { - import testImplicits._ - val df = Seq((" hello world ", " test string ")).toDF("text1", "text2") - df.createOrReplaceTempView("test_data") - - // Query using trim function (slow in Comet, but CBO is disabled) - val query = "SELECT trim(text1) as trimmed1 FROM test_data" - val result = sql(query) - - result.collect() - - val executedPlan = stripAQEPlan(result.queryExecution.executedPlan) - val hasProjectExec = findProjectExec(executedPlan) - - // Without CBO, Comet should be used even for slow expressions - assert(hasProjectExec.isDefined, "Should have a project operator") - assert( - hasProjectExec.get.isInstanceOf[CometProjectExec], - s"Expected CometProjectExec when CBO disabled, got ${hasProjectExec.get.getClass.getSimpleName}") - } + f } } - test("Mixed expressions should use appropriate operators per expression cost") { + /** Helper method to run tests with CBO disabled */ + private def withCBODisabled(f: => Unit): Unit = { withSQLConf( CometConf.COMET_ENABLED.key -> "true", - CometConf.COMET_COST_BASED_OPTIMIZATION_ENABLED.key -> "true", + CometConf.COMET_COST_BASED_OPTIMIZATION_ENABLED.key -> "false", SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", CometConf.COMET_EXEC_ENABLED.key -> "true", CometConf.COMET_EXEC_PROJECT_ENABLED.key -> "true") { + f + } + } - withTempView("test_data") { - import testImplicits._ - val df = Seq(("hello world", "test")).toDF("text1", "text2") - df.createOrReplaceTempView("test_data") - - // Query mixing fast (length: 9.1x) and slow (ascii: 0.6x) expressions - val query = "SELECT length(text1) as fast_expr, ascii(text1) as slow_expr FROM test_data" - val result = sql(query) - - result.collect() - - // The overall cost should be average: (9.1 + 0.6) / 2 = 4.85 - // Cost = 1/4.85 = 0.206, which should still prefer Comet - val executedPlan = stripAQEPlan(result.queryExecution.executedPlan) - val hasProjectExec = findProjectExec(executedPlan) + /** Create simple test data for string operations */ + private def createSimpleTestData(): Unit = { + withTempView("test_data") { + import testImplicits._ + val df = Seq( + ("hello world", "test string"), + ("comet rocks", "another test"), + ("fast execution", "performance")).toDF("text1", "text2") + df.createOrReplaceTempView("test_data") + } + } - assert(hasProjectExec.isDefined, "Should have a project operator") - // Mixed expressions with overall positive acceleration should use Comet - assert( - hasProjectExec.get.isInstanceOf[CometProjectExec], - s"Expected CometProjectExec for mixed expressions with positive average, got ${hasProjectExec.get.getClass.getSimpleName}") - } + /** Create padded test data for trim operations */ + private def createPaddedTestData(): Unit = { + withTempView("test_data") { + import testImplicits._ + val df = Seq( + (" hello world ", " test string "), + (" comet rocks ", " another test "), + (" slow execution ", " performance ")).toDF("text1", "text2") + df.createOrReplaceTempView("test_data") } } - /** - * Helper method to find ProjectExec or CometProjectExec in the plan tree - */ + /** Execute query and check that the expected operator type is used */ + private def executeAndCheckOperator( + query: String, + expectedClass: Class[_], + message: String): Unit = { + val result = sql(query) + result.collect() // Materialize the plan + + val executedPlan = stripAQEPlan(result.queryExecution.executedPlan) + val hasProjectExec = findProjectExec(executedPlan) + + assert(hasProjectExec.isDefined, "Should have a project operator") + assert( + expectedClass.isInstance(hasProjectExec.get), + s"$message, got ${hasProjectExec.get.getClass.getSimpleName}") + } + + /** Helper method to find ProjectExec or CometProjectExec in the plan tree */ private def findProjectExec(plan: SparkPlan): Option[SparkPlan] = { if (plan.isInstanceOf[ProjectExec] || plan.isInstanceOf[CometProjectExec]) { Some(plan) From 8adc08004aecbdc55a841903faa25be98e60e07a Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Sun, 14 Dec 2025 17:30:28 -0700 Subject: [PATCH 13/20] test --- .../apache/comet/CometCostModelSuite.scala | 183 +++++++++++++----- 1 file changed, 135 insertions(+), 48 deletions(-) diff --git a/spark/src/test/scala/org/apache/comet/CometCostModelSuite.scala b/spark/src/test/scala/org/apache/comet/CometCostModelSuite.scala index ca04df2df1..30673d25b3 100644 --- a/spark/src/test/scala/org/apache/comet/CometCostModelSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometCostModelSuite.scala @@ -38,57 +38,81 @@ class CometCostModelSuite extends CometTestBase { test("CBO should prefer Comet for fast expressions (length)") { withCBOEnabled { - createSimpleTestData() - val query = "SELECT length(text1) as len1, length(text2) as len2 FROM test_data" - - executeAndCheckOperator( - query, - classOf[CometProjectExec], - "Expected CometProjectExec for fast expression") + withTempView("test_data") { + createSimpleTestData() + // Use a subquery to prevent projection pushdown + val query = """ + SELECT length(upper_text1) as len1, length(upper_text2) as len2 + FROM (SELECT upper(text1) as upper_text1, upper(text2) as upper_text2 FROM test_data) + """ + + executeAndCheckOperator( + query, + classOf[CometProjectExec], + "Expected CometProjectExec for fast expression") + } } } test("CBO should prefer Spark for slow expressions (trim)") { withCBOEnabled { - createPaddedTestData() - val query = "SELECT trim(text1) as trimmed1, trim(text2) as trimmed2 FROM test_data" - - executeAndCheckOperator( - query, - classOf[ProjectExec], - "Expected Spark ProjectExec for slow expression") + withTempView("test_data") { + createPaddedTestData() + // Use a subquery to prevent projection pushdown + val query = """ + SELECT trim(padded_text1) as trimmed1, trim(padded_text2) as trimmed2 + FROM (SELECT text1 as padded_text1, text2 as padded_text2 FROM test_data) + """ + + executeAndCheckOperator( + query, + classOf[ProjectExec], + "Expected Spark ProjectExec for slow expression") + } } } test("Without CBO, Comet should be used regardless of expression cost") { withCBODisabled { - createPaddedTestData() - val query = "SELECT trim(text1) as trimmed1 FROM test_data" - - executeAndCheckOperator( - query, - classOf[CometProjectExec], - "Expected CometProjectExec when CBO disabled") + withTempView("test_data") { + createPaddedTestData() + // Use a subquery to prevent projection pushdown + val query = """ + SELECT trim(padded_text1) as trimmed1 + FROM (SELECT text1 as padded_text1 FROM test_data) + """ + + executeAndCheckOperator( + query, + classOf[CometProjectExec], + "Expected CometProjectExec when CBO disabled") + } } } test("Mixed expressions should use appropriate operators per expression cost") { withCBOEnabled { - createSimpleTestData() - // Query mixing fast (length: 9.1x) and slow (ascii: 0.6x) expressions - // Average acceleration: (9.1 + 0.6) / 2 = 4.85x -> cost = 0.206 (still prefer Comet) - val query = "SELECT length(text1) as fast_expr, ascii(text1) as slow_expr FROM test_data" - - executeAndCheckOperator( - query, - classOf[CometProjectExec], - "Expected CometProjectExec for mixed expressions with positive average") + withTempView("test_data") { + createSimpleTestData() + // Query mixing fast (length: 9.1x) and slow (ascii: 0.6x) expressions with subquery + // Average acceleration: (9.1 + 0.6) / 2 = 4.85x -> cost = 0.206 (still prefer Comet) + val query = """ + SELECT length(base_text) as fast_expr, ascii(base_text) as slow_expr + FROM (SELECT text1 as base_text FROM test_data) + """ + + executeAndCheckOperator( + query, + classOf[CometProjectExec], + "Expected CometProjectExec for mixed expressions with positive average") + } } } /** Helper method to run tests with CBO enabled */ private def withCBOEnabled(f: => Unit): Unit = { withSQLConf( + CometConf.COMET_EXEC_LOCAL_TABLE_SCAN_ENABLED.key -> "true", CometConf.COMET_ENABLED.key -> "true", CometConf.COMET_COST_BASED_OPTIMIZATION_ENABLED.key -> "true", SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", @@ -110,28 +134,37 @@ class CometCostModelSuite extends CometTestBase { } } - /** Create simple test data for string operations */ + /** Create simple test data for string operations using parquet to prevent pushdown */ private def createSimpleTestData(): Unit = { - withTempView("test_data") { - import testImplicits._ - val df = Seq( - ("hello world", "test string"), - ("comet rocks", "another test"), - ("fast execution", "performance")).toDF("text1", "text2") - df.createOrReplaceTempView("test_data") - } + import testImplicits._ + val df = Seq( + ("hello world", "test string"), + ("comet rocks", "another test"), + ("fast execution", "performance")).toDF("text1", "text2") + + // Write to parquet and read back to prevent projection pushdown + val tempPath = s"${System.getProperty("java.io.tmpdir")}/comet_cost_test_${System.nanoTime()}" + df.write.mode("overwrite").parquet(tempPath) + + val parquetDf = spark.read.parquet(tempPath) + parquetDf.createOrReplaceTempView("test_data") } - /** Create padded test data for trim operations */ + /** Create padded test data for trim operations using parquet to prevent pushdown */ private def createPaddedTestData(): Unit = { - withTempView("test_data") { - import testImplicits._ - val df = Seq( - (" hello world ", " test string "), - (" comet rocks ", " another test "), - (" slow execution ", " performance ")).toDF("text1", "text2") - df.createOrReplaceTempView("test_data") - } + import testImplicits._ + val df = Seq( + (" hello world ", " test string "), + (" comet rocks ", " another test "), + (" slow execution ", " performance ")).toDF("text1", "text2") + + // Write to parquet and read back to prevent projection pushdown + val tempPath = + s"${System.getProperty("java.io.tmpdir")}/comet_cost_test_padded_${System.nanoTime()}" + df.write.mode("overwrite").parquet(tempPath) + + val parquetDf = spark.read.parquet(tempPath) + parquetDf.createOrReplaceTempView("test_data") } /** Execute query and check that the expected operator type is used */ @@ -143,6 +176,11 @@ class CometCostModelSuite extends CometTestBase { result.collect() // Materialize the plan val executedPlan = stripAQEPlan(result.queryExecution.executedPlan) + + // scalastyle:off + println(result.queryExecution.executedPlan) + println(executedPlan) + val hasProjectExec = findProjectExec(executedPlan) assert(hasProjectExec.isDefined, "Should have a project operator") @@ -159,4 +197,53 @@ class CometCostModelSuite extends CometTestBase { plan.children.flatMap(findProjectExec).headOption } } + + test("Direct cost model test - fast vs slow expressions") { + withCBOEnabled { + withTempView("test_data") { + createSimpleTestData() + + // Test the cost model directly on project operators + val lengthQuery = "SELECT length(text1) as len FROM test_data" + val trimQuery = "SELECT trim(text1) as trimmed FROM test_data" + + val lengthPlan = sql(lengthQuery).queryExecution.optimizedPlan + val trimPlan = sql(trimQuery).queryExecution.optimizedPlan + + // Find project nodes in the optimized plans + val lengthProject = findProjectInPlan(lengthPlan) + val trimProject = findProjectInPlan(trimPlan) + + if (lengthProject.isDefined && trimProject.isDefined) { + val costModel = new org.apache.comet.cost.DefaultCometCostModel() + + // Create mock Comet project operators (this would normally be done by the planner) + // For now, just verify the cost model has different estimates for different expressions + val lengthCost = 9.1 // Expected length acceleration + val trimCost = 0.4 // Expected trim acceleration + + assert( + lengthCost > trimCost, + s"Length ($lengthCost) should be faster than trim ($trimCost)") + + // Cost = 1/acceleration, so lower acceleration = higher cost + assert( + 1.0 / trimCost > 1.0 / lengthCost, + s"Trim cost (${1.0 / trimCost}) should be higher than length cost (${1.0 / lengthCost})") + } else { + // Skip test if projections are optimized away + cancel("Projections were optimized away - using integration tests instead") + } + } + } + } + + /** Helper to find Project nodes in a logical plan */ + private def findProjectInPlan(plan: org.apache.spark.sql.catalyst.plans.logical.LogicalPlan) + : Option[org.apache.spark.sql.catalyst.plans.logical.Project] = { + plan match { + case p: org.apache.spark.sql.catalyst.plans.logical.Project => Some(p) + case _ => plan.children.flatMap(findProjectInPlan).headOption + } + } } From 162ee477dd60bc26a55cbf98fa4332ba7febb4ae Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Sun, 14 Dec 2025 17:55:26 -0700 Subject: [PATCH 14/20] test --- .../comet/cost/CometCostEvaluator.scala | 6 +- .../apache/comet/cost/CometCostModel.scala | 78 +++++++++-- .../apache/comet/CometCostModelSuite.scala | 127 +++++++++++++++--- 3 files changed, 181 insertions(+), 30 deletions(-) diff --git a/spark/src/main/scala/org/apache/comet/cost/CometCostEvaluator.scala b/spark/src/main/scala/org/apache/comet/cost/CometCostEvaluator.scala index bd706ff21c..076d8bd502 100644 --- a/spark/src/main/scala/org/apache/comet/cost/CometCostEvaluator.scala +++ b/spark/src/main/scala/org/apache/comet/cost/CometCostEvaluator.scala @@ -88,9 +88,11 @@ class CometCostEvaluator extends CostEvaluator with Logging { // - 0.8x acceleration -> cost = 1.25 (25% more cost) val costValue = 1.0 / estimate.acceleration - logDebug( - s"Cost evaluation for ${plan.getClass.getSimpleName}: " + + // scalastyle:off println + println( + s"[CostEvaluator] Plan: ${plan.getClass.getSimpleName}, " + s"acceleration=${estimate.acceleration}, cost=$costValue") + // scalastyle:on println // Create Cost object with the calculated value CometCost(costValue) diff --git a/spark/src/main/scala/org/apache/comet/cost/CometCostModel.scala b/spark/src/main/scala/org/apache/comet/cost/CometCostModel.scala index d2508fdb72..3fb0ccb3c5 100644 --- a/spark/src/main/scala/org/apache/comet/cost/CometCostModel.scala +++ b/spark/src/main/scala/org/apache/comet/cost/CometCostModel.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.comet.execution.shuffle.{CometColumnarShuffle, Comet import org.apache.spark.sql.execution.SparkPlan import org.apache.comet.DataTypeSupport -import org.apache.comet.serde.ExprOuterClass +import org.apache.comet.serde.{ExprOuterClass, OperatorOuterClass} case class CometCostEstimate(acceleration: Double) @@ -48,6 +48,11 @@ class DefaultCometCostModel extends CometCostModel { def collectOperatorCosts(node: SparkPlan): Unit = { val operatorCost = estimateOperatorCost(node) + // scalastyle:off println + println( + s"[CostModel] Operator: ${node.getClass.getSimpleName}, " + + s"Cost: ${operatorCost.acceleration}") + // scalastyle:on println totalAcceleration += operatorCost.acceleration operatorCount += 1 @@ -65,16 +70,44 @@ class DefaultCometCostModel extends CometCostModel { 1.0 // No acceleration if no operators } + // scalastyle:off println + println( + s"[CostModel] Plan: ${plan.getClass.getSimpleName}, Total operators: $operatorCount, " + + s"Average acceleration: $averageAcceleration") + // scalastyle:on println + CometCostEstimate(averageAcceleration) } /** Estimate the cost of a single operator */ private def estimateOperatorCost(plan: SparkPlan): CometCostEstimate = { - plan match { + val result = plan match { case op: CometProjectExec => - val expressions = op.nativeOp.getProjection.getProjectListList.asScala - val total: Double = expressions.map(estimateCometExpressionCost).sum - CometCostEstimate(total / expressions.length.toDouble) + // scalastyle:off println + println(s"[CostModel] CometProjectExec found - evaluating expressions") + // scalastyle:on println + // Cast nativeOp to Operator and extract projection expressions + val operator = op.nativeOp.asInstanceOf[OperatorOuterClass.Operator] + val projection = operator.getProjection + val expressions = projection.getProjectListList.asScala + // scalastyle:off println + println(s"[CostModel] Found ${expressions.length} expressions in projection") + // scalastyle:on println + + val costs = expressions.map { expr => + val cost = estimateCometExpressionCost(expr) + // scalastyle:off println + println(s"[CostModel] Expression cost: $cost") + // scalastyle:on println + cost + } + val total = costs.sum + val average = total / expressions.length.toDouble + // scalastyle:off println + println(s"[CostModel] CometProjectExec total cost: $total, average: $average") + // scalastyle:on println + CometCostEstimate(average) + case op: CometShuffleExchangeExec => op.shuffleType match { case CometNativeShuffle => CometCostEstimate(1.5) @@ -88,22 +121,38 @@ class DefaultCometCostModel extends CometCostModel { case _: CometColumnarToRowExec => CometCostEstimate(1.0) case _: CometPlan => + // scalastyle:off println + println(s"[CostModel] Generic CometPlan: ${plan.getClass.getSimpleName}") + // scalastyle:on println CometCostEstimate(defaultAcceleration) case _ => + // scalastyle:off println + println(s"[CostModel] Non-Comet operator: ${plan.getClass.getSimpleName}") + // scalastyle:on println // Spark operator CometCostEstimate(1.0) } + + // scalastyle:off println + println(s"[CostModel] ${plan.getClass.getSimpleName} -> acceleration: ${result.acceleration}") + // scalastyle:on println + result } /** Estimate the cost of a Comet protobuf expression */ private def estimateCometExpressionCost(expr: ExprOuterClass.Expr): Double = { - expr.getExprStructCase match { + val result = expr.getExprStructCase match { // Handle specialized expression types - case ExprOuterClass.Expr.ExprStructCase.SUBSTRING => 6.3 + case ExprOuterClass.Expr.ExprStructCase.SUBSTRING => + // scalastyle:off println + println(s"[CostModel] Expression: SUBSTRING -> 6.3") + // scalastyle:on println + 6.3 // Handle generic scalar functions case ExprOuterClass.Expr.ExprStructCase.SCALARFUNC => - expr.getScalarFunc.getFunc match { + val funcName = expr.getScalarFunc.getFunc + val cost = funcName match { // String expression numbers from CometStringExpressionBenchmark case "ascii" => 0.6 case "octet_length" => 0.6 @@ -122,9 +171,20 @@ class DefaultCometCostModel extends CometCostModel { case "translate" => 0.8 case _ => defaultAcceleration } + // scalastyle:off println + println(s"[CostModel] Expression: SCALARFUNC($funcName) -> $cost") + // scalastyle:on println + cost - case _ => defaultAcceleration + case _ => + // scalastyle:off println + println( + s"[CostModel] Expression: Unknown type ${expr.getExprStructCase} -> " + + s"$defaultAcceleration") + // scalastyle:on println + defaultAcceleration } + result } } diff --git a/spark/src/test/scala/org/apache/comet/CometCostModelSuite.scala b/spark/src/test/scala/org/apache/comet/CometCostModelSuite.scala index 30673d25b3..cddfb225c6 100644 --- a/spark/src/test/scala/org/apache/comet/CometCostModelSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometCostModelSuite.scala @@ -40,10 +40,13 @@ class CometCostModelSuite extends CometTestBase { withCBOEnabled { withTempView("test_data") { createSimpleTestData() - // Use a subquery to prevent projection pushdown + // Create a more complex query that will trigger AQE with joins/aggregations val query = """ - SELECT length(upper_text1) as len1, length(upper_text2) as len2 - FROM (SELECT upper(text1) as upper_text1, upper(text2) as upper_text2 FROM test_data) + SELECT t1.len1, t2.len2, COUNT(*) as cnt + FROM (SELECT length(text1) as len1, text1 FROM test_data) t1 + JOIN (SELECT length(text2) as len2, text2 FROM test_data) t2 + ON t1.text1 = t2.text2 + GROUP BY t1.len1, t2.len2 """ executeAndCheckOperator( @@ -58,10 +61,13 @@ class CometCostModelSuite extends CometTestBase { withCBOEnabled { withTempView("test_data") { createPaddedTestData() - // Use a subquery to prevent projection pushdown + // Create a more complex query that will trigger AQE with joins/aggregations val query = """ - SELECT trim(padded_text1) as trimmed1, trim(padded_text2) as trimmed2 - FROM (SELECT text1 as padded_text1, text2 as padded_text2 FROM test_data) + SELECT t1.trimmed1, t2.trimmed2, COUNT(*) as cnt + FROM (SELECT trim(text1) as trimmed1, text1 FROM test_data) t1 + JOIN (SELECT trim(text2) as trimmed2, text2 FROM test_data) t2 + ON t1.text1 = t2.text2 + GROUP BY t1.trimmed1, t2.trimmed2 """ executeAndCheckOperator( @@ -76,10 +82,11 @@ class CometCostModelSuite extends CometTestBase { withCBODisabled { withTempView("test_data") { createPaddedTestData() - // Use a subquery to prevent projection pushdown + // Complex query without CBO val query = """ - SELECT trim(padded_text1) as trimmed1 - FROM (SELECT text1 as padded_text1 FROM test_data) + SELECT trim(text1) as trimmed1, COUNT(*) as cnt + FROM test_data + GROUP BY trim(text1) """ executeAndCheckOperator( @@ -94,11 +101,12 @@ class CometCostModelSuite extends CometTestBase { withCBOEnabled { withTempView("test_data") { createSimpleTestData() - // Query mixing fast (length: 9.1x) and slow (ascii: 0.6x) expressions with subquery + // Query mixing fast (length: 9.1x) and slow (ascii: 0.6x) expressions with aggregation // Average acceleration: (9.1 + 0.6) / 2 = 4.85x -> cost = 0.206 (still prefer Comet) val query = """ - SELECT length(base_text) as fast_expr, ascii(base_text) as slow_expr - FROM (SELECT text1 as base_text FROM test_data) + SELECT length(text1) as fast_expr, ascii(text1) as slow_expr, COUNT(*) as cnt + FROM test_data + GROUP BY length(text1), ascii(text1) """ executeAndCheckOperator( @@ -117,7 +125,29 @@ class CometCostModelSuite extends CometTestBase { CometConf.COMET_COST_BASED_OPTIMIZATION_ENABLED.key -> "true", SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", CometConf.COMET_EXEC_ENABLED.key -> "true", - CometConf.COMET_EXEC_PROJECT_ENABLED.key -> "true") { + CometConf.COMET_EXEC_PROJECT_ENABLED.key -> "true", + CometConf.COMET_EXEC_AGGREGATE_ENABLED.key -> "true", // Enable aggregation for GROUP BY + CometConf.COMET_EXEC_HASH_JOIN_ENABLED.key -> "true", // Enable joins + // Manually set the custom cost evaluator since plugin might not be loaded + "spark.sql.adaptive.customCostEvaluatorClass" -> "org.apache.comet.cost.CometCostEvaluator", + // Lower AQE thresholds to ensure it triggers on small test data + "spark.sql.adaptive.advisoryPartitionSizeInBytes" -> "1KB", + "spark.sql.adaptive.coalescePartitions.minPartitionSize" -> "1B") { + + println(s"\n=== CBO Configuration ===") + println(s"COMET_ENABLED: ${spark.conf.get(CometConf.COMET_ENABLED.key)}") + println(s"COMET_COST_BASED_OPTIMIZATION_ENABLED: ${spark.conf.get( + CometConf.COMET_COST_BASED_OPTIMIZATION_ENABLED.key)}") + println( + s"ADAPTIVE_EXECUTION_ENABLED: ${spark.conf.get(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key)}") + println(s"COMET_EXEC_ENABLED: ${spark.conf.get(CometConf.COMET_EXEC_ENABLED.key)}") + println( + s"COMET_EXEC_PROJECT_ENABLED: ${spark.conf.get(CometConf.COMET_EXEC_PROJECT_ENABLED.key)}") + + // Check if custom cost evaluator is set + val costEvaluator = spark.conf.getOption("spark.sql.adaptive.customCostEvaluatorClass") + println(s"Custom cost evaluator: ${costEvaluator.getOrElse("None")}") + f } } @@ -172,30 +202,89 @@ class CometCostModelSuite extends CometTestBase { query: String, expectedClass: Class[_], message: String): Unit = { + + println(s"\n=== Executing Query ===") + println(s"Query: $query") + println(s"Expected class: ${expectedClass.getSimpleName}") + val result = sql(query) - result.collect() // Materialize the plan - val executedPlan = stripAQEPlan(result.queryExecution.executedPlan) + println(s"\n=== Pre-execution Plans ===") + println("Logical Plan:") + println(result.queryExecution.logical) + println("\nOptimized Plan:") + println(result.queryExecution.optimizedPlan) + println("\nSpark Plan:") + println(result.queryExecution.sparkPlan) - // scalastyle:off + result.collect() // Materialize the plan + + println(s"\n=== Post-execution Plans ===") + println("Executed Plan (with AQE wrappers):") println(result.queryExecution.executedPlan) + + val executedPlan = stripAQEPlan(result.queryExecution.executedPlan) + println("\nExecuted Plan (stripped AQE):") println(executedPlan) + // Enhanced debugging: show complete plan tree structure + println("\n=== Plan Tree Analysis ===") + debugPlanTree(executedPlan, 0) + val hasProjectExec = findProjectExec(executedPlan) + println(s"\n=== Project Analysis ===") + println(s"Found project exec: ${hasProjectExec.isDefined}") + if (hasProjectExec.isDefined) { + println(s"Actual class: ${hasProjectExec.get.getClass.getSimpleName}") + println(s"Expected class: ${expectedClass.getSimpleName}") + println(s"Is expected type: ${expectedClass.isInstance(hasProjectExec.get)}") + } + assert(hasProjectExec.isDefined, "Should have a project operator") assert( expectedClass.isInstance(hasProjectExec.get), s"$message, got ${hasProjectExec.get.getClass.getSimpleName}") + + println(s"=== Test PASSED ===\n") } /** Helper method to find ProjectExec or CometProjectExec in the plan tree */ private def findProjectExec(plan: SparkPlan): Option[SparkPlan] = { + // More robust recursive search that handles deep nesting + def searchPlan(node: SparkPlan): Option[SparkPlan] = { + println(s"[findProjectExec] Checking node: ${node.getClass.getSimpleName}") + + if (node.isInstanceOf[ProjectExec] || node.isInstanceOf[CometProjectExec]) { + println(s"[findProjectExec] Found project operator: ${node.getClass.getSimpleName}") + Some(node) + } else { + // Search all children recursively + for (child <- node.children) { + searchPlan(child) match { + case Some(found) => return Some(found) + case None => // continue searching + } + } + None + } + } + + searchPlan(plan) + } + + /** Debug method to print complete plan tree structure */ + private def debugPlanTree(plan: SparkPlan, depth: Int): Unit = { + val indent = " " * depth + println(s"$indent${plan.getClass.getSimpleName}") + + // Also show if this is a project operator if (plan.isInstanceOf[ProjectExec] || plan.isInstanceOf[CometProjectExec]) { - Some(plan) - } else { - plan.children.flatMap(findProjectExec).headOption + println(s"$indent -> PROJECT OPERATOR FOUND!") } + + // Recursively print children + plan.children.foreach(child => debugPlanTree(child, depth + 1)) } test("Direct cost model test - fast vs slow expressions") { From 63eeebc09cdcb02441ebe54d9a1f4c87cbd0e504 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Sun, 14 Dec 2025 17:58:54 -0700 Subject: [PATCH 15/20] test --- .../apache/comet/CometCostModelSuite.scala | 63 ++----------------- 1 file changed, 4 insertions(+), 59 deletions(-) diff --git a/spark/src/test/scala/org/apache/comet/CometCostModelSuite.scala b/spark/src/test/scala/org/apache/comet/CometCostModelSuite.scala index cddfb225c6..6d6734fecf 100644 --- a/spark/src/test/scala/org/apache/comet/CometCostModelSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometCostModelSuite.scala @@ -40,15 +40,7 @@ class CometCostModelSuite extends CometTestBase { withCBOEnabled { withTempView("test_data") { createSimpleTestData() - // Create a more complex query that will trigger AQE with joins/aggregations - val query = """ - SELECT t1.len1, t2.len2, COUNT(*) as cnt - FROM (SELECT length(text1) as len1, text1 FROM test_data) t1 - JOIN (SELECT length(text2) as len2, text2 FROM test_data) t2 - ON t1.text1 = t2.text2 - GROUP BY t1.len1, t2.len2 - """ - + val query = "SELECT length(text1), length(text2) FROM test_data" executeAndCheckOperator( query, classOf[CometProjectExec], @@ -61,15 +53,7 @@ class CometCostModelSuite extends CometTestBase { withCBOEnabled { withTempView("test_data") { createPaddedTestData() - // Create a more complex query that will trigger AQE with joins/aggregations - val query = """ - SELECT t1.trimmed1, t2.trimmed2, COUNT(*) as cnt - FROM (SELECT trim(text1) as trimmed1, text1 FROM test_data) t1 - JOIN (SELECT trim(text2) as trimmed2, text2 FROM test_data) t2 - ON t1.text1 = t2.text2 - GROUP BY t1.trimmed1, t2.trimmed2 - """ - + val query = "SELECT trim(text1), trim(text2) FROM test_data" executeAndCheckOperator( query, classOf[ProjectExec], @@ -78,45 +62,6 @@ class CometCostModelSuite extends CometTestBase { } } - test("Without CBO, Comet should be used regardless of expression cost") { - withCBODisabled { - withTempView("test_data") { - createPaddedTestData() - // Complex query without CBO - val query = """ - SELECT trim(text1) as trimmed1, COUNT(*) as cnt - FROM test_data - GROUP BY trim(text1) - """ - - executeAndCheckOperator( - query, - classOf[CometProjectExec], - "Expected CometProjectExec when CBO disabled") - } - } - } - - test("Mixed expressions should use appropriate operators per expression cost") { - withCBOEnabled { - withTempView("test_data") { - createSimpleTestData() - // Query mixing fast (length: 9.1x) and slow (ascii: 0.6x) expressions with aggregation - // Average acceleration: (9.1 + 0.6) / 2 = 4.85x -> cost = 0.206 (still prefer Comet) - val query = """ - SELECT length(text1) as fast_expr, ascii(text1) as slow_expr, COUNT(*) as cnt - FROM test_data - GROUP BY length(text1), ascii(text1) - """ - - executeAndCheckOperator( - query, - classOf[CometProjectExec], - "Expected CometProjectExec for mixed expressions with positive average") - } - } - } - /** Helper method to run tests with CBO enabled */ private def withCBOEnabled(f: => Unit): Unit = { withSQLConf( @@ -176,7 +121,7 @@ class CometCostModelSuite extends CometTestBase { val tempPath = s"${System.getProperty("java.io.tmpdir")}/comet_cost_test_${System.nanoTime()}" df.write.mode("overwrite").parquet(tempPath) - val parquetDf = spark.read.parquet(tempPath) + val parquetDf = spark.read.parquet(tempPath).repartition(5) parquetDf.createOrReplaceTempView("test_data") } @@ -193,7 +138,7 @@ class CometCostModelSuite extends CometTestBase { s"${System.getProperty("java.io.tmpdir")}/comet_cost_test_padded_${System.nanoTime()}" df.write.mode("overwrite").parquet(tempPath) - val parquetDf = spark.read.parquet(tempPath) + val parquetDf = spark.read.parquet(tempPath).repartition(5) parquetDf.createOrReplaceTempView("test_data") } From 14ff5a5757de895d0865fee307379a01e263f448 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Sun, 14 Dec 2025 18:06:59 -0700 Subject: [PATCH 16/20] test --- .../apache/comet/cost/CometCostModel.scala | 64 ++++++------------- .../apache/comet/CometCostModelSuite.scala | 14 ++-- 2 files changed, 27 insertions(+), 51 deletions(-) diff --git a/spark/src/main/scala/org/apache/comet/cost/CometCostModel.scala b/spark/src/main/scala/org/apache/comet/cost/CometCostModel.scala index 3fb0ccb3c5..7903e7feb5 100644 --- a/spark/src/main/scala/org/apache/comet/cost/CometCostModel.scala +++ b/spark/src/main/scala/org/apache/comet/cost/CometCostModel.scala @@ -21,6 +21,7 @@ package org.apache.comet.cost import scala.jdk.CollectionConverters._ +import org.apache.spark.internal.Logging import org.apache.spark.sql.comet.{CometColumnarToRowExec, CometPlan, CometProjectExec} import org.apache.spark.sql.comet.execution.shuffle.{CometColumnarShuffle, CometNativeShuffle, CometShuffleExchangeExec} import org.apache.spark.sql.execution.SparkPlan @@ -36,23 +37,24 @@ trait CometCostModel { def estimateCost(plan: SparkPlan): CometCostEstimate } -class DefaultCometCostModel extends CometCostModel { +class DefaultCometCostModel extends CometCostModel with Logging { // optimistic default of 2x acceleration private val defaultAcceleration = 2.0 override def estimateCost(plan: SparkPlan): CometCostEstimate = { + + logTrace(s"estimateCost for $plan") + // Walk the entire plan tree and accumulate costs var totalAcceleration = 0.0 var operatorCount = 0 def collectOperatorCosts(node: SparkPlan): Unit = { val operatorCost = estimateOperatorCost(node) - // scalastyle:off println - println( - s"[CostModel] Operator: ${node.getClass.getSimpleName}, " + + logTrace( + s"Operator: ${node.getClass.getSimpleName}, " + s"Cost: ${operatorCost.acceleration}") - // scalastyle:on println totalAcceleration += operatorCost.acceleration operatorCount += 1 @@ -70,11 +72,9 @@ class DefaultCometCostModel extends CometCostModel { 1.0 // No acceleration if no operators } - // scalastyle:off println - println( - s"[CostModel] Plan: ${plan.getClass.getSimpleName}, Total operators: $operatorCount, " + + logTrace( + s"Plan: ${plan.getClass.getSimpleName}, Total operators: $operatorCount, " + s"Average acceleration: $averageAcceleration") - // scalastyle:on println CometCostEstimate(averageAcceleration) } @@ -83,29 +83,21 @@ class DefaultCometCostModel extends CometCostModel { private def estimateOperatorCost(plan: SparkPlan): CometCostEstimate = { val result = plan match { case op: CometProjectExec => - // scalastyle:off println - println(s"[CostModel] CometProjectExec found - evaluating expressions") - // scalastyle:on println + logTrace("CometProjectExec found - evaluating expressions") // Cast nativeOp to Operator and extract projection expressions val operator = op.nativeOp.asInstanceOf[OperatorOuterClass.Operator] val projection = operator.getProjection val expressions = projection.getProjectListList.asScala - // scalastyle:off println - println(s"[CostModel] Found ${expressions.length} expressions in projection") - // scalastyle:on println + logTrace(s"Found ${expressions.length} expressions in projection") val costs = expressions.map { expr => val cost = estimateCometExpressionCost(expr) - // scalastyle:off println - println(s"[CostModel] Expression cost: $cost") - // scalastyle:on println + logTrace(s"Expression cost: $cost") cost } val total = costs.sum val average = total / expressions.length.toDouble - // scalastyle:off println - println(s"[CostModel] CometProjectExec total cost: $total, average: $average") - // scalastyle:on println + logTrace(s"CometProjectExec total cost: $total, average: $average") CometCostEstimate(average) case op: CometShuffleExchangeExec => @@ -121,21 +113,15 @@ class DefaultCometCostModel extends CometCostModel { case _: CometColumnarToRowExec => CometCostEstimate(1.0) case _: CometPlan => - // scalastyle:off println - println(s"[CostModel] Generic CometPlan: ${plan.getClass.getSimpleName}") - // scalastyle:on println + logTrace(s"Generic CometPlan: ${plan.getClass.getSimpleName}") CometCostEstimate(defaultAcceleration) case _ => - // scalastyle:off println - println(s"[CostModel] Non-Comet operator: ${plan.getClass.getSimpleName}") - // scalastyle:on println + logTrace(s"Non-Comet operator: ${plan.getClass.getSimpleName}") // Spark operator CometCostEstimate(1.0) } - // scalastyle:off println - println(s"[CostModel] ${plan.getClass.getSimpleName} -> acceleration: ${result.acceleration}") - // scalastyle:on println + logTrace(s"${plan.getClass.getSimpleName} -> acceleration: ${result.acceleration}") result } @@ -143,16 +129,12 @@ class DefaultCometCostModel extends CometCostModel { private def estimateCometExpressionCost(expr: ExprOuterClass.Expr): Double = { val result = expr.getExprStructCase match { // Handle specialized expression types - case ExprOuterClass.Expr.ExprStructCase.SUBSTRING => - // scalastyle:off println - println(s"[CostModel] Expression: SUBSTRING -> 6.3") - // scalastyle:on println - 6.3 + case ExprOuterClass.Expr.ExprStructCase.SUBSTRING => 6.3 // Handle generic scalar functions case ExprOuterClass.Expr.ExprStructCase.SCALARFUNC => val funcName = expr.getScalarFunc.getFunc - val cost = funcName match { + funcName match { // String expression numbers from CometStringExpressionBenchmark case "ascii" => 0.6 case "octet_length" => 0.6 @@ -171,17 +153,11 @@ class DefaultCometCostModel extends CometCostModel { case "translate" => 0.8 case _ => defaultAcceleration } - // scalastyle:off println - println(s"[CostModel] Expression: SCALARFUNC($funcName) -> $cost") - // scalastyle:on println - cost case _ => - // scalastyle:off println - println( - s"[CostModel] Expression: Unknown type ${expr.getExprStructCase} -> " + + logTrace( + s"Expression: Unknown type ${expr.getExprStructCase} -> " + s"$defaultAcceleration") - // scalastyle:on println defaultAcceleration } result diff --git a/spark/src/test/scala/org/apache/comet/CometCostModelSuite.scala b/spark/src/test/scala/org/apache/comet/CometCostModelSuite.scala index 6d6734fecf..2f99139a98 100644 --- a/spark/src/test/scala/org/apache/comet/CometCostModelSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometCostModelSuite.scala @@ -19,7 +19,7 @@ package org.apache.comet -import org.apache.spark.sql.{CometTestBase, DataFrame} +import org.apache.spark.sql.CometTestBase import org.apache.spark.sql.comet.CometProjectExec import org.apache.spark.sql.execution.ProjectExec import org.apache.spark.sql.execution.SparkPlan @@ -79,7 +79,7 @@ class CometCostModelSuite extends CometTestBase { "spark.sql.adaptive.advisoryPartitionSizeInBytes" -> "1KB", "spark.sql.adaptive.coalescePartitions.minPartitionSize" -> "1B") { - println(s"\n=== CBO Configuration ===") + println("\n=== CBO Configuration ===") println(s"COMET_ENABLED: ${spark.conf.get(CometConf.COMET_ENABLED.key)}") println(s"COMET_COST_BASED_OPTIMIZATION_ENABLED: ${spark.conf.get( CometConf.COMET_COST_BASED_OPTIMIZATION_ENABLED.key)}") @@ -148,13 +148,13 @@ class CometCostModelSuite extends CometTestBase { expectedClass: Class[_], message: String): Unit = { - println(s"\n=== Executing Query ===") + println("\n=== Executing Query ===") println(s"Query: $query") println(s"Expected class: ${expectedClass.getSimpleName}") val result = sql(query) - println(s"\n=== Pre-execution Plans ===") + println("\n=== Pre-execution Plans ===") println("Logical Plan:") println(result.queryExecution.logical) println("\nOptimized Plan:") @@ -164,7 +164,7 @@ class CometCostModelSuite extends CometTestBase { result.collect() // Materialize the plan - println(s"\n=== Post-execution Plans ===") + println("\n=== Post-execution Plans ===") println("Executed Plan (with AQE wrappers):") println(result.queryExecution.executedPlan) @@ -178,7 +178,7 @@ class CometCostModelSuite extends CometTestBase { val hasProjectExec = findProjectExec(executedPlan) - println(s"\n=== Project Analysis ===") + println("\n=== Project Analysis ===") println(s"Found project exec: ${hasProjectExec.isDefined}") if (hasProjectExec.isDefined) { println(s"Actual class: ${hasProjectExec.get.getClass.getSimpleName}") @@ -191,7 +191,7 @@ class CometCostModelSuite extends CometTestBase { expectedClass.isInstance(hasProjectExec.get), s"$message, got ${hasProjectExec.get.getClass.getSimpleName}") - println(s"=== Test PASSED ===\n") + println("=== Test PASSED ===\n") } /** Helper method to find ProjectExec or CometProjectExec in the plan tree */ From 304e9f75b7e9231cf503efe200b9b6754bfcfb65 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Mon, 15 Dec 2025 14:57:38 -0700 Subject: [PATCH 17/20] remove AQE integration --- .../comet/cost/CometCostEvaluator.scala | 100 ------------------ .../apache/comet/rules/CometExecRule.scala | 15 ++- .../main/scala/org/apache/spark/Plugins.scala | 11 +- .../apache/comet/CometCostModelSuite.scala | 2 - 4 files changed, 15 insertions(+), 113 deletions(-) delete mode 100644 spark/src/main/scala/org/apache/comet/cost/CometCostEvaluator.scala diff --git a/spark/src/main/scala/org/apache/comet/cost/CometCostEvaluator.scala b/spark/src/main/scala/org/apache/comet/cost/CometCostEvaluator.scala deleted file mode 100644 index 076d8bd502..0000000000 --- a/spark/src/main/scala/org/apache/comet/cost/CometCostEvaluator.scala +++ /dev/null @@ -1,100 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.comet.cost - -import org.apache.spark.internal.Logging -import org.apache.spark.sql.execution.SparkPlan -import org.apache.spark.sql.execution.adaptive.{Cost, CostEvaluator} -import org.apache.spark.sql.internal.SQLConf - -import org.apache.comet.CometConf - -/** - * Simple Cost implementation for Comet cost evaluator. - */ -case class CometCost(value: Double) extends Cost { - override def compare(that: Cost): Int = that match { - case CometCost(thatValue) => java.lang.Double.compare(value, thatValue) - case _ => 0 // If we can't compare, assume equal - } - - override def toString: String = s"CometCost($value)" -} - -/** - * Comet implementation of Spark's CostEvaluator for adaptive query execution. - * - * This evaluator uses the configured CometCostModel to estimate costs for query plans, allowing - * Spark's adaptive query execution to make informed decisions about whether to use Comet or Spark - * operators based on estimated performance. - */ -class CometCostEvaluator extends CostEvaluator with Logging { - - @transient private lazy val costModel: CometCostModel = { - val conf = SQLConf.get - val costModelClass = CometConf.COMET_COST_MODEL_CLASS.get(conf) - - try { - // scalastyle:off classforname - val clazz = Class.forName(costModelClass) - // scalastyle:on classforname - val constructor = clazz.getConstructor() - constructor.newInstance().asInstanceOf[CometCostModel] - } catch { - case e: Exception => - logWarning( - s"Failed to instantiate cost model class '$costModelClass', " + - s"falling back to DefaultCometCostModel. Error: ${e.getMessage}") - new DefaultCometCostModel() - } - } - - /** - * Evaluates the cost of executing the given SparkPlan. - * - * This method uses the configured CometCostModel to estimate the acceleration factor for the - * plan, then converts it to a Cost object that Spark's adaptive query execution can use for - * decision making. - * - * @param plan - * The SparkPlan to evaluate - * @return - * A Cost representing the estimated execution cost - */ - override def evaluateCost(plan: SparkPlan): Cost = { - val estimate = costModel.estimateCost(plan) - - // Convert acceleration factor to cost - // Lower cost means better performance, so we use the inverse of acceleration factor - // For example: - // - 2.0x acceleration -> cost = 0.5 (half the cost) - // - 0.8x acceleration -> cost = 1.25 (25% more cost) - val costValue = 1.0 / estimate.acceleration - - // scalastyle:off println - println( - s"[CostEvaluator] Plan: ${plan.getClass.getSimpleName}, " + - s"acceleration=${estimate.acceleration}, cost=$costValue") - // scalastyle:on println - - // Create Cost object with the calculated value - CometCost(costValue) - } -} diff --git a/spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala b/spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala index ed48e36f07..5077c1cc69 100644 --- a/spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala +++ b/spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala @@ -49,6 +49,7 @@ import org.apache.spark.sql.types._ import org.apache.comet.{CometConf, CometExplainInfo, ExtendedExplainInfo} import org.apache.comet.CometConf.{COMET_SPARK_TO_ARROW_ENABLED, COMET_SPARK_TO_ARROW_SUPPORTED_OPERATOR_LIST} import org.apache.comet.CometSparkSessionExtensions._ +import org.apache.comet.cost.DefaultCometCostModel import org.apache.comet.rules.CometExecRule.allExecs import org.apache.comet.serde.{CometOperatorSerde, Compatible, Incompatible, OperatorOuterClass, Unsupported} import org.apache.comet.serde.operator._ @@ -344,7 +345,19 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] { } override def apply(plan: SparkPlan): SparkPlan = { - val newPlan = _apply(plan) + val candidatePlan = _apply(plan) + + // TODO load cost model via config and reflection + val costModel = new DefaultCometCostModel + val costBefore = costModel.estimateCost(plan) + val costAfter = costModel.estimateCost(candidatePlan) + + val newPlan = if (costAfter.acceleration > costBefore.acceleration) { + candidatePlan + } else { + plan + } + if (showTransformations && !newPlan.fastEquals(plan)) { logInfo(s""" |=== Applying Rule $ruleName === diff --git a/spark/src/main/scala/org/apache/spark/Plugins.scala b/spark/src/main/scala/org/apache/spark/Plugins.scala index 9765ab5a6c..2529f08cfb 100644 --- a/spark/src/main/scala/org/apache/spark/Plugins.scala +++ b/spark/src/main/scala/org/apache/spark/Plugins.scala @@ -28,7 +28,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.internal.config.{EXECUTOR_MEMORY, EXECUTOR_MEMORY_OVERHEAD, EXECUTOR_MEMORY_OVERHEAD_FACTOR} import org.apache.spark.sql.internal.StaticSQLConf -import org.apache.comet.CometConf.{COMET_COST_BASED_OPTIMIZATION_ENABLED, COMET_ONHEAP_ENABLED} +import org.apache.comet.CometConf.COMET_ONHEAP_ENABLED import org.apache.comet.CometSparkSessionExtensions /** @@ -57,15 +57,6 @@ class CometDriverPlugin extends DriverPlugin with Logging with ShimCometDriverPl // register CometSparkSessionExtensions if it isn't already registered CometDriverPlugin.registerCometSessionExtension(sc.conf) - // Enable cost-based optimization if configured - if (sc.getConf.getBoolean(COMET_COST_BASED_OPTIMIZATION_ENABLED.key, false)) { - // Set the custom cost evaluator for Spark's adaptive query execution - sc.conf.set( - "spark.sql.adaptive.customCostEvaluatorClass", - "org.apache.comet.cost.CometCostEvaluator") - logInfo("Enabled Comet cost-based optimization with CometCostEvaluator") - } - if (CometSparkSessionExtensions.shouldOverrideMemoryConf(sc.getConf)) { val execMemOverhead = if (sc.getConf.contains(EXECUTOR_MEMORY_OVERHEAD.key)) { sc.getConf.getSizeAsMb(EXECUTOR_MEMORY_OVERHEAD.key) diff --git a/spark/src/test/scala/org/apache/comet/CometCostModelSuite.scala b/spark/src/test/scala/org/apache/comet/CometCostModelSuite.scala index 2f99139a98..3e6fd66547 100644 --- a/spark/src/test/scala/org/apache/comet/CometCostModelSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometCostModelSuite.scala @@ -73,8 +73,6 @@ class CometCostModelSuite extends CometTestBase { CometConf.COMET_EXEC_PROJECT_ENABLED.key -> "true", CometConf.COMET_EXEC_AGGREGATE_ENABLED.key -> "true", // Enable aggregation for GROUP BY CometConf.COMET_EXEC_HASH_JOIN_ENABLED.key -> "true", // Enable joins - // Manually set the custom cost evaluator since plugin might not be loaded - "spark.sql.adaptive.customCostEvaluatorClass" -> "org.apache.comet.cost.CometCostEvaluator", // Lower AQE thresholds to ensure it triggers on small test data "spark.sql.adaptive.advisoryPartitionSizeInBytes" -> "1KB", "spark.sql.adaptive.coalescePartitions.minPartitionSize" -> "1B") { From a7acf9affc468ce88382e020d997826f0a6db185 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Mon, 15 Dec 2025 15:05:32 -0700 Subject: [PATCH 18/20] use configs --- .../scala/org/apache/comet/CometConf.scala | 2 +- .../apache/comet/rules/CometExecRule.scala | 47 +++++++++++++++---- 2 files changed, 38 insertions(+), 11 deletions(-) diff --git a/common/src/main/scala/org/apache/comet/CometConf.scala b/common/src/main/scala/org/apache/comet/CometConf.scala index a5cce453e2..24fbfbab62 100644 --- a/common/src/main/scala/org/apache/comet/CometConf.scala +++ b/common/src/main/scala/org/apache/comet/CometConf.scala @@ -754,7 +754,7 @@ object CometConf extends ShimCometConf { .booleanConf .createWithEnvVarOrDefault("ENABLE_COMET_STRICT_TESTING", false) - val COMET_COST_BASED_OPTIMIZATION_ENABLED: ConfigEntry[Boolean] = + val `COMET_COST_BASED_OPTIMIZATION_ENABLED`: ConfigEntry[Boolean] = conf("spark.comet.cost.enabled") .category(CATEGORY_TUNING) .doc( diff --git a/spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala b/spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala index 5077c1cc69..7025bb357d 100644 --- a/spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala +++ b/spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala @@ -47,9 +47,9 @@ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.comet.{CometConf, CometExplainInfo, ExtendedExplainInfo} -import org.apache.comet.CometConf.{COMET_SPARK_TO_ARROW_ENABLED, COMET_SPARK_TO_ARROW_SUPPORTED_OPERATOR_LIST} +import org.apache.comet.CometConf.{COMET_COST_BASED_OPTIMIZATION_ENABLED, COMET_COST_MODEL_CLASS, COMET_SPARK_TO_ARROW_ENABLED, COMET_SPARK_TO_ARROW_SUPPORTED_OPERATOR_LIST} import org.apache.comet.CometSparkSessionExtensions._ -import org.apache.comet.cost.DefaultCometCostModel +import org.apache.comet.cost.CometCostModel import org.apache.comet.rules.CometExecRule.allExecs import org.apache.comet.serde.{CometOperatorSerde, Compatible, Incompatible, OperatorOuterClass, Unsupported} import org.apache.comet.serde.operator._ @@ -98,6 +98,28 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] { private lazy val showTransformations = CometConf.COMET_EXPLAIN_TRANSFORMATIONS.get() + // Cache the cost model to avoid loading the class on every call + @transient private lazy val costModel: Option[CometCostModel] = { + if (COMET_COST_BASED_OPTIMIZATION_ENABLED.get(conf)) { + try { + val costModelClassName = COMET_COST_MODEL_CLASS.get(conf) + // scalastyle:off classforname + val costModelClass = Class.forName(costModelClassName) + // scalastyle:on classforname + val constructor = costModelClass.getConstructor() + Some(constructor.newInstance().asInstanceOf[CometCostModel]) + } catch { + case e: Exception => + logWarning( + s"Failed to load cost model class: ${e.getMessage}. " + + "Falling back to Spark query plan without cost-based optimization.") + None + } + } else { + None + } + } + private def applyCometShuffle(plan: SparkPlan): SparkPlan = { plan.transformUp { case s: ShuffleExchangeExec if CometShuffleExchangeExec.nativeShuffleSupported(s) => @@ -347,15 +369,20 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] { override def apply(plan: SparkPlan): SparkPlan = { val candidatePlan = _apply(plan) - // TODO load cost model via config and reflection - val costModel = new DefaultCometCostModel - val costBefore = costModel.estimateCost(plan) - val costAfter = costModel.estimateCost(candidatePlan) + // Only apply cost-based optimization if enabled and cost model is available + val newPlan = costModel match { + case Some(model) => + val costBefore = model.estimateCost(plan) + val costAfter = model.estimateCost(candidatePlan) - val newPlan = if (costAfter.acceleration > costBefore.acceleration) { - candidatePlan - } else { - plan + if (costAfter.acceleration > costBefore.acceleration) { + candidatePlan + } else { + plan + } + case None => + // Cost-based optimization is disabled or failed to load, return candidate plan + candidatePlan } if (showTransformations && !newPlan.fastEquals(plan)) { From f60d72a1008f9167f77f16dd7195b85c0778fdd0 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Mon, 15 Dec 2025 15:09:43 -0700 Subject: [PATCH 19/20] remove debug logging --- .../apache/comet/CometCostModelSuite.scala | 102 +----------------- 1 file changed, 3 insertions(+), 99 deletions(-) diff --git a/spark/src/test/scala/org/apache/comet/CometCostModelSuite.scala b/spark/src/test/scala/org/apache/comet/CometCostModelSuite.scala index 3e6fd66547..6a91a498af 100644 --- a/spark/src/test/scala/org/apache/comet/CometCostModelSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometCostModelSuite.scala @@ -23,7 +23,6 @@ import org.apache.spark.sql.CometTestBase import org.apache.spark.sql.comet.CometProjectExec import org.apache.spark.sql.execution.ProjectExec import org.apache.spark.sql.execution.SparkPlan -import org.apache.spark.sql.internal.SQLConf class CometCostModelSuite extends CometTestBase { @@ -37,7 +36,7 @@ class CometCostModelSuite extends CometTestBase { // - ascii: 0.6x -> cost = 1/0.6 = 1.67 (high cost, Spark preferred) test("CBO should prefer Comet for fast expressions (length)") { - withCBOEnabled { + withSQLConf(CometConf.COMET_COST_BASED_OPTIMIZATION_ENABLED.key -> "true") { withTempView("test_data") { createSimpleTestData() val query = "SELECT length(text1), length(text2) FROM test_data" @@ -50,7 +49,7 @@ class CometCostModelSuite extends CometTestBase { } test("CBO should prefer Spark for slow expressions (trim)") { - withCBOEnabled { + withSQLConf(CometConf.COMET_COST_BASED_OPTIMIZATION_ENABLED.key -> "true") { withTempView("test_data") { createPaddedTestData() val query = "SELECT trim(text1), trim(text2) FROM test_data" @@ -62,51 +61,6 @@ class CometCostModelSuite extends CometTestBase { } } - /** Helper method to run tests with CBO enabled */ - private def withCBOEnabled(f: => Unit): Unit = { - withSQLConf( - CometConf.COMET_EXEC_LOCAL_TABLE_SCAN_ENABLED.key -> "true", - CometConf.COMET_ENABLED.key -> "true", - CometConf.COMET_COST_BASED_OPTIMIZATION_ENABLED.key -> "true", - SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", - CometConf.COMET_EXEC_ENABLED.key -> "true", - CometConf.COMET_EXEC_PROJECT_ENABLED.key -> "true", - CometConf.COMET_EXEC_AGGREGATE_ENABLED.key -> "true", // Enable aggregation for GROUP BY - CometConf.COMET_EXEC_HASH_JOIN_ENABLED.key -> "true", // Enable joins - // Lower AQE thresholds to ensure it triggers on small test data - "spark.sql.adaptive.advisoryPartitionSizeInBytes" -> "1KB", - "spark.sql.adaptive.coalescePartitions.minPartitionSize" -> "1B") { - - println("\n=== CBO Configuration ===") - println(s"COMET_ENABLED: ${spark.conf.get(CometConf.COMET_ENABLED.key)}") - println(s"COMET_COST_BASED_OPTIMIZATION_ENABLED: ${spark.conf.get( - CometConf.COMET_COST_BASED_OPTIMIZATION_ENABLED.key)}") - println( - s"ADAPTIVE_EXECUTION_ENABLED: ${spark.conf.get(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key)}") - println(s"COMET_EXEC_ENABLED: ${spark.conf.get(CometConf.COMET_EXEC_ENABLED.key)}") - println( - s"COMET_EXEC_PROJECT_ENABLED: ${spark.conf.get(CometConf.COMET_EXEC_PROJECT_ENABLED.key)}") - - // Check if custom cost evaluator is set - val costEvaluator = spark.conf.getOption("spark.sql.adaptive.customCostEvaluatorClass") - println(s"Custom cost evaluator: ${costEvaluator.getOrElse("None")}") - - f - } - } - - /** Helper method to run tests with CBO disabled */ - private def withCBODisabled(f: => Unit): Unit = { - withSQLConf( - CometConf.COMET_ENABLED.key -> "true", - CometConf.COMET_COST_BASED_OPTIMIZATION_ENABLED.key -> "false", - SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", - CometConf.COMET_EXEC_ENABLED.key -> "true", - CometConf.COMET_EXEC_PROJECT_ENABLED.key -> "true") { - f - } - } - /** Create simple test data for string operations using parquet to prevent pushdown */ private def createSimpleTestData(): Unit = { import testImplicits._ @@ -146,60 +100,24 @@ class CometCostModelSuite extends CometTestBase { expectedClass: Class[_], message: String): Unit = { - println("\n=== Executing Query ===") - println(s"Query: $query") - println(s"Expected class: ${expectedClass.getSimpleName}") - val result = sql(query) - - println("\n=== Pre-execution Plans ===") - println("Logical Plan:") - println(result.queryExecution.logical) - println("\nOptimized Plan:") - println(result.queryExecution.optimizedPlan) - println("\nSpark Plan:") - println(result.queryExecution.sparkPlan) - result.collect() // Materialize the plan - println("\n=== Post-execution Plans ===") - println("Executed Plan (with AQE wrappers):") - println(result.queryExecution.executedPlan) - val executedPlan = stripAQEPlan(result.queryExecution.executedPlan) - println("\nExecuted Plan (stripped AQE):") - println(executedPlan) - - // Enhanced debugging: show complete plan tree structure - println("\n=== Plan Tree Analysis ===") - debugPlanTree(executedPlan, 0) val hasProjectExec = findProjectExec(executedPlan) - println("\n=== Project Analysis ===") - println(s"Found project exec: ${hasProjectExec.isDefined}") - if (hasProjectExec.isDefined) { - println(s"Actual class: ${hasProjectExec.get.getClass.getSimpleName}") - println(s"Expected class: ${expectedClass.getSimpleName}") - println(s"Is expected type: ${expectedClass.isInstance(hasProjectExec.get)}") - } - assert(hasProjectExec.isDefined, "Should have a project operator") assert( expectedClass.isInstance(hasProjectExec.get), s"$message, got ${hasProjectExec.get.getClass.getSimpleName}") - - println("=== Test PASSED ===\n") } /** Helper method to find ProjectExec or CometProjectExec in the plan tree */ private def findProjectExec(plan: SparkPlan): Option[SparkPlan] = { // More robust recursive search that handles deep nesting def searchPlan(node: SparkPlan): Option[SparkPlan] = { - println(s"[findProjectExec] Checking node: ${node.getClass.getSimpleName}") - if (node.isInstanceOf[ProjectExec] || node.isInstanceOf[CometProjectExec]) { - println(s"[findProjectExec] Found project operator: ${node.getClass.getSimpleName}") Some(node) } else { // Search all children recursively @@ -216,22 +134,8 @@ class CometCostModelSuite extends CometTestBase { searchPlan(plan) } - /** Debug method to print complete plan tree structure */ - private def debugPlanTree(plan: SparkPlan, depth: Int): Unit = { - val indent = " " * depth - println(s"$indent${plan.getClass.getSimpleName}") - - // Also show if this is a project operator - if (plan.isInstanceOf[ProjectExec] || plan.isInstanceOf[CometProjectExec]) { - println(s"$indent -> PROJECT OPERATOR FOUND!") - } - - // Recursively print children - plan.children.foreach(child => debugPlanTree(child, depth + 1)) - } - test("Direct cost model test - fast vs slow expressions") { - withCBOEnabled { + withSQLConf(CometConf.COMET_COST_BASED_OPTIMIZATION_ENABLED.key -> "true") { withTempView("test_data") { createSimpleTestData() From 54e1a3312043f10f02ca478a987a8ea9247ab6bd Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Mon, 15 Dec 2025 17:21:27 -0700 Subject: [PATCH 20/20] remove nonsense test --- .../apache/comet/CometCostModelSuite.scala | 58 ------------------- 1 file changed, 58 deletions(-) diff --git a/spark/src/test/scala/org/apache/comet/CometCostModelSuite.scala b/spark/src/test/scala/org/apache/comet/CometCostModelSuite.scala index 6a91a498af..5ee530ae69 100644 --- a/spark/src/test/scala/org/apache/comet/CometCostModelSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometCostModelSuite.scala @@ -26,15 +26,6 @@ import org.apache.spark.sql.execution.SparkPlan class CometCostModelSuite extends CometTestBase { - // Fast expressions in Comet (high acceleration factor -> low cost -> preferred) - // Based on CometCostModel estimates: - // - length: 9.1x -> cost = 1/9.1 = 0.11 (very low cost, Comet preferred) - // - reverse: 6.9x -> cost = 1/6.9 = 0.145 (low cost, Comet preferred) - - // Slow expressions in Comet (low acceleration factor -> high cost -> Spark preferred) - // - trim: 0.4x -> cost = 1/0.4 = 2.5 (high cost, Spark preferred) - // - ascii: 0.6x -> cost = 1/0.6 = 1.67 (high cost, Spark preferred) - test("CBO should prefer Comet for fast expressions (length)") { withSQLConf(CometConf.COMET_COST_BASED_OPTIMIZATION_ENABLED.key -> "true") { withTempView("test_data") { @@ -133,53 +124,4 @@ class CometCostModelSuite extends CometTestBase { searchPlan(plan) } - - test("Direct cost model test - fast vs slow expressions") { - withSQLConf(CometConf.COMET_COST_BASED_OPTIMIZATION_ENABLED.key -> "true") { - withTempView("test_data") { - createSimpleTestData() - - // Test the cost model directly on project operators - val lengthQuery = "SELECT length(text1) as len FROM test_data" - val trimQuery = "SELECT trim(text1) as trimmed FROM test_data" - - val lengthPlan = sql(lengthQuery).queryExecution.optimizedPlan - val trimPlan = sql(trimQuery).queryExecution.optimizedPlan - - // Find project nodes in the optimized plans - val lengthProject = findProjectInPlan(lengthPlan) - val trimProject = findProjectInPlan(trimPlan) - - if (lengthProject.isDefined && trimProject.isDefined) { - val costModel = new org.apache.comet.cost.DefaultCometCostModel() - - // Create mock Comet project operators (this would normally be done by the planner) - // For now, just verify the cost model has different estimates for different expressions - val lengthCost = 9.1 // Expected length acceleration - val trimCost = 0.4 // Expected trim acceleration - - assert( - lengthCost > trimCost, - s"Length ($lengthCost) should be faster than trim ($trimCost)") - - // Cost = 1/acceleration, so lower acceleration = higher cost - assert( - 1.0 / trimCost > 1.0 / lengthCost, - s"Trim cost (${1.0 / trimCost}) should be higher than length cost (${1.0 / lengthCost})") - } else { - // Skip test if projections are optimized away - cancel("Projections were optimized away - using integration tests instead") - } - } - } - } - - /** Helper to find Project nodes in a logical plan */ - private def findProjectInPlan(plan: org.apache.spark.sql.catalyst.plans.logical.LogicalPlan) - : Option[org.apache.spark.sql.catalyst.plans.logical.Project] = { - plan match { - case p: org.apache.spark.sql.catalyst.plans.logical.Project => Some(p) - case _ => plan.children.flatMap(findProjectInPlan).headOption - } - } }