From 1487d6aedad44ef07e400837a66d6fae988dedb8 Mon Sep 17 00:00:00 2001 From: hyy_cyan Date: Fri, 13 Dec 2024 11:20:07 +0800 Subject: [PATCH 1/5] reorganize Scala plan transform codes --- .../boostkit/spark/ColumnarPlugin.scala | 986 +++++++++--------- ...uardRule.scala => TransformHintRule.scala} | 303 ++++-- 2 files changed, 702 insertions(+), 587 deletions(-) rename omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/{ColumnarGuardRule.scala => TransformHintRule.scala} (49%) diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarPlugin.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarPlugin.scala index e38bede42..d445c941e 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarPlugin.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarPlugin.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.expressions.{Ascending, DynamicPruningSubqu import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Partial, PartialMerge} import org.apache.spark.sql.catalyst.optimizer.{DelayCartesianProduct, HeuristicJoinReorder, MergeSubqueryFilters, RewriteSelfJoinInInPredicate} import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.execution.{RowToOmniColumnarExec, _} +import org.apache.spark.sql.execution.{RowToOmniColumnarExec, BroadcastExchangeExecProxy, _} import org.apache.spark.sql.execution.adaptive.{AQEShuffleReadExec, BroadcastQueryStageExec, OmniAQEShuffleReadExec, QueryStageExec, ShuffleQueryStageExec} import org.apache.spark.sql.execution.aggregate.{DummyLogicalPlan, ExtendedAggUtils, HashAggregateExec} import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, Exchange, ReusedExchangeExec, ShuffleExchangeExec} @@ -41,7 +41,8 @@ import org.apache.spark.sql.catalyst.plans.LeftSemi import org.apache.spark.sql.catalyst.plans.logical.Aggregate import org.apache.spark.sql.execution.util.SparkMemoryUtils.addLeakSafeTaskCompletionListener -case class ColumnarPreOverrides() extends Rule[SparkPlan] { +case class ColumnarPreOverrides(isAdaptiveContext: Boolean) + extends Rule[SparkPlan] { val columnarConf: ColumnarPluginConfig = ColumnarPluginConfig.getSessionConf val enableColumnarFileScan: Boolean = columnarConf.enableColumnarFileScan val enableColumnarProject: Boolean = columnarConf.enableColumnarProject @@ -73,12 +74,6 @@ case class ColumnarPreOverrides() extends Rule[SparkPlan] { val enableRowShuffle: Boolean = columnarConf.enableRowShuffle val columnsThreshold: Int = columnarConf.columnsThreshold - def apply(plan: SparkPlan): SparkPlan = { - replaceWithColumnarPlan(plan) - } - - def setAdaptiveSupport(enable: Boolean): Unit = { isSupportAdaptive = enable } - def checkBhjRightChild(x: Any): Boolean = { x match { case _: ColumnarFilterExec | _: ColumnarConditionProjectExec => true @@ -86,212 +81,279 @@ case class ColumnarPreOverrides() extends Rule[SparkPlan] { } } - def replaceWithColumnarPlan(plan: SparkPlan): SparkPlan = plan match { - case plan: RowGuard => - val actualPlan: SparkPlan = plan.child match { - case p: BroadcastHashJoinExec => - p.withNewChildren(p.children.map { - case RowGuard(queryStage: BroadcastQueryStageExec) => - fallBackBroadcastQueryStage(queryStage) - case queryStage: BroadcastQueryStageExec => - fallBackBroadcastQueryStage(queryStage) - case plan: BroadcastExchangeExec => - // if BroadcastHashJoin is row-based, BroadcastExchange should also be row-based - RowGuard(plan) - case other => other - }) - case p: BroadcastNestedLoopJoinExec => - p.withNewChildren(p.children.map { - case RowGuard(queryStage: BroadcastQueryStageExec) => - fallBackBroadcastQueryStage(queryStage) - case queryStage: BroadcastQueryStageExec => - fallBackBroadcastQueryStage(queryStage) - case plan: BroadcastExchangeExec => - // if BroadcastNestedLoopJoin is row-based, BroadcastExchange should also be row-based - RowGuard(plan) - case other => other - }) - case other => - other - } - logDebug(s"Columnar Processing for ${actualPlan.getClass} is under RowGuard.") - actualPlan.withNewChildren(actualPlan.children.map(replaceWithColumnarPlan)) - case plan: FileSourceScanExec - if enableColumnarFileScan && checkColumnarBatchSupport(conf, plan) => - logInfo(s"Columnar Processing for ${plan.getClass} is currently supported.") - ColumnarFileSourceScanExec( - plan.relation, - plan.output, - plan.requiredSchema, - plan.partitionFilters, - plan.optionalBucketSet, - plan.optionalNumCoalescedBuckets, - plan.dataFilters, - plan.tableIdentifier, - plan.disableBucketedScan - ) - case range: RangeExec => - new ColumnarRangeExec(range.range) - case plan: ProjectExec if enableColumnarProject => - val child = replaceWithColumnarPlan(plan.child) - logInfo(s"Columnar Processing for ${plan.getClass} is currently supported.") - child match { - case ColumnarFilterExec(condition, child) => - ColumnarConditionProjectExec(plan.projectList, condition, child) - case join : ColumnarBroadcastHashJoinExec => - if (plan.projectList.forall(project => OmniExpressionAdaptor.isSimpleProjectForAll(project)) && enableColumnarProjectFusion) { - ColumnarBroadcastHashJoinExec( - join.leftKeys, - join.rightKeys, - join.joinType, - join.buildSide, - join.condition, - join.left, - join.right, - join.isNullAwareAntiJoin, - plan.projectList) - } else { - ColumnarProjectExec(plan.projectList, child) - } - case join : ColumnarShuffledHashJoinExec => - if (plan.projectList.forall(project => OmniExpressionAdaptor.isSimpleProjectForAll(project)) && enableColumnarProjectFusion) { - ColumnarShuffledHashJoinExec( - join.leftKeys, - join.rightKeys, - join.joinType, - join.buildSide, - join.condition, - join.left, - join.right, - join.isSkewJoin, - plan.projectList) - } else { + def replaceWithColumnarPlan(plan: SparkPlan): SparkPlan = { + TransformHints.getHint(plan) match { + case _: TRANSFORM_SUPPORTED => + // supported, break + case _: TRANSFORM_UNSUPPORTED => + logDebug(s"Columnar Processing for ${plan.getClass} is under RowGuard.") + return plan.withNewChildren( + plan.children.map(replaceWithColumnarPlan)) + } + plan match { + case plan: FileSourceScanExec + if enableColumnarFileScan && checkColumnarBatchSupport(conf, plan) => + logInfo(s"Columnar Processing for ${plan.getClass} is currently supported.") + ColumnarFileSourceScanExec( + plan.relation, + plan.output, + plan.requiredSchema, + plan.partitionFilters, + plan.optionalBucketSet, + plan.optionalNumCoalescedBuckets, + plan.dataFilters, + plan.tableIdentifier, + plan.disableBucketedScan + ) + case range: RangeExec => + new ColumnarRangeExec(range.range) + case plan: ProjectExec if enableColumnarProject => + val child = replaceWithColumnarPlan(plan.child) + logInfo(s"Columnar Processing for ${plan.getClass} is currently supported.") + child match { + case ColumnarFilterExec(condition, child) => + ColumnarConditionProjectExec(plan.projectList, condition, child) + case join : ColumnarBroadcastHashJoinExec => + if (plan.projectList.forall(project => OmniExpressionAdaptor.isSimpleProjectForAll(project)) && enableColumnarProjectFusion) { + ColumnarBroadcastHashJoinExec( + join.leftKeys, + join.rightKeys, + join.joinType, + join.buildSide, + join.condition, + join.left, + join.right, + join.isNullAwareAntiJoin, + plan.projectList) + } else { + ColumnarProjectExec(plan.projectList, child) + } + case join : ColumnarShuffledHashJoinExec => + if (plan.projectList.forall(project => OmniExpressionAdaptor.isSimpleProjectForAll(project)) && enableColumnarProjectFusion) { + ColumnarShuffledHashJoinExec( + join.leftKeys, + join.rightKeys, + join.joinType, + join.buildSide, + join.condition, + join.left, + join.right, + join.isSkewJoin, + plan.projectList) + } else { + ColumnarProjectExec(plan.projectList, child) + } + case join : ColumnarSortMergeJoinExec => + if (plan.projectList.forall(project => OmniExpressionAdaptor.isSimpleProjectForAll(project)) && enableColumnarProjectFusion) { + ColumnarSortMergeJoinExec( + join.leftKeys, + join.rightKeys, + join.joinType, + join.condition, + join.left, + join.right, + join.isSkewJoin, + plan.projectList) + } else { + ColumnarProjectExec(plan.projectList, child) + } + case _ => ColumnarProjectExec(plan.projectList, child) - } - case join : ColumnarSortMergeJoinExec => - if (plan.projectList.forall(project => OmniExpressionAdaptor.isSimpleProjectForAll(project)) && enableColumnarProjectFusion) { - ColumnarSortMergeJoinExec( - join.leftKeys, - join.rightKeys, - join.joinType, - join.condition, - join.left, - join.right, - join.isSkewJoin, - plan.projectList) + } + case plan: FilterExec if enableColumnarFilter => + val child = replaceWithColumnarPlan(plan.child) + logInfo(s"Columnar Processing for ${plan.getClass} is currently supported.") + ColumnarFilterExec(plan.condition, child) + case plan: ExpandExec if enableColumnarExpand => + val child = replaceWithColumnarPlan(plan.child) + logInfo(s"Columnar Processing for ${plan.getClass} is currently supported.") + ColumnarExpandExec(plan.projections, plan.output, child) + case plan: HashAggregateExec if enableColumnarHashAgg => + val child = replaceWithColumnarPlan(plan.child) + logInfo(s"Columnar Processing for ${plan.getClass} is currently supported.") + if (enableFusion && !isSupportAdaptive) { + if (plan.aggregateExpressions.forall(_.mode == Partial)) { + child match { + case proj1 @ ColumnarProjectExec(_, + join1 @ ColumnarBroadcastHashJoinExec(_, _, _, _, _, + proj2 @ ColumnarProjectExec(_, + join2 @ ColumnarBroadcastHashJoinExec(_, _, _, _, _, + proj3 @ ColumnarProjectExec(_, + join3 @ ColumnarBroadcastHashJoinExec(_, _, _, _, _, + proj4 @ ColumnarProjectExec(_, + join4 @ ColumnarBroadcastHashJoinExec(_, _, _, _, _, + filter @ ColumnarFilterExec(_, + scan @ ColumnarFileSourceScanExec(_, _, _, _, _, _, _, _, _) + ), _, _, _)), _, _, _)), _, _, _)), _, _, _)) + if checkBhjRightChild( + child.asInstanceOf[ColumnarProjectExec].child.children(1) + .asInstanceOf[ColumnarBroadcastExchangeExec].child) => + ColumnarMultipleOperatorExec( + plan, + proj1, + join1, + proj2, + join2, + proj3, + join3, + proj4, + join4, + filter, + scan.relation, + plan.output, + scan.requiredSchema, + scan.partitionFilters, + scan.optionalBucketSet, + scan.optionalNumCoalescedBuckets, + scan.dataFilters, + scan.tableIdentifier, + scan.disableBucketedScan) + case proj1 @ ColumnarProjectExec(_, + join1 @ ColumnarBroadcastHashJoinExec(_, _, _, _, _, + proj2 @ ColumnarProjectExec(_, + join2 @ ColumnarBroadcastHashJoinExec(_, _, _, _, _, + proj3 @ ColumnarProjectExec(_, + join3 @ ColumnarBroadcastHashJoinExec(_, _, _, _, _, _, + filter @ ColumnarFilterExec(_, + scan @ ColumnarFileSourceScanExec(_, _, _, _, _, _, _, _, _)), _, _)) , _, _, _)), _, _, _)) + if checkBhjRightChild( + child.asInstanceOf[ColumnarProjectExec].child.children(1) + .asInstanceOf[ColumnarBroadcastExchangeExec].child) => + ColumnarMultipleOperatorExec1( + plan, + proj1, + join1, + proj2, + join2, + proj3, + join3, + filter, + scan.relation, + plan.output, + scan.requiredSchema, + scan.partitionFilters, + scan.optionalBucketSet, + scan.optionalNumCoalescedBuckets, + scan.dataFilters, + scan.tableIdentifier, + scan.disableBucketedScan) + case proj1 @ ColumnarProjectExec(_, + join1 @ ColumnarBroadcastHashJoinExec(_, _, _, _, _, + proj2 @ ColumnarProjectExec(_, + join2 @ ColumnarBroadcastHashJoinExec(_, _, _, _, _, + proj3 @ ColumnarProjectExec(_, + join3 @ ColumnarBroadcastHashJoinExec(_, _, _, _, _, + filter @ ColumnarFilterExec(_, + scan @ ColumnarFileSourceScanExec(_, _, _, _, _, _, _, _, _)), _, _, _)) , _, _, _)), _, _, _)) + if checkBhjRightChild( + child.asInstanceOf[ColumnarProjectExec].child.children(1) + .asInstanceOf[ColumnarBroadcastExchangeExec].child) => + ColumnarMultipleOperatorExec1( + plan, + proj1, + join1, + proj2, + join2, + proj3, + join3, + filter, + scan.relation, + plan.output, + scan.requiredSchema, + scan.partitionFilters, + scan.optionalBucketSet, + scan.optionalNumCoalescedBuckets, + scan.dataFilters, + scan.tableIdentifier, + scan.disableBucketedScan) + case _ => + new ColumnarHashAggregateExec( + plan.requiredChildDistributionExpressions, + plan.isStreaming, + plan.numShufflePartitions, + plan.groupingExpressions, + plan.aggregateExpressions, + plan.aggregateAttributes, + plan.initialInputBufferOffset, + plan.resultExpressions, + child) + } } else { - ColumnarProjectExec(plan.projectList, child) + new ColumnarHashAggregateExec( + plan.requiredChildDistributionExpressions, + plan.isStreaming, + plan.numShufflePartitions, + plan.groupingExpressions, + plan.aggregateExpressions, + plan.aggregateAttributes, + plan.initialInputBufferOffset, + plan.resultExpressions, + child) } - case _ => - ColumnarProjectExec(plan.projectList, child) - } - case plan: FilterExec if enableColumnarFilter => - val child = replaceWithColumnarPlan(plan.child) - logInfo(s"Columnar Processing for ${plan.getClass} is currently supported.") - ColumnarFilterExec(plan.condition, child) - case plan: ExpandExec if enableColumnarExpand => - val child = replaceWithColumnarPlan(plan.child) - logInfo(s"Columnar Processing for ${plan.getClass} is currently supported.") - ColumnarExpandExec(plan.projections, plan.output, child) - case plan: HashAggregateExec if enableColumnarHashAgg => - val child = replaceWithColumnarPlan(plan.child) - logInfo(s"Columnar Processing for ${plan.getClass} is currently supported.") - if (enableFusion && !isSupportAdaptive) { - if (plan.aggregateExpressions.forall(_.mode == Partial)) { - child match { - case proj1 @ ColumnarProjectExec(_, - join1 @ ColumnarBroadcastHashJoinExec(_, _, _, _, _, - proj2 @ ColumnarProjectExec(_, - join2 @ ColumnarBroadcastHashJoinExec(_, _, _, _, _, - proj3 @ ColumnarProjectExec(_, - join3 @ ColumnarBroadcastHashJoinExec(_, _, _, _, _, - proj4 @ ColumnarProjectExec(_, - join4 @ ColumnarBroadcastHashJoinExec(_, _, _, _, _, - filter @ ColumnarFilterExec(_, - scan @ ColumnarFileSourceScanExec(_, _, _, _, _, _, _, _, _) - ), _, _, _)), _, _, _)), _, _, _)), _, _, _)) - if checkBhjRightChild( - child.asInstanceOf[ColumnarProjectExec].child.children(1) - .asInstanceOf[ColumnarBroadcastExchangeExec].child) => - ColumnarMultipleOperatorExec( - plan, - proj1, - join1, - proj2, - join2, - proj3, - join3, - proj4, - join4, - filter, - scan.relation, - plan.output, - scan.requiredSchema, - scan.partitionFilters, - scan.optionalBucketSet, - scan.optionalNumCoalescedBuckets, - scan.dataFilters, - scan.tableIdentifier, - scan.disableBucketedScan) - case proj1 @ ColumnarProjectExec(_, - join1 @ ColumnarBroadcastHashJoinExec(_, _, _, _, _, - proj2 @ ColumnarProjectExec(_, - join2 @ ColumnarBroadcastHashJoinExec(_, _, _, _, _, - proj3 @ ColumnarProjectExec(_, - join3 @ ColumnarBroadcastHashJoinExec(_, _, _, _, _, _, - filter @ ColumnarFilterExec(_, - scan @ ColumnarFileSourceScanExec(_, _, _, _, _, _, _, _, _)), _, _)) , _, _, _)), _, _, _)) - if checkBhjRightChild( - child.asInstanceOf[ColumnarProjectExec].child.children(1) - .asInstanceOf[ColumnarBroadcastExchangeExec].child) => - ColumnarMultipleOperatorExec1( - plan, - proj1, - join1, - proj2, - join2, - proj3, - join3, - filter, - scan.relation, - plan.output, - scan.requiredSchema, - scan.partitionFilters, - scan.optionalBucketSet, - scan.optionalNumCoalescedBuckets, - scan.dataFilters, - scan.tableIdentifier, - scan.disableBucketedScan) - case proj1 @ ColumnarProjectExec(_, - join1 @ ColumnarBroadcastHashJoinExec(_, _, _, _, _, - proj2 @ ColumnarProjectExec(_, - join2 @ ColumnarBroadcastHashJoinExec(_, _, _, _, _, - proj3 @ ColumnarProjectExec(_, - join3 @ ColumnarBroadcastHashJoinExec(_, _, _, _, _, - filter @ ColumnarFilterExec(_, - scan @ ColumnarFileSourceScanExec(_, _, _, _, _, _, _, _, _)), _, _, _)) , _, _, _)), _, _, _)) - if checkBhjRightChild( - child.asInstanceOf[ColumnarProjectExec].child.children(1) - .asInstanceOf[ColumnarBroadcastExchangeExec].child) => - ColumnarMultipleOperatorExec1( - plan, - proj1, - join1, - proj2, - join2, - proj3, - join3, - filter, - scan.relation, + } else { + if (child.isInstanceOf[ColumnarExpandExec]) { + var columnarExpandExec = child.asInstanceOf[ColumnarExpandExec] + val matchRollupOptimization: Boolean = columnarExpandExec.matchRollupOptimization() + if (matchRollupOptimization && enableRollupOptimization) { + // The sparkPlan: ColumnarExpandExec -> ColumnarHashAggExec => ColumnarExpandExec -> ColumnarHashAggExec -> ColumnarOptRollupExec. + // ColumnarHashAggExec handles the first combination by Partial mode, i.e. projections[0]. + // ColumnarOptRollupExec handles the residual combinations by PartialMerge mode, i.e. projections[1]~projections[n]. + val projections = columnarExpandExec.projections + val headProjections = projections.slice(0, 1) + var residualProjections = projections.slice(1, projections.length) + // replace parameters + columnarExpandExec = columnarExpandExec.replace(headProjections) + + // partial + val partialHashAggExec = new ColumnarHashAggregateExec( + plan.requiredChildDistributionExpressions, + plan.isStreaming, + plan.numShufflePartitions, + plan.groupingExpressions, + plan.aggregateExpressions, + plan.aggregateAttributes, + plan.initialInputBufferOffset, + plan.resultExpressions, + columnarExpandExec) + + + // If the aggregator has an expression, more than one column in the projection is used + // for expression calculation. Meanwhile, If the single distinct syntax exists, the + // sequence of group columns is disordered. Therefore, we need to calculate the sequence + // of expandSeq first to ensure the project operator correctly processes the columns. + val expectSeq = plan.resultExpressions + val expandSeq = columnarExpandExec.output + // the processing sequences of expandSeq + residualProjections = residualProjections.map(projection => { + val indexSeq: Seq[Expression] = expectSeq.map(expectExpr => { + val index = expandSeq.indexWhere(expandExpr => expectExpr.exprId.equals(expandExpr.exprId)) + if (index != -1) { + projection.apply(index) match { + case literal: Literal => literal + case _ => expectExpr + } + } else { + expectExpr + } + }) + indexSeq + }) + + // partial merge + val groupingExpressions = plan.resultExpressions.slice(0, plan.groupingExpressions.length) + val aggregateExpressions = plan.aggregateExpressions.map(expr => { + expr.copy(expr.aggregateFunction, PartialMerge, expr.isDistinct, expr.filter, expr.resultId) + }) + + // need ExpandExec parameters and HashAggExec parameters + new ColumnarOptRollupExec( + residualProjections, plan.output, - scan.requiredSchema, - scan.partitionFilters, - scan.optionalBucketSet, - scan.optionalNumCoalescedBuckets, - scan.dataFilters, - scan.tableIdentifier, - scan.disableBucketedScan) - case _ => + groupingExpressions, + aggregateExpressions, + plan.aggregateAttributes, + partialHashAggExec) + } else { new ColumnarHashAggregateExec( plan.requiredChildDistributionExpressions, plan.isStreaming, @@ -302,82 +364,7 @@ case class ColumnarPreOverrides() extends Rule[SparkPlan] { plan.initialInputBufferOffset, plan.resultExpressions, child) - } - } else { - new ColumnarHashAggregateExec( - plan.requiredChildDistributionExpressions, - plan.isStreaming, - plan.numShufflePartitions, - plan.groupingExpressions, - plan.aggregateExpressions, - plan.aggregateAttributes, - plan.initialInputBufferOffset, - plan.resultExpressions, - child) - } - } else { - if (child.isInstanceOf[ColumnarExpandExec]) { - var columnarExpandExec = child.asInstanceOf[ColumnarExpandExec] - val matchRollupOptimization: Boolean = columnarExpandExec.matchRollupOptimization() - if (matchRollupOptimization && enableRollupOptimization) { - // The sparkPlan: ColumnarExpandExec -> ColumnarHashAggExec => ColumnarExpandExec -> ColumnarHashAggExec -> ColumnarOptRollupExec. - // ColumnarHashAggExec handles the first combination by Partial mode, i.e. projections[0]. - // ColumnarOptRollupExec handles the residual combinations by PartialMerge mode, i.e. projections[1]~projections[n]. - val projections = columnarExpandExec.projections - val headProjections = projections.slice(0, 1) - var residualProjections = projections.slice(1, projections.length) - // replace parameters - columnarExpandExec = columnarExpandExec.replace(headProjections) - - // partial - val partialHashAggExec = new ColumnarHashAggregateExec( - plan.requiredChildDistributionExpressions, - plan.isStreaming, - plan.numShufflePartitions, - plan.groupingExpressions, - plan.aggregateExpressions, - plan.aggregateAttributes, - plan.initialInputBufferOffset, - plan.resultExpressions, - columnarExpandExec) - - - // If the aggregator has an expression, more than one column in the projection is used - // for expression calculation. Meanwhile, If the single distinct syntax exists, the - // sequence of group columns is disordered. Therefore, we need to calculate the sequence - // of expandSeq first to ensure the project operator correctly processes the columns. - val expectSeq = plan.resultExpressions - val expandSeq = columnarExpandExec.output - // the processing sequences of expandSeq - residualProjections = residualProjections.map(projection => { - val indexSeq: Seq[Expression] = expectSeq.map(expectExpr => { - val index = expandSeq.indexWhere(expandExpr => expectExpr.exprId.equals(expandExpr.exprId)) - if (index != -1) { - projection.apply(index) match { - case literal: Literal => literal - case _ => expectExpr - } - } else { - expectExpr - } - }) - indexSeq - }) - - // partial merge - val groupingExpressions = plan.resultExpressions.slice(0, plan.groupingExpressions.length) - val aggregateExpressions = plan.aggregateExpressions.map(expr => { - expr.copy(expr.aggregateFunction, PartialMerge, expr.isDistinct, expr.filter, expr.resultId) - }) - - // need ExpandExec parameters and HashAggExec parameters - new ColumnarOptRollupExec( - residualProjections, - plan.output, - groupingExpressions, - aggregateExpressions, - plan.aggregateAttributes, - partialHashAggExec) + } } else { new ColumnarHashAggregateExec( plan.requiredChildDistributionExpressions, @@ -390,81 +377,84 @@ case class ColumnarPreOverrides() extends Rule[SparkPlan] { plan.resultExpressions, child) } - } else { - new ColumnarHashAggregateExec( - plan.requiredChildDistributionExpressions, - plan.isStreaming, - plan.numShufflePartitions, - plan.groupingExpressions, - plan.aggregateExpressions, - plan.aggregateAttributes, - plan.initialInputBufferOffset, - plan.resultExpressions, - child) } - } - case plan: TakeOrderedAndProjectExec if enableTakeOrderedAndProject => - val child = replaceWithColumnarPlan(plan.child) - logInfo(s"Columnar Processing for ${plan.getClass} is currently supported.") - ColumnarTakeOrderedAndProjectExec( - plan.limit, - plan.sortOrder, - plan.projectList, - child) - case plan: BroadcastExchangeExec if enableColumnarBroadcastExchange => - val child = replaceWithColumnarPlan(plan.child) - logInfo(s"Columnar Processing for ${plan.getClass} is currently supported.") - new ColumnarBroadcastExchangeExec(plan.mode, child) - case plan: BroadcastHashJoinExec if enableColumnarBroadcastJoin => - logInfo(s"Columnar Processing for ${plan.getClass} is currently supported.") - val left = replaceWithColumnarPlan(plan.left) - val right = replaceWithColumnarPlan(plan.right) - logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.") - ColumnarBroadcastHashJoinExec( - plan.leftKeys, - plan.rightKeys, - plan.joinType, - plan.buildSide, - plan.condition, - left, - right) - case plan: ShuffledHashJoinExec if enableShuffledHashJoin && enableDedupLeftSemiJoin => { - plan.joinType match { - case LeftSemi => { - if (plan.condition.isEmpty && plan.right.output.size >= dedupLeftSemiJoinThreshold) { - val left = replaceWithColumnarPlan(plan.left) - val right = replaceWithColumnarPlan(plan.right) - val partialAgg = PhysicalAggregation.unapply(Aggregate(plan.right.output, plan.right.output, new DummyLogicalPlan)) match { - case Some((groupingExpressions, aggExpressions, resultExpressions, _)) - if aggExpressions.forall(expr => expr.isInstanceOf[AggregateExpression]) => - ExtendedAggUtils.planPartialAggregateWithoutDistinct( - ExtendedAggUtils.normalizeGroupingExpressions(groupingExpressions), - aggExpressions.map(_.asInstanceOf[AggregateExpression]), - resultExpressions, - right).asInstanceOf[HashAggregateExec] + case plan: TakeOrderedAndProjectExec if enableTakeOrderedAndProject => + val child = replaceWithColumnarPlan(plan.child) + logInfo(s"Columnar Processing for ${plan.getClass} is currently supported.") + ColumnarTakeOrderedAndProjectExec( + plan.limit, + plan.sortOrder, + plan.projectList, + child) + case plan: BroadcastExchangeExec if enableColumnarBroadcastExchange => + val child = replaceWithColumnarPlan(plan.child) + logInfo(s"Columnar Processing for ${plan.getClass} is currently supported.") + new ColumnarBroadcastExchangeExec(plan.mode, child) + case plan: BroadcastHashJoinExec if enableColumnarBroadcastJoin => + logInfo(s"Columnar Processing for ${plan.getClass} is currently supported.") + val left = replaceWithColumnarPlan(plan.left) + val right = replaceWithColumnarPlan(plan.right) + logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.") + ColumnarBroadcastHashJoinExec( + plan.leftKeys, + plan.rightKeys, + plan.joinType, + plan.buildSide, + plan.condition, + left, + right) + case plan: ShuffledHashJoinExec if enableShuffledHashJoin && enableDedupLeftSemiJoin && !SQLConf.get.adaptiveExecutionEnabled => { + plan.joinType match { + case LeftSemi => { + if (plan.condition.isEmpty && plan.right.output.size >= dedupLeftSemiJoinThreshold) { + val left = replaceWithColumnarPlan(plan.left) + val right = replaceWithColumnarPlan(plan.right) + val partialAgg = PhysicalAggregation.unapply(Aggregate(plan.right.output, plan.right.output, new DummyLogicalPlan)) match { + case Some((groupingExpressions, aggExpressions, resultExpressions, _)) + if aggExpressions.forall(expr => expr.isInstanceOf[AggregateExpression]) => + ExtendedAggUtils.planPartialAggregateWithoutDistinct( + ExtendedAggUtils.normalizeGroupingExpressions(groupingExpressions), + aggExpressions.map(_.asInstanceOf[AggregateExpression]), + resultExpressions, + right).asInstanceOf[HashAggregateExec] + } + val newHashAgg = new ColumnarHashAggregateExec( + partialAgg.requiredChildDistributionExpressions, + partialAgg.isStreaming, + partialAgg.numShufflePartitions, + partialAgg.groupingExpressions, + partialAgg.aggregateExpressions, + partialAgg.aggregateAttributes, + partialAgg.initialInputBufferOffset, + partialAgg.resultExpressions, + right) + + ColumnarShuffledHashJoinExec( + plan.leftKeys, + plan.rightKeys, + plan.joinType, + plan.buildSide, + plan.condition, + left, + newHashAgg, + plan.isSkewJoin) + } else { + val left = replaceWithColumnarPlan(plan.left) + val right = replaceWithColumnarPlan(plan.right) + logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.") + ColumnarShuffledHashJoinExec( + plan.leftKeys, + plan.rightKeys, + plan.joinType, + plan.buildSide, + plan.condition, + left, + right, + plan.isSkewJoin) } - val newHashAgg = new ColumnarHashAggregateExec( - partialAgg.requiredChildDistributionExpressions, - partialAgg.isStreaming, - partialAgg.numShufflePartitions, - partialAgg.groupingExpressions, - partialAgg.aggregateExpressions, - partialAgg.aggregateAttributes, - partialAgg.initialInputBufferOffset, - partialAgg.resultExpressions, - right) - - ColumnarShuffledHashJoinExec( - plan.leftKeys, - plan.rightKeys, - plan.joinType, - plan.buildSide, - plan.condition, - left, - newHashAgg, - plan.isSkewJoin) - } else { + } + case _ => { val left = replaceWithColumnarPlan(plan.left) val right = replaceWithColumnarPlan(plan.right) logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.") @@ -479,149 +469,117 @@ case class ColumnarPreOverrides() extends Rule[SparkPlan] { plan.isSkewJoin) } } - case _ => { - val left = replaceWithColumnarPlan(plan.left) - val right = replaceWithColumnarPlan(plan.right) - logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.") - ColumnarShuffledHashJoinExec( - plan.leftKeys, - plan.rightKeys, - plan.joinType, - plan.buildSide, - plan.condition, - left, - right, - plan.isSkewJoin) - } - } - } - case plan: ShuffledHashJoinExec if enableShuffledHashJoin => - val left = replaceWithColumnarPlan(plan.left) - val right = replaceWithColumnarPlan(plan.right) - logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.") - ColumnarShuffledHashJoinExec( - plan.leftKeys, - plan.rightKeys, - plan.joinType, - plan.buildSide, - plan.condition, - left, - right, - plan.isSkewJoin) - case plan: SortMergeJoinExec if enableColumnarSortMergeJoin => - logInfo(s"Columnar Processing for ${plan.getClass} is currently supported.") - val left = replaceWithColumnarPlan(plan.left) - val right = replaceWithColumnarPlan(plan.right) - logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.") - new ColumnarSortMergeJoinExec( - plan.leftKeys, - plan.rightKeys, - plan.joinType, - plan.condition, - left, - right, - plan.isSkewJoin) - case plan: TopNSortExec if enableColumnarTopNSort => - val child = replaceWithColumnarPlan(plan.child) - logInfo(s"Columnar Processing for ${plan.getClass} is currently supported.") - ColumnarTopNSortExec(plan.n, plan.strictTopN, plan.partitionSpec, plan.sortOrder, plan.global, child) - case plan: SortExec if enableColumnarSort => - val child = replaceWithColumnarPlan(plan.child) - logInfo(s"Columnar Processing for ${plan.getClass} is currently supported.") - ColumnarSortExec(plan.sortOrder, plan.global, child, plan.testSpillFrequency) - case plan: WindowExec if enableColumnarWindow => - val child = replaceWithColumnarPlan(plan.child) - if (child.output.isEmpty) { - return plan } - logInfo(s"Columnar Processing for ${plan.getClass} is currently supported.") - child match { - case ColumnarSortExec(sortOrder, _, sortChild, _) => - if (Seq(plan.partitionSpec.map(SortOrder(_, Ascending)) ++ plan.orderSpec) == Seq(sortOrder)) { - ColumnarWindowExec(plan.windowExpression, plan.partitionSpec, plan.orderSpec, sortChild) - } else { + case plan: ShuffledHashJoinExec if enableShuffledHashJoin => + val left = replaceWithColumnarPlan(plan.left) + val right = replaceWithColumnarPlan(plan.right) + logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.") + ColumnarShuffledHashJoinExec( + plan.leftKeys, + plan.rightKeys, + plan.joinType, + plan.buildSide, + plan.condition, + left, + right, + plan.isSkewJoin) + case plan: SortMergeJoinExec if enableColumnarSortMergeJoin => + logInfo(s"Columnar Processing for ${plan.getClass} is currently supported.") + val left = replaceWithColumnarPlan(plan.left) + val right = replaceWithColumnarPlan(plan.right) + logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.") + new ColumnarSortMergeJoinExec( + plan.leftKeys, + plan.rightKeys, + plan.joinType, + plan.condition, + left, + right, + plan.isSkewJoin) + case plan: TopNSortExec if enableColumnarTopNSort => + val child = replaceWithColumnarPlan(plan.child) + logInfo(s"Columnar Processing for ${plan.getClass} is currently supported.") + ColumnarTopNSortExec(plan.n, plan.strictTopN, plan.partitionSpec, plan.sortOrder, plan.global, child) + case plan: SortExec if enableColumnarSort => + val child = replaceWithColumnarPlan(plan.child) + logInfo(s"Columnar Processing for ${plan.getClass} is currently supported.") + ColumnarSortExec(plan.sortOrder, plan.global, child, plan.testSpillFrequency) + case plan: WindowExec if enableColumnarWindow => + val child = replaceWithColumnarPlan(plan.child) + if (child.output.isEmpty) { + return plan + } + logInfo(s"Columnar Processing for ${plan.getClass} is currently supported.") + child match { + case ColumnarSortExec(sortOrder, _, sortChild, _) => + if (Seq(plan.partitionSpec.map(SortOrder(_, Ascending)) ++ plan.orderSpec) == Seq(sortOrder)) { + ColumnarWindowExec(plan.windowExpression, plan.partitionSpec, plan.orderSpec, sortChild) + } else { + ColumnarWindowExec(plan.windowExpression, plan.partitionSpec, plan.orderSpec, child) + } + case _ => ColumnarWindowExec(plan.windowExpression, plan.partitionSpec, plan.orderSpec, child) + } + case plan: UnionExec if enableColumnarUnion => + val children = plan.children.map(replaceWithColumnarPlan) + logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.") + ColumnarUnionExec(children) + case plan: ShuffleExchangeExec if enableColumnarShuffle || enableRowShuffle => + val child = replaceWithColumnarPlan(plan.child) + if (child.output.nonEmpty) { + logInfo(s"Columnar Processing for ${plan.getClass} is currently supported.") + if (child.output.size > columnsThreshold && enableRowShuffle) { + new ColumnarShuffleExchangeExec(plan.outputPartitioning, child, plan.shuffleOrigin, true) + } else if (enableColumnarShuffle) { + new ColumnarShuffleExchangeExec(plan.outputPartitioning, child, plan.shuffleOrigin, false) + } else { + plan } - case _ => - ColumnarWindowExec(plan.windowExpression, plan.partitionSpec, plan.orderSpec, child) - } - case plan: UnionExec if enableColumnarUnion => - val children = plan.children.map(replaceWithColumnarPlan) - logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.") - ColumnarUnionExec(children) - case plan: ShuffleExchangeExec if enableColumnarShuffle || enableRowShuffle => - val child = replaceWithColumnarPlan(plan.child) - if (child.output.nonEmpty) { - logInfo(s"Columnar Processing for ${plan.getClass} is currently supported.") - if (child.isInstanceOf[ColumnarHashAggregateExec] && child.output.size > columnsThreshold - && enableRowShuffle) { - new ColumnarShuffleExchangeExec(plan.outputPartitioning, child, plan.shuffleOrigin, true) - } else if (enableColumnarShuffle) { - new ColumnarShuffleExchangeExec(plan.outputPartitioning, child, plan.shuffleOrigin, false) } else { plan } - } else { - plan - } - case plan: AQEShuffleReadExec if columnarConf.enableColumnarShuffle => - plan.child match { - case shuffle: ColumnarShuffleExchangeExec => - logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.") - OmniAQEShuffleReadExec(plan.child, plan.partitionSpecs) - case ShuffleQueryStageExec(_, shuffle: ColumnarShuffleExchangeExec, _) => - logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.") - OmniAQEShuffleReadExec(plan.child, plan.partitionSpecs) - case ShuffleQueryStageExec(_, reused: ReusedExchangeExec, _) => - reused match { - case ReusedExchangeExec(_, shuffle: ColumnarShuffleExchangeExec) => - logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.") - OmniAQEShuffleReadExec( - plan.child, - plan.partitionSpecs) - case _ => - plan - } - case _ => - plan - } - case plan: LocalLimitExec if enableLocalColumnarLimit => - val child = replaceWithColumnarPlan(plan.child) - logInfo(s"Columnar Processing for ${plan.getClass} is currently supported.") - ColumnarLocalLimitExec(plan.limit, child) - case plan: GlobalLimitExec if enableGlobalColumnarLimit => - val child = replaceWithColumnarPlan(plan.child) - logInfo(s"Columnar Processing for ${plan.getClass} is currently supported.") - ColumnarGlobalLimitExec(plan.limit, child) - case plan: CoalesceExec if enableColumnarCoalesce => - val child = replaceWithColumnarPlan(plan.child) - logInfo(s"Columnar Processing for ${plan.getClass} is currently supported.") - ColumnarCoalesceExec(plan.numPartitions, child) - case p => - val children = plan.children.map(replaceWithColumnarPlan) - logInfo(s"Columnar Processing for ${p.getClass} is currently not supported.") - p.withNewChildren(children) + case plan: AQEShuffleReadExec if columnarConf.enableColumnarShuffle => + plan.child match { + case shuffle: ColumnarShuffleExchangeExec => + logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.") + OmniAQEShuffleReadExec(plan.child, plan.partitionSpecs) + case ShuffleQueryStageExec(_, shuffle: ColumnarShuffleExchangeExec, _) => + logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.") + OmniAQEShuffleReadExec(plan.child, plan.partitionSpecs) + case ShuffleQueryStageExec(_, reused: ReusedExchangeExec, _) => + reused match { + case ReusedExchangeExec(_, shuffle: ColumnarShuffleExchangeExec) => + logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.") + OmniAQEShuffleReadExec( + plan.child, + plan.partitionSpecs) + case _ => + plan + } + case _ => + plan + } + case plan: LocalLimitExec if enableLocalColumnarLimit => + val child = replaceWithColumnarPlan(plan.child) + logInfo(s"Columnar Processing for ${plan.getClass} is currently supported.") + ColumnarLocalLimitExec(plan.limit, child) + case plan: GlobalLimitExec if enableGlobalColumnarLimit => + val child = replaceWithColumnarPlan(plan.child) + logInfo(s"Columnar Processing for ${plan.getClass} is currently supported.") + ColumnarGlobalLimitExec(plan.limit, child) + case plan: CoalesceExec if enableColumnarCoalesce => + val child = replaceWithColumnarPlan(plan.child) + logInfo(s"Columnar Processing for ${plan.getClass} is currently supported.") + ColumnarCoalesceExec(plan.numPartitions, child) + case p => + val children = plan.children.map(replaceWithColumnarPlan) + logInfo(s"Columnar Processing for ${p.getClass} is currently not supported.") + p.withNewChildren(children) + } } - def fallBackBroadcastQueryStage(curPlan: BroadcastQueryStageExec): BroadcastQueryStageExec = { - curPlan.plan match { - case originalBroadcastPlan: ColumnarBroadcastExchangeExec => - BroadcastQueryStageExec( - curPlan.id, - BroadcastExchangeExec( - originalBroadcastPlan.mode, - ColumnarBroadcastExchangeAdaptorExec(originalBroadcastPlan, 1)), - curPlan._canonicalized) - case ReusedExchangeExec(_, originalBroadcastPlan: ColumnarBroadcastExchangeExec) => - BroadcastQueryStageExec( - curPlan.id, - BroadcastExchangeExec( - originalBroadcastPlan.mode, - ColumnarBroadcastExchangeAdaptorExec(curPlan.plan, 1)), - curPlan._canonicalized) - case _ => - curPlan - } + override def apply(plan: SparkPlan): SparkPlan = { + replaceWithColumnarPlan(plan) } } @@ -666,23 +624,23 @@ case class ColumnarPostOverrides() extends Rule[SparkPlan] { case child: BroadcastQueryStageExec => child.plan match { case originalBroadcastPlan: ColumnarBroadcastExchangeExec => - BroadcastQueryStageExec( - child.id, - BroadcastExchangeExec( - originalBroadcastPlan.mode, - ColumnarBroadcastExchangeAdaptorExec(originalBroadcastPlan, 1)), child._canonicalized) + child case ReusedExchangeExec(_, originalBroadcastPlan: ColumnarBroadcastExchangeExec) => - BroadcastQueryStageExec( - child.id, - BroadcastExchangeExec( - originalBroadcastPlan.mode, - ColumnarBroadcastExchangeAdaptorExec(child.plan, 1)), child._canonicalized) + child case _ => replaceColumnarToRow(plan, conf) } case _ => replaceColumnarToRow(plan, conf) } + case plan: BroadcastExchangeExecProxy => + val children = plan.children.map { + case c: ColumnarToRowExec => + replaceWithColumnarPlan(c.child) + case other => + replaceWithColumnarPlan(other) + } + plan.withNewChildren(children) case r: SparkPlan if !r.isInstanceOf[QueryStageExec] && !r.supportsColumnar && r.children.exists(c => c.isInstanceOf[ColumnarToRowExec]) => @@ -710,12 +668,11 @@ case class ColumnarPostOverrides() extends Rule[SparkPlan] { } case class ColumnarOverrideRules(session: SparkSession) extends ColumnarRule with Logging { - def columnarEnabled: Boolean = session.sqlContext.getConf( - "org.apache.spark.sql.columnar.enabled", "true").trim.toBoolean - - def rowGuardOverrides: ColumnarGuardRule = ColumnarGuardRule() - def preOverrides: ColumnarPreOverrides = ColumnarPreOverrides() + private def preOverrides: List[SparkSession => Rule[SparkPlan]] = List( + FallbackMultiCodegens, + (_: SparkSession) => AddTransformHintRule(), + (_: SparkSession) => ColumnarPreOverrides(isSupportAdaptive)) def postOverrides: ColumnarPostOverrides = ColumnarPostOverrides() @@ -740,9 +697,14 @@ case class ColumnarOverrideRules(session: SparkSession) extends ColumnarRule wit maybe(session, plan) { isSupportAdaptive = supportAdaptive(plan) val rule = preOverrides - rule.setAdaptiveSupport(isSupportAdaptive) logInfo("Using BoostKit Spark Native Sql Engine Extension ColumnarPreOverrides") - rule(rowGuardOverrides(plan)) + val overridden = rule.foldLeft(plan) { + (p, getRule) => + val rule = getRule(session) + val newPlan = rule(p) + newPlan + } + overridden } override def postColumnarTransitions: Rule[SparkPlan] = plan => PhysicalPlanSelector. @@ -763,6 +725,8 @@ class ColumnarPlugin extends (SparkSessionExtensions => Unit) with Logging { extensions.injectOptimizerRule(_ => DelayCartesianProduct) extensions.injectOptimizerRule(_ => HeuristicJoinReorder) extensions.injectOptimizerRule(_ => MergeSubqueryFilters) + extensions.injectQueryStagePrepRule(session => FallbackBroadcastExchange(session)) + extensions.injectQueryStagePrepRule(session => DedupLeftSemiJoinAQE(session)) extensions.injectQueryStagePrepRule(_ => TopNPushDownForWindow) } } @@ -770,7 +734,7 @@ class ColumnarPlugin extends (SparkSessionExtensions => Unit) with Logging { private class OmniTaskStartExecutorPlugin extends ExecutorPlugin { override def onTaskStart(): Unit = { addLeakSafeTaskCompletionListener[Unit](_ => { - MemoryManager.clearMemory() + MemoryManager.reclaimMemory() }) } } diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarGuardRule.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/TransformHintRule.scala similarity index 49% rename from omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarGuardRule.scala rename to omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/TransformHintRule.scala index d20781708..5191c4a10 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarGuardRule.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/TransformHintRule.scala @@ -1,5 +1,5 @@ /* - * Copyright (C) 2020-2024. Huawei Technologies Co., Ltd. All rights reserved. + * Copyright (C) 2024-2024. Huawei Technologies Co., Ltd. All rights reserved. * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. @@ -18,34 +18,90 @@ package com.huawei.boostkit.spark -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.commons.lang3.exception.ExceptionUtils +import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.execution._ +import org.apache.spark.sql.catalyst.trees.TreeNodeTag import org.apache.spark.sql.execution.adaptive.{BroadcastQueryStageExec, OmniAQEShuffleReadExec} import org.apache.spark.sql.execution.aggregate.HashAggregateExec import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec} -import org.apache.spark.sql.execution.joins._ +import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, BroadcastNestedLoopJoinExec, ColumnarBroadcastHashJoinExec, ColumnarShuffledHashJoinExec, ColumnarSortMergeJoinExec, ShuffledHashJoinExec, SortMergeJoinExec} import org.apache.spark.sql.execution.window.WindowExec +import org.apache.spark.sql.execution.{CoalesceExec, CodegenSupport, ColumnarBroadcastExchangeExec, ColumnarCoalesceExec, ColumnarExpandExec, ColumnarFileSourceScanExec, ColumnarFilterExec, ColumnarGlobalLimitExec, ColumnarHashAggregateExec, ColumnarLocalLimitExec, ColumnarProjectExec, ColumnarShuffleExchangeExec, ColumnarSortExec, ColumnarTakeOrderedAndProjectExec, ColumnarTopNSortExec, ColumnarUnionExec, ColumnarWindowExec, ExpandExec, FileSourceScanExec, FilterExec, GlobalLimitExec, LocalLimitExec, ProjectExec, SortExec, SparkPlan, TakeOrderedAndProjectExec, TopNSortExec, UnionExec} import org.apache.spark.sql.types.ColumnarBatchSupportUtil.checkColumnarBatchSupport -case class RowGuard(child: SparkPlan) extends SparkPlan { - def output: Seq[Attribute] = child.output +trait TransformHint { + val stacktrace: Option[String] = + if (TransformHints.DEBUG) { + Some(ExceptionUtils.getStackTrace(new Throwable())) + } else None +} + +case class TRANSFORM_SUPPORTED() extends TransformHint +case class TRANSFORM_UNSUPPORTED(reason: Option[String]) extends TransformHint + +object TransformHints { + val TAG: TreeNodeTag[TransformHint] = + TreeNodeTag[TransformHint]("omni.transformhint") + + val DEBUG = false + + def isAlreadyTagged(plan: SparkPlan): Boolean = { + plan.getTagValue(TAG).isDefined + } + + def isTransformable(plan: SparkPlan): Boolean = { + if (plan.getTagValue(TAG).isDefined) { + return plan.getTagValue(TAG).get.isInstanceOf[TRANSFORM_SUPPORTED] + } + false + } + + def isNotTransformable(plan: SparkPlan): Boolean = { + if (plan.getTagValue(TAG).isDefined) { + return plan.getTagValue(TAG).get.isInstanceOf[TRANSFORM_UNSUPPORTED] + } + false + } + + def tag(plan: SparkPlan, hint: TransformHint): Unit = { + if (isAlreadyTagged(plan)) { + if (isNotTransformable(plan) && hint.isInstanceOf[TRANSFORM_SUPPORTED]) { + throw new UnsupportedOperationException( + "Plan was already tagged as non-transformable, " + + s"cannot mark it as transformable after that:\n${plan.toString()}") + } + } + plan.setTagValue(TAG, hint) + } - protected def doExecute(): RDD[InternalRow] = { - throw new UnsupportedOperationException + def untag(plan: SparkPlan): Unit = { + plan.unsetTagValue(TAG) } - def children: Seq[SparkPlan] = Seq(child) + def tagTransformable(plan: SparkPlan): Unit = { + tag(plan, TRANSFORM_SUPPORTED()) + } + + def tagNotTransformable(plan: SparkPlan, reason: String = ""): Unit = { + tag(plan, TRANSFORM_UNSUPPORTED(Some(reason))) + } + + def getHint(plan: SparkPlan): TransformHint = { + if (!isAlreadyTagged(plan)) { + throw new IllegalStateException("Transform hint tag not set in plan: " + plan.toString()) + } + plan.getTagValue(TAG).getOrElse(throw new IllegalStateException()) + } - override protected def withNewChildrenInternal(newChildren: IndexedSeq[SparkPlan]): SparkPlan = - legacyWithNewChildren(newChildren) + def getHintOption(plan: SparkPlan): Option[TransformHint] = { + plan.getTagValue(TAG) + } } -case class ColumnarGuardRule() extends Rule[SparkPlan] { +case class AddTransformHintRule() extends Rule[SparkPlan] { + val columnarConf: ColumnarPluginConfig = ColumnarPluginConfig.getSessionConf - val preferColumnar: Boolean = columnarConf.enablePreferColumnar val enableColumnarShuffle: Boolean = columnarConf.enableColumnarShuffle val enableColumnarTopNSort: Boolean = columnarConf.enableColumnarTopNSort val enableColumnarSort: Boolean = columnarConf.enableColumnarSort @@ -65,17 +121,41 @@ case class ColumnarGuardRule() extends Rule[SparkPlan] { val enableColumnarFileScan: Boolean = columnarConf.enableColumnarFileScan val enableLocalColumnarLimit: Boolean = columnarConf.enableLocalColumnarLimit val enableGlobalColumnarLimit: Boolean = columnarConf.enableGlobalColumnarLimit - val optimizeLevel: Integer = columnarConf.joinOptimizationThrottle val enableColumnarCoalesce: Boolean = columnarConf.enableColumnarCoalesce - private def tryConvertToColumnar(plan: SparkPlan): Boolean = { + override def apply(plan: SparkPlan): SparkPlan = { + addTransformableTags(plan) + } + + /** Inserts a transformable tag on top of those that are not supported. */ + private def addTransformableTags(plan: SparkPlan): SparkPlan = { + // Walk the tree with post-order + val out = plan.withNewChildren(plan.children.map(addTransformableTags)) + addTransformableTag(out) + out + } + + private def addTransformableTag(plan: SparkPlan): Unit = { + if (TransformHints.isAlreadyTagged(plan)) { + logDebug( + s"Skip adding transformable tag, since plan already tagged as " + + s"${TransformHints.getHint(plan)}: ${plan.toString()}") + return + } try { - val columnarPlan = plan match { + plan match { case plan: FileSourceScanExec => if (!checkColumnarBatchSupport(conf, plan)) { - return false + TransformHints.tagNotTransformable( + plan, + "columnar Batch is not enabled in FileSourceScanExec") + return + } + if (!enableColumnarFileScan) { + TransformHints.tagNotTransformable( + plan, "columnar FileScan is not enabled in FileSourceScanExec") + return } - if (!enableColumnarFileScan) return false ColumnarFileSourceScanExec( plan.relation, plan.output, @@ -88,16 +168,32 @@ case class ColumnarGuardRule() extends Rule[SparkPlan] { plan.disableBucketedScan ).buildCheck() case plan: ProjectExec => - if (!enableColumnarProject) return false + if (!enableColumnarProject) { + TransformHints.tagNotTransformable( + plan, "columnar Project is not enabled in ProjectExec") + return + } ColumnarProjectExec(plan.projectList, plan.child).buildCheck() case plan: FilterExec => - if (!enableColumnarFilter) return false + if (!enableColumnarFilter) { + TransformHints.tagNotTransformable( + plan, "columnar Filter is not enabled in FilterExec") + return + } ColumnarFilterExec(plan.condition, plan.child).buildCheck() case plan: ExpandExec => - if (!enableColumnarExpand) return false + if (!enableColumnarExpand) { + TransformHints.tagNotTransformable( + plan, "columnar Expand is not enabled in ExpandExec") + return + } ColumnarExpandExec(plan.projections, plan.output, plan.child).buildCheck() case plan: HashAggregateExec => - if (!enableColumnarHashAgg) return false + if (!enableColumnarHashAgg) { + TransformHints.tagNotTransformable( + plan, "columnar HashAggregate is not enabled in HashAggregateExec") + return + } new ColumnarHashAggregateExec( plan.requiredChildDistributionExpressions, plan.isStreaming, @@ -109,34 +205,62 @@ case class ColumnarGuardRule() extends Rule[SparkPlan] { plan.resultExpressions, plan.child).buildCheck() case plan: TopNSortExec => - if (!enableColumnarTopNSort) return false + if (!enableColumnarTopNSort) { + TransformHints.tagNotTransformable( + plan, "columnar TopNSort is not enabled in TopNSortExec") + return + } ColumnarTopNSortExec(plan.n, plan.strictTopN, plan.partitionSpec, plan.sortOrder, plan.global, plan.child).buildCheck() case plan: SortExec => - if (!enableColumnarSort) return false + if (!enableColumnarSort) { + TransformHints.tagNotTransformable( + plan, "columnar Sort is not enabled in SortExec") + return + } ColumnarSortExec(plan.sortOrder, plan.global, plan.child, plan.testSpillFrequency).buildCheck() case plan: BroadcastExchangeExec => - if (!enableColumnarBroadcastExchange) return false + if (!enableColumnarBroadcastExchange) { + TransformHints.tagNotTransformable( + plan, "columnar BroadcastExchange is not enabled in BroadcastExchangeExec") + return + } new ColumnarBroadcastExchangeExec(plan.mode, plan.child).buildCheck() case plan: TakeOrderedAndProjectExec => - if (!enableTakeOrderedAndProject) return false + if (!enableTakeOrderedAndProject) { + TransformHints.tagNotTransformable( + plan, "columnar TakeOrderedAndProject is not enabled in TakeOrderedAndProjectExec") + return + } ColumnarTakeOrderedAndProjectExec( plan.limit, plan.sortOrder, plan.projectList, plan.child).buildCheck() case plan: UnionExec => - if (!enableColumnarUnion) return false + if (!enableColumnarUnion) { + TransformHints.tagNotTransformable( + plan, "columnar Union is not enabled in UnionExec") + return + } ColumnarUnionExec(plan.children).buildCheck() case plan: ShuffleExchangeExec => - if (!enableColumnarShuffle) return false - new ColumnarShuffleExchangeExec(plan.outputPartitioning, plan.child, plan.shuffleOrigin) + if (!enableColumnarShuffle) { + TransformHints.tagNotTransformable( + plan, "columnar ShuffleExchange is not enabled in ShuffleExchangeExec") + return + } + ColumnarShuffleExchangeExec(plan.outputPartitioning, plan.child, plan.shuffleOrigin) .buildCheck() case plan: BroadcastHashJoinExec => // We need to check if BroadcastExchangeExec can be converted to columnar-based. // If not, BHJ should also be row-based. - if (!enableColumnarBroadcastJoin) return false + if (!enableColumnarBroadcastJoin) { + TransformHints.tagNotTransformable( + plan, "columnar BroadcastHashJoin is not enabled in BroadcastHashJoinExec") + return + } val left = plan.left left match { case exec: BroadcastExchangeExec => @@ -175,8 +299,12 @@ case class ColumnarGuardRule() extends Rule[SparkPlan] { plan.right, plan.isNullAwareAntiJoin).buildCheck() case plan: SortMergeJoinExec => - if (!enableColumnarSortMergeJoin) return false - new ColumnarSortMergeJoinExec( + if (!enableColumnarSortMergeJoin) { + TransformHints.tagNotTransformable( + plan, "columnar SortMergeJoin is not enabled in SortMergeJoinExec") + return + } + ColumnarSortMergeJoinExec( plan.leftKeys, plan.rightKeys, plan.joinType, @@ -185,11 +313,19 @@ case class ColumnarGuardRule() extends Rule[SparkPlan] { plan.right, plan.isSkewJoin).buildCheck() case plan: WindowExec => - if (!enableColumnarWindow) return false + if (!enableColumnarWindow) { + TransformHints.tagNotTransformable( + plan, "columnar Window is not enabled in WindowExec") + return + } ColumnarWindowExec(plan.windowExpression, plan.partitionSpec, plan.orderSpec, plan.child).buildCheck() case plan: ShuffledHashJoinExec => - if (!enableShuffledHashJoin) return false + if (!enableShuffledHashJoin) { + TransformHints.tagNotTransformable( + plan, "columnar ShuffledHashJoin is not enabled in ShuffledHashJoinExec") + return + } ColumnarShuffledHashJoinExec( plan.leftKeys, plan.rightKeys, @@ -200,36 +336,52 @@ case class ColumnarGuardRule() extends Rule[SparkPlan] { plan.right, plan.isSkewJoin).buildCheck() case plan: LocalLimitExec => - if (!enableLocalColumnarLimit) return false + if (!enableLocalColumnarLimit) { + TransformHints.tagNotTransformable( + plan, "columnar LocalLimit is not enabled in LocalLimitExec") + return + } ColumnarLocalLimitExec(plan.limit, plan.child).buildCheck() case plan: GlobalLimitExec => - if (!enableGlobalColumnarLimit) return false + if (!enableGlobalColumnarLimit) { + TransformHints.tagNotTransformable( + plan, "columnar GlobalLimit is not enabled in GlobalLimitExec") + return + } ColumnarGlobalLimitExec(plan.limit, plan.child).buildCheck() - case plan: BroadcastNestedLoopJoinExec => return false + case plan: BroadcastNestedLoopJoinExec => + TransformHints.tagNotTransformable( + plan, "columnar BroadcastNestedLoopJoin is not support") case plan: CoalesceExec => - if (!enableColumnarCoalesce) return false + if (!enableColumnarCoalesce) { + TransformHints.tagNotTransformable( + plan, "columnar Coalesce is not enabled in CoalesceExec") + return + } ColumnarCoalesceExec(plan.numPartitions, plan.child).buildCheck() - case p => - p + case _ => } - } - catch { - case e: UnsupportedOperationException => - logDebug(s"[OPERATOR FALLBACK] ${e} ${plan.getClass} falls back to Spark operator") - return false + TransformHints.tagTransformable(plan) + } catch { + case throwable @ (_:UnsupportedOperationException | _:RuntimeException | _:Throwable) => + val message = s"[OPERATOR FALLBACK] ${throwable} ${plan.getClass} falls back to Spark operator" + logDebug(message) + TransformHints.tagNotTransformable(plan, reason = message) case l: UnsatisfiedLinkError => + TransformHints.tagNotTransformable(plan) throw l case f: NoClassDefFoundError => + TransformHints.tagNotTransformable(plan) throw f - case r: RuntimeException => - logDebug(s"[OPERATOR FALLBACK] ${r} ${plan.getClass} falls back to Spark operator") - return false - case t: Throwable => - logDebug(s"[OPERATOR FALLBACK] ${t} ${plan.getClass} falls back to Spark operator") - return false } - true } +} + +case class FallbackMultiCodegens(session: SparkSession) extends Rule[SparkPlan] { + + val columnarConf: ColumnarPluginConfig = ColumnarPluginConfig.getSessionConf + val optimizeLevel: Integer = columnarConf.joinOptimizationThrottle + val preferColumnar: Boolean = columnarConf.enablePreferColumnar private def existsMultiCodegens(plan: SparkPlan, count: Int = 0): Boolean = plan match { @@ -239,7 +391,7 @@ case class ColumnarGuardRule() extends Rule[SparkPlan] { case plan: ShuffledHashJoinExec => if ((count + 1) >= optimizeLevel) return true plan.children.map(existsMultiCodegens(_, count + 1)).exists(_ == true) - case other => false + case _ => false } private def supportCodegen(plan: SparkPlan): Boolean = plan match { @@ -248,51 +400,50 @@ case class ColumnarGuardRule() extends Rule[SparkPlan] { case _ => false } + private def tagNotTransformable(plan: SparkPlan): SparkPlan = { + TransformHints.tagNotTransformable(plan, "fallback multi codegens") + plan + } + /** * Inserts an InputAdapter on top of those that do not support codegen. */ - private def insertRowGuardRecursive(plan: SparkPlan): SparkPlan = { + private def tagNotTransformableRecursive(plan: SparkPlan): SparkPlan = { plan match { case p: ShuffleExchangeExec => - RowGuard(p.withNewChildren(p.children.map(insertRowGuardOrNot))) + tagNotTransformable(p.withNewChildren(p.children.map(tagNotTransformableForMultiCodegens))) case p: BroadcastExchangeExec => - RowGuard(p.withNewChildren(p.children.map(insertRowGuardOrNot))) + tagNotTransformable(p.withNewChildren(p.children.map(tagNotTransformableForMultiCodegens))) case p: ShuffledHashJoinExec => - RowGuard(p.withNewChildren(p.children.map(insertRowGuardRecursive))) + tagNotTransformable(p.withNewChildren(p.children.map(tagNotTransformableRecursive))) case p if !supportCodegen(p) => - // insert row guard them recursively - p.withNewChildren(p.children.map(insertRowGuardOrNot)) + // tag them recursively + p.withNewChildren(p.children.map(tagNotTransformableForMultiCodegens)) case p: OmniAQEShuffleReadExec => - p.withNewChildren(p.children.map(insertRowGuardOrNot)) + p.withNewChildren(p.children.map(tagNotTransformableForMultiCodegens)) case p: BroadcastQueryStageExec => p - case p => RowGuard(p.withNewChildren(p.children.map(insertRowGuardRecursive))) + case p => tagNotTransformable(p.withNewChildren(p.children.map(tagNotTransformableRecursive))) } } - private def insertRowGuard(plan: SparkPlan): SparkPlan = { - RowGuard(plan.withNewChildren(plan.children.map(insertRowGuardOrNot))) - } - /** * Inserts a WholeStageCodegen on top of those that support codegen. */ - private def insertRowGuardOrNot(plan: SparkPlan): SparkPlan = { + private def tagNotTransformableForMultiCodegens(plan: SparkPlan): SparkPlan = { plan match { // For operators that will output domain object, do not insert WholeStageCodegen for it as // domain object can not be written into unsafe row. case plan if !preferColumnar && existsMultiCodegens(plan) => - insertRowGuardRecursive(plan) - case plan if !tryConvertToColumnar(plan) => - insertRowGuard(plan) - case p: BroadcastQueryStageExec => - p + tagNotTransformableRecursive(plan) case other => - other.withNewChildren(other.children.map(insertRowGuardOrNot)) + other.withNewChildren(other.children.map(tagNotTransformableForMultiCodegens)) } } - def apply(plan: SparkPlan): SparkPlan = { - insertRowGuardOrNot(plan) + override def apply(plan: SparkPlan): SparkPlan = { + tagNotTransformableForMultiCodegens(plan) } + } + -- Gitee From ab4c3bd08f1ab0701fc0fc167dbec58b4cf220f0 Mon Sep 17 00:00:00 2001 From: dengzhaochu Date: Thu, 22 Aug 2024 09:47:06 +0800 Subject: [PATCH 2/5] add fallback policy --- .../boostkit/spark/ColumnarPlugin.scala | 93 ++++- .../boostkit/spark/ColumnarPluginConfig.scala | 4 + .../boostkit/spark/ExpandFallbackPolicy.scala | 153 ++++++++ .../boostkit/spark/TransformHintRule.scala | 190 ++++++++- .../spark/util/InsertTransitions.scala | 51 +++ .../spark/FallbackStrategiesSuite.scala | 367 ++++++++++++++++++ 6 files changed, 826 insertions(+), 32 deletions(-) create mode 100644 omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ExpandFallbackPolicy.scala create mode 100644 omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/util/InsertTransitions.scala create mode 100644 omnioperator/omniop-spark-extension/java/src/test/scala/com/huawei/boostkit/spark/FallbackStrategiesSuite.scala diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarPlugin.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarPlugin.scala index d445c941e..5d271f2d9 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarPlugin.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarPlugin.scala @@ -28,12 +28,12 @@ import org.apache.spark.sql.catalyst.expressions.{Ascending, DynamicPruningSubqu import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Partial, PartialMerge} import org.apache.spark.sql.catalyst.optimizer.{DelayCartesianProduct, HeuristicJoinReorder, MergeSubqueryFilters, RewriteSelfJoinInInPredicate} import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.execution.{RowToOmniColumnarExec, BroadcastExchangeExecProxy, _} +import org.apache.spark.sql.execution.{BroadcastExchangeExecProxy, RowToOmniColumnarExec, _} import org.apache.spark.sql.execution.adaptive.{AQEShuffleReadExec, BroadcastQueryStageExec, OmniAQEShuffleReadExec, QueryStageExec, ShuffleQueryStageExec} import org.apache.spark.sql.execution.aggregate.{DummyLogicalPlan, ExtendedAggUtils, HashAggregateExec} import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, Exchange, ReusedExchangeExec, ShuffleExchangeExec} import org.apache.spark.sql.execution.joins._ -import org.apache.spark.sql.execution.window.{WindowExec, TopNPushDownForWindow} +import org.apache.spark.sql.execution.window.{TopNPushDownForWindow, WindowExec} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.ColumnarBatchSupportUtil.checkColumnarBatchSupport import org.apache.spark.sql.catalyst.planning.PhysicalAggregation @@ -41,7 +41,9 @@ import org.apache.spark.sql.catalyst.plans.LeftSemi import org.apache.spark.sql.catalyst.plans.logical.Aggregate import org.apache.spark.sql.execution.util.SparkMemoryUtils.addLeakSafeTaskCompletionListener -case class ColumnarPreOverrides(isAdaptiveContext: Boolean) +import scala.collection.mutable.ListBuffer + +case class ColumnarPreOverrides(isSupportAdaptive: Boolean = true) extends Rule[SparkPlan] { val columnarConf: ColumnarPluginConfig = ColumnarPluginConfig.getSessionConf val enableColumnarFileScan: Boolean = columnarConf.enableColumnarFileScan @@ -63,7 +65,6 @@ case class ColumnarPreOverrides(isAdaptiveContext: Boolean) val enableShuffledHashJoin: Boolean = columnarConf.enableShuffledHashJoin val enableColumnarUnion: Boolean = columnarConf.enableColumnarUnion val enableFusion: Boolean = columnarConf.enableFusion - var isSupportAdaptive: Boolean = true val enableColumnarProjectFusion: Boolean = columnarConf.enableColumnarProjectFusion val enableLocalColumnarLimit: Boolean = columnarConf.enableLocalColumnarLimit val enableGlobalColumnarLimit: Boolean = columnarConf.enableGlobalColumnarLimit @@ -579,6 +580,7 @@ case class ColumnarPreOverrides(isAdaptiveContext: Boolean) } override def apply(plan: SparkPlan): SparkPlan = { + logInfo("Using BoostKit Spark Native Sql Engine Extension ColumnarPreOverrides") replaceWithColumnarPlan(plan) } } @@ -589,6 +591,7 @@ case class ColumnarPostOverrides() extends Rule[SparkPlan] { var isSupportAdaptive: Boolean = true def apply(plan: SparkPlan): SparkPlan = { + logInfo("Using BoostKit Spark Native Sql Engine Extension ColumnarPostOverrides") handleColumnarToRowPartialFetch(replaceWithColumnarPlan(plan)) } @@ -615,7 +618,7 @@ case class ColumnarPostOverrides() extends Rule[SparkPlan] { val child = replaceWithColumnarPlan(plan.child) logInfo(s"Columnar Processing for ${plan.getClass} is currently supported") RowToOmniColumnarExec(child) - case ColumnarToRowExec(child: ColumnarShuffleExchangeExec) => + case ColumnarToRowExec(child: ColumnarShuffleExchangeExec) if isSupportAdaptive => replaceWithColumnarPlan(child) case ColumnarToRowExec(child: ColumnarBroadcastExchangeExec) => replaceWithColumnarPlan(child) @@ -674,7 +677,49 @@ case class ColumnarOverrideRules(session: SparkSession) extends ColumnarRule wit (_: SparkSession) => AddTransformHintRule(), (_: SparkSession) => ColumnarPreOverrides(isSupportAdaptive)) - def postOverrides: ColumnarPostOverrides = ColumnarPostOverrides() + private def postOverrides: ColumnarPostOverrides = ColumnarPostOverrides() + + private def finallyRules(): List[SparkSession => Rule[SparkPlan]] = { + List( + (_: SparkSession) => RemoveTransformHintRule() + ) + } + + private def transformPlan(getRules: List[SparkSession => Rule[SparkPlan]], + plan: SparkPlan, + step: String) = { + logDebug( + s"${step}ColumnarTransitions preOverriden plan:\n${plan.toString}") + val overridden = getRules.foldLeft(plan) { + (p, getRule) => + val rule = getRule(session) + val newPlan = rule(p) + newPlan + } + logDebug( + s"${step}ColumnarTransitions afterOverriden plan:\n${overridden.toString}") + overridden + } + + // Holds the original plan for possible entire fallback. + private val localOriginalPlans: ThreadLocal[ListBuffer[SparkPlan]] = + ThreadLocal.withInitial(() => ListBuffer.empty[SparkPlan]) + + private def setOriginalPlan(plan: SparkPlan): Unit = { + localOriginalPlans.get.prepend(plan) + } + + private def originalPlan: SparkPlan = { + val plan = localOriginalPlans.get.head + assert(plan != null) + plan + } + + private def resetOriginalPlan(): Unit = localOriginalPlans.get.remove(0) + + private def fallbackPolicy(): List[SparkSession => Rule[SparkPlan]] = { + List((_: SparkSession) => ExpandFallbackPolicy(isSupportAdaptive, originalPlan)) + } var isSupportAdaptive: Boolean = true @@ -696,26 +741,36 @@ case class ColumnarOverrideRules(session: SparkSession) extends ColumnarRule wit override def preColumnarTransitions: Rule[SparkPlan] = plan => PhysicalPlanSelector. maybe(session, plan) { isSupportAdaptive = supportAdaptive(plan) - val rule = preOverrides - logInfo("Using BoostKit Spark Native Sql Engine Extension ColumnarPreOverrides") - val overridden = rule.foldLeft(plan) { - (p, getRule) => - val rule = getRule(session) - val newPlan = rule(p) - newPlan - } - overridden + setOriginalPlan(plan) + transformPlan(preOverrides, plan, "pre") } override def postColumnarTransitions: Rule[SparkPlan] = plan => PhysicalPlanSelector. maybe(session, plan) { - val rule = postOverrides - rule.setAdaptiveSupport(isSupportAdaptive) - logInfo("Using BoostKit Spark Native Sql Engine Extension ColumnarPostOverrides") - rule(plan) + val planWithFallBackPolicy = transformPlan(fallbackPolicy(), plan, "fallback") + + val finalPlan = planWithFallBackPolicy match { + case FallbackNode(fallbackPlan) => + // skip c2r and r2c replaceWithColumnarPlan + fallbackPlan + case plan => + val rule = postOverrides + rule.setAdaptiveSupport(isSupportAdaptive) + rule(plan) + } + resetOriginalPlan() + transformPlan(finallyRules(), finalPlan, "final") } } +case class RemoveTransformHintRule() extends Rule[SparkPlan] { + override def apply(plan: SparkPlan): SparkPlan = { + plan.foreach(TransformHints.untag) + plan + } +} + + class ColumnarPlugin extends (SparkSessionExtensions => Unit) with Logging { override def apply(extensions: SparkSessionExtensions): Unit = { logInfo("Using BoostKit Spark Native Sql Engine Extension to Speed Up Your Queries.") diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarPluginConfig.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarPluginConfig.scala index d58094521..a6c15e104 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarPluginConfig.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarPluginConfig.scala @@ -269,6 +269,10 @@ class ColumnarPluginConfig(conf: SQLConf) extends Logging { val columnsThreshold: Int = conf.getConfString("spark.omni.sql.columnar.columnsThreshold", "10").toInt + + // enable or disable bloomfilter subquery reuse + val enableBloomfilterSubqueryReuse: Boolean = + conf.getConfString("spark.omni.sql.columnar.bloomfilterSubqueryReuse", "false").toBoolean } diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ExpandFallbackPolicy.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ExpandFallbackPolicy.scala new file mode 100644 index 000000000..1833cca66 --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ExpandFallbackPolicy.scala @@ -0,0 +1,153 @@ +/* + * Copyright (C) 2024-2024. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.huawei.boostkit.spark + +import com.huawei.boostkit.spark.util.InsertTransitions +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution.adaptive.{AQEShuffleReadExec, AdaptiveSparkPlanExec, BroadcastQueryStageExec, QueryStageExec, ShuffleQueryStageExec} +import org.apache.spark.sql.execution.command.ExecutedCommandExec +import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec} +import org.apache.spark.sql.execution.joins.ColumnarBroadcastHashJoinExec +import org.apache.spark.sql.execution._ + +/** + * Note, this rule should only fallback to row-based plan if there is no harm. + * The follow case should be handled carefully + * + * @param isAdaptiveContext If is inside AQE + * @param originalPlan The vanilla SparkPlan without apply gluten transform rules + * + * */ +case class ExpandFallbackPolicy(isAdaptiveContext: Boolean, originalPlan: SparkPlan) + extends Rule[SparkPlan] { + + val columnarConf: ColumnarPluginConfig = ColumnarPluginConfig.getSessionConf + + private def countFallback(plan: SparkPlan): Int = { + var fallbacks = 0 + + def countFallbackInternal(plan: SparkPlan): Unit = { + plan match { + case _: QueryStageExec => // Another stage. + case _: CommandResultExec | _: ExecutedCommandExec => // ignore + case _: AQEShuffleReadExec => // ignore + case leafPlan: LeafExecNode if !isOmniSparkPlan(leafPlan) => + // Possible fallback for leaf node. + fallbacks = fallbacks + 1 + case p:ColumnarToRowExec => p.children.foreach(countFallbackInternal) + case p:RowToColumnarExec => p.children.foreach(countFallbackInternal) + case p => + if(!isOmniSparkPlan(p)) { + fallbacks = fallbacks + 1 + } + p.children.foreach(countFallbackInternal) + } + } + + countFallbackInternal(plan) + fallbacks + } + + private def isOmniSparkPlan(plan: SparkPlan): Boolean = plan.nodeName.startsWith("Omni") && plan.supportsColumnar + + private def hasColumnarBroadcastExchangeWithJoin(plan: SparkPlan): Boolean = { + def isColumnarBroadcastExchange(p: SparkPlan): Boolean = p match { + case BroadcastQueryStageExec(_, _: ColumnarBroadcastExchangeExec, _) => true + case _ => false + } + + plan.find { + case j: ColumnarBroadcastHashJoinExec + if isColumnarBroadcastExchange(j.left) || + isColumnarBroadcastExchange(j.right) => + true + case _ => false + }.isDefined + } + + private def fallback(plan: SparkPlan): Option[String] = { + val fallbackThreshold = if (isAdaptiveContext) { + columnarConf.wholeStageFallbackThreshold + } else if (plan.find(_.isInstanceOf[AdaptiveSparkPlanExec]).isDefined) { + // if we are here, that means we are now at `QueryExecution.preparations` and + // AQE is actually not applied. We do nothing for this case, and later in + // AQE we can check `wholeStageFallbackThreshold`. + return None + } else { + // AQE is not applied, so we use the whole query threshold to check if should fallback + columnarConf.queryFallbackThreshold + } + if (fallbackThreshold < 0) { + return None + } + + // todo 157-162添加適配后是否需要添加這段邏輯 + // // not safe to fallback row-based BHJ as the broadcast exchange is already columnar + // if (hasColumnarBroadcastExchangeWithJoin(plan)) { + // return None + // } + + val netFallbackNum = countFallback(plan) + + if (netFallbackNum >= fallbackThreshold) { + Some( + s"Fallback policy is taking effect, net fallback number: $netFallbackNum, " + + s"threshold: $fallbackThreshold") + } else { + None + } + } + + private def fallbackToRowBasedPlan(): SparkPlan = { + val columnarPostOverrides = ColumnarPostOverrides() + val planWithColumnarToRow = InsertTransitions.insertTransitions(originalPlan, false) + planWithColumnarToRow.transform { + case ColumnarToRowExec(bqe: BroadcastQueryStageExec) if bqe.plan.isInstanceOf[ColumnarBroadcastExchangeExec] => + val columnarBroadcastExchangeExec = bqe.plan.asInstanceOf[ColumnarBroadcastExchangeExec] + BroadcastQueryStageExec(bqe.id, BroadcastExchangeExec(columnarBroadcastExchangeExec.mode, ColumnarBroadcastExchangeAdaptorExec(columnarBroadcastExchangeExec, 1)), bqe._canonicalized) + case ColumnarToRowExec(bqe: BroadcastQueryStageExec) if bqe.plan.isInstanceOf[ReusedExchangeExec] && bqe.plan.asInstanceOf[ReusedExchangeExec].child.isInstanceOf[ColumnarBroadcastExchangeExec] => + val columnarBroadcastExchangeExec = bqe.plan.asInstanceOf[ReusedExchangeExec].child.asInstanceOf[ColumnarBroadcastExchangeExec] + BroadcastQueryStageExec(bqe.id, BroadcastExchangeExec(columnarBroadcastExchangeExec.mode, ColumnarBroadcastExchangeAdaptorExec(columnarBroadcastExchangeExec, 1)), bqe._canonicalized) + case c2r@(ColumnarToRowExec(_: ShuffleQueryStageExec) | ColumnarToRowExec(_: AQEShuffleReadExec)) => + columnarPostOverrides.replaceColumnarToRow(c2r.asInstanceOf[ColumnarToRowExec], conf) + } + } + + override def apply(plan: SparkPlan): SparkPlan = { + logInfo("Using BoostKit Spark Native Sql Engine Extension FallbackPolicy") + val reason = fallback(plan) + if (reason.isDefined) { + val fallbackPlan = fallbackToRowBasedPlan() + TransformHints.tagAllNotTransformable(fallbackPlan, reason.get) + FallbackNode(fallbackPlan) + } else { + plan + } + } +} + +/** A wrapper to specify the plan is fallback plan, the caller side should unwrap it. */ +case class FallbackNode(fallbackPlan: SparkPlan) extends LeafExecNode { + override protected def doExecute(): RDD[InternalRow] = throw new UnsupportedOperationException() + + override def output: Seq[Attribute] = fallbackPlan.output +} diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/TransformHintRule.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/TransformHintRule.scala index 5191c4a10..16701e665 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/TransformHintRule.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/TransformHintRule.scala @@ -20,16 +20,19 @@ package com.huawei.boostkit.spark import org.apache.commons.lang3.exception.ExceptionUtils import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.trees.TreeNodeTag import org.apache.spark.sql.execution.adaptive.{BroadcastQueryStageExec, OmniAQEShuffleReadExec} import org.apache.spark.sql.execution.aggregate.HashAggregateExec -import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec} +import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, BroadcastExchangeLike, ReusedExchangeExec, ShuffleExchangeExec} import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, BroadcastNestedLoopJoinExec, ColumnarBroadcastHashJoinExec, ColumnarShuffledHashJoinExec, ColumnarSortMergeJoinExec, ShuffledHashJoinExec, SortMergeJoinExec} import org.apache.spark.sql.execution.window.WindowExec import org.apache.spark.sql.execution.{CoalesceExec, CodegenSupport, ColumnarBroadcastExchangeExec, ColumnarCoalesceExec, ColumnarExpandExec, ColumnarFileSourceScanExec, ColumnarFilterExec, ColumnarGlobalLimitExec, ColumnarHashAggregateExec, ColumnarLocalLimitExec, ColumnarProjectExec, ColumnarShuffleExchangeExec, ColumnarSortExec, ColumnarTakeOrderedAndProjectExec, ColumnarTopNSortExec, ColumnarUnionExec, ColumnarWindowExec, ExpandExec, FileSourceScanExec, FilterExec, GlobalLimitExec, LocalLimitExec, ProjectExec, SortExec, SparkPlan, TakeOrderedAndProjectExec, TopNSortExec, UnionExec} import org.apache.spark.sql.types.ColumnarBatchSupportUtil.checkColumnarBatchSupport +import scala.util.control.Breaks.{break, breakable} + trait TransformHint { val stacktrace: Option[String] = if (TransformHints.DEBUG) { @@ -87,6 +90,10 @@ object TransformHints { tag(plan, TRANSFORM_UNSUPPORTED(Some(reason))) } + def tagAllNotTransformable(plan: SparkPlan, reason: String): Unit = { + plan.foreach(other => tagNotTransformable(other, reason)) + } + def getHint(plan: SparkPlan): TransformHint = { if (!isAlreadyTagged(plan)) { throw new IllegalStateException("Transform hint tag not set in plan: " + plan.toString()) @@ -115,7 +122,8 @@ case class AddTransformHintRule() extends Rule[SparkPlan] { val enableColumnarExpand: Boolean = columnarConf.enableColumnarExpand val enableColumnarBroadcastExchange: Boolean = columnarConf.enableColumnarBroadcastExchange && columnarConf.enableColumnarBroadcastJoin - val enableColumnarBroadcastJoin: Boolean = columnarConf.enableColumnarBroadcastJoin + val enableColumnarBroadcastJoin: Boolean = columnarConf.enableColumnarBroadcastExchange && + columnarConf.enableColumnarBroadcastJoin val enableColumnarSortMergeJoin: Boolean = columnarConf.enableColumnarSortMergeJoin val enableShuffledHashJoin: Boolean = columnarConf.enableShuffledHashJoin val enableColumnarFileScan: Boolean = columnarConf.enableColumnarFileScan @@ -167,6 +175,7 @@ case class AddTransformHintRule() extends Rule[SparkPlan] { plan.tableIdentifier, plan.disableBucketedScan ).buildCheck() + TransformHints.tagTransformable(plan) case plan: ProjectExec => if (!enableColumnarProject) { TransformHints.tagNotTransformable( @@ -174,6 +183,7 @@ case class AddTransformHintRule() extends Rule[SparkPlan] { return } ColumnarProjectExec(plan.projectList, plan.child).buildCheck() + TransformHints.tagTransformable(plan) case plan: FilterExec => if (!enableColumnarFilter) { TransformHints.tagNotTransformable( @@ -181,6 +191,7 @@ case class AddTransformHintRule() extends Rule[SparkPlan] { return } ColumnarFilterExec(plan.condition, plan.child).buildCheck() + TransformHints.tagTransformable(plan) case plan: ExpandExec => if (!enableColumnarExpand) { TransformHints.tagNotTransformable( @@ -188,6 +199,7 @@ case class AddTransformHintRule() extends Rule[SparkPlan] { return } ColumnarExpandExec(plan.projections, plan.output, plan.child).buildCheck() + TransformHints.tagTransformable(plan) case plan: HashAggregateExec => if (!enableColumnarHashAgg) { TransformHints.tagNotTransformable( @@ -204,6 +216,7 @@ case class AddTransformHintRule() extends Rule[SparkPlan] { plan.initialInputBufferOffset, plan.resultExpressions, plan.child).buildCheck() + TransformHints.tagTransformable(plan) case plan: TopNSortExec => if (!enableColumnarTopNSort) { TransformHints.tagNotTransformable( @@ -212,6 +225,7 @@ case class AddTransformHintRule() extends Rule[SparkPlan] { } ColumnarTopNSortExec(plan.n, plan.strictTopN, plan.partitionSpec, plan.sortOrder, plan.global, plan.child).buildCheck() + TransformHints.tagTransformable(plan) case plan: SortExec => if (!enableColumnarSort) { TransformHints.tagNotTransformable( @@ -220,6 +234,7 @@ case class AddTransformHintRule() extends Rule[SparkPlan] { } ColumnarSortExec(plan.sortOrder, plan.global, plan.child, plan.testSpillFrequency).buildCheck() + TransformHints.tagTransformable(plan) case plan: BroadcastExchangeExec => if (!enableColumnarBroadcastExchange) { TransformHints.tagNotTransformable( @@ -227,6 +242,7 @@ case class AddTransformHintRule() extends Rule[SparkPlan] { return } new ColumnarBroadcastExchangeExec(plan.mode, plan.child).buildCheck() + TransformHints.tagTransformable(plan) case plan: TakeOrderedAndProjectExec => if (!enableTakeOrderedAndProject) { TransformHints.tagNotTransformable( @@ -238,6 +254,7 @@ case class AddTransformHintRule() extends Rule[SparkPlan] { plan.sortOrder, plan.projectList, plan.child).buildCheck() + TransformHints.tagTransformable(plan) case plan: UnionExec => if (!enableColumnarUnion) { TransformHints.tagNotTransformable( @@ -245,6 +262,7 @@ case class AddTransformHintRule() extends Rule[SparkPlan] { return } ColumnarUnionExec(plan.children).buildCheck() + TransformHints.tagTransformable(plan) case plan: ShuffleExchangeExec => if (!enableColumnarShuffle) { TransformHints.tagNotTransformable( @@ -253,6 +271,7 @@ case class AddTransformHintRule() extends Rule[SparkPlan] { } ColumnarShuffleExchangeExec(plan.outputPartitioning, plan.child, plan.shuffleOrigin) .buildCheck() + TransformHints.tagTransformable(plan) case plan: BroadcastHashJoinExec => // We need to check if BroadcastExchangeExec can be converted to columnar-based. // If not, BHJ should also be row-based. @@ -289,15 +308,153 @@ case class AddTransformHintRule() extends Rule[SparkPlan] { } case _ => } - ColumnarBroadcastHashJoinExec( - plan.leftKeys, - plan.rightKeys, - plan.joinType, - plan.buildSide, - plan.condition, - plan.left, - plan.right, - plan.isNullAwareAntiJoin).buildCheck() + val isBhjColumnar: Boolean = try { + ColumnarBroadcastHashJoinExec( + plan.leftKeys, + plan.rightKeys, + plan.joinType, + plan.buildSide, + plan.condition, + plan.left, + plan.right, + plan.isNullAwareAntiJoin).buildCheck() + true + } catch { + case throwable@(_: UnsupportedOperationException | _: RuntimeException | _: Throwable) => + val message = s"[OPERATOR FALLBACK] ${throwable} ${plan.getClass} falls back to Spark operator" + logDebug(message) + TransformHints.tagNotTransformable(plan, reason = message) + false + case l: UnsatisfiedLinkError => + TransformHints.tagNotTransformable(plan) + throw l + case f: NoClassDefFoundError => + TransformHints.tagNotTransformable(plan) + throw f + } + + val buildSidePlan = plan.buildSide match { + case BuildLeft => plan.left + case BuildRight => plan.right + } + + val maybeExchange = buildSidePlan + .find { + case BroadcastExchangeExec(_, _) => true + case _ => false + } + .map(_.asInstanceOf[BroadcastExchangeExec]) + + maybeExchange match { + case Some(exchange @ BroadcastExchangeExec(_, _)) => + if (isBhjColumnar) { + TransformHints.tagTransformable(plan) + }else{ + TransformHints.tagNotTransformable(exchange) + TransformHints.tagNotTransformable(plan) + } + case None => + // we are in AQE, find the hidden exchange + // FIXME did we consider the case that AQE: OFF && Reuse: ON ? + var maybeHiddenExchange: Option[BroadcastExchangeLike] = None + breakable { + buildSidePlan.foreach { + case e: BroadcastExchangeLike => + maybeHiddenExchange = Some(e) + break + case t: BroadcastQueryStageExec => + t.plan.foreach { + case e2: BroadcastExchangeLike => + maybeHiddenExchange = Some(e2) + break + case r: ReusedExchangeExec => + r.child match { + case e2: BroadcastExchangeLike => + maybeHiddenExchange = Some(e2) + break + case _ => + } + case _ => + } + case _ => + } + } + // restriction to force the hidden exchange to be found + val exchange = maybeHiddenExchange.get + // to conform to the underlying exchange's type, columnar or vanilla + exchange match { + case _: ColumnarBroadcastExchangeExec => + if (!isBhjColumnar) { + throw new IllegalStateException( + s"BroadcastExchange has already been" + + s" transformed to columnar version but BHJ is determined as" + + s" non-transformable: ${plan.toString()}") + } + TransformHints.tagTransformable(plan) + case _: BroadcastExchangeExec => + TransformHints.tagNotTransformable( + plan, + "it's a materialized broadcast exchange or reused broadcast exchange") + } + } + case plan: BroadcastNestedLoopJoinExec => + val buildSidePlan = plan.buildSide match { + case BuildLeft => plan.left + case BuildRight => plan.right + } + + val maybeExchange = buildSidePlan + .find { + case BroadcastExchangeExec(_, _) => true + case _ => false + } + .map(_.asInstanceOf[BroadcastExchangeExec]) + + maybeExchange match { + case Some(exchange@BroadcastExchangeExec(_, _)) => + TransformHints.tagNotTransformable(exchange) + TransformHints.tagNotTransformable(plan) + case None => + // we are in AQE, find the hidden exchange + // FIXME did we consider the case that AQE: OFF && Reuse: ON ? + var maybeHiddenExchange: Option[BroadcastExchangeLike] = None + breakable { + buildSidePlan.foreach { + case e: BroadcastExchangeLike => + maybeHiddenExchange = Some(e) + break + case t: BroadcastQueryStageExec => + t.plan.foreach { + case e2: BroadcastExchangeLike => + maybeHiddenExchange = Some(e2) + break + case r: ReusedExchangeExec => + r.child match { + case e2: BroadcastExchangeLike => + maybeHiddenExchange = Some(e2) + break + case _ => + } + case _ => + } + case _ => + } + } + // restriction to force the hidden exchange to be found + val exchange = maybeHiddenExchange.get + // to conform to the underlying exchange's type, columnar or vanilla + exchange match { + case _: ColumnarBroadcastExchangeExec => + throw new IllegalStateException( + s"BroadcastExchange has already been" + + s" transformed to columnar version but BHJ is determined as" + + s" non-transformable: ${plan.toString()}") + case _: BroadcastExchangeExec => + TransformHints.tagNotTransformable( + plan, + "it's a materialized broadcast exchange or reused broadcast exchange") + } + } case plan: SortMergeJoinExec => if (!enableColumnarSortMergeJoin) { TransformHints.tagNotTransformable( @@ -312,6 +469,7 @@ case class AddTransformHintRule() extends Rule[SparkPlan] { plan.left, plan.right, plan.isSkewJoin).buildCheck() + TransformHints.tagTransformable(plan) case plan: WindowExec => if (!enableColumnarWindow) { TransformHints.tagNotTransformable( @@ -320,6 +478,7 @@ case class AddTransformHintRule() extends Rule[SparkPlan] { } ColumnarWindowExec(plan.windowExpression, plan.partitionSpec, plan.orderSpec, plan.child).buildCheck() + TransformHints.tagTransformable(plan) case plan: ShuffledHashJoinExec => if (!enableShuffledHashJoin) { TransformHints.tagNotTransformable( @@ -335,6 +494,7 @@ case class AddTransformHintRule() extends Rule[SparkPlan] { plan.left, plan.right, plan.isSkewJoin).buildCheck() + TransformHints.tagTransformable(plan) case plan: LocalLimitExec => if (!enableLocalColumnarLimit) { TransformHints.tagNotTransformable( @@ -342,6 +502,7 @@ case class AddTransformHintRule() extends Rule[SparkPlan] { return } ColumnarLocalLimitExec(plan.limit, plan.child).buildCheck() + TransformHints.tagTransformable(plan) case plan: GlobalLimitExec => if (!enableGlobalColumnarLimit) { TransformHints.tagNotTransformable( @@ -349,9 +510,11 @@ case class AddTransformHintRule() extends Rule[SparkPlan] { return } ColumnarGlobalLimitExec(plan.limit, plan.child).buildCheck() + TransformHints.tagTransformable(plan) case plan: BroadcastNestedLoopJoinExec => TransformHints.tagNotTransformable( plan, "columnar BroadcastNestedLoopJoin is not support") + TransformHints.tagTransformable(plan) case plan: CoalesceExec => if (!enableColumnarCoalesce) { TransformHints.tagNotTransformable( @@ -359,9 +522,10 @@ case class AddTransformHintRule() extends Rule[SparkPlan] { return } ColumnarCoalesceExec(plan.numPartitions, plan.child).buildCheck() - case _ => + TransformHints.tagTransformable(plan) + case _ => TransformHints.tagTransformable(plan) + } - TransformHints.tagTransformable(plan) } catch { case throwable @ (_:UnsupportedOperationException | _:RuntimeException | _:Throwable) => val message = s"[OPERATOR FALLBACK] ${throwable} ${plan.getClass} falls back to Spark operator" diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/util/InsertTransitions.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/util/InsertTransitions.scala new file mode 100644 index 000000000..1de01e840 --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/util/InsertTransitions.scala @@ -0,0 +1,51 @@ +/* + * Copyright (C) 2024-2024. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.huawei.boostkit.spark.util + +import org.apache.spark.sql.execution.{ColumnarToRowExec, ColumnarToRowTransition, RowToColumnarExec, RowToColumnarTransition, SparkPlan} + +/** Ported from [[ApplyColumnarRulesAndInsertTransitions]] of vanilla Spark. */ +object InsertTransitions { + + private def insertRowToColumnar(plan: SparkPlan): SparkPlan = { + if (!plan.supportsColumnar) { + // The tree feels kind of backwards + // Columnar Processing will start here, so transition from row to columnar + RowToColumnarExec(insertTransitions(plan, outputsColumnar = false)) + } else if (!plan.isInstanceOf[RowToColumnarTransition]) { + plan.withNewChildren(plan.children.map(insertRowToColumnar)) + } else { + plan + } + } + + def insertTransitions(plan: SparkPlan, outputsColumnar: Boolean): SparkPlan = { + if (outputsColumnar) { + insertRowToColumnar(plan) + } else if (plan.supportsColumnar) { + // `outputsColumnar` is false but the plan outputs columnar format, so add a + // to-row transition here. + ColumnarToRowExec(insertRowToColumnar(plan)) + } else if (!plan.isInstanceOf[ColumnarToRowTransition]) { + plan.withNewChildren(plan.children.map(insertTransitions(_, outputsColumnar = false))) + } else { + plan + } + } +} diff --git a/omnioperator/omniop-spark-extension/java/src/test/scala/com/huawei/boostkit/spark/FallbackStrategiesSuite.scala b/omnioperator/omniop-spark-extension/java/src/test/scala/com/huawei/boostkit/spark/FallbackStrategiesSuite.scala new file mode 100644 index 000000000..b65c0d547 --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/test/scala/com/huawei/boostkit/spark/FallbackStrategiesSuite.scala @@ -0,0 +1,367 @@ +/* + * Copyright (C) 2024-2024. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.huawei.boostkit.spark + +import com.huawei.boostkit.spark.util.InsertTransitions +import org.apache.spark.SparkConf +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{QueryTest, Row} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.plans.logical.Statistics +import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, QueryStageExec} +import org.apache.spark.sql.execution.exchange.Exchange +import org.apache.spark.sql.execution.joins.BroadcastHashJoinExec +import org.apache.spark.sql.execution.{ColumnarProjectExec, ColumnarTakeOrderedAndProjectExec, LeafExecNode, ProjectExec, SparkPlan, TakeOrderedAndProjectExec, UnaryExecNode} +import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf} +import org.apache.spark.sql.test.SharedSparkSession + +import scala.concurrent.Future + +class FallbackStrategiesSuite extends QueryTest with SharedSparkSession { + + import testImplicits._ + + override def sparkConf: SparkConf = super.sparkConf + .setAppName("test FallbackStrategiesSuite") + .set(StaticSQLConf.SPARK_SESSION_EXTENSIONS.key, "com.huawei.boostkit.spark.ColumnarPlugin") + .set(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, "false") + .set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.OmniColumnarShuffleManager") + + override def beforeAll(): Unit = { + super.beforeAll() + val employees = Seq[(String, String, Int, Int)]( + ("Lisa", "Sales", 10000, 35), + ("Evan", "Sales", 32000, 38), + ("Fred", "Engineering", 21000, 28), + ("Alex", "Sales", 30000, 33), + ("Tom", "Engineering", 23000, 33), + ("Jane", "Marketing", 29000, 28), + ("Jeff", "Marketing", 35000, 38), + ("Paul", "Engineering", 29000, 23), + ("Chloe", "Engineering", 23000, 25) + ).toDF("name", "dept", "salary", "age") + employees.createOrReplaceTempView("employees_for_fallback_ut_test") + } + + test("Fall back stage contain bhj") { + withSQLConf(("spark.omni.sql.columnar.wholeStage.fallback.threshold", "3"), + ("spark.omni.sql.columnar.project", "false"), + (SQLConf.ADAPTIVE_EXECUTION_ENABLED.key, "true"), + (SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key, "10MB")) { + val df = spark.sql("select t1.age * 2, t2.salary from employees_for_fallback_ut_test t1 join employees_for_fallback_ut_test t2 on t1.age = t2.age sort by t1.age") + val runRows = df.collect().sortBy(row => row.getInt(1)) + val expectedRows = Seq(Row(20000, 35), Row(64000, 38), + Row(42000, 28), Row(60000, 33), + Row(46000, 33), Row(46000, 33), + Row(58000, 28), Row(58000, 28), + Row(70000, 38), Row(58000, 23), + Row(58000, 23), Row(46000, 25), + Row(46000, 25)).sortBy(row => row.getInt(1)) + QueryTest.sameRows(runRows, expectedRows) + val plans = df.queryExecution.executedPlan.asInstanceOf[AdaptiveSparkPlanExec].executedPlan.collect({ + case plan: BroadcastHashJoinExec => plan + }) + assert(plans.size == 1, "the last stage containing bhj should fallback") + } + } + + + test("Fall back the last stage contains unsupported bnlj if meeting the configured threshold") { + withSQLConf(("spark.omni.sql.columnar.wholeStage.fallback.threshold", "2"), + (SQLConf.ADAPTIVE_EXECUTION_ENABLED.key, "true")) { + val df = spark.sql("select age + salary from (select age, salary from (select age from employees_for_fallback_ut_test order by age limit 1) s1," + + " (select salary from employees_for_fallback_ut_test order by salary limit 1) s2)") + QueryTest.checkAnswer(df, Seq(Row(10023))) + val plans = df.queryExecution.executedPlan.asInstanceOf[AdaptiveSparkPlanExec].executedPlan.collect({ + case plan: ProjectExec => plan + case plan: TakeOrderedAndProjectExec => plan + }) + assert(plans.count(_.isInstanceOf[ProjectExec]) == 1, "the last stage containing projectExec should fallback") + assert(plans.count(_.isInstanceOf[TakeOrderedAndProjectExec]) == 1, "the last stage containing projectExec should fallback") + } + } + + test("Don't Fall back the last stage contains unsupported bnlj if NOT meeting the configured threshold") { + withSQLConf(("spark.omni.sql.columnar.wholeStage.fallback.threshold", "3"), + (SQLConf.ADAPTIVE_EXECUTION_ENABLED.key, "true")) { + val df = spark.sql("select age + salary from (select age, salary from (select age from employees_for_fallback_ut_test order by age limit 1) s1," + + " (select salary from employees_for_fallback_ut_test order by salary limit 1) s2)") + QueryTest.checkAnswer(df, Seq(Row(10023))) + val plans = df.queryExecution.executedPlan.asInstanceOf[AdaptiveSparkPlanExec].executedPlan.collect({ + case plan: ColumnarProjectExec => plan + case plan: ColumnarTakeOrderedAndProjectExec => plan + }) + assert(plans.count(_.isInstanceOf[ColumnarProjectExec]) == 1, "the last stage containing projectExec should not fallback") + assert(plans.count(_.isInstanceOf[ColumnarTakeOrderedAndProjectExec]) == 1, "the last stage containing projectExec should not fallback") + } + } + + + test("Fall back the whole query if one unsupported") { + withSQLConf(("spark.omni.sql.columnar.query.fallback.threshold", "1")) { + val originalPlan = UnaryOp2(UnaryOp1(UnaryOp2(UnaryOp1(LeafOp())))) + val rule = ColumnarOverrideRules(spark) + rule.preColumnarTransitions(originalPlan) + // Fake output of preColumnarTransitions, mocking replacing UnaryOp1 with UnaryOp1Transformer. + val planAfterPreOverride = + UnaryOp2(ColumnarUnaryOp1(UnaryOp2(ColumnarUnaryOp1(LeafOp())))) + val planWithTransition = InsertTransitions.insertTransitions(planAfterPreOverride, false) + val outputPlan = rule.postColumnarTransitions(planWithTransition) + // Expect to fall back the entire plan. + assert(outputPlan == originalPlan) + } + } + + test("Fall back the whole plan if meeting the configured threshold") { + withSQLConf(("spark.omni.sql.columnar.wholeStage.fallback.threshold", "1")) { + val originalPlan = UnaryOp2(UnaryOp1(UnaryOp2(UnaryOp1(LeafOp())))) + val rule = ColumnarOverrideRules(spark) + rule.preColumnarTransitions(originalPlan) + rule.isSupportAdaptive = true + // Fake output of preColumnarTransitions, mocking replacing UnaryOp1 with UnaryOp1Transformer. + val planAfterPreOverride = + UnaryOp2(ColumnarUnaryOp1(UnaryOp2(ColumnarUnaryOp1(LeafOp())))) + val planWithTransition = InsertTransitions.insertTransitions(planAfterPreOverride, false) + val outputPlan = rule.postColumnarTransitions(planWithTransition) + // Expect to fall back the entire plan. + assert(outputPlan == originalPlan) + } + } + + test("Don't fall back the whole plan if NOT meeting the configured threshold") { + withSQLConf(("spark.omni.sql.columnar.wholeStage.fallback.threshold", "4")) { + val originalPlan = UnaryOp2(UnaryOp1(UnaryOp2(UnaryOp1(LeafOp())))) + val rule = ColumnarOverrideRules(spark) + rule.preColumnarTransitions(originalPlan) + rule.isSupportAdaptive = true + // Fake output of preColumnarTransitions, mocking replacing UnaryOp1 with UnaryOp1Transformer. + val planAfterPreOverride = + UnaryOp2(ColumnarUnaryOp1(UnaryOp2(ColumnarUnaryOp1(LeafOp())))) + val planWithTransition = InsertTransitions.insertTransitions(planAfterPreOverride, false) + val outputPlan = rule.postColumnarTransitions(planWithTransition) + val columnarUnaryOps = outputPlan.collect({ + case p: ColumnarUnaryOp1 => p + }) + // Expect to get the plan with columnar rule applied. + assert(columnarUnaryOps.size == 2) + assert(outputPlan != originalPlan) + } + } + + test( + "Fall back the whole plan if meeting the configured threshold (leaf node is" + + " transformable)") { + withSQLConf(("spark.omni.sql.columnar.wholeStage.fallback.threshold", "2")) { + val originalPlan = UnaryOp2(UnaryOp1(UnaryOp2(UnaryOp1(LeafOp())))) + val rule = ColumnarOverrideRules(spark) + rule.preColumnarTransitions(originalPlan) + rule.isSupportAdaptive = true + // Fake output of preColumnarTransitions, mocking replacing UnaryOp1 with UnaryOp1Transformer + // and replacing LeafOp with LeafOpTransformer. + val planAfterPreOverride = + UnaryOp2(ColumnarUnaryOp1(UnaryOp2(ColumnarUnaryOp1(ColumnarLeafOp())))) + val planWithTransition = InsertTransitions.insertTransitions(planAfterPreOverride, false) + val outputPlan = rule.postColumnarTransitions(planWithTransition) + // Expect to fall back the entire plan. + val columnarUnaryOps = outputPlan.collect({ + case p: ColumnarUnaryOp1 => p + }) + assert(columnarUnaryOps.isEmpty) + assert(outputPlan == originalPlan) + } + } + + test( + "Don't Fall back the whole plan if NOT meeting the configured threshold (" + + "leaf node is transformable)") { + withSQLConf(("spark.omni.sql.columnar.wholeStage.fallback.threshold", "3")) { + val originalPlan = UnaryOp2(UnaryOp1(UnaryOp2(UnaryOp1(LeafOp())))) + val rule = ColumnarOverrideRules(spark) + rule.preColumnarTransitions(originalPlan) + rule.isSupportAdaptive = true + // Fake output of preColumnarTransitions, mocking replacing UnaryOp1 with UnaryOp1Transformer + // and replacing LeafOp with LeafOpTransformer. + val planAfterPreOverride = + UnaryOp2(ColumnarUnaryOp1(UnaryOp2(ColumnarUnaryOp1(ColumnarLeafOp())))) + val planWithTransition = InsertTransitions.insertTransitions(planAfterPreOverride, false) + val outputPlan = rule.postColumnarTransitions(planWithTransition) + // Expect to get the plan with columnar rule applied. + val columnarUnaryOps = outputPlan.collect({ + case p: ColumnarUnaryOp1 => p + case p: ColumnarLeafOp => p + }) + assert(columnarUnaryOps.size == 3) + assert(outputPlan != originalPlan) + } + } + + test("Don't fall back the whole query if all supported") { + withSQLConf(("spark.omni.sql.columnar.query.fallback.threshold", "1")) { + val originalPlan = UnaryOp1(UnaryOp1(UnaryOp1(UnaryOp1(LeafOp())))) + val rule = ColumnarOverrideRules(spark) + rule.preColumnarTransitions(originalPlan) + // Fake output of preColumnarTransitions, mocking replacing UnaryOp1 with UnaryOp1Transformer. + val planAfterPreOverride = + ColumnarUnaryOp1(ColumnarUnaryOp1(ColumnarUnaryOp1(ColumnarUnaryOp1(ColumnarLeafOp())))) + val planWithTransition = InsertTransitions.insertTransitions(planAfterPreOverride, false) + val outputPlan = rule.postColumnarTransitions(planWithTransition) + // Expect to not fall back the entire plan. + val columnPlans = outputPlan.collect({ + case p@(ColumnarExchange1(_, _) | ColumnarUnaryOp1(_, _) | ColumnarLeafOp(_)) => p + }) + assert(columnPlans.size == 5) + assert(outputPlan != originalPlan) + } + } + + test("Don't fall back the whole plan if all supported") { + withSQLConf(("spark.omni.sql.columnar.wholeStage.fallback.threshold", "1")) { + val originalPlan = Exchange1(UnaryOp1(UnaryOp1(UnaryOp1(LeafOp())))) + val rule = ColumnarOverrideRules(spark) + rule.preColumnarTransitions(originalPlan) + rule.isSupportAdaptive = true + // Fake output of preColumnarTransitions, mocking replacing UnaryOp1 with UnaryOp1Transformer. + val planAfterPreOverride = + ColumnarExchange1(ColumnarUnaryOp1(ColumnarUnaryOp1(ColumnarUnaryOp1(ColumnarLeafOp())))) + val planWithTransition = InsertTransitions.insertTransitions(planAfterPreOverride, false) + val outputPlan = rule.postColumnarTransitions(planWithTransition) + // Expect to not fall back the entire plan. + val columnPlans = outputPlan.collect({ + case p@(ColumnarExchange1(_, _) | ColumnarUnaryOp1(_, _) | ColumnarLeafOp(_)) => p + }) + assert(columnPlans.size == 5) + assert(outputPlan != originalPlan) + } + } + + test("Fall back the whole plan if one supported plan before queryStage") { + withSQLConf(("spark.omni.sql.columnar.wholeStage.fallback.threshold", "1")) { + val mockQueryStageExec = QueryStageExec1(LeafOp(), LeafOp()) + val originalPlan = Exchange1(UnaryOp1(UnaryOp1(mockQueryStageExec))) + val rule = ColumnarOverrideRules(spark) + rule.preColumnarTransitions(originalPlan) + rule.isSupportAdaptive = true + // Fake output of preColumnarTransitions, mocking replacing UnaryOp1 with UnaryOp1Transformer. + val planAfterPreOverride = + ColumnarExchange1(ColumnarUnaryOp1(UnaryOp1(mockQueryStageExec))) + val planWithTransition = InsertTransitions.insertTransitions(planAfterPreOverride, false) + val outputPlan = rule.postColumnarTransitions(planWithTransition) + // Expect to fall back the entire plan. + val columnPlans = outputPlan.collect({ + case p@(ColumnarExchange1(_, _) | ColumnarUnaryOp1(_, _) | ColumnarLeafOp(_)) => p + }) + assert(columnPlans.isEmpty) + assert(outputPlan == originalPlan) + } + } + +} + +case class LeafOp(override val supportsColumnar: Boolean = false) extends LeafExecNode { + override protected def doExecute(): RDD[InternalRow] = throw new UnsupportedOperationException() + + override def output: Seq[Attribute] = Seq.empty +} + +case class UnaryOp1(child: SparkPlan, override val supportsColumnar: Boolean = false) + extends UnaryExecNode { + override protected def doExecute(): RDD[InternalRow] = throw new UnsupportedOperationException() + + override def output: Seq[Attribute] = child.output + + override protected def withNewChildInternal(newChild: SparkPlan): UnaryOp1 = + copy(child = newChild) +} + +case class UnaryOp2(child: SparkPlan, override val supportsColumnar: Boolean = false) + extends UnaryExecNode { + override protected def doExecute(): RDD[InternalRow] = throw new UnsupportedOperationException() + + override def output: Seq[Attribute] = child.output + + override protected def withNewChildInternal(newChild: SparkPlan): UnaryOp2 = + copy(child = newChild) +} + +// For replacing LeafOp. +case class ColumnarLeafOp(override val supportsColumnar: Boolean = true) + extends LeafExecNode { + + override def nodeName: String = "OmniColumnarLeafOp" + + override protected def doExecute(): RDD[InternalRow] = throw new UnsupportedOperationException() + + override def output: Seq[Attribute] = Seq.empty +} + +// For replacing UnaryOp1. +case class ColumnarUnaryOp1( + override val child: SparkPlan, + override val supportsColumnar: Boolean = true) + extends UnaryExecNode { + override protected def doExecute(): RDD[InternalRow] = throw new UnsupportedOperationException() + + override def output: Seq[Attribute] = child.output + + override def nodeName: String = "OmniColumnarUnaryOp1" + + override protected def withNewChildInternal(newChild: SparkPlan): ColumnarUnaryOp1 = + copy(child = newChild) +} + +case class Exchange1( + override val child: SparkPlan, + override val supportsColumnar: Boolean = false) + extends Exchange { + override protected def doExecute(): RDD[InternalRow] = throw new UnsupportedOperationException() + + override def output: Seq[Attribute] = child.output + + override protected def withNewChildInternal(newChild: SparkPlan): Exchange1 = + copy(child = newChild) +} + +case class ColumnarExchange1( + override val child: SparkPlan, + override val supportsColumnar: Boolean = true) + extends Exchange { + override protected def doExecute(): RDD[InternalRow] = throw new UnsupportedOperationException() + + override def output: Seq[Attribute] = child.output + + override def nodeName: String = "OmniColumnarExchange" + + override protected def withNewChildInternal(newChild: SparkPlan): ColumnarExchange1 = + copy(child = newChild) +} + +case class QueryStageExec1(override val plan: SparkPlan, + override val _canonicalized: SparkPlan) extends QueryStageExec { + + override val id: Int = 0 + + override def doMaterialize(): Future[Any] = throw new UnsupportedOperationException("it is mock spark plan") + + override def cancel(): Unit = {} + + override def newReuseInstance(newStageId: Int, newOutput: Seq[Attribute]): QueryStageExec = throw new UnsupportedOperationException("it is mock spark plan") + + override def getRuntimeStatistics: Statistics = throw new UnsupportedOperationException("it is mock spark plan") +} -- Gitee From b08af4d3bcae3efd92114622a4894ad78fdf6716 Mon Sep 17 00:00:00 2001 From: dengzhaochu Date: Sat, 31 Aug 2024 14:42:53 +0800 Subject: [PATCH 3/5] support r2c broadcast and support c2r broadcast --- .../boostkit/spark/ColumnarPlugin.scala | 214 +++++++------- .../boostkit/spark/ColumnarPluginConfig.scala | 9 + .../com/huawei/boostkit/spark/Constant.scala | 2 + .../boostkit/spark/ExpandFallbackPolicy.scala | 57 ++-- .../boostkit/spark/TransformHintRule.scala | 182 ++---------- .../sql/execution/BroadcastColumnarRDD.scala | 72 ----- ...ColumnarBroadcastExchangeAdaptorExec.scala | 70 ----- .../spark/sql/execution/ColumnarExec.scala | 159 ++++++---- .../sql/execution/util/BroadcastUtils.scala | 276 ++++++++++++++++++ .../spark/FallbackStrategiesSuite.scala | 107 +++++-- 10 files changed, 637 insertions(+), 511 deletions(-) delete mode 100644 omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/BroadcastColumnarRDD.scala delete mode 100644 omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarBroadcastExchangeAdaptorExec.scala create mode 100644 omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/util/BroadcastUtils.scala diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarPlugin.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarPlugin.scala index 5d271f2d9..76e0d2eb6 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarPlugin.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarPlugin.scala @@ -18,20 +18,20 @@ package com.huawei.boostkit.spark +import com.huawei.boostkit.spark.Constant.OMNI_IS_ADAPTIVE_CONTEXT import com.huawei.boostkit.spark.expression.OmniExpressionAdaptor import com.huawei.boostkit.spark.util.PhysicalPlanSelector -import nova.hetu.omniruntime.memory.MemoryManager import org.apache.spark.api.plugin.{DriverPlugin, ExecutorPlugin, SparkPlugin} import org.apache.spark.internal.Logging import org.apache.spark.sql.{SparkSession, SparkSessionExtensions} -import org.apache.spark.sql.catalyst.expressions.{Ascending, DynamicPruningSubquery, Expression, Literal, SortOrder} +import org.apache.spark.sql.catalyst.expressions.{Ascending, Expression, Literal, SortOrder} import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Partial, PartialMerge} import org.apache.spark.sql.catalyst.optimizer.{DelayCartesianProduct, HeuristicJoinReorder, MergeSubqueryFilters, RewriteSelfJoinInInPredicate} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.{BroadcastExchangeExecProxy, RowToOmniColumnarExec, _} -import org.apache.spark.sql.execution.adaptive.{AQEShuffleReadExec, BroadcastQueryStageExec, OmniAQEShuffleReadExec, QueryStageExec, ShuffleQueryStageExec} +import org.apache.spark.sql.execution.adaptive.{AQEShuffleReadExec, AdaptiveSparkPlanExec, BroadcastQueryStageExec, OmniAQEShuffleReadExec, QueryStageExec, ShuffleQueryStageExec} import org.apache.spark.sql.execution.aggregate.{DummyLogicalPlan, ExtendedAggUtils, HashAggregateExec} -import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, Exchange, ReusedExchangeExec, ShuffleExchangeExec} +import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec} import org.apache.spark.sql.execution.joins._ import org.apache.spark.sql.execution.window.{TopNPushDownForWindow, WindowExec} import org.apache.spark.sql.internal.SQLConf @@ -43,34 +43,17 @@ import org.apache.spark.sql.execution.util.SparkMemoryUtils.addLeakSafeTaskCompl import scala.collection.mutable.ListBuffer +import nova.hetu.omniruntime.memory.MemoryManager + case class ColumnarPreOverrides(isSupportAdaptive: Boolean = true) extends Rule[SparkPlan] { val columnarConf: ColumnarPluginConfig = ColumnarPluginConfig.getSessionConf - val enableColumnarFileScan: Boolean = columnarConf.enableColumnarFileScan - val enableColumnarProject: Boolean = columnarConf.enableColumnarProject - val enableColumnarFilter: Boolean = columnarConf.enableColumnarFilter - val enableColumnarExpand: Boolean = columnarConf.enableColumnarExpand - val enableColumnarHashAgg: Boolean = columnarConf.enableColumnarHashAgg - val enableTakeOrderedAndProject: Boolean = columnarConf.enableTakeOrderedAndProject && - columnarConf.enableColumnarShuffle - val enableColumnarBroadcastExchange: Boolean = columnarConf.enableColumnarBroadcastExchange && - columnarConf.enableColumnarBroadcastJoin - val enableColumnarBroadcastJoin: Boolean = columnarConf.enableColumnarBroadcastJoin && - columnarConf.enableColumnarBroadcastExchange - val enableColumnarSortMergeJoin: Boolean = columnarConf.enableColumnarSortMergeJoin - val enableColumnarTopNSort: Boolean = columnarConf.enableColumnarTopNSort - val enableColumnarSort: Boolean = columnarConf.enableColumnarSort - val enableColumnarWindow: Boolean = columnarConf.enableColumnarWindow + val enableColumnarShuffle: Boolean = columnarConf.enableColumnarShuffle - val enableShuffledHashJoin: Boolean = columnarConf.enableShuffledHashJoin - val enableColumnarUnion: Boolean = columnarConf.enableColumnarUnion val enableFusion: Boolean = columnarConf.enableFusion val enableColumnarProjectFusion: Boolean = columnarConf.enableColumnarProjectFusion - val enableLocalColumnarLimit: Boolean = columnarConf.enableLocalColumnarLimit - val enableGlobalColumnarLimit: Boolean = columnarConf.enableGlobalColumnarLimit val enableDedupLeftSemiJoin: Boolean = columnarConf.enableDedupLeftSemiJoin val dedupLeftSemiJoinThreshold: Int = columnarConf.dedupLeftSemiJoinThreshold - val enableColumnarCoalesce: Boolean = columnarConf.enableColumnarCoalesce val enableRollupOptimization: Boolean = columnarConf.enableRollupOptimization val enableRowShuffle: Boolean = columnarConf.enableRowShuffle val columnsThreshold: Int = columnarConf.columnsThreshold @@ -92,8 +75,7 @@ case class ColumnarPreOverrides(isSupportAdaptive: Boolean = true) plan.children.map(replaceWithColumnarPlan)) } plan match { - case plan: FileSourceScanExec - if enableColumnarFileScan && checkColumnarBatchSupport(conf, plan) => + case plan: FileSourceScanExec => logInfo(s"Columnar Processing for ${plan.getClass} is currently supported.") ColumnarFileSourceScanExec( plan.relation, @@ -108,13 +90,13 @@ case class ColumnarPreOverrides(isSupportAdaptive: Boolean = true) ) case range: RangeExec => new ColumnarRangeExec(range.range) - case plan: ProjectExec if enableColumnarProject => + case plan: ProjectExec => val child = replaceWithColumnarPlan(plan.child) logInfo(s"Columnar Processing for ${plan.getClass} is currently supported.") child match { case ColumnarFilterExec(condition, child) => ColumnarConditionProjectExec(plan.projectList, condition, child) - case join : ColumnarBroadcastHashJoinExec => + case join: ColumnarBroadcastHashJoinExec => if (plan.projectList.forall(project => OmniExpressionAdaptor.isSimpleProjectForAll(project)) && enableColumnarProjectFusion) { ColumnarBroadcastHashJoinExec( join.leftKeys, @@ -129,7 +111,7 @@ case class ColumnarPreOverrides(isSupportAdaptive: Boolean = true) } else { ColumnarProjectExec(plan.projectList, child) } - case join : ColumnarShuffledHashJoinExec => + case join: ColumnarShuffledHashJoinExec => if (plan.projectList.forall(project => OmniExpressionAdaptor.isSimpleProjectForAll(project)) && enableColumnarProjectFusion) { ColumnarShuffledHashJoinExec( join.leftKeys, @@ -144,7 +126,7 @@ case class ColumnarPreOverrides(isSupportAdaptive: Boolean = true) } else { ColumnarProjectExec(plan.projectList, child) } - case join : ColumnarSortMergeJoinExec => + case join: ColumnarSortMergeJoinExec => if (plan.projectList.forall(project => OmniExpressionAdaptor.isSimpleProjectForAll(project)) && enableColumnarProjectFusion) { ColumnarSortMergeJoinExec( join.leftKeys, @@ -161,30 +143,30 @@ case class ColumnarPreOverrides(isSupportAdaptive: Boolean = true) case _ => ColumnarProjectExec(plan.projectList, child) } - case plan: FilterExec if enableColumnarFilter => + case plan: FilterExec => val child = replaceWithColumnarPlan(plan.child) logInfo(s"Columnar Processing for ${plan.getClass} is currently supported.") ColumnarFilterExec(plan.condition, child) - case plan: ExpandExec if enableColumnarExpand => + case plan: ExpandExec => val child = replaceWithColumnarPlan(plan.child) logInfo(s"Columnar Processing for ${plan.getClass} is currently supported.") ColumnarExpandExec(plan.projections, plan.output, child) - case plan: HashAggregateExec if enableColumnarHashAgg => + case plan: HashAggregateExec => val child = replaceWithColumnarPlan(plan.child) logInfo(s"Columnar Processing for ${plan.getClass} is currently supported.") if (enableFusion && !isSupportAdaptive) { if (plan.aggregateExpressions.forall(_.mode == Partial)) { child match { - case proj1 @ ColumnarProjectExec(_, - join1 @ ColumnarBroadcastHashJoinExec(_, _, _, _, _, - proj2 @ ColumnarProjectExec(_, - join2 @ ColumnarBroadcastHashJoinExec(_, _, _, _, _, - proj3 @ ColumnarProjectExec(_, - join3 @ ColumnarBroadcastHashJoinExec(_, _, _, _, _, - proj4 @ ColumnarProjectExec(_, - join4 @ ColumnarBroadcastHashJoinExec(_, _, _, _, _, - filter @ ColumnarFilterExec(_, - scan @ ColumnarFileSourceScanExec(_, _, _, _, _, _, _, _, _) + case proj1@ColumnarProjectExec(_, + join1@ColumnarBroadcastHashJoinExec(_, _, _, _, _, + proj2@ColumnarProjectExec(_, + join2@ColumnarBroadcastHashJoinExec(_, _, _, _, _, + proj3@ColumnarProjectExec(_, + join3@ColumnarBroadcastHashJoinExec(_, _, _, _, _, + proj4@ColumnarProjectExec(_, + join4@ColumnarBroadcastHashJoinExec(_, _, _, _, _, + filter@ColumnarFilterExec(_, + scan@ColumnarFileSourceScanExec(_, _, _, _, _, _, _, _, _) ), _, _, _)), _, _, _)), _, _, _)), _, _, _)) if checkBhjRightChild( child.asInstanceOf[ColumnarProjectExec].child.children(1) @@ -209,14 +191,14 @@ case class ColumnarPreOverrides(isSupportAdaptive: Boolean = true) scan.dataFilters, scan.tableIdentifier, scan.disableBucketedScan) - case proj1 @ ColumnarProjectExec(_, - join1 @ ColumnarBroadcastHashJoinExec(_, _, _, _, _, - proj2 @ ColumnarProjectExec(_, - join2 @ ColumnarBroadcastHashJoinExec(_, _, _, _, _, - proj3 @ ColumnarProjectExec(_, - join3 @ ColumnarBroadcastHashJoinExec(_, _, _, _, _, _, - filter @ ColumnarFilterExec(_, - scan @ ColumnarFileSourceScanExec(_, _, _, _, _, _, _, _, _)), _, _)) , _, _, _)), _, _, _)) + case proj1@ColumnarProjectExec(_, + join1@ColumnarBroadcastHashJoinExec(_, _, _, _, _, + proj2@ColumnarProjectExec(_, + join2@ColumnarBroadcastHashJoinExec(_, _, _, _, _, + proj3@ColumnarProjectExec(_, + join3@ColumnarBroadcastHashJoinExec(_, _, _, _, _, _, + filter@ColumnarFilterExec(_, + scan@ColumnarFileSourceScanExec(_, _, _, _, _, _, _, _, _)), _, _)), _, _, _)), _, _, _)) if checkBhjRightChild( child.asInstanceOf[ColumnarProjectExec].child.children(1) .asInstanceOf[ColumnarBroadcastExchangeExec].child) => @@ -238,14 +220,14 @@ case class ColumnarPreOverrides(isSupportAdaptive: Boolean = true) scan.dataFilters, scan.tableIdentifier, scan.disableBucketedScan) - case proj1 @ ColumnarProjectExec(_, - join1 @ ColumnarBroadcastHashJoinExec(_, _, _, _, _, - proj2 @ ColumnarProjectExec(_, - join2 @ ColumnarBroadcastHashJoinExec(_, _, _, _, _, - proj3 @ ColumnarProjectExec(_, - join3 @ ColumnarBroadcastHashJoinExec(_, _, _, _, _, - filter @ ColumnarFilterExec(_, - scan @ ColumnarFileSourceScanExec(_, _, _, _, _, _, _, _, _)), _, _, _)) , _, _, _)), _, _, _)) + case proj1@ColumnarProjectExec(_, + join1@ColumnarBroadcastHashJoinExec(_, _, _, _, _, + proj2@ColumnarProjectExec(_, + join2@ColumnarBroadcastHashJoinExec(_, _, _, _, _, + proj3@ColumnarProjectExec(_, + join3@ColumnarBroadcastHashJoinExec(_, _, _, _, _, + filter@ColumnarFilterExec(_, + scan@ColumnarFileSourceScanExec(_, _, _, _, _, _, _, _, _)), _, _, _)), _, _, _)), _, _, _)) if checkBhjRightChild( child.asInstanceOf[ColumnarProjectExec].child.children(1) .asInstanceOf[ColumnarBroadcastExchangeExec].child) => @@ -380,7 +362,7 @@ case class ColumnarPreOverrides(isSupportAdaptive: Boolean = true) } } - case plan: TakeOrderedAndProjectExec if enableTakeOrderedAndProject => + case plan: TakeOrderedAndProjectExec => val child = replaceWithColumnarPlan(plan.child) logInfo(s"Columnar Processing for ${plan.getClass} is currently supported.") ColumnarTakeOrderedAndProjectExec( @@ -388,11 +370,11 @@ case class ColumnarPreOverrides(isSupportAdaptive: Boolean = true) plan.sortOrder, plan.projectList, child) - case plan: BroadcastExchangeExec if enableColumnarBroadcastExchange => + case plan: BroadcastExchangeExec => val child = replaceWithColumnarPlan(plan.child) logInfo(s"Columnar Processing for ${plan.getClass} is currently supported.") new ColumnarBroadcastExchangeExec(plan.mode, child) - case plan: BroadcastHashJoinExec if enableColumnarBroadcastJoin => + case plan: BroadcastHashJoinExec => logInfo(s"Columnar Processing for ${plan.getClass} is currently supported.") val left = replaceWithColumnarPlan(plan.left) val right = replaceWithColumnarPlan(plan.right) @@ -405,7 +387,7 @@ case class ColumnarPreOverrides(isSupportAdaptive: Boolean = true) plan.condition, left, right) - case plan: ShuffledHashJoinExec if enableShuffledHashJoin && enableDedupLeftSemiJoin && !SQLConf.get.adaptiveExecutionEnabled => { + case plan: ShuffledHashJoinExec if enableDedupLeftSemiJoin && !SQLConf.get.adaptiveExecutionEnabled => { plan.joinType match { case LeftSemi => { if (plan.condition.isEmpty && plan.right.output.size >= dedupLeftSemiJoinThreshold) { @@ -471,7 +453,7 @@ case class ColumnarPreOverrides(isSupportAdaptive: Boolean = true) } } } - case plan: ShuffledHashJoinExec if enableShuffledHashJoin => + case plan: ShuffledHashJoinExec => val left = replaceWithColumnarPlan(plan.left) val right = replaceWithColumnarPlan(plan.right) logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.") @@ -484,7 +466,7 @@ case class ColumnarPreOverrides(isSupportAdaptive: Boolean = true) left, right, plan.isSkewJoin) - case plan: SortMergeJoinExec if enableColumnarSortMergeJoin => + case plan: SortMergeJoinExec => logInfo(s"Columnar Processing for ${plan.getClass} is currently supported.") val left = replaceWithColumnarPlan(plan.left) val right = replaceWithColumnarPlan(plan.right) @@ -497,15 +479,15 @@ case class ColumnarPreOverrides(isSupportAdaptive: Boolean = true) left, right, plan.isSkewJoin) - case plan: TopNSortExec if enableColumnarTopNSort => + case plan: TopNSortExec => val child = replaceWithColumnarPlan(plan.child) logInfo(s"Columnar Processing for ${plan.getClass} is currently supported.") ColumnarTopNSortExec(plan.n, plan.strictTopN, plan.partitionSpec, plan.sortOrder, plan.global, child) - case plan: SortExec if enableColumnarSort => + case plan: SortExec => val child = replaceWithColumnarPlan(plan.child) logInfo(s"Columnar Processing for ${plan.getClass} is currently supported.") ColumnarSortExec(plan.sortOrder, plan.global, child, plan.testSpillFrequency) - case plan: WindowExec if enableColumnarWindow => + case plan: WindowExec => val child = replaceWithColumnarPlan(plan.child) if (child.output.isEmpty) { return plan @@ -521,20 +503,18 @@ case class ColumnarPreOverrides(isSupportAdaptive: Boolean = true) case _ => ColumnarWindowExec(plan.windowExpression, plan.partitionSpec, plan.orderSpec, child) } - case plan: UnionExec if enableColumnarUnion => + case plan: UnionExec => val children = plan.children.map(replaceWithColumnarPlan) logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.") ColumnarUnionExec(children) - case plan: ShuffleExchangeExec if enableColumnarShuffle || enableRowShuffle => + case plan: ShuffleExchangeExec => val child = replaceWithColumnarPlan(plan.child) if (child.output.nonEmpty) { logInfo(s"Columnar Processing for ${plan.getClass} is currently supported.") if (child.output.size > columnsThreshold && enableRowShuffle) { new ColumnarShuffleExchangeExec(plan.outputPartitioning, child, plan.shuffleOrigin, true) - } else if (enableColumnarShuffle) { - new ColumnarShuffleExchangeExec(plan.outputPartitioning, child, plan.shuffleOrigin, false) } else { - plan + new ColumnarShuffleExchangeExec(plan.outputPartitioning, child, plan.shuffleOrigin, false) } } else { plan @@ -560,15 +540,15 @@ case class ColumnarPreOverrides(isSupportAdaptive: Boolean = true) case _ => plan } - case plan: LocalLimitExec if enableLocalColumnarLimit => + case plan: LocalLimitExec => val child = replaceWithColumnarPlan(plan.child) logInfo(s"Columnar Processing for ${plan.getClass} is currently supported.") ColumnarLocalLimitExec(plan.limit, child) - case plan: GlobalLimitExec if enableGlobalColumnarLimit => + case plan: GlobalLimitExec => val child = replaceWithColumnarPlan(plan.child) logInfo(s"Columnar Processing for ${plan.getClass} is currently supported.") ColumnarGlobalLimitExec(plan.limit, child) - case plan: CoalesceExec if enableColumnarCoalesce => + case plan: CoalesceExec => val child = replaceWithColumnarPlan(plan.child) logInfo(s"Columnar Processing for ${plan.getClass} is currently supported.") ColumnarCoalesceExec(plan.numPartitions, child) @@ -585,10 +565,9 @@ case class ColumnarPreOverrides(isSupportAdaptive: Boolean = true) } } -case class ColumnarPostOverrides() extends Rule[SparkPlan] { +case class ColumnarPostOverrides(isSupportAdaptive: Boolean = true) extends Rule[SparkPlan] { val columnarConf: ColumnarPluginConfig = ColumnarPluginConfig.getSessionConf - var isSupportAdaptive: Boolean = true def apply(plan: SparkPlan): SparkPlan = { logInfo("Using BoostKit Spark Native Sql Engine Extension ColumnarPostOverrides") @@ -598,9 +577,9 @@ case class ColumnarPostOverrides() extends Rule[SparkPlan] { private def handleColumnarToRowPartialFetch(plan: SparkPlan): SparkPlan = { // simple check plan tree have OmniColumnarToRow and no LimitExec and TakeOrderedAndProjectExec plan val noPartialFetch = if (plan.find(_.isInstanceOf[OmniColumnarToRowExec]).isDefined) { - (!plan.find(node => - node.isInstanceOf[LimitExec] || node.isInstanceOf[TakeOrderedAndProjectExec] || - node.isInstanceOf[SortMergeJoinExec]).isDefined) + (!plan.find(node => + node.isInstanceOf[LimitExec] || node.isInstanceOf[TakeOrderedAndProjectExec] || + node.isInstanceOf[SortMergeJoinExec]).isDefined) } else { false } @@ -611,8 +590,6 @@ case class ColumnarPostOverrides() extends Rule[SparkPlan] { newPlan } - def setAdaptiveSupport(enable: Boolean): Unit = { isSupportAdaptive = enable } - def replaceWithColumnarPlan(plan: SparkPlan): SparkPlan = plan match { case plan: RowToColumnarExec => val child = replaceWithColumnarPlan(plan.child) @@ -660,7 +637,7 @@ case class ColumnarPostOverrides() extends Rule[SparkPlan] { p.withNewChildren(children) } - def replaceColumnarToRow(plan: ColumnarToRowExec, conf: SQLConf) : SparkPlan = { + def replaceColumnarToRow(plan: ColumnarToRowExec, conf: SQLConf): SparkPlan = { val child = replaceWithColumnarPlan(plan.child) if (conf.getConfString("spark.omni.sql.columnar.columnarToRow", "true").toBoolean) { OmniColumnarToRowExec(child) @@ -672,12 +649,14 @@ case class ColumnarPostOverrides() extends Rule[SparkPlan] { case class ColumnarOverrideRules(session: SparkSession) extends ColumnarRule with Logging { - private def preOverrides: List[SparkSession => Rule[SparkPlan]] = List( + private def preOverrides(): List[SparkSession => Rule[SparkPlan]] = List( FallbackMultiCodegens, (_: SparkSession) => AddTransformHintRule(), - (_: SparkSession) => ColumnarPreOverrides(isSupportAdaptive)) + (_: SparkSession) => ColumnarPreOverrides(isAdaptiveContext)) - private def postOverrides: ColumnarPostOverrides = ColumnarPostOverrides() + private def postOverrides(): List[SparkSession => Rule[SparkPlan]] = List( + (_: SparkSession) => ColumnarPostOverrides(isAdaptiveContext) + ) private def finallyRules(): List[SparkSession => Rule[SparkPlan]] = { List( @@ -686,8 +665,8 @@ case class ColumnarOverrideRules(session: SparkSession) extends ColumnarRule wit } private def transformPlan(getRules: List[SparkSession => Rule[SparkPlan]], - plan: SparkPlan, - step: String) = { + plan: SparkPlan, + step: String) = { logDebug( s"${step}ColumnarTransitions preOverriden plan:\n${plan.toString}") val overridden = getRules.foldLeft(plan) { @@ -701,9 +680,16 @@ case class ColumnarOverrideRules(session: SparkSession) extends ColumnarRule wit overridden } + // Just for test use. + def enableAdaptiveContext(): Unit = { + session.sparkContext.setLocalProperty(OMNI_IS_ADAPTIVE_CONTEXT, "true") + } + // Holds the original plan for possible entire fallback. private val localOriginalPlans: ThreadLocal[ListBuffer[SparkPlan]] = ThreadLocal.withInitial(() => ListBuffer.empty[SparkPlan]) + private val localIsAdaptiveContextFlags: ThreadLocal[ListBuffer[Boolean]] = + ThreadLocal.withInitial(() => ListBuffer.empty[Boolean]) private def setOriginalPlan(plan: SparkPlan): Unit = { localOriginalPlans.get.prepend(plan) @@ -718,31 +704,41 @@ case class ColumnarOverrideRules(session: SparkSession) extends ColumnarRule wit private def resetOriginalPlan(): Unit = localOriginalPlans.get.remove(0) private def fallbackPolicy(): List[SparkSession => Rule[SparkPlan]] = { - List((_: SparkSession) => ExpandFallbackPolicy(isSupportAdaptive, originalPlan)) + List((_: SparkSession) => ExpandFallbackPolicy(isAdaptiveContext, originalPlan)) } - var isSupportAdaptive: Boolean = true - - private def supportAdaptive(plan: SparkPlan): Boolean = { - // Only QueryStage will have Exchange as Leaf Plan - val isLeafPlanExchange = plan match { - case e: Exchange => true - case other => false - } - isLeafPlanExchange || (SQLConf.get.adaptiveExecutionEnabled && (sanityCheck(plan) && - !plan.logicalLink.exists(_.isStreaming) && - !plan.expressions.exists(_.find(_.isInstanceOf[DynamicPruningSubquery]).isDefined) && - plan.children.forall(supportAdaptive))) + // This is an empirical value, may need to be changed for supporting other versions of spark. + private val aqeStackTraceIndex = 14 + + private def setAdaptiveContext(): Unit = { + val traceElements = Thread.currentThread.getStackTrace + assert( + traceElements.length > aqeStackTraceIndex, + s"The number of stack trace elements is expected to be more than $aqeStackTraceIndex") + // ApplyColumnarRulesAndInsertTransitions is called by either QueryExecution or + // AdaptiveSparkPlanExec. So by checking the stack trace, we can know whether + // columnar rule will be applied in adaptive execution context. This part of code + // needs to be carefully checked when supporting higher versions of spark to make + // sure the calling stack has not been changed. + localIsAdaptiveContextFlags + .get() + .prepend( + traceElements(aqeStackTraceIndex).getClassName + .equals(AdaptiveSparkPlanExec.getClass.getName)) } - private def sanityCheck(plan: SparkPlan): Boolean = - plan.logicalLink.isDefined + private def resetAdaptiveContext(): Unit = + localIsAdaptiveContextFlags.get().remove(0) + + def isAdaptiveContext: Boolean = Option(session.sparkContext.getLocalProperty(OMNI_IS_ADAPTIVE_CONTEXT)) + .getOrElse("false") + .toBoolean || localIsAdaptiveContextFlags.get().head override def preColumnarTransitions: Rule[SparkPlan] = plan => PhysicalPlanSelector. maybe(session, plan) { - isSupportAdaptive = supportAdaptive(plan) + setAdaptiveContext() setOriginalPlan(plan) - transformPlan(preOverrides, plan, "pre") + transformPlan(preOverrides(), plan, "pre") } override def postColumnarTransitions: Rule[SparkPlan] = plan => PhysicalPlanSelector. @@ -754,11 +750,10 @@ case class ColumnarOverrideRules(session: SparkSession) extends ColumnarRule wit // skip c2r and r2c replaceWithColumnarPlan fallbackPlan case plan => - val rule = postOverrides - rule.setAdaptiveSupport(isSupportAdaptive) - rule(plan) + transformPlan(postOverrides(), plan, "post") } resetOriginalPlan() + resetAdaptiveContext() transformPlan(finallyRules(), finalPlan, "final") } } @@ -780,7 +775,6 @@ class ColumnarPlugin extends (SparkSessionExtensions => Unit) with Logging { extensions.injectOptimizerRule(_ => DelayCartesianProduct) extensions.injectOptimizerRule(_ => HeuristicJoinReorder) extensions.injectOptimizerRule(_ => MergeSubqueryFilters) - extensions.injectQueryStagePrepRule(session => FallbackBroadcastExchange(session)) extensions.injectQueryStagePrepRule(session => DedupLeftSemiJoinAQE(session)) extensions.injectQueryStagePrepRule(_ => TopNPushDownForWindow) } diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarPluginConfig.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarPluginConfig.scala index a6c15e104..2622f9b25 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarPluginConfig.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarPluginConfig.scala @@ -273,6 +273,15 @@ class ColumnarPluginConfig(conf: SQLConf) extends Logging { // enable or disable bloomfilter subquery reuse val enableBloomfilterSubqueryReuse: Boolean = conf.getConfString("spark.omni.sql.columnar.bloomfilterSubqueryReuse", "false").toBoolean + + // The threshold for whether whole stage will fall back in AQE supported case by counting the number of vanilla SparkPlan. + // If it is set to -1, it means that this function is turned off + // otherwise, when the number of vanilla SparkPlan of the stage is greater than or equal to the threshold, + // all the SparkPlan of the stage will be fallback to vanilla SparkPlan + val wholeStageFallbackThreshold: Int = conf.getConfString("spark.omni.sql.columnar.wholeStage.fallback.threshold", "-1").toInt + + // it is same with wholeStageFallbackThreshold, but it is used for non AQE + val queryFallbackThreshold: Int = conf.getConfString("spark.omni.sql.columnar.query.fallback.threshold", "-1").toInt } diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/Constant.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/Constant.scala index 9d7f844bc..652142117 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/Constant.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/Constant.scala @@ -36,4 +36,6 @@ object Constant { val IS_SKIP_VERIFY_EXP: Boolean = true val OMNI_DECIMAL64_TYPE: String = DataTypeId.OMNI_DECIMAL64.ordinal().toString val OMNI_DECIMAL128_TYPE: String = DataTypeId.OMNI_DECIMAL128.ordinal().toString + // for UT + val OMNI_IS_ADAPTIVE_CONTEXT = "omni.isAdaptiveContext" } \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ExpandFallbackPolicy.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ExpandFallbackPolicy.scala index 1833cca66..73d6a3cd7 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ExpandFallbackPolicy.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ExpandFallbackPolicy.scala @@ -23,18 +23,18 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.execution.adaptive.{AQEShuffleReadExec, AdaptiveSparkPlanExec, BroadcastQueryStageExec, QueryStageExec, ShuffleQueryStageExec} +import org.apache.spark.sql.execution.adaptive.{AQEShuffleReadExec, AdaptiveSparkPlanExec, BroadcastQueryStageExec, OmniAQEShuffleReadExec, QueryStageExec, ShuffleQueryStageExec} import org.apache.spark.sql.execution.command.ExecutedCommandExec -import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec} import org.apache.spark.sql.execution.joins.ColumnarBroadcastHashJoinExec import org.apache.spark.sql.execution._ +import org.apache.spark.sql.execution.exchange.ReusedExchangeExec /** * Note, this rule should only fallback to row-based plan if there is no harm. * The follow case should be handled carefully * * @param isAdaptiveContext If is inside AQE - * @param originalPlan The vanilla SparkPlan without apply gluten transform rules + * @param originalPlan The vanilla SparkPlan without apply boostkit extension transform rules * * */ case class ExpandFallbackPolicy(isAdaptiveContext: Boolean, originalPlan: SparkPlan) @@ -53,10 +53,10 @@ case class ExpandFallbackPolicy(isAdaptiveContext: Boolean, originalPlan: SparkP case leafPlan: LeafExecNode if !isOmniSparkPlan(leafPlan) => // Possible fallback for leaf node. fallbacks = fallbacks + 1 - case p:ColumnarToRowExec => p.children.foreach(countFallbackInternal) - case p:RowToColumnarExec => p.children.foreach(countFallbackInternal) + case p@(ColumnarToRowExec(_) | RowToColumnarExec(_)) => + p.children.foreach(countFallbackInternal) case p => - if(!isOmniSparkPlan(p)) { + if (!isOmniSparkPlan(p)) { fallbacks = fallbacks + 1 } p.children.foreach(countFallbackInternal) @@ -100,17 +100,11 @@ case class ExpandFallbackPolicy(isAdaptiveContext: Boolean, originalPlan: SparkP return None } - // todo 157-162添加適配后是否需要添加這段邏輯 - // // not safe to fallback row-based BHJ as the broadcast exchange is already columnar - // if (hasColumnarBroadcastExchangeWithJoin(plan)) { - // return None - // } - - val netFallbackNum = countFallback(plan) + val fallbackNum = countFallback(plan) - if (netFallbackNum >= fallbackThreshold) { + if (fallbackNum >= fallbackThreshold) { Some( - s"Fallback policy is taking effect, net fallback number: $netFallbackNum, " + + s"Fallback policy is taking effect, net fallback number: $fallbackNum, " + s"threshold: $fallbackThreshold") } else { None @@ -118,17 +112,31 @@ case class ExpandFallbackPolicy(isAdaptiveContext: Boolean, originalPlan: SparkP } private def fallbackToRowBasedPlan(): SparkPlan = { - val columnarPostOverrides = ColumnarPostOverrides() + val columnarPostOverrides = ColumnarPostOverrides(isAdaptiveContext) val planWithColumnarToRow = InsertTransitions.insertTransitions(originalPlan, false) planWithColumnarToRow.transform { - case ColumnarToRowExec(bqe: BroadcastQueryStageExec) if bqe.plan.isInstanceOf[ColumnarBroadcastExchangeExec] => - val columnarBroadcastExchangeExec = bqe.plan.asInstanceOf[ColumnarBroadcastExchangeExec] - BroadcastQueryStageExec(bqe.id, BroadcastExchangeExec(columnarBroadcastExchangeExec.mode, ColumnarBroadcastExchangeAdaptorExec(columnarBroadcastExchangeExec, 1)), bqe._canonicalized) - case ColumnarToRowExec(bqe: BroadcastQueryStageExec) if bqe.plan.isInstanceOf[ReusedExchangeExec] && bqe.plan.asInstanceOf[ReusedExchangeExec].child.isInstanceOf[ColumnarBroadcastExchangeExec] => - val columnarBroadcastExchangeExec = bqe.plan.asInstanceOf[ReusedExchangeExec].child.asInstanceOf[ColumnarBroadcastExchangeExec] - BroadcastQueryStageExec(bqe.id, BroadcastExchangeExec(columnarBroadcastExchangeExec.mode, ColumnarBroadcastExchangeAdaptorExec(columnarBroadcastExchangeExec, 1)), bqe._canonicalized) - case c2r@(ColumnarToRowExec(_: ShuffleQueryStageExec) | ColumnarToRowExec(_: AQEShuffleReadExec)) => + case c2r@(ColumnarToRowExec(_: ShuffleQueryStageExec) | + ColumnarToRowExec(_: BroadcastQueryStageExec)) => columnarPostOverrides.replaceColumnarToRow(c2r.asInstanceOf[ColumnarToRowExec], conf) + case c2r@ColumnarToRowExec(aqeShuffleReadExec: AQEShuffleReadExec) => + val newPlan = columnarPostOverrides.replaceColumnarToRow(c2r, conf) + newPlan.withNewChildren(Seq(aqeShuffleReadExec.child match { + case _: ColumnarShuffleExchangeExec => + OmniAQEShuffleReadExec(aqeShuffleReadExec.child, aqeShuffleReadExec.partitionSpecs) + case ShuffleQueryStageExec(_, _: ColumnarShuffleExchangeExec, _) => + OmniAQEShuffleReadExec(aqeShuffleReadExec.child, aqeShuffleReadExec.partitionSpecs) + case ShuffleQueryStageExec(_, reused: ReusedExchangeExec, _) => + reused match { + case ReusedExchangeExec(_, _: ColumnarShuffleExchangeExec) => + OmniAQEShuffleReadExec( + aqeShuffleReadExec.child, + aqeShuffleReadExec.partitionSpecs) + case _ => + aqeShuffleReadExec + } + case _ => + aqeShuffleReadExec + })) } } @@ -136,6 +144,9 @@ case class ExpandFallbackPolicy(isAdaptiveContext: Boolean, originalPlan: SparkP logInfo("Using BoostKit Spark Native Sql Engine Extension FallbackPolicy") val reason = fallback(plan) if (reason.isDefined) { + if (hasColumnarBroadcastExchangeWithJoin(plan)) { + logDebug("plan fallback using ExpandFallbackPolicy contains ColumnarBroadcastExchange") + } val fallbackPlan = fallbackToRowBasedPlan() TransformHints.tagAllNotTransformable(fallbackPlan, reason.get) FallbackNode(fallbackPlan) diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/TransformHintRule.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/TransformHintRule.scala index 16701e665..e37e701ff 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/TransformHintRule.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/TransformHintRule.scala @@ -20,19 +20,16 @@ package com.huawei.boostkit.spark import org.apache.commons.lang3.exception.ExceptionUtils import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.trees.TreeNodeTag import org.apache.spark.sql.execution.adaptive.{BroadcastQueryStageExec, OmniAQEShuffleReadExec} import org.apache.spark.sql.execution.aggregate.HashAggregateExec -import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, BroadcastExchangeLike, ReusedExchangeExec, ShuffleExchangeExec} +import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec} import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, BroadcastNestedLoopJoinExec, ColumnarBroadcastHashJoinExec, ColumnarShuffledHashJoinExec, ColumnarSortMergeJoinExec, ShuffledHashJoinExec, SortMergeJoinExec} import org.apache.spark.sql.execution.window.WindowExec import org.apache.spark.sql.execution.{CoalesceExec, CodegenSupport, ColumnarBroadcastExchangeExec, ColumnarCoalesceExec, ColumnarExpandExec, ColumnarFileSourceScanExec, ColumnarFilterExec, ColumnarGlobalLimitExec, ColumnarHashAggregateExec, ColumnarLocalLimitExec, ColumnarProjectExec, ColumnarShuffleExchangeExec, ColumnarSortExec, ColumnarTakeOrderedAndProjectExec, ColumnarTopNSortExec, ColumnarUnionExec, ColumnarWindowExec, ExpandExec, FileSourceScanExec, FilterExec, GlobalLimitExec, LocalLimitExec, ProjectExec, SortExec, SparkPlan, TakeOrderedAndProjectExec, TopNSortExec, UnionExec} import org.apache.spark.sql.types.ColumnarBatchSupportUtil.checkColumnarBatchSupport -import scala.util.control.Breaks.{break, breakable} - trait TransformHint { val stacktrace: Option[String] = if (TransformHints.DEBUG) { @@ -120,10 +117,8 @@ case class AddTransformHintRule() extends Rule[SparkPlan] { val enableColumnarProject: Boolean = columnarConf.enableColumnarProject val enableColumnarFilter: Boolean = columnarConf.enableColumnarFilter val enableColumnarExpand: Boolean = columnarConf.enableColumnarExpand - val enableColumnarBroadcastExchange: Boolean = columnarConf.enableColumnarBroadcastExchange && - columnarConf.enableColumnarBroadcastJoin - val enableColumnarBroadcastJoin: Boolean = columnarConf.enableColumnarBroadcastExchange && - columnarConf.enableColumnarBroadcastJoin + val enableColumnarBroadcastExchange: Boolean = columnarConf.enableColumnarBroadcastExchange + val enableColumnarBroadcastJoin: Boolean = columnarConf.enableColumnarBroadcastJoin val enableColumnarSortMergeJoin: Boolean = columnarConf.enableColumnarSortMergeJoin val enableShuffledHashJoin: Boolean = columnarConf.enableShuffledHashJoin val enableColumnarFileScan: Boolean = columnarConf.enableColumnarFileScan @@ -308,153 +303,16 @@ case class AddTransformHintRule() extends Rule[SparkPlan] { } case _ => } - val isBhjColumnar: Boolean = try { - ColumnarBroadcastHashJoinExec( - plan.leftKeys, - plan.rightKeys, - plan.joinType, - plan.buildSide, - plan.condition, - plan.left, - plan.right, - plan.isNullAwareAntiJoin).buildCheck() - true - } catch { - case throwable@(_: UnsupportedOperationException | _: RuntimeException | _: Throwable) => - val message = s"[OPERATOR FALLBACK] ${throwable} ${plan.getClass} falls back to Spark operator" - logDebug(message) - TransformHints.tagNotTransformable(plan, reason = message) - false - case l: UnsatisfiedLinkError => - TransformHints.tagNotTransformable(plan) - throw l - case f: NoClassDefFoundError => - TransformHints.tagNotTransformable(plan) - throw f - } - - val buildSidePlan = plan.buildSide match { - case BuildLeft => plan.left - case BuildRight => plan.right - } - - val maybeExchange = buildSidePlan - .find { - case BroadcastExchangeExec(_, _) => true - case _ => false - } - .map(_.asInstanceOf[BroadcastExchangeExec]) - - maybeExchange match { - case Some(exchange @ BroadcastExchangeExec(_, _)) => - if (isBhjColumnar) { - TransformHints.tagTransformable(plan) - }else{ - TransformHints.tagNotTransformable(exchange) - TransformHints.tagNotTransformable(plan) - } - case None => - // we are in AQE, find the hidden exchange - // FIXME did we consider the case that AQE: OFF && Reuse: ON ? - var maybeHiddenExchange: Option[BroadcastExchangeLike] = None - breakable { - buildSidePlan.foreach { - case e: BroadcastExchangeLike => - maybeHiddenExchange = Some(e) - break - case t: BroadcastQueryStageExec => - t.plan.foreach { - case e2: BroadcastExchangeLike => - maybeHiddenExchange = Some(e2) - break - case r: ReusedExchangeExec => - r.child match { - case e2: BroadcastExchangeLike => - maybeHiddenExchange = Some(e2) - break - case _ => - } - case _ => - } - case _ => - } - } - // restriction to force the hidden exchange to be found - val exchange = maybeHiddenExchange.get - // to conform to the underlying exchange's type, columnar or vanilla - exchange match { - case _: ColumnarBroadcastExchangeExec => - if (!isBhjColumnar) { - throw new IllegalStateException( - s"BroadcastExchange has already been" + - s" transformed to columnar version but BHJ is determined as" + - s" non-transformable: ${plan.toString()}") - } - TransformHints.tagTransformable(plan) - case _: BroadcastExchangeExec => - TransformHints.tagNotTransformable( - plan, - "it's a materialized broadcast exchange or reused broadcast exchange") - } - } - case plan: BroadcastNestedLoopJoinExec => - val buildSidePlan = plan.buildSide match { - case BuildLeft => plan.left - case BuildRight => plan.right - } - - val maybeExchange = buildSidePlan - .find { - case BroadcastExchangeExec(_, _) => true - case _ => false - } - .map(_.asInstanceOf[BroadcastExchangeExec]) - - maybeExchange match { - case Some(exchange@BroadcastExchangeExec(_, _)) => - TransformHints.tagNotTransformable(exchange) - TransformHints.tagNotTransformable(plan) - case None => - // we are in AQE, find the hidden exchange - // FIXME did we consider the case that AQE: OFF && Reuse: ON ? - var maybeHiddenExchange: Option[BroadcastExchangeLike] = None - breakable { - buildSidePlan.foreach { - case e: BroadcastExchangeLike => - maybeHiddenExchange = Some(e) - break - case t: BroadcastQueryStageExec => - t.plan.foreach { - case e2: BroadcastExchangeLike => - maybeHiddenExchange = Some(e2) - break - case r: ReusedExchangeExec => - r.child match { - case e2: BroadcastExchangeLike => - maybeHiddenExchange = Some(e2) - break - case _ => - } - case _ => - } - case _ => - } - } - // restriction to force the hidden exchange to be found - val exchange = maybeHiddenExchange.get - // to conform to the underlying exchange's type, columnar or vanilla - exchange match { - case _: ColumnarBroadcastExchangeExec => - throw new IllegalStateException( - s"BroadcastExchange has already been" + - s" transformed to columnar version but BHJ is determined as" + - s" non-transformable: ${plan.toString()}") - case _: BroadcastExchangeExec => - TransformHints.tagNotTransformable( - plan, - "it's a materialized broadcast exchange or reused broadcast exchange") - } - } + ColumnarBroadcastHashJoinExec( + plan.leftKeys, + plan.rightKeys, + plan.joinType, + plan.buildSide, + plan.condition, + plan.left, + plan.right, + plan.isNullAwareAntiJoin).buildCheck() + TransformHints.tagTransformable(plan) case plan: SortMergeJoinExec => if (!enableColumnarSortMergeJoin) { TransformHints.tagNotTransformable( @@ -527,16 +385,22 @@ case class AddTransformHintRule() extends Rule[SparkPlan] { } } catch { - case throwable @ (_:UnsupportedOperationException | _:RuntimeException | _:Throwable) => - val message = s"[OPERATOR FALLBACK] ${throwable} ${plan.getClass} falls back to Spark operator" + case e: UnsupportedOperationException => + val message = s"[OPERATOR FALLBACK] ${e} ${plan.getClass} falls back to Spark operator" logDebug(message) TransformHints.tagNotTransformable(plan, reason = message) case l: UnsatisfiedLinkError => - TransformHints.tagNotTransformable(plan) throw l case f: NoClassDefFoundError => - TransformHints.tagNotTransformable(plan) throw f + case r: RuntimeException => + val message = s"[OPERATOR FALLBACK] ${r} ${plan.getClass} falls back to Spark operator" + logDebug(message) + TransformHints.tagNotTransformable(plan, reason = message) + case t: Throwable => + val message = s"[OPERATOR FALLBACK] ${t} ${plan.getClass} falls back to Spark operator" + logDebug(message) + TransformHints.tagNotTransformable(plan, reason = message) } } } diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/BroadcastColumnarRDD.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/BroadcastColumnarRDD.scala deleted file mode 100644 index 60398f033..000000000 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/BroadcastColumnarRDD.scala +++ /dev/null @@ -1,72 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution - -import nova.hetu.omniruntime.vector.VecBatch -import nova.hetu.omniruntime.vector.serialize.VecBatchSerializerFactory -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.execution.metric.SQLMetric -import org.apache.spark.sql.execution.vectorized.OmniColumnVector -import org.apache.spark.sql.types.StructType -import org.apache.spark.sql.vectorized.ColumnarBatch -import org.apache.spark.{Partition, SparkContext, TaskContext, broadcast} - - -private final case class BroadcastColumnarRDDPartition(index: Int) extends Partition - -case class BroadcastColumnarRDD( - @transient private val sc: SparkContext, - metrics: Map[String, SQLMetric], - numPartitioning: Int, - inputByteBuf: broadcast.Broadcast[ColumnarHashedRelation], - localSchema: StructType) - extends RDD[ColumnarBatch](sc, Nil) { - - override protected def getPartitions: Array[Partition] = { - (0 until numPartitioning).map { index => new BroadcastColumnarRDDPartition(index) }.toArray - } - - private def vecBatchToColumnarBatch(vecBatch: VecBatch): ColumnarBatch = { - val vectors: Seq[OmniColumnVector] = OmniColumnVector.allocateColumns( - vecBatch.getRowCount, localSchema, false) - vectors.zipWithIndex.foreach { case (vector, i) => - vector.reset() - vector.setVec(vecBatch.getVectors()(i)) - } - vecBatch.close() - new ColumnarBatch(vectors.toArray, vecBatch.getRowCount) - } - - override def compute(split: Partition, context: TaskContext): Iterator[ColumnarBatch] = { - // val relation = inputByteBuf.value.asReadOnlyCopy - // new CloseableColumnBatchIterator(relation.getColumnarBatchAsIter) - val deserializer = VecBatchSerializerFactory.create() - new Iterator[ColumnarBatch] { - var idx = 0 - val total_len = inputByteBuf.value.buildData.length - - override def hasNext: Boolean = idx < total_len - - override def next(): ColumnarBatch = { - val batch: VecBatch = deserializer.deserialize(inputByteBuf.value.buildData(idx)) - idx += 1 - vecBatchToColumnarBatch(batch) - } - } - } -} diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarBroadcastExchangeAdaptorExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarBroadcastExchangeAdaptorExec.scala deleted file mode 100644 index 1d236c16d..000000000 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarBroadcastExchangeAdaptorExec.scala +++ /dev/null @@ -1,70 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution - -import nova.hetu.omniruntime.vector.Vec -import org.apache.spark.broadcast -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Attribute, SortOrder, UnsafeProjection} -import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, UnknownPartitioning} -import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} -import org.apache.spark.sql.execution.util.SparkMemoryUtils -import org.apache.spark.sql.execution.vectorized.OmniColumnVector -import org.apache.spark.sql.types.StructType - -import scala.collection.JavaConverters.asScalaIteratorConverter -import scala.collection.mutable.ListBuffer - -case class ColumnarBroadcastExchangeAdaptorExec(child: SparkPlan, numPartitions: Int) - extends UnaryExecNode { - override def output: Seq[Attribute] = child.output - - override def outputPartitioning: Partitioning = UnknownPartitioning(numPartitions) - - override def outputOrdering: Seq[SortOrder] = child.outputOrdering - - override def doExecute(): RDD[InternalRow] = { - val numOutputRows: SQLMetric = longMetric("numOutputRows") - val numOutputBatches: SQLMetric = longMetric("numOutputBatches") - val processTime: SQLMetric = longMetric("processTime") - val inputRdd: BroadcastColumnarRDD = BroadcastColumnarRDD( - sparkContext, - metrics, - numPartitions, - child.executeBroadcast(), - StructType.fromAttributes(child.output)) - inputRdd.mapPartitions { batches => - ColumnarBatchToInternalRow.convert(output, batches, numOutputRows, numOutputBatches, processTime) - } - } - - override def doExecuteBroadcast[T](): broadcast.Broadcast[T] = { - child.executeBroadcast() - } - - override def supportsColumnar: Boolean = true - - override lazy val metrics: Map[String, SQLMetric] = Map( - "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"), - "numOutputBatches" -> SQLMetrics.createMetric(sparkContext, "output_batches"), - "processTime" -> SQLMetrics.createTimingMetric(sparkContext, "totaltime_datatoarrowcolumnar")) - - override protected def withNewChildInternal(newChild: SparkPlan): - ColumnarBroadcastExchangeAdaptorExec = copy(child = newChild) -} \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarExec.scala index cec2012e6..d8a2bea77 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarExec.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarExec.scala @@ -17,26 +17,23 @@ package org.apache.spark.sql.execution -import java.util.concurrent.TimeUnit.NANOSECONDS - -import scala.collection.JavaConverters._ -import scala.collection.mutable.ListBuffer - import org.apache.spark.broadcast +import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, SortOrder, SpecializedGetters, UnsafeProjection} import org.apache.spark.sql.catalyst.plans.physical.Partitioning -import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} -import org.apache.spark.sql.execution.util.SparkMemoryUtils +import org.apache.spark.sql.execution.util.{BroadcastUtils, SparkMemoryUtils} import org.apache.spark.sql.execution.vectorized.{OffHeapColumnVector, OmniColumnVector, WritableColumnVector} -import org.apache.spark.sql.types.{BinaryType, BooleanType, ByteType, CalendarIntervalType, DataType, DateType, DecimalType, DoubleType, IntegerType, LongType, ShortType, StringType, StructType, TimestampType} +import org.apache.spark.sql.types._ import org.apache.spark.sql.vectorized.ColumnarBatch -import org.apache.spark.util.Utils -import nova.hetu.omniruntime.vector.Vec +import java.util.concurrent.TimeUnit.NANOSECONDS +import scala.collection.JavaConverters._ +import scala.collection.mutable.ListBuffer +import nova.hetu.omniruntime.vector.Vec /** * Provides an optimized set of APIs to append row based data to an array of @@ -205,7 +202,32 @@ case class RowToOmniColumnarExec(child: SparkPlan) extends RowToColumnarTransiti } override def doExecuteBroadcast[T](): broadcast.Broadcast[T] = { - child.doExecuteBroadcast() + val numInputRows = longMetric("numInputRows") + val numOutputBatches = longMetric("numOutputBatches") + val rowToOmniColumnarTime = longMetric("rowToOmniColumnarTime") + // Instead of creating a new config we are reusing columnBatchSize. In the future if we do + // combine with some of the Arrow conversion tools we will need to unify some of the configs. + val numRows = conf.columnBatchSize + val enableOffHeapColumnVector = session.sqlContext.conf.offHeapColumnVectorEnabled + val localSchema = this.schema + val relation = child.executeBroadcast() + val mode = BroadcastUtils.getBroadCastMode(outputPartitioning) + val broadcast: Broadcast[T] = BroadcastUtils.sparkToOmniUnsafe(sparkContext, + mode, + relation, + logError, + InternalRowToColumnarBatch.convert( + enableOffHeapColumnVector, + numInputRows, + numOutputBatches, + rowToOmniColumnarTime, + numRows, + localSchema, + _) + ) + val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) + SQLMetrics.postDriverMetricUpdates(sparkContext, executionId, Seq(numInputRows, numOutputBatches, rowToOmniColumnarTime)) + broadcast } override def nodeName: String = "RowToOmniColumnar" @@ -233,44 +255,10 @@ case class RowToOmniColumnarExec(child: SparkPlan) extends RowToColumnarTransiti // plan (this) in the closure. val localSchema = this.schema child.execute().mapPartitionsInternal { rowIterator => - if (rowIterator.hasNext) { - new Iterator[ColumnarBatch] { - private val converters = new OmniRowToColumnConverter(localSchema) - - override def hasNext: Boolean = { - rowIterator.hasNext - } - - override def next(): ColumnarBatch = { - val startTime = System.nanoTime() - val vectors: Seq[WritableColumnVector] = OmniColumnVector.allocateColumns(numRows, - localSchema, true) - val cb: ColumnarBatch = new ColumnarBatch(vectors.toArray) - cb.setNumRows(0) - vectors.foreach(_.reset()) - var rowCount = 0 - while (rowCount < numRows && rowIterator.hasNext) { - val row = rowIterator.next() - converters.convert(row, vectors.toArray) - rowCount += 1 - } - if (!enableOffHeapColumnVector) { - vectors.foreach { v => - v.asInstanceOf[OmniColumnVector].getVec.setSize(rowCount) - } - } - cb.setNumRows(rowCount) - numInputRows += rowCount - numOutputBatches += 1 - rowToOmniColumnarTime += NANOSECONDS.toMillis(System.nanoTime() - startTime) - cb - } - } - } else { - Iterator.empty - } + InternalRowToColumnarBatch.convert(enableOffHeapColumnVector, numInputRows, numOutputBatches, rowToOmniColumnarTime, numRows, localSchema, rowIterator) } } + } @@ -298,6 +286,23 @@ case class OmniColumnarToRowExec(child: SparkPlan, |""".stripMargin } + override def doExecuteBroadcast[T](): Broadcast[T] = { + val numOutputRows = longMetric("numOutputRows") + val numInputBatches = longMetric("numInputBatches") + val omniColumnarToRowTime = longMetric("omniColumnarToRowTime") + val mode = BroadcastUtils.getBroadCastMode(outputPartitioning) + val relation = child.executeBroadcast() + val broadcast: Broadcast[T] = BroadcastUtils.omniToSparkUnsafe(sparkContext, + mode, + relation, + StructType.fromAttributes(output), + ColumnarBatchToInternalRow.convert(output, _, numOutputRows, numInputBatches, omniColumnarToRowTime) + ) + val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) + SQLMetrics.postDriverMetricUpdates(sparkContext, executionId, Seq(numOutputRows, numInputBatches, omniColumnarToRowTime)) + broadcast + } + override def doExecute(): RDD[InternalRow] = { val numOutputRows = longMetric("numOutputRows") val numInputBatches = longMetric("numInputBatches") @@ -311,7 +316,55 @@ case class OmniColumnarToRowExec(child: SparkPlan, } override protected def withNewChildInternal(newChild: SparkPlan): - OmniColumnarToRowExec = copy(child = newChild) + OmniColumnarToRowExec = copy(child = newChild) +} + +object InternalRowToColumnarBatch { + + def convert(enableOffHeapColumnVector: Boolean, + numInputRows: SQLMetric, + numOutputBatches: SQLMetric, + rowToOmniColumnarTime: SQLMetric, + numRows: Int, localSchema: StructType, + rowIterator: Iterator[InternalRow]): Iterator[ColumnarBatch] = { + if (rowIterator.hasNext) { + new Iterator[ColumnarBatch] { + private val converters = new OmniRowToColumnConverter(localSchema) + + override def hasNext: Boolean = { + rowIterator.hasNext + } + + override def next(): ColumnarBatch = { + val startTime = System.nanoTime() + val vectors: Seq[WritableColumnVector] = OmniColumnVector.allocateColumns(numRows, + localSchema, true) + val cb: ColumnarBatch = new ColumnarBatch(vectors.toArray) + cb.setNumRows(0) + vectors.foreach(_.reset()) + var rowCount = 0 + while (rowCount < numRows && rowIterator.hasNext) { + val row = rowIterator.next() + converters.convert(row, vectors.toArray) + rowCount += 1 + } + if (!enableOffHeapColumnVector) { + vectors.foreach { v => + v.asInstanceOf[OmniColumnVector].getVec.setSize(rowCount) + } + } + cb.setNumRows(rowCount) + numInputRows += rowCount + numOutputBatches += 1 + rowToOmniColumnarTime += NANOSECONDS.toMillis(System.nanoTime() - startTime) + cb + } + } + } else { + Iterator.empty + } + } + } object ColumnarBatchToInternalRow { @@ -346,16 +399,16 @@ object ColumnarBatchToInternalRow { SparkMemoryUtils.addLeakSafeTaskCompletionListener { _ => - toClosedVecs.foreach {vec => - vec.close() - } + toClosedVecs.foreach { vec => + vec.close() + } } - override def hasNext: Boolean = { + override def hasNext: Boolean = { val has = iter.hasNext // fetch all rows if (!has) { - toClosedVecs.foreach {vec => + toClosedVecs.foreach { vec => vec.close() toClosedVecs.remove(toClosedVecs.indexOf(vec)) } diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/util/BroadcastUtils.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/util/BroadcastUtils.scala new file mode 100644 index 000000000..2204e0fe4 --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/util/BroadcastUtils.scala @@ -0,0 +1,276 @@ +/* + * Copyright (C) 2024-2024. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.util + +import com.huawei.boostkit.spark.util.OmniAdaptorUtil.transColBatchToOmniVecs +import org.apache.spark.{SparkConf, SparkContext, TaskContext, TaskContextImpl} +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.executor.TaskMetrics +import org.apache.spark.memory.{TaskMemoryManager, UnifiedMemoryManager} +import org.apache.spark.metrics.MetricsSystem +import org.apache.spark.network.util.ByteUnit +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.catalyst.plans.physical.{BroadcastMode, BroadcastPartitioning, IdentityBroadcastMode, Partitioning} +import org.apache.spark.sql.execution.ColumnarHashedRelation +import org.apache.spark.sql.execution.joins.{HashedRelation, HashedRelationBroadcastMode} +import org.apache.spark.sql.execution.vectorized.OmniColumnVector +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.vectorized.ColumnarBatch + +import java.util.Properties +import scala.collection.mutable.{ArrayBuffer, ListBuffer} + +import nova.hetu.omniruntime.vector.serialize.VecBatchSerializerFactory +import nova.hetu.omniruntime.vector.VecBatch + +object BroadcastUtils { + + def getBroadCastMode(partitioning: Partitioning): BroadcastMode = { + partitioning match { + case BroadcastPartitioning(mode) => + mode + case _ => + throw new IllegalArgumentException("Unexpected partitioning: " + partitioning.toString) + } + } + + def omniToSparkUnsafe[F, T]( + context: SparkContext, + mode: BroadcastMode, + from: Broadcast[F], + fromSchema: StructType, + fn: Iterator[ColumnarBatch] => Iterator[InternalRow]): Broadcast[T] = { + mode match { + case HashedRelationBroadcastMode(_, _) => + // ColumnarHashedRelation to HashedRelation. + val fromBroadcast = from.asInstanceOf[Broadcast[ColumnarHashedRelation]] + val fromRelation = fromBroadcast.value + val toRelation = runUnSafe(context.getConf) { + val (rowCount, rowIterator) = deserializeRelation(fromSchema, fn, fromRelation) + mode.transform(rowIterator, Some(rowCount)) + } + // Rebroadcast Spark relation. + context.broadcast(toRelation).asInstanceOf[Broadcast[T]] + case IdentityBroadcastMode => + // ColumnarBuildSideRelation to Array. + val fromBroadcast = from.asInstanceOf[Broadcast[ColumnarHashedRelation]] + val fromRelation = fromBroadcast.value + val toRelation = runUnSafe(context.getConf) { + val (_, rowIterator) = deserializeRelation(fromSchema, fn, fromRelation) + val rowArray = new ArrayBuffer[InternalRow]() + while (rowIterator.hasNext) { + val unsafeRow = rowIterator.next().asInstanceOf[UnsafeRow] + rowArray.append(unsafeRow.copy()) + } + rowArray.toArray + } + context.broadcast(toRelation).asInstanceOf[Broadcast[T]] + case _ => throw new IllegalStateException("Unexpected broadcast mode: " + mode) + } + } + + def sparkToOmniUnsafe[F, T](context: SparkContext, + mode: BroadcastMode, + from: Broadcast[F], + logFunc: (=> String) => Unit, + fn: Iterator[InternalRow] => Iterator[ColumnarBatch]): Broadcast[T] = { + mode match { + case HashedRelationBroadcastMode(_, _) => + // HashedRelation to ColumnarHashedRelation + val fromBroadcast = from.asInstanceOf[Broadcast[HashedRelation]] + val fromRelation = fromBroadcast.value + val toRelation = runUnSafe(context.getConf) { + val (nullBatchCount, input) = serializeRelation(mode, fn(fromRelation.keys().flatMap(fromRelation.get)), logError = logFunc) + val relation = new ColumnarHashedRelation + relation.converterData(mode, nullBatchCount, input) + relation + } + // Rebroadcast Omni relation. + context.broadcast(toRelation).asInstanceOf[Broadcast[T]] + case IdentityBroadcastMode => + // ColumnarBuildSideRelation to Array. + val fromBroadcast = from.asInstanceOf[Broadcast[Array[InternalRow]]] + val fromRelation = fromBroadcast.value + val toRelation = runUnSafe(context.getConf) { + val (nullBatchCount, input) = serializeRelation(mode, fn(fromRelation.toIterator), logError = logFunc) + val relation = new ColumnarHashedRelation + relation.converterData(mode, nullBatchCount, input) + relation + } + // Rebroadcast Omni relation. + context.broadcast(toRelation).asInstanceOf[Broadcast[T]] + case _ => throw new IllegalStateException("Unexpected broadcast mode: " + mode) + } + } + + private def deserializeRelation(fromSchema: StructType, + fn: Iterator[ColumnarBatch] => Iterator[InternalRow], + fromRelation: ColumnarHashedRelation): (Long, Iterator[InternalRow]) = { + val deserializer = VecBatchSerializerFactory.create() + val data = fromRelation.buildData + var rowCount = 0 + val batchBuffer = ListBuffer[ColumnarBatch]() + val rowIterator = fn( + try { + data.map(bytes => { + val batch: VecBatch = deserializer.deserialize(bytes) + val columnarBatch = vecBatchToColumnarBatch(batch, fromSchema) + batchBuffer.append(columnarBatch) + rowCount += columnarBatch.numRows() + columnarBatch + }).toIterator + } catch { + case exception: Exception => + batchBuffer.foreach(_.close()) + throw exception + } + ) + (rowCount, rowIterator) + } + + private def serializeRelation(mode: BroadcastMode, + batches: Iterator[ColumnarBatch], + logError: (=> String) => Unit): (Int, Array[Array[Byte]]) = { + val serializer = VecBatchSerializerFactory.create() + var nullBatchCount = 0 + val nullRelationFlag = mode match { + case hashRelMode: HashedRelationBroadcastMode => + hashRelMode.isNullAware + case _ => false + } + val input: Array[Array[Byte]] = batches.map(batch => { + // When nullRelationFlag is true, it means anti-join + // Only one column of data is involved in the anti- + if (nullRelationFlag && batch.numCols() > 0) { + val vec = batch.column(0) + if (vec.hasNull) { + try { + nullBatchCount += 1 + } catch { + case e: Exception => + logError(s"compute null BatchCount error : ${e.getMessage}.") + } + } + } + val vectors = transColBatchToOmniVecs(batch) + val vecBatch = new VecBatch(vectors, batch.numRows()) + val vecBatchSer = serializer.serialize(vecBatch) + // close omni vec + vecBatch.releaseAllVectors() + vecBatch.close() + vecBatchSer + }).toArray + (nullBatchCount, input) + } + + private def vecBatchToColumnarBatch(vecBatch: VecBatch, schema: StructType): ColumnarBatch = { + val vectors: Seq[OmniColumnVector] = OmniColumnVector.allocateColumns( + vecBatch.getRowCount, schema, false) + vectors.zipWithIndex.foreach { case (vector, i) => + vector.reset() + vector.setVec(vecBatch.getVectors()(i)) + } + vecBatch.close() + new ColumnarBatch(vectors.toArray, vecBatch.getRowCount) + } + + // Run code with unsafe task context. If the call took place from Spark driver or test code + // without a Spark task context registered, a temporary unsafe task context instance will + // be created and used. Since unsafe task context is not managed by Spark's task memory manager, + // Spark may not be aware of the allocations happened inside the user code. + // + // The API should only be used in the following cases: + // + // 1. Run code on driver + // 2. Run test code + private def runUnSafe[T](sparkConf: SparkConf)(body: => T): T = { + if (inSparkTask()) { + throw new UnsupportedOperationException("runUnsafe should only be used outside Spark task") + } + TaskContext.setTaskContext(createUnSafeTaskContext(sparkConf)) + val context = getLocalTaskContext() + try { + val out = + try { + body + } catch { + case t: Throwable => + // Similar code with those in Task.scala + try { + context.markTaskFailed(t) + } catch { + case t: Throwable => + t.addSuppressed(t) + } + context.markTaskCompleted(Some(t)) + throw t + } finally { + try { + context.markTaskCompleted(None) + } finally { + unsetUnsafeTaskContext() + } + } + out + } catch { + case t: Throwable => + throw t + } + } + + private def getLocalTaskContext(): TaskContext = { + TaskContext.get() + } + + private def inSparkTask(): Boolean = { + TaskContext.get() != null + } + + private def unsetUnsafeTaskContext(): Unit = { + if (!inSparkTask()) { + throw new IllegalStateException() + } + if (getLocalTaskContext().taskAttemptId() != -1) { + throw new IllegalStateException() + } + TaskContext.unset() + } + + private def createUnSafeTaskContext(sparkConf: SparkConf): TaskContext = { + // driver code run on unsafe task context which is not managed by Spark's task memory manager, + // so the maxHeapMemory set ByteUnit.TiB.toBytes(2) + val memoryManager = + new UnifiedMemoryManager(sparkConf, ByteUnit.TiB.toBytes(2), ByteUnit.TiB.toBytes(1), 1) + new TaskContextImpl( + -1, + -1, + -1, + -1L, + -1, + new TaskMemoryManager(memoryManager, -1L), + new Properties, + MetricsSystem.createMetricsSystem("OMNI_UNSAFE", sparkConf), + TaskMetrics.empty, + 1, + Map.empty + ) + } + +} diff --git a/omnioperator/omniop-spark-extension/java/src/test/scala/com/huawei/boostkit/spark/FallbackStrategiesSuite.scala b/omnioperator/omniop-spark-extension/java/src/test/scala/com/huawei/boostkit/spark/FallbackStrategiesSuite.scala index b65c0d547..9b97295fc 100644 --- a/omnioperator/omniop-spark-extension/java/src/test/scala/com/huawei/boostkit/spark/FallbackStrategiesSuite.scala +++ b/omnioperator/omniop-spark-extension/java/src/test/scala/com/huawei/boostkit/spark/FallbackStrategiesSuite.scala @@ -25,10 +25,10 @@ import org.apache.spark.sql.{QueryTest, Row} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.Statistics -import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, QueryStageExec} -import org.apache.spark.sql.execution.exchange.Exchange -import org.apache.spark.sql.execution.joins.BroadcastHashJoinExec -import org.apache.spark.sql.execution.{ColumnarProjectExec, ColumnarTakeOrderedAndProjectExec, LeafExecNode, ProjectExec, SparkPlan, TakeOrderedAndProjectExec, UnaryExecNode} +import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, BroadcastQueryStageExec, QueryStageExec} +import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, Exchange} +import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, ColumnarBroadcastHashJoinExec} +import org.apache.spark.sql.execution.{ColumnarBroadcastExchangeExec, ColumnarProjectExec, ColumnarTakeOrderedAndProjectExec, LeafExecNode, OmniColumnarToRowExec, ProjectExec, RowToOmniColumnarExec, SparkPlan, TakeOrderedAndProjectExec, UnaryExecNode} import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf} import org.apache.spark.sql.test.SharedSparkSession @@ -60,21 +60,81 @@ class FallbackStrategiesSuite extends QueryTest with SharedSparkSession { employees.createOrReplaceTempView("employees_for_fallback_ut_test") } + test("c2r doExecuteBroadcast") { + withSQLConf(("spark.omni.sql.columnar.wholeStage.fallback.threshold", "3"), + (SQLConf.ADAPTIVE_EXECUTION_ENABLED.key, "true"), + (SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key, "10MB"), + ("spark.omni.sql.columnar.broadcastJoin", "false")) { + val df = spark.sql("select t1.age * 2, t2.salary from employees_for_fallback_ut_test t1 join employees_for_fallback_ut_test t2 on t1.age = t2.age sort by t1.age") + val runRows = df.collect() + val expectedRows = Seq(Row(56, 21000), Row(56, 21000), + Row(66, 30000), Row(66, 30000), + Row(70, 10000), Row(76, 32000), + Row(76, 32000), Row(46, 29000), + Row(50, 23000), Row(56, 29000), + Row(56, 29000), Row(66, 23000), + Row(66, 23000),Row(76, 35000),Row(76, 35000)) + assert(QueryTest.sameRows(runRows, expectedRows).isEmpty, "the run value is error") + val bhj = df.queryExecution.executedPlan.asInstanceOf[AdaptiveSparkPlanExec].executedPlan.find({ + case _: BroadcastHashJoinExec => true + case _ => false + }) + assert(bhj.isDefined, "bhj is fallback") + val c2rWithOmniBroadCast = bhj.get.children.find({ + case p: OmniColumnarToRowExec => + p.child.isInstanceOf[BroadcastQueryStageExec] && + p.child.asInstanceOf[BroadcastQueryStageExec].plan.isInstanceOf[ColumnarBroadcastExchangeExec] + case _ => false + }) + assert(c2rWithOmniBroadCast.isDefined, "bhj should use omni c2r to adapt to omni broadcast exchange") + } + } + + test("r2c doExecuteBroadcast") { + withSQLConf(("spark.omni.sql.columnar.wholeStage.fallback.threshold", "3"), + (SQLConf.ADAPTIVE_EXECUTION_ENABLED.key, "true"), + (SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key, "10MB"), + ("spark.omni.sql.columnar.broadcastexchange", "false")) { + val df = spark.sql("select t1.age * 2, t2.salary from employees_for_fallback_ut_test t1 join employees_for_fallback_ut_test t2 on t1.age = t2.age sort by t1.age") + val runRows = df.collect() + val expectedRows = Seq(Row(56, 21000), Row(56, 21000), + Row(66, 30000), Row(66, 30000), + Row(70, 10000), Row(76, 32000), + Row(76, 32000), Row(46, 29000), + Row(50, 23000), Row(56, 29000), + Row(56, 29000), Row(66, 23000), + Row(66, 23000),Row(76, 35000),Row(76, 35000)) + assert(QueryTest.sameRows(runRows, expectedRows).isEmpty, "the run value is error") + val omniBhj = df.queryExecution.executedPlan.asInstanceOf[AdaptiveSparkPlanExec].executedPlan.find({ + case _: ColumnarBroadcastHashJoinExec => true + case _ => false + }) + assert(omniBhj.isDefined, "bhj should be columnar") + val r2cWithBroadCast = omniBhj.get.children.find({ + case p: RowToOmniColumnarExec => + p.child.isInstanceOf[BroadcastQueryStageExec] && + p.child.asInstanceOf[BroadcastQueryStageExec].plan.isInstanceOf[BroadcastExchangeExec] + case _ => false + }) + assert(r2cWithBroadCast.isDefined, "OmniBhj should use omni r2c to adapt to broadcast exchange") + } + } + test("Fall back stage contain bhj") { withSQLConf(("spark.omni.sql.columnar.wholeStage.fallback.threshold", "3"), ("spark.omni.sql.columnar.project", "false"), (SQLConf.ADAPTIVE_EXECUTION_ENABLED.key, "true"), (SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key, "10MB")) { val df = spark.sql("select t1.age * 2, t2.salary from employees_for_fallback_ut_test t1 join employees_for_fallback_ut_test t2 on t1.age = t2.age sort by t1.age") - val runRows = df.collect().sortBy(row => row.getInt(1)) - val expectedRows = Seq(Row(20000, 35), Row(64000, 38), - Row(42000, 28), Row(60000, 33), - Row(46000, 33), Row(46000, 33), - Row(58000, 28), Row(58000, 28), - Row(70000, 38), Row(58000, 23), - Row(58000, 23), Row(46000, 25), - Row(46000, 25)).sortBy(row => row.getInt(1)) - QueryTest.sameRows(runRows, expectedRows) + val runRows = df.collect() + val expectedRows = Seq(Row(56, 21000), Row(56, 21000), + Row(66, 30000), Row(66, 30000), + Row(70, 10000), Row(76, 32000), + Row(76, 32000), Row(46, 29000), + Row(50, 23000), Row(56, 29000), + Row(56, 29000), Row(66, 23000), + Row(66, 23000),Row(76, 35000),Row(76, 35000)) + assert(QueryTest.sameRows(runRows, expectedRows).isEmpty, "the run value is error") val plans = df.queryExecution.executedPlan.asInstanceOf[AdaptiveSparkPlanExec].executedPlan.collect({ case plan: BroadcastHashJoinExec => plan }) @@ -134,7 +194,7 @@ class FallbackStrategiesSuite extends QueryTest with SharedSparkSession { val originalPlan = UnaryOp2(UnaryOp1(UnaryOp2(UnaryOp1(LeafOp())))) val rule = ColumnarOverrideRules(spark) rule.preColumnarTransitions(originalPlan) - rule.isSupportAdaptive = true + rule.enableAdaptiveContext() // Fake output of preColumnarTransitions, mocking replacing UnaryOp1 with UnaryOp1Transformer. val planAfterPreOverride = UnaryOp2(ColumnarUnaryOp1(UnaryOp2(ColumnarUnaryOp1(LeafOp())))) @@ -150,7 +210,7 @@ class FallbackStrategiesSuite extends QueryTest with SharedSparkSession { val originalPlan = UnaryOp2(UnaryOp1(UnaryOp2(UnaryOp1(LeafOp())))) val rule = ColumnarOverrideRules(spark) rule.preColumnarTransitions(originalPlan) - rule.isSupportAdaptive = true + rule.enableAdaptiveContext() // Fake output of preColumnarTransitions, mocking replacing UnaryOp1 with UnaryOp1Transformer. val planAfterPreOverride = UnaryOp2(ColumnarUnaryOp1(UnaryOp2(ColumnarUnaryOp1(LeafOp())))) @@ -172,8 +232,7 @@ class FallbackStrategiesSuite extends QueryTest with SharedSparkSession { val originalPlan = UnaryOp2(UnaryOp1(UnaryOp2(UnaryOp1(LeafOp())))) val rule = ColumnarOverrideRules(spark) rule.preColumnarTransitions(originalPlan) - rule.isSupportAdaptive = true - // Fake output of preColumnarTransitions, mocking replacing UnaryOp1 with UnaryOp1Transformer + rule.enableAdaptiveContext() // Fake output of preColumnarTransitions, mocking replacing UnaryOp1 with UnaryOp1Transformer // and replacing LeafOp with LeafOpTransformer. val planAfterPreOverride = UnaryOp2(ColumnarUnaryOp1(UnaryOp2(ColumnarUnaryOp1(ColumnarLeafOp())))) @@ -195,7 +254,7 @@ class FallbackStrategiesSuite extends QueryTest with SharedSparkSession { val originalPlan = UnaryOp2(UnaryOp1(UnaryOp2(UnaryOp1(LeafOp())))) val rule = ColumnarOverrideRules(spark) rule.preColumnarTransitions(originalPlan) - rule.isSupportAdaptive = true + rule.enableAdaptiveContext() // Fake output of preColumnarTransitions, mocking replacing UnaryOp1 with UnaryOp1Transformer // and replacing LeafOp with LeafOpTransformer. val planAfterPreOverride = @@ -236,7 +295,7 @@ class FallbackStrategiesSuite extends QueryTest with SharedSparkSession { val originalPlan = Exchange1(UnaryOp1(UnaryOp1(UnaryOp1(LeafOp())))) val rule = ColumnarOverrideRules(spark) rule.preColumnarTransitions(originalPlan) - rule.isSupportAdaptive = true + rule.enableAdaptiveContext() // Fake output of preColumnarTransitions, mocking replacing UnaryOp1 with UnaryOp1Transformer. val planAfterPreOverride = ColumnarExchange1(ColumnarUnaryOp1(ColumnarUnaryOp1(ColumnarUnaryOp1(ColumnarLeafOp())))) @@ -257,7 +316,7 @@ class FallbackStrategiesSuite extends QueryTest with SharedSparkSession { val originalPlan = Exchange1(UnaryOp1(UnaryOp1(mockQueryStageExec))) val rule = ColumnarOverrideRules(spark) rule.preColumnarTransitions(originalPlan) - rule.isSupportAdaptive = true + rule.enableAdaptiveContext() // Fake output of preColumnarTransitions, mocking replacing UnaryOp1 with UnaryOp1Transformer. val planAfterPreOverride = ColumnarExchange1(ColumnarUnaryOp1(UnaryOp1(mockQueryStageExec))) @@ -327,8 +386,8 @@ case class ColumnarUnaryOp1( } case class Exchange1( - override val child: SparkPlan, - override val supportsColumnar: Boolean = false) + override val child: SparkPlan, + override val supportsColumnar: Boolean = false) extends Exchange { override protected def doExecute(): RDD[InternalRow] = throw new UnsupportedOperationException() @@ -339,8 +398,8 @@ case class Exchange1( } case class ColumnarExchange1( - override val child: SparkPlan, - override val supportsColumnar: Boolean = true) + override val child: SparkPlan, + override val supportsColumnar: Boolean = true) extends Exchange { override protected def doExecute(): RDD[InternalRow] = throw new UnsupportedOperationException() -- Gitee From e579de9bcce0e779c61e58de1a92a99ef52e80e5 Mon Sep 17 00:00:00 2001 From: dengzhaochu Date: Thu, 10 Oct 2024 11:29:39 +0800 Subject: [PATCH 4/5] fix subquery --- .../boostkit/spark/ColumnarPlugin.scala | 1 + .../boostkit/spark/ColumnarPluginConfig.scala | 5 +- .../spark/FallbackBroadcastExchange.scala | 58 +++++++++++++++++++ 3 files changed, 63 insertions(+), 1 deletion(-) create mode 100644 omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/FallbackBroadcastExchange.scala diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarPlugin.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarPlugin.scala index 76e0d2eb6..7744566bf 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarPlugin.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarPlugin.scala @@ -777,6 +777,7 @@ class ColumnarPlugin extends (SparkSessionExtensions => Unit) with Logging { extensions.injectOptimizerRule(_ => MergeSubqueryFilters) extensions.injectQueryStagePrepRule(session => DedupLeftSemiJoinAQE(session)) extensions.injectQueryStagePrepRule(_ => TopNPushDownForWindow) + extensions.injectQueryStagePrepRule(session => FallbackBroadcastExchange(session)) } } diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarPluginConfig.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarPluginConfig.scala index 2622f9b25..0b98c8bba 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarPluginConfig.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarPluginConfig.scala @@ -18,6 +18,7 @@ package com.huawei.boostkit.spark +import com.huawei.boostkit.spark.ColumnarPluginConfig.{OMNI_BROADCAST_EXCHANGE_ENABLE_KEY} import org.apache.spark.SparkEnv import org.apache.spark.internal.Logging import org.apache.spark.shuffle.sort.ColumnarShuffleManager @@ -68,7 +69,7 @@ class ColumnarPluginConfig(conf: SQLConf) extends Logging { // enable or disable columnar broadcastexchange val enableColumnarBroadcastExchange: Boolean = - conf.getConfString("spark.omni.sql.columnar.broadcastexchange", "true").toBoolean + conf.getConfString(OMNI_BROADCAST_EXCHANGE_ENABLE_KEY, "true").toBoolean // enable or disable columnar wholestagecodegen val enableColumnarWholeStageCodegen: Boolean = @@ -289,6 +290,8 @@ object ColumnarPluginConfig { val OMNI_ENABLE_KEY: String = "spark.omni.enabled" + val OMNI_BROADCAST_EXCHANGE_ENABLE_KEY = "spark.omni.sql.columnar.broadcastexchange" + var ins: ColumnarPluginConfig = null def getConf: ColumnarPluginConfig = synchronized { diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/FallbackBroadcastExchange.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/FallbackBroadcastExchange.scala new file mode 100644 index 000000000..6a7c9180c --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/FallbackBroadcastExchange.scala @@ -0,0 +1,58 @@ +/* + * Copyright (C) 2024-2024. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.huawei.boostkit.spark + +import com.huawei.boostkit.spark.ColumnarPluginConfig.OMNI_BROADCAST_EXCHANGE_ENABLE_KEY +import com.huawei.boostkit.spark.util.PhysicalPlanSelector +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution.{ColumnarBroadcastExchangeExec, SparkPlan} +import org.apache.spark.sql.execution.exchange.BroadcastExchangeExec + +case class FallbackBroadcastExchange(session: SparkSession) extends Rule[SparkPlan] { + + override def apply(plan: SparkPlan): SparkPlan = PhysicalPlanSelector.maybe(session, plan) { + + val SQLconf = session.sessionState.conf + val enableColumnarBroadcastExchange: Boolean = + SQLconf.getConfString(OMNI_BROADCAST_EXCHANGE_ENABLE_KEY, "true").toBoolean + + plan.foreach { + case exec: BroadcastExchangeExec => + if (!enableColumnarBroadcastExchange) { + TransformHints.tagNotTransformable(exec) + logInfo(s"BroadcastExchange falls back for disable ColumnarBroadcastExchange") + } else { + // check whether support ColumnarBroadcastExchangeExec or not + try { + new ColumnarBroadcastExchangeExec( + exec.mode, + exec.child + ).buildCheck() + } catch { + case t: Throwable => + TransformHints.tagNotTransformable(exec) + logDebug(s"BroadcastExchange falls back for ColumnarBroadcastExchangeExec: ${t}") + } + } + case _ => + } + plan + } +} -- Gitee From 5483979846654e4313039dfe76fc3d2e002193f4 Mon Sep 17 00:00:00 2001 From: hyy_cyan Date: Fri, 13 Dec 2024 15:29:39 +0800 Subject: [PATCH 5/5] fix compile --- .../com/huawei/boostkit/spark/ColumnarPlugin.scala | 14 ++------------ 1 file changed, 2 insertions(+), 12 deletions(-) diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarPlugin.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarPlugin.scala index 7744566bf..865cdd550 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarPlugin.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarPlugin.scala @@ -28,14 +28,13 @@ import org.apache.spark.sql.catalyst.expressions.{Ascending, Expression, Literal import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Partial, PartialMerge} import org.apache.spark.sql.catalyst.optimizer.{DelayCartesianProduct, HeuristicJoinReorder, MergeSubqueryFilters, RewriteSelfJoinInInPredicate} import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.execution.{BroadcastExchangeExecProxy, RowToOmniColumnarExec, _} +import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.adaptive.{AQEShuffleReadExec, AdaptiveSparkPlanExec, BroadcastQueryStageExec, OmniAQEShuffleReadExec, QueryStageExec, ShuffleQueryStageExec} import org.apache.spark.sql.execution.aggregate.{DummyLogicalPlan, ExtendedAggUtils, HashAggregateExec} import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec} import org.apache.spark.sql.execution.joins._ import org.apache.spark.sql.execution.window.{TopNPushDownForWindow, WindowExec} import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.ColumnarBatchSupportUtil.checkColumnarBatchSupport import org.apache.spark.sql.catalyst.planning.PhysicalAggregation import org.apache.spark.sql.catalyst.plans.LeftSemi import org.apache.spark.sql.catalyst.plans.logical.Aggregate @@ -613,14 +612,6 @@ case class ColumnarPostOverrides(isSupportAdaptive: Boolean = true) extends Rule case _ => replaceColumnarToRow(plan, conf) } - case plan: BroadcastExchangeExecProxy => - val children = plan.children.map { - case c: ColumnarToRowExec => - replaceWithColumnarPlan(c.child) - case other => - replaceWithColumnarPlan(other) - } - plan.withNewChildren(children) case r: SparkPlan if !r.isInstanceOf[QueryStageExec] && !r.supportsColumnar && r.children.exists(c => c.isInstanceOf[ColumnarToRowExec]) => @@ -775,7 +766,6 @@ class ColumnarPlugin extends (SparkSessionExtensions => Unit) with Logging { extensions.injectOptimizerRule(_ => DelayCartesianProduct) extensions.injectOptimizerRule(_ => HeuristicJoinReorder) extensions.injectOptimizerRule(_ => MergeSubqueryFilters) - extensions.injectQueryStagePrepRule(session => DedupLeftSemiJoinAQE(session)) extensions.injectQueryStagePrepRule(_ => TopNPushDownForWindow) extensions.injectQueryStagePrepRule(session => FallbackBroadcastExchange(session)) } @@ -784,7 +774,7 @@ class ColumnarPlugin extends (SparkSessionExtensions => Unit) with Logging { private class OmniTaskStartExecutorPlugin extends ExecutorPlugin { override def onTaskStart(): Unit = { addLeakSafeTaskCompletionListener[Unit](_ => { - MemoryManager.reclaimMemory() + MemoryManager.clearMemory() }) } } -- Gitee