From 535e933130cda8c5a47c6f62e3981955aceb4152 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E7=A6=8F=E6=99=BA?= <14267354+follow__my_heart@user.noreply.gitee.com> Date: Mon, 17 Mar 2025 13:42:21 +0000 Subject: [PATCH 1/5] update spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: 王福智 <14267354+follow__my_heart@user.noreply.gitee.com> --- .../adaptive/AdaptiveSparkPlanExec.scala | 130 +++++++++--------- 1 file changed, 67 insertions(+), 63 deletions(-) diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala index 6c7ff9119..fdcdee14f 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala @@ -40,7 +40,6 @@ import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec._ import org.apache.spark.sql.execution.bucketing.DisableUnnecessaryBucketedScan import org.apache.spark.sql.execution.exchange._ -import org.apache.spark.sql.execution.window.TopNPushDownForWindow import org.apache.spark.sql.execution.ui.{SparkListenerSQLAdaptiveExecutionUpdate, SparkListenerSQLAdaptiveSQLMetricUpdates, SQLPlanMetric} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.vectorized.ColumnarBatch @@ -63,11 +62,11 @@ import org.apache.spark.util.{SparkFatalException, ThreadUtils} * the rest of the plan. */ case class AdaptiveSparkPlanExec( - inputPlan: SparkPlan, - @transient context: AdaptiveExecutionContext, - @transient preprocessingRules: Seq[Rule[SparkPlan]], - @transient isSubquery: Boolean, - @transient override val supportsColumnar: Boolean = false) + inputPlan: SparkPlan, + @transient context: AdaptiveExecutionContext, + @transient preprocessingRules: Seq[Rule[SparkPlan]], + @transient isSubquery: Boolean, + @transient override val supportsColumnar: Boolean = false) extends LeafExecNode { @transient private val lock = new Object() @@ -122,27 +121,32 @@ case class AdaptiveSparkPlanExec( ReplaceHashWithSortAgg, RemoveRedundantSorts, DisableUnnecessaryBucketedScan, - TopNPushDownForWindow, OptimizeSkewedJoin(ensureRequirements) ) ++ context.session.sessionState.queryStagePrepRules } // A list of physical optimizer rules to be applied to a new stage before its execution. These // optimizations should be stage-independent. - @transient private val queryStageOptimizerRules: Seq[Rule[SparkPlan]] = Seq( - PlanAdaptiveDynamicPruningFilters(this), - ReuseAdaptiveSubquery(context.subqueryCache), - OptimizeSkewInRebalancePartitions, - CoalesceShufflePartitions(context.session), - // `OptimizeShuffleWithLocalRead` needs to make use of 'AQEShuffleReadExec.partitionSpecs' - // added by `CoalesceShufflePartitions`, and must be executed after it. - OptimizeShuffleWithLocalRead - ) + @transient private val queryStageOptimizerRules: Seq[Rule[SparkPlan]] = { + val ensureRequirements = EnsureRequirements( + requiredDistribution.isDefined, requiredDistribution) + Seq( + PlanAdaptiveDynamicPruningFilters(this), + ReuseAdaptiveSubquery(context.subqueryCache), + OptimizeSkewInRebalancePartitions, + OptimizeSkewShufflePartition( + ensureRequirements, context.session.sparkContext.defaultParallelism), + CoalesceShufflePartitions(context.session), + // `OptimizeShuffleWithLocalRead` needs to make use of 'AQEShuffleReadExec.partitionSpecs' + // added by `CoalesceShufflePartitions`, and must be executed after it. + OptimizeShuffleWithLocalRead + ) + } // This rule is stateful as it maintains the codegen stage ID. We can't create a fresh one every // time and need to keep it in a variable. @transient private val collapseCodegenStagesRule: Rule[SparkPlan] = - CollapseCodegenStages() + CollapseCodegenStages() // A list of physical optimizer rules to be applied right after a new stage is created. The input // plan to these rules has exchange as its root node. @@ -199,9 +203,9 @@ case class AdaptiveSparkPlanExec( * @param newStages the newly created query stages, including new reused query stages. */ private case class CreateStageResult( - newPlan: SparkPlan, - allChildStagesMaterialized: Boolean, - newStages: Seq[QueryStageExec]) + newPlan: SparkPlan, + allChildStagesMaterialized: Boolean, + newStages: Seq[QueryStageExec]) def executedPlan: SparkPlan = currentPhysicalPlan @@ -380,15 +384,15 @@ case class AdaptiveSparkPlanExec( protected override def stringArgs: Iterator[Any] = Iterator(s"isFinalPlan=$isFinalPlan") override def generateTreeString( - depth: Int, - lastChildren: Seq[Boolean], - append: String => Unit, - verbose: Boolean, - prefix: String = "", - addSuffix: Boolean = false, - maxFields: Int, - printNodeId: Boolean, - indent: Int = 0): Unit = { + depth: Int, + lastChildren: Seq[Boolean], + append: String => Unit, + verbose: Boolean, + prefix: String = "", + addSuffix: Boolean = false, + maxFields: Int, + printNodeId: Boolean, + indent: Int = 0): Unit = { super.generateTreeString( depth, lastChildren, @@ -432,13 +436,13 @@ case class AdaptiveSparkPlanExec( private def generateTreeStringWithHeader( - header: String, - plan: SparkPlan, - depth: Int, - append: String => Unit, - verbose: Boolean, - maxFields: Int, - printNodeId: Boolean): Unit = { + header: String, + plan: SparkPlan, + depth: Int, + append: String => Unit, + verbose: Boolean, + maxFields: Int, + printNodeId: Boolean): Unit = { append(" " * depth) append(s"+- == $header ==\n") plan.generateTreeString( @@ -585,35 +589,35 @@ case class AdaptiveSparkPlanExec( * For each query stage in `stagesToReplace`, find their corresponding logical nodes in the * `logicalPlan` and replace them with new [[LogicalQueryStage]] nodes. * 1. If the query stage can be mapped to an integral logical sub-tree, replace the corresponding - * logical sub-tree with a leaf node [[LogicalQueryStage]] referencing this query stage. For - * example: - * Join SMJ SMJ - * / \ / \ / \ - * r1 r2 => Xchg1 Xchg2 => Stage1 Stage2 - * | | - * r1 r2 - * The updated plan node will be: - * Join - * / \ - * LogicalQueryStage1(Stage1) LogicalQueryStage2(Stage2) + * logical sub-tree with a leaf node [[LogicalQueryStage]] referencing this query stage. For + * example: + * Join SMJ SMJ + * / \ / \ / \ + * r1 r2 => Xchg1 Xchg2 => Stage1 Stage2 + * | | + * r1 r2 + * The updated plan node will be: + * Join + * / \ + * LogicalQueryStage1(Stage1) LogicalQueryStage2(Stage2) * * 2. Otherwise (which means the query stage can only be mapped to part of a logical sub-tree), - * replace the corresponding logical sub-tree with a leaf node [[LogicalQueryStage]] - * referencing to the top physical node into which this logical node is transformed during - * physical planning. For example: - * Agg HashAgg HashAgg - * | | | - * child => Xchg => Stage1 - * | - * HashAgg - * | - * child - * The updated plan node will be: - * LogicalQueryStage(HashAgg - Stage1) + * replace the corresponding logical sub-tree with a leaf node [[LogicalQueryStage]] + * referencing to the top physical node into which this logical node is transformed during + * physical planning. For example: + * Agg HashAgg HashAgg + * | | | + * child => Xchg => Stage1 + * | + * HashAgg + * | + * child + * The updated plan node will be: + * LogicalQueryStage(HashAgg - Stage1) */ private def replaceWithQueryStagesInLogicalPlan( - plan: LogicalPlan, - stagesToReplace: Seq[QueryStageExec]): LogicalPlan = { + plan: LogicalPlan, + stagesToReplace: Seq[QueryStageExec]): LogicalPlan = { var logicalPlan = plan stagesToReplace.foreach { case stage if currentPhysicalPlan.exists(_.eq(stage)) => @@ -724,8 +728,8 @@ case class AdaptiveSparkPlanExec( * materialization errors and stage cancellation errors. */ private def cleanUpAndThrowException( - errors: Seq[Throwable], - earlyFailedStage: Option[Int]): Unit = { + errors: Seq[Throwable], + earlyFailedStage: Option[Int]): Unit = { currentPhysicalPlan.foreach { // earlyFailedStage is the stage which failed before calling doMaterialize, // so we should avoid calling cancel on it to re-trigger the failure again. -- Gitee From 903189a0c34f908649acd48d410fc69705b59ec3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E7=A6=8F=E6=99=BA?= <14267354+follow__my_heart@user.noreply.gitee.com> Date: Mon, 17 Mar 2025 13:43:06 +0000 Subject: [PATCH 2/5] add spark/sql/execution/adaptive/OptimizeSkewedJoin.scala. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: 王福智 <14267354+follow__my_heart@user.noreply.gitee.com> --- .../adaptive/OptimizeSkewedJoin.scala | 299 ++++++++++++++++++ 1 file changed, 299 insertions(+) create mode 100644 omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala new file mode 100644 index 000000000..1f5861f0b --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala @@ -0,0 +1,299 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.adaptive + +import scala.collection.mutable + +import org.apache.commons.io.FileUtils + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Attribute, SortOrder} +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.physical.Partitioning +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution._ +import org.apache.spark.sql.execution.exchange.{ENSURE_REQUIREMENTS, EnsureRequirements, SKEW_NORMAL_SHUFFLE, ValidateRequirements} +import org.apache.spark.sql.execution.joins.{ShuffledHashJoinExec, SortMergeJoinExec} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.util.Utils + +/** + * A rule to optimize skewed joins to avoid straggler tasks whose share of data are significantly + * larger than those of the rest of the tasks. + * + * The general idea is to divide each skew partition into smaller partitions and replicate its + * matching partition on the other side of the join so that they can run in parallel tasks. + * Note that when matching partitions from the left side and the right side both have skew, + * it will become a cartesian product of splits from left and right joining together. + * + * For example, assume the Sort-Merge join has 4 partitions: + * left: [L1, L2, L3, L4] + * right: [R1, R2, R3, R4] + * + * Let's say L2, L4 and R3, R4 are skewed, and each of them get split into 2 sub-partitions. This + * is scheduled to run 4 tasks at the beginning: (L1, R1), (L2, R2), (L3, R3), (L4, R4). + * This rule expands it to 9 tasks to increase parallelism: + * (L1, R1), + * (L2-1, R2), (L2-2, R2), + * (L3, R3-1), (L3, R3-2), + * (L4-1, R4-1), (L4-2, R4-1), (L4-1, R4-2), (L4-2, R4-2) + */ +case class OmniOptimizeSkewedJoin(ensureRequirements: EnsureRequirements) + extends Rule[SparkPlan] { + + /** + * A partition is considered as a skewed partition if its size is larger than the median + * partition size * SKEW_JOIN_SKEWED_PARTITION_FACTOR and also larger than + * SKEW_JOIN_SKEWED_PARTITION_THRESHOLD. Thus we pick the larger one as the skew threshold. + */ + def getSkewThreshold(medianSize: Long): Long = { + conf.getConf(SQLConf.SKEW_JOIN_SKEWED_PARTITION_THRESHOLD).max( + medianSize * conf.getConf(SQLConf.SKEW_JOIN_SKEWED_PARTITION_FACTOR)) + } + + /** + * The goal of skew join optimization is to make the data distribution more even. The target size + * to split skewed partitions is the average size of non-skewed partition, or the + * advisory partition size if avg size is smaller than it. + */ + private def targetSize(sizes: Array[Long], skewThreshold: Long): Long = { + val advisorySize = conf.getConf(SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES) + val nonSkewSizes = sizes.filter(_ <= skewThreshold) + if (nonSkewSizes.isEmpty) { + advisorySize + } else { + math.max(advisorySize, nonSkewSizes.sum / nonSkewSizes.length) + } + } + + private def canSplitLeftSide(joinType: JoinType) = { + joinType == Inner || joinType == Cross || joinType == LeftSemi || + joinType == LeftAnti || joinType == LeftOuter + } + + private def canSplitRightSide(joinType: JoinType) = { + joinType == Inner || joinType == Cross || joinType == RightOuter + } + + private def getSizeInfo(medianSize: Long, sizes: Array[Long]): String = { + s"median size: $medianSize, max size: ${sizes.max}, min size: ${sizes.min}, avg size: " + + sizes.sum / sizes.length + } + + /* + * This method aim to optimize the skewed join with the following steps: + * 1. Check whether the shuffle partition is skewed based on the median size + * and the skewed partition threshold in origin shuffled join (smj and shj). + * 2. Assuming partition0 is skewed in left side, and it has 5 mappers (Map0, Map1...Map4). + * And we may split the 5 Mappers into 3 mapper ranges [(Map0, Map1), (Map2, Map3), (Map4)] + * based on the map size and the max split number. + * 3. Wrap the join left child with a special shuffle read that loads each mapper range with one + * task, so total 3 tasks. + * 4. Wrap the join right child with a special shuffle read that loads partition0 3 times by + * 3 tasks separately. + */ + private def tryOptimizeJoinChildren( + left: ShuffleQueryStageExec, + right: ShuffleQueryStageExec, + joinType: JoinType): Option[(SparkPlan, SparkPlan)] = { + val canSplitLeft = canSplitLeftSide(joinType) + val canSplitRight = canSplitRightSide(joinType) + if (!canSplitLeft && !canSplitRight) return None + + val leftSizes = left.mapStats.get.bytesByPartitionId + val rightSizes = right.mapStats.get.bytesByPartitionId + assert(leftSizes.length == rightSizes.length) + val numPartitions = leftSizes.length + // We use the median size of the original shuffle partitions to detect skewed partitions. + val leftMedSize = Utils.median(leftSizes, false) + val rightMedSize = Utils.median(rightSizes, false) + logDebug( + s""" + |Optimizing skewed join. + |Left side partitions size info: + |${getSizeInfo(leftMedSize, leftSizes)} + |Right side partitions size info: + |${getSizeInfo(rightMedSize, rightSizes)} + """.stripMargin) + + val leftSkewThreshold = getSkewThreshold(leftMedSize) + val rightSkewThreshold = getSkewThreshold(rightMedSize) + val leftTargetSize = targetSize(leftSizes, leftSkewThreshold) + val rightTargetSize = targetSize(rightSizes, rightSkewThreshold) + + val leftSidePartitions = mutable.ArrayBuffer.empty[ShufflePartitionSpec] + val rightSidePartitions = mutable.ArrayBuffer.empty[ShufflePartitionSpec] + var numSkewedLeft = 0 + var numSkewedRight = 0 + for (partitionIndex <- 0 until numPartitions) { + val leftSize = leftSizes(partitionIndex) + val isLeftSkew = canSplitLeft && leftSize > leftSkewThreshold + val rightSize = rightSizes(partitionIndex) + val isRightSkew = canSplitRight && rightSize > rightSkewThreshold + val leftNoSkewPartitionSpec = + Seq(CoalescedPartitionSpec(partitionIndex, partitionIndex + 1, leftSize)) + val rightNoSkewPartitionSpec = + Seq(CoalescedPartitionSpec(partitionIndex, partitionIndex + 1, rightSize)) + + val leftParts = if (isLeftSkew) { + val skewSpecs = ShufflePartitionsUtil.createSkewPartitionSpecs( + left.mapStats.get.shuffleId, partitionIndex, leftTargetSize) + if (skewSpecs.isDefined) { + logDebug(s"Left side partition $partitionIndex " + + s"(${FileUtils.byteCountToDisplaySize(leftSize)}) is skewed, " + + s"split it into ${skewSpecs.get.length} parts.") + numSkewedLeft += 1 + } + skewSpecs.getOrElse(leftNoSkewPartitionSpec) + } else { + leftNoSkewPartitionSpec + } + + val rightParts = if (isRightSkew) { + val skewSpecs = ShufflePartitionsUtil.createSkewPartitionSpecs( + right.mapStats.get.shuffleId, partitionIndex, rightTargetSize) + if (skewSpecs.isDefined) { + logDebug(s"Right side partition $partitionIndex " + + s"(${FileUtils.byteCountToDisplaySize(rightSize)}) is skewed, " + + s"split it into ${skewSpecs.get.length} parts.") + numSkewedRight += 1 + } + skewSpecs.getOrElse(rightNoSkewPartitionSpec) + } else { + rightNoSkewPartitionSpec + } + + for { + leftSidePartition <- leftParts + rightSidePartition <- rightParts + } { + leftSidePartitions += leftSidePartition + rightSidePartitions += rightSidePartition + } + } + logDebug(s"number of skewed partitions: left $numSkewedLeft, right $numSkewedRight") + if (numSkewedLeft > 0 || numSkewedRight > 0) { + Some(( + SkewJoinChildWrapper(AQEShuffleReadExec(left, leftSidePartitions.toSeq)), + SkewJoinChildWrapper(AQEShuffleReadExec(right, rightSidePartitions.toSeq)) + )) + } else { + None + } + } + + def optimizeSkewJoin(plan: SparkPlan): SparkPlan = plan.transformUp { + case smj @ SortMergeJoinExec(_, _, joinType, _, + s1 @ SortExec(_, _, ShuffleStage(left: ShuffleQueryStageExec), _), + s2 @ SortExec(_, _, ShuffleStage(right: ShuffleQueryStageExec), _), false) => + tryOptimizeJoinChildren(left, right, joinType).map { + case (newLeft, newRight) => + smj.copy( + left = s1.copy(child = newLeft), right = s2.copy(child = newRight), isSkewJoin = true) + }.getOrElse(smj) + + case shj @ ShuffledHashJoinExec(_, _, joinType, _, _, + ShuffleStage(left: ShuffleQueryStageExec), + ShuffleStage(right: ShuffleQueryStageExec), false) => + tryOptimizeJoinChildren(left, right, joinType).map { + case (newLeft, newRight) => + shj.copy(left = newLeft, right = newRight, isSkewJoin = true) + }.getOrElse(shj) + } + + private def rewriteJoinOrigin(plan: SparkPlan): SparkPlan = plan.transformDown { + case smj @ SortMergeJoinExec(_, _, _, _, + SortExec(_, _, ShuffleStage(left: ShuffleQueryStageExec), _), + SortExec(_, _, ShuffleStage(right: ShuffleQueryStageExec), _), false) => { + // join场景下,不对本体的sort/aggr进行优化 + logWarning(s"skewed join left.shuffleOrigin ${left.shuffleOrigin}") + logWarning(s"skewed join left.shuffleOrigin ${right.shuffleOrigin}") + logWarning(s"side skew, and make optimize SKew Join, SKEW_NORMAL_SHUFFLE ") + left.shuffleOrigin = SKEW_NORMAL_SHUFFLE + right.shuffleOrigin = SKEW_NORMAL_SHUFFLE + logWarning(s"skewed join left.shuffleOrigin ${left.shuffleOrigin}") + logWarning(s"skewed join left.shuffleOrigin ${right.shuffleOrigin}") + smj + } + + case shj @ ShuffledHashJoinExec(_, _, _, _, _, + ShuffleStage(left: ShuffleQueryStageExec), + ShuffleStage(right: ShuffleQueryStageExec), false) => { + // join场景下,不对本体的sort/aggr进行优化 + logWarning(s"skewed join left.shuffleOrigin ${left.shuffleOrigin}") + logWarning(s"skewed join left.shuffleOrigin ${right.shuffleOrigin}") + logWarning(s"side skew, and make optimize SKew Join, SKEW_NORMAL_SHUFFLE ") + left.shuffleOrigin = SKEW_NORMAL_SHUFFLE + right.shuffleOrigin = SKEW_NORMAL_SHUFFLE + logWarning(s"skewed join left.shuffleOrigin ${left.shuffleOrigin}") + logWarning(s"skewed join left.shuffleOrigin ${right.shuffleOrigin}") + shj + } + } + + override def apply(plan: SparkPlan): SparkPlan = { + rewriteJoinOrigin(plan) + if (!conf.getConf(SQLConf.SKEW_JOIN_ENABLED)) { + logWarning(s"return OptimizeSkewedJoin abc") + return plan + } + + // We try to optimize every skewed sort-merge/shuffle-hash joins in the query plan. If this + // introduces extra shuffles, we give up the optimization and return the original query plan, or + // accept the extra shuffles if the force-apply config is true. + // TODO: It's possible that only one skewed join in the query plan leads to extra shuffles and + // we only need to skip optimizing that join. We should make the strategy smarter here. + val optimized = optimizeSkewJoin(plan) + val requirementSatisfied = if (ensureRequirements.requiredDistribution.isDefined) { + ValidateRequirements.validate(optimized, ensureRequirements.requiredDistribution.get) + } else { + ValidateRequirements.validate(optimized) + } + if (requirementSatisfied) { + optimized.transform { + case SkewJoinChildWrapper(child) => child + } + } else if (conf.getConf(SQLConf.ADAPTIVE_FORCE_OPTIMIZE_SKEWED_JOIN)) { + ensureRequirements.apply(optimized).transform { + case SkewJoinChildWrapper(child) => child + } + } else { + plan + } + } + + object ShuffleStage { + def unapply(plan: SparkPlan): Option[ShuffleQueryStageExec] = plan match { + case s: ShuffleQueryStageExec if s.isMaterialized && s.mapStats.isDefined && + s.shuffle.shuffleOrigin == ENSURE_REQUIREMENTS => + Some(s) + case _ => None + } + } +} + +// After optimizing skew joins, we need to run EnsureRequirements again to add necessary shuffles +// caused by skew join optimization. However, this shouldn't apply to the sub-plan under skew join, +// as it's guaranteed to satisfy distribution requirement. +case class SkewJoinChildWrapper(plan: SparkPlan) extends LeafExecNode { + override protected def doExecute(): RDD[InternalRow] = throw new UnsupportedOperationException() + override def output: Seq[Attribute] = plan.output + override def outputPartitioning: Partitioning = plan.outputPartitioning + override def outputOrdering: Seq[SortOrder] = plan.outputOrdering +} -- Gitee From 14c7c0383eb3603abfa2cc2c3df3beb9ed94cedf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E7=A6=8F=E6=99=BA?= <14267354+follow__my_heart@user.noreply.gitee.com> Date: Mon, 17 Mar 2025 13:44:19 +0000 Subject: [PATCH 3/5] add spark/sql/execution/adaptive/OptimizeSkewShufflePartition.scala. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: 王福智 <14267354+follow__my_heart@user.noreply.gitee.com> --- .../OptimizeSkewShufflePartition.scala | 129 ++++++++++++++++++ 1 file changed, 129 insertions(+) create mode 100644 omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewShufflePartition.scala diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewShufflePartition.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewShufflePartition.scala new file mode 100644 index 000000000..1af8af40c --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewShufflePartition.scala @@ -0,0 +1,129 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.adaptive + +import org.apache.spark.sql.execution.{CoalescedPartitionSpec, ShufflePartitionSpec, SparkPlan} +import org.apache.spark.sql.execution.exchange.{ENSURE_REQUIREMENTS, EnsureRequirements, ShuffleOrigin, ValidateRequirements} +import org.apache.spark.sql.internal.SQLConf + +case class OptimizeSkewShufflePartition(ensureRequirements: EnsureRequirements, parallelism: Int) + extends AQEShuffleReadRule { + + override val supportedShuffleOrigins: Seq[ShuffleOrigin] = { + Seq(ENSURE_REQUIREMENTS) + } + + private def optimizeSkewedPartitions( + shuffleId: Int, + bytesByPartitionId: Array[Long], + targetSize: Long + ): Seq[ShufflePartitionSpec] = { + logWarning(s"Enter OptimizeSkewShufflePartition optimizeSkewedPartitions") + + val smallPartitionFactor = + conf.getConf(SQLConf.ADAPTIVE_REBALANCE_PARTITIONS_SMALL_PARTITION_FACTOR) + + bytesByPartitionId.indices.flatMap { reduceIndex => + val bytes = bytesByPartitionId(reduceIndex) + if (bytes > targetSize) { + val newPartitionSpec = ShufflePartitionsUtil.createSkewPartitionSpecs( + shuffleId, reduceIndex, targetSize, smallPartitionFactor) + if (newPartitionSpec.isEmpty) { + CoalescedPartitionSpec(reduceIndex, reduceIndex + 1, bytes) :: Nil + } else { + logWarning(s"For shuffle $shuffleId partition $reduceIndex is skewed, " + + s"split it into ${newPartitionSpec.get.size} parts.") + + newPartitionSpec.get + } + } else { + CoalescedPartitionSpec(reduceIndex, reduceIndex + 1, bytes) :: Nil + } + } + } + + private def tryOptimizeSkewedPartitions(shuffle: ShuffleQueryStageExec): SparkPlan = { + logWarning(s"OptimizeSkewShufflePartition shuffle.shuffleOrigin ${shuffle.shuffleOrigin}") + logWarning(s"Enter OptimizeSkewShufflePartition tryOptimizeSkewedPartitions") + + val mapStats = shuffle.mapStats + if (mapStats.isEmpty) { + return shuffle + } + val partitionCount = mapStats.get.bytesByPartitionId.length + val partitionMean = mapStats.get.bytesByPartitionId.sum / partitionCount + val advisorySize = if (partitionCount < parallelism) { + math.max(partitionMean * 2L, 64L * 1024 * 1024) + } else { + 1L * 1024 * 1024 * 1024 + } + if (mapStats.isEmpty || mapStats.get.bytesByPartitionId.forall(_ <= advisorySize)) { + return shuffle + } + + val newPartitionsSpec = optimizeSkewedPartitions( + mapStats.get.shuffleId, mapStats.get.bytesByPartitionId, advisorySize) + + if (newPartitionsSpec.length == mapStats.get.bytesByPartitionId.length) { + shuffle + } else { + AQEShuffleReadExec(shuffle, newPartitionsSpec) + } + } + + private def optimizedSkewPlan(plan: SparkPlan): SparkPlan = plan.transformUp { + case stage: ShuffleQueryStageExec + if (isSupported(stage.shuffle)) => { + if (stage.shuffleOrigin == ENSURE_REQUIREMENTS) { + tryOptimizeSkewedPartitions(stage) + } else { + logWarning(s"optimizedSkewPlan stage.shuffleOrigin ${stage.shuffleOrigin}") + + stage + } + } + } + + override def apply(plan: SparkPlan): SparkPlan = { + logWarning(s"Enter OptimizeSkewShufflePartition") + + if (!conf.getConf(SQLConf.ADAPTIVE_OPTIMIZE_SKEWS_IN_REBALANCE_PARTITIONS_ENABLED)) { + return plan + } + + val optimized = optimizedSkewPlan(plan) + + logWarning(s"ensureRequirements.requiredDistribution.isDefined" + + s"${ensureRequirements.requiredDistribution.isDefined}") + + val requirementSatisfied = if (ensureRequirements.requiredDistribution.isDefined) { + ValidateRequirements.validate(optimized, ensureRequirements.requiredDistribution.get) + } else { + ValidateRequirements.validate(optimized) + } + if (requirementSatisfied) { + logWarning(s"OptimizeSkewShufflePartition optimized") + + optimized + } else { + logWarning(s"OptimizeSkewShufflePartition ensureRequirements.apply(optimized)") + + ensureRequirements.apply(optimized) + } + } +} -- Gitee From 226ea492b9d8b973e6f86214125d689a959aada8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E7=A6=8F=E6=99=BA?= <14267354+follow__my_heart@user.noreply.gitee.com> Date: Mon, 17 Mar 2025 13:46:16 +0000 Subject: [PATCH 4/5] =?UTF-8?q?=E6=96=B0=E5=BB=BA=20exchange?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../src/main/scala/org/apache/spark/sql/execution/exchange/.keep | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/exchange/.keep diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/exchange/.keep b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/exchange/.keep new file mode 100644 index 000000000..e69de29bb -- Gitee From c2ce231dd20d69955b3857ed414123b68c9bbd04 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E7=A6=8F=E6=99=BA?= <14267354+follow__my_heart@user.noreply.gitee.com> Date: Mon, 17 Mar 2025 13:47:07 +0000 Subject: [PATCH 5/5] rename spark/sql/execution/exchange/.keep to spark/sql/execution/exchange/ShuffleExchangeExec.scala. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: 王福智 <14267354+follow__my_heart@user.noreply.gitee.com> --- .../apache/spark/sql/execution/exchange/.keep | 0 .../exchange/ShuffleExchangeExec.scala | 425 ++++++++++++++++++ 2 files changed, 425 insertions(+) delete mode 100644 omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/exchange/.keep create mode 100644 omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/exchange/.keep b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/exchange/.keep deleted file mode 100644 index e69de29bb..000000000 diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala new file mode 100644 index 000000000..2add187bf --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala @@ -0,0 +1,425 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.exchange + +import java.util.function.Supplier + +import scala.concurrent.Future + +import org.apache.spark._ +import org.apache.spark.internal.config +import org.apache.spark.rdd.RDD +import org.apache.spark.serializer.Serializer +import org.apache.spark.shuffle.{ShuffleWriteMetricsReporter, ShuffleWriteProcessor} +import org.apache.spark.shuffle.sort.SortShuffleManager +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Attribute, BoundReference, UnsafeProjection, UnsafeRow} +import org.apache.spark.sql.catalyst.expressions.codegen.LazilyGeneratedOrdering +import org.apache.spark.sql.catalyst.plans.logical.Statistics +import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.execution._ +import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics, SQLShuffleReadMetricsReporter, SQLShuffleWriteMetricsReporter} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.StructType +import org.apache.spark.util.MutablePair +import org.apache.spark.util.collection.unsafe.sort.{PrefixComparators, RecordComparator} +import org.apache.spark.util.random.XORShiftRandom + +/** + * Common trait for all shuffle exchange implementations to facilitate pattern matching. + */ +trait ShuffleExchangeLike extends Exchange { + + /** + * Returns the number of mappers of this shuffle. + */ + def numMappers: Int + + /** + * Returns the shuffle partition number. + */ + def numPartitions: Int + + /** + * The origin of this shuffle operator. + */ + def shuffleOrigin: ShuffleOrigin + + /** + * The asynchronous job that materializes the shuffle. It also does the preparations work, + * such as waiting for the subqueries. + */ + final def submitShuffleJob: Future[MapOutputStatistics] = executeQuery { + mapOutputStatisticsFuture + } + + protected def mapOutputStatisticsFuture: Future[MapOutputStatistics] + + /** + * Returns the shuffle RDD with specified partition specs. + */ + def getShuffleRDD(partitionSpecs: Array[ShufflePartitionSpec]): RDD[_] + + /** + * Returns the runtime statistics after shuffle materialization. + */ + def runtimeStatistics: Statistics +} + +// Describes where the shuffle operator comes from. +sealed trait ShuffleOrigin + +// Indicates that the shuffle operator was added by the internal `EnsureRequirements` rule. It +// means that the shuffle operator is used to ensure internal data partitioning requirements and +// Spark is free to optimize it as long as the requirements are still ensured. +case object ENSURE_REQUIREMENTS extends ShuffleOrigin + +// Indicates that the shuffle operator was added by the user-specified repartition operator. Spark +// can still optimize it via changing shuffle partition number, as data partitioning won't change. +case object REPARTITION_BY_COL extends ShuffleOrigin + +// Indicates that the shuffle operator was added by the user-specified repartition operator with +// a certain partition number. Spark can't optimize it. +case object REPARTITION_BY_NUM extends ShuffleOrigin + +// Indicates that the shuffle operator was added by the user-specified rebalance operator. +// Spark will try to rebalance partitions that make per-partition size not too small and not +// too big. Local shuffle read will be used if possible to reduce network traffic. +case object REBALANCE_PARTITIONS_BY_NONE extends ShuffleOrigin + +// Indicates that the shuffle operator was added by the user-specified rebalance operator with +// columns. Spark will try to rebalance partitions that make per-partition size not too small and +// not too big. +// Different from `REBALANCE_PARTITIONS_BY_NONE`, local shuffle read cannot be used for it as +// the output needs to be partitioned by the given columns. +case object REBALANCE_PARTITIONS_BY_COL extends ShuffleOrigin + +case object SKEW_NORMAL_SHUFFLE extends ShuffleOrigin + +/** + * Performs a shuffle that will result in the desired partitioning. + */ +case class ShuffleExchangeExec( + override val outputPartitioning: Partitioning, + child: SparkPlan, + shuffleOrigin: ShuffleOrigin = ENSURE_REQUIREMENTS) + extends ShuffleExchangeLike { + + private lazy val writeMetrics = + SQLShuffleWriteMetricsReporter.createShuffleWriteMetrics(sparkContext) + private[sql] lazy val readMetrics = + SQLShuffleReadMetricsReporter.createShuffleReadMetrics(sparkContext) + override lazy val metrics = Map( + "dataSize" -> SQLMetrics.createSizeMetric(sparkContext, "data size"), + "numPartitions" -> SQLMetrics.createMetric(sparkContext, "number of partitions") + ) ++ readMetrics ++ writeMetrics + + override def nodeName: String = "Exchange" + + private lazy val serializer: Serializer = + new UnsafeRowSerializer(child.output.size, longMetric("dataSize")) + + @transient lazy val inputRDD: RDD[InternalRow] = child.execute() + + // 'mapOutputStatisticsFuture' is only needed when enable AQE. + @transient + override lazy val mapOutputStatisticsFuture: Future[MapOutputStatistics] = { + if (inputRDD.getNumPartitions == 0) { + Future.successful(null) + } else { + sparkContext.submitMapStage(shuffleDependency) + } + } + + override def numMappers: Int = shuffleDependency.rdd.getNumPartitions + + override def numPartitions: Int = shuffleDependency.partitioner.numPartitions + + override def getShuffleRDD(partitionSpecs: Array[ShufflePartitionSpec]): RDD[InternalRow] = { + new ShuffledRowRDD(shuffleDependency, readMetrics, partitionSpecs) + } + + override def runtimeStatistics: Statistics = { + val dataSize = metrics("dataSize").value + val rowCount = metrics(SQLShuffleWriteMetricsReporter.SHUFFLE_RECORDS_WRITTEN).value + Statistics(dataSize, Some(rowCount)) + } + + /** + * A [[ShuffleDependency]] that will partition rows of its child based on + * the partitioning scheme defined in `newPartitioning`. Those partitions of + * the returned ShuffleDependency will be the input of shuffle. + */ + @transient + lazy val shuffleDependency : ShuffleDependency[Int, InternalRow, InternalRow] = { + val dep = ShuffleExchangeExec.prepareShuffleDependency( + inputRDD, + child.output, + outputPartitioning, + serializer, + writeMetrics) + metrics("numPartitions").set(dep.partitioner.numPartitions) + val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) + SQLMetrics.postDriverMetricUpdates( + sparkContext, executionId, metrics("numPartitions") :: Nil) + dep + } + + /** + * Caches the created ShuffleRowRDD so we can reuse that. + */ + private var cachedShuffleRDD: ShuffledRowRDD = null + + protected override def doExecute(): RDD[InternalRow] = { + // Returns the same ShuffleRowRDD if this plan is used by multiple plans. + if (cachedShuffleRDD == null) { + cachedShuffleRDD = new ShuffledRowRDD(shuffleDependency, readMetrics) + } + cachedShuffleRDD + } + + override protected def withNewChildInternal(newChild: SparkPlan): ShuffleExchangeExec = + copy(child = newChild) +} + +object ShuffleExchangeExec { + + /** + * Determines whether records must be defensively copied before being sent to the shuffle. + * Several of Spark's shuffle components will buffer deserialized Java objects in memory. The + * shuffle code assumes that objects are immutable and hence does not perform its own defensive + * copying. In Spark SQL, however, operators' iterators return the same mutable `Row` object. In + * order to properly shuffle the output of these operators, we need to perform our own copying + * prior to sending records to the shuffle. This copying is expensive, so we try to avoid it + * whenever possible. This method encapsulates the logic for choosing when to copy. + * + * In the long run, we might want to push this logic into core's shuffle APIs so that we don't + * have to rely on knowledge of core internals here in SQL. + * + * See SPARK-2967, SPARK-4479, and SPARK-7375 for more discussion of this issue. + * + * @param partitioner the partitioner for the shuffle + * @return true if rows should be copied before being shuffled, false otherwise + */ + private def needToCopyObjectsBeforeShuffle(partitioner: Partitioner): Boolean = { + // Note: even though we only use the partitioner's `numPartitions` field, we require it to be + // passed instead of directly passing the number of partitions in order to guard against + // corner-cases where a partitioner constructed with `numPartitions` partitions may output + // fewer partitions (like RangePartitioner, for example). + val conf = SparkEnv.get.conf + val shuffleManager = SparkEnv.get.shuffleManager + val sortBasedShuffleOn = shuffleManager.isInstanceOf[SortShuffleManager] + val bypassMergeThreshold = conf.get(config.SHUFFLE_SORT_BYPASS_MERGE_THRESHOLD) + val numParts = partitioner.numPartitions + if (sortBasedShuffleOn) { + if (numParts <= bypassMergeThreshold) { + // If we're using the original SortShuffleManager and the number of output partitions is + // sufficiently small, then Spark will fall back to the hash-based shuffle write path, which + // doesn't buffer deserialized records. + // Note that we'll have to remove this case if we fix SPARK-6026 and remove this bypass. + false + } else if (numParts <= SortShuffleManager.MAX_SHUFFLE_OUTPUT_PARTITIONS_FOR_SERIALIZED_MODE) { + // SPARK-4550 and SPARK-7081 extended sort-based shuffle to serialize individual records + // prior to sorting them. This optimization is only applied in cases where shuffle + // dependency does not specify an aggregator or ordering and the record serializer has + // certain properties and the number of partitions doesn't exceed the limitation. If this + // optimization is enabled, we can safely avoid the copy. + // + // Exchange never configures its ShuffledRDDs with aggregators or key orderings, and the + // serializer in Spark SQL always satisfy the properties, so we only need to check whether + // the number of partitions exceeds the limitation. + false + } else { + // Spark's SortShuffleManager uses `ExternalSorter` to buffer records in memory, so we must + // copy. + true + } + } else { + // Catch-all case to safely handle any future ShuffleManager implementations. + true + } + } + + /** + * Returns a [[ShuffleDependency]] that will partition rows of its child based on + * the partitioning scheme defined in `newPartitioning`. Those partitions of + * the returned ShuffleDependency will be the input of shuffle. + */ + def prepareShuffleDependency( + rdd: RDD[InternalRow], + outputAttributes: Seq[Attribute], + newPartitioning: Partitioning, + serializer: Serializer, + writeMetrics: Map[String, SQLMetric]) + : ShuffleDependency[Int, InternalRow, InternalRow] = { + val part: Partitioner = newPartitioning match { + case RoundRobinPartitioning(numPartitions) => new HashPartitioner(numPartitions) + case HashPartitioning(_, n) => + new Partitioner { + override def numPartitions: Int = n + // For HashPartitioning, the partitioning key is already a valid partition ID, as we use + // `HashPartitioning.partitionIdExpression` to produce partitioning key. + override def getPartition(key: Any): Int = key.asInstanceOf[Int] + } + case RangePartitioning(sortingExpressions, numPartitions) => + // Extract only fields used for sorting to avoid collecting large fields that does not + // affect sorting result when deciding partition bounds in RangePartitioner + val rddForSampling = rdd.mapPartitionsInternal { iter => + val projection = + UnsafeProjection.create(sortingExpressions.map(_.child), outputAttributes) + val mutablePair = new MutablePair[InternalRow, Null]() + // Internally, RangePartitioner runs a job on the RDD that samples keys to compute + // partition bounds. To get accurate samples, we need to copy the mutable keys. + iter.map(row => mutablePair.update(projection(row).copy(), null)) + } + // Construct ordering on extracted sort key. + val orderingAttributes = sortingExpressions.zipWithIndex.map { case (ord, i) => + ord.copy(child = BoundReference(i, ord.dataType, ord.nullable)) + } + implicit val ordering = new LazilyGeneratedOrdering(orderingAttributes) + new RangePartitioner( + numPartitions, + rddForSampling, + ascending = true, + samplePointsPerPartitionHint = SQLConf.get.rangeExchangeSampleSizePerPartition) + case SinglePartition => + new Partitioner { + override def numPartitions: Int = 1 + override def getPartition(key: Any): Int = 0 + } + case _ => throw new IllegalStateException(s"Exchange not implemented for $newPartitioning") + // TODO: Handle BroadcastPartitioning. + } + def getPartitionKeyExtractor(): InternalRow => Any = newPartitioning match { + case RoundRobinPartitioning(numPartitions) => + // Distributes elements evenly across output partitions, starting from a random partition. + // nextInt(numPartitions) implementation has a special case when bound is a power of 2, + // which is basically taking several highest bits from the initial seed, with only a + // minimal scrambling. Due to deterministic seed, using the generator only once, + // and lack of scrambling, the position values for power-of-two numPartitions always + // end up being almost the same regardless of the index. substantially scrambling the + // seed by hashing will help. Refer to SPARK-21782 for more details. + val partitionId = TaskContext.get().partitionId() + var position = new XORShiftRandom(partitionId).nextInt(numPartitions) + (row: InternalRow) => { + // The HashPartitioner will handle the `mod` by the number of partitions + position += 1 + position + } + case h: HashPartitioning => + val projection = UnsafeProjection.create(h.partitionIdExpression :: Nil, outputAttributes) + row => projection(row).getInt(0) + case RangePartitioning(sortingExpressions, _) => + val projection = UnsafeProjection.create(sortingExpressions.map(_.child), outputAttributes) + row => projection(row) + case SinglePartition => identity + case _ => throw new IllegalStateException(s"Exchange not implemented for $newPartitioning") + } + + val isRoundRobin = newPartitioning.isInstanceOf[RoundRobinPartitioning] && + newPartitioning.numPartitions > 1 + + val rddWithPartitionIds: RDD[Product2[Int, InternalRow]] = { + // [SPARK-23207] Have to make sure the generated RoundRobinPartitioning is deterministic, + // otherwise a retry task may output different rows and thus lead to data loss. + // + // Currently we following the most straight-forward way that perform a local sort before + // partitioning. + // + // Note that we don't perform local sort if the new partitioning has only 1 partition, under + // that case all output rows go to the same partition. + val newRdd = if (isRoundRobin && SQLConf.get.sortBeforeRepartition) { + rdd.mapPartitionsInternal { iter => + val recordComparatorSupplier = new Supplier[RecordComparator] { + override def get: RecordComparator = new RecordBinaryComparator() + } + // The comparator for comparing row hashcode, which should always be Integer. + val prefixComparator = PrefixComparators.LONG + + // The prefix computer generates row hashcode as the prefix, so we may decrease the + // probability that the prefixes are equal when input rows choose column values from a + // limited range. + val prefixComputer = new UnsafeExternalRowSorter.PrefixComputer { + private val result = new UnsafeExternalRowSorter.PrefixComputer.Prefix + override def computePrefix(row: InternalRow): + UnsafeExternalRowSorter.PrefixComputer.Prefix = { + // The hashcode generated from the binary form of a [[UnsafeRow]] should not be null. + result.isNull = false + result.value = row.hashCode() + result + } + } + val pageSize = SparkEnv.get.memoryManager.pageSizeBytes + + val sorter = UnsafeExternalRowSorter.createWithRecordComparator( + StructType.fromAttributes(outputAttributes), + recordComparatorSupplier, + prefixComparator, + prefixComputer, + pageSize, + // We are comparing binary here, which does not support radix sort. + // See more details in SPARK-28699. + false) + sorter.sort(iter.asInstanceOf[Iterator[UnsafeRow]]) + } + } else { + rdd + } + + // round-robin function is order sensitive if we don't sort the input. + val isOrderSensitive = isRoundRobin && !SQLConf.get.sortBeforeRepartition + if (needToCopyObjectsBeforeShuffle(part)) { + newRdd.mapPartitionsWithIndexInternal((_, iter) => { + val getPartitionKey = getPartitionKeyExtractor() + iter.map { row => (part.getPartition(getPartitionKey(row)), row.copy()) } + }, isOrderSensitive = isOrderSensitive) + } else { + newRdd.mapPartitionsWithIndexInternal((_, iter) => { + val getPartitionKey = getPartitionKeyExtractor() + val mutablePair = new MutablePair[Int, InternalRow]() + iter.map { row => mutablePair.update(part.getPartition(getPartitionKey(row)), row) } + }, isOrderSensitive = isOrderSensitive) + } + } + + // Now, we manually create a ShuffleDependency. Because pairs in rddWithPartitionIds + // are in the form of (partitionId, row) and every partitionId is in the expected range + // [0, part.numPartitions - 1]. The partitioner of this is a PartitionIdPassthrough. + val dependency = + new ShuffleDependency[Int, InternalRow, InternalRow]( + rddWithPartitionIds, + new PartitionIdPassthrough(part.numPartitions), + serializer, + shuffleWriterProcessor = createShuffleWriteProcessor(writeMetrics)) + + dependency + } + + /** + * Create a customized [[ShuffleWriteProcessor]] for SQL which wrap the default metrics reporter + * with [[SQLShuffleWriteMetricsReporter]] as new reporter for [[ShuffleWriteProcessor]]. + */ + def createShuffleWriteProcessor(metrics: Map[String, SQLMetric]): ShuffleWriteProcessor = { + new ShuffleWriteProcessor { + override protected def createMetricsReporter( + context: TaskContext): ShuffleWriteMetricsReporter = { + new SQLShuffleWriteMetricsReporter(context.taskMetrics().shuffleWriteMetrics, metrics) + } + } + } +} -- Gitee