From 94908b41e844167424e6fb13b31262aec53b50b5 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Fri, 20 Feb 2026 08:28:04 -0700 Subject: [PATCH] Add statistics-based guards to SortMergeJoin-to-HashJoin rewrite Previously, enabling `spark.comet.exec.replaceSortMergeJoin` would unconditionally rewrite all SortMergeJoins to ShuffledHashJoins, which could cause OOM when the build side is too large to fit in memory. This adds two checks mirroring Spark's own JoinSelection logic: 1. Per-partition size check: the build side must fit in memory using Spark's `autoBroadcastJoinThreshold * numShufflePartitions` formula (matching `canBuildLocalHashMapBySize()`). 2. Size ratio check: the build side must be significantly smaller than the probe side, controlled by a new config `spark.comet.exec.replaceSortMergeJoin.sizeRatio` (default: 3), matching Spark's `SHUFFLE_HASH_JOIN_FACTOR`. When either check fails, the SortMergeJoin is kept and the reason is logged via `withInfo` (visible with `explainFallback.enabled=true`). Co-Authored-By: Claude Opus 4.6 --- .../scala/org/apache/comet/CometConf.scala | 19 ++++- docs/source/user-guide/latest/tuning.md | 12 +++ .../org/apache/comet/rules/RewriteJoin.scala | 84 ++++++++++++++++--- 3 files changed, 103 insertions(+), 12 deletions(-) diff --git a/common/src/main/scala/org/apache/comet/CometConf.scala b/common/src/main/scala/org/apache/comet/CometConf.scala index 480eafdcb7..128de4544f 100644 --- a/common/src/main/scala/org/apache/comet/CometConf.scala +++ b/common/src/main/scala/org/apache/comet/CometConf.scala @@ -365,11 +365,26 @@ object CometConf extends ShimCometConf { val COMET_REPLACE_SMJ: ConfigEntry[Boolean] = conf(s"$COMET_EXEC_CONFIG_PREFIX.replaceSortMergeJoin") .category(CATEGORY_EXEC) - .doc("Experimental feature to force Spark to replace SortMergeJoin with ShuffledHashJoin " + - s"for improved performance. This feature is not stable yet. $TUNING_GUIDE.") + .doc("Experimental feature to replace SortMergeJoin with ShuffledHashJoin " + + "for improved performance. When enabled, Comet will only rewrite a " + + "SortMergeJoin to a ShuffledHashJoin if the build side is estimated to " + + "fit in memory (using Spark's autoBroadcastJoinThreshold * " + + "numShufflePartitions) and is significantly smaller than the probe side " + + s"(controlled by replaceSortMergeJoin.sizeRatio). $TUNING_GUIDE.") .booleanConf .createWithDefault(false) + val COMET_REPLACE_SMJ_SIZE_RATIO: ConfigEntry[Int] = + conf(s"$COMET_EXEC_CONFIG_PREFIX.replaceSortMergeJoin.sizeRatio") + .category(CATEGORY_EXEC) + .doc("Minimum ratio of probe side size to build side size required when " + + "replacing SortMergeJoin with ShuffledHashJoin. Building a hash table is " + + "more expensive than sorting, so the build side should be significantly " + + "smaller than the probe side. Matches Spark's " + + s"spark.sql.shuffledHashJoinFactor behavior. $TUNING_GUIDE.") + .intConf + .createWithDefault(3) + val COMET_EXEC_SHUFFLE_WITH_HASH_PARTITIONING_ENABLED: ConfigEntry[Boolean] = conf("spark.comet.native.shuffle.partitioning.hash.enabled") .category(CATEGORY_SHUFFLE) diff --git a/docs/source/user-guide/latest/tuning.md b/docs/source/user-guide/latest/tuning.md index 5939e89ef3..1a5ee93724 100644 --- a/docs/source/user-guide/latest/tuning.md +++ b/docs/source/user-guide/latest/tuning.md @@ -116,6 +116,18 @@ to test with both for your specific workloads. To configure Comet to convert `SortMergeJoin` to `ShuffledHashJoin`, set `spark.comet.exec.replaceSortMergeJoin=true`. +When this feature is enabled, Comet uses statistics-based guards to only rewrite joins where the build side is +estimated to fit in memory. Specifically: + +- The build side must be small enough that each partition's hash table fits in memory, using Spark's + `autoBroadcastJoinThreshold * numShufflePartitions` as the threshold (matching Spark's own `canBuildLocalHashMapBySize` + logic). +- The build side must be significantly smaller than the probe side. The required ratio is controlled by + `spark.comet.exec.replaceSortMergeJoin.sizeRatio` (default: 3), matching Spark's `SHUFFLE_HASH_JOIN_FACTOR`. + +If either check fails, the `SortMergeJoin` is kept as-is. To understand why a specific join was not rewritten, enable +`spark.comet.explainFallback.enabled=true` and check the logs. + ## Shuffle Comet provides accelerated shuffle implementations that can be used to improve the performance of your queries. diff --git a/spark/src/main/scala/org/apache/comet/rules/RewriteJoin.scala b/spark/src/main/scala/org/apache/comet/rules/RewriteJoin.scala index a4d31a59ac..a970f70717 100644 --- a/spark/src/main/scala/org/apache/comet/rules/RewriteJoin.scala +++ b/spark/src/main/scala/org/apache/comet/rules/RewriteJoin.scala @@ -24,13 +24,16 @@ import org.apache.spark.sql.catalyst.plans.{LeftAnti, LeftSemi} import org.apache.spark.sql.catalyst.plans.logical.Join import org.apache.spark.sql.execution.{SortExec, SparkPlan} import org.apache.spark.sql.execution.joins.{ShuffledHashJoinExec, SortMergeJoinExec} +import org.apache.spark.sql.internal.SQLConf +import org.apache.comet.CometConf import org.apache.comet.CometSparkSessionExtensions.withInfo /** * Adapted from equivalent rule in Apache Gluten. * - * This rule replaces [[SortMergeJoinExec]] with [[ShuffledHashJoinExec]]. + * This rule replaces [[SortMergeJoinExec]] with [[ShuffledHashJoinExec]] when the build side + * is estimated to fit in memory and is significantly smaller than the probe side. */ object RewriteJoin extends JoinSelectionHelper { @@ -64,6 +67,59 @@ object RewriteJoin extends JoinSelectionHelper { case _ => plan } + /** + * Check whether the build side of a join is small enough to rewrite from SortMergeJoin + * to ShuffledHashJoin. Uses two checks mirroring Spark's JoinSelection logic: + * + * 1. Per-partition size: build side size / numShufflePartitions must be less than + * autoBroadcastJoinThreshold (i.e., each partition's hash table fits in memory). + * This matches Spark's canBuildLocalHashMapBySize(). + * + * 2. Size ratio: the build side must be significantly smaller than the probe side, + * controlled by spark.comet.exec.replaceSortMergeJoin.sizeRatio (default 3). + * This matches Spark's muchSmaller() / SHUFFLE_HASH_JOIN_FACTOR. + * + * If no logical link with statistics is available, the rewrite is skipped to be safe. + */ + private def canRewriteToHashJoin( + smj: SortMergeJoinExec, + buildSide: BuildSide): Option[String] = { + val conf = SQLConf.get + val sizeRatio = CometConf.COMET_REPLACE_SMJ_SIZE_RATIO.get() + val broadcastThreshold = conf.autoBroadcastJoinThreshold + val numPartitions = conf.numShufflePartitions + + // If broadcast threshold is -1 (disabled), skip the per-partition check + // but still enforce the size ratio check + val maxBuildSize = if (broadcastThreshold > 0) { + broadcastThreshold * numPartitions + } else { + Long.MaxValue + } + + smj.logicalLink match { + case Some(join: Join) => + val (buildSize, probeSize) = buildSide match { + case BuildLeft => (join.left.stats.sizeInBytes, join.right.stats.sizeInBytes) + case BuildRight => (join.right.stats.sizeInBytes, join.left.stats.sizeInBytes) + } + + if (maxBuildSize != Long.MaxValue && buildSize >= maxBuildSize) { + Some(s"build side too large: $buildSize bytes >= " + + s"autoBroadcastJoinThreshold($broadcastThreshold) * " + + s"numShufflePartitions($numPartitions) = $maxBuildSize bytes") + } else if (buildSize * sizeRatio > probeSize) { + Some(s"build side not much smaller than probe side: " + + s"buildSize($buildSize) * sizeRatio($sizeRatio) > probeSize($probeSize)") + } else { + None // OK to rewrite + } + + case _ => + Some("no logical plan statistics available to estimate join sizes") + } + } + def rewrite(plan: SparkPlan): SparkPlan = plan match { case smj: SortMergeJoinExec => getSmjBuildSide(smj) match { @@ -76,15 +132,23 @@ object RewriteJoin extends JoinSelectionHelper { s"BuildRight with ${smj.joinType} is not supported") plan case Some(buildSide) => - ShuffledHashJoinExec( - smj.leftKeys, - smj.rightKeys, - smj.joinType, - buildSide, - smj.condition, - removeSort(smj.left), - removeSort(smj.right), - smj.isSkewJoin) + canRewriteToHashJoin(smj, buildSide) match { + case Some(reason) => + withInfo( + smj, + s"Not rewriting SortMergeJoin to HashJoin: $reason") + plan + case None => + ShuffledHashJoinExec( + smj.leftKeys, + smj.rightKeys, + smj.joinType, + buildSide, + smj.condition, + removeSort(smj.left), + removeSort(smj.right), + smj.isSkewJoin) + } case _ => plan } case _ => plan