diff --git a/omnioperator/omniop-spark-extension/java/pom.xml b/omnioperator/omniop-spark-extension/java/pom.xml index caafa313fbd2cb88b124e370f2d73460199b7051..3e3175bab7d6d61499e28ea729810f9865cf4474 100644 --- a/omnioperator/omniop-spark-extension/java/pom.xml +++ b/omnioperator/omniop-spark-extension/java/pom.xml @@ -7,7 +7,7 @@ com.huawei.kunpeng boostkit-omniop-spark-parent - 3.1.1-1.1.0 + 3.3.1-1.1.0 ../pom.xml @@ -103,20 +103,20 @@ spark-core_${scala.binary.version} test-jar test - 3.1.1 + 3.3.1 org.apache.spark spark-catalyst_${scala.binary.version} test-jar test - 3.1.1 + 3.3.1 org.apache.spark spark-sql_${scala.binary.version} test-jar - 3.1.1 + 3.3.1 test @@ -127,7 +127,7 @@ org.apache.spark spark-hive_${scala.binary.version} - 3.1.1 + 3.3.1 provided diff --git a/omnioperator/omniop-spark-extension/java/src/main/java/org/apache/spark/sql/execution/vectorized/OmniColumnVector.java b/omnioperator/omniop-spark-extension/java/src/main/java/org/apache/spark/sql/execution/vectorized/OmniColumnVector.java index 808f96e1fb666def4ff9fc224f01020a81a5baf7..cc750a371cdb64c1e60eac27f5b5881a964b5ac4 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/java/org/apache/spark/sql/execution/vectorized/OmniColumnVector.java +++ b/omnioperator/omniop-spark-extension/java/src/main/java/org/apache/spark/sql/execution/vectorized/OmniColumnVector.java @@ -354,6 +354,11 @@ public class OmniColumnVector extends WritableColumnVector { } } + @Override + public void putBooleans(int rowId, byte src) { + throw new UnsupportedOperationException("putBooleans is not supported"); + } + @Override public boolean getBoolean(int rowId) { if (dictionaryData != null) { @@ -453,6 +458,11 @@ public class OmniColumnVector extends WritableColumnVector { return UTF8String.fromBytes(getBytes(rowId, count), rowId, count); } + @Override + public ByteBuffer getByteBuffer(int rowId, int count) { + throw new UnsupportedOperationException("getByteBuffer is not supported"); + } + // // APIs dealing with Shorts // diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarGuardRule.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarGuardRule.scala index a4e4eaa0a877f7ee2e3401ecf4ee98fecfcb7314..ec075787233b025ad5caae8a8c80c473a573b5e1 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarGuardRule.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarGuardRule.scala @@ -22,7 +22,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution._ -import org.apache.spark.sql.execution.adaptive.{BroadcastQueryStageExec, CustomShuffleReaderExec} +import org.apache.spark.sql.execution.adaptive.{BroadcastQueryStageExec, OmniAQEShuffleReadExec} import org.apache.spark.sql.execution.aggregate.HashAggregateExec import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec} import org.apache.spark.sql.execution.joins._ @@ -37,6 +37,9 @@ case class RowGuard(child: SparkPlan) extends SparkPlan { } def children: Seq[SparkPlan] = Seq(child) + + @Override override protected def withNewChildrenInternal(newChildren: IndexedSeq[SparkPlan]): SparkPlan = + legacyWithNewChildren(newChildren) } case class ColumnarGuardRule() extends Rule[SparkPlan] { @@ -92,6 +95,8 @@ case class ColumnarGuardRule() extends Rule[SparkPlan] { if (!enableColumnarHashAgg) return false new ColumnarHashAggregateExec( plan.requiredChildDistributionExpressions, + plan.isStreaming, + plan.numShufflePartitions, plan.groupingExpressions, plan.aggregateExpressions, plan.aggregateAttributes, @@ -127,9 +132,9 @@ case class ColumnarGuardRule() extends Rule[SparkPlan] { left match { case exec: BroadcastExchangeExec => new ColumnarBroadcastExchangeExec(exec.mode, exec.child) - case BroadcastQueryStageExec(_, plan: BroadcastExchangeExec) => + case BroadcastQueryStageExec(_, plan: BroadcastExchangeExec, _) => new ColumnarBroadcastExchangeExec(plan.mode, plan.child) - case BroadcastQueryStageExec(_, plan: ReusedExchangeExec) => + case BroadcastQueryStageExec(_, plan: ReusedExchangeExec, _) => plan match { case ReusedExchangeExec(_, b: BroadcastExchangeExec) => new ColumnarBroadcastExchangeExec(b.mode, b.child) @@ -141,9 +146,9 @@ case class ColumnarGuardRule() extends Rule[SparkPlan] { right match { case exec: BroadcastExchangeExec => new ColumnarBroadcastExchangeExec(exec.mode, exec.child) - case BroadcastQueryStageExec(_, plan: BroadcastExchangeExec) => + case BroadcastQueryStageExec(_, plan: BroadcastExchangeExec, _) => new ColumnarBroadcastExchangeExec(plan.mode, plan.child) - case BroadcastQueryStageExec(_, plan: ReusedExchangeExec) => + case BroadcastQueryStageExec(_, plan: ReusedExchangeExec, _) => plan match { case ReusedExchangeExec(_, b: BroadcastExchangeExec) => new ColumnarBroadcastExchangeExec(b.mode, b.child) @@ -182,7 +187,8 @@ case class ColumnarGuardRule() extends Rule[SparkPlan] { plan.buildSide, plan.condition, plan.left, - plan.right).buildCheck() + plan.right, + plan.isSkewJoin).buildCheck() case plan: BroadcastNestedLoopJoinExec => return false case p => p @@ -237,7 +243,7 @@ case class ColumnarGuardRule() extends Rule[SparkPlan] { case p if !supportCodegen(p) => // insert row guard them recursively p.withNewChildren(p.children.map(insertRowGuardOrNot)) - case p: CustomShuffleReaderExec => + case p: OmniAQEShuffleReadExec => p.withNewChildren(p.children.map(insertRowGuardOrNot)) case p: BroadcastQueryStageExec => p diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarPlugin.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarPlugin.scala index d3fcbaf539493ff8a26b72b6e1b98c4c448b54c9..a94eb5d67d9612feee34c5c23194bf60537392d9 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarPlugin.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarPlugin.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.expressions.DynamicPruningSubquery import org.apache.spark.sql.catalyst.expressions.aggregate.Partial import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.{RowToOmniColumnarExec, _} -import org.apache.spark.sql.execution.adaptive.{BroadcastQueryStageExec, ColumnarCustomShuffleReaderExec, CustomShuffleReaderExec, QueryStageExec, ShuffleQueryStageExec} +import org.apache.spark.sql.execution.adaptive.{BroadcastQueryStageExec, OmniAQEShuffleReadExec, AQEShuffleReadExec, QueryStageExec, ShuffleQueryStageExec} import org.apache.spark.sql.execution.aggregate.HashAggregateExec import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, Exchange, ReusedExchangeExec, ShuffleExchangeExec} import org.apache.spark.sql.execution.joins._ @@ -247,6 +247,8 @@ case class ColumnarPreOverrides() extends Rule[SparkPlan] { case _ => new ColumnarHashAggregateExec( plan.requiredChildDistributionExpressions, + plan.isStreaming, + plan.numShufflePartitions, plan.groupingExpressions, plan.aggregateExpressions, plan.aggregateAttributes, @@ -257,6 +259,8 @@ case class ColumnarPreOverrides() extends Rule[SparkPlan] { } else { new ColumnarHashAggregateExec( plan.requiredChildDistributionExpressions, + plan.isStreaming, + plan.numShufflePartitions, plan.groupingExpressions, plan.aggregateExpressions, plan.aggregateAttributes, @@ -267,6 +271,8 @@ case class ColumnarPreOverrides() extends Rule[SparkPlan] { } else { new ColumnarHashAggregateExec( plan.requiredChildDistributionExpressions, + plan.isStreaming, + plan.numShufflePartitions, plan.groupingExpressions, plan.aggregateExpressions, plan.aggregateAttributes, @@ -311,7 +317,8 @@ case class ColumnarPreOverrides() extends Rule[SparkPlan] { plan.buildSide, plan.condition, left, - right) + right, + plan.isSkewJoin) case plan: SortMergeJoinExec if enableColumnarSortMergeJoin => logInfo(s"Columnar Processing for ${plan.getClass} is currently supported.") val left = replaceWithColumnarPlan(plan.left) @@ -341,19 +348,19 @@ case class ColumnarPreOverrides() extends Rule[SparkPlan] { val child = replaceWithColumnarPlan(plan.child) logInfo(s"Columnar Processing for ${plan.getClass} is currently supported.") new ColumnarShuffleExchangeExec(plan.outputPartitioning, child, plan.shuffleOrigin) - case plan: CustomShuffleReaderExec if columnarConf.enableColumnarShuffle => + case plan: AQEShuffleReadExec if columnarConf.enableColumnarShuffle => plan.child match { case shuffle: ColumnarShuffleExchangeExec => logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.") - ColumnarCustomShuffleReaderExec(plan.child, plan.partitionSpecs) - case ShuffleQueryStageExec(_, shuffle: ColumnarShuffleExchangeExec) => + OmniAQEShuffleReadExec(plan.child, plan.partitionSpecs) + case ShuffleQueryStageExec(_, shuffle: ColumnarShuffleExchangeExec, _) => logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.") - ColumnarCustomShuffleReaderExec(plan.child, plan.partitionSpecs) - case ShuffleQueryStageExec(_, reused: ReusedExchangeExec) => + OmniAQEShuffleReadExec(plan.child, plan.partitionSpecs) + case ShuffleQueryStageExec(_, reused: ReusedExchangeExec, _) => reused match { case ReusedExchangeExec(_, shuffle: ColumnarShuffleExchangeExec) => logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.") - ColumnarCustomShuffleReaderExec( + OmniAQEShuffleReadExec( plan.child, plan.partitionSpecs) case _ => @@ -375,13 +382,15 @@ case class ColumnarPreOverrides() extends Rule[SparkPlan] { curPlan.id, BroadcastExchangeExec( originalBroadcastPlan.mode, - ColumnarBroadcastExchangeAdaptorExec(originalBroadcastPlan, 1))) + ColumnarBroadcastExchangeAdaptorExec(originalBroadcastPlan, 1)), + curPlan._canonicalized) case ReusedExchangeExec(_, originalBroadcastPlan: ColumnarBroadcastExchangeExec) => BroadcastQueryStageExec( curPlan.id, BroadcastExchangeExec( originalBroadcastPlan.mode, - ColumnarBroadcastExchangeAdaptorExec(curPlan.plan, 1))) + ColumnarBroadcastExchangeAdaptorExec(curPlan.plan, 1)), + curPlan._canonicalized) case _ => curPlan } @@ -409,11 +418,26 @@ case class ColumnarPostOverrides() extends Rule[SparkPlan] { case ColumnarToRowExec(child: ColumnarBroadcastExchangeExec) => replaceWithColumnarPlan(child) case plan: ColumnarToRowExec => - val child = replaceWithColumnarPlan(plan.child) - if (conf.getConfString("spark.omni.sql.columnar.columnarToRow", "true").toBoolean) { - OmniColumnarToRowExec(child) - } else { - ColumnarToRowExec(child) + plan.child match { + case child: BroadcastQueryStageExec => + child.plan match { + case originalBroadcastPlan: ColumnarBroadcastExchangeExec => + BroadcastQueryStageExec( + child.id, + BroadcastExchangeExec( + originalBroadcastPlan.mode, + ColumnarBroadcastExchangeAdaptorExec(originalBroadcastPlan, 1)), child._canonicalized) + case ReusedExchangeExec(_, originalBroadcastPlan: ColumnarBroadcastExchangeExec) => + BroadcastQueryStageExec( + child.id, + BroadcastExchangeExec( + originalBroadcastPlan.mode, + ColumnarBroadcastExchangeAdaptorExec(child.plan, 1)), child._canonicalized) + case _ => + replaceColumnarToRow(plan, conf) + } + case _ => + replaceColumnarToRow(plan, conf) } case r: SparkPlan if !r.isInstanceOf[QueryStageExec] && !r.supportsColumnar && r.children.exists(c => @@ -430,6 +454,15 @@ case class ColumnarPostOverrides() extends Rule[SparkPlan] { val children = p.children.map(replaceWithColumnarPlan) p.withNewChildren(children) } + + def replaceColumnarToRow(plan: ColumnarToRowExec, conf: SQLConf) : SparkPlan = { + val child = replaceWithColumnarPlan(plan.child) + if (conf.getConfString("spark.omni.sql.columnar.columnarToRow", "true").toBoolean) { + OmniColumnarToRowExec(child) + } else { + ColumnarToRowExec(child) + } + } } case class ColumnarOverrideRules(session: SparkSession) extends ColumnarRule with Logging { diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarPluginConfig.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarPluginConfig.scala index 29776a07ac820c9d243df4d7cc8fc6135f1b8db6..a698c81089f517d06eaeb011a6606a3a39b81054 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarPluginConfig.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarPluginConfig.scala @@ -153,7 +153,7 @@ class ColumnarPluginConfig(conf: SQLConf) extends Logging { .toBoolean val enableFusion: Boolean = conf - .getConfString("spark.omni.sql.columnar.fusion", "true") + .getConfString("spark.omni.sql.columnar.fusion", "false") .toBoolean // Pick columnar shuffle hash join if one side join count > = 0 to build local hash map, and is diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ShuffleJoinStrategy.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ShuffleJoinStrategy.scala index 2071420c9d219e4b4a029bd17eba114ac9d2dd7e..6b065552ceb1d46327fd3644001b0ca408d5d46b 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ShuffleJoinStrategy.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ShuffleJoinStrategy.scala @@ -37,7 +37,7 @@ object ShuffleJoinStrategy extends Strategy ColumnarPluginConfig.getConf.columnarPreferShuffledHashJoinCBO def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, nonEquiCond, left, right, hint) + case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, nonEquiCond, _, left, right, hint) if columnarPreferShuffledHashJoin => val enable = getBroadcastBuildSide(left, right, joinType, hint, true, conf).isEmpty && !hintToSortMergeJoin(hint) && diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/expression/OmniExpressionAdaptor.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/expression/OmniExpressionAdaptor.scala index da1a5b7479d8a3814ef6b07dde6b9ad8acdcd50e..c4307082aec6aa42fbd5754b42950fd0c3f201f4 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/expression/OmniExpressionAdaptor.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/expression/OmniExpressionAdaptor.scala @@ -668,9 +668,9 @@ object OmniExpressionAdaptor extends Logging { def toOmniAggFunType(agg: AggregateExpression, isHashAgg: Boolean = false, isFinal: Boolean = false): FunctionType = { agg.aggregateFunction match { - case Sum(_) => OMNI_AGGREGATION_TYPE_SUM + case Sum(_, _) => OMNI_AGGREGATION_TYPE_SUM case Max(_) => OMNI_AGGREGATION_TYPE_MAX - case Average(_) => OMNI_AGGREGATION_TYPE_AVG + case Average(_, _) => OMNI_AGGREGATION_TYPE_AVG case Min(_) => OMNI_AGGREGATION_TYPE_MIN case Count(Literal(1, IntegerType) :: Nil) | Count(ArrayBuffer(Literal(1, IntegerType))) => if (isFinal) { diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/shuffle/ColumnarShuffleWriter.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/shuffle/ColumnarShuffleWriter.scala index 7eca3427ec3f6c618f84e70aeb85ce98d0267176..4883203c99954c5fb596f3804e5903c57709db4a 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/shuffle/ColumnarShuffleWriter.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/shuffle/ColumnarShuffleWriter.scala @@ -71,7 +71,7 @@ class ColumnarShuffleWriter[K, V]( override def write(records: Iterator[Product2[K, V]]): Unit = { if (!records.hasNext) { partitionLengths = new Array[Long](dep.partitioner.numPartitions) - shuffleBlockResolver.writeIndexFileAndCommit(dep.shuffleId, mapId, partitionLengths, null) + shuffleBlockResolver.writeMetadataFileAndCommit(dep.shuffleId, mapId, partitionLengths, Array[Long](), null) mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths, mapId) return } @@ -107,7 +107,7 @@ class ColumnarShuffleWriter[K, V]( jniWrapper.split(nativeSplitter, vb.getNativeVectorBatch) dep.splitTime.add(System.nanoTime() - startTime) dep.numInputRows.add(cb.numRows) - writeMetrics.incRecordsWritten(1) + writeMetrics.incRecordsWritten(cb.numRows()) } } val startTime = System.nanoTime() @@ -122,10 +122,11 @@ class ColumnarShuffleWriter[K, V]( partitionLengths = splitResult.getPartitionLengths try { - shuffleBlockResolver.writeIndexFileAndCommit( + shuffleBlockResolver.writeMetadataFileAndCommit( dep.shuffleId, mapId, partitionLengths, + Array[Long](), dataTmp) } finally { if (dataTmp.exists() && !dataTmp.delete()) { diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/shuffle/sort/OmniColumnarShuffleManager.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/shuffle/sort/OmniColumnarShuffleManager.scala index e7c66ee726ae4b9090e41e5d71de386e4b94ed13..28427bba2842f77d53327121001966dbbdb17a01 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/shuffle/sort/OmniColumnarShuffleManager.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/shuffle/sort/OmniColumnarShuffleManager.scala @@ -99,7 +99,7 @@ class OmniColumnarShuffleManager(conf: SparkConf) extends ColumnarShuffleManager env.conf, metrics, shuffleExecutorComponents) - case bypassMergeSortHandle: BypassMergeSortShuffleHandle[K@unchecked, V@unchecked] => + case bypassMergeSortHandle: BypassMergeSortShuffleHandle[K @unchecked, V @unchecked] => new BypassMergeSortShuffleWriter( env.blockManager, bypassMergeSortHandle, @@ -107,9 +107,8 @@ class OmniColumnarShuffleManager(conf: SparkConf) extends ColumnarShuffleManager env.conf, metrics, shuffleExecutorComponents) - case other: BaseShuffleHandle[K@unchecked, V@unchecked, _] => + case other: BaseShuffleHandle[K @unchecked, V @unchecked, _] => new SortShuffleWriter( - shuffleBlockResolver, other, mapId, context, diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarBasicPhysicalOperators.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarBasicPhysicalOperators.scala index cb23b68f09bb085d86e133d3ef40628e8c5ca4c2..4dc6ede583fbb33048d7054b0e50f4a309ddd732 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarBasicPhysicalOperators.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarBasicPhysicalOperators.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution import java.util.concurrent.TimeUnit.NANOSECONDS + import com.huawei.boostkit.spark.Constant.IS_SKIP_VERIFY_EXP import com.huawei.boostkit.spark.expression.OmniExpressionAdaptor._ import com.huawei.boostkit.spark.util.OmniAdaptorUtil @@ -101,6 +102,9 @@ case class ColumnarProjectExec(projectList: Seq[NamedExpression], child: SparkPl |${ExplainUtils.generateFieldString("Input", child.output)} |""".stripMargin } + + override protected def withNewChildInternal(newChild: SparkPlan): ColumnarProjectExec = + copy(child = newChild) } case class ColumnarFilterExec(condition: Expression, child: SparkPlan) @@ -109,6 +113,9 @@ case class ColumnarFilterExec(condition: Expression, child: SparkPlan) override def supportsColumnar: Boolean = true override def nodeName: String = "OmniColumnarFilter" + override protected def withNewChildInternal(newChild: SparkPlan): ColumnarFilterExec = + copy(this.condition, newChild) + // Split out all the IsNotNulls from condition. private val (notNullPreds, otherPreds) = splitConjunctivePredicates(condition).partition { case IsNotNull(a) => isNullIntolerant(a) && a.references.subsetOf(child.outputSet) @@ -116,7 +123,7 @@ case class ColumnarFilterExec(condition: Expression, child: SparkPlan) } // If one expression and its children are null intolerant, it is null intolerant. - private def isNullIntolerant(expr: Expression): Boolean = expr match { + override def isNullIntolerant(expr: Expression): Boolean = expr match { case e: NullIntolerant => e.children.forall(isNullIntolerant) case _ => false } @@ -267,6 +274,9 @@ case class ColumnarConditionProjectExec(projectList: Seq[NamedExpression], override def output: Seq[Attribute] = projectList.map(_.toAttribute) + override protected def withNewChildInternal(newChild: SparkPlan): ColumnarConditionProjectExec = + copy(child = newChild) + override lazy val metrics = Map( "addInputTime" -> SQLMetrics.createTimingMetric(sparkContext, "time in omni addInput"), "numInputVecBatchs" -> SQLMetrics.createMetric(sparkContext, "number of input vecBatchs"), @@ -383,7 +393,7 @@ case class ColumnarUnionExec(children: Seq[SparkPlan]) extends SparkPlan { children.map(_.output).transpose.map { attrs => val firstAttr = attrs.head val nullable = attrs.exists(_.nullable) - val newDt = attrs.map(_.dataType).reduce(StructType.merge) + val newDt = attrs.map(_.dataType).reduce(StructType.unionLikeMerge) if (firstAttr.dataType == newDt) { firstAttr.withNullability(nullable) } else { @@ -393,6 +403,9 @@ case class ColumnarUnionExec(children: Seq[SparkPlan]) extends SparkPlan { } } + override protected def withNewChildrenInternal(newChildren: IndexedSeq[SparkPlan]): SparkPlan = + legacyWithNewChildren(newChildren) + def buildCheck(): Unit = { val inputTypes = new Array[DataType](output.size) output.zipWithIndex.foreach { @@ -420,7 +433,7 @@ class ColumnarRangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range protected override def doExecuteColumnar(): RDD[ColumnarBatch] = { val numOutputRows = longMetric("numOutputRows") - sqlContext + session.sqlContext .sparkContext .parallelize(0 until numSlices, numSlices) .mapPartitionsWithIndex { (i, _) => diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarBroadcastExchangeAdaptorExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarBroadcastExchangeAdaptorExec.scala index d137388ab3c41c3ee103ac974cb594990379d394..1d236c16d3849906ec997726ca62ee5957ff0740 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarBroadcastExchangeAdaptorExec.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarBroadcastExchangeAdaptorExec.scala @@ -64,4 +64,7 @@ case class ColumnarBroadcastExchangeAdaptorExec(child: SparkPlan, numPartitions: "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"), "numOutputBatches" -> SQLMetrics.createMetric(sparkContext, "output_batches"), "processTime" -> SQLMetrics.createTimingMetric(sparkContext, "totaltime_datatoarrowcolumnar")) + + override protected def withNewChildInternal(newChild: SparkPlan): + ColumnarBroadcastExchangeAdaptorExec = copy(child = newChild) } \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarBroadcastExchangeExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarBroadcastExchangeExec.scala index 72d1aae05d8f9e4a22a7f0bc17e68aca8b157d74..8a29e0d2bc1531351210fe7ae77b7da0577e2fa2 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarBroadcastExchangeExec.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarBroadcastExchangeExec.scala @@ -65,7 +65,7 @@ class ColumnarBroadcastExchangeExec(mode: BroadcastMode, child: SparkPlan) @transient override lazy val relationFuture: Future[broadcast.Broadcast[Any]] = { SQLExecution.withThreadLocalCaptured[broadcast.Broadcast[Any]]( - sqlContext.sparkSession, ColumnarBroadcastExchangeExec.executionContext) { + session.sqlContext.sparkSession, ColumnarBroadcastExchangeExec.executionContext) { try { // Setup a job group here so later it may get cancelled by groupId if necessary. sparkContext.setJobGroup(runId.toString, s"broadcast exchange (runId $runId)", @@ -159,6 +159,9 @@ class ColumnarBroadcastExchangeExec(mode: BroadcastMode, child: SparkPlan) } } + override protected def withNewChildInternal(newChild: SparkPlan): ColumnarBroadcastExchangeExec = + new ColumnarBroadcastExchangeExec(this.mode, newChild) + override protected def doPrepare(): Unit = { // Materialize the future. relationFuture diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarExec.scala index b1fd51f4867cf2c435c8ddd7036bf6f8b6818212..e88fec3a56ad6da646cabe73bc09146064e07718 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarExec.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarExec.scala @@ -18,10 +18,8 @@ package org.apache.spark.sql.execution import java.util.concurrent.TimeUnit.NANOSECONDS - import scala.collection.JavaConverters._ import scala.collection.mutable.ListBuffer - import org.apache.spark.broadcast import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow @@ -31,8 +29,9 @@ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} import org.apache.spark.sql.execution.util.SparkMemoryUtils import org.apache.spark.sql.execution.vectorized.{OffHeapColumnVector, OmniColumnVector, WritableColumnVector} -import org.apache.spark.sql.types.{BooleanType, ByteType, CalendarIntervalType, DataType, DateType, DecimalType, DoubleType, IntegerType, LongType, ShortType, StringType, StructType, TimestampType} +import org.apache.spark.sql.types.{BinaryType, BooleanType, ByteType, CalendarIntervalType, DataType, DateType, DecimalType, DoubleType, IntegerType, LongType, ShortType, StringType, StructType, TimestampType} import org.apache.spark.sql.vectorized.ColumnarBatch +import org.apache.spark.util.Utils import nova.hetu.omniruntime.vector.Vec @@ -101,6 +100,7 @@ private object RowToColumnConverter { private def getConverterForType(dataType: DataType, nullable: Boolean): TypeConverter = { val core = dataType match { + case BinaryType => BinaryConverter case BooleanType => BooleanConverter case ByteType => ByteConverter case ShortType => ShortConverter @@ -123,6 +123,13 @@ private object RowToColumnConverter { } } + private object BinaryConverter extends TypeConverter { + override def append(row: SpecializedGetters, column: Int, cv: WritableColumnVector): Unit = { + val bytes = row.getBinary(column) + cv.appendByteArray(bytes, 0, bytes.length) + } + } + private object BooleanConverter extends TypeConverter { override def append(row: SpecializedGetters, column: Int, cv: WritableColumnVector): Unit = cv.appendBoolean(row.getBoolean(column)) @@ -232,8 +239,11 @@ case class RowToOmniColumnarExec(child: SparkPlan) extends RowToColumnarTransiti "rowToOmniColumnarTime" -> SQLMetrics.createTimingMetric(sparkContext, "time in row to OmniColumnar") ) + override protected def withNewChildInternal(newChild: SparkPlan): RowToOmniColumnarExec = + copy(child = newChild) + override def doExecuteColumnar(): RDD[ColumnarBatch] = { - val enableOffHeapColumnVector = sqlContext.conf.offHeapColumnVectorEnabled + val enableOffHeapColumnVector = session.sqlContext.conf.offHeapColumnVectorEnabled val numInputRows = longMetric("numInputRows") val numOutputBatches = longMetric("numOutputBatches") val rowToOmniColumnarTime = longMetric("rowToOmniColumnarTime") @@ -313,6 +323,9 @@ case class OmniColumnarToRowExec(child: SparkPlan) extends ColumnarToRowTransiti ColumnarBatchToInternalRow.convert(localOutput, batches, numOutputRows, numInputBatches, omniColumnarToRowTime) } } + + override protected def withNewChildInternal(newChild: SparkPlan): + OmniColumnarToRowExec = copy(child = newChild) } object ColumnarBatchToInternalRow { diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarExpandExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarExpandExec.scala index 27b05b16c017c43e73a0c3b6d4f05ea02d11f951..b25d97d604da1ae0cbaef04b34bbf53e61b8af83 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarExpandExec.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarExpandExec.scala @@ -1,3 +1,20 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.apache.spark.sql.execution import com.huawei.boostkit.spark.Constant.IS_SKIP_VERIFY_EXP @@ -161,4 +178,6 @@ case class ColumnarExpandExec( throw new UnsupportedOperationException(s"ColumnarExpandExec operator doesn't support doExecute().") } + override protected def withNewChildInternal(newChild: SparkPlan): ColumnarExpandExec = + copy(child = newChild) } diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarFileSourceScanExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarFileSourceScanExec.scala index 73091d069cb311f129fef45078d936ad365e14e0..fb741f5effc94772dc95f0f485dd0e8eb58cce6c 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarFileSourceScanExec.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarFileSourceScanExec.scala @@ -47,6 +47,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.optimizer.BuildLeft import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, UnknownPartitioning} +import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.execution.aggregate.HashAggregateExec import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.orc.{OmniOrcFileFormat, OrcFileFormat} @@ -54,6 +55,7 @@ import org.apache.spark.sql.execution.joins.ColumnarBroadcastHashJoinExec import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} import org.apache.spark.sql.execution.util.SparkMemoryUtils import org.apache.spark.sql.execution.util.SparkMemoryUtils.addLeakSafeTaskCompletionListener +import org.apache.spark.sql.execution.vectorized.ConstantColumnVector import org.apache.spark.sql.execution.vectorized.OmniColumnVector import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{DecimalType, StructType} @@ -74,13 +76,19 @@ abstract class BaseColumnarFileSourceScanExec( disableBucketedScan: Boolean = false) extends DataSourceScanExec { + lazy val metadataColumns: Seq[AttributeReference] = + output.collect { case FileSourceMetadataAttribute(attr) => attr } + override lazy val supportsColumnar: Boolean = true override def vectorTypes: Option[Seq[String]] = relation.fileFormat.vectorTypes( requiredSchema = requiredSchema, partitionSchema = relation.partitionSchema, - relation.sparkSession.sessionState.conf) + relation.sparkSession.sessionState.conf).map { vectorTypes => + // for column-based file format, append metadata column's vector type classes if any + vectorTypes ++ Seq.fill(metadataColumns.size)(classOf[ConstantColumnVector].getName) + } private lazy val driverMetrics: HashMap[String, Long] = HashMap.empty @@ -96,7 +104,7 @@ abstract class BaseColumnarFileSourceScanExec( } private def isDynamicPruningFilter(e: Expression): Boolean = - e.find(_.isInstanceOf[PlanExpression[_]]).isDefined + e.exists(_.isInstanceOf[PlanExpression[_]]) @transient lazy val selectedPartitions: Array[PartitionDirectory] = { val optimizerMetadataTimeNs = relation.location.metadataOpsTimeNs.getOrElse(0L) @@ -223,7 +231,13 @@ abstract class BaseColumnarFileSourceScanExec( @transient private lazy val pushedDownFilters = { val supportNestedPredicatePushdown = DataSourceUtils.supportNestedPredicatePushdown(relation) - dataFilters.flatMap(DataSourceStrategy.translateFilter(_, supportNestedPredicatePushdown)) + // `dataFilters` should not include any metadata col filters + // because the metadata struct has been flatted in FileSourceStrategy + // and thus metadata col filters are invalid to be pushed down + dataFilters.filterNot(_.references.exists { + case FileSourceMetadataAttribute(_) => true + case _ => false + }).flatMap(DataSourceStrategy.translateFilter(_, supportNestedPredicatePushdown)) } override protected def metadata: Map[String, String] = { @@ -242,22 +256,27 @@ abstract class BaseColumnarFileSourceScanExec( "DataFilters" -> seqToString(dataFilters), "Location" -> locationDesc) - // (SPARK-32986): Add bucketed scan info in explain output of FileSourceScanExec - if (bucketedScan) { - relation.bucketSpec.map { spec => + relation.bucketSpec.map { spec => + val bucketedKey = "Bucketed" + if (bucketedScan) { val numSelectedBuckets = optionalBucketSet.map { b => b.cardinality() } getOrElse { spec.numBuckets } - metadata + ("SelectedBucketsCount" -> - (s"$numSelectedBuckets out of ${spec.numBuckets}" + + metadata ++ Map( + bucketedKey -> "true", + "SelectedBucketsCount" -> (s"$numSelectedBuckets out of ${spec.numBuckets}" + optionalNumCoalescedBuckets.map { b => s" (Coalesced to $b)" }.getOrElse(""))) - } getOrElse { - metadata + } else if (!relation.sparkSession.sessionState.conf.bucketingEnabled) { + metadata + (bucketedKey -> "false (disabled by configuration)") + } else if (disableBucketedScan) { + metadata + (bucketedKey -> "false (disabled by query planner") + } else { + metadata + (bucketedKey -> "false (disabled column(s) not read)") } - } else { - metadata + } getOrElse { + metadata } } @@ -312,7 +331,7 @@ abstract class BaseColumnarFileSourceScanExec( createBucketedReadRDD(relation.bucketSpec.get, readFile, dynamicallySelectedPartitions, relation) } else { - createNonBucketedReadRDD(readFile, dynamicallySelectedPartitions, relation) + createReadRDD(readFile, dynamicallySelectedPartitions, relation) } sendDriverMetrics() readRDD @@ -343,7 +362,7 @@ abstract class BaseColumnarFileSourceScanExec( driverMetrics("staticFilesNum") = filesNum driverMetrics("staticFilesSize") = filesSize } - if (relation.partitionSchemaOption.isDefined) { + if (relation.partitionSchema.nonEmpty) { driverMetrics("numPartitions") = partitions.length } } @@ -363,7 +382,7 @@ abstract class BaseColumnarFileSourceScanExec( None } } ++ { - if (relation.partitionSchemaOption.isDefined) { + if (relation.partitionSchema.nonEmpty) { Map( "numPartitions" -> SQLMetrics.createMetric(sparkContext, "number of partitions read"), "pruningTime" -> @@ -423,7 +442,7 @@ abstract class BaseColumnarFileSourceScanExec( /** * Create an RDD for bucketed reads. - * The non-bucketed variant of this function is [[createNonBucketedReadRDD]]. + * The non-bucketed variant of this function is [[createReadRDD]]. * * The algorithm is pretty simple: each RDD partition being returned should include all the files * with the same bucket id from all the given Hive partitions. @@ -447,10 +466,9 @@ abstract class BaseColumnarFileSourceScanExec( }.groupBy { f => BucketingUtils .getBucketId(new Path(f.filePath).getName) - .getOrElse(sys.error(s"Invalid bucket file ${f.filePath}")) + .getOrElse(throw QueryExecutionErrors.invalidBucketFile(f.filePath)) } - // (SPARK-32985): Decouple bucket filter pruning and bucketed table scan val prunedFilesGroupedToBuckets = if (optionalBucketSet.isDefined) { val bucketSet = optionalBucketSet.get filesGroupedToBuckets.filter { @@ -475,7 +493,8 @@ abstract class BaseColumnarFileSourceScanExec( } } - new FileScanRDD(fsRelation.sparkSession, readFile, filePartitions) + new FileScanRDD(fsRelation.sparkSession, readFile, filePartitions, + new StructType(requiredSchema.fields ++ fsRelation.partitionSchema.fields), metadataColumns) } /** @@ -486,7 +505,7 @@ abstract class BaseColumnarFileSourceScanExec( * @param selectedPartitions Hive-style partition that are part of the read. * @param fsRelation [[HadoopFsRelation]] associated with the read. */ - private def createNonBucketedReadRDD( + private def createReadRDD( readFile: (PartitionedFile) => Iterator[InternalRow], selectedPartitions: Array[PartitionDirectory], fsRelation: HadoopFsRelation): RDD[InternalRow] = { @@ -496,27 +515,43 @@ abstract class BaseColumnarFileSourceScanExec( logInfo(s"Planning scan with bin packing, max size: $maxSplitBytes bytes, " + s"open cost is considered as scanning $openCostInBytes bytes.") + // Filter files with bucket pruning if possible + val bucketingEnabled = fsRelation.sparkSession.sessionState.conf.bucketingEnabled + val shouldProcess: Path => Boolean = optionalBucketSet match { + case Some(bucketSet) if bucketingEnabled => + // Do not prune the file if bucket file name is invalid + filePath => BucketingUtils.getBucketId(filePath.getName).forall(bucketSet.get) + case _ => + _ => true + } + val splitFiles = selectedPartitions.flatMap { partition => partition.files.flatMap { file => // getPath() is very expensive so we only want to call it once in this block: val filePath = file.getPath - val isSplitable = relation.fileFormat.isSplitable( - relation.sparkSession, relation.options, filePath) - PartitionedFileUtil.splitFiles( - sparkSession = relation.sparkSession, - file = file, - filePath = filePath, - isSplitable = isSplitable, - maxSplitBytes = maxSplitBytes, - partitionValues = partition.values - ) + + if (shouldProcess(filePath)) { + val isSplitable = relation.fileFormat.isSplitable( + relation.sparkSession, relation.options, filePath) + PartitionedFileUtil.splitFiles( + sparkSession = relation.sparkSession, + file = file, + filePath = filePath, + isSplitable = isSplitable, + maxSplitBytes = maxSplitBytes, + partitionValues = partition.values + ) + } else { + Seq.empty + } } }.sortBy(_.length)(implicitly[Ordering[Long]].reverse) val partitions = FilePartition.getFilePartitions(relation.sparkSession, splitFiles, maxSplitBytes) - new FileScanRDD(fsRelation.sparkSession, readFile, partitions) + new FileScanRDD(fsRelation.sparkSession, readFile, partitions, + new StructType(requiredSchema.fields ++ fsRelation.partitionSchema.fields), metadataColumns) } // Filters unused DynamicPruningExpression expressions - one which has been replaced @@ -551,7 +586,7 @@ abstract class BaseColumnarFileSourceScanExec( throw new UnsupportedOperationException(s"Unsupported final aggregate expression in operator fusion, exp: $exp") } else if (exp.mode == Partial) { exp.aggregateFunction match { - case Sum(_) | Min(_) | Average(_) | Max(_) | Count(_) | First(_, _) => + case Sum(_, _) | Min(_) | Average(_, _) | Max(_) | Count(_) | First(_, _) => val aggExp = exp.aggregateFunction.children.head omniOutputExressionOrder += { exp.aggregateFunction.inputAggBufferAttributes.head.exprId -> @@ -569,7 +604,7 @@ abstract class BaseColumnarFileSourceScanExec( } } else if (exp.mode == PartialMerge) { exp.aggregateFunction match { - case Sum(_) | Min(_) | Average(_) | Max(_) | Count(_) | First(_, _) => + case Sum(_, _) | Min(_) | Average(_, _) | Max(_) | Count(_) | First(_, _) => val aggExp = exp.aggregateFunction.children.head omniOutputExressionOrder += { exp.aggregateFunction.inputAggBufferAttributes.head.exprId -> @@ -815,7 +850,7 @@ case class ColumnarMultipleOperatorExec( None } } ++ { - if (relation.partitionSchemaOption.isDefined) { + if (relation.partitionSchema.nonEmpty) { Map( "numPartitions" -> SQLMetrics.createMetric(sparkContext, "number of partitions read"), "pruningTime" -> @@ -1162,7 +1197,7 @@ case class ColumnarMultipleOperatorExec1( None } } ++ { - if (relation.partitionSchemaOption.isDefined) { + if (relation.partitionSchema.nonEmpty) { Map( "numPartitions" -> SQLMetrics.createMetric(sparkContext, "number of partitions read"), "pruningTime" -> diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarHashAggregateExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarHashAggregateExec.scala index e2618842a7dd6a2170b8738be2a299cca2a86d47..be2aa8f0cf8f57e2ac8c56c04d0195c25ba0d0fc 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarHashAggregateExec.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarHashAggregateExec.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution import java.util.concurrent.TimeUnit.NANOSECONDS + import com.huawei.boostkit.spark.Constant.IS_SKIP_VERIFY_EXP import com.huawei.boostkit.spark.expression.OmniExpressionAdaptor._ import com.huawei.boostkit.spark.util.OmniAdaptorUtil @@ -32,8 +33,9 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.execution.ColumnarProjection.dealPartitionData -import org.apache.spark.sql.execution.aggregate.BaseAggregateExec +import org.apache.spark.sql.execution.aggregate.{AggregateCodegenSupport, BaseAggregateExec} import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.execution.util.SparkMemoryUtils import org.apache.spark.sql.execution.vectorized.OmniColumnVector @@ -45,14 +47,18 @@ import org.apache.spark.sql.vectorized.ColumnarBatch */ case class ColumnarHashAggregateExec( requiredChildDistributionExpressions: Option[Seq[Expression]], + isStreaming: Boolean, + numShufflePartitions: Option[Int], groupingExpressions: Seq[NamedExpression], aggregateExpressions: Seq[AggregateExpression], aggregateAttributes: Seq[Attribute], initialInputBufferOffset: Int, resultExpressions: Seq[NamedExpression], child: SparkPlan) - extends BaseAggregateExec - with AliasAwareOutputPartitioning { + extends AggregateCodegenSupport { + + override protected def withNewChildInternal(newChild: SparkPlan): ColumnarHashAggregateExec = + copy(child = newChild) override def verboseStringWithOperatorId(): String = { s""" @@ -77,6 +83,15 @@ case class ColumnarHashAggregateExec( "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"), "numOutputVecBatchs" -> SQLMetrics.createMetric(sparkContext, "number of output vecBatchs")) + protected override def needHashTable: Boolean = true + + protected override def doConsumeWithKeys(ctx: CodegenContext, input: Seq[ExprCode]): String = { + throw new UnsupportedOperationException("ColumnarHashAgg code-gen does not support grouping keys") + } + + protected override def doProduceWithKeys(ctx: CodegenContext): String = { + throw new UnsupportedOperationException("ColumnarHashAgg code-gen does not support grouping keys") + } override def supportsColumnar: Boolean = true @@ -99,7 +114,7 @@ case class ColumnarHashAggregateExec( } if (exp.mode == Final) { exp.aggregateFunction match { - case Sum(_) | Min(_) | Max(_) | Count(_) | Average(_) | First(_,_) => + case Sum(_, _) | Min(_) | Max(_) | Count(_) | Average(_, _) | First(_,_) => omniAggFunctionTypes(index) = toOmniAggFunType(exp, true, true) omniAggOutputTypes(index) = toOmniAggInOutType(exp.aggregateFunction.dataType) omniAggChannels(index) = @@ -110,7 +125,7 @@ case class ColumnarHashAggregateExec( } } else if (exp.mode == PartialMerge) { exp.aggregateFunction match { - case Sum(_) | Min(_) | Max(_) | Count(_) | Average(_) | First(_,_) => + case Sum(_, _) | Min(_) | Max(_) | Count(_) | Average(_, _) | First(_,_) => omniAggFunctionTypes(index) = toOmniAggFunType(exp, true) omniAggOutputTypes(index) = toOmniAggInOutType(exp.aggregateFunction.inputAggBufferAttributes) @@ -125,7 +140,7 @@ case class ColumnarHashAggregateExec( } } else if (exp.mode == Partial) { exp.aggregateFunction match { - case Sum(_) | Min(_) | Max(_) | Count(_) | Average(_) | First(_,_) => + case Sum(_, _) | Min(_) | Max(_) | Count(_) | Average(_, _) | First(_,_) => omniAggFunctionTypes(index) = toOmniAggFunType(exp, true) omniAggOutputTypes(index) = toOmniAggInOutType(exp.aggregateFunction.inputAggBufferAttributes) @@ -150,7 +165,7 @@ case class ColumnarHashAggregateExec( omniSourceTypes(i) = sparkTypeToOmniType(attr.dataType, attr.metadata) } - for (aggChannel <-omniAggChannels) { + for (aggChannel <- omniAggChannels) { if (!isSimpleColumnForAll(aggChannel)) { checkOmniJsonWhiteList("", aggChannel.toArray) } @@ -202,7 +217,7 @@ case class ColumnarHashAggregateExec( } if (exp.mode == Final) { exp.aggregateFunction match { - case Sum(_) | Min(_) | Max(_) | Count(_) | Average(_) | First(_,_) => + case Sum(_, _) | Min(_) | Max(_) | Count(_) | Average(_, _) | First(_,_) => omniAggFunctionTypes(index) = toOmniAggFunType(exp, true, true) omniAggOutputTypes(index) = toOmniAggInOutType(exp.aggregateFunction.dataType) @@ -214,7 +229,7 @@ case class ColumnarHashAggregateExec( } } else if (exp.mode == PartialMerge) { exp.aggregateFunction match { - case Sum(_) | Min(_) | Max(_) | Count(_) | Average(_) | First(_,_) => + case Sum(_, _) | Min(_) | Max(_) | Count(_) | Average(_, _) | First(_,_) => omniAggFunctionTypes(index) = toOmniAggFunType(exp, true) omniAggOutputTypes(index) = toOmniAggInOutType(exp.aggregateFunction.inputAggBufferAttributes) @@ -229,7 +244,7 @@ case class ColumnarHashAggregateExec( } } else if (exp.mode == Partial) { exp.aggregateFunction match { - case Sum(_) | Min(_) | Max(_) | Count(_) | Average(_) | First(_,_) => + case Sum(_, _) | Min(_) | Max(_) | Count(_) | Average(_, _) | First(_,_) => omniAggFunctionTypes(index) = toOmniAggFunType(exp, true) omniAggOutputTypes(index) = toOmniAggInOutType(exp.aggregateFunction.inputAggBufferAttributes) @@ -338,10 +353,3 @@ case class ColumnarHashAggregateExec( throw new UnsupportedOperationException("This operator doesn't support doExecute().") } } - -object ColumnarHashAggregateExec { - def supportsAggregate(aggregateBufferAttributes: Seq[Attribute]): Boolean = { - val aggregationBufferSchema = StructType.fromAttributes(aggregateBufferAttributes) - UnsafeFixedWidthAggregationMap.supportsAggregationBufferSchema(aggregationBufferSchema) - } -} diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarShuffleExchangeExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarShuffleExchangeExec.scala index cea0a1438b1c64a0d1372e2a272742aa9be08502..746e1898a23528b869c6097fdd4cc8bc799d59be 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarShuffleExchangeExec.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarShuffleExchangeExec.scala @@ -18,8 +18,6 @@ package org.apache.spark.sql.execution import com.huawei.boostkit.spark.ColumnarPluginConfig - -import java.util.Random import com.huawei.boostkit.spark.Constant.IS_SKIP_VERIFY_EXP import scala.collection.JavaConverters._ @@ -41,6 +39,7 @@ import org.apache.spark.shuffle.ColumnarShuffleDependency import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.LazilyGeneratedOrdering +import org.apache.spark.sql.catalyst.plans.logical.Statistics import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.exchange.{ENSURE_REQUIREMENTS, ShuffleExchangeExec, ShuffleExchangeLike, ShuffleOrigin} import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec.createShuffleWriteProcessor @@ -53,8 +52,9 @@ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{IntegerType, StructType} import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.spark.util.MutablePair +import org.apache.spark.util.random.XORShiftRandom -class ColumnarShuffleExchangeExec( +case class ColumnarShuffleExchangeExec( override val outputPartitioning: Partitioning, child: SparkPlan, shuffleOrigin: ShuffleOrigin = ENSURE_REQUIREMENTS) @@ -62,7 +62,7 @@ class ColumnarShuffleExchangeExec( private lazy val writeMetrics = SQLShuffleWriteMetricsReporter.createShuffleWriteMetrics(sparkContext) - override lazy val readMetrics = + private[sql] lazy val readMetrics = SQLShuffleReadMetricsReporter.createShuffleReadMetrics(sparkContext) override lazy val metrics: Map[String, SQLMetric] = Map( "dataSize" -> SQLMetrics.createSizeMetric(sparkContext, "data size"), @@ -100,9 +100,19 @@ class ColumnarShuffleExchangeExec( override def numPartitions: Int = columnarShuffleDependency.partitioner.numPartitions + override def getShuffleRDD(partitionSpecs: Array[ShufflePartitionSpec]): RDD[ColumnarBatch] = { + new ShuffledColumnarRDD(columnarShuffleDependency, readMetrics, partitionSpecs) + } + + override def runtimeStatistics: Statistics = { + val dataSize = metrics("dataSize").value + val rowCount = metrics(SQLShuffleWriteMetricsReporter.SHUFFLE_RECORDS_WRITTEN).value + Statistics(dataSize, Some(rowCount)) + } + @transient lazy val columnarShuffleDependency: ShuffleDependency[Int, ColumnarBatch, ColumnarBatch] = { - ColumnarShuffleExchangeExec.prepareShuffleDependency( + val dep = ColumnarShuffleExchangeExec.prepareShuffleDependency( inputColumnarRDD, child.output, outputPartitioning, @@ -113,8 +123,8 @@ class ColumnarShuffleExchangeExec( longMetric("numInputRows"), longMetric("splitTime"), longMetric("spillTime")) + dep } - var cachedShuffleRDD: ShuffledColumnarRDD = _ override def doExecute(): RDD[InternalRow] = { @@ -155,6 +165,8 @@ class ColumnarShuffleExchangeExec( cachedShuffleRDD } } + override protected def withNewChildInternal(newChild: SparkPlan): ColumnarShuffleExchangeExec = + copy(child = newChild) } object ColumnarShuffleExchangeExec extends Logging { @@ -324,6 +336,7 @@ object ColumnarShuffleExchangeExec extends Logging { rdd.mapPartitionsWithIndexInternal((_, cbIter) => { cbIter.map { cb => (0, cb) } }, isOrderSensitive = isOrderSensitive) + case _ => throw new IllegalStateException(s"Exchange not implemented for $newPartitioning") } val numCols = outputAttributes.size @@ -341,6 +354,7 @@ object ColumnarShuffleExchangeExec extends Logging { new PartitionInfo("hash", numPartitions, numCols, intputTypes) case RangePartitioning(ordering, numPartitions) => new PartitionInfo("range", numPartitions, numCols, intputTypes) + case _ => throw new IllegalStateException(s"Exchange not implemented for $newPartitioning") } new ColumnarShuffleDependency[Int, ColumnarBatch, ColumnarBatch]( diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarSortExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarSortExec.scala index 7c7001dbc1c468465a0115946aeff9849d51a3df..49f2451112f66915d95b57e76dfdd8203a2af635 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarSortExec.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarSortExec.scala @@ -56,6 +56,9 @@ case class ColumnarSortExec( override def outputPartitioning: Partitioning = child.outputPartitioning + override protected def withNewChildInternal(newChild: SparkPlan): ColumnarSortExec = + copy(child = newChild) + override def requiredChildDistribution: Seq[Distribution] = if (global) OrderedDistribution(sortOrder) :: Nil else UnspecifiedDistribution :: Nil diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarTakeOrderedAndProjectExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarTakeOrderedAndProjectExec.scala index 6fec9f9a054f345a83bc20f278c2ed3be57e6dbd..0e5fac68c74aea7980f7f207dee4345565515a62 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarTakeOrderedAndProjectExec.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarTakeOrderedAndProjectExec.scala @@ -49,6 +49,9 @@ case class ColumnarTakeOrderedAndProjectExec( override def nodeName: String = "OmniColumnarTakeOrderedAndProject" + override protected def withNewChildInternal(newChild: SparkPlan): ColumnarTakeOrderedAndProjectExec = + copy(child = newChild) + val serializer: Serializer = new ColumnarBatchSerializer( longMetric("avgReadBatchNumRows"), longMetric("numOutputRows")) diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarWindowExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarWindowExec.scala index e5534d3c67680a747f1b4d92fcb2c377c81577c9..63414c781030455c89ebc434cd54db0cbcbbd34a 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarWindowExec.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarWindowExec.scala @@ -50,6 +50,9 @@ case class ColumnarWindowExec(windowExpression: Seq[NamedExpression], override def supportsColumnar: Boolean = true + override protected def withNewChildInternal(newChild: SparkPlan): ColumnarWindowExec = + copy(child = newChild) + override lazy val metrics = Map( "addInputTime" -> SQLMetrics.createTimingMetric(sparkContext, "time in omni addInput"), "numInputVecBatchs" -> SQLMetrics.createMetric(sparkContext, "number of input vecBatchs"), @@ -59,25 +62,6 @@ case class ColumnarWindowExec(windowExpression: Seq[NamedExpression], "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"), "numOutputVecBatchs" -> SQLMetrics.createMetric(sparkContext, "number of output vecBatchs")) - override def output: Seq[Attribute] = - child.output ++ windowExpression.map(_.toAttribute) - - override def requiredChildDistribution: Seq[Distribution] = { - if (partitionSpec.isEmpty) { - // Only show warning when the number of bytes is larger than 100 MiB? - logWarning("No Partition Defined for Window operation! Moving all data to a single " - + "partition, this can cause serious performance degradation.") - AllTuples :: Nil - } else ClusteredDistribution(partitionSpec) :: Nil - } - - override def requiredChildOrdering: Seq[Seq[SortOrder]] = - Seq(partitionSpec.map(SortOrder(_, Ascending)) ++ orderSpec) - - override def outputOrdering: Seq[SortOrder] = child.outputOrdering - - override def outputPartitioning: Partitioning = child.outputPartitioning - override protected def doExecute(): RDD[InternalRow] = { throw new UnsupportedOperationException(s"This operator doesn't support doExecute().") } diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ShuffledColumnarRDD.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ShuffledColumnarRDD.scala index 1e728239b988592aa7ef8e89cba4cccf0751c065..eb11d449c25483e231ee1c672295a0e83186214f 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ShuffledColumnarRDD.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ShuffledColumnarRDD.scala @@ -24,6 +24,43 @@ import org.apache.spark.sql.execution.metric.{SQLMetric, SQLShuffleReadMetricsRe import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.vectorized.ColumnarBatch +sealed trait ShufflePartitionSpec + +// A partition that reads data of one or more reducers, from `startReducerIndex` (inclusive) to +// `endReducerIndex` (exclusive). +case class CoalescedPartitionSpec( + startReducerIndex: Int, + endReducerIndex: Int, + @transient dataSize: Option[Long] = None) extends ShufflePartitionSpec + +object CoalescedPartitionSpec { + def apply(startReducerIndex: Int, + endReducerIndex: Int, + dataSize: Long): CoalescedPartitionSpec = { + CoalescedPartitionSpec(startReducerIndex, endReducerIndex, Some(dataSize)) + } +} + +// A partition that reads partial data of one reducer, from `startMapIndex` (inclusive) to +// `endMapIndex` (exclusive). +case class PartialReducerPartitionSpec( + reducerIndex: Int, + startMapIndex: Int, + endMapIndex: Int, + @transient dataSize: Long) extends ShufflePartitionSpec + +// A partition that reads partial data of one mapper, from `startReducerIndex` (inclusive) to +// `endReducerIndex` (exclusive). +case class PartialMapperPartitionSpec( + mapIndex: Int, + startReducerIndex: Int, + endReducerIndex: Int) extends ShufflePartitionSpec + +case class CoalescedMapperPartitionSpec( + startMapIndex: Int, + endMapIndex: Int, + numReducers: Int) extends ShufflePartitionSpec + /** * The [[Partition]] used by [[ShuffledRowRDD]]. */ @@ -70,7 +107,7 @@ class ShuffledColumnarRDD( override def getPreferredLocations(partition: Partition): Seq[String] = { val tracker = SparkEnv.get.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster] partition.asInstanceOf[ShuffledColumnarRDDPartition].spec match { - case CoalescedPartitionSpec(startReducerIndex, endReducerIndex) => + case CoalescedPartitionSpec(startReducerIndex, endReducerIndex, _) => startReducerIndex.until(endReducerIndex).flatMap { reducerIndex => tracker.getPreferredLocationsForShuffle(dependency, reducerIndex) } @@ -80,6 +117,9 @@ class ShuffledColumnarRDD( case PartialMapperPartitionSpec(mapIndex, _, _) => tracker.getMapLocation(dependency, mapIndex, mapIndex + 1) + + case CoalescedMapperPartitionSpec(startMapIndex, endMapIndex, numReducers) => + tracker.getMapLocation(dependency, startMapIndex, endMapIndex) } } @@ -89,7 +129,7 @@ class ShuffledColumnarRDD( // as well as the `tempMetrics` for basic shuffle metrics. val sqlMetricsReporter = new SQLShuffleReadMetricsReporter(tempMetrics, metrics) val reader = split.asInstanceOf[ShuffledColumnarRDDPartition].spec match { - case CoalescedPartitionSpec(startReducerIndex, endReducerIndex) => + case CoalescedPartitionSpec(startReducerIndex, endReducerIndex, _) => SparkEnv.get.shuffleManager.getReader( dependency.shuffleHandle, startReducerIndex, @@ -116,7 +156,22 @@ class ShuffledColumnarRDD( endReducerIndex, context, sqlMetricsReporter) + + case CoalescedMapperPartitionSpec(startMapIndex, endMapIndex, numReducers) => + SparkEnv.get.shuffleManager.getReader( + dependency.shuffleHandle, + startMapIndex, + endMapIndex, + 0, + numReducers, + context, + sqlMetricsReporter) } reader.read().asInstanceOf[Iterator[Product2[Int, ColumnarBatch]]].map(_._2) } + + override def clearDependencies(): Unit = { + super.clearDependencies() + dependency = null + } } \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEPropagateEmptyRelation.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEPropagateEmptyRelation.scala new file mode 100644 index 0000000000000000000000000000000000000000..b5de9dff4b303652ca1eb4fa08f4d07aef81d7bc --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEPropagateEmptyRelation.scala @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.adaptive + +import org.apache.spark.sql.catalyst.optimizer.PropagateEmptyRelationBase +import org.apache.spark.sql.catalyst.planning.ExtractSingleColumnNullAwareAntiJoin +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.trees.TreePattern.{LOCAL_RELATION, LOGICAL_QUERY_STAGE, TRUE_OR_FALSE_LITERAL} +import org.apache.spark.sql.execution.ColumnarHashedRelation +import org.apache.spark.sql.execution.aggregate.BaseAggregateExec +import org.apache.spark.sql.execution.exchange.{REPARTITION_BY_COL, REPARTITION_BY_NUM, ShuffleExchangeLike} +import org.apache.spark.sql.execution.joins.HashedRelationWithAllNullKeys + +/** + * This rule runs in the AQE optimizer and optimizes more cases + * compared to [[PropagateEmptyRelationBase]]: + * 1. Join is single column NULL-aware anti join (NAAJ) + * Broadcasted [[HashedRelation]] is [[HashedRelationWithAllNullKeys]]. Eliminate join to an + * empty [[LocalRelation]]. + */ +object AQEPropagateEmptyRelation extends AQEPropagateEmptyRelationBase { + override protected def isEmpty(plan: LogicalPlan): Boolean = + super.isEmpty(plan) || (!isRootRepartition(plan) && getEstimatedRowCount(plan).contains(0)) + + override protected def notEmpty(plan: LogicalPlan): Boolean = + super.notEmpty(plan) || getEstimatedRowCount(plan).exists(_ > 0) + + private def isRootRepartition(plan: LogicalPlan): Boolean = plan match { + case l: LogicalQueryStage if l.getTagValue(ROOT_REPARTITION).isDefined => true + case _ => false + } + + // The returned value follows: + // - 0 means the plan must produce 0 row + // - positive value means an estimated row count which can be over-estimated + // - none means the plan has not materialized or the plan can not be estimated + private def getEstimatedRowCount(plan: LogicalPlan): Option[BigInt] = plan match { + case LogicalQueryStage(_, stage: QueryStageExec) if stage.isMeterialized => + stage.getRuntimeStatistics.rowCount + + case LogicalQueryStage(_, agg: BaseAggregateExec) if agg.groupingExpressions.nonEmpty && + agg.child.isInstanceOf[QueryStageExec] => + val stage = agg.child.asInstanceOf[QueryStageExec] + if (stage.isMeterialized) { + stage.getRuntimeStatistics.rowCount + } else { + None + } + + case _ => None + } + + private def isRelationWithAllNullKeys(plan: LogicalPlan): Boolean = plan match { + case LogicalQueryStage(_, stage: BroadcastQueryStageExec) if stage.isMeterialized => + if (stage.broadcast.supportsColumnar) { + val colRelation = stage.broadcast.relationFuture.get().value.asInstanceOf[ColumnarHashedRelation] + colRelation.relation == HashedRelationWithAllNullKeys + } else { + stage.broadcast.relationFuture.get().value == HashedRelationWithAllNullKeys + } + case _ => false + } + + private def eliminateSingleColumnarNullAwareAntiJoin: PartialFunction[LogicalPlan, LogicalPlan] = { + case j @ ExtractSingleColumnarNullAwareAntiJoin(_, _) if isRelationWithAllNullKeys(j.right) => + empty(j) + } + + override protected def userSpecifiedRepartition(p: LogicalPlan): Boolean = p match { + case LogicalQueryStage(_, ShuffleQueryStageExec(_, shuffle: ShuffleExchangeLike, _)) + if shuffle.shuffleOrigin == REPARTITION_BY_COL || + shuffle.shuffleOrigin == REPARTITION_BY_NUM => true + case _ => false + } + + override protected def applyInternal(p: LogicalPlan): LogicalPlan = p.transformUpWithPruning( + // LOCAL_RELATION and TRUE_OR_FALSE_LITERAL pattern are matched at + // `PropagateEmptyRelationBase.commonApplyFunc` + // LOGICAL_QUERY_STAGE pattern is matched at `PropagateEmptyRelationBase.commonAppleFunc` + // and `AQEPropagateEmptyRelation.eliminateSingleColumnarNullAwareAntiJoin` + // Note that, We can not specify ruleId here since the LogicalQueryStage is not immutable. + _.containsAnyPattern(LOGICAL_QUERY_STAGE, LOCAL_RELATION, TRUE_OR_FALSE_LITERAL)) { + eliminateSingleColumnarNullAwareAntiJoin.orElse(commonApplyFunc) + } +} diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/adaptive/EliminateJoinToEmptyRelation.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/adaptive/EliminateJoinToEmptyRelation.scala deleted file mode 100644 index 4edf0f4f86cb79a3e2b3a5c2fc01c999a42349c8..0000000000000000000000000000000000000000 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/adaptive/EliminateJoinToEmptyRelation.scala +++ /dev/null @@ -1,63 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.adaptive - -import org.apache.spark.sql.catalyst.planning.ExtractSingleColumnNullAwareAntiJoin -import org.apache.spark.sql.catalyst.plans.{Inner, LeftSemi} -import org.apache.spark.sql.catalyst.plans.logical.{Join, LocalRelation, LogicalPlan} -import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.execution.ColumnarHashedRelation -import org.apache.spark.sql.execution.joins.{EmptyHashedRelation, HashedRelation, HashedRelationWithAllNullKeys} - -/** - * This optimization rule detects and converts a Join to an empty [[LocalRelation]]: - * 1. Join is single column NULL-aware anti join (NAAJ), and broadcasted [[HashedRelation]] - * is [[HashedRelationWithAllNullKeys]]. - * - * 2. Join is inner or left semi join, and broadcasted [[HashedRelation]] - * is [[EmptyHashedRelation]]. - * This applies to all Joins (sort merge join, shuffled hash join, and broadcast hash join), - * because sort merge join and shuffled hash join will be changed to broadcast hash join with AQE - * at the first place. - */ -object EliminateJoinToEmptyRelation extends Rule[LogicalPlan] { - - private def canEliminate(plan: LogicalPlan, relation: HashedRelation): Boolean = plan match { - case LogicalQueryStage(_, stage: BroadcastQueryStageExec) if stage.resultOption.get().isDefined - && stage.broadcast.relationFuture.get().value == relation => true - case LogicalQueryStage(_, stage: BroadcastQueryStageExec) if stage.resultOption.get().isDefined - && stage.broadcast.supportsColumnar => { - val cr = stage.broadcast.relationFuture.get().value.asInstanceOf[ColumnarHashedRelation] - cr.relation == relation - } - case _ => false - } - - def apply(plan: LogicalPlan): LogicalPlan = plan.transformDown { - case j @ ExtractSingleColumnNullAwareAntiJoin(_, _) - if canEliminate(j.right, HashedRelationWithAllNullKeys) => - LocalRelation(j.output, data = Seq.empty, isStreaming = j.isStreaming) - - case j @ Join(_, _, Inner, _, _) if canEliminate(j.left, EmptyHashedRelation) || - canEliminate(j.right, EmptyHashedRelation) => - LocalRelation(j.output, data = Seq.empty, isStreaming = j.isStreaming) - - case j @ Join(_, _, LeftSemi, _, _) if canEliminate(j.right, EmptyHashedRelation) => - LocalRelation(j.output, data = Seq.empty, isStreaming = j.isStreaming) - } -} diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/adaptive/ColumnarCustomShuffleReaderExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/adaptive/OmniAQEShuffleReaderExec.scala similarity index 99% rename from omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/adaptive/ColumnarCustomShuffleReaderExec.scala rename to omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/adaptive/OmniAQEShuffleReaderExec.scala index d34b93e5b0da5b61ac35c0824acbf817f1a5e938..c26bed04f20f97821eb9a528cf943b0d2c240c92 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/adaptive/ColumnarCustomShuffleReaderExec.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/adaptive/OmniAQEShuffleReaderExec.scala @@ -36,7 +36,7 @@ import scala.collection.mutable.ArrayBuffer * node during canonicalization. * @param partitionSpecs The partition specs that defines the arrangement. */ -case class ColumnarCustomShuffleReaderExec( +case class OmniAQEShuffleReaderExec( child: SparkPlan, partitionSpecs: Seq[ShufflePartitionSpec]) extends UnaryExecNode { diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OmniOrcFileFormat.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OmniOrcFileFormat.scala index 0e5a7eae6efaac64d59c9effdcd8304d30c5c9fe..57ca9688df38b19cad0c70e17021ec0588a6dd13 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OmniOrcFileFormat.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OmniOrcFileFormat.scala @@ -51,7 +51,7 @@ class OmniOrcFileFormat extends FileFormat with DataSourceRegister with Serializ sparkSession: SparkSession, options: Map[String, String], files: Seq[FileStatus]): Option[StructType] = { - OrcUtils.inferSchema(sparkSession, files, options) + OmniOrcUtils.inferSchema(sparkSession, files, options) } override def buildReaderWithPartitionValues( @@ -82,18 +82,17 @@ class OmniOrcFileFormat extends FileFormat with DataSourceRegister with Serializ val fs = filePath.getFileSystem(conf) val readerOptions = OrcFile.readerOptions(conf).filesystem(fs) - val resultedColPruneInfo = - Utils.tryWithResource(OrcFile.createReader(filePath, readerOptions)) { reader => - OrcUtils.requestedColumnIds( - isCaseSensitive, dataSchema, requiredSchema, reader, conf) - } + val orcSchema = + Utils.tryWithResource(OrcFile.createReader(filePath, readerOptions))(_.getSchema) + val resultedColPruneInfo = OmniOrcUtils.requestedColumnIds( + isCaseSensitive, dataSchema, requiredSchema, orcSchema, conf) if (resultedColPruneInfo.isEmpty) { Iterator.empty } else { // ORC predicate pushdown - if (orcFilterPushDown) { - OrcUtils.readCatalystSchema(filePath, conf, ignoreCorruptFiles).foreach { + if (orcFilterPushDown && filters.nonEmpty) { + OmniOrcUtils.readCatalystSchema(filePath, conf, ignoreCorruptFiles).foreach { fileSchema => OrcFilters.createFilter(fileSchema, filters).foreach { f => OrcInputFormat.setSearchArgument(conf, f, fileSchema.fieldNames) } @@ -101,12 +100,15 @@ class OmniOrcFileFormat extends FileFormat with DataSourceRegister with Serializ } val (requestedColIds, canPruneCols) = resultedColPruneInfo.get - val resultSchemaString = OrcUtils.orcResultSchemaString(canPruneCols, + val resultSchemaString = OmniOrcUtils.orcResultSchemaString(canPruneCols, dataSchema, resultSchema, partitionSchema, conf) assert(requestedColIds.length == requiredSchema.length, "[BUG] requested column IDs do not match required schema") val taskConf = new Configuration(conf) + val includeColumns = requestedColIds.filter(_ != -1).sorted.mkString(",") + taskConf.set(OrcConf.INCLUDE_COLUMNS.getAttribute, includeColumns) + val fileSplit = new FileSplit(filePath, file.start, file.length, Array.empty) val attemptId = new TaskAttemptID(new TaskID(new JobID(), TaskType.MAP, 0), 0) val taskAttemptContext = new TaskAttemptContextImpl(taskConf, attemptId) diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OmniOrcUtils.scala similarity index 95% rename from omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala rename to omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OmniOrcUtils.scala index 3392caa54f0cff52820b98196f9cbd0235151ef3..71b04ef489d7e3271a071ac439ccbdf0ba40decc 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OmniOrcUtils.scala @@ -37,7 +37,7 @@ import org.apache.spark.sql.execution.datasources.SchemaMergeUtils import org.apache.spark.sql.types._ import org.apache.spark.util.{ThreadUtils, Utils} -object OrcUtils extends Logging { +object OmniOrcUtils extends Logging { // The extensions for ORC compression codecs val extensionsForCompressionCodecNames = Map( @@ -121,7 +121,7 @@ object OrcUtils extends Logging { def readOrcSchemasInParallel( files: Seq[FileStatus], conf: Configuration, ignoreCorruptFiles: Boolean): Seq[StructType] = { ThreadUtils.parmap(files, "readingOrcSchemas", 8) { currentFile => - OrcUtils.readSchema(currentFile.getPath, conf, ignoreCorruptFiles).map(toCatalystSchema) + OmniOrcUtils.readSchema(currentFile.getPath, conf, ignoreCorruptFiles).map(toCatalystSchema) }.flatten } @@ -130,9 +130,9 @@ object OrcUtils extends Logging { val orcOptions = new OrcOptions(options, sparkSession.sessionState.conf) if (orcOptions.mergeSchema) { SchemaMergeUtils.mergeSchemasInParallel( - sparkSession, options, files, OrcUtils.readOrcSchemasInParallel) + sparkSession, options, files, OmniOrcUtils.readOrcSchemasInParallel) } else { - OrcUtils.readSchema(sparkSession, files, options) + OmniOrcUtils.readSchema(sparkSession, files, options) } } @@ -246,9 +246,9 @@ object OrcUtils extends Logging { partitionSchema: StructType, conf: Configuration): String = { val resultSchemaString = if (canPruneCols) { - OrcUtils.orcTypeDescriptionString(resultSchema) + OmniOrcUtils.orcTypeDescriptionString(resultSchema) } else { - OrcUtils.orcTypeDescriptionString(StructType(dataSchema.fields ++ partitionSchema.fields)) + OmniOrcUtils.orcTypeDescriptionString(StructType(dataSchema.fields ++ partitionSchema.fields)) } OrcConf.MAPRED_INPUT_SCHEMA.setString(conf, resultSchemaString) resultSchemaString diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/joins/ColumnarBroadcastHashJoinExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/joins/ColumnarBroadcastHashJoinExec.scala index a2ee977f979a873bfb3447c59abac52319e0e0a1..2c1271fb009f14ce324fda3537c618298d198a16 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/joins/ColumnarBroadcastHashJoinExec.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/joins/ColumnarBroadcastHashJoinExec.scala @@ -97,6 +97,9 @@ case class ColumnarBroadcastHashJoinExec( override def nodeName: String = "OmniColumnarBroadcastHashJoin" + override protected def withNewChildrenInternal(newLeft: SparkPlan, newRight: SparkPlan): + ColumnarBroadcastHashJoinExec = copy(left = newLeft, right = newRight) + override def requiredChildDistribution: Seq[Distribution] = { val mode = HashedRelationBroadcastMode(buildBoundKeys, isNullAwareAntiJoin) buildSide match { @@ -109,7 +112,7 @@ case class ColumnarBroadcastHashJoinExec( override lazy val outputPartitioning: Partitioning = { joinType match { - case _: InnerLike if sqlContext.conf.broadcastHashJoinOutputPartitioningExpandLimit > 0 => + case _: InnerLike if session.sqlContext.conf.broadcastHashJoinOutputPartitioningExpandLimit > 0 => streamedPlan.outputPartitioning match { case h: HashPartitioning => expandOutputPartitioning(h) case c: PartitioningCollection => expandOutputPartitioning(c) @@ -150,7 +153,7 @@ case class ColumnarBroadcastHashJoinExec( // Seq("a", "b", "c"), Seq("a", "b", "y"), Seq("a", "x", "c"), Seq("a", "x", "y"). // The expanded expressions are returned as PartitioningCollection. private def expandOutputPartitioning(partitioning: HashPartitioning): PartitioningCollection = { - val maxNumCombinations = sqlContext.conf.broadcastHashJoinOutputPartitioningExpandLimit + val maxNumCombinations = session.sqlContext.conf.broadcastHashJoinOutputPartitioningExpandLimit var currentNumCombinations = 0 def generateExprCombinations( diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/joins/ColumnarShuffledHashJoinExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/joins/ColumnarShuffledHashJoinExec.scala index 9eb666fcc85df9295656973c4a833c52a472669e..263af0ddbeb6c6ac8ca7d64917eedf0e889782e9 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/joins/ColumnarShuffledHashJoinExec.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/joins/ColumnarShuffledHashJoinExec.scala @@ -50,7 +50,8 @@ case class ColumnarShuffledHashJoinExec( buildSide: BuildSide, condition: Option[Expression], left: SparkPlan, - right: SparkPlan) + right: SparkPlan, + isSkewJoin: Boolean) extends HashJoin with ShuffledJoin { override lazy val metrics = Map( @@ -81,6 +82,9 @@ case class ColumnarShuffledHashJoinExec( override def outputPartitioning: Partitioning = super[ShuffledJoin].outputPartitioning + override protected def withNewChildrenInternal(newLeft: SparkPlan, newRight: SparkPlan): + ColumnarShuffledHashJoinExec = copy(left = newLeft, right = newRight) + override def outputOrdering: Seq[SortOrder] = joinType match { case FullOuter => Nil case _ => super.outputOrdering diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/joins/ColumnarSortMergeJoinExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/joins/ColumnarSortMergeJoinExec.scala index 59b763428b1f5955c581f85aee8765280dccac01..d55af2d9d7e8c0427ab7e27333c3e91a45a8be8d 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/joins/ColumnarSortMergeJoinExec.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/joins/ColumnarSortMergeJoinExec.scala @@ -68,6 +68,12 @@ class ColumnarSortMergeJoinExec( if (isSkewJoin) "OmniColumnarSortMergeJoin(skew=true)" else "OmniColumnarSortMergeJoin" } + override protected def withNewChildrenInternal(newLeft: SparkPlan, + newRight: SparkPlan): ColumnarSortMergeJoinExec = { + new ColumnarSortMergeJoinExec(this.leftKeys, this.rightKeys, this.joinType, + this.condition, newLeft, newRight, this.isSkewJoin) + } + val SMJ_NEED_ADD_STREAM_TBL_DATA = 2 val SMJ_NEED_ADD_BUFFERED_TBL_DATA = 3 val SCAN_FINISH = 4 diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/hive/execution/PruneHiveTablePartitions.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/hive/execution/PruneHiveTablePartitions.scala deleted file mode 100644 index 0503b2b7b684537f5191585cccf8b55cf50997d8..0000000000000000000000000000000000000000 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/hive/execution/PruneHiveTablePartitions.scala +++ /dev/null @@ -1,126 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.hive.execution - -import org.apache.hadoop.hive.common.StatsSetupConst - -import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.catalyst.analysis.CastSupport -import org.apache.spark.sql.catalyst.catalog._ -import org.apache.spark.sql.catalyst.expressions.{And, AttributeSet, Expression, ExpressionSet, PredicateHelper, SubqueryExpression} -import org.apache.spark.sql.catalyst.planning.PhysicalOperation -import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, Project} -import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.FilterEstimation -import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.execution.datasources.DataSourceStrategy - -/** - * Prune hive table partitions using partition filters on [[HiveTableRelation]]. The pruned - * partitions will be kept in [[HiveTableRelation.prunedPartitions]], and the statistics of - * the hive table relation will be updated based on pruned partitions. - * - * This rule is executed in optimization phase, so the statistics can be updated before physical - * planning, which is useful for some spark strategy, e.g. - * [[org.apache.spark.sql.execution.SparkStrategies.JoinSelection]]. - * - * TODO: merge this with PruneFileSourcePartitions after we completely make hive as a data source. - */ -private[sql] class PruneHiveTablePartitions(session: SparkSession) - extends Rule[LogicalPlan] with CastSupport with PredicateHelper { - - /** - * Extract the partition filters from the filters on the table. - */ - private def getPartitionKeyFilters( - filters: Seq[Expression], - relation: HiveTableRelation): ExpressionSet = { - val normalizedFilters = DataSourceStrategy.normalizeExprs( - filters.filter(f => f.deterministic && !SubqueryExpression.hasSubquery(f)), relation.output) - val partitionColumnSet = AttributeSet(relation.partitionCols) - ExpressionSet( - normalizedFilters.flatMap(extractPredicatesWithinOutputSet(_, partitionColumnSet))) - } - - /** - * Prune the hive table using filters on the partitions of the table. - */ - private def prunePartitions( - relation: HiveTableRelation, - partitionFilters: ExpressionSet): Seq[CatalogTablePartition] = { - if (conf.metastorePartitionPruning) { - session.sessionState.catalog.listPartitionsByFilter( - relation.tableMeta.identifier, partitionFilters.toSeq) - } else { - ExternalCatalogUtils.prunePartitionsByFilter(relation.tableMeta, - session.sessionState.catalog.listPartitions(relation.tableMeta.identifier), - partitionFilters.toSeq, conf.sessionLocalTimeZone) - } - } - - /** - * Update the statistics of the table. - */ - private def updateTableMeta( - relation: HiveTableRelation, - prunedPartitions: Seq[CatalogTablePartition], - partitionKeyFilters: ExpressionSet): CatalogTable = { - val sizeOfPartitions = prunedPartitions.map { partition => - val rawDataSize = partition.parameters.get(StatsSetupConst.RAW_DATA_SIZE).map(_.toLong) - val totalSize = partition.parameters.get(StatsSetupConst.TOTAL_SIZE).map(_.toLong) - if (rawDataSize.isDefined && rawDataSize.get > 0) { - rawDataSize.get - } else if (totalSize.isDefined && totalSize.get > 0L) { - totalSize.get - } else { - 0L - } - } - // Fix spark issue SPARK-34119(row 95-106) - if (sizeOfPartitions.forall(_ > 0)) { - val filteredStats = - FilterEstimation(Filter(partitionKeyFilters.reduce(And), relation)).estimate - val colStats = filteredStats.map(_.attributeStats.map { case (attr, colStat) => - (attr.name, colStat.toCatalogColumnStat(attr.name, attr.dataType)) - }) - relation.tableMeta.copy( - stats = Some(CatalogStatistics( - sizeInBytes = BigInt(sizeOfPartitions.sum), - rowCount = filteredStats.flatMap(_.rowCount), - colStats = colStats.getOrElse(Map.empty)))) - } else { - relation.tableMeta - } - } - - override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { - case op @ PhysicalOperation(projections, filters, relation: HiveTableRelation) - if filters.nonEmpty && relation.isPartitioned && relation.prunedPartitions.isEmpty => - val partitionKeyFilters = getPartitionKeyFilters(filters, relation) - if (partitionKeyFilters.nonEmpty) { - val newPartitions = prunePartitions(relation, partitionKeyFilters) - // Fix spark issue SPARK-34119(row 117) - val newTableMeta = updateTableMeta(relation, newPartitions, partitionKeyFilters) - val newRelation = relation.copy( - tableMeta = newTableMeta, prunedPartitions = Some(newPartitions)) - // Keep partition filters so that they are visible in physical planning - Project(projections, Filter(filters.reduceLeft(And), newRelation)) - } else { - op - } - } -} diff --git a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/shuffle/ColumnShuffleSerializerDisableCompressSuite.scala b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/shuffle/ColumnShuffleSerializerDisableCompressSuite.scala index 237321f5921726da1b80119936b8bcadaa6f1c95..62a837953b5358a4e460f580a8048bb91fa5b759 100644 --- a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/shuffle/ColumnShuffleSerializerDisableCompressSuite.scala +++ b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/shuffle/ColumnShuffleSerializerDisableCompressSuite.scala @@ -107,14 +107,14 @@ class ColumnShuffleSerializerDisableCompressSuite extends SharedSparkSession { when(blockResolver.getDataFile(0, 0)).thenReturn(outputFile) doAnswer { (invocationOnMock: InvocationOnMock) => - val tmp = invocationOnMock.getArguments()(3).asInstanceOf[File] + val tmp = invocationOnMock.getArguments()(4).asInstanceOf[File] if (tmp != null) { outputFile.delete tmp.renameTo(outputFile) } null }.when(blockResolver) - .writeIndexFileAndCommit(anyInt, anyLong, any(classOf[Array[Long]]), any(classOf[File])) + .writeMetadataFileAndCommit(anyInt, anyLong, any(classOf[Array[Long]]), any(classOf[Array[Long]]), any(classOf[File])) } override def afterEach(): Unit = { diff --git a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/shuffle/ColumnShuffleSerializerLz4Suite.scala b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/shuffle/ColumnShuffleSerializerLz4Suite.scala index 8f0329248c9cf8e75b277b9cae4a3bd3e5a2e361..a8f287e1f77d8ed058239e610b679ac22533a583 100644 --- a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/shuffle/ColumnShuffleSerializerLz4Suite.scala +++ b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/shuffle/ColumnShuffleSerializerLz4Suite.scala @@ -108,14 +108,14 @@ class ColumnShuffleSerializerLz4Suite extends SharedSparkSession { when(blockResolver.getDataFile(0, 0)).thenReturn(outputFile) doAnswer { (invocationOnMock: InvocationOnMock) => - val tmp = invocationOnMock.getArguments()(3).asInstanceOf[File] + val tmp = invocationOnMock.getArguments()(4).asInstanceOf[File] if (tmp != null) { outputFile.delete tmp.renameTo(outputFile) } null }.when(blockResolver) - .writeIndexFileAndCommit(anyInt, anyLong, any(classOf[Array[Long]]), any(classOf[File])) + .writeMetadataFileAndCommit(anyInt, anyLong, any(classOf[Array[Long]]), any(classOf[Array[Long]]), any(classOf[File])) } override def afterEach(): Unit = { diff --git a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/shuffle/ColumnShuffleSerializerSnappySuite.scala b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/shuffle/ColumnShuffleSerializerSnappySuite.scala index 5b6811b03362294e35ca39a65de42592a9385aa8..df3004cce9479f43d81facc7a517f28ed18f02d9 100644 --- a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/shuffle/ColumnShuffleSerializerSnappySuite.scala +++ b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/shuffle/ColumnShuffleSerializerSnappySuite.scala @@ -108,14 +108,14 @@ class ColumnShuffleSerializerSnappySuite extends SharedSparkSession { when(blockResolver.getDataFile(0, 0)).thenReturn(outputFile) doAnswer { (invocationOnMock: InvocationOnMock) => - val tmp = invocationOnMock.getArguments()(3).asInstanceOf[File] + val tmp = invocationOnMock.getArguments()(4).asInstanceOf[File] if (tmp != null) { outputFile.delete tmp.renameTo(outputFile) } null }.when(blockResolver) - .writeIndexFileAndCommit(anyInt, anyLong, any(classOf[Array[Long]]), any(classOf[File])) + .writeMetadataFileAndCommit(anyInt, anyLong, any(classOf[Array[Long]]), any(classOf[Array[Long]]), any(classOf[File])) } override def afterEach(): Unit = { diff --git a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/shuffle/ColumnShuffleSerializerZlibSuite.scala b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/shuffle/ColumnShuffleSerializerZlibSuite.scala index a9924a95d42310d1f784088f1b67e015f45d1ca3..8c3b27914008e57217fc5a45628d793bc3ab9d6f 100644 --- a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/shuffle/ColumnShuffleSerializerZlibSuite.scala +++ b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/shuffle/ColumnShuffleSerializerZlibSuite.scala @@ -108,14 +108,14 @@ class ColumnShuffleSerializerZlibSuite extends SharedSparkSession { when(blockResolver.getDataFile(0, 0)).thenReturn(outputFile) doAnswer { (invocationOnMock: InvocationOnMock) => - val tmp = invocationOnMock.getArguments()(3).asInstanceOf[File] + val tmp = invocationOnMock.getArguments()(4).asInstanceOf[File] if (tmp != null) { outputFile.delete tmp.renameTo(outputFile) } null }.when(blockResolver) - .writeIndexFileAndCommit(anyInt, anyLong, any(classOf[Array[Long]]), any(classOf[File])) + .writeMetadataFileAndCommit(anyInt, anyLong, any(classOf[Array[Long]]), any(classOf[Array[Long]]), any(classOf[File])) } override def afterEach(): Unit = { diff --git a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/shuffle/ColumnarShuffleWriterSuite.scala b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/shuffle/ColumnarShuffleWriterSuite.scala index 00adf145979e33f7dd7b1c49873fd72cdff18756..d527c177805cd54baa53c18c98dddfd84870b953 100644 --- a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/shuffle/ColumnarShuffleWriterSuite.scala +++ b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/shuffle/ColumnarShuffleWriterSuite.scala @@ -107,14 +107,14 @@ class ColumnarShuffleWriterSuite extends SharedSparkSession { when(blockResolver.getDataFile(0, 0)).thenReturn(outputFile) doAnswer { (invocationOnMock: InvocationOnMock) => - val tmp = invocationOnMock.getArguments()(3).asInstanceOf[File] + val tmp = invocationOnMock.getArguments()(4).asInstanceOf[File] if (tmp != null) { outputFile.delete tmp.renameTo(outputFile) } null }.when(blockResolver) - .writeIndexFileAndCommit(anyInt, anyLong, any(classOf[Array[Long]]), any(classOf[File])) + .writeMetadataFileAndCommit(anyInt, anyLong, any(classOf[Array[Long]]), any(classOf[Array[Long]]), any(classOf[File])) } override def afterEach(): Unit = { diff --git a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala new file mode 100644 index 0000000000000000000000000000000000000000..d3cbaa8c41e2d133c8b2ebd450195118a5c293ed --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala @@ -0,0 +1,307 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.rules._ + +class CombiningLimitsSuite extends PlanTest { + + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = + Batch("Column Pruning", FixedPoint(100), + ColumnPruning, + RemoveNoopOperators) :: + Batch("Eliminate Limit", FixedPoint(10), + EliminateLimits) :: + Batch("Constant Folding", FixedPoint(10), + NullPropagation, + ConstantFolding, + BooleanSimplification, + SimplifyConditionals) :: Nil + } + + val testRelation = LocalRelation.fromExternalRows( + Seq("a".attr.int, "b".attr.int, "c".attr.int), + 1.to(10).map(_ => Row(1, 2, 3)) + ) + val testRelation2 = LocalRelation.fromExternalRows( + Seq("x".attr.int, "y".attr.int, "z".attr.int), + Seq(Row(1, 2, 3), Row(2, 3, 4)) + ) + val testRelation3 = RelationWithoutMaxRows(Seq("i".attr.int)) + val testRelation4 = LongMaxRelation(Seq("j".attr.int)) + val testRelation5 = EmptyRelation(Seq("k".attr.int)) + + test("limits: combines two limits") { + val originalQuery = + testRelation + .select('a) + .limit(10) + .limit(5) + + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = + testRelation + .select('a) + .limit(5).analyze + + comparePlans(optimized, correctAnswer) + } + + test("limits: combines three limits") { + val originalQuery = + testRelation + .select('a) + .limit(2) + .limit(7) + .limit(5) + + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = + testRelation + .select('a) + .limit(2).analyze + + comparePlans(optimized, correctAnswer) + } + + test("limits: combines two limits after ColumnPruning") { + val originalQuery = + testRelation + .select('a) + .limit(2) + .select('a) + .limit(5) + + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = + testRelation + .select('a) + .limit(2).analyze + + comparePlans(optimized, correctAnswer) + } + + test("SPARK-33442: Change Combine Limit to Eliminate limit using max row") { + // test child max row <= limit. + val query1 = testRelation.select().groupBy()(count(1)).limit(1).analyze + val optimized1 = Optimize.execute(query1) + val expected1 = testRelation.select().groupBy()(count(1)).analyze + comparePlans(optimized1, expected1) + + // test child max row > limit. + val query2 = testRelation.select().groupBy()(count(1)).limit(0).analyze + val optimized2 = Optimize.execute(query2) + comparePlans(optimized2, query2) + + // test child max row is none + val query3 = testRelation.select(Symbol("a")).limit(1).analyze + val optimized3 = Optimize.execute(query3) + comparePlans(optimized3, query3) + + // test sort after limit + val query4 = testRelation.select().groupBy()(count(1)) + .orderBy(count(1).asc).limit(1).analyze + val optimized4 = Optimize.execute(query4) + // the top project has been removed, so we need optimize expected too + val expected4 = Optimize.execute( + testRelation.select().groupBy()(count(1)).orderBy(count(1).asc).analyze) + comparePlans(optimized4, expected4) + } + + test("SPARK-33497: Eliminate Limit if LocalRelation max rows not larger than Limit") { + checkPlanAndMaxRow( + testRelation.select().limit(10), + testRelation.select(), + 10 + ) + } + + test("SPARK-33497: Eliminate Limit if Range max rows not larger than Limit") { + checkPlanAndMaxRow( + Range(0, 100, 1, None).select().limit(200), + Range(0, 100, 1, None).select(), + 100 + ) + checkPlanAndMaxRow( + Range(-1, Long.MaxValue, 1, None).select().limit(1), + Range(-1, Long.MaxValue, 1, None).select().limit(1), + 1 + ) + } + + test("SPARK-33497: Eliminate Limit if Sample max rows not larger than Limit") { + checkPlanAndMaxRow( + testRelation.select().sample(0, 0.2, false, 1).limit(10), + testRelation.select().sample(0, 0.2, false, 1), + 10 + ) + } + + test("SPARK-38271: PoissonSampler may output more rows than child.maxRows") { + val query = testRelation.select().sample(0, 0.2, true, 1) + assert(query.maxRows.isEmpty) + val optimized = Optimize.execute(query.analyze) + assert(optimized.maxRows.isEmpty) + // can not eliminate Limit since Sample.maxRows is None + checkPlanAndMaxRow( + query.limit(10), + query.limit(10), + 10 + ) + } + + test("SPARK-33497: Eliminate Limit if Deduplicate max rows not larger than Limit") { + checkPlanAndMaxRow( + testRelation.deduplicate("a".attr).limit(10), + testRelation.deduplicate("a".attr), + 10 + ) + } + + test("SPARK-33497: Eliminate Limit if Repartition max rows not larger than Limit") { + checkPlanAndMaxRow( + testRelation.repartition(2).limit(10), + testRelation.repartition(2), + 10 + ) + checkPlanAndMaxRow( + testRelation.distribute("a".attr)(2).limit(10), + testRelation.distribute("a".attr)(2), + 10 + ) + } + + test("SPARK-33497: Eliminate Limit if Join max rows not larger than Limit") { + Seq(Inner, FullOuter, LeftOuter, RightOuter).foreach { joinType => + checkPlanAndMaxRow( + testRelation.join(testRelation2, joinType).limit(20), + testRelation.join(testRelation2, joinType), + 20 + ) + checkPlanAndMaxRow( + testRelation.join(testRelation2, joinType).limit(10), + testRelation.join(testRelation2, joinType).limit(10), + 10 + ) + // without maxRow + checkPlanAndMaxRow( + testRelation.join(testRelation3, joinType).limit(100), + testRelation.join(testRelation3, joinType).limit(100), + 100 + ) + // maxRow is not valid long + checkPlanAndMaxRow( + testRelation.join(testRelation4, joinType).limit(100), + testRelation.join(testRelation4, joinType).limit(100), + 100 + ) + } + + Seq(LeftSemi, LeftAnti).foreach { joinType => + checkPlanAndMaxRow( + testRelation.join(testRelation2, joinType).limit(5), + testRelation.join(testRelation2.select(), joinType).limit(5), + 5 + ) + checkPlanAndMaxRow( + testRelation.join(testRelation2, joinType).limit(10), + testRelation.join(testRelation2.select(), joinType), + 10 + ) + } + } + + test("SPARK-33497: Eliminate Limit if Window max rows not larger than Limit") { + checkPlanAndMaxRow( + testRelation.window( + Seq(count(1).as("c")), Seq("a".attr), Seq("b".attr.asc)).limit(20), + testRelation.window( + Seq(count(1).as("c")), Seq("a".attr), Seq("b".attr.asc)), + 10 + ) + } + + test("SPARK-34628: Remove GlobalLimit operator if its child max rows <= limit") { + val query = GlobalLimit(100, testRelation) + val optimized = Optimize.execute(query.analyze) + comparePlans(optimized, testRelation) + } + + test("SPARK-37064: Fix outer join return the wrong max rows if other side is empty") { + Seq(LeftOuter, FullOuter).foreach { joinType => + checkPlanAndMaxRow( + testRelation.join(testRelation5, joinType).limit(9), + testRelation.join(testRelation5, joinType).limit(9), + 9 + ) + + checkPlanAndMaxRow( + testRelation.join(testRelation5, joinType).limit(10), + testRelation.join(testRelation5, joinType), + 10 + ) + } + + Seq(RightOuter, FullOuter).foreach { joinType => + checkPlanAndMaxRow( + testRelation5.join(testRelation, joinType).limit(9), + testRelation5.join(testRelation, joinType).limit(9), + 9 + ) + + checkPlanAndMaxRow( + testRelation5.join(testRelation, joinType).limit(10), + testRelation5.join(testRelation, joinType), + 10 + ) + } + + Seq(Inner, Cross).foreach { joinType => + checkPlanAndMaxRow( + testRelation.join(testRelation5, joinType).limit(9), + testRelation.join(testRelation5, joinType), + 0 + ) + } + } + + private def checkPlanAndMaxRow( + optimized: LogicalPlan, expected: LogicalPlan, expectedMaxRow: Long): Unit = { + comparePlans(Optimize.execute(optimized.analyze), expected.analyze) + assert(expected.maxRows.get == expectedMaxRow) + } +} + +case class RelationWithoutMaxRows(output: Seq[Attribute]) extends LeafNode { + override def maxRows: Option[Long] = None +} + +case class LongMaxRelation(output: Seq[Attribute]) extends LeafNode { + override def maxRows: Option[Long] = Some(Long.MaxValue) +} + +case class EmptyRelation(output: Seq[Attribute]) extends LeafNode { + override def maxRows: Option[Long] = Some(0) +} diff --git a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConvertToLocalRelationSuite.scala b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConvertToLocalRelationSuite.scala new file mode 100644 index 0000000000000000000000000000000000000000..02b6eed9ed050e3e718dd855d07f01da4c8ddb0f --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConvertToLocalRelationSuite.scala @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions.{Expression, GenericInternalRow, LessThan, Literal, UnaryExpression} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.rules.RuleExecutor +import org.apache.spark.sql.types.{DataType, StructType} + + +class ConvertToLocalRelationSuite extends PlanTest { + + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = + Batch("LocalRelation", FixedPoint(100), + ConvertToLocalRelation) :: Nil + } + + test("Project on LocalRelation should be turned into a single LocalRelation") { + val testRelation = LocalRelation( + LocalRelation('a.int, 'b.int).output, + InternalRow(1, 2) :: InternalRow(4, 5) :: Nil) + + val correctAnswer = LocalRelation( + LocalRelation('a1.int, 'b1.int).output, + InternalRow(1, 3) :: InternalRow(4, 6) :: Nil) + + val projectOnLocal = testRelation.select( + UnresolvedAttribute("a").as("a1"), + (UnresolvedAttribute("b") + 1).as("b1")) + + val optimized = Optimize.execute(projectOnLocal.analyze) + + comparePlans(optimized, correctAnswer) + } + + test("Filter on LocalRelation should be turned into a single LocalRelation") { + val testRelation = LocalRelation( + LocalRelation('a.int, 'b.int).output, + InternalRow(1, 2) :: InternalRow(4, 5) :: Nil) + + val correctAnswer = LocalRelation( + LocalRelation('a1.int, 'b1.int).output, + InternalRow(1, 3) :: Nil) + + val filterAndProjectOnLocal = testRelation + .select(UnresolvedAttribute("a").as("a1"), (UnresolvedAttribute("b") + 1).as("b1")) + .where(LessThan(UnresolvedAttribute("b1"), Literal.create(6))) + + val optimized = Optimize.execute(filterAndProjectOnLocal.analyze) + + comparePlans(optimized, correctAnswer) + } + + test("SPARK-27798: Expression reusing output shouldn't override values in local relation") { + val testRelation = LocalRelation( + LocalRelation('a.int).output, + InternalRow(1) :: InternalRow(2) :: Nil) + + val correctAnswer = LocalRelation( + LocalRelation('a.struct('a1.int)).output, + InternalRow(InternalRow(1)) :: InternalRow(InternalRow(2)) :: Nil) + + val projected = testRelation.select(ExprReuseOutput(UnresolvedAttribute("a")).as("a")) + val optimized = Optimize.execute(projected.analyze) + + comparePlans(optimized, correctAnswer) + } +} + + +// Dummy expression used for testing. It reuses output row. Assumes child expr outputs an integer. +case class ExprReuseOutput(child: Expression) extends UnaryExpression { + override def dataType: DataType = StructType.fromDDL("a1 int") + override def nullable: Boolean = true + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = + throw new UnsupportedOperationException("Should not trigger codegen") + + private val row: InternalRow = new GenericInternalRow(1) + + override def eval(input: InternalRow): Any = { + row.update(0, child.eval(input)) + row + } + + override protected def withNewChildInternal(newChild: Expression): ExprReuseOutput = + copy(child = newChild) +} diff --git a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeOneRowPlanSuite.scala b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeOneRowPlanSuite.scala new file mode 100644 index 0000000000000000000000000000000000000000..3266febb9ed69d06b75e3a92855924c943ab6ec2 --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeOneRowPlanSuite.scala @@ -0,0 +1,104 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions.Literal +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.rules.RuleExecutor + +class OptimizeOneRowPlanSuite extends PlanTest { + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = + Batch("Replace Operators", Once, ReplaceDistinctWithAggregate) :: + Batch("Eliminate Sorts", Once, EliminateSorts) :: + Batch("Optimize One Row Plan", FixedPoint(10), OptimizeOneRowPlan) :: Nil + } + + private val t1 = LocalRelation.fromExternalRows(Seq($"a".int), data = Seq(Row(1))) + private val t2 = LocalRelation.fromExternalRows(Seq($"a".int), data = Seq(Row(1), Row(2))) + + test("SPARK-35906: Remove order by if the maximum number of rows less than or equal to 1") { + comparePlans( + Optimize.execute(t2.groupBy()(count(1).as("cnt")).orderBy('cnt.asc)).analyze, + t2.groupBy()(count(1).as("cnt")).analyze) + + comparePlans( + Optimize.execute(t2.limit(Literal(1)).orderBy('a.asc).orderBy('a.asc)).analyze, + t2.limit(Literal(1)).analyze) + } + + test("Remove sort") { + // remove local sort + val plan1 = LocalLimit(0, t1).union(LocalLimit(0, t2)).sortBy($"a".desc).analyze + val expected = LocalLimit(0, t1).union(LocalLimit(0, t2)).analyze + comparePlans(Optimize.execute(plan1), expected) + + // do not remove + val plan2 = t2.orderBy($"a".desc).analyze + comparePlans(Optimize.execute(plan2), plan2) + + val plan3 = t2.sortBy($"a".desc).analyze + comparePlans(Optimize.execute(plan3), plan3) + } + + test("Convert group only aggregate to project") { + val plan1 = t1.groupBy($"a")($"a").analyze + comparePlans(Optimize.execute(plan1), t1.select($"a").analyze) + + val plan2 = t1.groupBy($"a" + 1)($"a" + 1).analyze + comparePlans(Optimize.execute(plan2), t1.select($"a" + 1).analyze) + + // do not remove + val plan3 = t2.groupBy($"a")($"a").analyze + comparePlans(Optimize.execute(plan3), plan3) + + val plan4 = t1.groupBy($"a")(sum($"a")).analyze + comparePlans(Optimize.execute(plan4), plan4) + + val plan5 = t1.groupBy()(sum($"a")).analyze + comparePlans(Optimize.execute(plan5), plan5) + } + + test("Remove distinct in aggregate expression") { + val plan1 = t1.groupBy($"a")(sumDistinct($"a").as("s")).analyze + val expected1 = t1.groupBy($"a")(sum($"a").as("s")).analyze + comparePlans(Optimize.execute(plan1), expected1) + + val plan2 = t1.groupBy()(sumDistinct($"a").as("s")).analyze + val expected2 = t1.groupBy()(sum($"a").as("s")).analyze + comparePlans(Optimize.execute(plan2), expected2) + + // do not remove + val plan3 = t2.groupBy($"a")(sumDistinct($"a").as("s")).analyze + comparePlans(Optimize.execute(plan3), plan3) + } + + test("Remove in complex case") { + val plan1 = t1.groupBy($"a")($"a").orderBy($"a".asc).analyze + val expected1 = t1.select($"a").analyze + comparePlans(Optimize.execute(plan1), expected1) + + val plan2 = t1.groupBy($"a")(sumDistinct($"a").as("s")).orderBy($"s".asc).analyze + val expected2 = t1.groupBy($"a")(sum($"a").as("s")).analyze + comparePlans(Optimize.execute(plan2), expected2) + } +} diff --git a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/CoalesceShufflePartitionsSuite.scala b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/CoalesceShufflePartitionsSuite.scala index 9f4ae359e1cc8459841dc9757fb52a514a1cbfb4..ddf4d421f3e5d8a9e43c54c9ed31b16439b5c597 100644 --- a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/CoalesceShufflePartitionsSuite.scala +++ b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/CoalesceShufflePartitionsSuite.scala @@ -18,13 +18,16 @@ package org.apache.spark.sql.execution import org.scalatest.BeforeAndAfterAll + import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.internal.config.IO_ENCRYPTION_ENABLED import org.apache.spark.internal.config.UI.UI_ENABLED import org.apache.spark.sql._ import org.apache.spark.sql.execution.adaptive._ +import org.apache.spark.sql.execution.adaptive.AQEShuffleReadExec import org.apache.spark.sql.execution.exchange.ReusedExchangeExec import org.apache.spark.sql.functions._ -import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf} +import org.apache.spark.sql.internal.SQLConf class CoalesceShufflePartitionsSuite extends SparkFunSuite with BeforeAndAfterAll { @@ -53,23 +56,24 @@ class CoalesceShufflePartitionsSuite extends SparkFunSuite with BeforeAndAfterAl val numInputPartitions: Int = 10 def withSparkSession( - f: SparkSession => Unit, - targetPostShuffleInputSize: Int, - minNumPostShufflePartitions: Option[Int]): Unit = { + f: SparkSession => Unit, + targetPostShuffleInputSize: Int, + minNumPostShufflePartitions: Option[Int], + enableIOEncryption: Boolean = false): Unit = { val sparkConf = new SparkConf(false) .setMaster("local[*]") .setAppName("test") .set(UI_ENABLED, false) + .set(IO_ENCRYPTION_ENABLED, enableIOEncryption) .set(SQLConf.SHUFFLE_PARTITIONS.key, "5") .set(SQLConf.COALESCE_PARTITIONS_INITIAL_PARTITION_NUM.key, "5") .set(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key, "true") + .set(SQLConf.FETCH_SHUFFLE_BLOCKS_IN_BATCH.key, "true") .set(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key, "-1") .set( SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES.key, targetPostShuffleInputSize.toString) - .set(StaticSQLConf.SPARK_SESSION_EXTENSIONS.key, "com.huawei.boostkit.spark.ColumnarPlugin") - .set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.OmniColumnarShuffleManager") minNumPostShufflePartitions match { case Some(numPartitions) => sparkConf.set(SQLConf.COALESCE_PARTITIONS_MIN_PARTITION_NUM.key, numPartitions.toString) @@ -90,7 +94,7 @@ class CoalesceShufflePartitionsSuite extends SparkFunSuite with BeforeAndAfterAl } test(s"determining the number of reducers: aggregate operator$testNameNote") { - val test = { spark: SparkSession => + val test: SparkSession => Unit = { spark: SparkSession => val df = spark .range(0, 1000, 1, numInputPartitions) @@ -106,27 +110,27 @@ class CoalesceShufflePartitionsSuite extends SparkFunSuite with BeforeAndAfterAl // by the ExchangeCoordinator. val finalPlan = agg.queryExecution.executedPlan .asInstanceOf[AdaptiveSparkPlanExec].executedPlan - val shuffleReaders = finalPlan.collect { - case r @ ColumnarCoalescedShuffleReader() => r + val shuffleReads = finalPlan.collect { + case r @ CoalescedShuffleRead() => r } - assert(shuffleReaders.length === 1) + minNumPostShufflePartitions match { case Some(numPartitions) => - shuffleReaders.foreach { reader => - assert(reader.outputPartitioning.numPartitions === numPartitions) - } + assert(shuffleReads.isEmpty) + case None => - shuffleReaders.foreach { reader => - assert(reader.outputPartitioning.numPartitions === 3) + assert(shuffleReads.length === 1) + shuffleReads.foreach { read => + assert(read.outputPartitioning.numPartitions === 3) } } } - // The number of coulmn partitions byte is small. smaller threshold value should be used - withSparkSession(test, 1500, minNumPostShufflePartitions) + + withSparkSession(test, 2000, minNumPostShufflePartitions) } test(s"determining the number of reducers: join operator$testNameNote") { - val test = { spark: SparkSession => + val test: SparkSession => Unit = { spark: SparkSession => val df1 = spark .range(0, 1000, 1, numInputPartitions) @@ -152,23 +156,23 @@ class CoalesceShufflePartitionsSuite extends SparkFunSuite with BeforeAndAfterAl // by the ExchangeCoordinator. val finalPlan = join.queryExecution.executedPlan .asInstanceOf[AdaptiveSparkPlanExec].executedPlan - val shuffleReaders = finalPlan.collect { - case r @ ColumnarCoalescedShuffleReader() => r + val shuffleReads = finalPlan.collect { + case r @ CoalescedShuffleRead() => r } - assert(shuffleReaders.length === 2) + minNumPostShufflePartitions match { case Some(numPartitions) => - shuffleReaders.foreach { reader => - assert(reader.outputPartitioning.numPartitions === numPartitions) - } + assert(shuffleReads.isEmpty) + case None => - shuffleReaders.foreach { reader => - assert(reader.outputPartitioning.numPartitions === 2) + assert(shuffleReads.length === 2) + shuffleReads.foreach { read => + assert(read.outputPartitioning.numPartitions === 2) } } } - // The number of coulmn partitions byte is small. smaller threshold value should be used - withSparkSession(test, 11384, minNumPostShufflePartitions) + + withSparkSession(test, 16384, minNumPostShufflePartitions) } test(s"determining the number of reducers: complex query 1$testNameNote") { @@ -203,23 +207,23 @@ class CoalesceShufflePartitionsSuite extends SparkFunSuite with BeforeAndAfterAl // by the ExchangeCoordinator. val finalPlan = join.queryExecution.executedPlan .asInstanceOf[AdaptiveSparkPlanExec].executedPlan - val shuffleReaders = finalPlan.collect { - case r @ ColumnarCoalescedShuffleReader() => r + val shuffleReads = finalPlan.collect { + case r @ CoalescedShuffleRead() => r } - assert(shuffleReaders.length === 2) + minNumPostShufflePartitions match { case Some(numPartitions) => - shuffleReaders.foreach { reader => - assert(reader.outputPartitioning.numPartitions === numPartitions) - } + assert(shuffleReads.isEmpty) + case None => - shuffleReaders.foreach { reader => - assert(reader.outputPartitioning.numPartitions === 3) + assert(shuffleReads.length === 2) + shuffleReads.foreach { read => + assert(read.outputPartitioning.numPartitions === 2) } } } - // The number of coulmn partitions byte is small. smaller threshold value should be used - withSparkSession(test, 7384, minNumPostShufflePartitions) + + withSparkSession(test, 16384, minNumPostShufflePartitions) } test(s"determining the number of reducers: complex query 2$testNameNote") { @@ -254,23 +258,23 @@ class CoalesceShufflePartitionsSuite extends SparkFunSuite with BeforeAndAfterAl // by the ExchangeCoordinator. val finalPlan = join.queryExecution.executedPlan .asInstanceOf[AdaptiveSparkPlanExec].executedPlan - val shuffleReaders = finalPlan.collect { - case r @ ColumnarCoalescedShuffleReader() => r + val shuffleReads = finalPlan.collect { + case r @ CoalescedShuffleRead() => r } - assert(shuffleReaders.length === 2) + minNumPostShufflePartitions match { case Some(numPartitions) => - shuffleReaders.foreach { reader => - assert(reader.outputPartitioning.numPartitions === numPartitions) - } + assert(shuffleReads.isEmpty) + case None => - shuffleReaders.foreach { reader => - assert(reader.outputPartitioning.numPartitions === 2) + assert(shuffleReads.length === 2) + shuffleReads.foreach { read => + assert(read.outputPartitioning.numPartitions === 3) } } } - // The number of coulmn partitions byte is small. smaller threshold value should be used - withSparkSession(test, 10000, minNumPostShufflePartitions) + + withSparkSession(test, 12000, minNumPostShufflePartitions) } test(s"determining the number of reducers: plan already partitioned$testNameNote") { @@ -296,10 +300,10 @@ class CoalesceShufflePartitionsSuite extends SparkFunSuite with BeforeAndAfterAl // Then, let's make sure we do not reduce number of post shuffle partitions. val finalPlan = join.queryExecution.executedPlan .asInstanceOf[AdaptiveSparkPlanExec].executedPlan - val shuffleReaders = finalPlan.collect { - case r @ ColumnarCoalescedShuffleReader() => r + val shuffleReads = finalPlan.collect { + case r @ CoalescedShuffleRead() => r } - assert(shuffleReaders.length === 0) + assert(shuffleReads.length === 0) } finally { spark.sql("drop table t") } @@ -308,10 +312,10 @@ class CoalesceShufflePartitionsSuite extends SparkFunSuite with BeforeAndAfterAl } } - ignore("SPARK-24705 adaptive query execution works correctly when exchange reuse enabled") { + test("SPARK-24705 adaptive query execution works correctly when exchange reuse enabled") { val test: SparkSession => Unit = { spark: SparkSession => spark.sql("SET spark.sql.exchange.reuse=true") - val df = spark.range(1).selectExpr("id AS key", "id AS value") + val df = spark.range(0, 6, 1).selectExpr("id AS key", "id AS value") // test case 1: a query stage has 3 child stages but they are the same stage. // Final Stage 1 @@ -319,15 +323,15 @@ class CoalesceShufflePartitionsSuite extends SparkFunSuite with BeforeAndAfterAl // ReusedQueryStage 0 // ReusedQueryStage 0 val resultDf = df.join(df, "key").join(df, "key") - QueryTest.checkAnswer(resultDf, Row(0, 0, 0, 0) :: Nil) + QueryTest.checkAnswer(resultDf, (0 to 5).map(i => Row(i, i, i, i))) val finalPlan = resultDf.queryExecution.executedPlan .asInstanceOf[AdaptiveSparkPlanExec].executedPlan assert(finalPlan.collect { - case ShuffleQueryStageExec(_, r: ReusedExchangeExec) => r + case ShuffleQueryStageExec(_, r: ReusedExchangeExec, _) => r }.length == 2) assert( finalPlan.collect { - case r @ ColumnarCoalescedShuffleReader() => r + case r @ CoalescedShuffleRead() => r }.length == 3) @@ -340,7 +344,9 @@ class CoalesceShufflePartitionsSuite extends SparkFunSuite with BeforeAndAfterAl val grouped = df.groupBy("key").agg(max("value").as("value")) val resultDf2 = grouped.groupBy(col("key") + 1).max("value") .union(grouped.groupBy(col("key") + 2).max("value")) - QueryTest.checkAnswer(resultDf2, Row(1, 0) :: Row(2, 0) :: Nil) + QueryTest.checkAnswer(resultDf2, Row(1, 0) :: Row(2, 0) :: Row(2, 1) :: Row(3, 1) :: + Row(3, 2) :: Row(4, 2) :: Row(4, 3) :: Row(5, 3) :: Row(5, 4) :: Row(6, 4) :: Row(6, 5) :: + Row(7, 5) :: Nil) val finalPlan2 = resultDf2.queryExecution.executedPlan .asInstanceOf[AdaptiveSparkPlanExec].executedPlan @@ -349,6 +355,17 @@ class CoalesceShufflePartitionsSuite extends SparkFunSuite with BeforeAndAfterAl val level1Stages = finalPlan2.collect { case q: QueryStageExec => q } assert(level1Stages.length == 2) + assert( + finalPlan2.collect { + case r @ CoalescedShuffleRead() => r + }.length == 2, "finalPlan2") + + level1Stages.foreach(qs => + assert(qs.plan.collect { + case r @ CoalescedShuffleRead() => r + }.length == 1, + "Wrong CoalescedShuffleRead below " + qs.simpleString(3))) + val leafStages = level1Stages.flatMap { stage => // All of the child stages of result stage have only one child stage. val children = stage.plan.collect { case q: QueryStageExec => q } @@ -359,12 +376,12 @@ class CoalesceShufflePartitionsSuite extends SparkFunSuite with BeforeAndAfterAl val reusedStages = level1Stages.flatMap { stage => stage.plan.collect { - case ShuffleQueryStageExec(_, r: ReusedExchangeExec) => r + case ShuffleQueryStageExec(_, r: ReusedExchangeExec, _) => r } } assert(reusedStages.length == 1) } - withSparkSession(test, 4, None) + withSparkSession(test, 400, None) } test("Do not reduce the number of shuffle partition for repartition") { @@ -378,7 +395,7 @@ class CoalesceShufflePartitionsSuite extends SparkFunSuite with BeforeAndAfterAl .asInstanceOf[AdaptiveSparkPlanExec].executedPlan assert( finalPlan.collect { - case r @ ColumnarCoalescedShuffleReader() => r + case r @ CoalescedShuffleRead() => r }.isEmpty) } withSparkSession(test, 200, None) @@ -393,21 +410,40 @@ class CoalesceShufflePartitionsSuite extends SparkFunSuite with BeforeAndAfterAl QueryTest.checkAnswer(resultDf, Seq((0), (1), (2), (3)).map(i => Row(i))) + // Shuffle partition coalescing of the join is performed independent of the non-grouping + // aggregate on the other side of the union. val finalPlan = resultDf.queryExecution.executedPlan .asInstanceOf[AdaptiveSparkPlanExec].executedPlan - // As the pre-shuffle partition number are different, we will skip reducing - // the shuffle partition numbers. assert( finalPlan.collect { - case r @ ColumnarCoalescedShuffleReader() => r - }.isEmpty) + case r @ CoalescedShuffleRead() => r + }.size == 2) } withSparkSession(test, 100, None) } + + test("SPARK-34790: enable IO encryption in AQE partition coalescing") { + val test: SparkSession => Unit = { spark: SparkSession => + val ds = spark.range(0, 100, 1, numInputPartitions) + val resultDf = ds.repartition(ds.col("id")) + resultDf.collect() + + val finalPlan = resultDf.queryExecution.executedPlan + .asInstanceOf[AdaptiveSparkPlanExec].executedPlan + assert( + finalPlan.collect { + case r @ CoalescedShuffleRead() => r + }.isDefinedAt(0)) + } + Seq(true, false).foreach { enableIOEncryption => + // Before SPARK-34790, it will throw an exception when io encryption enabled. + withSparkSession(test, Int.MaxValue, None, enableIOEncryption) + } + } } -object ColumnarCoalescedShuffleReader { - def unapply(reader: ColumnarCustomShuffleReaderExec): Boolean = { - !reader.isLocalReader && !reader.hasSkewedPartition && reader.hasCoalescedPartition +object CoalescedShuffleRead { + def unapply(read: AQEShuffleReadExec): Boolean = { + !read.isLocalRead && !read.hasSkewedPartition && read.hasCoalescedPartition } } diff --git a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/ColumnarSparkPlanTest.scala b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/ColumnarSparkPlanTest.scala index 16ab589578aacc68a1964566ef05f3898d0e406a..fd5649c4486d0def0131ddb9a42102db11a38718 100644 --- a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/ColumnarSparkPlanTest.scala +++ b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/ColumnarSparkPlanTest.scala @@ -31,6 +31,7 @@ private[sql] abstract class ColumnarSparkPlanTest extends SparkPlanTest with Sha .set(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, "false") .set("spark.executorEnv.OMNI_CONNECTED_ENGINE", "Spark") .set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.OmniColumnarShuffleManager") + .set(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key, "false") protected def checkAnswer(df: => DataFrame, expectedAnswer: Seq[Row]): Unit = { val analyzedDF = try df catch { diff --git a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/adaptive/ColumnarAdaptiveQueryExecSuite.scala b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/adaptive/ColumnarAdaptiveQueryExecSuite.scala index cf2537484aefdcd68214db0046877652847bb34b..0055b94fa06626c48b19ef98d770d9109ace0726 100644 --- a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/adaptive/ColumnarAdaptiveQueryExecSuite.scala +++ b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/adaptive/ColumnarAdaptiveQueryExecSuite.scala @@ -17,34 +17,41 @@ package org.apache.spark.sql.execution.adaptive -import org.apache.log4j.Level -import org.apache.spark.Partition -import org.apache.spark.rdd.RDD +import java.io.File +import java.net.URI + +import org.apache.logging.log4j.Level +import org.scalatest.PrivateMethodTester +import org.scalatest.time.SpanSugar._ + +import org.apache.spark.SparkException import org.apache.spark.scheduler.{SparkListener, SparkListenerEvent, SparkListenerJobStart} -import org.apache.spark.sql.{Dataset, Row, SparkSession, Strategy} +import org.apache.spark.sql.{Dataset, QueryTest, Row, SparkSession, Strategy} import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight} import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan} +import org.apache.spark.sql.execution.{CollectLimitExec, CommandResultExec, LocalTableScanExec, PartialReducerPartitionSpec, QueryExecution, ReusedSubqueryExec, ShuffledRowRDD, SortExec, SparkPlan, UnaryExecNode, UnionExec} +import org.apache.spark.sql.execution.aggregate.BaseAggregateExec import org.apache.spark.sql.execution.command.DataWritingCommandExec import org.apache.spark.sql.execution.datasources.noop.NoopDataSource import org.apache.spark.sql.execution.datasources.v2.V2TableWriteExec -import org.apache.spark.sql.execution.{ColumnarBroadcastExchangeExec, ColumnarSparkPlanTest, PartialReducerPartitionSpec, QueryExecution, ReusedSubqueryExec, ShuffledColumnarRDD, SparkPlan, UnaryExecNode} -import org.apache.spark.sql.execution.exchange.{Exchange, REPARTITION, REPARTITION_WITH_NUM, ReusedExchangeExec, ShuffleExchangeExec, ShuffleExchangeLike} -import org.apache.spark.sql.execution.joins.{BaseJoinExec, BroadcastHashJoinExec, ColumnarBroadcastHashJoinExec, ColumnarSortMergeJoinExec, SortMergeJoinExec} +import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ENSURE_REQUIREMENTS, Exchange, REPARTITION_BY_COL, REPARTITION_BY_NUM, ReusedExchangeExec, ShuffleExchangeExec, ShuffleExchangeLike, ShuffleOrigin} +import org.apache.spark.sql.execution.joins.{BaseJoinExec, BroadcastHashJoinExec, BroadcastNestedLoopJoinExec, ShuffledHashJoinExec, ShuffledJoin, SortMergeJoinExec} import org.apache.spark.sql.execution.metric.SQLShuffleReadMetricsReporter import org.apache.spark.sql.execution.ui.SparkListenerSQLAdaptiveExecutionUpdate -import org.apache.spark.sql.functions.{sum, when} +import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf.PartitionOverwriteMode +import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.sql.test.SQLTestData.TestData import org.apache.spark.sql.types.{IntegerType, StructType} import org.apache.spark.sql.util.QueryExecutionListener -import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.spark.util.Utils -import java.io.File -import java.net.URI - -class ColumnarAdaptiveQueryExecSuite extends ColumnarSparkPlanTest - with AdaptiveSparkPlanHelper { +class AdaptiveQueryExecSuite + extends QueryTest + with SharedSparkSession + with AdaptiveSparkPlanHelper + with PrivateMethodTester { import testImplicits._ @@ -98,10 +105,9 @@ class ColumnarAdaptiveQueryExecSuite extends ColumnarSparkPlanTest } } - private def findTopLevelColumnarBroadcastHashJoin(plan: SparkPlan) - : Seq[ColumnarBroadcastHashJoinExec] = { + def findTopLevelBroadcastNestedLoopJoin(plan: SparkPlan): Seq[BaseJoinExec] = { collect(plan) { - case j: ColumnarBroadcastHashJoinExec => j + case j: BroadcastNestedLoopJoinExec => j } } @@ -111,9 +117,9 @@ class ColumnarAdaptiveQueryExecSuite extends ColumnarSparkPlanTest } } - private def findTopLevelColumnarSortMergeJoin(plan: SparkPlan): Seq[ColumnarSortMergeJoinExec] = { + private def findTopLevelShuffledHashJoin(plan: SparkPlan): Seq[ShuffledHashJoinExec] = { collect(plan) { - case j: ColumnarSortMergeJoinExec => j + case j: ShuffledHashJoinExec => j } } @@ -123,10 +129,28 @@ class ColumnarAdaptiveQueryExecSuite extends ColumnarSparkPlanTest } } + private def findTopLevelSort(plan: SparkPlan): Seq[SortExec] = { + collect(plan) { + case s: SortExec => s + } + } + + private def findTopLevelAggregate(plan: SparkPlan): Seq[BaseAggregateExec] = { + collect(plan) { + case agg: BaseAggregateExec => agg + } + } + + private def findTopLevelLimit(plan: SparkPlan): Seq[CollectLimitExec] = { + collect(plan) { + case l: CollectLimitExec => l + } + } + private def findReusedExchange(plan: SparkPlan): Seq[ReusedExchangeExec] = { collectWithSubqueries(plan) { - case ShuffleQueryStageExec(_, e: ReusedExchangeExec) => e - case BroadcastQueryStageExec(_, e: ReusedExchangeExec) => e + case ShuffleQueryStageExec(_, e: ReusedExchangeExec, _) => e + case BroadcastQueryStageExec(_, e: ReusedExchangeExec, _) => e } } @@ -136,28 +160,21 @@ class ColumnarAdaptiveQueryExecSuite extends ColumnarSparkPlanTest } } - private def checkNumLocalShuffleReaders( - plan: SparkPlan, numShufflesWithoutLocalReader: Int = 0): Unit = { + private def checkNumLocalShuffleReads( + plan: SparkPlan, numShufflesWithoutLocalRead: Int = 0): Unit = { val numShuffles = collect(plan) { case s: ShuffleQueryStageExec => s }.length - val numLocalReaders = collect(plan) { - case rowReader: CustomShuffleReaderExec if rowReader.isLocalReader => rowReader - case colReader: ColumnarCustomShuffleReaderExec if colReader.isLocalReader => colReader + val numLocalReads = collect(plan) { + case read: AQEShuffleReadExec if read.isLocalRead => read } - numLocalReaders.foreach { - case rowCus: CustomShuffleReaderExec => - val rdd = rowCus.execute() - val parts = rdd.partitions - assert(parts.forall(rdd.preferredLocations(_).nonEmpty)) - case r => - val columnarCus = r.asInstanceOf[ColumnarCustomShuffleReaderExec] - val rdd: RDD[ColumnarBatch] = columnarCus.executeColumnar() - val parts: Array[Partition] = rdd.partitions - assert(parts.forall(rdd.preferredLocations(_).nonEmpty)) + numLocalReads.foreach { r => + val rdd = r.execute() + val parts = rdd.partitions + assert(parts.forall(rdd.preferredLocations(_).nonEmpty)) } - assert(numShuffles === (numLocalReaders.length + numShufflesWithoutLocalReader)) + assert(numShuffles === (numLocalReads.length + numShufflesWithoutLocalRead)) } private def checkInitialPartitionNum(df: Dataset[_], numPartition: Int): Unit = { @@ -173,20 +190,42 @@ class ColumnarAdaptiveQueryExecSuite extends ColumnarSparkPlanTest test("Change merge join to broadcast join") { withSQLConf( - SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", - SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") { + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") { val (plan, adaptivePlan) = runAdaptiveAndVerifyResult( "SELECT * FROM testData join testData2 ON key = a where value = '1'") - val smj: Seq[SortMergeJoinExec] = findTopLevelSortMergeJoin(plan) + val smj = findTopLevelSortMergeJoin(plan) assert(smj.size == 1) - val bhj: Seq[ColumnarBroadcastHashJoinExec] = - findTopLevelColumnarBroadcastHashJoin(adaptivePlan) + val bhj = findTopLevelBroadcastHashJoin(adaptivePlan) assert(bhj.size == 1) - checkNumLocalShuffleReaders(adaptivePlan) + checkNumLocalShuffleReads(adaptivePlan) } } - test("Reuse the parallelism of CoalescedShuffleReaderExec in LocalShuffleReaderExec") { + test("Change broadcast join to merge join") { + withTable("t1", "t2") { + withSQLConf( + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "10000", + SQLConf.ADAPTIVE_AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", + SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + sql("CREATE TABLE t1 USING PARQUET AS SELECT 1 c1") + sql("CREATE TABLE t2 USING PARQUET AS SELECT 1 c1") + val (plan, adaptivePlan) = runAdaptiveAndVerifyResult( + """ + |SELECT * FROM ( + | SELECT distinct c1 from t1 + | ) tmp1 JOIN ( + | SELECT distinct c1 from t2 + | ) tmp2 ON tmp1.c1 = tmp2.c1 + |""".stripMargin) + assert(findTopLevelBroadcastHashJoin(plan).size == 1) + assert(findTopLevelBroadcastHashJoin(adaptivePlan).isEmpty) + assert(findTopLevelSortMergeJoin(adaptivePlan).size == 1) + } + } + } + + test("Reuse the parallelism of coalesced shuffle in local shuffle read") { withSQLConf( SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80", @@ -195,30 +234,30 @@ class ColumnarAdaptiveQueryExecSuite extends ColumnarSparkPlanTest "SELECT * FROM testData join testData2 ON key = a where value = '1'") val smj = findTopLevelSortMergeJoin(plan) assert(smj.size == 1) - val bhj = findTopLevelColumnarBroadcastHashJoin(adaptivePlan) + val bhj = findTopLevelBroadcastHashJoin(adaptivePlan) assert(bhj.size == 1) - val localReaders = collect(adaptivePlan) { - case reader: ColumnarCustomShuffleReaderExec if reader.isLocalReader => reader + val localReads = collect(adaptivePlan) { + case read: AQEShuffleReadExec if read.isLocalRead => read } - assert(localReaders.length == 2) - val localShuffleRDD0 = localReaders(0).executeColumnar().asInstanceOf[ShuffledColumnarRDD] - val localShuffleRDD1 = localReaders(1).executeColumnar().asInstanceOf[ShuffledColumnarRDD] + assert(localReads.length == 2) + val localShuffleRDD0 = localReads(0).execute().asInstanceOf[ShuffledRowRDD] + val localShuffleRDD1 = localReads(1).execute().asInstanceOf[ShuffledRowRDD] // The pre-shuffle partition size is [0, 0, 0, 72, 0] // We exclude the 0-size partitions, so only one partition, advisoryParallelism = 1 // the final parallelism is - // math.max(1, advisoryParallelism / numMappers): math.max(1, 1/2) = 1 - // and the partitions length is 1 * numMappers = 2 - assert(localShuffleRDD0.getPartitions.length == 2) + // advisoryParallelism = 1 since advisoryParallelism < numMappers + // and the partitions length is 1 + assert(localShuffleRDD0.getPartitions.length == 1) // The pre-shuffle partition size is [0, 72, 0, 72, 126] // We exclude the 0-size partitions, so only 3 partition, advisoryParallelism = 3 // the final parallelism is - // math.max(1, advisoryParallelism / numMappers): math.max(1, 3/2) = 1 + // advisoryParallelism / numMappers: 3/2 = 1 since advisoryParallelism >= numMappers // and the partitions length is 1 * numMappers = 2 assert(localShuffleRDD1.getPartitions.length == 2) } } - test("Reuse the default parallelism in LocalShuffleReaderExec") { + test("Reuse the default parallelism in local shuffle read") { withSQLConf( SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80", @@ -227,14 +266,14 @@ class ColumnarAdaptiveQueryExecSuite extends ColumnarSparkPlanTest "SELECT * FROM testData join testData2 ON key = a where value = '1'") val smj = findTopLevelSortMergeJoin(plan) assert(smj.size == 1) - val bhj = findTopLevelColumnarBroadcastHashJoin(adaptivePlan) + val bhj = findTopLevelBroadcastHashJoin(adaptivePlan) assert(bhj.size == 1) - val localReaders = collect(adaptivePlan) { - case reader: ColumnarCustomShuffleReaderExec if reader.isLocalReader => reader + val localReads = collect(adaptivePlan) { + case read: AQEShuffleReadExec if read.isLocalRead => read } - assert(localReaders.length == 2) - val localShuffleRDD0 = localReaders(0).executeColumnar().asInstanceOf[ShuffledColumnarRDD] - val localShuffleRDD1 = localReaders(1).executeColumnar().asInstanceOf[ShuffledColumnarRDD] + assert(localReads.length == 2) + val localShuffleRDD0 = localReads(0).execute().asInstanceOf[ShuffledRowRDD] + val localShuffleRDD1 = localReads(1).execute().asInstanceOf[ShuffledRowRDD] // the final parallelism is math.max(1, numReduces / numMappers): math.max(1, 5/2) = 2 // and the partitions length is 2 * numMappers = 4 assert(localShuffleRDD0.getPartitions.length == 4) @@ -247,73 +286,75 @@ class ColumnarAdaptiveQueryExecSuite extends ColumnarSparkPlanTest test("Empty stage coalesced to 1-partition RDD") { withSQLConf( SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", - SQLConf.COALESCE_PARTITIONS_ENABLED.key -> "true") { - val df1 = spark.range(10).withColumn("a", 'id) - val df2 = spark.range(10).withColumn("b", 'id) + SQLConf.COALESCE_PARTITIONS_ENABLED.key -> "true", + SQLConf.ADAPTIVE_OPTIMIZER_EXCLUDED_RULES.key -> AQEPropagateEmptyRelation.ruleName) { + val df1 = spark.range(10).withColumn("a", Symbol("id")) + val df2 = spark.range(10).withColumn("b", Symbol("id")) withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { - val testDf = df1.where('a > 10).join(df2.where('b > 10), Seq("id"), "left_outer") - .groupBy('a).count() + val testDf = df1.where(Symbol("a") > 10) + .join(df2.where(Symbol("b") > 10), Seq("id"), "left_outer") + .groupBy(Symbol("a")).count() checkAnswer(testDf, Seq()) val plan = testDf.queryExecution.executedPlan assert(find(plan)(_.isInstanceOf[SortMergeJoinExec]).isDefined) - val coalescedReaders = collect(plan) { - case r: ColumnarCustomShuffleReaderExec => r + val coalescedReads = collect(plan) { + case r: AQEShuffleReadExec => r } - assert(coalescedReaders.length == 3) - coalescedReaders.foreach(r => assert(r.partitionSpecs.length == 1)) + assert(coalescedReads.length == 3) + coalescedReads.foreach(r => assert(r.partitionSpecs.length == 1)) } withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "1") { - val testDf = df1.where('a > 10).join(df2.where('b > 10), Seq("id"), "left_outer") - .groupBy('a).count() + val testDf = df1.where(Symbol("a") > 10) + .join(df2.where(Symbol("b") > 10), Seq("id"), "left_outer") + .groupBy(Symbol("a")).count() checkAnswer(testDf, Seq()) val plan = testDf.queryExecution.executedPlan - print(plan) - assert(find(plan)(_.isInstanceOf[ColumnarBroadcastHashJoinExec]).isDefined) - val coalescedReaders = collect(plan) { - case r: ColumnarCustomShuffleReaderExec => r + assert(find(plan)(_.isInstanceOf[BroadcastHashJoinExec]).isDefined) + val coalescedReads = collect(plan) { + case r: AQEShuffleReadExec => r } - assert(coalescedReaders.length == 3, s"$plan") - coalescedReaders.foreach(r => assert(r.isLocalReader || r.partitionSpecs.length == 1)) + assert(coalescedReads.length == 3, s"$plan") + coalescedReads.foreach(r => assert(r.isLocalRead || r.partitionSpecs.length == 1)) } } } test("Scalar subquery") { withSQLConf( - SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", - SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") { + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") { val (plan, adaptivePlan) = runAdaptiveAndVerifyResult( "SELECT * FROM testData join testData2 ON key = a " + - "where value = (SELECT max(a) from testData3)") + "where value = (SELECT max(a) from testData3)") val smj = findTopLevelSortMergeJoin(plan) assert(smj.size == 1) - val bhj = findTopLevelColumnarBroadcastHashJoin(adaptivePlan) + val bhj = findTopLevelBroadcastHashJoin(adaptivePlan) assert(bhj.size == 1) - checkNumLocalShuffleReaders(adaptivePlan) + checkNumLocalShuffleReads(adaptivePlan) } } - // Currently, OmniFilterExec will fall back to Filter, if AQE is enabled, it will cause error - ignore("Scalar subquery in later stages") { + test("Scalar subquery in later stages") { withSQLConf( - SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", - SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") { + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") { val (plan, adaptivePlan) = runAdaptiveAndVerifyResult( "SELECT * FROM testData join testData2 ON key = a " + - "where (value + a) = (SELECT max(a) from testData3)") + "where (value + a) = (SELECT max(a) from testData3)") val smj = findTopLevelSortMergeJoin(plan) assert(smj.size == 1) - val bhj = findTopLevelColumnarBroadcastHashJoin(adaptivePlan) + val bhj = findTopLevelBroadcastHashJoin(adaptivePlan) assert(bhj.size == 1) - checkNumLocalShuffleReaders(adaptivePlan) + + checkNumLocalShuffleReads(adaptivePlan) } } test("multiple joins") { withSQLConf( - SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", - SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") { + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") { val (plan, adaptivePlan) = runAdaptiveAndVerifyResult( """ |WITH t4 AS ( @@ -326,7 +367,7 @@ class ColumnarAdaptiveQueryExecSuite extends ColumnarSparkPlanTest """.stripMargin) val smj = findTopLevelSortMergeJoin(plan) assert(smj.size == 3) - val bhj = findTopLevelColumnarBroadcastHashJoin(adaptivePlan) + val bhj = findTopLevelBroadcastHashJoin(adaptivePlan) assert(bhj.size == 3) // A possible resulting query plan: @@ -347,18 +388,18 @@ class ColumnarAdaptiveQueryExecSuite extends ColumnarSparkPlanTest // +-LocalShuffleReader* // +- ShuffleExchange - // After applied the 'OptimizeLocalShuffleReader' rule, we can convert all the four - // shuffle reader to local shuffle reader in the bottom two 'BroadcastHashJoin'. + // After applied the 'OptimizeShuffleWithLocalRead' rule, we can convert all the four + // shuffle read to local shuffle read in the bottom two 'BroadcastHashJoin'. // For the top level 'BroadcastHashJoin', the probe side is not shuffle query stage - // and the build side shuffle query stage is also converted to local shuffle reader. - checkNumLocalShuffleReaders(adaptivePlan) + // and the build side shuffle query stage is also converted to local shuffle read. + checkNumLocalShuffleReads(adaptivePlan) } } test("multiple joins with aggregate") { withSQLConf( - SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", - SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") { + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") { val (plan, adaptivePlan) = runAdaptiveAndVerifyResult( """ |WITH t4 AS ( @@ -373,7 +414,7 @@ class ColumnarAdaptiveQueryExecSuite extends ColumnarSparkPlanTest """.stripMargin) val smj = findTopLevelSortMergeJoin(plan) assert(smj.size == 3) - val bhj = findTopLevelColumnarBroadcastHashJoin(adaptivePlan) + val bhj = findTopLevelBroadcastHashJoin(adaptivePlan) assert(bhj.size == 3) // A possible resulting query plan: @@ -395,15 +436,15 @@ class ColumnarAdaptiveQueryExecSuite extends ColumnarSparkPlanTest // +- CoalescedShuffleReader // +- ShuffleExchange - // The shuffle added by Aggregate can't apply local reader. - checkNumLocalShuffleReaders(adaptivePlan, 1) + // The shuffle added by Aggregate can't apply local read. + checkNumLocalShuffleReads(adaptivePlan, 1) } } test("multiple joins with aggregate 2") { withSQLConf( - SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", - SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "500") { + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "500") { val (plan, adaptivePlan) = runAdaptiveAndVerifyResult( """ |WITH t4 AS ( @@ -418,8 +459,8 @@ class ColumnarAdaptiveQueryExecSuite extends ColumnarSparkPlanTest """.stripMargin) val smj = findTopLevelSortMergeJoin(plan) assert(smj.size == 3) - val bhj = findTopLevelColumnarBroadcastHashJoin(adaptivePlan) - assert(bhj.size == 2) + val bhj = findTopLevelBroadcastHashJoin(adaptivePlan) + assert(bhj.size == 3) // A possible resulting query plan: // BroadcastHashJoin @@ -441,25 +482,25 @@ class ColumnarAdaptiveQueryExecSuite extends ColumnarSparkPlanTest // +-LocalShuffleReader* // +- ShuffleExchange - // The shuffle added by Aggregate can't apply local reader. - checkNumLocalShuffleReaders(adaptivePlan, 1) + // The shuffle added by Aggregate can't apply local read. + checkNumLocalShuffleReads(adaptivePlan, 1) } } test("Exchange reuse") { withSQLConf( - SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", - SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") { + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") { val (plan, adaptivePlan) = runAdaptiveAndVerifyResult( "SELECT value FROM testData join testData2 ON key = a " + - "join (SELECT value v from testData join testData3 ON key = a) on value = v") + "join (SELECT value v from testData join testData3 ON key = a) on value = v") val smj = findTopLevelSortMergeJoin(plan) assert(smj.size == 3) - val bhj = findTopLevelColumnarBroadcastHashJoin(adaptivePlan) - assert(bhj.size == 3) - // There is no SMJ - checkNumLocalShuffleReaders(adaptivePlan, 0) - // Even with local shuffle reader, the query stage reuse can also work. + val bhj = findTopLevelBroadcastHashJoin(adaptivePlan) + assert(bhj.size == 2) + // There is still a SMJ, and its two shuffles can't apply local read. + checkNumLocalShuffleReads(adaptivePlan, 2) + // Even with local shuffle read, the query stage reuse can also work. val ex = findReusedExchange(adaptivePlan) assert(ex.size == 1) } @@ -467,17 +508,17 @@ class ColumnarAdaptiveQueryExecSuite extends ColumnarSparkPlanTest test("Exchange reuse with subqueries") { withSQLConf( - SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", - SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") { + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") { val (plan, adaptivePlan) = runAdaptiveAndVerifyResult( "SELECT a FROM testData join testData2 ON key = a " + - "where value = (SELECT max(a) from testData join testData2 ON key = a)") + "where value = (SELECT max(a) from testData join testData2 ON key = a)") val smj = findTopLevelSortMergeJoin(plan) assert(smj.size == 1) - val bhj = findTopLevelColumnarBroadcastHashJoin(adaptivePlan) + val bhj = findTopLevelBroadcastHashJoin(adaptivePlan) assert(bhj.size == 1) - checkNumLocalShuffleReaders(adaptivePlan) - // Even with local shuffle reader, the query stage reuse can also work. + checkNumLocalShuffleReads(adaptivePlan) + // Even with local shuffle read, the query stage reuse can also work. val ex = findReusedExchange(adaptivePlan) assert(ex.size == 1) } @@ -485,19 +526,19 @@ class ColumnarAdaptiveQueryExecSuite extends ColumnarSparkPlanTest test("Exchange reuse across subqueries") { withSQLConf( - SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", - SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80", - SQLConf.SUBQUERY_REUSE_ENABLED.key -> "false") { + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80", + SQLConf.SUBQUERY_REUSE_ENABLED.key -> "false") { val (plan, adaptivePlan) = runAdaptiveAndVerifyResult( "SELECT a FROM testData join testData2 ON key = a " + - "where value >= (SELECT max(a) from testData join testData2 ON key = a) " + - "and a <= (SELECT max(a) from testData join testData2 ON key = a)") + "where value >= (SELECT max(a) from testData join testData2 ON key = a) " + + "and a <= (SELECT max(a) from testData join testData2 ON key = a)") val smj = findTopLevelSortMergeJoin(plan) assert(smj.size == 1) - val bhj = findTopLevelColumnarBroadcastHashJoin(adaptivePlan) + val bhj = findTopLevelBroadcastHashJoin(adaptivePlan) assert(bhj.size == 1) - checkNumLocalShuffleReaders(adaptivePlan) - // Even with local shuffle reader, the query stage reuse can also work. + checkNumLocalShuffleReads(adaptivePlan) + // Even with local shuffle read, the query stage reuse can also work. val ex = findReusedExchange(adaptivePlan) assert(ex.nonEmpty) val sub = findReusedSubquery(adaptivePlan) @@ -507,18 +548,18 @@ class ColumnarAdaptiveQueryExecSuite extends ColumnarSparkPlanTest test("Subquery reuse") { withSQLConf( - SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", - SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") { + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") { val (plan, adaptivePlan) = runAdaptiveAndVerifyResult( "SELECT a FROM testData join testData2 ON key = a " + - "where value >= (SELECT max(a) from testData join testData2 ON key = a) " + - "and a <= (SELECT max(a) from testData join testData2 ON key = a)") + "where value >= (SELECT max(a) from testData join testData2 ON key = a) " + + "and a <= (SELECT max(a) from testData join testData2 ON key = a)") val smj = findTopLevelSortMergeJoin(plan) assert(smj.size == 1) - val bhj = findTopLevelColumnarBroadcastHashJoin(adaptivePlan) + val bhj = findTopLevelBroadcastHashJoin(adaptivePlan) assert(bhj.size == 1) - checkNumLocalShuffleReaders(adaptivePlan) - // Even with local shuffle reader, the query stage reuse can also work. + checkNumLocalShuffleReads(adaptivePlan) + // Even with local shuffle read, the query stage reuse can also work. val ex = findReusedExchange(adaptivePlan) assert(ex.isEmpty) val sub = findReusedSubquery(adaptivePlan) @@ -528,24 +569,24 @@ class ColumnarAdaptiveQueryExecSuite extends ColumnarSparkPlanTest test("Broadcast exchange reuse across subqueries") { withSQLConf( - SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", - SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "20000000", - SQLConf.SUBQUERY_REUSE_ENABLED.key -> "false") { + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "20000000", + SQLConf.SUBQUERY_REUSE_ENABLED.key -> "false") { val (plan, adaptivePlan) = runAdaptiveAndVerifyResult( "SELECT a FROM testData join testData2 ON key = a " + - "where value >= (" + - "SELECT /*+ broadcast(testData2) */ max(key) from testData join testData2 ON key = a) " + - "and a <= (" + - "SELECT /*+ broadcast(testData2) */ max(value) from testData join testData2 ON key = a)") + "where value >= (" + + "SELECT /*+ broadcast(testData2) */ max(key) from testData join testData2 ON key = a) " + + "and a <= (" + + "SELECT /*+ broadcast(testData2) */ max(value) from testData join testData2 ON key = a)") val smj = findTopLevelSortMergeJoin(plan) assert(smj.size == 1) - val bhj = findTopLevelColumnarBroadcastHashJoin(adaptivePlan) + val bhj = findTopLevelBroadcastHashJoin(adaptivePlan) assert(bhj.size == 1) - checkNumLocalShuffleReaders(adaptivePlan) - // Even with local shuffle reader, the query stage reuse can also work. + checkNumLocalShuffleReads(adaptivePlan) + // Even with local shuffle read, the query stage reuse can also work. val ex = findReusedExchange(adaptivePlan) assert(ex.nonEmpty) - assert(ex.head.child.isInstanceOf[ColumnarBroadcastExchangeExec]) + assert(ex.head.child.isInstanceOf[BroadcastExchangeExec]) val sub = findReusedSubquery(adaptivePlan) assert(sub.isEmpty) } @@ -591,7 +632,7 @@ class ColumnarAdaptiveQueryExecSuite extends ColumnarSparkPlanTest withSQLConf( SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", - SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "25", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80", SQLConf.BROADCAST_HASH_JOIN_OUTPUT_PARTITIONING_EXPAND_LIMIT.key -> "0") { val (plan, adaptivePlan) = runAdaptiveAndVerifyResult( "SELECT * FROM testData " + @@ -604,11 +645,11 @@ class ColumnarAdaptiveQueryExecSuite extends ColumnarSparkPlanTest } } - test("Change merge join to broadcast join without local shuffle reader") { + test("Change merge join to broadcast join without local shuffle read") { withSQLConf( SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", SQLConf.LOCAL_SHUFFLE_READER_ENABLED.key -> "true", - SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "25") { + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "40") { val (plan, adaptivePlan) = runAdaptiveAndVerifyResult( """ |SELECT * FROM testData t1 join testData2 t2 @@ -618,9 +659,10 @@ class ColumnarAdaptiveQueryExecSuite extends ColumnarSparkPlanTest ) val smj = findTopLevelSortMergeJoin(plan) assert(smj.size == 2) - val bhj = findTopLevelColumnarBroadcastHashJoin(adaptivePlan) + val bhj = findTopLevelBroadcastHashJoin(adaptivePlan) assert(bhj.size == 1) - checkNumLocalShuffleReaders(adaptivePlan, 2) + // There is still a SMJ, and its two shuffles can't apply local read. + checkNumLocalShuffleReads(adaptivePlan, 2) } } @@ -643,12 +685,53 @@ class ColumnarAdaptiveQueryExecSuite extends ColumnarSparkPlanTest "SELECT * FROM testData join testData2 ON key = a where value = '1'") val smj = findTopLevelSortMergeJoin(plan) assert(smj.size == 1) - val bhj = findTopLevelColumnarBroadcastHashJoin(adaptivePlan) + val bhj = findTopLevelBroadcastHashJoin(adaptivePlan) assert(bhj.size == 1) assert(bhj.head.buildSide == BuildRight) } } } + test("SPARK-37753: Allow changing outer join to broadcast join even if too many empty" + + " partitions on broadcast side") { + withSQLConf( + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", + SQLConf.NON_EMPTY_PARTITION_RATIO_FOR_BROADCAST_JOIN.key -> "0.5") { + // `testData` is small enough to be broadcast but has empty partition ratio over the config. + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") { + val (plan, adaptivePlan) = runAdaptiveAndVerifyResult( + "SELECT * FROM (select * from testData where value = '1') td" + + " right outer join testData2 ON key = a") + val smj = findTopLevelSortMergeJoin(plan) + assert(smj.size == 1) + val bhj = findTopLevelBroadcastHashJoin(adaptivePlan) + assert(bhj.size == 1) + } + } + } + + test("SPARK-37753: Inhibit broadcast in left outer join when there are many empty" + + " partitions on outer/left side") { + // if the right side is completed first and the left side is still being executed, + // the right side does not know whether there are many empty partitions on the left side, + // so there is no demote, and then the right side is broadcast in the planning stage. + // so retry several times here to avoid unit test failure. + eventually(timeout(15.seconds), interval(500.milliseconds)) { + withSQLConf( + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", + SQLConf.NON_EMPTY_PARTITION_RATIO_FOR_BROADCAST_JOIN.key -> "0.5") { + // `testData` is small enough to be broadcast but has empty partition ratio over the config. + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "200") { + val (plan, adaptivePlan) = runAdaptiveAndVerifyResult( + "SELECT * FROM (select * from testData where value = '1') td" + + " left outer join testData2 ON key = a") + val smj = findTopLevelSortMergeJoin(plan) + assert(smj.size == 1) + val bhj = findTopLevelBroadcastHashJoin(adaptivePlan) + assert(bhj.isEmpty) + } + } + } + } test("SPARK-29906: AQE should not introduce extra shuffle for outermost limit") { var numStages = 0 @@ -688,7 +771,7 @@ class ColumnarAdaptiveQueryExecSuite extends ColumnarSparkPlanTest def checkSkewJoin(query: String, optimizeSkewJoin: Boolean): Unit = { val (_, innerAdaptivePlan) = runAdaptiveAndVerifyResult(query) - val innerSmj = findTopLevelColumnarSortMergeJoin(innerAdaptivePlan) + val innerSmj = findTopLevelSortMergeJoin(innerAdaptivePlan) assert(innerSmj.size == 1 && innerSmj.head.isSkewJoin == optimizeSkewJoin) } @@ -701,65 +784,75 @@ class ColumnarAdaptiveQueryExecSuite extends ColumnarSparkPlanTest } } - ignore("SPARK-29544: adaptive skew join with different join types") { - withSQLConf( - SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", - SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", - SQLConf.COALESCE_PARTITIONS_MIN_PARTITION_NUM.key -> "1", - SQLConf.SHUFFLE_PARTITIONS.key -> "100", - SQLConf.SKEW_JOIN_SKEWED_PARTITION_THRESHOLD.key -> "800", - SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES.key -> "800") { - withTempView("skewData1", "skewData2") { - spark - .range(0, 1000, 1, 10) - .select( - when('id < 250, 249) - .when('id >= 750, 1000) - .otherwise('id).as("key1"), - 'id as "value1") - .createOrReplaceTempView("skewData1") - spark - .range(0, 1000, 1, 10) - .select( - when('id < 250, 249) - .otherwise('id).as("key2"), - 'id as "value2") - .createOrReplaceTempView("skewData2") - - def checkSkewJoin( - joins: Seq[SortMergeJoinExec], - leftSkewNum: Int, - rightSkewNum: Int): Unit = { - assert(joins.size == 1 && joins.head.isSkewJoin) - assert(joins.head.left.collect { - case r: ColumnarCustomShuffleReaderExec => r - }.head.partitionSpecs.collect { - case p: PartialReducerPartitionSpec => p.reducerIndex - }.distinct.length == leftSkewNum) - assert(joins.head.right.collect { - case r: ColumnarCustomShuffleReaderExec => r - }.head.partitionSpecs.collect { - case p: PartialReducerPartitionSpec => p.reducerIndex - }.distinct.length == rightSkewNum) - } - - // skewed inner join optimization - val (_, innerAdaptivePlan) = runAdaptiveAndVerifyResult( - "SELECT * FROM skewData1 join skewData2 ON key1 = key2") - val innerSmj = findTopLevelColumnarSortMergeJoin(innerAdaptivePlan) - checkSkewJoin(innerSmj, 1, 1) + test("SPARK-29544: adaptive skew join with different join types") { + Seq("SHUFFLE_MERGE", "SHUFFLE_HASH").foreach { joinHint => + def getJoinNode(plan: SparkPlan): Seq[ShuffledJoin] = if (joinHint == "SHUFFLE_MERGE") { + findTopLevelSortMergeJoin(plan) + } else { + findTopLevelShuffledHashJoin(plan) + } + withSQLConf( + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", + SQLConf.COALESCE_PARTITIONS_MIN_PARTITION_NUM.key -> "1", + SQLConf.SHUFFLE_PARTITIONS.key -> "100", + SQLConf.SKEW_JOIN_SKEWED_PARTITION_THRESHOLD.key -> "800", + SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES.key -> "800") { + withTempView("skewData1", "skewData2") { + spark + .range(0, 1000, 1, 10) + .select( + when(Symbol("id") < 250, 249) + .when(Symbol("id") >= 750, 1000) + .otherwise(Symbol("id")).as("key1"), + Symbol("id") as "value1") + .createOrReplaceTempView("skewData1") + spark + .range(0, 1000, 1, 10) + .select( + when(Symbol("id") < 250, 249) + .otherwise(Symbol("id")).as("key2"), + Symbol("id") as "value2") + .createOrReplaceTempView("skewData2") - // skewed left outer join optimization - val (_, leftAdaptivePlan) = runAdaptiveAndVerifyResult( - "SELECT * FROM skewData1 left outer join skewData2 ON key1 = key2") - val leftSmj = findTopLevelColumnarSortMergeJoin(leftAdaptivePlan) - checkSkewJoin(leftSmj, 2, 0) + def checkSkewJoin( + joins: Seq[ShuffledJoin], + leftSkewNum: Int, + rightSkewNum: Int): Unit = { + assert(joins.size == 1 && joins.head.isSkewJoin) + assert(joins.head.left.collect { + case r: AQEShuffleReadExec => r + }.head.partitionSpecs.collect { + case p: PartialReducerPartitionSpec => p.reducerIndex + }.distinct.length == leftSkewNum) + assert(joins.head.right.collect { + case r: AQEShuffleReadExec => r + }.head.partitionSpecs.collect { + case p: PartialReducerPartitionSpec => p.reducerIndex + }.distinct.length == rightSkewNum) + } - // skewed right outer join optimization - val (_, rightAdaptivePlan) = runAdaptiveAndVerifyResult( - "SELECT * FROM skewData1 right outer join skewData2 ON key1 = key2") - val rightSmj = findTopLevelColumnarSortMergeJoin(rightAdaptivePlan) - checkSkewJoin(rightSmj, 0, 1) + // skewed inner join optimization + val (_, innerAdaptivePlan) = runAdaptiveAndVerifyResult( + s"SELECT /*+ $joinHint(skewData1) */ * FROM skewData1 " + + "JOIN skewData2 ON key1 = key2") + val inner = getJoinNode(innerAdaptivePlan) + checkSkewJoin(inner, 2, 1) + + // skewed left outer join optimization + val (_, leftAdaptivePlan) = runAdaptiveAndVerifyResult( + s"SELECT /*+ $joinHint(skewData2) */ * FROM skewData1 " + + "LEFT OUTER JOIN skewData2 ON key1 = key2") + val leftJoin = getJoinNode(leftAdaptivePlan) + checkSkewJoin(leftJoin, 2, 0) + + // skewed right outer join optimization + val (_, rightAdaptivePlan) = runAdaptiveAndVerifyResult( + s"SELECT /*+ $joinHint(skewData1) */ * FROM skewData1 " + + "RIGHT OUTER JOIN skewData2 ON key1 = key2") + val rightJoin = getJoinNode(rightAdaptivePlan) + checkSkewJoin(rightJoin, 0, 1) + } } } } @@ -770,18 +863,18 @@ class ColumnarAdaptiveQueryExecSuite extends ColumnarSparkPlanTest withTable("bucketed_table") { val df1 = (0 until 50).map(i => (i % 5, i % 13, i.toString)).toDF("i", "j", "k").as("df1") - df1.write.format("orc").bucketBy(8, "i").saveAsTable("bucketed_table") + df1.write.format("parquet").bucketBy(8, "i").saveAsTable("bucketed_table") val warehouseFilePath = new URI(spark.sessionState.conf.warehousePath).getPath val tableDir = new File(warehouseFilePath, "bucketed_table") Utils.deleteRecursively(tableDir) - df1.write.orc(tableDir.getAbsolutePath) + df1.write.parquet(tableDir.getAbsolutePath) val aggregated = spark.table("bucketed_table").groupBy("i").count() - val error = intercept[Exception] { + val error = intercept[SparkException] { aggregated.count() } - assert(error.getCause.toString contains "Invalid bucket file") - assert(error.getSuppressed.size === 0) + assert(error.getErrorClass === "INVALID_BUCKET_FILE") + assert(error.getMessage contains "Invalid bucket file") } } } @@ -794,409 +887,430 @@ class ColumnarAdaptiveQueryExecSuite extends ColumnarSparkPlanTest } } - test("force apply AQE") { + test("force apply AQE") { + withSQLConf( + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", + SQLConf.ADAPTIVE_EXECUTION_FORCE_APPLY.key -> "true") { + val plan = sql("SELECT * FROM testData").queryExecution.executedPlan + assert(plan.isInstanceOf[AdaptiveSparkPlanExec]) + } + } + + test("SPARK-30719: do not log warning if intentionally skip AQE") { + val testAppender = new LogAppender("aqe logging warning test when skip") + withLogAppender(testAppender) { withSQLConf( - SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", - SQLConf.ADAPTIVE_EXECUTION_FORCE_APPLY.key -> "true") { + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") { val plan = sql("SELECT * FROM testData").queryExecution.executedPlan - assert(plan.isInstanceOf[AdaptiveSparkPlanExec]) + assert(!plan.isInstanceOf[AdaptiveSparkPlanExec]) } } + assert(!testAppender.loggingEvents + .exists(msg => msg.getMessage.getFormattedMessage.contains( + s"${SQLConf.ADAPTIVE_EXECUTION_ENABLED.key} is" + + s" enabled but is not supported for"))) + } - test("SPARK-30719: do not log warning if intentionally skip AQE") { - val testAppender = new LogAppender("aqe logging warning test when skip") - withLogAppender(testAppender) { + test("test log level") { + def verifyLog(expectedLevel: Level): Unit = { + val logAppender = new LogAppender("adaptive execution") + logAppender.setThreshold(expectedLevel) + withLogAppender( + logAppender, + loggerNames = Seq(AdaptiveSparkPlanExec.getClass.getName.dropRight(1)), + level = Some(Level.TRACE)) { withSQLConf( - SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") { - val plan = sql("SELECT * FROM testData").queryExecution.executedPlan - assert(!plan.isInstanceOf[AdaptiveSparkPlanExec]) + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") { + sql("SELECT * FROM testData join testData2 ON key = a where value = '1'").collect() } } - assert(!testAppender.loggingEvents - .exists(msg => msg.getRenderedMessage.contains( - s"${SQLConf.ADAPTIVE_EXECUTION_ENABLED.key} is" + - s" enabled but is not supported for"))) + Seq("Plan changed", "Final plan").foreach { msg => + assert( + logAppender.loggingEvents.exists { event => + event.getMessage.getFormattedMessage.contains(msg) && event.getLevel == expectedLevel + }) + } } - test("test log level") { - def verifyLog(expectedLevel: Level): Unit = { - val logAppender = new LogAppender("adaptive execution") - withLogAppender( - logAppender, - loggerName = Some(AdaptiveSparkPlanExec.getClass.getName.dropRight(1)), - level = Some(Level.TRACE)) { - withSQLConf( - SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", - SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") { - sql("SELECT * FROM testData join testData2 ON key = a where value = '1'").collect() - } - } - Seq("Plan changed", "Final plan").foreach { msg => - assert( - logAppender.loggingEvents.exists { event => - event.getRenderedMessage.contains(msg) && event.getLevel == expectedLevel - }) - } + // Verify default log level + verifyLog(Level.DEBUG) + + // Verify custom log level + val levels = Seq( + "TRACE" -> Level.TRACE, + "trace" -> Level.TRACE, + "DEBUG" -> Level.DEBUG, + "debug" -> Level.DEBUG, + "INFO" -> Level.INFO, + "info" -> Level.INFO, + "WARN" -> Level.WARN, + "warn" -> Level.WARN, + "ERROR" -> Level.ERROR, + "error" -> Level.ERROR, + "deBUG" -> Level.DEBUG) + + levels.foreach { level => + withSQLConf(SQLConf.ADAPTIVE_EXECUTION_LOG_LEVEL.key -> level._1) { + verifyLog(level._2) } + } + } - // Verify default log level - verifyLog(Level.DEBUG) - - // Verify custom log level - val levels = Seq( - "TRACE" -> Level.TRACE, - "trace" -> Level.TRACE, - "DEBUG" -> Level.DEBUG, - "debug" -> Level.DEBUG, - "INFO" -> Level.INFO, - "info" -> Level.INFO, - "WARN" -> Level.WARN, - "warn" -> Level.WARN, - "ERROR" -> Level.ERROR, - "error" -> Level.ERROR, - "deBUG" -> Level.DEBUG) - - levels.foreach { level => - withSQLConf(SQLConf.ADAPTIVE_EXECUTION_LOG_LEVEL.key -> level._1) { - verifyLog(level._2) - } - } + test("tree string output") { + withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") { + val df = sql("SELECT * FROM testData join testData2 ON key = a where value = '1'") + val planBefore = df.queryExecution.executedPlan + assert(!planBefore.toString.contains("== Current Plan ==")) + assert(!planBefore.toString.contains("== Initial Plan ==")) + df.collect() + val planAfter = df.queryExecution.executedPlan + assert(planAfter.toString.contains("== Final Plan ==")) + assert(planAfter.toString.contains("== Initial Plan ==")) } + } - test("tree string output") { - withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") { - val df = sql("SELECT * FROM testData join testData2 ON key = a where value = '1'") - val planBefore = df.queryExecution.executedPlan - assert(!planBefore.toString.contains("== Current Plan ==")) - assert(!planBefore.toString.contains("== Initial Plan ==")) - df.collect() - val planAfter = df.queryExecution.executedPlan - assert(planAfter.toString.contains("== Final Plan ==")) - assert(planAfter.toString.contains("== Initial Plan ==")) + test("SPARK-31384: avoid NPE in OptimizeSkewedJoin when there's 0 partition plan") { + withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { + withTempView("t2") { + // create DataFrame with 0 partition + spark.createDataFrame(sparkContext.emptyRDD[Row], new StructType().add("b", IntegerType)) + .createOrReplaceTempView("t2") + // should run successfully without NPE + runAdaptiveAndVerifyResult("SELECT * FROM testData2 t1 left semi join t2 ON t1.a=t2.b") } } + } - test("SPARK-31384: avoid NPE in OptimizeSkewedJoin when there's 0 partition plan") { - withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", - SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { - withTempView("t2") { - // create DataFrame with 0 partition - spark.createDataFrame(sparkContext.emptyRDD[Row], new StructType().add("b", IntegerType)) - .createOrReplaceTempView("t2") - // should run successfully without NPE - runAdaptiveAndVerifyResult("SELECT * FROM testData2 t1 join t2 ON t1.a=t2.b") - } + test("SPARK-34682: AQEShuffleReadExec operating on canonicalized plan") { + withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") { + val (_, adaptivePlan) = runAdaptiveAndVerifyResult( + "SELECT key FROM testData GROUP BY key") + val reads = collect(adaptivePlan) { + case r: AQEShuffleReadExec => r + } + assert(reads.length == 1) + val read = reads.head + val c = read.canonicalized.asInstanceOf[AQEShuffleReadExec] + // we can't just call execute() because that has separate checks for canonicalized plans + val ex = intercept[IllegalStateException] { + val doExecute = PrivateMethod[Unit](Symbol("doExecute")) + c.invokePrivate(doExecute()) } + assert(ex.getMessage === "operating on canonicalized plan") } + } - ignore("metrics of the shuffle reader") { + test("metrics of the shuffle read") { withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") { val (_, adaptivePlan) = runAdaptiveAndVerifyResult( "SELECT key FROM testData GROUP BY key") - val readers = collect(adaptivePlan) { - case r: ColumnarCustomShuffleReaderExec => r - } - print(readers.length) - assert(readers.length == 1) - val reader = readers.head - assert(!reader.isLocalReader) - assert(!reader.hasSkewedPartition) - assert(reader.hasCoalescedPartition) - assert(reader.metrics.keys.toSeq.sorted == Seq( - "numPartitions", "partitionDataSize")) - assert(reader.metrics("numPartitions").value == reader.partitionSpecs.length) - assert(reader.metrics("partitionDataSize").value > 0) + val reads = collect(adaptivePlan) { + case r: AQEShuffleReadExec => r + } + assert(reads.length == 1) + val read = reads.head + assert(!read.isLocalRead) + assert(!read.hasSkewedPartition) + assert(read.hasCoalescedPartition) + assert(read.metrics.keys.toSeq.sorted == Seq( + "numCoalescedPartitions", "numPartitions", "partitionDataSize")) + assert(read.metrics("numCoalescedPartitions").value == 1) + assert(read.metrics("numPartitions").value == read.partitionSpecs.length) + assert(read.metrics("partitionDataSize").value > 0) withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") { val (_, adaptivePlan) = runAdaptiveAndVerifyResult( "SELECT * FROM testData join testData2 ON key = a where value = '1'") val join = collect(adaptivePlan) { - case j: ColumnarBroadcastHashJoinExec => j + case j: BroadcastHashJoinExec => j }.head assert(join.buildSide == BuildLeft) - val readers = collect(join.right) { - case r: ColumnarCustomShuffleReaderExec => r + val reads = collect(join.right) { + case r: AQEShuffleReadExec => r } - assert(readers.length == 1) - val reader = readers.head - assert(reader.isLocalReader) - assert(reader.metrics.keys.toSeq == Seq("numPartitions")) - assert(reader.metrics("numPartitions").value == reader.partitionSpecs.length) + assert(reads.length == 1) + val read = reads.head + assert(read.isLocalRead) + assert(read.metrics.keys.toSeq == Seq("numPartitions")) + assert(read.metrics("numPartitions").value == read.partitionSpecs.length) } withSQLConf( SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", SQLConf.SHUFFLE_PARTITIONS.key -> "100", SQLConf.SKEW_JOIN_SKEWED_PARTITION_THRESHOLD.key -> "800", - SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES.key -> "800") { + SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES.key -> "1000") { withTempView("skewData1", "skewData2") { spark .range(0, 1000, 1, 10) .select( - when('id < 250, 249) - .when('id >= 750, 1000) - .otherwise('id).as("key1"), - 'id as "value1") + when(Symbol("id") < 250, 249) + .when(Symbol("id") >= 750, 1000) + .otherwise(Symbol("id")).as("key1"), + Symbol("id") as "value1") .createOrReplaceTempView("skewData1") spark .range(0, 1000, 1, 10) .select( - when('id < 250, 249) - .otherwise('id).as("key2"), - 'id as "value2") + when(Symbol("id") < 250, 249) + .otherwise(Symbol("id")).as("key2"), + Symbol("id") as "value2") .createOrReplaceTempView("skewData2") val (_, adaptivePlan) = runAdaptiveAndVerifyResult( "SELECT * FROM skewData1 join skewData2 ON key1 = key2") - val readers = collect(adaptivePlan) { - case r: CustomShuffleReaderExec => r + val reads = collect(adaptivePlan) { + case r: AQEShuffleReadExec => r } - readers.foreach { reader => - assert(!reader.isLocalReader) - assert(reader.hasCoalescedPartition) - assert(reader.hasSkewedPartition) - assert(reader.metrics.contains("numSkewedPartitions")) + reads.foreach { read => + assert(!read.isLocalRead) + assert(read.hasCoalescedPartition) + assert(read.hasSkewedPartition) + assert(read.metrics.contains("numSkewedPartitions")) } - print(readers(1).metrics("numSkewedPartitions")) - print(readers(1).metrics("numSkewedSplits")) - assert(readers(0).metrics("numSkewedPartitions").value == 2) - assert(readers(0).metrics("numSkewedSplits").value == 15) - assert(readers(1).metrics("numSkewedPartitions").value == 1) - assert(readers(1).metrics("numSkewedSplits").value == 12) + assert(reads(0).metrics("numSkewedPartitions").value == 2) + assert(reads(0).metrics("numSkewedSplits").value == 11) + assert(reads(1).metrics("numSkewedPartitions").value == 1) + assert(reads(1).metrics("numSkewedSplits").value == 9) } } } } - test("control a plan explain mode in listeners via SQLConf") { - - def checkPlanDescription(mode: String, expected: Seq[String]): Unit = { - var checkDone = false - val listener = new SparkListener { - override def onOtherEvent(event: SparkListenerEvent): Unit = { - event match { - case SparkListenerSQLAdaptiveExecutionUpdate(_, planDescription, _) => - assert(expected.forall(planDescription.contains)) - checkDone = true - case _ => // ignore other events - } + test("control a plan explain mode in listeners via SQLConf") { + + def checkPlanDescription(mode: String, expected: Seq[String]): Unit = { + var checkDone = false + val listener = new SparkListener { + override def onOtherEvent(event: SparkListenerEvent): Unit = { + event match { + case SparkListenerSQLAdaptiveExecutionUpdate(_, planDescription, _) => + assert(expected.forall(planDescription.contains)) + checkDone = true + case _ => // ignore other events } } - spark.sparkContext.addSparkListener(listener) - withSQLConf(SQLConf.UI_EXPLAIN_MODE.key -> mode, + } + spark.sparkContext.addSparkListener(listener) + withSQLConf(SQLConf.UI_EXPLAIN_MODE.key -> mode, SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") { - val dfAdaptive = sql("SELECT * FROM testData JOIN testData2 ON key = a WHERE value = '1'") - try { - checkAnswer(dfAdaptive, Row(1, "1", 1, 1) :: Row(1, "1", 1, 2) :: Nil) - spark.sparkContext.listenerBus.waitUntilEmpty() - assert(checkDone) - } finally { - spark.sparkContext.removeSparkListener(listener) - } + val dfAdaptive = sql("SELECT * FROM testData JOIN testData2 ON key = a WHERE value = '1'") + try { + checkAnswer(dfAdaptive, Row(1, "1", 1, 1) :: Row(1, "1", 1, 2) :: Nil) + spark.sparkContext.listenerBus.waitUntilEmpty() + assert(checkDone) + } finally { + spark.sparkContext.removeSparkListener(listener) } } + } - Seq(("simple", Seq("== Physical Plan ==")), + Seq(("simple", Seq("== Physical Plan ==")), ("extended", Seq("== Parsed Logical Plan ==", "== Analyzed Logical Plan ==", "== Optimized Logical Plan ==", "== Physical Plan ==")), ("codegen", Seq("WholeStageCodegen subtrees")), ("cost", Seq("== Optimized Logical Plan ==", "Statistics(sizeInBytes")), ("formatted", Seq("== Physical Plan ==", "Output", "Arguments"))).foreach { - case (mode, expected) => - checkPlanDescription(mode, expected) - } + case (mode, expected) => + checkPlanDescription(mode, expected) } + } - test("SPARK-30953: InsertAdaptiveSparkPlan should apply AQE on child plan of write commands") { - withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", - SQLConf.ADAPTIVE_EXECUTION_FORCE_APPLY.key -> "true") { - withTable("t1") { - val plan = sql("CREATE TABLE t1 USING parquet AS SELECT 1 col").queryExecution.executedPlan - assert(plan.isInstanceOf[DataWritingCommandExec]) - assert(plan.asInstanceOf[DataWritingCommandExec].child.isInstanceOf[AdaptiveSparkPlanExec]) - } + test("SPARK-30953: InsertAdaptiveSparkPlan should apply AQE on child plan of write commands") { + withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", + SQLConf.ADAPTIVE_EXECUTION_FORCE_APPLY.key -> "true") { + withTable("t1") { + val plan = sql("CREATE TABLE t1 USING parquet AS SELECT 1 col").queryExecution.executedPlan + assert(plan.isInstanceOf[CommandResultExec]) + val commandResultExec = plan.asInstanceOf[CommandResultExec] + assert(commandResultExec.commandPhysicalPlan.isInstanceOf[DataWritingCommandExec]) + assert(commandResultExec.commandPhysicalPlan.asInstanceOf[DataWritingCommandExec] + .child.isInstanceOf[AdaptiveSparkPlanExec]) } } + } - test("AQE should set active session during execution") { - withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") { - val df = spark.range(10).select(sum('id)) - assert(df.queryExecution.executedPlan.isInstanceOf[AdaptiveSparkPlanExec]) - SparkSession.setActiveSession(null) - checkAnswer(df, Seq(Row(45))) - SparkSession.setActiveSession(spark) // recover the active session. - } - } - - test("No deadlock in UI update") { - object TestStrategy extends Strategy { - def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case _: Aggregate => - withSQLConf( - SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", - SQLConf.ADAPTIVE_EXECUTION_FORCE_APPLY.key -> "true") { - spark.range(5).rdd - } - Nil - case _ => Nil - } + test("AQE should set active session during execution") { + withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") { + val df = spark.range(10).select(sum(Symbol("id"))) + assert(df.queryExecution.executedPlan.isInstanceOf[AdaptiveSparkPlanExec]) + SparkSession.setActiveSession(null) + checkAnswer(df, Seq(Row(45))) + SparkSession.setActiveSession(spark) // recover the active session. + } + } + + test("No deadlock in UI update") { + object TestStrategy extends Strategy { + def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + case _: Aggregate => + withSQLConf( + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", + SQLConf.ADAPTIVE_EXECUTION_FORCE_APPLY.key -> "true") { + spark.range(5).rdd + } + Nil + case _ => Nil } + } - withSQLConf( - SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", - SQLConf.ADAPTIVE_EXECUTION_FORCE_APPLY.key -> "true") { - try { - spark.experimental.extraStrategies = TestStrategy :: Nil - val df = spark.range(10).groupBy('id).count() - df.collect() - } finally { - spark.experimental.extraStrategies = Nil - } + withSQLConf( + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", + SQLConf.ADAPTIVE_EXECUTION_FORCE_APPLY.key -> "true") { + try { + spark.experimental.extraStrategies = TestStrategy :: Nil + val df = spark.range(10).groupBy(Symbol("id")).count() + df.collect() + } finally { + spark.experimental.extraStrategies = Nil } } + } - test("SPARK-31658: SQL UI should show write commands") { - withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", - SQLConf.ADAPTIVE_EXECUTION_FORCE_APPLY.key -> "true") { - withTable("t1") { - var checkDone = false - val listener = new SparkListener { - override def onOtherEvent(event: SparkListenerEvent): Unit = { - event match { - case SparkListenerSQLAdaptiveExecutionUpdate(_, _, planInfo) => - assert(planInfo.nodeName == "Execute CreateDataSourceTableAsSelectCommand") - checkDone = true - case _ => // ignore other events - } + test("SPARK-31658: SQL UI should show write commands") { + withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", + SQLConf.ADAPTIVE_EXECUTION_FORCE_APPLY.key -> "true") { + withTable("t1") { + var checkDone = false + val listener = new SparkListener { + override def onOtherEvent(event: SparkListenerEvent): Unit = { + event match { + case SparkListenerSQLAdaptiveExecutionUpdate(_, _, planInfo) => + assert(planInfo.nodeName == "Execute CreateDataSourceTableAsSelectCommand") + checkDone = true + case _ => // ignore other events } } - spark.sparkContext.addSparkListener(listener) - try { - sql("CREATE TABLE t1 USING parquet AS SELECT 1 col").collect() - spark.sparkContext.listenerBus.waitUntilEmpty() - assert(checkDone) - } finally { - spark.sparkContext.removeSparkListener(listener) - } + } + spark.sparkContext.addSparkListener(listener) + try { + sql("CREATE TABLE t1 USING parquet AS SELECT 1 col").collect() + spark.sparkContext.listenerBus.waitUntilEmpty() + assert(checkDone) + } finally { + spark.sparkContext.removeSparkListener(listener) } } } + } - test("SPARK-31220, SPARK-32056: repartition by expression with AQE") { - Seq(true, false).foreach { enableAQE => - withSQLConf( - SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> enableAQE.toString, - SQLConf.COALESCE_PARTITIONS_ENABLED.key -> "true", - SQLConf.COALESCE_PARTITIONS_INITIAL_PARTITION_NUM.key -> "10", - SQLConf.SHUFFLE_PARTITIONS.key -> "10") { - - val df1 = spark.range(10).repartition($"id") - val df2 = spark.range(10).repartition($"id" + 1) - - val partitionsNum1 = df1.rdd.collectPartitions().length - val partitionsNum2 = df2.rdd.collectPartitions().length - - if (enableAQE) { - assert(partitionsNum1 < 10) - assert(partitionsNum2 < 10) + test("SPARK-31220, SPARK-32056: repartition by expression with AQE") { + Seq(true, false).foreach { enableAQE => + withSQLConf( + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> enableAQE.toString, + SQLConf.COALESCE_PARTITIONS_ENABLED.key -> "true", + SQLConf.COALESCE_PARTITIONS_INITIAL_PARTITION_NUM.key -> "10", + SQLConf.SHUFFLE_PARTITIONS.key -> "10") { + + val df1 = spark.range(10).repartition($"id") + val df2 = spark.range(10).repartition($"id" + 1) + + val partitionsNum1 = df1.rdd.collectPartitions().length + val partitionsNum2 = df2.rdd.collectPartitions().length + + if (enableAQE) { + assert(partitionsNum1 < 10) + assert(partitionsNum2 < 10) + + checkInitialPartitionNum(df1, 10) + checkInitialPartitionNum(df2, 10) + } else { + assert(partitionsNum1 === 10) + assert(partitionsNum2 === 10) + } - checkInitialPartitionNum(df1, 10) - checkInitialPartitionNum(df2, 10) - } else { - assert(partitionsNum1 === 10) - assert(partitionsNum2 === 10) - } + // Don't coalesce partitions if the number of partitions is specified. + val df3 = spark.range(10).repartition(10, $"id") + val df4 = spark.range(10).repartition(10) + assert(df3.rdd.collectPartitions().length == 10) + assert(df4.rdd.collectPartitions().length == 10) + } + } + } - // Don't coalesce partitions if the number of partitions is specified. - val df3 = spark.range(10).repartition(10, $"id") - val df4 = spark.range(10).repartition(10) - assert(df3.rdd.collectPartitions().length == 10) - assert(df4.rdd.collectPartitions().length == 10) + test("SPARK-31220, SPARK-32056: repartition by range with AQE") { + Seq(true, false).foreach { enableAQE => + withSQLConf( + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> enableAQE.toString, + SQLConf.COALESCE_PARTITIONS_ENABLED.key -> "true", + SQLConf.COALESCE_PARTITIONS_INITIAL_PARTITION_NUM.key -> "10", + SQLConf.SHUFFLE_PARTITIONS.key -> "10") { + + val df1 = spark.range(10).toDF.repartitionByRange($"id".asc) + val df2 = spark.range(10).toDF.repartitionByRange(($"id" + 1).asc) + + val partitionsNum1 = df1.rdd.collectPartitions().length + val partitionsNum2 = df2.rdd.collectPartitions().length + + if (enableAQE) { + assert(partitionsNum1 < 10) + assert(partitionsNum2 < 10) + + checkInitialPartitionNum(df1, 10) + checkInitialPartitionNum(df2, 10) + } else { + assert(partitionsNum1 === 10) + assert(partitionsNum2 === 10) } + + // Don't coalesce partitions if the number of partitions is specified. + val df3 = spark.range(10).repartitionByRange(10, $"id".asc) + assert(df3.rdd.collectPartitions().length == 10) } } + } - test("SPARK-31220, SPARK-32056: repartition by range with AQE") { - Seq(true, false).foreach { enableAQE => + test("SPARK-31220, SPARK-32056: repartition using sql and hint with AQE") { + Seq(true, false).foreach { enableAQE => + withTempView("test") { withSQLConf( SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> enableAQE.toString, SQLConf.COALESCE_PARTITIONS_ENABLED.key -> "true", SQLConf.COALESCE_PARTITIONS_INITIAL_PARTITION_NUM.key -> "10", SQLConf.SHUFFLE_PARTITIONS.key -> "10") { - val df1 = spark.range(10).toDF.repartitionByRange($"id".asc) - val df2 = spark.range(10).toDF.repartitionByRange(($"id" + 1).asc) + spark.range(10).toDF.createTempView("test") + + val df1 = spark.sql("SELECT /*+ REPARTITION(id) */ * from test") + val df2 = spark.sql("SELECT /*+ REPARTITION_BY_RANGE(id) */ * from test") + val df3 = spark.sql("SELECT * from test DISTRIBUTE BY id") + val df4 = spark.sql("SELECT * from test CLUSTER BY id") val partitionsNum1 = df1.rdd.collectPartitions().length val partitionsNum2 = df2.rdd.collectPartitions().length + val partitionsNum3 = df3.rdd.collectPartitions().length + val partitionsNum4 = df4.rdd.collectPartitions().length if (enableAQE) { assert(partitionsNum1 < 10) assert(partitionsNum2 < 10) + assert(partitionsNum3 < 10) + assert(partitionsNum4 < 10) checkInitialPartitionNum(df1, 10) checkInitialPartitionNum(df2, 10) + checkInitialPartitionNum(df3, 10) + checkInitialPartitionNum(df4, 10) } else { assert(partitionsNum1 === 10) assert(partitionsNum2 === 10) + assert(partitionsNum3 === 10) + assert(partitionsNum4 === 10) } // Don't coalesce partitions if the number of partitions is specified. - val df3 = spark.range(10).repartitionByRange(10, $"id".asc) - assert(df3.rdd.collectPartitions().length == 10) - } - } - } - - test("SPARK-31220, SPARK-32056: repartition using sql and hint with AQE") { - Seq(true, false).foreach { enableAQE => - withTempView("test") { - withSQLConf( - SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> enableAQE.toString, - SQLConf.COALESCE_PARTITIONS_ENABLED.key -> "true", - SQLConf.COALESCE_PARTITIONS_INITIAL_PARTITION_NUM.key -> "10", - SQLConf.SHUFFLE_PARTITIONS.key -> "10") { - - spark.range(10).toDF.createTempView("test") - - val df1 = spark.sql("SELECT /*+ REPARTITION(id) */ * from test") - val df2 = spark.sql("SELECT /*+ REPARTITION_BY_RANGE(id) */ * from test") - val df3 = spark.sql("SELECT * from test DISTRIBUTE BY id") - val df4 = spark.sql("SELECT * from test CLUSTER BY id") - - val partitionsNum1 = df1.rdd.collectPartitions().length - val partitionsNum2 = df2.rdd.collectPartitions().length - val partitionsNum3 = df3.rdd.collectPartitions().length - val partitionsNum4 = df4.rdd.collectPartitions().length - - if (enableAQE) { - assert(partitionsNum1 < 10) - assert(partitionsNum2 < 10) - assert(partitionsNum3 < 10) - assert(partitionsNum4 < 10) - - checkInitialPartitionNum(df1, 10) - checkInitialPartitionNum(df2, 10) - checkInitialPartitionNum(df3, 10) - checkInitialPartitionNum(df4, 10) - } else { - assert(partitionsNum1 === 10) - assert(partitionsNum2 === 10) - assert(partitionsNum3 === 10) - assert(partitionsNum4 === 10) - } - - // Don't coalesce partitions if the number of partitions is specified. - val df5 = spark.sql("SELECT /*+ REPARTITION(10, id) */ * from test") - val df6 = spark.sql("SELECT /*+ REPARTITION_BY_RANGE(10, id) */ * from test") - assert(df5.rdd.collectPartitions().length == 10) - assert(df6.rdd.collectPartitions().length == 10) - } + val df5 = spark.sql("SELECT /*+ REPARTITION(10, id) */ * from test") + val df6 = spark.sql("SELECT /*+ REPARTITION_BY_RANGE(10, id) */ * from test") + assert(df5.rdd.collectPartitions().length == 10) + assert(df6.rdd.collectPartitions().length == 10) } } } + } test("SPARK-32573: Eliminate NAAJ when BuildSide is HashedRelationWithAllNullKeys") { withSQLConf( @@ -1208,149 +1322,373 @@ class ColumnarAdaptiveQueryExecSuite extends ColumnarSparkPlanTest assert(bhj.size == 1) val join = findTopLevelBaseJoin(adaptivePlan) assert(join.isEmpty) - checkNumLocalShuffleReaders(adaptivePlan) + checkNumLocalShuffleReads(adaptivePlan) } } - test("SPARK-32717: AQEOptimizer should respect excludedRules configuration") { - withSQLConf( - SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", - SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> Long.MaxValue.toString, - // This test is a copy of test(SPARK-32573), in order to test the configuration - // `spark.sql.adaptive.optimizer.excludedRules` works as expect. - SQLConf.ADAPTIVE_OPTIMIZER_EXCLUDED_RULES.key -> EliminateJoinToEmptyRelation.ruleName) { - val (plan, adaptivePlan) = runAdaptiveAndVerifyResult( - "SELECT * FROM testData2 t1 WHERE t1.b NOT IN (SELECT b FROM testData3)") - val bhj = findTopLevelBroadcastHashJoin(plan) - assert(bhj.size == 1) + test("SPARK-32717: AQEOptimizer should respect excludedRules configuration") { + withSQLConf( + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> Long.MaxValue.toString, + // This test is a copy of test(SPARK-32573), in order to test the configuration + // `spark.sql.adaptive.optimizer.excludedRules` works as expect. + SQLConf.ADAPTIVE_OPTIMIZER_EXCLUDED_RULES.key -> AQEPropagateEmptyRelation.ruleName) { + val (plan, adaptivePlan) = runAdaptiveAndVerifyResult( + "SELECT * FROM testData2 t1 WHERE t1.b NOT IN (SELECT b FROM testData3)") + val bhj = findTopLevelBroadcastHashJoin(plan) + assert(bhj.size == 1) + val join = findTopLevelBaseJoin(adaptivePlan) + // this is different compares to test(SPARK-32573) due to the rule + // `EliminateUnnecessaryJoin` has been excluded. + assert(join.nonEmpty) + checkNumLocalShuffleReads(adaptivePlan) + } + } + + test("SPARK-32649: Eliminate inner and semi join to empty relation") { + withSQLConf( + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") { + Seq( + // inner join (small table at right side) + "SELECT * FROM testData t1 join testData3 t2 ON t1.key = t2.a WHERE t2.b = 1", + // inner join (small table at left side) + "SELECT * FROM testData3 t1 join testData t2 ON t1.a = t2.key WHERE t1.b = 1", + // left semi join + "SELECT * FROM testData t1 left semi join testData3 t2 ON t1.key = t2.a AND t2.b = 1" + ).foreach(query => { + val (plan, adaptivePlan) = runAdaptiveAndVerifyResult(query) + val smj = findTopLevelSortMergeJoin(plan) + assert(smj.size == 1) val join = findTopLevelBaseJoin(adaptivePlan) - // this is different compares to test(SPARK-32573) due to the rule - // `EliminateJoinToEmptyRelation` has been excluded. - assert(join.nonEmpty) - checkNumLocalShuffleReaders(adaptivePlan) + assert(join.isEmpty) + checkNumLocalShuffleReads(adaptivePlan) + }) + } + } + + test("SPARK-34533: Eliminate left anti join to empty relation") { + withSQLConf( + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") { + Seq( + // broadcast non-empty right side + ("SELECT /*+ broadcast(testData3) */ * FROM testData LEFT ANTI JOIN testData3", true), + // broadcast empty right side + ("SELECT /*+ broadcast(emptyTestData) */ * FROM testData LEFT ANTI JOIN emptyTestData", + true), + // broadcast left side + ("SELECT /*+ broadcast(testData) */ * FROM testData LEFT ANTI JOIN testData3", false) + ).foreach { case (query, isEliminated) => + val (plan, adaptivePlan) = runAdaptiveAndVerifyResult(query) + assert(findTopLevelBaseJoin(plan).size == 1) + assert(findTopLevelBaseJoin(adaptivePlan).isEmpty == isEliminated) } } + } - test("SPARK-32649: Eliminate inner to empty relation") { - withSQLConf( - SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", - SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") { - Seq( - // inner join (small table at right side) - "SELECT * FROM testData t1 join testData3 t2 ON t1.key = t2.a WHERE t2.b = 1", - // inner join (small table at left side) - "SELECT * FROM testData3 t1 join testData t2 ON t1.a = t2.key WHERE t1.b = 1", - // left semi join : left join do not has omni impl - // "SELECT * FROM testData t1 left semi join testData3 t2 ON t1.key = t2.a AND t2.b = 1" - ).foreach(query => { - val (plan, adaptivePlan) = runAdaptiveAndVerifyResult(query) - val smj = findTopLevelSortMergeJoin(plan) - assert(smj.size == 1) - val join = findTopLevelBaseJoin(adaptivePlan) - assert(join.isEmpty) - checkNumLocalShuffleReaders(adaptivePlan) - }) + test("SPARK-34781: Eliminate left semi/anti join to its left side") { + withSQLConf( + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") { + Seq( + // left semi join and non-empty right side + ("SELECT * FROM testData LEFT SEMI JOIN testData3", true), + // left semi join, non-empty right side and non-empty join condition + ("SELECT * FROM testData t1 LEFT SEMI JOIN testData3 t2 ON t1.key = t2.a", false), + // left anti join and empty right side + ("SELECT * FROM testData LEFT ANTI JOIN emptyTestData", true), + // left anti join, empty right side and non-empty join condition + ("SELECT * FROM testData t1 LEFT ANTI JOIN emptyTestData t2 ON t1.key = t2.key", true) + ).foreach { case (query, isEliminated) => + val (plan, adaptivePlan) = runAdaptiveAndVerifyResult(query) + assert(findTopLevelBaseJoin(plan).size == 1) + assert(findTopLevelBaseJoin(adaptivePlan).isEmpty == isEliminated) } } + } - test("SPARK-32753: Only copy tags to node with no tags") { - withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") { - withTempView("v1") { - spark.range(10).union(spark.range(10)).createOrReplaceTempView("v1") + test("SPARK-35455: Unify empty relation optimization between normal and AQE optimizer " + + "- single join") { + withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { + Seq( + // left semi join and empty left side + ("SELECT * FROM (SELECT * FROM testData WHERE value = '0')t1 LEFT SEMI JOIN " + + "testData2 t2 ON t1.key = t2.a", true), + // left anti join and empty left side + ("SELECT * FROM (SELECT * FROM testData WHERE value = '0')t1 LEFT ANTI JOIN " + + "testData2 t2 ON t1.key = t2.a", true), + // left outer join and empty left side + ("SELECT * FROM (SELECT * FROM testData WHERE key = 0)t1 LEFT JOIN testData2 t2 ON " + + "t1.key = t2.a", true), + // left outer join and non-empty left side + ("SELECT * FROM testData t1 LEFT JOIN testData2 t2 ON " + + "t1.key = t2.a", false), + // right outer join and empty right side + ("SELECT * FROM testData t1 RIGHT JOIN (SELECT * FROM testData2 WHERE b = 0)t2 ON " + + "t1.key = t2.a", true), + // right outer join and non-empty right side + ("SELECT * FROM testData t1 RIGHT JOIN testData2 t2 ON " + + "t1.key = t2.a", false), + // full outer join and both side empty + ("SELECT * FROM (SELECT * FROM testData WHERE key = 0)t1 FULL JOIN " + + "(SELECT * FROM testData2 WHERE b = 0)t2 ON t1.key = t2.a", true), + // full outer join and left side empty right side non-empty + ("SELECT * FROM (SELECT * FROM testData WHERE key = 0)t1 FULL JOIN " + + "testData2 t2 ON t1.key = t2.a", true) + ).foreach { case (query, isEliminated) => + val (plan, adaptivePlan) = runAdaptiveAndVerifyResult(query) + assert(findTopLevelBaseJoin(plan).size == 1) + assert(findTopLevelBaseJoin(adaptivePlan).isEmpty == isEliminated, adaptivePlan) + } + } + } - val (_, adaptivePlan) = runAdaptiveAndVerifyResult( - "SELECT id FROM v1 GROUP BY id DISTRIBUTE BY id") - assert(collect(adaptivePlan) { - case s: ShuffleExchangeExec => s - }.length == 1) - } + test("SPARK-35455: Unify empty relation optimization between normal and AQE optimizer " + + "- multi join") { + withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { + Seq( + """ + |SELECT * FROM testData t1 + | JOIN (SELECT * FROM testData2 WHERE b = 0) t2 ON t1.key = t2.a + | LEFT JOIN testData2 t3 ON t1.key = t3.a + |""".stripMargin, + """ + |SELECT * FROM (SELECT * FROM testData WHERE key = 0) t1 + | LEFT ANTI JOIN testData2 t2 + | FULL JOIN (SELECT * FROM testData2 WHERE b = 0) t3 ON t1.key = t3.a + |""".stripMargin, + """ + |SELECT * FROM testData t1 + | LEFT SEMI JOIN (SELECT * FROM testData2 WHERE b = 0) + | RIGHT JOIN testData2 t3 on t1.key = t3.a + |""".stripMargin + ).foreach { query => + val (plan, adaptivePlan) = runAdaptiveAndVerifyResult(query) + assert(findTopLevelBaseJoin(plan).size == 2) + assert(findTopLevelBaseJoin(adaptivePlan).isEmpty) } } + } - test("Logging plan changes for AQE") { - val testAppender = new LogAppender("plan changes") - withLogAppender(testAppender) { - withSQLConf( + test("SPARK-35585: Support propagate empty relation through project/filter") { + withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { + val (plan1, adaptivePlan1) = runAdaptiveAndVerifyResult( + "SELECT key FROM testData WHERE key = 0 ORDER BY key, value") + assert(findTopLevelSort(plan1).size == 1) + assert(stripAQEPlan(adaptivePlan1).isInstanceOf[LocalTableScanExec]) + + val (plan2, adaptivePlan2) = runAdaptiveAndVerifyResult( + "SELECT key FROM (SELECT * FROM testData WHERE value = 'no_match' ORDER BY key)" + + " WHERE key > rand()") + assert(findTopLevelSort(plan2).size == 1) + assert(stripAQEPlan(adaptivePlan2).isInstanceOf[LocalTableScanExec]) + } + } + + test("SPARK-35442: Support propagate empty relation through aggregate") { + withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") { + val (plan1, adaptivePlan1) = runAdaptiveAndVerifyResult( + "SELECT key, count(*) FROM testData WHERE value = 'no_match' GROUP BY key") + assert(!plan1.isInstanceOf[LocalTableScanExec]) + assert(stripAQEPlan(adaptivePlan1).isInstanceOf[LocalTableScanExec]) + + val (plan2, adaptivePlan2) = runAdaptiveAndVerifyResult( + "SELECT key, count(*) FROM testData WHERE value = 'no_match' GROUP BY key limit 1") + assert(!plan2.isInstanceOf[LocalTableScanExec]) + assert(stripAQEPlan(adaptivePlan2).isInstanceOf[LocalTableScanExec]) + + val (plan3, adaptivePlan3) = runAdaptiveAndVerifyResult( + "SELECT count(*) FROM testData WHERE value = 'no_match'") + assert(!plan3.isInstanceOf[LocalTableScanExec]) + assert(!stripAQEPlan(adaptivePlan3).isInstanceOf[LocalTableScanExec]) + } + } + + test("SPARK-35442: Support propagate empty relation through union") { + def checkNumUnion(plan: SparkPlan, numUnion: Int): Unit = { + assert( + collect(plan) { + case u: UnionExec => u + }.size == numUnion) + } + + withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") { + val (plan1, adaptivePlan1) = runAdaptiveAndVerifyResult( + """ + |SELECT key, count(*) FROM testData WHERE value = 'no_match' GROUP BY key + |UNION ALL + |SELECT key, 1 FROM testData + |""".stripMargin) + checkNumUnion(plan1, 1) + checkNumUnion(adaptivePlan1, 0) + assert(!stripAQEPlan(adaptivePlan1).isInstanceOf[LocalTableScanExec]) + + val (plan2, adaptivePlan2) = runAdaptiveAndVerifyResult( + """ + |SELECT key, count(*) FROM testData WHERE value = 'no_match' GROUP BY key + |UNION ALL + |SELECT /*+ REPARTITION */ key, 1 FROM testData WHERE value = 'no_match' + |""".stripMargin) + checkNumUnion(plan2, 1) + checkNumUnion(adaptivePlan2, 0) + assert(stripAQEPlan(adaptivePlan2).isInstanceOf[LocalTableScanExec]) + } + } + + test("SPARK-32753: Only copy tags to node with no tags") { + withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") { + withTempView("v1") { + spark.range(10).union(spark.range(10)).createOrReplaceTempView("v1") + + val (_, adaptivePlan) = runAdaptiveAndVerifyResult( + "SELECT id FROM v1 GROUP BY id DISTRIBUTE BY id") + assert(collect(adaptivePlan) { + case s: ShuffleExchangeExec => s + }.length == 1) + } + } + } + + test("Logging plan changes for AQE") { + val testAppender = new LogAppender("plan changes") + withLogAppender(testAppender) { + withSQLConf( SQLConf.PLAN_CHANGE_LOG_LEVEL.key -> "INFO", SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") { - sql("SELECT * FROM testData JOIN testData2 ON key = a " + - "WHERE value = (SELECT max(a) FROM testData3)").collect() - } - Seq("=== Result of Batch AQE Preparations ===", + sql("SELECT * FROM testData JOIN testData2 ON key = a " + + "WHERE value = (SELECT max(a) FROM testData3)").collect() + } + Seq("=== Result of Batch AQE Preparations ===", "=== Result of Batch AQE Post Stage Creation ===", "=== Result of Batch AQE Replanning ===", - "=== Result of Batch AQE Query Stage Optimization ===", - "=== Result of Batch AQE Final Query Stage Optimization ===").foreach { expectedMsg => - assert(testAppender.loggingEvents.exists(_.getRenderedMessage.contains(expectedMsg))) - } + "=== Result of Batch AQE Query Stage Optimization ===").foreach { expectedMsg => + assert(testAppender.loggingEvents.exists( + _.getMessage.getFormattedMessage.contains(expectedMsg))) } } + } - test("SPARK-32932: Do not use local shuffle reader at final stage on write command") { - withSQLConf(SQLConf.PARTITION_OVERWRITE_MODE.key -> PartitionOverwriteMode.DYNAMIC.toString, - SQLConf.SHUFFLE_PARTITIONS.key -> "5", - SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") { - val data = for ( - i <- 1L to 10L; - j <- 1L to 3L - ) yield (i, j) - - val df = data.toDF("i", "j").repartition($"j") - var noLocalReader: Boolean = false - val listener = new QueryExecutionListener { - override def onSuccess(funcName: String, qe: QueryExecution, durationNs: Long): Unit = { - qe.executedPlan match { - case plan@(_: DataWritingCommandExec | _: V2TableWriteExec) => - assert(plan.asInstanceOf[UnaryExecNode].child.isInstanceOf[AdaptiveSparkPlanExec]) - noLocalReader = collect(plan) { - case exec: CustomShuffleReaderExec if exec.isLocalReader => exec - }.isEmpty - case _ => // ignore other events - } + test("SPARK-32932: Do not use local shuffle read at final stage on write command") { + withSQLConf(SQLConf.PARTITION_OVERWRITE_MODE.key -> PartitionOverwriteMode.DYNAMIC.toString, + SQLConf.SHUFFLE_PARTITIONS.key -> "5", + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") { + val data = for ( + i <- 1L to 10L; + j <- 1L to 3L + ) yield (i, j) + + val df = data.toDF("i", "j").repartition($"j") + var noLocalread: Boolean = false + val listener = new QueryExecutionListener { + override def onSuccess(funcName: String, qe: QueryExecution, durationNs: Long): Unit = { + qe.executedPlan match { + case plan@(_: DataWritingCommandExec | _: V2TableWriteExec) => + assert(plan.asInstanceOf[UnaryExecNode].child.isInstanceOf[AdaptiveSparkPlanExec]) + noLocalread = collect(plan) { + case exec: AQEShuffleReadExec if exec.isLocalRead => exec + }.isEmpty + case _ => // ignore other events } - override def onFailure(funcName: String, qe: QueryExecution, - exception: Exception): Unit = {} - } - spark.listenerManager.register(listener) - - withTable("t") { - df.write.partitionBy("j").saveAsTable("t") - sparkContext.listenerBus.waitUntilEmpty() - assert(noLocalReader) - noLocalReader = false } + override def onFailure(funcName: String, qe: QueryExecution, + exception: Exception): Unit = {} + } + spark.listenerManager.register(listener) - // Test DataSource v2 - val format = classOf[NoopDataSource].getName - df.write.format(format).mode("overwrite").save() + withTable("t") { + df.write.partitionBy("j").saveAsTable("t") sparkContext.listenerBus.waitUntilEmpty() - assert(noLocalReader) - noLocalReader = false - - spark.listenerManager.unregister(listener) + assert(noLocalread) + noLocalread = false } + + // Test DataSource v2 + val format = classOf[NoopDataSource].getName + df.write.format(format).mode("overwrite").save() + sparkContext.listenerBus.waitUntilEmpty() + assert(noLocalread) + noLocalread = false + + spark.listenerManager.unregister(listener) } + } - test("SPARK-33494: Do not use local shuffle reader for repartition") { - withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") { - val df = spark.table("testData").repartition('key) - df.collect() - // local shuffle reader breaks partitioning and shouldn't be used for repartition operation - // which is specified by users. - checkNumLocalShuffleReaders(df.queryExecution.executedPlan, numShufflesWithoutLocalReader = 1) - } + test("SPARK-33494: Do not use local shuffle read for repartition") { + withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") { + val df = spark.table("testData").repartition(Symbol("key")) + df.collect() + // local shuffle read breaks partitioning and shouldn't be used for repartition operation + // which is specified by users. + checkNumLocalShuffleReads(df.queryExecution.executedPlan, numShufflesWithoutLocalRead = 1) } + } - test("SPARK-33551: Do not use custom shuffle reader for repartition") { + test("SPARK-33551: Do not use AQE shuffle read for repartition") { def hasRepartitionShuffle(plan: SparkPlan): Boolean = { find(plan) { case s: ShuffleExchangeLike => - s.shuffleOrigin == REPARTITION || s.shuffleOrigin == REPARTITION_WITH_NUM + s.shuffleOrigin == REPARTITION_BY_COL || s.shuffleOrigin == REPARTITION_BY_NUM case _ => false }.isDefined } + def checkBHJ( + df: Dataset[Row], + optimizeOutRepartition: Boolean, + probeSideLocalRead: Boolean, + probeSideCoalescedRead: Boolean): Unit = { + df.collect() + val plan = df.queryExecution.executedPlan + // There should be only one shuffle that can't do local read, which is either the top shuffle + // from repartition, or BHJ probe side shuffle. + checkNumLocalShuffleReads(plan, 1) + assert(hasRepartitionShuffle(plan) == !optimizeOutRepartition) + val bhj = findTopLevelBroadcastHashJoin(plan) + assert(bhj.length == 1) + + // Build side should do local read. + val buildSide = find(bhj.head.left)(_.isInstanceOf[AQEShuffleReadExec]) + assert(buildSide.isDefined) + assert(buildSide.get.asInstanceOf[AQEShuffleReadExec].isLocalRead) + + val probeSide = find(bhj.head.right)(_.isInstanceOf[AQEShuffleReadExec]) + if (probeSideLocalRead || probeSideCoalescedRead) { + assert(probeSide.isDefined) + if (probeSideLocalRead) { + assert(probeSide.get.asInstanceOf[AQEShuffleReadExec].isLocalRead) + } else { + assert(probeSide.get.asInstanceOf[AQEShuffleReadExec].hasCoalescedPartition) + } + } else { + assert(probeSide.isEmpty) + } + } + + def checkSMJ( + df: Dataset[Row], + optimizeOutRepartition: Boolean, + optimizeSkewJoin: Boolean, + coalescedRead: Boolean): Unit = { + df.collect() + val plan = df.queryExecution.executedPlan + assert(hasRepartitionShuffle(plan) == !optimizeOutRepartition) + val smj = findTopLevelSortMergeJoin(plan) + assert(smj.length == 1) + assert(smj.head.isSkewJoin == optimizeSkewJoin) + val aqeReads = collect(smj.head) { + case c: AQEShuffleReadExec => c + } + if (coalescedRead || optimizeSkewJoin) { + assert(aqeReads.length == 2) + if (coalescedRead) assert(aqeReads.forall(_.hasCoalescedPartition)) + } else { + assert(aqeReads.isEmpty) + } + } + withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", SQLConf.SHUFFLE_PARTITIONS.key -> "5") { val df = sql( @@ -1359,50 +1697,30 @@ class ColumnarAdaptiveQueryExecSuite extends ColumnarSparkPlanTest | SELECT * FROM testData WHERE key = 1 |) |RIGHT OUTER JOIN testData2 - |ON value = b - """.stripMargin) + |ON CAST(value AS INT) = b + """.stripMargin) withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") { // Repartition with no partition num specified. - val dfRepartition = df.repartition('b) - dfRepartition.collect() - val plan = dfRepartition.queryExecution.executedPlan - // The top shuffle from repartition is optimized out. - assert(!hasRepartitionShuffle(plan)) - val bhj = findTopLevelBroadcastHashJoin(plan) - assert(bhj.length == 1) - checkNumLocalShuffleReaders(plan, 1) - // Probe side is coalesced. - val customReader = bhj.head.right.find(_.isInstanceOf[ColumnarCustomShuffleReaderExec]) - assert(customReader.isDefined) - assert(customReader.get.asInstanceOf[ColumnarCustomShuffleReaderExec].hasCoalescedPartition) - - // Repartition with partition default num specified. - val dfRepartitionWithNum = df.repartition(5, 'b) - dfRepartitionWithNum.collect() - val planWithNum = dfRepartitionWithNum.queryExecution.executedPlan - // The top shuffle from repartition is optimized out. - assert(!hasRepartitionShuffle(planWithNum)) - val bhjWithNum = findTopLevelBroadcastHashJoin(planWithNum) - assert(bhjWithNum.length == 1) - checkNumLocalShuffleReaders(planWithNum, 1) - // Probe side is not coalesced. - assert(bhjWithNum.head.right.find(_.isInstanceOf[CustomShuffleReaderExec]).isEmpty) - - // Repartition with partition non-default num specified. - val dfRepartitionWithNum2 = df.repartition(3, 'b) - dfRepartitionWithNum2.collect() - val planWithNum2 = dfRepartitionWithNum2.queryExecution.executedPlan - // The top shuffle from repartition is not optimized out, and this is the only shuffle that - // does not have local shuffle reader. - assert(hasRepartitionShuffle(planWithNum2)) - val bhjWithNum2 = findTopLevelBroadcastHashJoin(planWithNum2) - assert(bhjWithNum2.length == 1) - checkNumLocalShuffleReaders(planWithNum2, 1) - val customReader2 = bhjWithNum2.head.right - .find(_.isInstanceOf[ColumnarCustomShuffleReaderExec]) - assert(customReader2.isDefined) - assert(customReader2.get.asInstanceOf[ColumnarCustomShuffleReaderExec].isLocalReader) + checkBHJ(df.repartition(Symbol("b")), + // The top shuffle from repartition is optimized out. + optimizeOutRepartition = true, probeSideLocalRead = false, probeSideCoalescedRead = true) + + // Repartition with default partition num (5 in test env) specified. + checkBHJ(df.repartition(5, Symbol("b")), + // The top shuffle from repartition is optimized out + // The final plan must have 5 partitions, no optimization can be made to the probe side. + optimizeOutRepartition = true, probeSideLocalRead = false, probeSideCoalescedRead = false) + + // Repartition with non-default partition num specified. + checkBHJ(df.repartition(4, Symbol("b")), + // The top shuffle from repartition is not optimized out + optimizeOutRepartition = false, probeSideLocalRead = true, probeSideCoalescedRead = true) + + // Repartition by col and project away the partition cols + checkBHJ(df.repartition(Symbol("b")).select(Symbol("key")), + // The top shuffle from repartition is not optimized out + optimizeOutRepartition = false, probeSideLocalRead = true, probeSideCoalescedRead = true) } // Force skew join @@ -1412,108 +1730,941 @@ class ColumnarAdaptiveQueryExecSuite extends ColumnarSparkPlanTest SQLConf.SKEW_JOIN_SKEWED_PARTITION_FACTOR.key -> "0", SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES.key -> "10") { // Repartition with no partition num specified. - val dfRepartition = df.repartition('b) - dfRepartition.collect() - val plan = dfRepartition.queryExecution.executedPlan - // The top shuffle from repartition is optimized out. - assert(!hasRepartitionShuffle(plan)) - val smj = findTopLevelSortMergeJoin(plan) - assert(smj.length == 1) - // No skew join due to the repartition. - assert(!smj.head.isSkewJoin) - // Both sides are coalesced. - val customReaders = collect(smj.head) { - case c: CustomShuffleReaderExec if c.hasCoalescedPartition => c - case c: ColumnarCustomShuffleReaderExec if c.hasCoalescedPartition => c + checkSMJ(df.repartition(Symbol("b")), + // The top shuffle from repartition is optimized out. + optimizeOutRepartition = true, optimizeSkewJoin = false, coalescedRead = true) + + // Repartition with default partition num (5 in test env) specified. + checkSMJ(df.repartition(5, Symbol("b")), + // The top shuffle from repartition is optimized out. + // The final plan must have 5 partitions, can't do coalesced read. + optimizeOutRepartition = true, optimizeSkewJoin = false, coalescedRead = false) + + // Repartition with non-default partition num specified. + checkSMJ(df.repartition(4, Symbol("b")), + // The top shuffle from repartition is not optimized out. + optimizeOutRepartition = false, optimizeSkewJoin = true, coalescedRead = false) + + // Repartition by col and project away the partition cols + checkSMJ(df.repartition(Symbol("b")).select(Symbol("key")), + // The top shuffle from repartition is not optimized out. + optimizeOutRepartition = false, optimizeSkewJoin = true, coalescedRead = false) + } + } + } + + test("SPARK-34091: Batch shuffle fetch in AQE partition coalescing") { + withSQLConf( + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", + SQLConf.SHUFFLE_PARTITIONS.key -> "10", + SQLConf.FETCH_SHUFFLE_BLOCKS_IN_BATCH.key -> "true") { + withTable("t1") { + spark.range(100).selectExpr("id + 1 as a").write.format("parquet").saveAsTable("t1") + val query = "SELECT SUM(a) FROM t1 GROUP BY a" + val (_, adaptivePlan) = runAdaptiveAndVerifyResult(query) + val metricName = SQLShuffleReadMetricsReporter.LOCAL_BLOCKS_FETCHED + val blocksFetchedMetric = collectFirst(adaptivePlan) { + case p if p.metrics.contains(metricName) => p.metrics(metricName) + } + assert(blocksFetchedMetric.isDefined) + val blocksFetched = blocksFetchedMetric.get.value + withSQLConf(SQLConf.FETCH_SHUFFLE_BLOCKS_IN_BATCH.key -> "false") { + val (_, adaptivePlan2) = runAdaptiveAndVerifyResult(query) + val blocksFetchedMetric2 = collectFirst(adaptivePlan2) { + case p if p.metrics.contains(metricName) => p.metrics(metricName) + } + assert(blocksFetchedMetric2.isDefined) + val blocksFetched2 = blocksFetchedMetric2.get.value + assert(blocksFetched < blocksFetched2) + } + } + } + } + + test("SPARK-33933: Materialize BroadcastQueryStage first in AQE") { + val testAppender = new LogAppender("aqe query stage materialization order test") + testAppender.setThreshold(Level.DEBUG) + val df = spark.range(1000).select($"id" % 26, $"id" % 10) + .toDF("index", "pv") + val dim = Range(0, 26).map(x => (x, ('a' + x).toChar.toString)) + .toDF("index", "name") + val testDf = df.groupBy("index") + .agg(sum($"pv").alias("pv")) + .join(dim, Seq("index")) + val loggerNames = + Seq(classOf[BroadcastQueryStageExec].getName, classOf[ShuffleQueryStageExec].getName) + withLogAppender(testAppender, loggerNames, level = Some(Level.DEBUG)) { + withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") { + val result = testDf.collect() + assert(result.length == 26) + } + } + val materializeLogs = testAppender.loggingEvents + .map(_.getMessage.getFormattedMessage) + .filter(_.startsWith("Materialize query stage")) + .toArray + assert(materializeLogs(0).startsWith("Materialize query stage BroadcastQueryStageExec")) + assert(materializeLogs(1).startsWith("Materialize query stage ShuffleQueryStageExec")) + } + + test("SPARK-34899: Use origin plan if we can not coalesce shuffle partition") { + def checkNoCoalescePartitions(ds: Dataset[Row], origin: ShuffleOrigin): Unit = { + assert(collect(ds.queryExecution.executedPlan) { + case s: ShuffleExchangeExec if s.shuffleOrigin == origin && s.numPartitions == 2 => s + }.size == 1) + ds.collect() + val plan = ds.queryExecution.executedPlan + assert(collect(plan) { + case c: AQEShuffleReadExec => c + }.isEmpty) + assert(collect(plan) { + case s: ShuffleExchangeExec if s.shuffleOrigin == origin && s.numPartitions == 2 => s + }.size == 1) + checkAnswer(ds, testData) + } + + withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", + SQLConf.COALESCE_PARTITIONS_ENABLED.key -> "true", + // Pick a small value so that no coalesce can happen. + SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES.key -> "100", + SQLConf.COALESCE_PARTITIONS_MIN_PARTITION_NUM.key -> "1", + SQLConf.SHUFFLE_PARTITIONS.key -> "2") { + val df = spark.sparkContext.parallelize( + (1 to 100).map(i => TestData(i, i.toString)), 10).toDF() + + // partition size [1420, 1420] + checkNoCoalescePartitions(df.repartition($"key"), REPARTITION_BY_COL) + // partition size [1140, 1119] + checkNoCoalescePartitions(df.sort($"key"), ENSURE_REQUIREMENTS) + } + } + + test("SPARK-34980: Support coalesce partition through union") { + def checkResultPartition( + df: Dataset[Row], + numUnion: Int, + numShuffleReader: Int, + numPartition: Int): Unit = { + df.collect() + assert(collect(df.queryExecution.executedPlan) { + case u: UnionExec => u + }.size == numUnion) + assert(collect(df.queryExecution.executedPlan) { + case r: AQEShuffleReadExec => r + }.size === numShuffleReader) + assert(df.rdd.partitions.length === numPartition) + } + + Seq(true, false).foreach { combineUnionEnabled => + val combineUnionConfig = if (combineUnionEnabled) { + SQLConf.OPTIMIZER_EXCLUDED_RULES.key -> "" + } else { + SQLConf.OPTIMIZER_EXCLUDED_RULES.key -> + "org.apache.spark.sql.catalyst.optimizer.CombineUnions" + } + // advisory partition size 1048576 has no special meaning, just a big enough value + withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", + SQLConf.COALESCE_PARTITIONS_ENABLED.key -> "true", + SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES.key -> "1048576", + SQLConf.COALESCE_PARTITIONS_MIN_PARTITION_NUM.key -> "1", + SQLConf.SHUFFLE_PARTITIONS.key -> "10", + combineUnionConfig) { + withTempView("t1", "t2") { + spark.sparkContext.parallelize((1 to 10).map(i => TestData(i, i.toString)), 2) + .toDF().createOrReplaceTempView("t1") + spark.sparkContext.parallelize((1 to 10).map(i => TestData(i, i.toString)), 4) + .toDF().createOrReplaceTempView("t2") + + // positive test that could be coalesced + checkResultPartition( + sql(""" + |SELECT key, count(*) FROM t1 GROUP BY key + |UNION ALL + |SELECT * FROM t2 + """.stripMargin), + numUnion = 1, + numShuffleReader = 1, + numPartition = 1 + 4) + + checkResultPartition( + sql(""" + |SELECT key, count(*) FROM t1 GROUP BY key + |UNION ALL + |SELECT * FROM t2 + |UNION ALL + |SELECT * FROM t1 + """.stripMargin), + numUnion = if (combineUnionEnabled) 1 else 2, + numShuffleReader = 1, + numPartition = 1 + 4 + 2) + + checkResultPartition( + sql(""" + |SELECT /*+ merge(t2) */ t1.key, t2.key FROM t1 JOIN t2 ON t1.key = t2.key + |UNION ALL + |SELECT key, count(*) FROM t2 GROUP BY key + |UNION ALL + |SELECT * FROM t1 + """.stripMargin), + numUnion = if (combineUnionEnabled) 1 else 2, + numShuffleReader = 3, + numPartition = 1 + 1 + 2) + + // negative test + checkResultPartition( + sql("SELECT * FROM t1 UNION ALL SELECT * FROM t2"), + numUnion = if (combineUnionEnabled) 1 else 1, + numShuffleReader = 0, + numPartition = 2 + 4 + ) } - assert(customReaders.length == 2) - - // Repartition with default partition num specified. - val dfRepartitionWithNum = df.repartition(5, 'b) - dfRepartitionWithNum.collect() - val planWithNum = dfRepartitionWithNum.queryExecution.executedPlan - // The top shuffle from repartition is optimized out. - assert(!hasRepartitionShuffle(planWithNum)) - val smjWithNum = findTopLevelSortMergeJoin(planWithNum) - assert(smjWithNum.length == 1) - // No skew join due to the repartition. - assert(!smjWithNum.head.isSkewJoin) - // No coalesce due to the num in repartition. - val customReadersWithNum = collect(smjWithNum.head) { - case c: CustomShuffleReaderExec if c.hasCoalescedPartition => c + } + } + } + + test("SPARK-35239: Coalesce shuffle partition should handle empty input RDD") { + withTable("t") { + withSQLConf(SQLConf.COALESCE_PARTITIONS_MIN_PARTITION_NUM.key -> "1", + SQLConf.SHUFFLE_PARTITIONS.key -> "2", + SQLConf.ADAPTIVE_OPTIMIZER_EXCLUDED_RULES.key -> AQEPropagateEmptyRelation.ruleName) { + spark.sql("CREATE TABLE t (c1 int) USING PARQUET") + val (_, adaptive) = runAdaptiveAndVerifyResult("SELECT c1, count(*) FROM t GROUP BY c1") + assert( + collect(adaptive) { + case c @ AQEShuffleReadExec(_, partitionSpecs) if partitionSpecs.length == 1 => + assert(c.hasCoalescedPartition) + c + }.length == 1 + ) + } + } + } + + test("SPARK-35264: Support AQE side broadcastJoin threshold") { + withTempView("t1", "t2") { + def checkJoinStrategy(shouldBroadcast: Boolean): Unit = { + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { + val (origin, adaptive) = runAdaptiveAndVerifyResult( + "SELECT t1.c1, t2.c1 FROM t1 JOIN t2 ON t1.c1 = t2.c1") + assert(findTopLevelSortMergeJoin(origin).size == 1) + if (shouldBroadcast) { + assert(findTopLevelBroadcastHashJoin(adaptive).size == 1) + } else { + assert(findTopLevelSortMergeJoin(adaptive).size == 1) + } } - assert(customReadersWithNum.isEmpty) + } + + // t1: 1600 bytes + // t2: 160 bytes + spark.sparkContext.parallelize( + (1 to 100).map(i => TestData(i, i.toString)), 10) + .toDF("c1", "c2").createOrReplaceTempView("t1") + spark.sparkContext.parallelize( + (1 to 10).map(i => TestData(i, i.toString)), 5) + .toDF("c1", "c2").createOrReplaceTempView("t2") + + checkJoinStrategy(false) + withSQLConf(SQLConf.ADAPTIVE_AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { + checkJoinStrategy(false) + } - // Repartition with default non-partition num specified. - val dfRepartitionWithNum2 = df.repartition(3, 'b) - dfRepartitionWithNum2.collect() - val planWithNum2 = dfRepartitionWithNum2.queryExecution.executedPlan - // The top shuffle from repartition is not optimized out. - assert(hasRepartitionShuffle(planWithNum2)) - val smjWithNum2 = findTopLevelSortMergeJoin(planWithNum2) - assert(smjWithNum2.length == 1) - // Skew join can apply as the repartition is not optimized out. - assert(smjWithNum2.head.isSkewJoin) + withSQLConf(SQLConf.ADAPTIVE_AUTO_BROADCASTJOIN_THRESHOLD.key -> "160") { + checkJoinStrategy(true) } } } - ignore("SPARK-34091: Batch shuffle fetch in AQE partition coalescing") { + test("SPARK-35264: Support AQE side shuffled hash join formula") { + withTempView("t1", "t2") { + def checkJoinStrategy(shouldShuffleHashJoin: Boolean): Unit = { + Seq("100", "100000").foreach { size => + withSQLConf(SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES.key -> size) { + val (origin1, adaptive1) = runAdaptiveAndVerifyResult( + "SELECT t1.c1, t2.c1 FROM t1 JOIN t2 ON t1.c1 = t2.c1") + assert(findTopLevelSortMergeJoin(origin1).size === 1) + if (shouldShuffleHashJoin && size.toInt < 100000) { + val shj = findTopLevelShuffledHashJoin(adaptive1) + assert(shj.size === 1) + assert(shj.head.buildSide == BuildRight) + } else { + assert(findTopLevelSortMergeJoin(adaptive1).size === 1) + } + } + } + // respect user specified join hint + val (origin2, adaptive2) = runAdaptiveAndVerifyResult( + "SELECT /*+ MERGE(t1) */ t1.c1, t2.c1 FROM t1 JOIN t2 ON t1.c1 = t2.c1") + assert(findTopLevelSortMergeJoin(origin2).size === 1) + assert(findTopLevelSortMergeJoin(adaptive2).size === 1) + } + + spark.sparkContext.parallelize( + (1 to 100).map(i => TestData(i, i.toString)), 10) + .toDF("c1", "c2").createOrReplaceTempView("t1") + spark.sparkContext.parallelize( + (1 to 10).map(i => TestData(i, i.toString)), 5) + .toDF("c1", "c2").createOrReplaceTempView("t2") + + // t1 partition size: [926, 729, 731] + // t2 partition size: [318, 120, 0] + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "3", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", + SQLConf.PREFER_SORTMERGEJOIN.key -> "true") { + // check default value + checkJoinStrategy(false) + withSQLConf(SQLConf.ADAPTIVE_MAX_SHUFFLE_HASH_JOIN_LOCAL_MAP_THRESHOLD.key -> "400") { + checkJoinStrategy(true) + } + withSQLConf(SQLConf.ADAPTIVE_MAX_SHUFFLE_HASH_JOIN_LOCAL_MAP_THRESHOLD.key -> "300") { + checkJoinStrategy(false) + } + withSQLConf(SQLConf.ADAPTIVE_MAX_SHUFFLE_HASH_JOIN_LOCAL_MAP_THRESHOLD.key -> "1000") { + checkJoinStrategy(true) + } + } + } + } + + test("SPARK-35650: Coalesce number of partitions by AEQ") { + withSQLConf(SQLConf.COALESCE_PARTITIONS_MIN_PARTITION_NUM.key -> "1") { + Seq("REPARTITION", "REBALANCE(key)") + .foreach {repartition => + val query = s"SELECT /*+ $repartition */ * FROM testData" + val (_, adaptivePlan) = runAdaptiveAndVerifyResult(query) + collect(adaptivePlan) { + case r: AQEShuffleReadExec => r + } match { + case Seq(aqeShuffleRead) => + assert(aqeShuffleRead.partitionSpecs.size === 1) + assert(!aqeShuffleRead.isLocalRead) + case _ => + fail("There should be a AQEShuffleReadExec") + } + } + } + } + + test("SPARK-35650: Use local shuffle read if can not coalesce number of partitions") { + withSQLConf(SQLConf.COALESCE_PARTITIONS_ENABLED.key -> "false") { + val query = "SELECT /*+ REPARTITION */ * FROM testData" + val (_, adaptivePlan) = runAdaptiveAndVerifyResult(query) + collect(adaptivePlan) { + case r: AQEShuffleReadExec => r + } match { + case Seq(aqeShuffleRead) => + assert(aqeShuffleRead.partitionSpecs.size === 4) + assert(aqeShuffleRead.isLocalRead) + case _ => + fail("There should be a AQEShuffleReadExec") + } + } + } + + test("SPARK-35725: Support optimize skewed partitions in RebalancePartitions") { + withTempView("v") { withSQLConf( SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", - SQLConf.SHUFFLE_PARTITIONS.key -> "10000", - SQLConf.FETCH_SHUFFLE_BLOCKS_IN_BATCH.key -> "true") { - withTable("t1") { - spark.range(100).selectExpr("id + 1 as a").write.format("parquet").saveAsTable("t1") - val query = "SELECT SUM(a) FROM t1 GROUP BY a" - val (_, adaptivePlan) = runAdaptiveAndVerifyResult(query) - val metricName = SQLShuffleReadMetricsReporter.LOCAL_BLOCKS_FETCHED - val blocksFetchedMetric = collectFirst(adaptivePlan) { - case p if p.metrics.contains(metricName) => p.metrics(metricName) + SQLConf.COALESCE_PARTITIONS_ENABLED.key -> "true", + SQLConf.ADAPTIVE_OPTIMIZE_SKEWS_IN_REBALANCE_PARTITIONS_ENABLED.key -> "true", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", + SQLConf.SHUFFLE_PARTITIONS.key -> "5", + SQLConf.COALESCE_PARTITIONS_MIN_PARTITION_NUM.key -> "1") { + + spark.sparkContext.parallelize( + (1 to 10).map(i => TestData(if (i > 4) 5 else i, i.toString)), 3) + .toDF("c1", "c2").createOrReplaceTempView("v") + + def checkPartitionNumber( + query: String, skewedPartitionNumber: Int, totalNumber: Int): Unit = { + val (_, adaptive) = runAdaptiveAndVerifyResult(query) + val read = collect(adaptive) { + case read: AQEShuffleReadExec => read } - assert(blocksFetchedMetric.isDefined) - val blocksFetched = blocksFetchedMetric.get.value - withSQLConf(SQLConf.FETCH_SHUFFLE_BLOCKS_IN_BATCH.key -> "false") { - val (_, adaptivePlan2) = runAdaptiveAndVerifyResult(query) - val blocksFetchedMetric2 = collectFirst(adaptivePlan2) { - case p if p.metrics.contains(metricName) => p.metrics(metricName) + assert(read.size == 1) + assert(read.head.partitionSpecs.count(_.isInstanceOf[PartialReducerPartitionSpec]) == + skewedPartitionNumber) + assert(read.head.partitionSpecs.size == totalNumber) + } + + withSQLConf(SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES.key -> "150") { + // partition size [0,258,72,72,72] + checkPartitionNumber("SELECT /*+ REBALANCE(c1) */ * FROM v", 2, 4) + // partition size [144,72,144,72,72,144,72] + checkPartitionNumber("SELECT /*+ REBALANCE */ * FROM v", 6, 7) + } + + // no skewed partition should be optimized + withSQLConf(SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES.key -> "10000") { + checkPartitionNumber("SELECT /*+ REBALANCE(c1) */ * FROM v", 0, 1) + } + } + } + } + + test("SPARK-35888: join with a 0-partition table") { + withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", + SQLConf.COALESCE_PARTITIONS_MIN_PARTITION_NUM.key -> "1", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", + SQLConf.ADAPTIVE_OPTIMIZER_EXCLUDED_RULES.key -> AQEPropagateEmptyRelation.ruleName) { + withTempView("t2") { + // create a temp view with 0 partition + spark.createDataFrame(sparkContext.emptyRDD[Row], new StructType().add("b", IntegerType)) + .createOrReplaceTempView("t2") + val (_, adaptive) = + runAdaptiveAndVerifyResult("SELECT * FROM testData2 t1 left semi join t2 ON t1.a=t2.b") + val aqeReads = collect(adaptive) { + case c: AQEShuffleReadExec => c + } + assert(aqeReads.length == 2) + aqeReads.foreach { c => + val stats = c.child.asInstanceOf[QueryStageExec].getRuntimeStatistics + assert(stats.sizeInBytes >= 0) + assert(stats.rowCount.get >= 0) + } + } + } + } + + test("SPARK-33832: Support optimize skew join even if introduce extra shuffle") { + withSQLConf( + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", + SQLConf.ADAPTIVE_OPTIMIZE_SKEWS_IN_REBALANCE_PARTITIONS_ENABLED.key -> "false", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", + SQLConf.SKEW_JOIN_SKEWED_PARTITION_THRESHOLD.key -> "100", + SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES.key -> "100", + SQLConf.COALESCE_PARTITIONS_MIN_PARTITION_NUM.key -> "1", + SQLConf.SHUFFLE_PARTITIONS.key -> "10", + SQLConf.ADAPTIVE_FORCE_OPTIMIZE_SKEWED_JOIN.key -> "true") { + withTempView("skewData1", "skewData2") { + spark + .range(0, 1000, 1, 10) + .selectExpr("id % 3 as key1", "id as value1") + .createOrReplaceTempView("skewData1") + spark + .range(0, 1000, 1, 10) + .selectExpr("id % 1 as key2", "id as value2") + .createOrReplaceTempView("skewData2") + + // check if optimized skewed join does not satisfy the required distribution + Seq(true, false).foreach { hasRequiredDistribution => + Seq(true, false).foreach { hasPartitionNumber => + val repartition = if (hasRequiredDistribution) { + s"/*+ repartition(${ if (hasPartitionNumber) "10," else ""}key1) */" + } else { + "" + } + + // check required distribution and extra shuffle + val (_, adaptive1) = + runAdaptiveAndVerifyResult(s"SELECT $repartition key1 FROM skewData1 " + + s"JOIN skewData2 ON key1 = key2 GROUP BY key1") + val shuffles1 = collect(adaptive1) { + case s: ShuffleExchangeExec => s } - assert(blocksFetchedMetric2.isDefined) - val blocksFetched2 = blocksFetchedMetric2.get.value - assert(blocksFetched < blocksFetched2) + assert(shuffles1.size == 3) + // shuffles1.head is the top-level shuffle under the Aggregate operator + assert(shuffles1.head.shuffleOrigin == ENSURE_REQUIREMENTS) + val smj1 = findTopLevelSortMergeJoin(adaptive1) + assert(smj1.size == 1 && smj1.head.isSkewJoin) + + // only check required distribution + val (_, adaptive2) = + runAdaptiveAndVerifyResult(s"SELECT $repartition key1 FROM skewData1 " + + s"JOIN skewData2 ON key1 = key2") + val shuffles2 = collect(adaptive2) { + case s: ShuffleExchangeExec => s + } + if (hasRequiredDistribution) { + assert(shuffles2.size == 3) + val finalShuffle = shuffles2.head + if (hasPartitionNumber) { + assert(finalShuffle.shuffleOrigin == REPARTITION_BY_NUM) + } else { + assert(finalShuffle.shuffleOrigin == REPARTITION_BY_COL) + } + } else { + assert(shuffles2.size == 2) + } + val smj2 = findTopLevelSortMergeJoin(adaptive2) + assert(smj2.size == 1 && smj2.head.isSkewJoin) } } } } + } + + test("SPARK-35968: AQE coalescing should not produce too small partitions by default") { + withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") { + val (_, adaptive) = + runAdaptiveAndVerifyResult("SELECT sum(id) FROM RANGE(10) GROUP BY id % 3") + val coalesceRead = collect(adaptive) { + case r: AQEShuffleReadExec if r.hasCoalescedPartition => r + } + assert(coalesceRead.length == 1) + // RANGE(10) is a very small dataset and AQE coalescing should produce one partition. + assert(coalesceRead.head.partitionSpecs.length == 1) + } + } + + test("SPARK-35794: Allow custom plugin for cost evaluator") { + CostEvaluator.instantiate( + classOf[SimpleShuffleSortCostEvaluator].getCanonicalName, spark.sparkContext.getConf) + intercept[IllegalArgumentException] { + CostEvaluator.instantiate( + classOf[InvalidCostEvaluator].getCanonicalName, spark.sparkContext.getConf) + } + + withSQLConf( + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") { + val query = "SELECT * FROM testData join testData2 ON key = a where value = '1'" + + withSQLConf(SQLConf.ADAPTIVE_CUSTOM_COST_EVALUATOR_CLASS.key -> + "org.apache.spark.sql.execution.adaptive.SimpleShuffleSortCostEvaluator") { + val (plan, adaptivePlan) = runAdaptiveAndVerifyResult(query) + val smj = findTopLevelSortMergeJoin(plan) + assert(smj.size == 1) + val bhj = findTopLevelBroadcastHashJoin(adaptivePlan) + assert(bhj.size == 1) + checkNumLocalShuffleReads(adaptivePlan) + } + + withSQLConf(SQLConf.ADAPTIVE_CUSTOM_COST_EVALUATOR_CLASS.key -> + "org.apache.spark.sql.execution.adaptive.InvalidCostEvaluator") { + intercept[IllegalArgumentException] { + runAdaptiveAndVerifyResult(query) + } + } + } + } - test("Do not use column shuffle in AQE") { - def findCustomShuffleReader(plan: SparkPlan): Seq[CustomShuffleReaderExec] ={ - collect(plan) { - case j: CustomShuffleReaderExec => j + test("SPARK-36020: Check logical link in remove redundant projects") { + withTempView("t") { + spark.range(10).selectExpr("id % 10 as key", "cast(id * 2 as int) as a", + "cast(id * 3 as int) as b", "array(id, id + 1, id + 3) as c").createOrReplaceTempView("t") + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", + SQLConf.ADAPTIVE_AUTO_BROADCASTJOIN_THRESHOLD.key -> "800") { + val query = + """ + |WITH tt AS ( + | SELECT key, a, b, explode(c) AS c FROM t + |) + |SELECT t1.key, t1.c, t2.key, t2.c + |FROM (SELECT a, b, c, key FROM tt WHERE a > 1) t1 + |JOIN (SELECT a, b, c, key FROM tt) t2 + | ON t1.key = t2.key + |""".stripMargin + val (origin, adaptive) = runAdaptiveAndVerifyResult(query) + assert(findTopLevelSortMergeJoin(origin).size == 1) + assert(findTopLevelBroadcastHashJoin(adaptive).size == 1) } } - def findShuffleExchange(plan: SparkPlan): Seq[ShuffleExchangeExec] ={ - collect(plan) { - case j: ShuffleExchangeExec => j + } + + test("SPARK-35874: AQE Shuffle should wait for its subqueries to finish before materializing") { + withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") { + val query = "SELECT b FROM testData2 DISTRIBUTE BY (b, (SELECT max(key) FROM testData))" + runAdaptiveAndVerifyResult(query) + } + } + + test("SPARK-36032: Use inputPlan instead of currentPhysicalPlan to initialize logical link") { + withTempView("v") { + spark.sparkContext.parallelize( + (1 to 10).map(i => TestData(i, i.toString)), 2) + .toDF("c1", "c2").createOrReplaceTempView("v") + + Seq("-1", "10000").foreach { aqeBhj => + withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", + SQLConf.ADAPTIVE_AUTO_BROADCASTJOIN_THRESHOLD.key -> aqeBhj, + SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + val (origin, adaptive) = runAdaptiveAndVerifyResult( + """ + |SELECT * FROM v t1 JOIN ( + | SELECT c1 + 1 as c3 FROM v + |)t2 ON t1.c1 = t2.c3 + |SORT BY c1 + """.stripMargin) + if (aqeBhj.toInt < 0) { + // 1 sort since spark plan has no shuffle for SMJ + assert(findTopLevelSort(origin).size == 1) + // 2 sorts in SMJ + assert(findTopLevelSort(adaptive).size == 2) + } else { + assert(findTopLevelSort(origin).size == 1) + // 1 sort at top node and BHJ has no sort + assert(findTopLevelSort(adaptive).size == 1) + } + } } } + } + + test("SPARK-36424: Support eliminate limits in AQE Optimizer") { + withTempView("v") { + spark.sparkContext.parallelize( + (1 to 10).map(i => TestData(i, if (i > 2) "2" else i.toString)), 2) + .toDF("c1", "c2").createOrReplaceTempView("v") + + withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", + SQLConf.SHUFFLE_PARTITIONS.key -> "3") { + val (origin1, adaptive1) = runAdaptiveAndVerifyResult( + """ + |SELECT c2, sum(c1) FROM v GROUP BY c2 LIMIT 5 + """.stripMargin) + assert(findTopLevelLimit(origin1).size == 1) + assert(findTopLevelLimit(adaptive1).isEmpty) + + // eliminate limit through filter + val (origin2, adaptive2) = runAdaptiveAndVerifyResult( + """ + |SELECT c2, sum(c1) FROM v GROUP BY c2 HAVING sum(c1) > 1 LIMIT 5 + """.stripMargin) + assert(findTopLevelLimit(origin2).size == 1) + assert(findTopLevelLimit(adaptive2).isEmpty) + + // The strategy of Eliminate Limits batch should be fixedPoint + val (origin3, adaptive3) = runAdaptiveAndVerifyResult( + """ + |SELECT * FROM (SELECT c1 + c2 FROM (SELECT DISTINCT * FROM v LIMIT 10086)) LIMIT 20 + """.stripMargin + ) + assert(findTopLevelLimit(origin3).size == 1) + assert(findTopLevelLimit(adaptive3).isEmpty) + } + } + } + + test("SPARK-37063: OptimizeSkewInRebalancePartitions support optimize non-root node") { + withTempView("v") { + withSQLConf( + SQLConf.ADAPTIVE_OPTIMIZE_SKEWS_IN_REBALANCE_PARTITIONS_ENABLED.key -> "true", + SQLConf.SHUFFLE_PARTITIONS.key -> "1", + SQLConf.COALESCE_PARTITIONS_MIN_PARTITION_NUM.key -> "1") { + spark.sparkContext.parallelize( + (1 to 10).map(i => TestData(if (i > 2) 2 else i, i.toString)), 2) + .toDF("c1", "c2").createOrReplaceTempView("v") + + def checkRebalance(query: String, numShufflePartitions: Int): Unit = { + val (_, adaptive) = runAdaptiveAndVerifyResult(query) + assert(adaptive.collect { + case sort: SortExec => sort + }.size == 1) + val read = collect(adaptive) { + case read: AQEShuffleReadExec => read + } + assert(read.size == 1) + assert(read.head.partitionSpecs.forall(_.isInstanceOf[PartialReducerPartitionSpec])) + assert(read.head.partitionSpecs.size == numShufflePartitions) + } + + withSQLConf(SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES.key -> "50") { + checkRebalance("SELECT /*+ REBALANCE(c1) */ * FROM v SORT BY c1", 2) + checkRebalance("SELECT /*+ REBALANCE */ * FROM v SORT BY c1", 2) + } + } + } + } + + test("SPARK-37357: Add small partition factor for rebalance partitions") { + withTempView("v") { + withSQLConf( + SQLConf.ADAPTIVE_OPTIMIZE_SKEWS_IN_REBALANCE_PARTITIONS_ENABLED.key -> "true", + SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + spark.sparkContext.parallelize( + (1 to 8).map(i => TestData(if (i > 2) 2 else i, i.toString)), 3) + .toDF("c1", "c2").createOrReplaceTempView("v") + + def checkAQEShuffleReadExists(query: String, exists: Boolean): Unit = { + val (_, adaptive) = runAdaptiveAndVerifyResult(query) + assert( + collect(adaptive) { + case read: AQEShuffleReadExec => read + }.nonEmpty == exists) + } + + withSQLConf(SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES.key -> "200") { + withSQLConf(SQLConf.ADAPTIVE_REBALANCE_PARTITIONS_SMALL_PARTITION_FACTOR.key -> "0.5") { + // block size: [88, 97, 97] + checkAQEShuffleReadExists("SELECT /*+ REBALANCE(c1) */ * FROM v", false) + } + withSQLConf(SQLConf.ADAPTIVE_REBALANCE_PARTITIONS_SMALL_PARTITION_FACTOR.key -> "0.2") { + // block size: [88, 97, 97] + checkAQEShuffleReadExists("SELECT /*+ REBALANCE(c1) */ * FROM v", true) + } + } + } + } + } + + test("SPARK-37742: AQE reads invalid InMemoryRelation stats and mistakenly plans BHJ") { + withSQLConf( + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "1048584", + SQLConf.ADAPTIVE_OPTIMIZER_EXCLUDED_RULES.key -> AQEPropagateEmptyRelation.ruleName) { + // Spark estimates a string column as 20 bytes so with 60k rows, these relations should be + // estimated at ~120m bytes which is greater than the broadcast join threshold. + val joinKeyOne = "00112233445566778899" + val joinKeyTwo = "11223344556677889900" + Seq.fill(60000)(joinKeyOne).toDF("key") + .createOrReplaceTempView("temp") + Seq.fill(60000)(joinKeyTwo).toDF("key") + .createOrReplaceTempView("temp2") + + Seq(joinKeyOne).toDF("key").createOrReplaceTempView("smallTemp") + spark.sql("SELECT key as newKey FROM temp").persist() + + // This query is trying to set up a situation where there are three joins. + // The first join will join the cached relation with a smaller relation. + // The first join is expected to be a broadcast join since the smaller relation will + // fit under the broadcast join threshold. + // The second join will join the first join with another relation and is expected + // to remain as a sort-merge join. + // The third join will join the cached relation with another relation and is expected + // to remain as a sort-merge join. + val query = + s""" + |SELECT t3.newKey + |FROM + | (SELECT t1.newKey + | FROM (SELECT key as newKey FROM temp) as t1 + | JOIN + | (SELECT key FROM smallTemp) as t2 + | ON t1.newKey = t2.key + | ) as t3 + | JOIN + | (SELECT key FROM temp2) as t4 + | ON t3.newKey = t4.key + |UNION + |SELECT t1.newKey + |FROM + | (SELECT key as newKey FROM temp) as t1 + | JOIN + | (SELECT key FROM temp2) as t2 + | ON t1.newKey = t2.key + |""".stripMargin + val df = spark.sql(query) + df.collect() + val adaptivePlan = df.queryExecution.executedPlan + val bhj = findTopLevelBroadcastHashJoin(adaptivePlan) + assert(bhj.length == 1) + } + } + + test("SPARK-37328: skew join with 3 tables") { withSQLConf( SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", - "spark.shuffle.manager"-> "sort", SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", + SQLConf.SKEW_JOIN_SKEWED_PARTITION_THRESHOLD.key -> "100", SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES.key -> "100", - SQLConf.COALESCE_PARTITIONS_ENABLED.key -> "true") { - spark - .range(1, 1000, 1).where("id > 995").createOrReplaceTempView("t1") - spark - .range(1, 5, 1)createOrReplaceTempView("t2") + SQLConf.COALESCE_PARTITIONS_MIN_PARTITION_NUM.key -> "1", + SQLConf.SHUFFLE_PARTITIONS.key -> "10") { + withTempView("skewData1", "skewData2", "skewData3") { + spark + .range(0, 1000, 1, 10) + .selectExpr("id % 3 as key1", "id % 3 as value1") + .createOrReplaceTempView("skewData1") + spark + .range(0, 1000, 1, 10) + .selectExpr("id % 1 as key2", "id as value2") + .createOrReplaceTempView("skewData2") + spark + .range(0, 1000, 1, 10) + .selectExpr("id % 1 as key3", "id as value3") + .createOrReplaceTempView("skewData3") + + // skewedJoin doesn't happen in last stage + val (_, adaptive1) = + runAdaptiveAndVerifyResult("SELECT key1 FROM skewData1 JOIN skewData2 ON key1 = key2 " + + "JOIN skewData3 ON value2 = value3") + val shuffles1 = collect(adaptive1) { + case s: ShuffleExchangeExec => s + } + assert(shuffles1.size == 4) + val smj1 = findTopLevelSortMergeJoin(adaptive1) + assert(smj1.size == 2 && smj1.last.isSkewJoin && !smj1.head.isSkewJoin) + + // Query has two skewJoin in two continuous stages. + val (_, adaptive2) = + runAdaptiveAndVerifyResult("SELECT key1 FROM skewData1 JOIN skewData2 ON key1 = key2 " + + "JOIN skewData3 ON value1 = value3") + val shuffles2 = collect(adaptive2) { + case s: ShuffleExchangeExec => s + } + assert(shuffles2.size == 4) + val smj2 = findTopLevelSortMergeJoin(adaptive2) + assert(smj2.size == 2 && smj2.forall(_.isSkewJoin)) + } + } + } + + test("SPARK-37652: optimize skewed join through union") { + withSQLConf( + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", + SQLConf.SKEW_JOIN_SKEWED_PARTITION_THRESHOLD.key -> "100", + SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES.key -> "100") { + withTempView("skewData1", "skewData2") { + spark + .range(0, 1000, 1, 10) + .selectExpr("id % 3 as key1", "id as value1") + .createOrReplaceTempView("skewData1") + spark + .range(0, 1000, 1, 10) + .selectExpr("id % 1 as key2", "id as value2") + .createOrReplaceTempView("skewData2") + + def checkSkewJoin(query: String, joinNums: Int, optimizeSkewJoinNums: Int): Unit = { + val (_, innerAdaptivePlan) = runAdaptiveAndVerifyResult(query) + val joins = findTopLevelSortMergeJoin(innerAdaptivePlan) + val optimizeSkewJoins = joins.filter(_.isSkewJoin) + assert(joins.size == joinNums && optimizeSkewJoins.size == optimizeSkewJoinNums) + } + + // skewJoin union skewJoin + checkSkewJoin( + "SELECT key1 FROM skewData1 JOIN skewData2 ON key1 = key2 " + + "UNION ALL SELECT key2 FROM skewData1 JOIN skewData2 ON key1 = key2", 2, 2) + + // skewJoin union aggregate + checkSkewJoin( + "SELECT key1 FROM skewData1 JOIN skewData2 ON key1 = key2 " + + "UNION ALL SELECT key2 FROM skewData2 GROUP BY key2", 1, 1) + + // skewJoin1 union (skewJoin2 join aggregate) + // skewJoin2 will lead to extra shuffles, but skew1 cannot be optimized + checkSkewJoin( + "SELECT key1 FROM skewData1 JOIN skewData2 ON key1 = key2 UNION ALL " + + "SELECT key1 from (SELECT key1 FROM skewData1 JOIN skewData2 ON key1 = key2) tmp1 " + + "JOIN (SELECT key2 FROM skewData2 GROUP BY key2) tmp2 ON key1 = key2", 3, 0) + } + } + } + + test("SPARK-38162: Optimize one row plan in AQE Optimizer") { + withTempView("v") { + spark.sparkContext.parallelize( + (1 to 4).map(i => TestData(i, i.toString)), 2) + .toDF("c1", "c2").createOrReplaceTempView("v") + + // remove sort + val (origin1, adaptive1) = runAdaptiveAndVerifyResult( + """ + |SELECT * FROM v where c1 = 1 order by c1, c2 + |""".stripMargin) + assert(findTopLevelSort(origin1).size == 1) + assert(findTopLevelSort(adaptive1).isEmpty) + + // convert group only aggregate to project + val (origin2, adaptive2) = runAdaptiveAndVerifyResult( + """ + |SELECT distinct c1 FROM (SELECT /*+ repartition(c1) */ * FROM v where c1 = 1) + |""".stripMargin) + assert(findTopLevelAggregate(origin2).size == 2) + assert(findTopLevelAggregate(adaptive2).isEmpty) + + // remove distinct in aggregate + val (origin3, adaptive3) = runAdaptiveAndVerifyResult( + """ + |SELECT sum(distinct c1) FROM (SELECT /*+ repartition(c1) */ * FROM v where c1 = 1) + |""".stripMargin) + assert(findTopLevelAggregate(origin3).size == 4) + assert(findTopLevelAggregate(adaptive3).size == 2) + + // do not optimize if the aggregate is inside query stage + val (origin4, adaptive4) = runAdaptiveAndVerifyResult( + """ + |SELECT distinct c1 FROM v where c1 = 1 + |""".stripMargin) + assert(findTopLevelAggregate(origin4).size == 2) + assert(findTopLevelAggregate(adaptive4).size == 2) + + val (origin5, adaptive5) = runAdaptiveAndVerifyResult( + """ + |SELECT sum(distinct c1) FROM v where c1 = 1 + |""".stripMargin) + assert(findTopLevelAggregate(origin5).size == 4) + assert(findTopLevelAggregate(adaptive5).size == 4) + } + } + + test("SPARK-39551: Invalid plan check - invalid broadcast query stage") { + withSQLConf( + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") { val (_, adaptivePlan) = runAdaptiveAndVerifyResult( - "SELECT * FROM t1 JOIN t2 ON t1.id = t2.id") - val shuffleNum = findShuffleExchange(adaptivePlan) - assert(shuffleNum.length == 2) - val shuffleReaderNum = findCustomShuffleReader(adaptivePlan) - assert(shuffleReaderNum.length == 2) + """ + |SELECT /*+ BROADCAST(t3) */ t3.b, count(t3.a) FROM testData2 t1 + |INNER JOIN testData2 t2 + |ON t1.b = t2.b AND t1.a = 0 + |RIGHT OUTER JOIN testData2 t3 + |ON t1.a > t3.a + |GROUP BY t3.b + """.stripMargin + ) + assert(findTopLevelBroadcastNestedLoopJoin(adaptivePlan).size == 1) + } + } + + test("SPARK-39915: Dataset.repartition(N) may not create N partitions") { + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "6") { + // partitioning: HashPartitioning + // shuffleOrigin: REPARTITION_BY_NUM + assert(spark.range(0).repartition(5, $"id").rdd.getNumPartitions == 5) + // shuffleOrigin: REPARTITION_BY_COL + // The minimum partition number after AQE coalesce is 1 + assert(spark.range(0).repartition($"id").rdd.getNumPartitions == 1) + // through project + assert(spark.range(0).selectExpr("id % 3 as c1", "id % 7 as c2") + .repartition(5, $"c1").select($"c2").rdd.getNumPartitions == 5) + + // partitioning: RangePartitioning + // shuffleOrigin: REPARTITION_BY_NUM + // The minimum partition number of RangePartitioner is 1 + assert(spark.range(0).repartitionByRange(5, $"id").rdd.getNumPartitions == 1) + // shuffleOrigin: REPARTITION_BY_COL + assert(spark.range(0).repartitionByRange($"id").rdd.getNumPartitions == 1) + + // partitioning: RoundRobinPartitioning + // shuffleOrigin: REPARTITION_BY_NUM + assert(spark.range(0).repartition(5).rdd.getNumPartitions == 5) + // shuffleOrigin: REBALANCE_PARTITIONS_BY_NONE + assert(spark.range(0).repartition().rdd.getNumPartitions == 0) + // through project + assert(spark.range(0).selectExpr("id % 3 as c1", "id % 7 as c2") + .repartition(5).select($"c2").rdd.getNumPartitions == 5) + + // partitioning: SinglePartition + assert(spark.range(0).repartition(1).rdd.getNumPartitions == 1) + } + } + test("SPARK-39915: Ensure the output partitioning is user-specified") { + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "3", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { + val df1 = spark.range(1).selectExpr("id as c1") + val df2 = spark.range(1).selectExpr("id as c2") + val df = df1.join(df2, col("c1") === col("c2")).repartition(3, col("c1")) + assert(df.rdd.getNumPartitions == 3) } } } + +/** + * Invalid implementation class for [[CostEvaluator]]. + */ +private class InvalidCostEvaluator() {} + +/** + * A simple [[CostEvaluator]] to count number of [[ShuffleExchangeLike]] and [[SortExec]]. + */ +private case class SimpleShuffleSortCostEvaluator() extends CostEvaluator { + override def evaluateCost(plan: SparkPlan): Cost = { + val cost = plan.collect { + case s: ShuffleExchangeLike => s + case s: SortExec => s + }.size + SimpleCost(cost) + } +} diff --git a/omnioperator/omniop-spark-extension/pom.xml b/omnioperator/omniop-spark-extension/pom.xml index 026fc59977b443256c933202f1ebb1dbc19ce3d7..fab207f793b948648cad6b9f6c6f5ed6d585af08 100644 --- a/omnioperator/omniop-spark-extension/pom.xml +++ b/omnioperator/omniop-spark-extension/pom.xml @@ -8,14 +8,14 @@ com.huawei.kunpeng boostkit-omniop-spark-parent pom - 3.1.1-1.1.0 + 3.3.1-1.1.0 BoostKit Spark Native Sql Engine Extension Parent Pom 2.12.10 2.12 - 3.1.1 + 3.3.1 3.2.2 UTF-8 UTF-8 @@ -55,6 +55,18 @@ org.apache.curator curator-recipes + + com.fasterxml.jackson.core + jackson-core + + + com.fasterxml.jackson.core + jackson-annotations + + + com.fasterxml.jackson.core + jackson-databind +