diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarPlugin.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarPlugin.scala index e38bede42479c0e8dca652cee8ba03433f9e435a..865cdd550c758b34839c3385ca6d1dee90eaa7b5 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarPlugin.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarPlugin.scala @@ -18,67 +18,45 @@ package com.huawei.boostkit.spark +import com.huawei.boostkit.spark.Constant.OMNI_IS_ADAPTIVE_CONTEXT import com.huawei.boostkit.spark.expression.OmniExpressionAdaptor import com.huawei.boostkit.spark.util.PhysicalPlanSelector -import nova.hetu.omniruntime.memory.MemoryManager import org.apache.spark.api.plugin.{DriverPlugin, ExecutorPlugin, SparkPlugin} import org.apache.spark.internal.Logging import org.apache.spark.sql.{SparkSession, SparkSessionExtensions} -import org.apache.spark.sql.catalyst.expressions.{Ascending, DynamicPruningSubquery, Expression, Literal, SortOrder} +import org.apache.spark.sql.catalyst.expressions.{Ascending, Expression, Literal, SortOrder} import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Partial, PartialMerge} import org.apache.spark.sql.catalyst.optimizer.{DelayCartesianProduct, HeuristicJoinReorder, MergeSubqueryFilters, RewriteSelfJoinInInPredicate} import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.execution.{RowToOmniColumnarExec, _} -import org.apache.spark.sql.execution.adaptive.{AQEShuffleReadExec, BroadcastQueryStageExec, OmniAQEShuffleReadExec, QueryStageExec, ShuffleQueryStageExec} +import org.apache.spark.sql.execution._ +import org.apache.spark.sql.execution.adaptive.{AQEShuffleReadExec, AdaptiveSparkPlanExec, BroadcastQueryStageExec, OmniAQEShuffleReadExec, QueryStageExec, ShuffleQueryStageExec} import org.apache.spark.sql.execution.aggregate.{DummyLogicalPlan, ExtendedAggUtils, HashAggregateExec} -import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, Exchange, ReusedExchangeExec, ShuffleExchangeExec} +import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec} import org.apache.spark.sql.execution.joins._ -import org.apache.spark.sql.execution.window.{WindowExec, TopNPushDownForWindow} +import org.apache.spark.sql.execution.window.{TopNPushDownForWindow, WindowExec} import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.ColumnarBatchSupportUtil.checkColumnarBatchSupport import org.apache.spark.sql.catalyst.planning.PhysicalAggregation import org.apache.spark.sql.catalyst.plans.LeftSemi import org.apache.spark.sql.catalyst.plans.logical.Aggregate import org.apache.spark.sql.execution.util.SparkMemoryUtils.addLeakSafeTaskCompletionListener -case class ColumnarPreOverrides() extends Rule[SparkPlan] { +import scala.collection.mutable.ListBuffer + +import nova.hetu.omniruntime.memory.MemoryManager + +case class ColumnarPreOverrides(isSupportAdaptive: Boolean = true) + extends Rule[SparkPlan] { val columnarConf: ColumnarPluginConfig = ColumnarPluginConfig.getSessionConf - val enableColumnarFileScan: Boolean = columnarConf.enableColumnarFileScan - val enableColumnarProject: Boolean = columnarConf.enableColumnarProject - val enableColumnarFilter: Boolean = columnarConf.enableColumnarFilter - val enableColumnarExpand: Boolean = columnarConf.enableColumnarExpand - val enableColumnarHashAgg: Boolean = columnarConf.enableColumnarHashAgg - val enableTakeOrderedAndProject: Boolean = columnarConf.enableTakeOrderedAndProject && - columnarConf.enableColumnarShuffle - val enableColumnarBroadcastExchange: Boolean = columnarConf.enableColumnarBroadcastExchange && - columnarConf.enableColumnarBroadcastJoin - val enableColumnarBroadcastJoin: Boolean = columnarConf.enableColumnarBroadcastJoin && - columnarConf.enableColumnarBroadcastExchange - val enableColumnarSortMergeJoin: Boolean = columnarConf.enableColumnarSortMergeJoin - val enableColumnarTopNSort: Boolean = columnarConf.enableColumnarTopNSort - val enableColumnarSort: Boolean = columnarConf.enableColumnarSort - val enableColumnarWindow: Boolean = columnarConf.enableColumnarWindow + val enableColumnarShuffle: Boolean = columnarConf.enableColumnarShuffle - val enableShuffledHashJoin: Boolean = columnarConf.enableShuffledHashJoin - val enableColumnarUnion: Boolean = columnarConf.enableColumnarUnion val enableFusion: Boolean = columnarConf.enableFusion - var isSupportAdaptive: Boolean = true val enableColumnarProjectFusion: Boolean = columnarConf.enableColumnarProjectFusion - val enableLocalColumnarLimit: Boolean = columnarConf.enableLocalColumnarLimit - val enableGlobalColumnarLimit: Boolean = columnarConf.enableGlobalColumnarLimit val enableDedupLeftSemiJoin: Boolean = columnarConf.enableDedupLeftSemiJoin val dedupLeftSemiJoinThreshold: Int = columnarConf.dedupLeftSemiJoinThreshold - val enableColumnarCoalesce: Boolean = columnarConf.enableColumnarCoalesce val enableRollupOptimization: Boolean = columnarConf.enableRollupOptimization val enableRowShuffle: Boolean = columnarConf.enableRowShuffle val columnsThreshold: Int = columnarConf.columnsThreshold - def apply(plan: SparkPlan): SparkPlan = { - replaceWithColumnarPlan(plan) - } - - def setAdaptiveSupport(enable: Boolean): Unit = { isSupportAdaptive = enable } - def checkBhjRightChild(x: Any): Boolean = { x match { case _: ColumnarFilterExec | _: ColumnarConditionProjectExec => true @@ -86,212 +64,278 @@ case class ColumnarPreOverrides() extends Rule[SparkPlan] { } } - def replaceWithColumnarPlan(plan: SparkPlan): SparkPlan = plan match { - case plan: RowGuard => - val actualPlan: SparkPlan = plan.child match { - case p: BroadcastHashJoinExec => - p.withNewChildren(p.children.map { - case RowGuard(queryStage: BroadcastQueryStageExec) => - fallBackBroadcastQueryStage(queryStage) - case queryStage: BroadcastQueryStageExec => - fallBackBroadcastQueryStage(queryStage) - case plan: BroadcastExchangeExec => - // if BroadcastHashJoin is row-based, BroadcastExchange should also be row-based - RowGuard(plan) - case other => other - }) - case p: BroadcastNestedLoopJoinExec => - p.withNewChildren(p.children.map { - case RowGuard(queryStage: BroadcastQueryStageExec) => - fallBackBroadcastQueryStage(queryStage) - case queryStage: BroadcastQueryStageExec => - fallBackBroadcastQueryStage(queryStage) - case plan: BroadcastExchangeExec => - // if BroadcastNestedLoopJoin is row-based, BroadcastExchange should also be row-based - RowGuard(plan) - case other => other - }) - case other => - other - } - logDebug(s"Columnar Processing for ${actualPlan.getClass} is under RowGuard.") - actualPlan.withNewChildren(actualPlan.children.map(replaceWithColumnarPlan)) - case plan: FileSourceScanExec - if enableColumnarFileScan && checkColumnarBatchSupport(conf, plan) => - logInfo(s"Columnar Processing for ${plan.getClass} is currently supported.") - ColumnarFileSourceScanExec( - plan.relation, - plan.output, - plan.requiredSchema, - plan.partitionFilters, - plan.optionalBucketSet, - plan.optionalNumCoalescedBuckets, - plan.dataFilters, - plan.tableIdentifier, - plan.disableBucketedScan - ) - case range: RangeExec => - new ColumnarRangeExec(range.range) - case plan: ProjectExec if enableColumnarProject => - val child = replaceWithColumnarPlan(plan.child) - logInfo(s"Columnar Processing for ${plan.getClass} is currently supported.") - child match { - case ColumnarFilterExec(condition, child) => - ColumnarConditionProjectExec(plan.projectList, condition, child) - case join : ColumnarBroadcastHashJoinExec => - if (plan.projectList.forall(project => OmniExpressionAdaptor.isSimpleProjectForAll(project)) && enableColumnarProjectFusion) { - ColumnarBroadcastHashJoinExec( - join.leftKeys, - join.rightKeys, - join.joinType, - join.buildSide, - join.condition, - join.left, - join.right, - join.isNullAwareAntiJoin, - plan.projectList) - } else { - ColumnarProjectExec(plan.projectList, child) - } - case join : ColumnarShuffledHashJoinExec => - if (plan.projectList.forall(project => OmniExpressionAdaptor.isSimpleProjectForAll(project)) && enableColumnarProjectFusion) { - ColumnarShuffledHashJoinExec( - join.leftKeys, - join.rightKeys, - join.joinType, - join.buildSide, - join.condition, - join.left, - join.right, - join.isSkewJoin, - plan.projectList) - } else { + def replaceWithColumnarPlan(plan: SparkPlan): SparkPlan = { + TransformHints.getHint(plan) match { + case _: TRANSFORM_SUPPORTED => + // supported, break + case _: TRANSFORM_UNSUPPORTED => + logDebug(s"Columnar Processing for ${plan.getClass} is under RowGuard.") + return plan.withNewChildren( + plan.children.map(replaceWithColumnarPlan)) + } + plan match { + case plan: FileSourceScanExec => + logInfo(s"Columnar Processing for ${plan.getClass} is currently supported.") + ColumnarFileSourceScanExec( + plan.relation, + plan.output, + plan.requiredSchema, + plan.partitionFilters, + plan.optionalBucketSet, + plan.optionalNumCoalescedBuckets, + plan.dataFilters, + plan.tableIdentifier, + plan.disableBucketedScan + ) + case range: RangeExec => + new ColumnarRangeExec(range.range) + case plan: ProjectExec => + val child = replaceWithColumnarPlan(plan.child) + logInfo(s"Columnar Processing for ${plan.getClass} is currently supported.") + child match { + case ColumnarFilterExec(condition, child) => + ColumnarConditionProjectExec(plan.projectList, condition, child) + case join: ColumnarBroadcastHashJoinExec => + if (plan.projectList.forall(project => OmniExpressionAdaptor.isSimpleProjectForAll(project)) && enableColumnarProjectFusion) { + ColumnarBroadcastHashJoinExec( + join.leftKeys, + join.rightKeys, + join.joinType, + join.buildSide, + join.condition, + join.left, + join.right, + join.isNullAwareAntiJoin, + plan.projectList) + } else { + ColumnarProjectExec(plan.projectList, child) + } + case join: ColumnarShuffledHashJoinExec => + if (plan.projectList.forall(project => OmniExpressionAdaptor.isSimpleProjectForAll(project)) && enableColumnarProjectFusion) { + ColumnarShuffledHashJoinExec( + join.leftKeys, + join.rightKeys, + join.joinType, + join.buildSide, + join.condition, + join.left, + join.right, + join.isSkewJoin, + plan.projectList) + } else { + ColumnarProjectExec(plan.projectList, child) + } + case join: ColumnarSortMergeJoinExec => + if (plan.projectList.forall(project => OmniExpressionAdaptor.isSimpleProjectForAll(project)) && enableColumnarProjectFusion) { + ColumnarSortMergeJoinExec( + join.leftKeys, + join.rightKeys, + join.joinType, + join.condition, + join.left, + join.right, + join.isSkewJoin, + plan.projectList) + } else { + ColumnarProjectExec(plan.projectList, child) + } + case _ => ColumnarProjectExec(plan.projectList, child) - } - case join : ColumnarSortMergeJoinExec => - if (plan.projectList.forall(project => OmniExpressionAdaptor.isSimpleProjectForAll(project)) && enableColumnarProjectFusion) { - ColumnarSortMergeJoinExec( - join.leftKeys, - join.rightKeys, - join.joinType, - join.condition, - join.left, - join.right, - join.isSkewJoin, - plan.projectList) + } + case plan: FilterExec => + val child = replaceWithColumnarPlan(plan.child) + logInfo(s"Columnar Processing for ${plan.getClass} is currently supported.") + ColumnarFilterExec(plan.condition, child) + case plan: ExpandExec => + val child = replaceWithColumnarPlan(plan.child) + logInfo(s"Columnar Processing for ${plan.getClass} is currently supported.") + ColumnarExpandExec(plan.projections, plan.output, child) + case plan: HashAggregateExec => + val child = replaceWithColumnarPlan(plan.child) + logInfo(s"Columnar Processing for ${plan.getClass} is currently supported.") + if (enableFusion && !isSupportAdaptive) { + if (plan.aggregateExpressions.forall(_.mode == Partial)) { + child match { + case proj1@ColumnarProjectExec(_, + join1@ColumnarBroadcastHashJoinExec(_, _, _, _, _, + proj2@ColumnarProjectExec(_, + join2@ColumnarBroadcastHashJoinExec(_, _, _, _, _, + proj3@ColumnarProjectExec(_, + join3@ColumnarBroadcastHashJoinExec(_, _, _, _, _, + proj4@ColumnarProjectExec(_, + join4@ColumnarBroadcastHashJoinExec(_, _, _, _, _, + filter@ColumnarFilterExec(_, + scan@ColumnarFileSourceScanExec(_, _, _, _, _, _, _, _, _) + ), _, _, _)), _, _, _)), _, _, _)), _, _, _)) + if checkBhjRightChild( + child.asInstanceOf[ColumnarProjectExec].child.children(1) + .asInstanceOf[ColumnarBroadcastExchangeExec].child) => + ColumnarMultipleOperatorExec( + plan, + proj1, + join1, + proj2, + join2, + proj3, + join3, + proj4, + join4, + filter, + scan.relation, + plan.output, + scan.requiredSchema, + scan.partitionFilters, + scan.optionalBucketSet, + scan.optionalNumCoalescedBuckets, + scan.dataFilters, + scan.tableIdentifier, + scan.disableBucketedScan) + case proj1@ColumnarProjectExec(_, + join1@ColumnarBroadcastHashJoinExec(_, _, _, _, _, + proj2@ColumnarProjectExec(_, + join2@ColumnarBroadcastHashJoinExec(_, _, _, _, _, + proj3@ColumnarProjectExec(_, + join3@ColumnarBroadcastHashJoinExec(_, _, _, _, _, _, + filter@ColumnarFilterExec(_, + scan@ColumnarFileSourceScanExec(_, _, _, _, _, _, _, _, _)), _, _)), _, _, _)), _, _, _)) + if checkBhjRightChild( + child.asInstanceOf[ColumnarProjectExec].child.children(1) + .asInstanceOf[ColumnarBroadcastExchangeExec].child) => + ColumnarMultipleOperatorExec1( + plan, + proj1, + join1, + proj2, + join2, + proj3, + join3, + filter, + scan.relation, + plan.output, + scan.requiredSchema, + scan.partitionFilters, + scan.optionalBucketSet, + scan.optionalNumCoalescedBuckets, + scan.dataFilters, + scan.tableIdentifier, + scan.disableBucketedScan) + case proj1@ColumnarProjectExec(_, + join1@ColumnarBroadcastHashJoinExec(_, _, _, _, _, + proj2@ColumnarProjectExec(_, + join2@ColumnarBroadcastHashJoinExec(_, _, _, _, _, + proj3@ColumnarProjectExec(_, + join3@ColumnarBroadcastHashJoinExec(_, _, _, _, _, + filter@ColumnarFilterExec(_, + scan@ColumnarFileSourceScanExec(_, _, _, _, _, _, _, _, _)), _, _, _)), _, _, _)), _, _, _)) + if checkBhjRightChild( + child.asInstanceOf[ColumnarProjectExec].child.children(1) + .asInstanceOf[ColumnarBroadcastExchangeExec].child) => + ColumnarMultipleOperatorExec1( + plan, + proj1, + join1, + proj2, + join2, + proj3, + join3, + filter, + scan.relation, + plan.output, + scan.requiredSchema, + scan.partitionFilters, + scan.optionalBucketSet, + scan.optionalNumCoalescedBuckets, + scan.dataFilters, + scan.tableIdentifier, + scan.disableBucketedScan) + case _ => + new ColumnarHashAggregateExec( + plan.requiredChildDistributionExpressions, + plan.isStreaming, + plan.numShufflePartitions, + plan.groupingExpressions, + plan.aggregateExpressions, + plan.aggregateAttributes, + plan.initialInputBufferOffset, + plan.resultExpressions, + child) + } } else { - ColumnarProjectExec(plan.projectList, child) + new ColumnarHashAggregateExec( + plan.requiredChildDistributionExpressions, + plan.isStreaming, + plan.numShufflePartitions, + plan.groupingExpressions, + plan.aggregateExpressions, + plan.aggregateAttributes, + plan.initialInputBufferOffset, + plan.resultExpressions, + child) } - case _ => - ColumnarProjectExec(plan.projectList, child) - } - case plan: FilterExec if enableColumnarFilter => - val child = replaceWithColumnarPlan(plan.child) - logInfo(s"Columnar Processing for ${plan.getClass} is currently supported.") - ColumnarFilterExec(plan.condition, child) - case plan: ExpandExec if enableColumnarExpand => - val child = replaceWithColumnarPlan(plan.child) - logInfo(s"Columnar Processing for ${plan.getClass} is currently supported.") - ColumnarExpandExec(plan.projections, plan.output, child) - case plan: HashAggregateExec if enableColumnarHashAgg => - val child = replaceWithColumnarPlan(plan.child) - logInfo(s"Columnar Processing for ${plan.getClass} is currently supported.") - if (enableFusion && !isSupportAdaptive) { - if (plan.aggregateExpressions.forall(_.mode == Partial)) { - child match { - case proj1 @ ColumnarProjectExec(_, - join1 @ ColumnarBroadcastHashJoinExec(_, _, _, _, _, - proj2 @ ColumnarProjectExec(_, - join2 @ ColumnarBroadcastHashJoinExec(_, _, _, _, _, - proj3 @ ColumnarProjectExec(_, - join3 @ ColumnarBroadcastHashJoinExec(_, _, _, _, _, - proj4 @ ColumnarProjectExec(_, - join4 @ ColumnarBroadcastHashJoinExec(_, _, _, _, _, - filter @ ColumnarFilterExec(_, - scan @ ColumnarFileSourceScanExec(_, _, _, _, _, _, _, _, _) - ), _, _, _)), _, _, _)), _, _, _)), _, _, _)) - if checkBhjRightChild( - child.asInstanceOf[ColumnarProjectExec].child.children(1) - .asInstanceOf[ColumnarBroadcastExchangeExec].child) => - ColumnarMultipleOperatorExec( - plan, - proj1, - join1, - proj2, - join2, - proj3, - join3, - proj4, - join4, - filter, - scan.relation, - plan.output, - scan.requiredSchema, - scan.partitionFilters, - scan.optionalBucketSet, - scan.optionalNumCoalescedBuckets, - scan.dataFilters, - scan.tableIdentifier, - scan.disableBucketedScan) - case proj1 @ ColumnarProjectExec(_, - join1 @ ColumnarBroadcastHashJoinExec(_, _, _, _, _, - proj2 @ ColumnarProjectExec(_, - join2 @ ColumnarBroadcastHashJoinExec(_, _, _, _, _, - proj3 @ ColumnarProjectExec(_, - join3 @ ColumnarBroadcastHashJoinExec(_, _, _, _, _, _, - filter @ ColumnarFilterExec(_, - scan @ ColumnarFileSourceScanExec(_, _, _, _, _, _, _, _, _)), _, _)) , _, _, _)), _, _, _)) - if checkBhjRightChild( - child.asInstanceOf[ColumnarProjectExec].child.children(1) - .asInstanceOf[ColumnarBroadcastExchangeExec].child) => - ColumnarMultipleOperatorExec1( - plan, - proj1, - join1, - proj2, - join2, - proj3, - join3, - filter, - scan.relation, - plan.output, - scan.requiredSchema, - scan.partitionFilters, - scan.optionalBucketSet, - scan.optionalNumCoalescedBuckets, - scan.dataFilters, - scan.tableIdentifier, - scan.disableBucketedScan) - case proj1 @ ColumnarProjectExec(_, - join1 @ ColumnarBroadcastHashJoinExec(_, _, _, _, _, - proj2 @ ColumnarProjectExec(_, - join2 @ ColumnarBroadcastHashJoinExec(_, _, _, _, _, - proj3 @ ColumnarProjectExec(_, - join3 @ ColumnarBroadcastHashJoinExec(_, _, _, _, _, - filter @ ColumnarFilterExec(_, - scan @ ColumnarFileSourceScanExec(_, _, _, _, _, _, _, _, _)), _, _, _)) , _, _, _)), _, _, _)) - if checkBhjRightChild( - child.asInstanceOf[ColumnarProjectExec].child.children(1) - .asInstanceOf[ColumnarBroadcastExchangeExec].child) => - ColumnarMultipleOperatorExec1( - plan, - proj1, - join1, - proj2, - join2, - proj3, - join3, - filter, - scan.relation, + } else { + if (child.isInstanceOf[ColumnarExpandExec]) { + var columnarExpandExec = child.asInstanceOf[ColumnarExpandExec] + val matchRollupOptimization: Boolean = columnarExpandExec.matchRollupOptimization() + if (matchRollupOptimization && enableRollupOptimization) { + // The sparkPlan: ColumnarExpandExec -> ColumnarHashAggExec => ColumnarExpandExec -> ColumnarHashAggExec -> ColumnarOptRollupExec. + // ColumnarHashAggExec handles the first combination by Partial mode, i.e. projections[0]. + // ColumnarOptRollupExec handles the residual combinations by PartialMerge mode, i.e. projections[1]~projections[n]. + val projections = columnarExpandExec.projections + val headProjections = projections.slice(0, 1) + var residualProjections = projections.slice(1, projections.length) + // replace parameters + columnarExpandExec = columnarExpandExec.replace(headProjections) + + // partial + val partialHashAggExec = new ColumnarHashAggregateExec( + plan.requiredChildDistributionExpressions, + plan.isStreaming, + plan.numShufflePartitions, + plan.groupingExpressions, + plan.aggregateExpressions, + plan.aggregateAttributes, + plan.initialInputBufferOffset, + plan.resultExpressions, + columnarExpandExec) + + + // If the aggregator has an expression, more than one column in the projection is used + // for expression calculation. Meanwhile, If the single distinct syntax exists, the + // sequence of group columns is disordered. Therefore, we need to calculate the sequence + // of expandSeq first to ensure the project operator correctly processes the columns. + val expectSeq = plan.resultExpressions + val expandSeq = columnarExpandExec.output + // the processing sequences of expandSeq + residualProjections = residualProjections.map(projection => { + val indexSeq: Seq[Expression] = expectSeq.map(expectExpr => { + val index = expandSeq.indexWhere(expandExpr => expectExpr.exprId.equals(expandExpr.exprId)) + if (index != -1) { + projection.apply(index) match { + case literal: Literal => literal + case _ => expectExpr + } + } else { + expectExpr + } + }) + indexSeq + }) + + // partial merge + val groupingExpressions = plan.resultExpressions.slice(0, plan.groupingExpressions.length) + val aggregateExpressions = plan.aggregateExpressions.map(expr => { + expr.copy(expr.aggregateFunction, PartialMerge, expr.isDistinct, expr.filter, expr.resultId) + }) + + // need ExpandExec parameters and HashAggExec parameters + new ColumnarOptRollupExec( + residualProjections, plan.output, - scan.requiredSchema, - scan.partitionFilters, - scan.optionalBucketSet, - scan.optionalNumCoalescedBuckets, - scan.dataFilters, - scan.tableIdentifier, - scan.disableBucketedScan) - case _ => + groupingExpressions, + aggregateExpressions, + plan.aggregateAttributes, + partialHashAggExec) + } else { new ColumnarHashAggregateExec( plan.requiredChildDistributionExpressions, plan.isStreaming, @@ -302,82 +346,7 @@ case class ColumnarPreOverrides() extends Rule[SparkPlan] { plan.initialInputBufferOffset, plan.resultExpressions, child) - } - } else { - new ColumnarHashAggregateExec( - plan.requiredChildDistributionExpressions, - plan.isStreaming, - plan.numShufflePartitions, - plan.groupingExpressions, - plan.aggregateExpressions, - plan.aggregateAttributes, - plan.initialInputBufferOffset, - plan.resultExpressions, - child) - } - } else { - if (child.isInstanceOf[ColumnarExpandExec]) { - var columnarExpandExec = child.asInstanceOf[ColumnarExpandExec] - val matchRollupOptimization: Boolean = columnarExpandExec.matchRollupOptimization() - if (matchRollupOptimization && enableRollupOptimization) { - // The sparkPlan: ColumnarExpandExec -> ColumnarHashAggExec => ColumnarExpandExec -> ColumnarHashAggExec -> ColumnarOptRollupExec. - // ColumnarHashAggExec handles the first combination by Partial mode, i.e. projections[0]. - // ColumnarOptRollupExec handles the residual combinations by PartialMerge mode, i.e. projections[1]~projections[n]. - val projections = columnarExpandExec.projections - val headProjections = projections.slice(0, 1) - var residualProjections = projections.slice(1, projections.length) - // replace parameters - columnarExpandExec = columnarExpandExec.replace(headProjections) - - // partial - val partialHashAggExec = new ColumnarHashAggregateExec( - plan.requiredChildDistributionExpressions, - plan.isStreaming, - plan.numShufflePartitions, - plan.groupingExpressions, - plan.aggregateExpressions, - plan.aggregateAttributes, - plan.initialInputBufferOffset, - plan.resultExpressions, - columnarExpandExec) - - - // If the aggregator has an expression, more than one column in the projection is used - // for expression calculation. Meanwhile, If the single distinct syntax exists, the - // sequence of group columns is disordered. Therefore, we need to calculate the sequence - // of expandSeq first to ensure the project operator correctly processes the columns. - val expectSeq = plan.resultExpressions - val expandSeq = columnarExpandExec.output - // the processing sequences of expandSeq - residualProjections = residualProjections.map(projection => { - val indexSeq: Seq[Expression] = expectSeq.map(expectExpr => { - val index = expandSeq.indexWhere(expandExpr => expectExpr.exprId.equals(expandExpr.exprId)) - if (index != -1) { - projection.apply(index) match { - case literal: Literal => literal - case _ => expectExpr - } - } else { - expectExpr - } - }) - indexSeq - }) - - // partial merge - val groupingExpressions = plan.resultExpressions.slice(0, plan.groupingExpressions.length) - val aggregateExpressions = plan.aggregateExpressions.map(expr => { - expr.copy(expr.aggregateFunction, PartialMerge, expr.isDistinct, expr.filter, expr.resultId) - }) - - // need ExpandExec parameters and HashAggExec parameters - new ColumnarOptRollupExec( - residualProjections, - plan.output, - groupingExpressions, - aggregateExpressions, - plan.aggregateAttributes, - partialHashAggExec) + } } else { new ColumnarHashAggregateExec( plan.requiredChildDistributionExpressions, @@ -390,81 +359,84 @@ case class ColumnarPreOverrides() extends Rule[SparkPlan] { plan.resultExpressions, child) } - } else { - new ColumnarHashAggregateExec( - plan.requiredChildDistributionExpressions, - plan.isStreaming, - plan.numShufflePartitions, - plan.groupingExpressions, - plan.aggregateExpressions, - plan.aggregateAttributes, - plan.initialInputBufferOffset, - plan.resultExpressions, - child) } - } - case plan: TakeOrderedAndProjectExec if enableTakeOrderedAndProject => - val child = replaceWithColumnarPlan(plan.child) - logInfo(s"Columnar Processing for ${plan.getClass} is currently supported.") - ColumnarTakeOrderedAndProjectExec( - plan.limit, - plan.sortOrder, - plan.projectList, - child) - case plan: BroadcastExchangeExec if enableColumnarBroadcastExchange => - val child = replaceWithColumnarPlan(plan.child) - logInfo(s"Columnar Processing for ${plan.getClass} is currently supported.") - new ColumnarBroadcastExchangeExec(plan.mode, child) - case plan: BroadcastHashJoinExec if enableColumnarBroadcastJoin => - logInfo(s"Columnar Processing for ${plan.getClass} is currently supported.") - val left = replaceWithColumnarPlan(plan.left) - val right = replaceWithColumnarPlan(plan.right) - logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.") - ColumnarBroadcastHashJoinExec( - plan.leftKeys, - plan.rightKeys, - plan.joinType, - plan.buildSide, - plan.condition, - left, - right) - case plan: ShuffledHashJoinExec if enableShuffledHashJoin && enableDedupLeftSemiJoin => { - plan.joinType match { - case LeftSemi => { - if (plan.condition.isEmpty && plan.right.output.size >= dedupLeftSemiJoinThreshold) { - val left = replaceWithColumnarPlan(plan.left) - val right = replaceWithColumnarPlan(plan.right) - val partialAgg = PhysicalAggregation.unapply(Aggregate(plan.right.output, plan.right.output, new DummyLogicalPlan)) match { - case Some((groupingExpressions, aggExpressions, resultExpressions, _)) - if aggExpressions.forall(expr => expr.isInstanceOf[AggregateExpression]) => - ExtendedAggUtils.planPartialAggregateWithoutDistinct( - ExtendedAggUtils.normalizeGroupingExpressions(groupingExpressions), - aggExpressions.map(_.asInstanceOf[AggregateExpression]), - resultExpressions, - right).asInstanceOf[HashAggregateExec] + case plan: TakeOrderedAndProjectExec => + val child = replaceWithColumnarPlan(plan.child) + logInfo(s"Columnar Processing for ${plan.getClass} is currently supported.") + ColumnarTakeOrderedAndProjectExec( + plan.limit, + plan.sortOrder, + plan.projectList, + child) + case plan: BroadcastExchangeExec => + val child = replaceWithColumnarPlan(plan.child) + logInfo(s"Columnar Processing for ${plan.getClass} is currently supported.") + new ColumnarBroadcastExchangeExec(plan.mode, child) + case plan: BroadcastHashJoinExec => + logInfo(s"Columnar Processing for ${plan.getClass} is currently supported.") + val left = replaceWithColumnarPlan(plan.left) + val right = replaceWithColumnarPlan(plan.right) + logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.") + ColumnarBroadcastHashJoinExec( + plan.leftKeys, + plan.rightKeys, + plan.joinType, + plan.buildSide, + plan.condition, + left, + right) + case plan: ShuffledHashJoinExec if enableDedupLeftSemiJoin && !SQLConf.get.adaptiveExecutionEnabled => { + plan.joinType match { + case LeftSemi => { + if (plan.condition.isEmpty && plan.right.output.size >= dedupLeftSemiJoinThreshold) { + val left = replaceWithColumnarPlan(plan.left) + val right = replaceWithColumnarPlan(plan.right) + val partialAgg = PhysicalAggregation.unapply(Aggregate(plan.right.output, plan.right.output, new DummyLogicalPlan)) match { + case Some((groupingExpressions, aggExpressions, resultExpressions, _)) + if aggExpressions.forall(expr => expr.isInstanceOf[AggregateExpression]) => + ExtendedAggUtils.planPartialAggregateWithoutDistinct( + ExtendedAggUtils.normalizeGroupingExpressions(groupingExpressions), + aggExpressions.map(_.asInstanceOf[AggregateExpression]), + resultExpressions, + right).asInstanceOf[HashAggregateExec] + } + val newHashAgg = new ColumnarHashAggregateExec( + partialAgg.requiredChildDistributionExpressions, + partialAgg.isStreaming, + partialAgg.numShufflePartitions, + partialAgg.groupingExpressions, + partialAgg.aggregateExpressions, + partialAgg.aggregateAttributes, + partialAgg.initialInputBufferOffset, + partialAgg.resultExpressions, + right) + + ColumnarShuffledHashJoinExec( + plan.leftKeys, + plan.rightKeys, + plan.joinType, + plan.buildSide, + plan.condition, + left, + newHashAgg, + plan.isSkewJoin) + } else { + val left = replaceWithColumnarPlan(plan.left) + val right = replaceWithColumnarPlan(plan.right) + logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.") + ColumnarShuffledHashJoinExec( + plan.leftKeys, + plan.rightKeys, + plan.joinType, + plan.buildSide, + plan.condition, + left, + right, + plan.isSkewJoin) } - val newHashAgg = new ColumnarHashAggregateExec( - partialAgg.requiredChildDistributionExpressions, - partialAgg.isStreaming, - partialAgg.numShufflePartitions, - partialAgg.groupingExpressions, - partialAgg.aggregateExpressions, - partialAgg.aggregateAttributes, - partialAgg.initialInputBufferOffset, - partialAgg.resultExpressions, - right) - - ColumnarShuffledHashJoinExec( - plan.leftKeys, - plan.rightKeys, - plan.joinType, - plan.buildSide, - plan.condition, - left, - newHashAgg, - plan.isSkewJoin) - } else { + } + case _ => { val left = replaceWithColumnarPlan(plan.left) val right = replaceWithColumnarPlan(plan.right) logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.") @@ -479,167 +451,134 @@ case class ColumnarPreOverrides() extends Rule[SparkPlan] { plan.isSkewJoin) } } - case _ => { - val left = replaceWithColumnarPlan(plan.left) - val right = replaceWithColumnarPlan(plan.right) - logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.") - ColumnarShuffledHashJoinExec( - plan.leftKeys, - plan.rightKeys, - plan.joinType, - plan.buildSide, - plan.condition, - left, - right, - plan.isSkewJoin) - } } - } - case plan: ShuffledHashJoinExec if enableShuffledHashJoin => - val left = replaceWithColumnarPlan(plan.left) - val right = replaceWithColumnarPlan(plan.right) - logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.") - ColumnarShuffledHashJoinExec( - plan.leftKeys, - plan.rightKeys, - plan.joinType, - plan.buildSide, - plan.condition, - left, - right, - plan.isSkewJoin) - case plan: SortMergeJoinExec if enableColumnarSortMergeJoin => - logInfo(s"Columnar Processing for ${plan.getClass} is currently supported.") - val left = replaceWithColumnarPlan(plan.left) - val right = replaceWithColumnarPlan(plan.right) - logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.") - new ColumnarSortMergeJoinExec( - plan.leftKeys, - plan.rightKeys, - plan.joinType, - plan.condition, - left, - right, - plan.isSkewJoin) - case plan: TopNSortExec if enableColumnarTopNSort => - val child = replaceWithColumnarPlan(plan.child) - logInfo(s"Columnar Processing for ${plan.getClass} is currently supported.") - ColumnarTopNSortExec(plan.n, plan.strictTopN, plan.partitionSpec, plan.sortOrder, plan.global, child) - case plan: SortExec if enableColumnarSort => - val child = replaceWithColumnarPlan(plan.child) - logInfo(s"Columnar Processing for ${plan.getClass} is currently supported.") - ColumnarSortExec(plan.sortOrder, plan.global, child, plan.testSpillFrequency) - case plan: WindowExec if enableColumnarWindow => - val child = replaceWithColumnarPlan(plan.child) - if (child.output.isEmpty) { - return plan - } - logInfo(s"Columnar Processing for ${plan.getClass} is currently supported.") - child match { - case ColumnarSortExec(sortOrder, _, sortChild, _) => - if (Seq(plan.partitionSpec.map(SortOrder(_, Ascending)) ++ plan.orderSpec) == Seq(sortOrder)) { - ColumnarWindowExec(plan.windowExpression, plan.partitionSpec, plan.orderSpec, sortChild) - } else { + case plan: ShuffledHashJoinExec => + val left = replaceWithColumnarPlan(plan.left) + val right = replaceWithColumnarPlan(plan.right) + logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.") + ColumnarShuffledHashJoinExec( + plan.leftKeys, + plan.rightKeys, + plan.joinType, + plan.buildSide, + plan.condition, + left, + right, + plan.isSkewJoin) + case plan: SortMergeJoinExec => + logInfo(s"Columnar Processing for ${plan.getClass} is currently supported.") + val left = replaceWithColumnarPlan(plan.left) + val right = replaceWithColumnarPlan(plan.right) + logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.") + new ColumnarSortMergeJoinExec( + plan.leftKeys, + plan.rightKeys, + plan.joinType, + plan.condition, + left, + right, + plan.isSkewJoin) + case plan: TopNSortExec => + val child = replaceWithColumnarPlan(plan.child) + logInfo(s"Columnar Processing for ${plan.getClass} is currently supported.") + ColumnarTopNSortExec(plan.n, plan.strictTopN, plan.partitionSpec, plan.sortOrder, plan.global, child) + case plan: SortExec => + val child = replaceWithColumnarPlan(plan.child) + logInfo(s"Columnar Processing for ${plan.getClass} is currently supported.") + ColumnarSortExec(plan.sortOrder, plan.global, child, plan.testSpillFrequency) + case plan: WindowExec => + val child = replaceWithColumnarPlan(plan.child) + if (child.output.isEmpty) { + return plan + } + logInfo(s"Columnar Processing for ${plan.getClass} is currently supported.") + child match { + case ColumnarSortExec(sortOrder, _, sortChild, _) => + if (Seq(plan.partitionSpec.map(SortOrder(_, Ascending)) ++ plan.orderSpec) == Seq(sortOrder)) { + ColumnarWindowExec(plan.windowExpression, plan.partitionSpec, plan.orderSpec, sortChild) + } else { + ColumnarWindowExec(plan.windowExpression, plan.partitionSpec, plan.orderSpec, child) + } + case _ => ColumnarWindowExec(plan.windowExpression, plan.partitionSpec, plan.orderSpec, child) + } + case plan: UnionExec => + val children = plan.children.map(replaceWithColumnarPlan) + logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.") + ColumnarUnionExec(children) + case plan: ShuffleExchangeExec => + val child = replaceWithColumnarPlan(plan.child) + if (child.output.nonEmpty) { + logInfo(s"Columnar Processing for ${plan.getClass} is currently supported.") + if (child.output.size > columnsThreshold && enableRowShuffle) { + new ColumnarShuffleExchangeExec(plan.outputPartitioning, child, plan.shuffleOrigin, true) + } else { + new ColumnarShuffleExchangeExec(plan.outputPartitioning, child, plan.shuffleOrigin, false) } - case _ => - ColumnarWindowExec(plan.windowExpression, plan.partitionSpec, plan.orderSpec, child) - } - case plan: UnionExec if enableColumnarUnion => - val children = plan.children.map(replaceWithColumnarPlan) - logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.") - ColumnarUnionExec(children) - case plan: ShuffleExchangeExec if enableColumnarShuffle || enableRowShuffle => - val child = replaceWithColumnarPlan(plan.child) - if (child.output.nonEmpty) { - logInfo(s"Columnar Processing for ${plan.getClass} is currently supported.") - if (child.isInstanceOf[ColumnarHashAggregateExec] && child.output.size > columnsThreshold - && enableRowShuffle) { - new ColumnarShuffleExchangeExec(plan.outputPartitioning, child, plan.shuffleOrigin, true) - } else if (enableColumnarShuffle) { - new ColumnarShuffleExchangeExec(plan.outputPartitioning, child, plan.shuffleOrigin, false) } else { plan } - } else { - plan - } - case plan: AQEShuffleReadExec if columnarConf.enableColumnarShuffle => - plan.child match { - case shuffle: ColumnarShuffleExchangeExec => - logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.") - OmniAQEShuffleReadExec(plan.child, plan.partitionSpecs) - case ShuffleQueryStageExec(_, shuffle: ColumnarShuffleExchangeExec, _) => - logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.") - OmniAQEShuffleReadExec(plan.child, plan.partitionSpecs) - case ShuffleQueryStageExec(_, reused: ReusedExchangeExec, _) => - reused match { - case ReusedExchangeExec(_, shuffle: ColumnarShuffleExchangeExec) => - logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.") - OmniAQEShuffleReadExec( - plan.child, - plan.partitionSpecs) - case _ => - plan - } - case _ => - plan - } - case plan: LocalLimitExec if enableLocalColumnarLimit => - val child = replaceWithColumnarPlan(plan.child) - logInfo(s"Columnar Processing for ${plan.getClass} is currently supported.") - ColumnarLocalLimitExec(plan.limit, child) - case plan: GlobalLimitExec if enableGlobalColumnarLimit => - val child = replaceWithColumnarPlan(plan.child) - logInfo(s"Columnar Processing for ${plan.getClass} is currently supported.") - ColumnarGlobalLimitExec(plan.limit, child) - case plan: CoalesceExec if enableColumnarCoalesce => - val child = replaceWithColumnarPlan(plan.child) - logInfo(s"Columnar Processing for ${plan.getClass} is currently supported.") - ColumnarCoalesceExec(plan.numPartitions, child) - case p => - val children = plan.children.map(replaceWithColumnarPlan) - logInfo(s"Columnar Processing for ${p.getClass} is currently not supported.") - p.withNewChildren(children) + case plan: AQEShuffleReadExec if columnarConf.enableColumnarShuffle => + plan.child match { + case shuffle: ColumnarShuffleExchangeExec => + logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.") + OmniAQEShuffleReadExec(plan.child, plan.partitionSpecs) + case ShuffleQueryStageExec(_, shuffle: ColumnarShuffleExchangeExec, _) => + logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.") + OmniAQEShuffleReadExec(plan.child, plan.partitionSpecs) + case ShuffleQueryStageExec(_, reused: ReusedExchangeExec, _) => + reused match { + case ReusedExchangeExec(_, shuffle: ColumnarShuffleExchangeExec) => + logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.") + OmniAQEShuffleReadExec( + plan.child, + plan.partitionSpecs) + case _ => + plan + } + case _ => + plan + } + case plan: LocalLimitExec => + val child = replaceWithColumnarPlan(plan.child) + logInfo(s"Columnar Processing for ${plan.getClass} is currently supported.") + ColumnarLocalLimitExec(plan.limit, child) + case plan: GlobalLimitExec => + val child = replaceWithColumnarPlan(plan.child) + logInfo(s"Columnar Processing for ${plan.getClass} is currently supported.") + ColumnarGlobalLimitExec(plan.limit, child) + case plan: CoalesceExec => + val child = replaceWithColumnarPlan(plan.child) + logInfo(s"Columnar Processing for ${plan.getClass} is currently supported.") + ColumnarCoalesceExec(plan.numPartitions, child) + case p => + val children = plan.children.map(replaceWithColumnarPlan) + logInfo(s"Columnar Processing for ${p.getClass} is currently not supported.") + p.withNewChildren(children) + } } - def fallBackBroadcastQueryStage(curPlan: BroadcastQueryStageExec): BroadcastQueryStageExec = { - curPlan.plan match { - case originalBroadcastPlan: ColumnarBroadcastExchangeExec => - BroadcastQueryStageExec( - curPlan.id, - BroadcastExchangeExec( - originalBroadcastPlan.mode, - ColumnarBroadcastExchangeAdaptorExec(originalBroadcastPlan, 1)), - curPlan._canonicalized) - case ReusedExchangeExec(_, originalBroadcastPlan: ColumnarBroadcastExchangeExec) => - BroadcastQueryStageExec( - curPlan.id, - BroadcastExchangeExec( - originalBroadcastPlan.mode, - ColumnarBroadcastExchangeAdaptorExec(curPlan.plan, 1)), - curPlan._canonicalized) - case _ => - curPlan - } + override def apply(plan: SparkPlan): SparkPlan = { + logInfo("Using BoostKit Spark Native Sql Engine Extension ColumnarPreOverrides") + replaceWithColumnarPlan(plan) } } -case class ColumnarPostOverrides() extends Rule[SparkPlan] { +case class ColumnarPostOverrides(isSupportAdaptive: Boolean = true) extends Rule[SparkPlan] { val columnarConf: ColumnarPluginConfig = ColumnarPluginConfig.getSessionConf - var isSupportAdaptive: Boolean = true def apply(plan: SparkPlan): SparkPlan = { + logInfo("Using BoostKit Spark Native Sql Engine Extension ColumnarPostOverrides") handleColumnarToRowPartialFetch(replaceWithColumnarPlan(plan)) } private def handleColumnarToRowPartialFetch(plan: SparkPlan): SparkPlan = { // simple check plan tree have OmniColumnarToRow and no LimitExec and TakeOrderedAndProjectExec plan val noPartialFetch = if (plan.find(_.isInstanceOf[OmniColumnarToRowExec]).isDefined) { - (!plan.find(node => - node.isInstanceOf[LimitExec] || node.isInstanceOf[TakeOrderedAndProjectExec] || - node.isInstanceOf[SortMergeJoinExec]).isDefined) + (!plan.find(node => + node.isInstanceOf[LimitExec] || node.isInstanceOf[TakeOrderedAndProjectExec] || + node.isInstanceOf[SortMergeJoinExec]).isDefined) } else { false } @@ -650,14 +589,12 @@ case class ColumnarPostOverrides() extends Rule[SparkPlan] { newPlan } - def setAdaptiveSupport(enable: Boolean): Unit = { isSupportAdaptive = enable } - def replaceWithColumnarPlan(plan: SparkPlan): SparkPlan = plan match { case plan: RowToColumnarExec => val child = replaceWithColumnarPlan(plan.child) logInfo(s"Columnar Processing for ${plan.getClass} is currently supported") RowToOmniColumnarExec(child) - case ColumnarToRowExec(child: ColumnarShuffleExchangeExec) => + case ColumnarToRowExec(child: ColumnarShuffleExchangeExec) if isSupportAdaptive => replaceWithColumnarPlan(child) case ColumnarToRowExec(child: ColumnarBroadcastExchangeExec) => replaceWithColumnarPlan(child) @@ -666,17 +603,9 @@ case class ColumnarPostOverrides() extends Rule[SparkPlan] { case child: BroadcastQueryStageExec => child.plan match { case originalBroadcastPlan: ColumnarBroadcastExchangeExec => - BroadcastQueryStageExec( - child.id, - BroadcastExchangeExec( - originalBroadcastPlan.mode, - ColumnarBroadcastExchangeAdaptorExec(originalBroadcastPlan, 1)), child._canonicalized) + child case ReusedExchangeExec(_, originalBroadcastPlan: ColumnarBroadcastExchangeExec) => - BroadcastQueryStageExec( - child.id, - BroadcastExchangeExec( - originalBroadcastPlan.mode, - ColumnarBroadcastExchangeAdaptorExec(child.plan, 1)), child._canonicalized) + child case _ => replaceColumnarToRow(plan, conf) } @@ -699,7 +628,7 @@ case class ColumnarPostOverrides() extends Rule[SparkPlan] { p.withNewChildren(children) } - def replaceColumnarToRow(plan: ColumnarToRowExec, conf: SQLConf) : SparkPlan = { + def replaceColumnarToRow(plan: ColumnarToRowExec, conf: SQLConf): SparkPlan = { val child = replaceWithColumnarPlan(plan.child) if (conf.getConfString("spark.omni.sql.columnar.columnarToRow", "true").toBoolean) { OmniColumnarToRowExec(child) @@ -710,50 +639,124 @@ case class ColumnarPostOverrides() extends Rule[SparkPlan] { } case class ColumnarOverrideRules(session: SparkSession) extends ColumnarRule with Logging { - def columnarEnabled: Boolean = session.sqlContext.getConf( - "org.apache.spark.sql.columnar.enabled", "true").trim.toBoolean - def rowGuardOverrides: ColumnarGuardRule = ColumnarGuardRule() + private def preOverrides(): List[SparkSession => Rule[SparkPlan]] = List( + FallbackMultiCodegens, + (_: SparkSession) => AddTransformHintRule(), + (_: SparkSession) => ColumnarPreOverrides(isAdaptiveContext)) - def preOverrides: ColumnarPreOverrides = ColumnarPreOverrides() + private def postOverrides(): List[SparkSession => Rule[SparkPlan]] = List( + (_: SparkSession) => ColumnarPostOverrides(isAdaptiveContext) + ) - def postOverrides: ColumnarPostOverrides = ColumnarPostOverrides() - - var isSupportAdaptive: Boolean = true + private def finallyRules(): List[SparkSession => Rule[SparkPlan]] = { + List( + (_: SparkSession) => RemoveTransformHintRule() + ) + } - private def supportAdaptive(plan: SparkPlan): Boolean = { - // Only QueryStage will have Exchange as Leaf Plan - val isLeafPlanExchange = plan match { - case e: Exchange => true - case other => false + private def transformPlan(getRules: List[SparkSession => Rule[SparkPlan]], + plan: SparkPlan, + step: String) = { + logDebug( + s"${step}ColumnarTransitions preOverriden plan:\n${plan.toString}") + val overridden = getRules.foldLeft(plan) { + (p, getRule) => + val rule = getRule(session) + val newPlan = rule(p) + newPlan } - isLeafPlanExchange || (SQLConf.get.adaptiveExecutionEnabled && (sanityCheck(plan) && - !plan.logicalLink.exists(_.isStreaming) && - !plan.expressions.exists(_.find(_.isInstanceOf[DynamicPruningSubquery]).isDefined) && - plan.children.forall(supportAdaptive))) + logDebug( + s"${step}ColumnarTransitions afterOverriden plan:\n${overridden.toString}") + overridden + } + + // Just for test use. + def enableAdaptiveContext(): Unit = { + session.sparkContext.setLocalProperty(OMNI_IS_ADAPTIVE_CONTEXT, "true") + } + + // Holds the original plan for possible entire fallback. + private val localOriginalPlans: ThreadLocal[ListBuffer[SparkPlan]] = + ThreadLocal.withInitial(() => ListBuffer.empty[SparkPlan]) + private val localIsAdaptiveContextFlags: ThreadLocal[ListBuffer[Boolean]] = + ThreadLocal.withInitial(() => ListBuffer.empty[Boolean]) + + private def setOriginalPlan(plan: SparkPlan): Unit = { + localOriginalPlans.get.prepend(plan) + } + + private def originalPlan: SparkPlan = { + val plan = localOriginalPlans.get.head + assert(plan != null) + plan + } + + private def resetOriginalPlan(): Unit = localOriginalPlans.get.remove(0) + + private def fallbackPolicy(): List[SparkSession => Rule[SparkPlan]] = { + List((_: SparkSession) => ExpandFallbackPolicy(isAdaptiveContext, originalPlan)) } - private def sanityCheck(plan: SparkPlan): Boolean = - plan.logicalLink.isDefined + // This is an empirical value, may need to be changed for supporting other versions of spark. + private val aqeStackTraceIndex = 14 + + private def setAdaptiveContext(): Unit = { + val traceElements = Thread.currentThread.getStackTrace + assert( + traceElements.length > aqeStackTraceIndex, + s"The number of stack trace elements is expected to be more than $aqeStackTraceIndex") + // ApplyColumnarRulesAndInsertTransitions is called by either QueryExecution or + // AdaptiveSparkPlanExec. So by checking the stack trace, we can know whether + // columnar rule will be applied in adaptive execution context. This part of code + // needs to be carefully checked when supporting higher versions of spark to make + // sure the calling stack has not been changed. + localIsAdaptiveContextFlags + .get() + .prepend( + traceElements(aqeStackTraceIndex).getClassName + .equals(AdaptiveSparkPlanExec.getClass.getName)) + } + + private def resetAdaptiveContext(): Unit = + localIsAdaptiveContextFlags.get().remove(0) + + def isAdaptiveContext: Boolean = Option(session.sparkContext.getLocalProperty(OMNI_IS_ADAPTIVE_CONTEXT)) + .getOrElse("false") + .toBoolean || localIsAdaptiveContextFlags.get().head override def preColumnarTransitions: Rule[SparkPlan] = plan => PhysicalPlanSelector. maybe(session, plan) { - isSupportAdaptive = supportAdaptive(plan) - val rule = preOverrides - rule.setAdaptiveSupport(isSupportAdaptive) - logInfo("Using BoostKit Spark Native Sql Engine Extension ColumnarPreOverrides") - rule(rowGuardOverrides(plan)) + setAdaptiveContext() + setOriginalPlan(plan) + transformPlan(preOverrides(), plan, "pre") } override def postColumnarTransitions: Rule[SparkPlan] = plan => PhysicalPlanSelector. maybe(session, plan) { - val rule = postOverrides - rule.setAdaptiveSupport(isSupportAdaptive) - logInfo("Using BoostKit Spark Native Sql Engine Extension ColumnarPostOverrides") - rule(plan) + val planWithFallBackPolicy = transformPlan(fallbackPolicy(), plan, "fallback") + + val finalPlan = planWithFallBackPolicy match { + case FallbackNode(fallbackPlan) => + // skip c2r and r2c replaceWithColumnarPlan + fallbackPlan + case plan => + transformPlan(postOverrides(), plan, "post") + } + resetOriginalPlan() + resetAdaptiveContext() + transformPlan(finallyRules(), finalPlan, "final") } } +case class RemoveTransformHintRule() extends Rule[SparkPlan] { + override def apply(plan: SparkPlan): SparkPlan = { + plan.foreach(TransformHints.untag) + plan + } +} + + class ColumnarPlugin extends (SparkSessionExtensions => Unit) with Logging { override def apply(extensions: SparkSessionExtensions): Unit = { logInfo("Using BoostKit Spark Native Sql Engine Extension to Speed Up Your Queries.") @@ -764,6 +767,7 @@ class ColumnarPlugin extends (SparkSessionExtensions => Unit) with Logging { extensions.injectOptimizerRule(_ => HeuristicJoinReorder) extensions.injectOptimizerRule(_ => MergeSubqueryFilters) extensions.injectQueryStagePrepRule(_ => TopNPushDownForWindow) + extensions.injectQueryStagePrepRule(session => FallbackBroadcastExchange(session)) } } diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarPluginConfig.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarPluginConfig.scala index d580945216cdcdc9dfe25200e705b8c208bec3a7..0b98c8bbac583179fbd9d30ce2743fe6ce1bab7c 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarPluginConfig.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarPluginConfig.scala @@ -18,6 +18,7 @@ package com.huawei.boostkit.spark +import com.huawei.boostkit.spark.ColumnarPluginConfig.{OMNI_BROADCAST_EXCHANGE_ENABLE_KEY} import org.apache.spark.SparkEnv import org.apache.spark.internal.Logging import org.apache.spark.shuffle.sort.ColumnarShuffleManager @@ -68,7 +69,7 @@ class ColumnarPluginConfig(conf: SQLConf) extends Logging { // enable or disable columnar broadcastexchange val enableColumnarBroadcastExchange: Boolean = - conf.getConfString("spark.omni.sql.columnar.broadcastexchange", "true").toBoolean + conf.getConfString(OMNI_BROADCAST_EXCHANGE_ENABLE_KEY, "true").toBoolean // enable or disable columnar wholestagecodegen val enableColumnarWholeStageCodegen: Boolean = @@ -269,6 +270,19 @@ class ColumnarPluginConfig(conf: SQLConf) extends Logging { val columnsThreshold: Int = conf.getConfString("spark.omni.sql.columnar.columnsThreshold", "10").toInt + + // enable or disable bloomfilter subquery reuse + val enableBloomfilterSubqueryReuse: Boolean = + conf.getConfString("spark.omni.sql.columnar.bloomfilterSubqueryReuse", "false").toBoolean + + // The threshold for whether whole stage will fall back in AQE supported case by counting the number of vanilla SparkPlan. + // If it is set to -1, it means that this function is turned off + // otherwise, when the number of vanilla SparkPlan of the stage is greater than or equal to the threshold, + // all the SparkPlan of the stage will be fallback to vanilla SparkPlan + val wholeStageFallbackThreshold: Int = conf.getConfString("spark.omni.sql.columnar.wholeStage.fallback.threshold", "-1").toInt + + // it is same with wholeStageFallbackThreshold, but it is used for non AQE + val queryFallbackThreshold: Int = conf.getConfString("spark.omni.sql.columnar.query.fallback.threshold", "-1").toInt } @@ -276,6 +290,8 @@ object ColumnarPluginConfig { val OMNI_ENABLE_KEY: String = "spark.omni.enabled" + val OMNI_BROADCAST_EXCHANGE_ENABLE_KEY = "spark.omni.sql.columnar.broadcastexchange" + var ins: ColumnarPluginConfig = null def getConf: ColumnarPluginConfig = synchronized { diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/Constant.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/Constant.scala index 9d7f844bcc19601ac065083b988085c340631ad3..6521421175b6144e20a1ad590b2936b605dc8d3e 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/Constant.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/Constant.scala @@ -36,4 +36,6 @@ object Constant { val IS_SKIP_VERIFY_EXP: Boolean = true val OMNI_DECIMAL64_TYPE: String = DataTypeId.OMNI_DECIMAL64.ordinal().toString val OMNI_DECIMAL128_TYPE: String = DataTypeId.OMNI_DECIMAL128.ordinal().toString + // for UT + val OMNI_IS_ADAPTIVE_CONTEXT = "omni.isAdaptiveContext" } \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ExpandFallbackPolicy.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ExpandFallbackPolicy.scala new file mode 100644 index 0000000000000000000000000000000000000000..73d6a3cd7aad4d35cca57e96987e7bbb8173ad21 --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ExpandFallbackPolicy.scala @@ -0,0 +1,164 @@ +/* + * Copyright (C) 2024-2024. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.huawei.boostkit.spark + +import com.huawei.boostkit.spark.util.InsertTransitions +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution.adaptive.{AQEShuffleReadExec, AdaptiveSparkPlanExec, BroadcastQueryStageExec, OmniAQEShuffleReadExec, QueryStageExec, ShuffleQueryStageExec} +import org.apache.spark.sql.execution.command.ExecutedCommandExec +import org.apache.spark.sql.execution.joins.ColumnarBroadcastHashJoinExec +import org.apache.spark.sql.execution._ +import org.apache.spark.sql.execution.exchange.ReusedExchangeExec + +/** + * Note, this rule should only fallback to row-based plan if there is no harm. + * The follow case should be handled carefully + * + * @param isAdaptiveContext If is inside AQE + * @param originalPlan The vanilla SparkPlan without apply boostkit extension transform rules + * + * */ +case class ExpandFallbackPolicy(isAdaptiveContext: Boolean, originalPlan: SparkPlan) + extends Rule[SparkPlan] { + + val columnarConf: ColumnarPluginConfig = ColumnarPluginConfig.getSessionConf + + private def countFallback(plan: SparkPlan): Int = { + var fallbacks = 0 + + def countFallbackInternal(plan: SparkPlan): Unit = { + plan match { + case _: QueryStageExec => // Another stage. + case _: CommandResultExec | _: ExecutedCommandExec => // ignore + case _: AQEShuffleReadExec => // ignore + case leafPlan: LeafExecNode if !isOmniSparkPlan(leafPlan) => + // Possible fallback for leaf node. + fallbacks = fallbacks + 1 + case p@(ColumnarToRowExec(_) | RowToColumnarExec(_)) => + p.children.foreach(countFallbackInternal) + case p => + if (!isOmniSparkPlan(p)) { + fallbacks = fallbacks + 1 + } + p.children.foreach(countFallbackInternal) + } + } + + countFallbackInternal(plan) + fallbacks + } + + private def isOmniSparkPlan(plan: SparkPlan): Boolean = plan.nodeName.startsWith("Omni") && plan.supportsColumnar + + private def hasColumnarBroadcastExchangeWithJoin(plan: SparkPlan): Boolean = { + def isColumnarBroadcastExchange(p: SparkPlan): Boolean = p match { + case BroadcastQueryStageExec(_, _: ColumnarBroadcastExchangeExec, _) => true + case _ => false + } + + plan.find { + case j: ColumnarBroadcastHashJoinExec + if isColumnarBroadcastExchange(j.left) || + isColumnarBroadcastExchange(j.right) => + true + case _ => false + }.isDefined + } + + private def fallback(plan: SparkPlan): Option[String] = { + val fallbackThreshold = if (isAdaptiveContext) { + columnarConf.wholeStageFallbackThreshold + } else if (plan.find(_.isInstanceOf[AdaptiveSparkPlanExec]).isDefined) { + // if we are here, that means we are now at `QueryExecution.preparations` and + // AQE is actually not applied. We do nothing for this case, and later in + // AQE we can check `wholeStageFallbackThreshold`. + return None + } else { + // AQE is not applied, so we use the whole query threshold to check if should fallback + columnarConf.queryFallbackThreshold + } + if (fallbackThreshold < 0) { + return None + } + + val fallbackNum = countFallback(plan) + + if (fallbackNum >= fallbackThreshold) { + Some( + s"Fallback policy is taking effect, net fallback number: $fallbackNum, " + + s"threshold: $fallbackThreshold") + } else { + None + } + } + + private def fallbackToRowBasedPlan(): SparkPlan = { + val columnarPostOverrides = ColumnarPostOverrides(isAdaptiveContext) + val planWithColumnarToRow = InsertTransitions.insertTransitions(originalPlan, false) + planWithColumnarToRow.transform { + case c2r@(ColumnarToRowExec(_: ShuffleQueryStageExec) | + ColumnarToRowExec(_: BroadcastQueryStageExec)) => + columnarPostOverrides.replaceColumnarToRow(c2r.asInstanceOf[ColumnarToRowExec], conf) + case c2r@ColumnarToRowExec(aqeShuffleReadExec: AQEShuffleReadExec) => + val newPlan = columnarPostOverrides.replaceColumnarToRow(c2r, conf) + newPlan.withNewChildren(Seq(aqeShuffleReadExec.child match { + case _: ColumnarShuffleExchangeExec => + OmniAQEShuffleReadExec(aqeShuffleReadExec.child, aqeShuffleReadExec.partitionSpecs) + case ShuffleQueryStageExec(_, _: ColumnarShuffleExchangeExec, _) => + OmniAQEShuffleReadExec(aqeShuffleReadExec.child, aqeShuffleReadExec.partitionSpecs) + case ShuffleQueryStageExec(_, reused: ReusedExchangeExec, _) => + reused match { + case ReusedExchangeExec(_, _: ColumnarShuffleExchangeExec) => + OmniAQEShuffleReadExec( + aqeShuffleReadExec.child, + aqeShuffleReadExec.partitionSpecs) + case _ => + aqeShuffleReadExec + } + case _ => + aqeShuffleReadExec + })) + } + } + + override def apply(plan: SparkPlan): SparkPlan = { + logInfo("Using BoostKit Spark Native Sql Engine Extension FallbackPolicy") + val reason = fallback(plan) + if (reason.isDefined) { + if (hasColumnarBroadcastExchangeWithJoin(plan)) { + logDebug("plan fallback using ExpandFallbackPolicy contains ColumnarBroadcastExchange") + } + val fallbackPlan = fallbackToRowBasedPlan() + TransformHints.tagAllNotTransformable(fallbackPlan, reason.get) + FallbackNode(fallbackPlan) + } else { + plan + } + } +} + +/** A wrapper to specify the plan is fallback plan, the caller side should unwrap it. */ +case class FallbackNode(fallbackPlan: SparkPlan) extends LeafExecNode { + override protected def doExecute(): RDD[InternalRow] = throw new UnsupportedOperationException() + + override def output: Seq[Attribute] = fallbackPlan.output +} diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/FallbackBroadcastExchange.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/FallbackBroadcastExchange.scala new file mode 100644 index 0000000000000000000000000000000000000000..6a7c9180c94638b26a86db8d2871c1087ddcdaa0 --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/FallbackBroadcastExchange.scala @@ -0,0 +1,58 @@ +/* + * Copyright (C) 2024-2024. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.huawei.boostkit.spark + +import com.huawei.boostkit.spark.ColumnarPluginConfig.OMNI_BROADCAST_EXCHANGE_ENABLE_KEY +import com.huawei.boostkit.spark.util.PhysicalPlanSelector +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution.{ColumnarBroadcastExchangeExec, SparkPlan} +import org.apache.spark.sql.execution.exchange.BroadcastExchangeExec + +case class FallbackBroadcastExchange(session: SparkSession) extends Rule[SparkPlan] { + + override def apply(plan: SparkPlan): SparkPlan = PhysicalPlanSelector.maybe(session, plan) { + + val SQLconf = session.sessionState.conf + val enableColumnarBroadcastExchange: Boolean = + SQLconf.getConfString(OMNI_BROADCAST_EXCHANGE_ENABLE_KEY, "true").toBoolean + + plan.foreach { + case exec: BroadcastExchangeExec => + if (!enableColumnarBroadcastExchange) { + TransformHints.tagNotTransformable(exec) + logInfo(s"BroadcastExchange falls back for disable ColumnarBroadcastExchange") + } else { + // check whether support ColumnarBroadcastExchangeExec or not + try { + new ColumnarBroadcastExchangeExec( + exec.mode, + exec.child + ).buildCheck() + } catch { + case t: Throwable => + TransformHints.tagNotTransformable(exec) + logDebug(s"BroadcastExchange falls back for ColumnarBroadcastExchangeExec: ${t}") + } + } + case _ => + } + plan + } +} diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarGuardRule.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/TransformHintRule.scala similarity index 46% rename from omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarGuardRule.scala rename to omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/TransformHintRule.scala index d20781708771934f58f296c3b1b79ad58b7b380c..e37e701ffc9f1ee08d8f8fd0ee8621209dcb8f65 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarGuardRule.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/TransformHintRule.scala @@ -1,5 +1,5 @@ /* - * Copyright (C) 2020-2024. Huawei Technologies Co., Ltd. All rights reserved. + * Copyright (C) 2024-2024. Huawei Technologies Co., Ltd. All rights reserved. * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. @@ -18,34 +18,94 @@ package com.huawei.boostkit.spark -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.commons.lang3.exception.ExceptionUtils +import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.execution._ +import org.apache.spark.sql.catalyst.trees.TreeNodeTag import org.apache.spark.sql.execution.adaptive.{BroadcastQueryStageExec, OmniAQEShuffleReadExec} import org.apache.spark.sql.execution.aggregate.HashAggregateExec import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec} -import org.apache.spark.sql.execution.joins._ +import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, BroadcastNestedLoopJoinExec, ColumnarBroadcastHashJoinExec, ColumnarShuffledHashJoinExec, ColumnarSortMergeJoinExec, ShuffledHashJoinExec, SortMergeJoinExec} import org.apache.spark.sql.execution.window.WindowExec +import org.apache.spark.sql.execution.{CoalesceExec, CodegenSupport, ColumnarBroadcastExchangeExec, ColumnarCoalesceExec, ColumnarExpandExec, ColumnarFileSourceScanExec, ColumnarFilterExec, ColumnarGlobalLimitExec, ColumnarHashAggregateExec, ColumnarLocalLimitExec, ColumnarProjectExec, ColumnarShuffleExchangeExec, ColumnarSortExec, ColumnarTakeOrderedAndProjectExec, ColumnarTopNSortExec, ColumnarUnionExec, ColumnarWindowExec, ExpandExec, FileSourceScanExec, FilterExec, GlobalLimitExec, LocalLimitExec, ProjectExec, SortExec, SparkPlan, TakeOrderedAndProjectExec, TopNSortExec, UnionExec} import org.apache.spark.sql.types.ColumnarBatchSupportUtil.checkColumnarBatchSupport -case class RowGuard(child: SparkPlan) extends SparkPlan { - def output: Seq[Attribute] = child.output +trait TransformHint { + val stacktrace: Option[String] = + if (TransformHints.DEBUG) { + Some(ExceptionUtils.getStackTrace(new Throwable())) + } else None +} + +case class TRANSFORM_SUPPORTED() extends TransformHint +case class TRANSFORM_UNSUPPORTED(reason: Option[String]) extends TransformHint + +object TransformHints { + val TAG: TreeNodeTag[TransformHint] = + TreeNodeTag[TransformHint]("omni.transformhint") + + val DEBUG = false + + def isAlreadyTagged(plan: SparkPlan): Boolean = { + plan.getTagValue(TAG).isDefined + } + + def isTransformable(plan: SparkPlan): Boolean = { + if (plan.getTagValue(TAG).isDefined) { + return plan.getTagValue(TAG).get.isInstanceOf[TRANSFORM_SUPPORTED] + } + false + } + + def isNotTransformable(plan: SparkPlan): Boolean = { + if (plan.getTagValue(TAG).isDefined) { + return plan.getTagValue(TAG).get.isInstanceOf[TRANSFORM_UNSUPPORTED] + } + false + } + + def tag(plan: SparkPlan, hint: TransformHint): Unit = { + if (isAlreadyTagged(plan)) { + if (isNotTransformable(plan) && hint.isInstanceOf[TRANSFORM_SUPPORTED]) { + throw new UnsupportedOperationException( + "Plan was already tagged as non-transformable, " + + s"cannot mark it as transformable after that:\n${plan.toString()}") + } + } + plan.setTagValue(TAG, hint) + } + + def untag(plan: SparkPlan): Unit = { + plan.unsetTagValue(TAG) + } - protected def doExecute(): RDD[InternalRow] = { - throw new UnsupportedOperationException + def tagTransformable(plan: SparkPlan): Unit = { + tag(plan, TRANSFORM_SUPPORTED()) } - def children: Seq[SparkPlan] = Seq(child) + def tagNotTransformable(plan: SparkPlan, reason: String = ""): Unit = { + tag(plan, TRANSFORM_UNSUPPORTED(Some(reason))) + } - override protected def withNewChildrenInternal(newChildren: IndexedSeq[SparkPlan]): SparkPlan = - legacyWithNewChildren(newChildren) + def tagAllNotTransformable(plan: SparkPlan, reason: String): Unit = { + plan.foreach(other => tagNotTransformable(other, reason)) + } + + def getHint(plan: SparkPlan): TransformHint = { + if (!isAlreadyTagged(plan)) { + throw new IllegalStateException("Transform hint tag not set in plan: " + plan.toString()) + } + plan.getTagValue(TAG).getOrElse(throw new IllegalStateException()) + } + + def getHintOption(plan: SparkPlan): Option[TransformHint] = { + plan.getTagValue(TAG) + } } -case class ColumnarGuardRule() extends Rule[SparkPlan] { +case class AddTransformHintRule() extends Rule[SparkPlan] { + val columnarConf: ColumnarPluginConfig = ColumnarPluginConfig.getSessionConf - val preferColumnar: Boolean = columnarConf.enablePreferColumnar val enableColumnarShuffle: Boolean = columnarConf.enableColumnarShuffle val enableColumnarTopNSort: Boolean = columnarConf.enableColumnarTopNSort val enableColumnarSort: Boolean = columnarConf.enableColumnarSort @@ -57,25 +117,48 @@ case class ColumnarGuardRule() extends Rule[SparkPlan] { val enableColumnarProject: Boolean = columnarConf.enableColumnarProject val enableColumnarFilter: Boolean = columnarConf.enableColumnarFilter val enableColumnarExpand: Boolean = columnarConf.enableColumnarExpand - val enableColumnarBroadcastExchange: Boolean = columnarConf.enableColumnarBroadcastExchange && - columnarConf.enableColumnarBroadcastJoin + val enableColumnarBroadcastExchange: Boolean = columnarConf.enableColumnarBroadcastExchange val enableColumnarBroadcastJoin: Boolean = columnarConf.enableColumnarBroadcastJoin val enableColumnarSortMergeJoin: Boolean = columnarConf.enableColumnarSortMergeJoin val enableShuffledHashJoin: Boolean = columnarConf.enableShuffledHashJoin val enableColumnarFileScan: Boolean = columnarConf.enableColumnarFileScan val enableLocalColumnarLimit: Boolean = columnarConf.enableLocalColumnarLimit val enableGlobalColumnarLimit: Boolean = columnarConf.enableGlobalColumnarLimit - val optimizeLevel: Integer = columnarConf.joinOptimizationThrottle val enableColumnarCoalesce: Boolean = columnarConf.enableColumnarCoalesce - private def tryConvertToColumnar(plan: SparkPlan): Boolean = { + override def apply(plan: SparkPlan): SparkPlan = { + addTransformableTags(plan) + } + + /** Inserts a transformable tag on top of those that are not supported. */ + private def addTransformableTags(plan: SparkPlan): SparkPlan = { + // Walk the tree with post-order + val out = plan.withNewChildren(plan.children.map(addTransformableTags)) + addTransformableTag(out) + out + } + + private def addTransformableTag(plan: SparkPlan): Unit = { + if (TransformHints.isAlreadyTagged(plan)) { + logDebug( + s"Skip adding transformable tag, since plan already tagged as " + + s"${TransformHints.getHint(plan)}: ${plan.toString()}") + return + } try { - val columnarPlan = plan match { + plan match { case plan: FileSourceScanExec => if (!checkColumnarBatchSupport(conf, plan)) { - return false + TransformHints.tagNotTransformable( + plan, + "columnar Batch is not enabled in FileSourceScanExec") + return + } + if (!enableColumnarFileScan) { + TransformHints.tagNotTransformable( + plan, "columnar FileScan is not enabled in FileSourceScanExec") + return } - if (!enableColumnarFileScan) return false ColumnarFileSourceScanExec( plan.relation, plan.output, @@ -87,17 +170,37 @@ case class ColumnarGuardRule() extends Rule[SparkPlan] { plan.tableIdentifier, plan.disableBucketedScan ).buildCheck() + TransformHints.tagTransformable(plan) case plan: ProjectExec => - if (!enableColumnarProject) return false + if (!enableColumnarProject) { + TransformHints.tagNotTransformable( + plan, "columnar Project is not enabled in ProjectExec") + return + } ColumnarProjectExec(plan.projectList, plan.child).buildCheck() + TransformHints.tagTransformable(plan) case plan: FilterExec => - if (!enableColumnarFilter) return false + if (!enableColumnarFilter) { + TransformHints.tagNotTransformable( + plan, "columnar Filter is not enabled in FilterExec") + return + } ColumnarFilterExec(plan.condition, plan.child).buildCheck() + TransformHints.tagTransformable(plan) case plan: ExpandExec => - if (!enableColumnarExpand) return false + if (!enableColumnarExpand) { + TransformHints.tagNotTransformable( + plan, "columnar Expand is not enabled in ExpandExec") + return + } ColumnarExpandExec(plan.projections, plan.output, plan.child).buildCheck() + TransformHints.tagTransformable(plan) case plan: HashAggregateExec => - if (!enableColumnarHashAgg) return false + if (!enableColumnarHashAgg) { + TransformHints.tagNotTransformable( + plan, "columnar HashAggregate is not enabled in HashAggregateExec") + return + } new ColumnarHashAggregateExec( plan.requiredChildDistributionExpressions, plan.isStreaming, @@ -108,35 +211,70 @@ case class ColumnarGuardRule() extends Rule[SparkPlan] { plan.initialInputBufferOffset, plan.resultExpressions, plan.child).buildCheck() + TransformHints.tagTransformable(plan) case plan: TopNSortExec => - if (!enableColumnarTopNSort) return false + if (!enableColumnarTopNSort) { + TransformHints.tagNotTransformable( + plan, "columnar TopNSort is not enabled in TopNSortExec") + return + } ColumnarTopNSortExec(plan.n, plan.strictTopN, plan.partitionSpec, plan.sortOrder, plan.global, plan.child).buildCheck() + TransformHints.tagTransformable(plan) case plan: SortExec => - if (!enableColumnarSort) return false + if (!enableColumnarSort) { + TransformHints.tagNotTransformable( + plan, "columnar Sort is not enabled in SortExec") + return + } ColumnarSortExec(plan.sortOrder, plan.global, plan.child, plan.testSpillFrequency).buildCheck() + TransformHints.tagTransformable(plan) case plan: BroadcastExchangeExec => - if (!enableColumnarBroadcastExchange) return false + if (!enableColumnarBroadcastExchange) { + TransformHints.tagNotTransformable( + plan, "columnar BroadcastExchange is not enabled in BroadcastExchangeExec") + return + } new ColumnarBroadcastExchangeExec(plan.mode, plan.child).buildCheck() + TransformHints.tagTransformable(plan) case plan: TakeOrderedAndProjectExec => - if (!enableTakeOrderedAndProject) return false + if (!enableTakeOrderedAndProject) { + TransformHints.tagNotTransformable( + plan, "columnar TakeOrderedAndProject is not enabled in TakeOrderedAndProjectExec") + return + } ColumnarTakeOrderedAndProjectExec( plan.limit, plan.sortOrder, plan.projectList, plan.child).buildCheck() + TransformHints.tagTransformable(plan) case plan: UnionExec => - if (!enableColumnarUnion) return false + if (!enableColumnarUnion) { + TransformHints.tagNotTransformable( + plan, "columnar Union is not enabled in UnionExec") + return + } ColumnarUnionExec(plan.children).buildCheck() + TransformHints.tagTransformable(plan) case plan: ShuffleExchangeExec => - if (!enableColumnarShuffle) return false - new ColumnarShuffleExchangeExec(plan.outputPartitioning, plan.child, plan.shuffleOrigin) + if (!enableColumnarShuffle) { + TransformHints.tagNotTransformable( + plan, "columnar ShuffleExchange is not enabled in ShuffleExchangeExec") + return + } + ColumnarShuffleExchangeExec(plan.outputPartitioning, plan.child, plan.shuffleOrigin) .buildCheck() + TransformHints.tagTransformable(plan) case plan: BroadcastHashJoinExec => // We need to check if BroadcastExchangeExec can be converted to columnar-based. // If not, BHJ should also be row-based. - if (!enableColumnarBroadcastJoin) return false + if (!enableColumnarBroadcastJoin) { + TransformHints.tagNotTransformable( + plan, "columnar BroadcastHashJoin is not enabled in BroadcastHashJoinExec") + return + } val left = plan.left left match { case exec: BroadcastExchangeExec => @@ -174,9 +312,14 @@ case class ColumnarGuardRule() extends Rule[SparkPlan] { plan.left, plan.right, plan.isNullAwareAntiJoin).buildCheck() + TransformHints.tagTransformable(plan) case plan: SortMergeJoinExec => - if (!enableColumnarSortMergeJoin) return false - new ColumnarSortMergeJoinExec( + if (!enableColumnarSortMergeJoin) { + TransformHints.tagNotTransformable( + plan, "columnar SortMergeJoin is not enabled in SortMergeJoinExec") + return + } + ColumnarSortMergeJoinExec( plan.leftKeys, plan.rightKeys, plan.joinType, @@ -184,12 +327,22 @@ case class ColumnarGuardRule() extends Rule[SparkPlan] { plan.left, plan.right, plan.isSkewJoin).buildCheck() + TransformHints.tagTransformable(plan) case plan: WindowExec => - if (!enableColumnarWindow) return false + if (!enableColumnarWindow) { + TransformHints.tagNotTransformable( + plan, "columnar Window is not enabled in WindowExec") + return + } ColumnarWindowExec(plan.windowExpression, plan.partitionSpec, plan.orderSpec, plan.child).buildCheck() + TransformHints.tagTransformable(plan) case plan: ShuffledHashJoinExec => - if (!enableShuffledHashJoin) return false + if (!enableShuffledHashJoin) { + TransformHints.tagNotTransformable( + plan, "columnar ShuffledHashJoin is not enabled in ShuffledHashJoinExec") + return + } ColumnarShuffledHashJoinExec( plan.leftKeys, plan.rightKeys, @@ -199,37 +352,64 @@ case class ColumnarGuardRule() extends Rule[SparkPlan] { plan.left, plan.right, plan.isSkewJoin).buildCheck() + TransformHints.tagTransformable(plan) case plan: LocalLimitExec => - if (!enableLocalColumnarLimit) return false + if (!enableLocalColumnarLimit) { + TransformHints.tagNotTransformable( + plan, "columnar LocalLimit is not enabled in LocalLimitExec") + return + } ColumnarLocalLimitExec(plan.limit, plan.child).buildCheck() + TransformHints.tagTransformable(plan) case plan: GlobalLimitExec => - if (!enableGlobalColumnarLimit) return false + if (!enableGlobalColumnarLimit) { + TransformHints.tagNotTransformable( + plan, "columnar GlobalLimit is not enabled in GlobalLimitExec") + return + } ColumnarGlobalLimitExec(plan.limit, plan.child).buildCheck() - case plan: BroadcastNestedLoopJoinExec => return false + TransformHints.tagTransformable(plan) + case plan: BroadcastNestedLoopJoinExec => + TransformHints.tagNotTransformable( + plan, "columnar BroadcastNestedLoopJoin is not support") + TransformHints.tagTransformable(plan) case plan: CoalesceExec => - if (!enableColumnarCoalesce) return false + if (!enableColumnarCoalesce) { + TransformHints.tagNotTransformable( + plan, "columnar Coalesce is not enabled in CoalesceExec") + return + } ColumnarCoalesceExec(plan.numPartitions, plan.child).buildCheck() - case p => - p + TransformHints.tagTransformable(plan) + case _ => TransformHints.tagTransformable(plan) + } - } - catch { + } catch { case e: UnsupportedOperationException => - logDebug(s"[OPERATOR FALLBACK] ${e} ${plan.getClass} falls back to Spark operator") - return false + val message = s"[OPERATOR FALLBACK] ${e} ${plan.getClass} falls back to Spark operator" + logDebug(message) + TransformHints.tagNotTransformable(plan, reason = message) case l: UnsatisfiedLinkError => throw l case f: NoClassDefFoundError => throw f case r: RuntimeException => - logDebug(s"[OPERATOR FALLBACK] ${r} ${plan.getClass} falls back to Spark operator") - return false + val message = s"[OPERATOR FALLBACK] ${r} ${plan.getClass} falls back to Spark operator" + logDebug(message) + TransformHints.tagNotTransformable(plan, reason = message) case t: Throwable => - logDebug(s"[OPERATOR FALLBACK] ${t} ${plan.getClass} falls back to Spark operator") - return false + val message = s"[OPERATOR FALLBACK] ${t} ${plan.getClass} falls back to Spark operator" + logDebug(message) + TransformHints.tagNotTransformable(plan, reason = message) } - true } +} + +case class FallbackMultiCodegens(session: SparkSession) extends Rule[SparkPlan] { + + val columnarConf: ColumnarPluginConfig = ColumnarPluginConfig.getSessionConf + val optimizeLevel: Integer = columnarConf.joinOptimizationThrottle + val preferColumnar: Boolean = columnarConf.enablePreferColumnar private def existsMultiCodegens(plan: SparkPlan, count: Int = 0): Boolean = plan match { @@ -239,7 +419,7 @@ case class ColumnarGuardRule() extends Rule[SparkPlan] { case plan: ShuffledHashJoinExec => if ((count + 1) >= optimizeLevel) return true plan.children.map(existsMultiCodegens(_, count + 1)).exists(_ == true) - case other => false + case _ => false } private def supportCodegen(plan: SparkPlan): Boolean = plan match { @@ -248,51 +428,50 @@ case class ColumnarGuardRule() extends Rule[SparkPlan] { case _ => false } + private def tagNotTransformable(plan: SparkPlan): SparkPlan = { + TransformHints.tagNotTransformable(plan, "fallback multi codegens") + plan + } + /** * Inserts an InputAdapter on top of those that do not support codegen. */ - private def insertRowGuardRecursive(plan: SparkPlan): SparkPlan = { + private def tagNotTransformableRecursive(plan: SparkPlan): SparkPlan = { plan match { case p: ShuffleExchangeExec => - RowGuard(p.withNewChildren(p.children.map(insertRowGuardOrNot))) + tagNotTransformable(p.withNewChildren(p.children.map(tagNotTransformableForMultiCodegens))) case p: BroadcastExchangeExec => - RowGuard(p.withNewChildren(p.children.map(insertRowGuardOrNot))) + tagNotTransformable(p.withNewChildren(p.children.map(tagNotTransformableForMultiCodegens))) case p: ShuffledHashJoinExec => - RowGuard(p.withNewChildren(p.children.map(insertRowGuardRecursive))) + tagNotTransformable(p.withNewChildren(p.children.map(tagNotTransformableRecursive))) case p if !supportCodegen(p) => - // insert row guard them recursively - p.withNewChildren(p.children.map(insertRowGuardOrNot)) + // tag them recursively + p.withNewChildren(p.children.map(tagNotTransformableForMultiCodegens)) case p: OmniAQEShuffleReadExec => - p.withNewChildren(p.children.map(insertRowGuardOrNot)) + p.withNewChildren(p.children.map(tagNotTransformableForMultiCodegens)) case p: BroadcastQueryStageExec => p - case p => RowGuard(p.withNewChildren(p.children.map(insertRowGuardRecursive))) + case p => tagNotTransformable(p.withNewChildren(p.children.map(tagNotTransformableRecursive))) } } - private def insertRowGuard(plan: SparkPlan): SparkPlan = { - RowGuard(plan.withNewChildren(plan.children.map(insertRowGuardOrNot))) - } - /** * Inserts a WholeStageCodegen on top of those that support codegen. */ - private def insertRowGuardOrNot(plan: SparkPlan): SparkPlan = { + private def tagNotTransformableForMultiCodegens(plan: SparkPlan): SparkPlan = { plan match { // For operators that will output domain object, do not insert WholeStageCodegen for it as // domain object can not be written into unsafe row. case plan if !preferColumnar && existsMultiCodegens(plan) => - insertRowGuardRecursive(plan) - case plan if !tryConvertToColumnar(plan) => - insertRowGuard(plan) - case p: BroadcastQueryStageExec => - p + tagNotTransformableRecursive(plan) case other => - other.withNewChildren(other.children.map(insertRowGuardOrNot)) + other.withNewChildren(other.children.map(tagNotTransformableForMultiCodegens)) } } - def apply(plan: SparkPlan): SparkPlan = { - insertRowGuardOrNot(plan) + override def apply(plan: SparkPlan): SparkPlan = { + tagNotTransformableForMultiCodegens(plan) } + } + diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/util/InsertTransitions.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/util/InsertTransitions.scala new file mode 100644 index 0000000000000000000000000000000000000000..1de01e84006541e766f5f40c71c9b40504e8ee5a --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/util/InsertTransitions.scala @@ -0,0 +1,51 @@ +/* + * Copyright (C) 2024-2024. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.huawei.boostkit.spark.util + +import org.apache.spark.sql.execution.{ColumnarToRowExec, ColumnarToRowTransition, RowToColumnarExec, RowToColumnarTransition, SparkPlan} + +/** Ported from [[ApplyColumnarRulesAndInsertTransitions]] of vanilla Spark. */ +object InsertTransitions { + + private def insertRowToColumnar(plan: SparkPlan): SparkPlan = { + if (!plan.supportsColumnar) { + // The tree feels kind of backwards + // Columnar Processing will start here, so transition from row to columnar + RowToColumnarExec(insertTransitions(plan, outputsColumnar = false)) + } else if (!plan.isInstanceOf[RowToColumnarTransition]) { + plan.withNewChildren(plan.children.map(insertRowToColumnar)) + } else { + plan + } + } + + def insertTransitions(plan: SparkPlan, outputsColumnar: Boolean): SparkPlan = { + if (outputsColumnar) { + insertRowToColumnar(plan) + } else if (plan.supportsColumnar) { + // `outputsColumnar` is false but the plan outputs columnar format, so add a + // to-row transition here. + ColumnarToRowExec(insertRowToColumnar(plan)) + } else if (!plan.isInstanceOf[ColumnarToRowTransition]) { + plan.withNewChildren(plan.children.map(insertTransitions(_, outputsColumnar = false))) + } else { + plan + } + } +} diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/BroadcastColumnarRDD.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/BroadcastColumnarRDD.scala deleted file mode 100644 index 60398f033fd5ebcab78fd0474d16bf865560b51c..0000000000000000000000000000000000000000 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/BroadcastColumnarRDD.scala +++ /dev/null @@ -1,72 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution - -import nova.hetu.omniruntime.vector.VecBatch -import nova.hetu.omniruntime.vector.serialize.VecBatchSerializerFactory -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.execution.metric.SQLMetric -import org.apache.spark.sql.execution.vectorized.OmniColumnVector -import org.apache.spark.sql.types.StructType -import org.apache.spark.sql.vectorized.ColumnarBatch -import org.apache.spark.{Partition, SparkContext, TaskContext, broadcast} - - -private final case class BroadcastColumnarRDDPartition(index: Int) extends Partition - -case class BroadcastColumnarRDD( - @transient private val sc: SparkContext, - metrics: Map[String, SQLMetric], - numPartitioning: Int, - inputByteBuf: broadcast.Broadcast[ColumnarHashedRelation], - localSchema: StructType) - extends RDD[ColumnarBatch](sc, Nil) { - - override protected def getPartitions: Array[Partition] = { - (0 until numPartitioning).map { index => new BroadcastColumnarRDDPartition(index) }.toArray - } - - private def vecBatchToColumnarBatch(vecBatch: VecBatch): ColumnarBatch = { - val vectors: Seq[OmniColumnVector] = OmniColumnVector.allocateColumns( - vecBatch.getRowCount, localSchema, false) - vectors.zipWithIndex.foreach { case (vector, i) => - vector.reset() - vector.setVec(vecBatch.getVectors()(i)) - } - vecBatch.close() - new ColumnarBatch(vectors.toArray, vecBatch.getRowCount) - } - - override def compute(split: Partition, context: TaskContext): Iterator[ColumnarBatch] = { - // val relation = inputByteBuf.value.asReadOnlyCopy - // new CloseableColumnBatchIterator(relation.getColumnarBatchAsIter) - val deserializer = VecBatchSerializerFactory.create() - new Iterator[ColumnarBatch] { - var idx = 0 - val total_len = inputByteBuf.value.buildData.length - - override def hasNext: Boolean = idx < total_len - - override def next(): ColumnarBatch = { - val batch: VecBatch = deserializer.deserialize(inputByteBuf.value.buildData(idx)) - idx += 1 - vecBatchToColumnarBatch(batch) - } - } - } -} diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarBroadcastExchangeAdaptorExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarBroadcastExchangeAdaptorExec.scala deleted file mode 100644 index 1d236c16d3849906ec997726ca62ee5957ff0740..0000000000000000000000000000000000000000 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarBroadcastExchangeAdaptorExec.scala +++ /dev/null @@ -1,70 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution - -import nova.hetu.omniruntime.vector.Vec -import org.apache.spark.broadcast -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Attribute, SortOrder, UnsafeProjection} -import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, UnknownPartitioning} -import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} -import org.apache.spark.sql.execution.util.SparkMemoryUtils -import org.apache.spark.sql.execution.vectorized.OmniColumnVector -import org.apache.spark.sql.types.StructType - -import scala.collection.JavaConverters.asScalaIteratorConverter -import scala.collection.mutable.ListBuffer - -case class ColumnarBroadcastExchangeAdaptorExec(child: SparkPlan, numPartitions: Int) - extends UnaryExecNode { - override def output: Seq[Attribute] = child.output - - override def outputPartitioning: Partitioning = UnknownPartitioning(numPartitions) - - override def outputOrdering: Seq[SortOrder] = child.outputOrdering - - override def doExecute(): RDD[InternalRow] = { - val numOutputRows: SQLMetric = longMetric("numOutputRows") - val numOutputBatches: SQLMetric = longMetric("numOutputBatches") - val processTime: SQLMetric = longMetric("processTime") - val inputRdd: BroadcastColumnarRDD = BroadcastColumnarRDD( - sparkContext, - metrics, - numPartitions, - child.executeBroadcast(), - StructType.fromAttributes(child.output)) - inputRdd.mapPartitions { batches => - ColumnarBatchToInternalRow.convert(output, batches, numOutputRows, numOutputBatches, processTime) - } - } - - override def doExecuteBroadcast[T](): broadcast.Broadcast[T] = { - child.executeBroadcast() - } - - override def supportsColumnar: Boolean = true - - override lazy val metrics: Map[String, SQLMetric] = Map( - "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"), - "numOutputBatches" -> SQLMetrics.createMetric(sparkContext, "output_batches"), - "processTime" -> SQLMetrics.createTimingMetric(sparkContext, "totaltime_datatoarrowcolumnar")) - - override protected def withNewChildInternal(newChild: SparkPlan): - ColumnarBroadcastExchangeAdaptorExec = copy(child = newChild) -} \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarExec.scala index cec2012e6229d81822ad66abc2f0b74091b5efb1..d8a2bea7715d05cf89bafe0c44608b7417901bde 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarExec.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarExec.scala @@ -17,26 +17,23 @@ package org.apache.spark.sql.execution -import java.util.concurrent.TimeUnit.NANOSECONDS - -import scala.collection.JavaConverters._ -import scala.collection.mutable.ListBuffer - import org.apache.spark.broadcast +import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, SortOrder, SpecializedGetters, UnsafeProjection} import org.apache.spark.sql.catalyst.plans.physical.Partitioning -import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} -import org.apache.spark.sql.execution.util.SparkMemoryUtils +import org.apache.spark.sql.execution.util.{BroadcastUtils, SparkMemoryUtils} import org.apache.spark.sql.execution.vectorized.{OffHeapColumnVector, OmniColumnVector, WritableColumnVector} -import org.apache.spark.sql.types.{BinaryType, BooleanType, ByteType, CalendarIntervalType, DataType, DateType, DecimalType, DoubleType, IntegerType, LongType, ShortType, StringType, StructType, TimestampType} +import org.apache.spark.sql.types._ import org.apache.spark.sql.vectorized.ColumnarBatch -import org.apache.spark.util.Utils -import nova.hetu.omniruntime.vector.Vec +import java.util.concurrent.TimeUnit.NANOSECONDS +import scala.collection.JavaConverters._ +import scala.collection.mutable.ListBuffer +import nova.hetu.omniruntime.vector.Vec /** * Provides an optimized set of APIs to append row based data to an array of @@ -205,7 +202,32 @@ case class RowToOmniColumnarExec(child: SparkPlan) extends RowToColumnarTransiti } override def doExecuteBroadcast[T](): broadcast.Broadcast[T] = { - child.doExecuteBroadcast() + val numInputRows = longMetric("numInputRows") + val numOutputBatches = longMetric("numOutputBatches") + val rowToOmniColumnarTime = longMetric("rowToOmniColumnarTime") + // Instead of creating a new config we are reusing columnBatchSize. In the future if we do + // combine with some of the Arrow conversion tools we will need to unify some of the configs. + val numRows = conf.columnBatchSize + val enableOffHeapColumnVector = session.sqlContext.conf.offHeapColumnVectorEnabled + val localSchema = this.schema + val relation = child.executeBroadcast() + val mode = BroadcastUtils.getBroadCastMode(outputPartitioning) + val broadcast: Broadcast[T] = BroadcastUtils.sparkToOmniUnsafe(sparkContext, + mode, + relation, + logError, + InternalRowToColumnarBatch.convert( + enableOffHeapColumnVector, + numInputRows, + numOutputBatches, + rowToOmniColumnarTime, + numRows, + localSchema, + _) + ) + val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) + SQLMetrics.postDriverMetricUpdates(sparkContext, executionId, Seq(numInputRows, numOutputBatches, rowToOmniColumnarTime)) + broadcast } override def nodeName: String = "RowToOmniColumnar" @@ -233,44 +255,10 @@ case class RowToOmniColumnarExec(child: SparkPlan) extends RowToColumnarTransiti // plan (this) in the closure. val localSchema = this.schema child.execute().mapPartitionsInternal { rowIterator => - if (rowIterator.hasNext) { - new Iterator[ColumnarBatch] { - private val converters = new OmniRowToColumnConverter(localSchema) - - override def hasNext: Boolean = { - rowIterator.hasNext - } - - override def next(): ColumnarBatch = { - val startTime = System.nanoTime() - val vectors: Seq[WritableColumnVector] = OmniColumnVector.allocateColumns(numRows, - localSchema, true) - val cb: ColumnarBatch = new ColumnarBatch(vectors.toArray) - cb.setNumRows(0) - vectors.foreach(_.reset()) - var rowCount = 0 - while (rowCount < numRows && rowIterator.hasNext) { - val row = rowIterator.next() - converters.convert(row, vectors.toArray) - rowCount += 1 - } - if (!enableOffHeapColumnVector) { - vectors.foreach { v => - v.asInstanceOf[OmniColumnVector].getVec.setSize(rowCount) - } - } - cb.setNumRows(rowCount) - numInputRows += rowCount - numOutputBatches += 1 - rowToOmniColumnarTime += NANOSECONDS.toMillis(System.nanoTime() - startTime) - cb - } - } - } else { - Iterator.empty - } + InternalRowToColumnarBatch.convert(enableOffHeapColumnVector, numInputRows, numOutputBatches, rowToOmniColumnarTime, numRows, localSchema, rowIterator) } } + } @@ -298,6 +286,23 @@ case class OmniColumnarToRowExec(child: SparkPlan, |""".stripMargin } + override def doExecuteBroadcast[T](): Broadcast[T] = { + val numOutputRows = longMetric("numOutputRows") + val numInputBatches = longMetric("numInputBatches") + val omniColumnarToRowTime = longMetric("omniColumnarToRowTime") + val mode = BroadcastUtils.getBroadCastMode(outputPartitioning) + val relation = child.executeBroadcast() + val broadcast: Broadcast[T] = BroadcastUtils.omniToSparkUnsafe(sparkContext, + mode, + relation, + StructType.fromAttributes(output), + ColumnarBatchToInternalRow.convert(output, _, numOutputRows, numInputBatches, omniColumnarToRowTime) + ) + val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) + SQLMetrics.postDriverMetricUpdates(sparkContext, executionId, Seq(numOutputRows, numInputBatches, omniColumnarToRowTime)) + broadcast + } + override def doExecute(): RDD[InternalRow] = { val numOutputRows = longMetric("numOutputRows") val numInputBatches = longMetric("numInputBatches") @@ -311,7 +316,55 @@ case class OmniColumnarToRowExec(child: SparkPlan, } override protected def withNewChildInternal(newChild: SparkPlan): - OmniColumnarToRowExec = copy(child = newChild) + OmniColumnarToRowExec = copy(child = newChild) +} + +object InternalRowToColumnarBatch { + + def convert(enableOffHeapColumnVector: Boolean, + numInputRows: SQLMetric, + numOutputBatches: SQLMetric, + rowToOmniColumnarTime: SQLMetric, + numRows: Int, localSchema: StructType, + rowIterator: Iterator[InternalRow]): Iterator[ColumnarBatch] = { + if (rowIterator.hasNext) { + new Iterator[ColumnarBatch] { + private val converters = new OmniRowToColumnConverter(localSchema) + + override def hasNext: Boolean = { + rowIterator.hasNext + } + + override def next(): ColumnarBatch = { + val startTime = System.nanoTime() + val vectors: Seq[WritableColumnVector] = OmniColumnVector.allocateColumns(numRows, + localSchema, true) + val cb: ColumnarBatch = new ColumnarBatch(vectors.toArray) + cb.setNumRows(0) + vectors.foreach(_.reset()) + var rowCount = 0 + while (rowCount < numRows && rowIterator.hasNext) { + val row = rowIterator.next() + converters.convert(row, vectors.toArray) + rowCount += 1 + } + if (!enableOffHeapColumnVector) { + vectors.foreach { v => + v.asInstanceOf[OmniColumnVector].getVec.setSize(rowCount) + } + } + cb.setNumRows(rowCount) + numInputRows += rowCount + numOutputBatches += 1 + rowToOmniColumnarTime += NANOSECONDS.toMillis(System.nanoTime() - startTime) + cb + } + } + } else { + Iterator.empty + } + } + } object ColumnarBatchToInternalRow { @@ -346,16 +399,16 @@ object ColumnarBatchToInternalRow { SparkMemoryUtils.addLeakSafeTaskCompletionListener { _ => - toClosedVecs.foreach {vec => - vec.close() - } + toClosedVecs.foreach { vec => + vec.close() + } } - override def hasNext: Boolean = { + override def hasNext: Boolean = { val has = iter.hasNext // fetch all rows if (!has) { - toClosedVecs.foreach {vec => + toClosedVecs.foreach { vec => vec.close() toClosedVecs.remove(toClosedVecs.indexOf(vec)) } diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/util/BroadcastUtils.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/util/BroadcastUtils.scala new file mode 100644 index 0000000000000000000000000000000000000000..2204e0fe4b0e3c04ede12b58260ce06cbbb15ce1 --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/util/BroadcastUtils.scala @@ -0,0 +1,276 @@ +/* + * Copyright (C) 2024-2024. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.util + +import com.huawei.boostkit.spark.util.OmniAdaptorUtil.transColBatchToOmniVecs +import org.apache.spark.{SparkConf, SparkContext, TaskContext, TaskContextImpl} +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.executor.TaskMetrics +import org.apache.spark.memory.{TaskMemoryManager, UnifiedMemoryManager} +import org.apache.spark.metrics.MetricsSystem +import org.apache.spark.network.util.ByteUnit +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.catalyst.plans.physical.{BroadcastMode, BroadcastPartitioning, IdentityBroadcastMode, Partitioning} +import org.apache.spark.sql.execution.ColumnarHashedRelation +import org.apache.spark.sql.execution.joins.{HashedRelation, HashedRelationBroadcastMode} +import org.apache.spark.sql.execution.vectorized.OmniColumnVector +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.vectorized.ColumnarBatch + +import java.util.Properties +import scala.collection.mutable.{ArrayBuffer, ListBuffer} + +import nova.hetu.omniruntime.vector.serialize.VecBatchSerializerFactory +import nova.hetu.omniruntime.vector.VecBatch + +object BroadcastUtils { + + def getBroadCastMode(partitioning: Partitioning): BroadcastMode = { + partitioning match { + case BroadcastPartitioning(mode) => + mode + case _ => + throw new IllegalArgumentException("Unexpected partitioning: " + partitioning.toString) + } + } + + def omniToSparkUnsafe[F, T]( + context: SparkContext, + mode: BroadcastMode, + from: Broadcast[F], + fromSchema: StructType, + fn: Iterator[ColumnarBatch] => Iterator[InternalRow]): Broadcast[T] = { + mode match { + case HashedRelationBroadcastMode(_, _) => + // ColumnarHashedRelation to HashedRelation. + val fromBroadcast = from.asInstanceOf[Broadcast[ColumnarHashedRelation]] + val fromRelation = fromBroadcast.value + val toRelation = runUnSafe(context.getConf) { + val (rowCount, rowIterator) = deserializeRelation(fromSchema, fn, fromRelation) + mode.transform(rowIterator, Some(rowCount)) + } + // Rebroadcast Spark relation. + context.broadcast(toRelation).asInstanceOf[Broadcast[T]] + case IdentityBroadcastMode => + // ColumnarBuildSideRelation to Array. + val fromBroadcast = from.asInstanceOf[Broadcast[ColumnarHashedRelation]] + val fromRelation = fromBroadcast.value + val toRelation = runUnSafe(context.getConf) { + val (_, rowIterator) = deserializeRelation(fromSchema, fn, fromRelation) + val rowArray = new ArrayBuffer[InternalRow]() + while (rowIterator.hasNext) { + val unsafeRow = rowIterator.next().asInstanceOf[UnsafeRow] + rowArray.append(unsafeRow.copy()) + } + rowArray.toArray + } + context.broadcast(toRelation).asInstanceOf[Broadcast[T]] + case _ => throw new IllegalStateException("Unexpected broadcast mode: " + mode) + } + } + + def sparkToOmniUnsafe[F, T](context: SparkContext, + mode: BroadcastMode, + from: Broadcast[F], + logFunc: (=> String) => Unit, + fn: Iterator[InternalRow] => Iterator[ColumnarBatch]): Broadcast[T] = { + mode match { + case HashedRelationBroadcastMode(_, _) => + // HashedRelation to ColumnarHashedRelation + val fromBroadcast = from.asInstanceOf[Broadcast[HashedRelation]] + val fromRelation = fromBroadcast.value + val toRelation = runUnSafe(context.getConf) { + val (nullBatchCount, input) = serializeRelation(mode, fn(fromRelation.keys().flatMap(fromRelation.get)), logError = logFunc) + val relation = new ColumnarHashedRelation + relation.converterData(mode, nullBatchCount, input) + relation + } + // Rebroadcast Omni relation. + context.broadcast(toRelation).asInstanceOf[Broadcast[T]] + case IdentityBroadcastMode => + // ColumnarBuildSideRelation to Array. + val fromBroadcast = from.asInstanceOf[Broadcast[Array[InternalRow]]] + val fromRelation = fromBroadcast.value + val toRelation = runUnSafe(context.getConf) { + val (nullBatchCount, input) = serializeRelation(mode, fn(fromRelation.toIterator), logError = logFunc) + val relation = new ColumnarHashedRelation + relation.converterData(mode, nullBatchCount, input) + relation + } + // Rebroadcast Omni relation. + context.broadcast(toRelation).asInstanceOf[Broadcast[T]] + case _ => throw new IllegalStateException("Unexpected broadcast mode: " + mode) + } + } + + private def deserializeRelation(fromSchema: StructType, + fn: Iterator[ColumnarBatch] => Iterator[InternalRow], + fromRelation: ColumnarHashedRelation): (Long, Iterator[InternalRow]) = { + val deserializer = VecBatchSerializerFactory.create() + val data = fromRelation.buildData + var rowCount = 0 + val batchBuffer = ListBuffer[ColumnarBatch]() + val rowIterator = fn( + try { + data.map(bytes => { + val batch: VecBatch = deserializer.deserialize(bytes) + val columnarBatch = vecBatchToColumnarBatch(batch, fromSchema) + batchBuffer.append(columnarBatch) + rowCount += columnarBatch.numRows() + columnarBatch + }).toIterator + } catch { + case exception: Exception => + batchBuffer.foreach(_.close()) + throw exception + } + ) + (rowCount, rowIterator) + } + + private def serializeRelation(mode: BroadcastMode, + batches: Iterator[ColumnarBatch], + logError: (=> String) => Unit): (Int, Array[Array[Byte]]) = { + val serializer = VecBatchSerializerFactory.create() + var nullBatchCount = 0 + val nullRelationFlag = mode match { + case hashRelMode: HashedRelationBroadcastMode => + hashRelMode.isNullAware + case _ => false + } + val input: Array[Array[Byte]] = batches.map(batch => { + // When nullRelationFlag is true, it means anti-join + // Only one column of data is involved in the anti- + if (nullRelationFlag && batch.numCols() > 0) { + val vec = batch.column(0) + if (vec.hasNull) { + try { + nullBatchCount += 1 + } catch { + case e: Exception => + logError(s"compute null BatchCount error : ${e.getMessage}.") + } + } + } + val vectors = transColBatchToOmniVecs(batch) + val vecBatch = new VecBatch(vectors, batch.numRows()) + val vecBatchSer = serializer.serialize(vecBatch) + // close omni vec + vecBatch.releaseAllVectors() + vecBatch.close() + vecBatchSer + }).toArray + (nullBatchCount, input) + } + + private def vecBatchToColumnarBatch(vecBatch: VecBatch, schema: StructType): ColumnarBatch = { + val vectors: Seq[OmniColumnVector] = OmniColumnVector.allocateColumns( + vecBatch.getRowCount, schema, false) + vectors.zipWithIndex.foreach { case (vector, i) => + vector.reset() + vector.setVec(vecBatch.getVectors()(i)) + } + vecBatch.close() + new ColumnarBatch(vectors.toArray, vecBatch.getRowCount) + } + + // Run code with unsafe task context. If the call took place from Spark driver or test code + // without a Spark task context registered, a temporary unsafe task context instance will + // be created and used. Since unsafe task context is not managed by Spark's task memory manager, + // Spark may not be aware of the allocations happened inside the user code. + // + // The API should only be used in the following cases: + // + // 1. Run code on driver + // 2. Run test code + private def runUnSafe[T](sparkConf: SparkConf)(body: => T): T = { + if (inSparkTask()) { + throw new UnsupportedOperationException("runUnsafe should only be used outside Spark task") + } + TaskContext.setTaskContext(createUnSafeTaskContext(sparkConf)) + val context = getLocalTaskContext() + try { + val out = + try { + body + } catch { + case t: Throwable => + // Similar code with those in Task.scala + try { + context.markTaskFailed(t) + } catch { + case t: Throwable => + t.addSuppressed(t) + } + context.markTaskCompleted(Some(t)) + throw t + } finally { + try { + context.markTaskCompleted(None) + } finally { + unsetUnsafeTaskContext() + } + } + out + } catch { + case t: Throwable => + throw t + } + } + + private def getLocalTaskContext(): TaskContext = { + TaskContext.get() + } + + private def inSparkTask(): Boolean = { + TaskContext.get() != null + } + + private def unsetUnsafeTaskContext(): Unit = { + if (!inSparkTask()) { + throw new IllegalStateException() + } + if (getLocalTaskContext().taskAttemptId() != -1) { + throw new IllegalStateException() + } + TaskContext.unset() + } + + private def createUnSafeTaskContext(sparkConf: SparkConf): TaskContext = { + // driver code run on unsafe task context which is not managed by Spark's task memory manager, + // so the maxHeapMemory set ByteUnit.TiB.toBytes(2) + val memoryManager = + new UnifiedMemoryManager(sparkConf, ByteUnit.TiB.toBytes(2), ByteUnit.TiB.toBytes(1), 1) + new TaskContextImpl( + -1, + -1, + -1, + -1L, + -1, + new TaskMemoryManager(memoryManager, -1L), + new Properties, + MetricsSystem.createMetricsSystem("OMNI_UNSAFE", sparkConf), + TaskMetrics.empty, + 1, + Map.empty + ) + } + +} diff --git a/omnioperator/omniop-spark-extension/java/src/test/scala/com/huawei/boostkit/spark/FallbackStrategiesSuite.scala b/omnioperator/omniop-spark-extension/java/src/test/scala/com/huawei/boostkit/spark/FallbackStrategiesSuite.scala new file mode 100644 index 0000000000000000000000000000000000000000..9b97295fc1a8dd7b92171ade804e5997094ec92d --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/test/scala/com/huawei/boostkit/spark/FallbackStrategiesSuite.scala @@ -0,0 +1,426 @@ +/* + * Copyright (C) 2024-2024. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.huawei.boostkit.spark + +import com.huawei.boostkit.spark.util.InsertTransitions +import org.apache.spark.SparkConf +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{QueryTest, Row} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.plans.logical.Statistics +import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, BroadcastQueryStageExec, QueryStageExec} +import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, Exchange} +import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, ColumnarBroadcastHashJoinExec} +import org.apache.spark.sql.execution.{ColumnarBroadcastExchangeExec, ColumnarProjectExec, ColumnarTakeOrderedAndProjectExec, LeafExecNode, OmniColumnarToRowExec, ProjectExec, RowToOmniColumnarExec, SparkPlan, TakeOrderedAndProjectExec, UnaryExecNode} +import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf} +import org.apache.spark.sql.test.SharedSparkSession + +import scala.concurrent.Future + +class FallbackStrategiesSuite extends QueryTest with SharedSparkSession { + + import testImplicits._ + + override def sparkConf: SparkConf = super.sparkConf + .setAppName("test FallbackStrategiesSuite") + .set(StaticSQLConf.SPARK_SESSION_EXTENSIONS.key, "com.huawei.boostkit.spark.ColumnarPlugin") + .set(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, "false") + .set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.OmniColumnarShuffleManager") + + override def beforeAll(): Unit = { + super.beforeAll() + val employees = Seq[(String, String, Int, Int)]( + ("Lisa", "Sales", 10000, 35), + ("Evan", "Sales", 32000, 38), + ("Fred", "Engineering", 21000, 28), + ("Alex", "Sales", 30000, 33), + ("Tom", "Engineering", 23000, 33), + ("Jane", "Marketing", 29000, 28), + ("Jeff", "Marketing", 35000, 38), + ("Paul", "Engineering", 29000, 23), + ("Chloe", "Engineering", 23000, 25) + ).toDF("name", "dept", "salary", "age") + employees.createOrReplaceTempView("employees_for_fallback_ut_test") + } + + test("c2r doExecuteBroadcast") { + withSQLConf(("spark.omni.sql.columnar.wholeStage.fallback.threshold", "3"), + (SQLConf.ADAPTIVE_EXECUTION_ENABLED.key, "true"), + (SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key, "10MB"), + ("spark.omni.sql.columnar.broadcastJoin", "false")) { + val df = spark.sql("select t1.age * 2, t2.salary from employees_for_fallback_ut_test t1 join employees_for_fallback_ut_test t2 on t1.age = t2.age sort by t1.age") + val runRows = df.collect() + val expectedRows = Seq(Row(56, 21000), Row(56, 21000), + Row(66, 30000), Row(66, 30000), + Row(70, 10000), Row(76, 32000), + Row(76, 32000), Row(46, 29000), + Row(50, 23000), Row(56, 29000), + Row(56, 29000), Row(66, 23000), + Row(66, 23000),Row(76, 35000),Row(76, 35000)) + assert(QueryTest.sameRows(runRows, expectedRows).isEmpty, "the run value is error") + val bhj = df.queryExecution.executedPlan.asInstanceOf[AdaptiveSparkPlanExec].executedPlan.find({ + case _: BroadcastHashJoinExec => true + case _ => false + }) + assert(bhj.isDefined, "bhj is fallback") + val c2rWithOmniBroadCast = bhj.get.children.find({ + case p: OmniColumnarToRowExec => + p.child.isInstanceOf[BroadcastQueryStageExec] && + p.child.asInstanceOf[BroadcastQueryStageExec].plan.isInstanceOf[ColumnarBroadcastExchangeExec] + case _ => false + }) + assert(c2rWithOmniBroadCast.isDefined, "bhj should use omni c2r to adapt to omni broadcast exchange") + } + } + + test("r2c doExecuteBroadcast") { + withSQLConf(("spark.omni.sql.columnar.wholeStage.fallback.threshold", "3"), + (SQLConf.ADAPTIVE_EXECUTION_ENABLED.key, "true"), + (SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key, "10MB"), + ("spark.omni.sql.columnar.broadcastexchange", "false")) { + val df = spark.sql("select t1.age * 2, t2.salary from employees_for_fallback_ut_test t1 join employees_for_fallback_ut_test t2 on t1.age = t2.age sort by t1.age") + val runRows = df.collect() + val expectedRows = Seq(Row(56, 21000), Row(56, 21000), + Row(66, 30000), Row(66, 30000), + Row(70, 10000), Row(76, 32000), + Row(76, 32000), Row(46, 29000), + Row(50, 23000), Row(56, 29000), + Row(56, 29000), Row(66, 23000), + Row(66, 23000),Row(76, 35000),Row(76, 35000)) + assert(QueryTest.sameRows(runRows, expectedRows).isEmpty, "the run value is error") + val omniBhj = df.queryExecution.executedPlan.asInstanceOf[AdaptiveSparkPlanExec].executedPlan.find({ + case _: ColumnarBroadcastHashJoinExec => true + case _ => false + }) + assert(omniBhj.isDefined, "bhj should be columnar") + val r2cWithBroadCast = omniBhj.get.children.find({ + case p: RowToOmniColumnarExec => + p.child.isInstanceOf[BroadcastQueryStageExec] && + p.child.asInstanceOf[BroadcastQueryStageExec].plan.isInstanceOf[BroadcastExchangeExec] + case _ => false + }) + assert(r2cWithBroadCast.isDefined, "OmniBhj should use omni r2c to adapt to broadcast exchange") + } + } + + test("Fall back stage contain bhj") { + withSQLConf(("spark.omni.sql.columnar.wholeStage.fallback.threshold", "3"), + ("spark.omni.sql.columnar.project", "false"), + (SQLConf.ADAPTIVE_EXECUTION_ENABLED.key, "true"), + (SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key, "10MB")) { + val df = spark.sql("select t1.age * 2, t2.salary from employees_for_fallback_ut_test t1 join employees_for_fallback_ut_test t2 on t1.age = t2.age sort by t1.age") + val runRows = df.collect() + val expectedRows = Seq(Row(56, 21000), Row(56, 21000), + Row(66, 30000), Row(66, 30000), + Row(70, 10000), Row(76, 32000), + Row(76, 32000), Row(46, 29000), + Row(50, 23000), Row(56, 29000), + Row(56, 29000), Row(66, 23000), + Row(66, 23000),Row(76, 35000),Row(76, 35000)) + assert(QueryTest.sameRows(runRows, expectedRows).isEmpty, "the run value is error") + val plans = df.queryExecution.executedPlan.asInstanceOf[AdaptiveSparkPlanExec].executedPlan.collect({ + case plan: BroadcastHashJoinExec => plan + }) + assert(plans.size == 1, "the last stage containing bhj should fallback") + } + } + + + test("Fall back the last stage contains unsupported bnlj if meeting the configured threshold") { + withSQLConf(("spark.omni.sql.columnar.wholeStage.fallback.threshold", "2"), + (SQLConf.ADAPTIVE_EXECUTION_ENABLED.key, "true")) { + val df = spark.sql("select age + salary from (select age, salary from (select age from employees_for_fallback_ut_test order by age limit 1) s1," + + " (select salary from employees_for_fallback_ut_test order by salary limit 1) s2)") + QueryTest.checkAnswer(df, Seq(Row(10023))) + val plans = df.queryExecution.executedPlan.asInstanceOf[AdaptiveSparkPlanExec].executedPlan.collect({ + case plan: ProjectExec => plan + case plan: TakeOrderedAndProjectExec => plan + }) + assert(plans.count(_.isInstanceOf[ProjectExec]) == 1, "the last stage containing projectExec should fallback") + assert(plans.count(_.isInstanceOf[TakeOrderedAndProjectExec]) == 1, "the last stage containing projectExec should fallback") + } + } + + test("Don't Fall back the last stage contains unsupported bnlj if NOT meeting the configured threshold") { + withSQLConf(("spark.omni.sql.columnar.wholeStage.fallback.threshold", "3"), + (SQLConf.ADAPTIVE_EXECUTION_ENABLED.key, "true")) { + val df = spark.sql("select age + salary from (select age, salary from (select age from employees_for_fallback_ut_test order by age limit 1) s1," + + " (select salary from employees_for_fallback_ut_test order by salary limit 1) s2)") + QueryTest.checkAnswer(df, Seq(Row(10023))) + val plans = df.queryExecution.executedPlan.asInstanceOf[AdaptiveSparkPlanExec].executedPlan.collect({ + case plan: ColumnarProjectExec => plan + case plan: ColumnarTakeOrderedAndProjectExec => plan + }) + assert(plans.count(_.isInstanceOf[ColumnarProjectExec]) == 1, "the last stage containing projectExec should not fallback") + assert(plans.count(_.isInstanceOf[ColumnarTakeOrderedAndProjectExec]) == 1, "the last stage containing projectExec should not fallback") + } + } + + + test("Fall back the whole query if one unsupported") { + withSQLConf(("spark.omni.sql.columnar.query.fallback.threshold", "1")) { + val originalPlan = UnaryOp2(UnaryOp1(UnaryOp2(UnaryOp1(LeafOp())))) + val rule = ColumnarOverrideRules(spark) + rule.preColumnarTransitions(originalPlan) + // Fake output of preColumnarTransitions, mocking replacing UnaryOp1 with UnaryOp1Transformer. + val planAfterPreOverride = + UnaryOp2(ColumnarUnaryOp1(UnaryOp2(ColumnarUnaryOp1(LeafOp())))) + val planWithTransition = InsertTransitions.insertTransitions(planAfterPreOverride, false) + val outputPlan = rule.postColumnarTransitions(planWithTransition) + // Expect to fall back the entire plan. + assert(outputPlan == originalPlan) + } + } + + test("Fall back the whole plan if meeting the configured threshold") { + withSQLConf(("spark.omni.sql.columnar.wholeStage.fallback.threshold", "1")) { + val originalPlan = UnaryOp2(UnaryOp1(UnaryOp2(UnaryOp1(LeafOp())))) + val rule = ColumnarOverrideRules(spark) + rule.preColumnarTransitions(originalPlan) + rule.enableAdaptiveContext() + // Fake output of preColumnarTransitions, mocking replacing UnaryOp1 with UnaryOp1Transformer. + val planAfterPreOverride = + UnaryOp2(ColumnarUnaryOp1(UnaryOp2(ColumnarUnaryOp1(LeafOp())))) + val planWithTransition = InsertTransitions.insertTransitions(planAfterPreOverride, false) + val outputPlan = rule.postColumnarTransitions(planWithTransition) + // Expect to fall back the entire plan. + assert(outputPlan == originalPlan) + } + } + + test("Don't fall back the whole plan if NOT meeting the configured threshold") { + withSQLConf(("spark.omni.sql.columnar.wholeStage.fallback.threshold", "4")) { + val originalPlan = UnaryOp2(UnaryOp1(UnaryOp2(UnaryOp1(LeafOp())))) + val rule = ColumnarOverrideRules(spark) + rule.preColumnarTransitions(originalPlan) + rule.enableAdaptiveContext() + // Fake output of preColumnarTransitions, mocking replacing UnaryOp1 with UnaryOp1Transformer. + val planAfterPreOverride = + UnaryOp2(ColumnarUnaryOp1(UnaryOp2(ColumnarUnaryOp1(LeafOp())))) + val planWithTransition = InsertTransitions.insertTransitions(planAfterPreOverride, false) + val outputPlan = rule.postColumnarTransitions(planWithTransition) + val columnarUnaryOps = outputPlan.collect({ + case p: ColumnarUnaryOp1 => p + }) + // Expect to get the plan with columnar rule applied. + assert(columnarUnaryOps.size == 2) + assert(outputPlan != originalPlan) + } + } + + test( + "Fall back the whole plan if meeting the configured threshold (leaf node is" + + " transformable)") { + withSQLConf(("spark.omni.sql.columnar.wholeStage.fallback.threshold", "2")) { + val originalPlan = UnaryOp2(UnaryOp1(UnaryOp2(UnaryOp1(LeafOp())))) + val rule = ColumnarOverrideRules(spark) + rule.preColumnarTransitions(originalPlan) + rule.enableAdaptiveContext() // Fake output of preColumnarTransitions, mocking replacing UnaryOp1 with UnaryOp1Transformer + // and replacing LeafOp with LeafOpTransformer. + val planAfterPreOverride = + UnaryOp2(ColumnarUnaryOp1(UnaryOp2(ColumnarUnaryOp1(ColumnarLeafOp())))) + val planWithTransition = InsertTransitions.insertTransitions(planAfterPreOverride, false) + val outputPlan = rule.postColumnarTransitions(planWithTransition) + // Expect to fall back the entire plan. + val columnarUnaryOps = outputPlan.collect({ + case p: ColumnarUnaryOp1 => p + }) + assert(columnarUnaryOps.isEmpty) + assert(outputPlan == originalPlan) + } + } + + test( + "Don't Fall back the whole plan if NOT meeting the configured threshold (" + + "leaf node is transformable)") { + withSQLConf(("spark.omni.sql.columnar.wholeStage.fallback.threshold", "3")) { + val originalPlan = UnaryOp2(UnaryOp1(UnaryOp2(UnaryOp1(LeafOp())))) + val rule = ColumnarOverrideRules(spark) + rule.preColumnarTransitions(originalPlan) + rule.enableAdaptiveContext() + // Fake output of preColumnarTransitions, mocking replacing UnaryOp1 with UnaryOp1Transformer + // and replacing LeafOp with LeafOpTransformer. + val planAfterPreOverride = + UnaryOp2(ColumnarUnaryOp1(UnaryOp2(ColumnarUnaryOp1(ColumnarLeafOp())))) + val planWithTransition = InsertTransitions.insertTransitions(planAfterPreOverride, false) + val outputPlan = rule.postColumnarTransitions(planWithTransition) + // Expect to get the plan with columnar rule applied. + val columnarUnaryOps = outputPlan.collect({ + case p: ColumnarUnaryOp1 => p + case p: ColumnarLeafOp => p + }) + assert(columnarUnaryOps.size == 3) + assert(outputPlan != originalPlan) + } + } + + test("Don't fall back the whole query if all supported") { + withSQLConf(("spark.omni.sql.columnar.query.fallback.threshold", "1")) { + val originalPlan = UnaryOp1(UnaryOp1(UnaryOp1(UnaryOp1(LeafOp())))) + val rule = ColumnarOverrideRules(spark) + rule.preColumnarTransitions(originalPlan) + // Fake output of preColumnarTransitions, mocking replacing UnaryOp1 with UnaryOp1Transformer. + val planAfterPreOverride = + ColumnarUnaryOp1(ColumnarUnaryOp1(ColumnarUnaryOp1(ColumnarUnaryOp1(ColumnarLeafOp())))) + val planWithTransition = InsertTransitions.insertTransitions(planAfterPreOverride, false) + val outputPlan = rule.postColumnarTransitions(planWithTransition) + // Expect to not fall back the entire plan. + val columnPlans = outputPlan.collect({ + case p@(ColumnarExchange1(_, _) | ColumnarUnaryOp1(_, _) | ColumnarLeafOp(_)) => p + }) + assert(columnPlans.size == 5) + assert(outputPlan != originalPlan) + } + } + + test("Don't fall back the whole plan if all supported") { + withSQLConf(("spark.omni.sql.columnar.wholeStage.fallback.threshold", "1")) { + val originalPlan = Exchange1(UnaryOp1(UnaryOp1(UnaryOp1(LeafOp())))) + val rule = ColumnarOverrideRules(spark) + rule.preColumnarTransitions(originalPlan) + rule.enableAdaptiveContext() + // Fake output of preColumnarTransitions, mocking replacing UnaryOp1 with UnaryOp1Transformer. + val planAfterPreOverride = + ColumnarExchange1(ColumnarUnaryOp1(ColumnarUnaryOp1(ColumnarUnaryOp1(ColumnarLeafOp())))) + val planWithTransition = InsertTransitions.insertTransitions(planAfterPreOverride, false) + val outputPlan = rule.postColumnarTransitions(planWithTransition) + // Expect to not fall back the entire plan. + val columnPlans = outputPlan.collect({ + case p@(ColumnarExchange1(_, _) | ColumnarUnaryOp1(_, _) | ColumnarLeafOp(_)) => p + }) + assert(columnPlans.size == 5) + assert(outputPlan != originalPlan) + } + } + + test("Fall back the whole plan if one supported plan before queryStage") { + withSQLConf(("spark.omni.sql.columnar.wholeStage.fallback.threshold", "1")) { + val mockQueryStageExec = QueryStageExec1(LeafOp(), LeafOp()) + val originalPlan = Exchange1(UnaryOp1(UnaryOp1(mockQueryStageExec))) + val rule = ColumnarOverrideRules(spark) + rule.preColumnarTransitions(originalPlan) + rule.enableAdaptiveContext() + // Fake output of preColumnarTransitions, mocking replacing UnaryOp1 with UnaryOp1Transformer. + val planAfterPreOverride = + ColumnarExchange1(ColumnarUnaryOp1(UnaryOp1(mockQueryStageExec))) + val planWithTransition = InsertTransitions.insertTransitions(planAfterPreOverride, false) + val outputPlan = rule.postColumnarTransitions(planWithTransition) + // Expect to fall back the entire plan. + val columnPlans = outputPlan.collect({ + case p@(ColumnarExchange1(_, _) | ColumnarUnaryOp1(_, _) | ColumnarLeafOp(_)) => p + }) + assert(columnPlans.isEmpty) + assert(outputPlan == originalPlan) + } + } + +} + +case class LeafOp(override val supportsColumnar: Boolean = false) extends LeafExecNode { + override protected def doExecute(): RDD[InternalRow] = throw new UnsupportedOperationException() + + override def output: Seq[Attribute] = Seq.empty +} + +case class UnaryOp1(child: SparkPlan, override val supportsColumnar: Boolean = false) + extends UnaryExecNode { + override protected def doExecute(): RDD[InternalRow] = throw new UnsupportedOperationException() + + override def output: Seq[Attribute] = child.output + + override protected def withNewChildInternal(newChild: SparkPlan): UnaryOp1 = + copy(child = newChild) +} + +case class UnaryOp2(child: SparkPlan, override val supportsColumnar: Boolean = false) + extends UnaryExecNode { + override protected def doExecute(): RDD[InternalRow] = throw new UnsupportedOperationException() + + override def output: Seq[Attribute] = child.output + + override protected def withNewChildInternal(newChild: SparkPlan): UnaryOp2 = + copy(child = newChild) +} + +// For replacing LeafOp. +case class ColumnarLeafOp(override val supportsColumnar: Boolean = true) + extends LeafExecNode { + + override def nodeName: String = "OmniColumnarLeafOp" + + override protected def doExecute(): RDD[InternalRow] = throw new UnsupportedOperationException() + + override def output: Seq[Attribute] = Seq.empty +} + +// For replacing UnaryOp1. +case class ColumnarUnaryOp1( + override val child: SparkPlan, + override val supportsColumnar: Boolean = true) + extends UnaryExecNode { + override protected def doExecute(): RDD[InternalRow] = throw new UnsupportedOperationException() + + override def output: Seq[Attribute] = child.output + + override def nodeName: String = "OmniColumnarUnaryOp1" + + override protected def withNewChildInternal(newChild: SparkPlan): ColumnarUnaryOp1 = + copy(child = newChild) +} + +case class Exchange1( + override val child: SparkPlan, + override val supportsColumnar: Boolean = false) + extends Exchange { + override protected def doExecute(): RDD[InternalRow] = throw new UnsupportedOperationException() + + override def output: Seq[Attribute] = child.output + + override protected def withNewChildInternal(newChild: SparkPlan): Exchange1 = + copy(child = newChild) +} + +case class ColumnarExchange1( + override val child: SparkPlan, + override val supportsColumnar: Boolean = true) + extends Exchange { + override protected def doExecute(): RDD[InternalRow] = throw new UnsupportedOperationException() + + override def output: Seq[Attribute] = child.output + + override def nodeName: String = "OmniColumnarExchange" + + override protected def withNewChildInternal(newChild: SparkPlan): ColumnarExchange1 = + copy(child = newChild) +} + +case class QueryStageExec1(override val plan: SparkPlan, + override val _canonicalized: SparkPlan) extends QueryStageExec { + + override val id: Int = 0 + + override def doMaterialize(): Future[Any] = throw new UnsupportedOperationException("it is mock spark plan") + + override def cancel(): Unit = {} + + override def newReuseInstance(newStageId: Int, newOutput: Seq[Attribute]): QueryStageExec = throw new UnsupportedOperationException("it is mock spark plan") + + override def getRuntimeStatistics: Statistics = throw new UnsupportedOperationException("it is mock spark plan") +}