diff --git a/common/src/main/scala/org/apache/comet/CometConf.scala b/common/src/main/scala/org/apache/comet/CometConf.scala index 480eafdcb7..128de4544f 100644 --- a/common/src/main/scala/org/apache/comet/CometConf.scala +++ b/common/src/main/scala/org/apache/comet/CometConf.scala @@ -365,11 +365,26 @@ object CometConf extends ShimCometConf { val COMET_REPLACE_SMJ: ConfigEntry[Boolean] = conf(s"$COMET_EXEC_CONFIG_PREFIX.replaceSortMergeJoin") .category(CATEGORY_EXEC) - .doc("Experimental feature to force Spark to replace SortMergeJoin with ShuffledHashJoin " + - s"for improved performance. This feature is not stable yet. $TUNING_GUIDE.") + .doc("Experimental feature to replace SortMergeJoin with ShuffledHashJoin " + + "for improved performance. When enabled, Comet will only rewrite a " + + "SortMergeJoin to a ShuffledHashJoin if the build side is estimated to " + + "fit in memory (using Spark's autoBroadcastJoinThreshold * " + + "numShufflePartitions) and is significantly smaller than the probe side " + + s"(controlled by replaceSortMergeJoin.sizeRatio). $TUNING_GUIDE.") .booleanConf .createWithDefault(false) + val COMET_REPLACE_SMJ_SIZE_RATIO: ConfigEntry[Int] = + conf(s"$COMET_EXEC_CONFIG_PREFIX.replaceSortMergeJoin.sizeRatio") + .category(CATEGORY_EXEC) + .doc("Minimum ratio of probe side size to build side size required when " + + "replacing SortMergeJoin with ShuffledHashJoin. Building a hash table is " + + "more expensive than sorting, so the build side should be significantly " + + "smaller than the probe side. Matches Spark's " + + s"spark.sql.shuffledHashJoinFactor behavior. $TUNING_GUIDE.") + .intConf + .createWithDefault(3) + val COMET_EXEC_SHUFFLE_WITH_HASH_PARTITIONING_ENABLED: ConfigEntry[Boolean] = conf("spark.comet.native.shuffle.partitioning.hash.enabled") .category(CATEGORY_SHUFFLE) diff --git a/docs/source/user-guide/latest/tuning.md b/docs/source/user-guide/latest/tuning.md index 5939e89ef3..1a5ee93724 100644 --- a/docs/source/user-guide/latest/tuning.md +++ b/docs/source/user-guide/latest/tuning.md @@ -116,6 +116,18 @@ to test with both for your specific workloads. To configure Comet to convert `SortMergeJoin` to `ShuffledHashJoin`, set `spark.comet.exec.replaceSortMergeJoin=true`. +When this feature is enabled, Comet uses statistics-based guards to only rewrite joins where the build side is +estimated to fit in memory. Specifically: + +- The build side must be small enough that each partition's hash table fits in memory, using Spark's + `autoBroadcastJoinThreshold * numShufflePartitions` as the threshold (matching Spark's own `canBuildLocalHashMapBySize` + logic). +- The build side must be significantly smaller than the probe side. The required ratio is controlled by + `spark.comet.exec.replaceSortMergeJoin.sizeRatio` (default: 3), matching Spark's `SHUFFLE_HASH_JOIN_FACTOR`. + +If either check fails, the `SortMergeJoin` is kept as-is. To understand why a specific join was not rewritten, enable +`spark.comet.explainFallback.enabled=true` and check the logs. + ## Shuffle Comet provides accelerated shuffle implementations that can be used to improve the performance of your queries. diff --git a/spark/src/main/scala/org/apache/comet/rules/RewriteJoin.scala b/spark/src/main/scala/org/apache/comet/rules/RewriteJoin.scala index a4d31a59ac..a970f70717 100644 --- a/spark/src/main/scala/org/apache/comet/rules/RewriteJoin.scala +++ b/spark/src/main/scala/org/apache/comet/rules/RewriteJoin.scala @@ -24,13 +24,16 @@ import org.apache.spark.sql.catalyst.plans.{LeftAnti, LeftSemi} import org.apache.spark.sql.catalyst.plans.logical.Join import org.apache.spark.sql.execution.{SortExec, SparkPlan} import org.apache.spark.sql.execution.joins.{ShuffledHashJoinExec, SortMergeJoinExec} +import org.apache.spark.sql.internal.SQLConf +import org.apache.comet.CometConf import org.apache.comet.CometSparkSessionExtensions.withInfo /** * Adapted from equivalent rule in Apache Gluten. * - * This rule replaces [[SortMergeJoinExec]] with [[ShuffledHashJoinExec]]. + * This rule replaces [[SortMergeJoinExec]] with [[ShuffledHashJoinExec]] when the build side + * is estimated to fit in memory and is significantly smaller than the probe side. */ object RewriteJoin extends JoinSelectionHelper { @@ -64,6 +67,59 @@ object RewriteJoin extends JoinSelectionHelper { case _ => plan } + /** + * Check whether the build side of a join is small enough to rewrite from SortMergeJoin + * to ShuffledHashJoin. Uses two checks mirroring Spark's JoinSelection logic: + * + * 1. Per-partition size: build side size / numShufflePartitions must be less than + * autoBroadcastJoinThreshold (i.e., each partition's hash table fits in memory). + * This matches Spark's canBuildLocalHashMapBySize(). + * + * 2. Size ratio: the build side must be significantly smaller than the probe side, + * controlled by spark.comet.exec.replaceSortMergeJoin.sizeRatio (default 3). + * This matches Spark's muchSmaller() / SHUFFLE_HASH_JOIN_FACTOR. + * + * If no logical link with statistics is available, the rewrite is skipped to be safe. + */ + private def canRewriteToHashJoin( + smj: SortMergeJoinExec, + buildSide: BuildSide): Option[String] = { + val conf = SQLConf.get + val sizeRatio = CometConf.COMET_REPLACE_SMJ_SIZE_RATIO.get() + val broadcastThreshold = conf.autoBroadcastJoinThreshold + val numPartitions = conf.numShufflePartitions + + // If broadcast threshold is -1 (disabled), skip the per-partition check + // but still enforce the size ratio check + val maxBuildSize = if (broadcastThreshold > 0) { + broadcastThreshold * numPartitions + } else { + Long.MaxValue + } + + smj.logicalLink match { + case Some(join: Join) => + val (buildSize, probeSize) = buildSide match { + case BuildLeft => (join.left.stats.sizeInBytes, join.right.stats.sizeInBytes) + case BuildRight => (join.right.stats.sizeInBytes, join.left.stats.sizeInBytes) + } + + if (maxBuildSize != Long.MaxValue && buildSize >= maxBuildSize) { + Some(s"build side too large: $buildSize bytes >= " + + s"autoBroadcastJoinThreshold($broadcastThreshold) * " + + s"numShufflePartitions($numPartitions) = $maxBuildSize bytes") + } else if (buildSize * sizeRatio > probeSize) { + Some(s"build side not much smaller than probe side: " + + s"buildSize($buildSize) * sizeRatio($sizeRatio) > probeSize($probeSize)") + } else { + None // OK to rewrite + } + + case _ => + Some("no logical plan statistics available to estimate join sizes") + } + } + def rewrite(plan: SparkPlan): SparkPlan = plan match { case smj: SortMergeJoinExec => getSmjBuildSide(smj) match { @@ -76,15 +132,23 @@ object RewriteJoin extends JoinSelectionHelper { s"BuildRight with ${smj.joinType} is not supported") plan case Some(buildSide) => - ShuffledHashJoinExec( - smj.leftKeys, - smj.rightKeys, - smj.joinType, - buildSide, - smj.condition, - removeSort(smj.left), - removeSort(smj.right), - smj.isSkewJoin) + canRewriteToHashJoin(smj, buildSide) match { + case Some(reason) => + withInfo( + smj, + s"Not rewriting SortMergeJoin to HashJoin: $reason") + plan + case None => + ShuffledHashJoinExec( + smj.leftKeys, + smj.rightKeys, + smj.joinType, + buildSide, + smj.condition, + removeSort(smj.left), + removeSort(smj.right), + smj.isSkewJoin) + } case _ => plan } case _ => plan