Skip to content
Draft
18 changes: 18 additions & 0 deletions common/src/main/scala/org/apache/comet/CometConf.scala
Original file line number Diff line number Diff line change
Expand Up @@ -754,6 +754,24 @@ object CometConf extends ShimCometConf {
.booleanConf
.createWithEnvVarOrDefault("ENABLE_COMET_STRICT_TESTING", false)

val `COMET_COST_BASED_OPTIMIZATION_ENABLED`: ConfigEntry[Boolean] =
conf("spark.comet.cost.enabled")
.category(CATEGORY_TUNING)
.doc(
"Whether to enable cost-based optimization for Comet. When enabled, Comet will " +
"use a cost model to estimate acceleration factors for operators and make decisions " +
"about whether to use Comet or Spark operators based on estimated performance.")
.booleanConf
.createWithDefault(false)

val COMET_COST_MODEL_CLASS: ConfigEntry[String] =
conf("spark.comet.cost.model.class")
.category(CATEGORY_TUNING)
.doc("The fully qualified class name of the cost model implementation to use for " +
"cost-based optimization. The class must implement the CometCostModel trait.")
.stringConf
.createWithDefault("org.apache.comet.cost.DefaultCometCostModel")

/** Create a config to enable a specific operator */
private def createExecEnabledConfig(
exec: String,
Expand Down
2 changes: 2 additions & 0 deletions docs/source/user-guide/latest/configs.md
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,8 @@ These settings can be used to determine which parts of the plan are accelerated
| Config | Description | Default Value |
|--------|-------------|---------------|
| `spark.comet.batchSize` | The columnar batch size, i.e., the maximum number of rows that a batch can contain. | 8192 |
| `spark.comet.cost.enabled` | Whether to enable cost-based optimization for Comet. When enabled, Comet will use a cost model to estimate acceleration factors for operators and make decisions about whether to use Comet or Spark operators based on estimated performance. | false |
| `spark.comet.cost.model.class` | The fully qualified class name of the cost model implementation to use for cost-based optimization. The class must implement the CometCostModel trait. | org.apache.comet.cost.DefaultCometCostModel |
| `spark.comet.exec.memoryPool` | The type of memory pool to be used for Comet native execution when running Spark in off-heap mode. Available pool types are `greedy_unified` and `fair_unified`. For more information, refer to the [Comet Tuning Guide](https://datafusion.apache.org/comet/user-guide/tuning.html). | fair_unified |
| `spark.comet.exec.memoryPool.fraction` | Fraction of off-heap memory pool that is available to Comet. Only applies to off-heap mode. For more information, refer to the [Comet Tuning Guide](https://datafusion.apache.org/comet/user-guide/tuning.html). | 1.0 |
| `spark.comet.tracing.enabled` | Enable fine-grained tracing of events and memory usage. For more information, refer to the [Comet Tracing Guide](https://datafusion.apache.org/comet/contributor-guide/tracing.html). | false |
Expand Down
3 changes: 3 additions & 0 deletions spark/src/main/scala/org/apache/comet/DataTypeSupport.scala
Original file line number Diff line number Diff line change
Expand Up @@ -79,4 +79,7 @@ object DataTypeSupport {
case _: StructType | _: ArrayType | _: MapType => true
case _ => false
}

def hasComplexTypes(schema: StructType): Boolean =
schema.fields.exists(f => isComplexType(f.dataType))
}
166 changes: 166 additions & 0 deletions spark/src/main/scala/org/apache/comet/cost/CometCostModel.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.apache.comet.cost

import scala.jdk.CollectionConverters._

import org.apache.spark.internal.Logging
import org.apache.spark.sql.comet.{CometColumnarToRowExec, CometPlan, CometProjectExec}
import org.apache.spark.sql.comet.execution.shuffle.{CometColumnarShuffle, CometNativeShuffle, CometShuffleExchangeExec}
import org.apache.spark.sql.execution.SparkPlan

import org.apache.comet.DataTypeSupport
import org.apache.comet.serde.{ExprOuterClass, OperatorOuterClass}

case class CometCostEstimate(acceleration: Double)

trait CometCostModel {

/** Estimate the relative cost of one operator */
def estimateCost(plan: SparkPlan): CometCostEstimate
}

class DefaultCometCostModel extends CometCostModel with Logging {

// optimistic default of 2x acceleration
private val defaultAcceleration = 2.0

override def estimateCost(plan: SparkPlan): CometCostEstimate = {

logTrace(s"estimateCost for $plan")

// Walk the entire plan tree and accumulate costs
var totalAcceleration = 0.0
var operatorCount = 0

def collectOperatorCosts(node: SparkPlan): Unit = {
val operatorCost = estimateOperatorCost(node)
logTrace(
s"Operator: ${node.getClass.getSimpleName}, " +
s"Cost: ${operatorCost.acceleration}")
totalAcceleration += operatorCost.acceleration
operatorCount += 1

// Recursively process children
node.children.foreach(collectOperatorCosts)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps we could remove usage of vars and let the function return the totalAcceleration and operator count itself ?

Something like :

def countItems(list: List[Any], accumulator: Int = 0): Int = {
  list match {
    case head :: tail => countItems(tail, accumulator + 1)
    case Nil => accumulator
  }
}

val myList = List(1, 2, 3, 4)
val count = countItems(myList)  // result: 4

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know how much of this code will still exist by the time the proof-of-concept is working and ready for detailed code review, so I'll hold off from making these changes now.

I am really looking for high-level feedback on the general approach at the moment.

}

collectOperatorCosts(plan)

// Calculate average acceleration across all operators
// This is crude but gives us a starting point
val averageAcceleration = if (operatorCount > 0) {
totalAcceleration / operatorCount.toDouble
} else {
1.0 // No acceleration if no operators
}

logTrace(
s"Plan: ${plan.getClass.getSimpleName}, Total operators: $operatorCount, " +
s"Average acceleration: $averageAcceleration")

CometCostEstimate(averageAcceleration)
}

/** Estimate the cost of a single operator */
private def estimateOperatorCost(plan: SparkPlan): CometCostEstimate = {
val result = plan match {
case op: CometProjectExec =>
logTrace("CometProjectExec found - evaluating expressions")
// Cast nativeOp to Operator and extract projection expressions
val operator = op.nativeOp.asInstanceOf[OperatorOuterClass.Operator]
val projection = operator.getProjection
val expressions = projection.getProjectListList.asScala
logTrace(s"Found ${expressions.length} expressions in projection")

val costs = expressions.map { expr =>
val cost = estimateCometExpressionCost(expr)
logTrace(s"Expression cost: $cost")
cost
}
val total = costs.sum
val average = total / expressions.length.toDouble
logTrace(s"CometProjectExec total cost: $total, average: $average")
CometCostEstimate(average)

case op: CometShuffleExchangeExec =>
op.shuffleType match {
case CometNativeShuffle => CometCostEstimate(1.5)
case CometColumnarShuffle =>
if (DataTypeSupport.hasComplexTypes(op.schema)) {
CometCostEstimate(0.8)
} else {
CometCostEstimate(1.1)
}
}
case _: CometColumnarToRowExec =>
CometCostEstimate(1.0)
case _: CometPlan =>
logTrace(s"Generic CometPlan: ${plan.getClass.getSimpleName}")
CometCostEstimate(defaultAcceleration)
case _ =>
logTrace(s"Non-Comet operator: ${plan.getClass.getSimpleName}")
// Spark operator
CometCostEstimate(1.0)
}

logTrace(s"${plan.getClass.getSimpleName} -> acceleration: ${result.acceleration}")
result
}

/** Estimate the cost of a Comet protobuf expression */
private def estimateCometExpressionCost(expr: ExprOuterClass.Expr): Double = {
val result = expr.getExprStructCase match {
// Handle specialized expression types
case ExprOuterClass.Expr.ExprStructCase.SUBSTRING => 6.3

// Handle generic scalar functions
case ExprOuterClass.Expr.ExprStructCase.SCALARFUNC =>
val funcName = expr.getScalarFunc.getFunc
funcName match {
// String expression numbers from CometStringExpressionBenchmark
case "ascii" => 0.6
case "octet_length" => 0.6
case "lower" => 3.0
case "upper" => 3.0
case "char" => 0.6
case "initcap" => 0.9
case "trim" => 0.4
case "concat_ws" => 0.5
case "length" => 9.1
case "repeat" => 0.4
case "reverse" => 6.9
case "instr" => 0.6
case "replace" => 1.3
case "string_space" => 0.8
case "translate" => 0.8
case _ => defaultAcceleration
}

case _ =>
logTrace(
s"Expression: Unknown type ${expr.getExprStructCase} -> " +
s"$defaultAcceleration")
defaultAcceleration
}
result
}

}
44 changes: 42 additions & 2 deletions spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,9 @@ import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._

import org.apache.comet.{CometConf, CometExplainInfo, ExtendedExplainInfo}
import org.apache.comet.CometConf.{COMET_SPARK_TO_ARROW_ENABLED, COMET_SPARK_TO_ARROW_SUPPORTED_OPERATOR_LIST}
import org.apache.comet.CometConf.{COMET_COST_BASED_OPTIMIZATION_ENABLED, COMET_COST_MODEL_CLASS, COMET_SPARK_TO_ARROW_ENABLED, COMET_SPARK_TO_ARROW_SUPPORTED_OPERATOR_LIST}
import org.apache.comet.CometSparkSessionExtensions._
import org.apache.comet.cost.CometCostModel
import org.apache.comet.rules.CometExecRule.allExecs
import org.apache.comet.serde.{CometOperatorSerde, Compatible, Incompatible, OperatorOuterClass, Unsupported}
import org.apache.comet.serde.operator._
Expand Down Expand Up @@ -97,6 +98,28 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] {

private lazy val showTransformations = CometConf.COMET_EXPLAIN_TRANSFORMATIONS.get()

// Cache the cost model to avoid loading the class on every call
@transient private lazy val costModel: Option[CometCostModel] = {
if (COMET_COST_BASED_OPTIMIZATION_ENABLED.get(conf)) {
try {
val costModelClassName = COMET_COST_MODEL_CLASS.get(conf)
// scalastyle:off classforname
val costModelClass = Class.forName(costModelClassName)
// scalastyle:on classforname
val constructor = costModelClass.getConstructor()
Some(constructor.newInstance().asInstanceOf[CometCostModel])
} catch {
case e: Exception =>
logWarning(
s"Failed to load cost model class: ${e.getMessage}. " +
"Falling back to Spark query plan without cost-based optimization.")
None
}
} else {
None
}
}

private def applyCometShuffle(plan: SparkPlan): SparkPlan = {
plan.transformUp {
case s: ShuffleExchangeExec if CometShuffleExchangeExec.nativeShuffleSupported(s) =>
Expand Down Expand Up @@ -344,7 +367,24 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] {
}

override def apply(plan: SparkPlan): SparkPlan = {
val newPlan = _apply(plan)
val candidatePlan = _apply(plan)

// Only apply cost-based optimization if enabled and cost model is available
val newPlan = costModel match {
case Some(model) =>
val costBefore = model.estimateCost(plan)
val costAfter = model.estimateCost(candidatePlan)

if (costAfter.acceleration > costBefore.acceleration) {
candidatePlan
} else {
plan
}
case None =>
// Cost-based optimization is disabled or failed to load, return candidate plan
candidatePlan
}

if (showTransformations && !newPlan.fastEquals(plan)) {
logInfo(s"""
|=== Applying Rule $ruleName ===
Expand Down
Loading
Loading