Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 17 additions & 2 deletions common/src/main/scala/org/apache/comet/CometConf.scala
Original file line number Diff line number Diff line change
Expand Up @@ -365,11 +365,26 @@ object CometConf extends ShimCometConf {
val COMET_REPLACE_SMJ: ConfigEntry[Boolean] =
conf(s"$COMET_EXEC_CONFIG_PREFIX.replaceSortMergeJoin")
.category(CATEGORY_EXEC)
.doc("Experimental feature to force Spark to replace SortMergeJoin with ShuffledHashJoin " +
s"for improved performance. This feature is not stable yet. $TUNING_GUIDE.")
.doc("Experimental feature to replace SortMergeJoin with ShuffledHashJoin " +
"for improved performance. When enabled, Comet will only rewrite a " +
"SortMergeJoin to a ShuffledHashJoin if the build side is estimated to " +
"fit in memory (using Spark's autoBroadcastJoinThreshold * " +
"numShufflePartitions) and is significantly smaller than the probe side " +
s"(controlled by replaceSortMergeJoin.sizeRatio). $TUNING_GUIDE.")
.booleanConf
.createWithDefault(false)

val COMET_REPLACE_SMJ_SIZE_RATIO: ConfigEntry[Int] =
conf(s"$COMET_EXEC_CONFIG_PREFIX.replaceSortMergeJoin.sizeRatio")
.category(CATEGORY_EXEC)
.doc("Minimum ratio of probe side size to build side size required when " +
"replacing SortMergeJoin with ShuffledHashJoin. Building a hash table is " +
"more expensive than sorting, so the build side should be significantly " +
"smaller than the probe side. Matches Spark's " +
s"spark.sql.shuffledHashJoinFactor behavior. $TUNING_GUIDE.")
.intConf
.createWithDefault(3)

val COMET_EXEC_SHUFFLE_WITH_HASH_PARTITIONING_ENABLED: ConfigEntry[Boolean] =
conf("spark.comet.native.shuffle.partitioning.hash.enabled")
.category(CATEGORY_SHUFFLE)
Expand Down
12 changes: 12 additions & 0 deletions docs/source/user-guide/latest/tuning.md
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,18 @@ to test with both for your specific workloads.

To configure Comet to convert `SortMergeJoin` to `ShuffledHashJoin`, set `spark.comet.exec.replaceSortMergeJoin=true`.

When this feature is enabled, Comet uses statistics-based guards to only rewrite joins where the build side is
estimated to fit in memory. Specifically:

- The build side must be small enough that each partition's hash table fits in memory, using Spark's
`autoBroadcastJoinThreshold * numShufflePartitions` as the threshold (matching Spark's own `canBuildLocalHashMapBySize`
logic).
- The build side must be significantly smaller than the probe side. The required ratio is controlled by
`spark.comet.exec.replaceSortMergeJoin.sizeRatio` (default: 3), matching Spark's `SHUFFLE_HASH_JOIN_FACTOR`.

If either check fails, the `SortMergeJoin` is kept as-is. To understand why a specific join was not rewritten, enable
`spark.comet.explainFallback.enabled=true` and check the logs.

## Shuffle

Comet provides accelerated shuffle implementations that can be used to improve the performance of your queries.
Expand Down
84 changes: 74 additions & 10 deletions spark/src/main/scala/org/apache/comet/rules/RewriteJoin.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,16 @@ import org.apache.spark.sql.catalyst.plans.{LeftAnti, LeftSemi}
import org.apache.spark.sql.catalyst.plans.logical.Join
import org.apache.spark.sql.execution.{SortExec, SparkPlan}
import org.apache.spark.sql.execution.joins.{ShuffledHashJoinExec, SortMergeJoinExec}
import org.apache.spark.sql.internal.SQLConf

import org.apache.comet.CometConf
import org.apache.comet.CometSparkSessionExtensions.withInfo

/**
* Adapted from equivalent rule in Apache Gluten.
*
* This rule replaces [[SortMergeJoinExec]] with [[ShuffledHashJoinExec]].
* This rule replaces [[SortMergeJoinExec]] with [[ShuffledHashJoinExec]] when the build side
* is estimated to fit in memory and is significantly smaller than the probe side.
*/
object RewriteJoin extends JoinSelectionHelper {

Expand Down Expand Up @@ -64,6 +67,59 @@ object RewriteJoin extends JoinSelectionHelper {
case _ => plan
}

/**
* Check whether the build side of a join is small enough to rewrite from SortMergeJoin
* to ShuffledHashJoin. Uses two checks mirroring Spark's JoinSelection logic:
*
* 1. Per-partition size: build side size / numShufflePartitions must be less than
* autoBroadcastJoinThreshold (i.e., each partition's hash table fits in memory).
* This matches Spark's canBuildLocalHashMapBySize().
*
* 2. Size ratio: the build side must be significantly smaller than the probe side,
* controlled by spark.comet.exec.replaceSortMergeJoin.sizeRatio (default 3).
* This matches Spark's muchSmaller() / SHUFFLE_HASH_JOIN_FACTOR.
*
* If no logical link with statistics is available, the rewrite is skipped to be safe.
*/
private def canRewriteToHashJoin(
smj: SortMergeJoinExec,
buildSide: BuildSide): Option[String] = {
val conf = SQLConf.get
val sizeRatio = CometConf.COMET_REPLACE_SMJ_SIZE_RATIO.get()
val broadcastThreshold = conf.autoBroadcastJoinThreshold
val numPartitions = conf.numShufflePartitions

// If broadcast threshold is -1 (disabled), skip the per-partition check
// but still enforce the size ratio check
val maxBuildSize = if (broadcastThreshold > 0) {
broadcastThreshold * numPartitions
} else {
Long.MaxValue
}

smj.logicalLink match {
case Some(join: Join) =>
val (buildSize, probeSize) = buildSide match {
case BuildLeft => (join.left.stats.sizeInBytes, join.right.stats.sizeInBytes)
case BuildRight => (join.right.stats.sizeInBytes, join.left.stats.sizeInBytes)
}

if (maxBuildSize != Long.MaxValue && buildSize >= maxBuildSize) {
Some(s"build side too large: $buildSize bytes >= " +
s"autoBroadcastJoinThreshold($broadcastThreshold) * " +
s"numShufflePartitions($numPartitions) = $maxBuildSize bytes")
} else if (buildSize * sizeRatio > probeSize) {
Some(s"build side not much smaller than probe side: " +
s"buildSize($buildSize) * sizeRatio($sizeRatio) > probeSize($probeSize)")
} else {
None // OK to rewrite
}

case _ =>
Some("no logical plan statistics available to estimate join sizes")
}
}

def rewrite(plan: SparkPlan): SparkPlan = plan match {
case smj: SortMergeJoinExec =>
getSmjBuildSide(smj) match {
Expand All @@ -76,15 +132,23 @@ object RewriteJoin extends JoinSelectionHelper {
s"BuildRight with ${smj.joinType} is not supported")
plan
case Some(buildSide) =>
ShuffledHashJoinExec(
smj.leftKeys,
smj.rightKeys,
smj.joinType,
buildSide,
smj.condition,
removeSort(smj.left),
removeSort(smj.right),
smj.isSkewJoin)
canRewriteToHashJoin(smj, buildSide) match {
case Some(reason) =>
withInfo(
smj,
s"Not rewriting SortMergeJoin to HashJoin: $reason")
plan
case None =>
ShuffledHashJoinExec(
smj.leftKeys,
smj.rightKeys,
smj.joinType,
buildSide,
smj.condition,
removeSort(smj.left),
removeSort(smj.right),
smj.isSkewJoin)
}
case _ => plan
}
case _ => plan
Expand Down
Loading