diff --git a/omnioperator/omniop-spark-extension/java/src/main/java/org/apache/spark/sql/execution/vectorized/OmniColumnVector.java b/omnioperator/omniop-spark-extension/java/src/main/java/org/apache/spark/sql/execution/vectorized/OmniColumnVector.java index c0e00761ec3c77120b8563924be8fa464339ba65..fcf82daa62d638e85c6270a4daf11d4c5ede176e 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/java/org/apache/spark/sql/execution/vectorized/OmniColumnVector.java +++ b/omnioperator/omniop-spark-extension/java/src/main/java/org/apache/spark/sql/execution/vectorized/OmniColumnVector.java @@ -192,6 +192,34 @@ public class OmniColumnVector extends WritableColumnVector { @Override public boolean hasNull() { + if (dictionaryData != null) { + return dictionaryData.hasNullValue(); + } + if (type instanceof BooleanType) { + return booleanDataVec.hasNullValue(); + } else if (type instanceof ByteType) { + return charsTypeDataVec.hasNullValue(); + } else if (type instanceof ShortType) { + return shortDataVec.hasNullValue(); + } else if (type instanceof IntegerType) { + return intDataVec.hasNullValue(); + } else if (type instanceof DecimalType) { + if (DecimalType.is64BitDecimalType(type)) { + return longDataVec.hasNullValue(); + } else { + return decimal128DataVec.hasNullValue(); + } + } else if (type instanceof LongType || DecimalType.is64BitDecimalType(type)) { + return longDataVec.hasNullValue(); + } else if (type instanceof FloatType) { + return false; + } else if (type instanceof DoubleType) { + return doubleDataVec.hasNullValue(); + } else if (type instanceof StringType) { + return charsTypeDataVec.hasNullValue(); + } else if (type instanceof DateType) { + return intDataVec.hasNullValue(); + } throw new UnsupportedOperationException("hasNull is not supported"); } diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarGuardRule.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarGuardRule.scala index f81ea4d0503c2fcb2875b0bf7b07ab866699eda8..1104ed4df65cddc260d43f45e187075085ba610a 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarGuardRule.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarGuardRule.scala @@ -117,7 +117,7 @@ case class ColumnarGuardRule() extends Rule[SparkPlan] { ColumnarUnionExec(plan.children).buildCheck() case plan: ShuffleExchangeExec => if (!enableColumnarShuffle) return false - new ColumnarShuffleExchangeExec(plan.outputPartitioning, plan.child).buildCheck() + new ColumnarShuffleExchangeExec(plan.outputPartitioning, plan.child, plan.shuffleOrgin).buildCheck() case plan: BroadcastHashJoinExec => // We need to check if BroadcastExchangeExec can be converted to columnar-based. // If not, BHJ should also be row-based. diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/BroadcastColumnarRDD.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/BroadcastColumnarRDD.scala new file mode 100644 index 0000000000000000000000000000000000000000..1738766130c4d3b34c35cde70fdfe345cd84524f --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/BroadcastColumnarRDD.scala @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import nova.hetu.omniruntime.vector.VecBatch +import nova.hetu.omniruntime.vector.serialize.VecBatchSerializerFactory +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.execution.metric.SQLMetric +import org.apache.spark.sql.execution.vectorized.OmniColumnVector +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.vectorized.ColumnarBatch +import org.apache.spark.{Partition, SparkContext, TaskContext, broadcast} + + +private final case class BroadcastColumnarRDDPartition(index: Int) extends Partition + +case class BroadcastColumnarRDD( + @transient private val sc: SparkContext, + metrics: Map[String, SQLMetric], + numPartitioning: Int, + inputByteBuf: broadcast.Broadcast[ColumnarHashedRelation], + localSchema: StructType) + extends RDD[ColumnarBatch](sc, Nil) { + + override protected def getPartitions: Array[Partition] = { + (0 until numPartitioning).map { index => new BroadcastColumnarRDDPartition(index) }.toArray + } + + private def vecBatchToColumnarBatch(vecBatch: VecBatch): ColumnarBatch = { + val vectors: Seq[OmniColumnVector] = OmniColumnVector.allocateColumns( + vecBatch.getRowCount, localSchema, false) + vectors.zipWithIndex.foreach { case (vector, i) => + vector.reset() + vector.setVec(vecBatch.getVectors()(i)) + } + vecBatch.close() + new ColumnarBatch(vectors.toArray, vecBatch.getRowCount) + } + + override def compute(split: Partition, context: TaskContext): Iterator[ColumnarBatch] = { + // val relation = inputByteBuf.value.asReadOnlyCopy + // new CloseableColumnBatchIterator(relation.getColumnarBatchAsIter) + val deserializer = VecBatchSerializerFactory.create() + new Iterator[ColumnarBatch] { + var idx = 0 + val total_len = inputByteBuf.value.buildData.length + + override def hasNext: Boolean = idx < total_len + + override def next(): ColumnarBatch = { + val tmp_idx = idx + idx += 1 + val batch: VecBatch = deserializer.deserialize(inputByteBuf.value.buildData(tmp_idx)) + vecBatchToColumnarBatch(batch) + } + } + } +} diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarBroadcastExchangeAdaptorExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarBroadcastExchangeAdaptorExec.scala new file mode 100644 index 0000000000000000000000000000000000000000..dbeaa5380708bf12650847c414bbe678e9f7c991 --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarBroadcastExchangeAdaptorExec.scala @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import nova.hetu.omniruntime.vector.Vec +import org.apache.spark.broadcast +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Attribute, SortOrder, UnsafeProjection} +import org.apache.spark.sql.catalyst.plans.physical.{Paritioning, UnknownPartitioning} +import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} +import org.apache.spark.sql.execution.util.SparkMemoryUtils +import org.apache.spark.sql.execution.vectorized.OmniColumnVector +import org.apache.spark.sql.types.StructType + +import scala.collection.JavaConverters.asScalaIteratorConverter +import scala.collection.mutable.ListBuffer + +case class ColumnarBroadcastExchangeAdaptorExec(child: SparkPlan, numPartitions: Int) + extends UnaryExecNode { + override def output: Seq[Attribute] = child.output + + override def outputPartitioning: Partitioning = UnknowPartitioning(numPartitions) + + override def outputOrdering: Seq[SortOrder] = child.outputOrdering + + override def doExecute(): RDD[InternalRow] = { + val numOutputRows: SQLMetric = longMetric("numOutputRows") + val numOutputBatches: SQLMetric = longMetric("numOutputBatches") + val inputRdd: BroadcastColumnarRDD = BroadcastColumnarRDD( + sparkContext, + metrics, + numPartitions, + child.executeBroadcast(), + StructType.fromAttributes(child.output)) + inputRdd.mapPartitions { batches => + + val toUnsafe = UnsafeProjection.create(output, output) + val vecsTmp = new ListBuffer[Vec] + + val batchIter = batches.flatMap { batch => + for (i <- 0 until batch.numCols()) { + batch.column(i) match { + case vector: OmniColumnVector => + vecsTmp.append(vector.getVec) + case _ => + } + } + numOutputBatches += 1 + numOutputRows += batch.numRows() + batch.rowIterator().asScala.map(toUnsafe) + } + + SparkMemoryUtils.addLeakSafeTaskCompletionListener { _ => + vecsTmp.foreach { vec => + vec.close() + } + } + batchIter + } + } + + override def doExecuteBroadcast[T](): broadcast.Broadcast[T] = { + child.executeBroadcast() + } + + override def supportsColumnar: Boolean = true + + override lazy val metrics: Map[String, SQLMetric] = Map( + "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"), + "numOutputBatches" -> SQLMetrics.createMetric(sparkContext, "output_batches"), + "processTime" -> SQLMetrics.createTimingMetric(sparkContext, "totaltime_datatoarrowcolumnar")) +} \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarBroadcastExchangeExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarBroadcastExchangeExec.scala index 8e57a73978edf318b50e554cc5099f3c78f57bfd..8ab542695470344090d8cebe71bc38eef64bd276 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarBroadcastExchangeExec.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarBroadcastExchangeExec.scala @@ -32,6 +32,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.plans.physical.BroadcastMode import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, BroadcastExchangeLike} +import org.apache.spark.sql.execution.joins.{EmptyHashedRelation, HashedRelationBroadcastMode, HashedRelationWithAllNullKeys} import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf} import org.apache.spark.unsafe.map.BytesToBytesMap @@ -69,12 +70,19 @@ class ColumnarBroadcastExchangeExec(mode: BroadcastMode, child: SparkPlan) // Setup a job group here so later it may get cancelled by groupId if necessary. sparkContext.setJobGroup(runId.toString, s"broadcast exchange (runId $runId)", interruptOnCancel = true) + val nullBatchCount = sparkContext.longAccumulator("nullBatchCount") val beforeCollect = System.nanoTime() val numRows = longMetric("numOutputRows") val dataSize = longMetric("dataSize") // Use executeCollect/executeCollectIterator to avoid conversion to Scala types val input = child.executeColumnar().mapPartitions { iter => val serializer = VecBatchSerializerFactory.create() + var nullRelationFlag = false + mode match { + case hashRelMode: HashedRelationBroadcastMode => + nullRelationFlag = hashRelMode.isNullAware + case _ => + } new Iterator[Array[Byte]] { override def hasNext: Boolean = { iter.hasNext @@ -82,6 +90,18 @@ class ColumnarBroadcastExchangeExec(mode: BroadcastMode, child: SparkPlan) override def next(): Array[Byte] = { val batch = iter.next() + var index = 0 + try { + while (nullRelationFlag && nullBatchCount.value == 0 && index < batch.numCols()) { + val vec = batch.column(index) + if (vec.hasNull) { + nullBatchCount.add(1) + } + index = index + 1 + } + } catch { + case e: Exception => logError(s"compute nullBatchCount error: ${e.getMessage}.") + } val vectors = transColBatchToOmniVecs(batch) val vecBatch = new VecBatch(vectors, batch.numRows()) numRows += vecBatch.getRowCount @@ -94,6 +114,8 @@ class ColumnarBroadcastExchangeExec(mode: BroadcastMode, child: SparkPlan) } } }.collect() + val relation = new ColumnarHashedRelation + relation.converterData(mode, nullBatchCount.value, input) val numOutputRows = numRows.value if (numOutputRows >= MAX_BROADCAST_TABLE_ROWS) { throw new SparkException(s"Cannot broadcast the table over " + @@ -104,7 +126,7 @@ class ColumnarBroadcastExchangeExec(mode: BroadcastMode, child: SparkPlan) longMetric("collectTime") += NANOSECONDS.toMillis(beforeBroadcast - beforeCollect) // Broadcast the relation - val broadcasted: broadcast.Broadcast[Any] = sparkContext.broadcast(input) + val broadcasted: broadcast.Broadcast[Any] = sparkContext.broadcast(relation) longMetric("broadcastTime") += NANOSECONDS.toMillis( System.nanoTime() - beforeBroadcast) val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) @@ -163,6 +185,21 @@ class ColumnarBroadcastExchangeExec(mode: BroadcastMode, child: SparkPlan) } } +class ColumnarHashedRelation extends Serializable { + var relation: Any = _ + var buildData: Array[Array[Byte]] = new Array[Array[Byte]](0) + + def converterData(mode: BroadcastMode, relationByte: Long, array: Array[Array[Byte]]): Unit = { + if (mode.isInstanceOf[HashedRelationBroadcastMode] && array.isEmpty) { + relation = EmptyHashedRelation + } + if (relationByte >= 1) { + relation = HashedRelationWithAllNullKeys + } + buildData = array + } +} + object ColumnarBroadcastExchangeExec { // Since the maximum number of keys that BytesToBytesMap supports is 1 << 29, // and only 70% of the slots can be used before growing in HashedRelation, diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/adaptive/EliminateJoinToEmptyRelation.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/adaptive/EliminateJoinToEmptyRelation.scala new file mode 100644 index 0000000000000000000000000000000000000000..4edf0f4f86cb79a3e2b3a5c2fc01c999a42349c8 --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/adaptive/EliminateJoinToEmptyRelation.scala @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.adaptive + +import org.apache.spark.sql.catalyst.planning.ExtractSingleColumnNullAwareAntiJoin +import org.apache.spark.sql.catalyst.plans.{Inner, LeftSemi} +import org.apache.spark.sql.catalyst.plans.logical.{Join, LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution.ColumnarHashedRelation +import org.apache.spark.sql.execution.joins.{EmptyHashedRelation, HashedRelation, HashedRelationWithAllNullKeys} + +/** + * This optimization rule detects and converts a Join to an empty [[LocalRelation]]: + * 1. Join is single column NULL-aware anti join (NAAJ), and broadcasted [[HashedRelation]] + * is [[HashedRelationWithAllNullKeys]]. + * + * 2. Join is inner or left semi join, and broadcasted [[HashedRelation]] + * is [[EmptyHashedRelation]]. + * This applies to all Joins (sort merge join, shuffled hash join, and broadcast hash join), + * because sort merge join and shuffled hash join will be changed to broadcast hash join with AQE + * at the first place. + */ +object EliminateJoinToEmptyRelation extends Rule[LogicalPlan] { + + private def canEliminate(plan: LogicalPlan, relation: HashedRelation): Boolean = plan match { + case LogicalQueryStage(_, stage: BroadcastQueryStageExec) if stage.resultOption.get().isDefined + && stage.broadcast.relationFuture.get().value == relation => true + case LogicalQueryStage(_, stage: BroadcastQueryStageExec) if stage.resultOption.get().isDefined + && stage.broadcast.supportsColumnar => { + val cr = stage.broadcast.relationFuture.get().value.asInstanceOf[ColumnarHashedRelation] + cr.relation == relation + } + case _ => false + } + + def apply(plan: LogicalPlan): LogicalPlan = plan.transformDown { + case j @ ExtractSingleColumnNullAwareAntiJoin(_, _) + if canEliminate(j.right, HashedRelationWithAllNullKeys) => + LocalRelation(j.output, data = Seq.empty, isStreaming = j.isStreaming) + + case j @ Join(_, _, Inner, _, _) if canEliminate(j.left, EmptyHashedRelation) || + canEliminate(j.right, EmptyHashedRelation) => + LocalRelation(j.output, data = Seq.empty, isStreaming = j.isStreaming) + + case j @ Join(_, _, LeftSemi, _, _) if canEliminate(j.right, EmptyHashedRelation) => + LocalRelation(j.output, data = Seq.empty, isStreaming = j.isStreaming) + } +} diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/joins/ColumnarBroadcastHashJoinExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/joins/ColumnarBroadcastHashJoinExec.scala index e632831a0cea3b640eace6f942030e4966c5b8ca..a71eaac030ef4fcdb622908063b52f8baea1e16a 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/joins/ColumnarBroadcastHashJoinExec.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/joins/ColumnarBroadcastHashJoinExec.scala @@ -41,7 +41,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide} import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.execution.{CodegenSupport, SparkPlan} +import org.apache.spark.sql.execution.{CodegenSupport, ColumnarHashedRelation, SparkPlan} import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.execution.util.{MergeIterator, SparkMemoryUtils} import org.apache.spark.sql.execution.vectorized.OmniColumnVector @@ -267,7 +267,7 @@ case class ColumnarBroadcastHashJoinExec( OmniExpressionAdaptor.rewriteToOmniJsonExpressionLiteral(x, OmniExpressionAdaptor.getExprIdMap(buildOutput.map(_.toAttribute))) }.toArray - val buildData = buildPlan.executeBroadcast[Array[Array[Byte]]]() + val relation = buildPlan.executeBroadcast[ColumnarHashedRelation]() // TODO: check val buildOutputTypes = buildTypes // {1,1} @@ -295,7 +295,7 @@ case class ColumnarBroadcastHashJoinExec( buildCodegenTime += NANOSECONDS.toMillis(System.nanoTime() - startBuildCodegen) val deserializer = VecBatchSerializerFactory.create() - buildData.value.foreach { input => + relation.value.buildData.foreach { input => val startBuildInput = System.nanoTime() buildOp.addInput(deserializer.deserialize(input)) buildAddInputTime += NANOSECONDS.toMillis(System.nanoTime() - startBuildInput) diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/joins/ColumnarSortMergeJoinExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/joins/ColumnarSortMergeJoinExec.scala index dfe539662b0616781e96039d31b8e86147718df8..7af33efd0e5b88b937c668f531ec837dfc59dfe8 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/joins/ColumnarSortMergeJoinExec.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/joins/ColumnarSortMergeJoinExec.scala @@ -63,7 +63,9 @@ class ColumnarSortMergeJoinExec( override def supportCodegen: Boolean = false - override def nodeName: String = "OmniColumnarSortMergeJoin" + override def nodeName: String = { + if (isSkewJoin) "OmniColumnarSortMergeJoin(skew=true)" else "OmniColumnarSortMergeJoin" + } val SMJ_NEED_ADD_STREAM_TBL_DATA = 2 val SMJ_NEED_ADD_BUFFERED_TBL_DATA = 3 diff --git a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/CoalesceShufflePartitionsSuite.scala b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/CoalesceShufflePartitionsSuite.scala new file mode 100644 index 0000000000000000000000000000000000000000..4ec953e8e8a2b247e3ec3e45fd0d372aeb688b44 --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/CoalesceShufflePartitionsSuite.scala @@ -0,0 +1,413 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.scalatest.BeforeAndAfterAll +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.internal.config.UI.UI_ENABLED +import org.apache.spark.sql._ +import org.apache.spark.sql.execution.adaptive._ +import org.apache.spark.sql.execution.exchange.ReusedExchangeExec +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf} + +class CoalesceShufflePartitionsSuite extends SparkFunSuite with BeforeAndAfterAll { + + private var originalActiveSparkSession: Option[SparkSession] = _ + private var originalInstantiatedSparkSession: Option[SparkSession] = _ + + override protected def beforeAll(): Unit = { + super.beforeAll() + originalActiveSparkSession = SparkSession.getActiveSession + originalInstantiatedSparkSession = SparkSession.getDefaultSession + + SparkSession.clearActiveSession() + SparkSession.clearDefaultSession() + } + + override protected def afterAll(): Unit = { + try { + // Set these states back. + originalActiveSparkSession.foreach(ctx => SparkSession.setActiveSession(ctx)) + originalInstantiatedSparkSession.foreach(ctx => SparkSession.setDefaultSession(ctx)) + } finally { + super.afterAll() + } + } + + val numInputPartitions: Int = 10 + + def withSparkSession( + f: SparkSession => Unit, + targetPostShuffleInputSize: Int, + minNumPostShufflePartitions: Option[Int]): Unit = { + val sparkConf = + new SparkConf(false) + .setMaster("local[*]") + .setAppName("test") + .set(UI_ENABLED, false) + .set(SQLConf.SHUFFLE_PARTITIONS.key, "5") + .set(SQLConf.COALESCE_PARTITIONS_INITIAL_PARTITION_NUM.key, "5") + .set(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key, "true") + .set(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key, "-1") + .set( + SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES.key, + targetPostShuffleInputSize.toString) + .set(StaticSQLConf.SPARK_SESSION_EXTENSIONS.key, "com.huawei.boostkit.spark.ColumnarPlugin") + .set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.ColumnarShuffleManager") + minNumPostShufflePartitions match { + case Some(numPartitions) => + sparkConf.set(SQLConf.COALESCE_PARTITIONS_MIN_PARTITION_NUM.key, numPartitions.toString) + case None => + sparkConf.set(SQLConf.COALESCE_PARTITIONS_MIN_PARTITION_NUM.key, "1") + } + + val spark = SparkSession.builder() + .config(sparkConf) + .getOrCreate() + try f(spark) finally spark.stop() + } + + Seq(Some(5), None).foreach { minNumPostShufflePartitions => + val testNameNote = minNumPostShufflePartitions match { + case Some(numPartitions) => "(minNumPostShufflePartitions: " + numPartitions + ")" + case None => "" + } + + test(s"determining the number of reducers: aggregate operator$testNameNote") { + val test = { spark: SparkSession => + val df = + spark + .range(0, 1000, 1, numInputPartitions) + .selectExpr("id % 20 as key", "id as value") + val agg = df.groupBy("key").count() + + // Check the answer first. + QueryTest.checkAnswer( + agg, + spark.range(0, 20).selectExpr("id", "50 as cnt").collect()) + + // Then, let's look at the number of post-shuffle partitions estimated + // by the ExchangeCoordinator. + val finalPlan = agg.queryExecution.executedPlan + .asInstanceOf[AdaptiveSparkPlanExec].executedPlan + val shuffleReaders = finalPlan.collect { + case r @ ColumnarCoalescedShuffleReader() => r + } + assert(shuffleReaders.length === 1) + minNumPostShufflePartitions match { + case Some(numPartitions) => + shuffleReaders.foreach { reader => + assert(reader.outputPartitioning.numPartitions === numPartitions) + } + case None => + shuffleReaders.foreach { reader => + assert(reader.outputPartitioning.numPartitions === 3) + } + } + } + + withSparkSession(test, 1500, minNumPostShufflePartitions) + } + + test(s"determining the number of reducers: join operator$testNameNote") { + val test = { spark: SparkSession => + val df1 = + spark + .range(0, 1000, 1, numInputPartitions) + .selectExpr("id % 500 as key1", "id as value1") + val df2 = + spark + .range(0, 1000, 1, numInputPartitions) + .selectExpr("id % 500 as key2", "id as value2") + + val join = df1.join(df2, col("key1") === col("key2")).select(col("key1"), col("value2")) + + // Check the answer first. + val expectedAnswer = + spark + .range(0, 1000) + .selectExpr("id % 500 as key", "id as value") + .union(spark.range(0, 1000).selectExpr("id % 500 as key", "id as value")) + QueryTest.checkAnswer( + join, + expectedAnswer.collect()) + + // Then, let's look at the number of post-shuffle partitions estimated + // by the ExchangeCoordinator. + val finalPlan = join.queryExecution.executedPlan + .asInstanceOf[AdaptiveSparkPlanExec].executedPlan + val shuffleReaders = finalPlan.collect { + case r @ ColumnarCoalescedShuffleReader() => r + } + assert(shuffleReaders.length === 2) + minNumPostShufflePartitions match { + case Some(numPartitions) => + shuffleReaders.foreach { reader => + assert(reader.outputPartitioning.numPartitions === numPartitions) + } + case None => + shuffleReaders.foreach { reader => + assert(reader.outputPartitioning.numPartitions === 2) + } + } + } + + withSparkSession(test, 11384, minNumPostShufflePartitions) + } + + test(s"determining the number of reducers: complex query 1$testNameNote") { + val test: (SparkSession) => Unit = { spark: SparkSession => + val df1 = + spark + .range(0, 1000, 1, numInputPartitions) + .selectExpr("id % 500 as key1", "id as value1") + .groupBy("key1") + .count() + .toDF("key1", "cnt1") + val df2 = + spark + .range(0, 1000, 1, numInputPartitions) + .selectExpr("id % 500 as key2", "id as value2") + .groupBy("key2") + .count() + .toDF("key2", "cnt2") + + val join = df1.join(df2, col("key1") === col("key2")).select(col("key1"), col("cnt2")) + + // Check the answer first. + val expectedAnswer = + spark + .range(0, 500) + .selectExpr("id", "2 as cnt") + QueryTest.checkAnswer( + join, + expectedAnswer.collect()) + + // Then, let's look at the number of post-shuffle partitions estimated + // by the ExchangeCoordinator. + val finalPlan = join.queryExecution.executedPlan + .asInstanceOf[AdaptiveSparkPlanExec].executedPlan + val shuffleReaders = finalPlan.collect { + case r @ ColumnarCoalescedShuffleReader() => r + } + assert(shuffleReaders.length === 2) + minNumPostShufflePartitions match { + case Some(numPartitions) => + shuffleReaders.foreach { reader => + assert(reader.outputPartitioning.numPartitions === numPartitions) + } + case None => + shuffleReaders.foreach { reader => + assert(reader.outputPartitioning.numPartitions === 3) + } + } + } + + withSparkSession(test, 5384, minNumPostShufflePartitions) + } + + test(s"determining the number of reducers: complex query 2$testNameNote") { + val test: (SparkSession) => Unit = { spark: SparkSession => + val df1 = + spark + .range(0, 1000, 1, numInputPartitions) + .selectExpr("id % 500 as key1", "id as value1") + .groupBy("key1") + .count() + .toDF("key1", "cnt1") + val df2 = + spark + .range(0, 1000, 1, numInputPartitions) + .selectExpr("id % 500 as key2", "id as value2") + + val join = + df1 + .join(df2, col("key1") === col("key2")) + .select(col("key1"), col("cnt1"), col("value2")) + + // Check the answer first. + val expectedAnswer = + spark + .range(0, 1000) + .selectExpr("id % 500 as key", "2 as cnt", "id as value") + QueryTest.checkAnswer( + join, + expectedAnswer.collect()) + + // Then, let's look at the number of post-shuffle partitions estimated + // by the ExchangeCoordinator. + val finalPlan = join.queryExecution.executedPlan + .asInstanceOf[AdaptiveSparkPlanExec].executedPlan + val shuffleReaders = finalPlan.collect { + case r @ ColumnarCoalescedShuffleReader() => r + } + assert(shuffleReaders.length === 2) + minNumPostShufflePartitions match { + case Some(numPartitions) => + shuffleReaders.foreach { reader => + assert(reader.outputPartitioning.numPartitions === numPartitions) + } + case None => + shuffleReaders.foreach { reader => + assert(reader.outputPartitioning.numPartitions === 2) + } + } + } + + withSparkSession(test, 10000, minNumPostShufflePartitions) + } + + test(s"determining the number of reducers: plan already partitioned$testNameNote") { + val test: SparkSession => Unit = { spark: SparkSession => + try { + spark.range(1000).write.bucketBy(30, "id").saveAsTable("t") + // `df1` is hash partitioned by `id`. + val df1 = spark.read.table("t") + val df2 = + spark + .range(0, 1000, 1, numInputPartitions) + .selectExpr("id % 500 as key2", "id as value2") + + val join = df1.join(df2, col("id") === col("key2")).select(col("id"), col("value2")) + + // Check the answer first. + val expectedAnswer = spark.range(0, 500).selectExpr("id % 500", "id as value") + .union(spark.range(500, 1000).selectExpr("id % 500", "id as value")) + QueryTest.checkAnswer( + join, + expectedAnswer.collect()) + + // Then, let's make sure we do not reduce number of post shuffle partitions. + val finalPlan = join.queryExecution.executedPlan + .asInstanceOf[AdaptiveSparkPlanExec].executedPlan + val shuffleReaders = finalPlan.collect { + case r @ ColumnarCoalescedShuffleReader() => r + } + assert(shuffleReaders.length === 0) + } finally { + spark.sql("drop table t") + } + } + withSparkSession(test, 12000, minNumPostShufflePartitions) + } + } + + test("SPARK-24705 adaptive query execution works correctly when exchange reuse enabled") { + val test: SparkSession => Unit = { spark: SparkSession => + spark.sql("SET spark.sql.exchange.reuse=true") + val df = spark.range(1).selectExpr("id AS key", "id AS value") + + // test case 1: a query stage has 3 child stages but they are the same stage. + // Final Stage 1 + // ShuffleQueryStage 0 + // ReusedQueryStage 0 + // ReusedQueryStage 0 + val resultDf = df.join(df, "key").join(df, "key") + QueryTest.checkAnswer(resultDf, Row(0, 0, 0, 0) :: Nil) + val finalPlan = resultDf.queryExecution.executedPlan + .asInstanceOf[AdaptiveSparkPlanExec].executedPlan + assert(finalPlan.collect { + case ShuffleQueryStageExec(_, r: ReusedExchangeExec) => r + }.length == 2) + assert( + finalPlan.collect { + case r @ ColumnarCoalescedShuffleReader() => r + }.length == 3) + + + // test case 2: a query stage has 2 parent stages. + // Final Stage 3 + // ShuffleQueryStage 1 + // ShuffleQueryStage 0 + // ShuffleQueryStage 2 + // ReusedQueryStage 0 + val grouped = df.groupBy("key").agg(max("value").as("value")) + val resultDf2 = grouped.groupBy(col("key") + 1).max("value") + .union(grouped.groupBy(col("key") + 2).max("value")) + QueryTest.checkAnswer(resultDf2, Row(1, 0) :: Row(2, 0) :: Nil) + + val finalPlan2 = resultDf2.queryExecution.executedPlan + .asInstanceOf[AdaptiveSparkPlanExec].executedPlan + + // The result stage has 2 children + val level1Stages = finalPlan2.collect { case q: QueryStageExec => q } + assert(level1Stages.length == 2) + + val leafStages = level1Stages.flatMap { stage => + // All of the child stages of result stage have only one child stage. + val children = stage.plan.collect { case q: QueryStageExec => q } + assert(children.length == 1) + children + } + assert(leafStages.length == 2) + + val reusedStages = level1Stages.flatMap { stage => + stage.plan.collect { + case ShuffleQueryStageExec(_, r: ReusedExchangeExec) => r + } + } + assert(reusedStages.length == 1) + } + withSparkSession(test, 4, None) + } + + test("Do not reduce the number of shuffle partition for repartition") { + val test: SparkSession => Unit = { spark: SparkSession => + val ds = spark.range(3) + val resultDf = ds.repartition(2, ds.col("id")).toDF() + + QueryTest.checkAnswer(resultDf, + Seq(0, 1, 2).map(i => Row(i))) + val finalPlan = resultDf.queryExecution.executedPlan + .asInstanceOf[AdaptiveSparkPlanExec].executedPlan + assert( + finalPlan.collect { + case r @ ColumnarCoalescedShuffleReader() => r + }.isEmpty) + } + withSparkSession(test, 200, None) + } + + test("Union two datasets with different pre-shuffle partition number") { + val test: SparkSession => Unit = { spark: SparkSession => + val df1 = spark.range(3).join(spark.range(3), "id").toDF() + val df2 = spark.range(3).groupBy().sum() + + val resultDf = df1.union(df2) + + QueryTest.checkAnswer(resultDf, Seq((0), (1), (2), (3)).map(i => Row(i))) + + val finalPlan = resultDf.queryExecution.executedPlan + .asInstanceOf[AdaptiveSparkPlanExec].executedPlan + // As the pre-shuffle partition number are different, we will skip reducing + // the shuffle partition numbers. + assert( + finalPlan.collect { + case r @ ColumnarCoalescedShuffleReader() => r + }.isEmpty) + } + withSparkSession(test, 100, None) + } +} + +object ColumnarCoalescedShuffleReader { + def unapply(reader: ColumnarCustomShuffleReaderExec): Boolean = { + !reader.isLocalReader && !reader.hasSkewedPartition && reader.hasCoalescedPartition + } +} diff --git a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/adaptive/ColumnarAdaptiveQueryExecSuite.scala b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/adaptive/ColumnarAdaptiveQueryExecSuite.scala new file mode 100644 index 0000000000000000000000000000000000000000..e4da8c93ca3ab1eb14d06fe84831d528e69fcb8f --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/adaptive/ColumnarAdaptiveQueryExecSuite.scala @@ -0,0 +1,1490 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.adaptive + +import org.apache.log4j.Level +import org.apache.spark.Partition +import org.apache.spark.rdd.RDD +import org.apache.spark.scheduler.{SparkListener, SparkListenerEvent, SparkListenerJobStart} +import org.apache.spark.sql.{Dataset, Row, SparkSession, Strategy} +import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight} +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan} +import org.apache.spark.sql.execution.command.DataWritingCommandExec +import org.apache.spark.sql.execution.datasources.noop.NoopDataSource +import org.apache.spark.sql.execution.datasources.v2.V2TableWriteExec +import org.apache.spark.sql.execution.{ColumnarBroadcastExchangeExec, ColumnarSparkPlanTest, PartialReducerPartitionSpec, QueryExecution, ReusedSubqueryExec, ShuffledColumnarRDD, SparkPlan, UnaryExecNode} +import org.apache.spark.sql.execution.exchange.{Exchange, REPARTITION, REPARTITION_WITH_NUM, ReusedExchangeExec, ShuffleExchangeExec, ShuffleExchangeLike} +import org.apache.spark.sql.execution.joins.{BaseJoinExec, BroadcastHashJoinExec, ColumnarBroadcastHashJoinExec, ColumnarSortMergeJoinExec, SortMergeJoinExec} +import org.apache.spark.sql.execution.metric.SQLShuffleReadMetricsReporter +import org.apache.spark.sql.execution.ui.SparkListenerSQLAdaptiveExecutionUpdate +import org.apache.spark.sql.functions.{sum, when} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.SQLConf.PartitionOverwriteMode +import org.apache.spark.sql.types.{IntegerType, StructType} +import org.apache.spark.sql.util.QueryExecutionListener +import org.apache.spark.sql.vectorized.ColumnarBatch +import org.apache.spark.util.Utils + +import java.io.File +import java.net.URI + +class ColumnarAdaptiveQueryExecSuite extends ColumnarSparkPlanTest + with AdaptiveSparkPlanHelper { + + import testImplicits._ + + setupTestData() + + private def runAdaptiveAndVerifyResult(query: String): (SparkPlan, SparkPlan) = { + var finalPlanCnt = 0 + val listener = new SparkListener { + override def onOtherEvent(event: SparkListenerEvent): Unit = { + event match { + case SparkListenerSQLAdaptiveExecutionUpdate(_, _, sparkPlanInfo) => + if (sparkPlanInfo.simpleString.startsWith( + "AdaptiveSparkPlan isFinalPlan=true")) { + finalPlanCnt += 1 + } + case _ => // ignore other events + } + } + } + spark.sparkContext.addSparkListener(listener) + + val dfAdaptive = sql(query) + val planBefore = dfAdaptive.queryExecution.executedPlan + assert(planBefore.toString.startsWith("AdaptiveSparkPlan isFinalPlan=false")) + val result = dfAdaptive.collect() + withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false") { + val df = sql(query) + checkAnswer(df, result) + } + val planAfter = dfAdaptive.queryExecution.executedPlan + assert(planAfter.toString.startsWith("AdaptiveSparkPlan isFinalPlan=true")) + val adaptivePlan = planAfter.asInstanceOf[AdaptiveSparkPlanExec].executedPlan + + spark.sparkContext.listenerBus.waitUntilEmpty() + // AQE will post `SparkListenerSQLAdaptiveExecutionUpdate` twice in case of subqueries that + // exist out of query stages. + val expectedFinalPlanCnt = adaptivePlan.find(_.subqueries.nonEmpty).map(_ => 2).getOrElse(1) + assert(finalPlanCnt == expectedFinalPlanCnt) + spark.sparkContext.removeSparkListener(listener) + + val exchanges = adaptivePlan.collect { + case e: Exchange => e + } + assert(exchanges.isEmpty, "The final plan should not contain any Exchange node.") + (dfAdaptive.queryExecution.sparkPlan, adaptivePlan) + } + + private def findTopLevelBroadcastHashJoin(plan: SparkPlan): Seq[BroadcastHashJoinExec] = { + collect(plan) { + case j: BroadcastHashJoinExec => j + } + } + + private def findTopLevelColumnarBroadcastHashJoin(plan: SparkPlan) + : Seq[ColumnarBroadcastHashJoinExec] = { + collect(plan) { + case j: ColumnarBroadcastHashJoinExec => j + } + } + + private def findTopLevelSortMergeJoin(plan: SparkPlan): Seq[SortMergeJoinExec] = { + collect(plan) { + case j: SortMergeJoinExec => j + } + } + + private def findTopLevelColumnarSortMergeJoin(plan: SparkPlan): Seq[ColumnarSortMergeJoinExec] = { + collect(plan) { + case j: ColumnarSortMergeJoinExec => j + } + } + + private def findTopLevelBaseJoin(plan: SparkPlan): Seq[BaseJoinExec] = { + collect(plan) { + case j: BaseJoinExec => j + } + } + + private def findReusedExchange(plan: SparkPlan): Seq[ReusedExchangeExec] = { + collectWithSubqueries(plan) { + case ShuffleQueryStageExec(_, e: ReusedExchangeExec) => e + case BroadcastQueryStageExec(_, e: ReusedExchangeExec) => e + } + } + + private def findReusedSubquery(plan: SparkPlan): Seq[ReusedSubqueryExec] = { + collectWithSubqueries(plan) { + case e: ReusedSubqueryExec => e + } + } + + private def checkNumLocalShuffleReaders( + plan: SparkPlan, numShufflesWithoutLocalReader: Int = 0): Unit = { + val numShuffles = collect(plan) { + case s: ShuffleQueryStageExec => s + }.length + + val numLocalReaders = collect(plan) { + case rowReader: CustomShuffleReaderExec if rowReader.isLocalReader => rowReader + case colReader: ColumnarCustomShuffleReaderExec if colReader.isLocalReader => colReader + } + numLocalReaders.foreach { + case rowCus: CustomShuffleReaderExec => + val rdd = rowCus.execute() + val parts = rdd.partitions + assert(parts.forall(rdd.preferredLocations(_).nonEmpty)) + case r => + val columnarCus = r.asInstanceOf[ColumnarCustomShuffleReaderExec] + val rdd: RDD[ColumnarBatch] = columnarCus.executeColumnar() + val parts: Array[Partition] = rdd.partitions + assert(parts.forall(rdd.preferredLocations(_).nonEmpty)) + } + assert(numShuffles === (numLocalReaders.length + numShufflesWithoutLocalReader)) + } + + private def checkInitialPartitionNum(df: Dataset[_], numPartition: Int): Unit = { + // repartition obeys initialPartitionNum when adaptiveExecutionEnabled + val plan = df.queryExecution.executedPlan + assert(plan.isInstanceOf[AdaptiveSparkPlanExec]) + val shuffle = plan.asInstanceOf[AdaptiveSparkPlanExec].executedPlan.collect { + case s: ShuffleExchangeExec => s + } + assert(shuffle.size == 1) + assert(shuffle(0).outputPartitioning.numPartitions == numPartition) + } + + test("Change merge join to broadcast join") { + withSQLConf( + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") { + val (plan, adaptivePlan) = runAdaptiveAndVerifyResult( + "SELECT * FROM testData join testData2 ON key = a where value = '1'") + val smj: Seq[SortMergeJoinExec] = findTopLevelSortMergeJoin(plan) + assert(smj.size == 1) + val bhj: Seq[ColumnarBroadcastHashJoinExec] = + findTopLevelColumnarBroadcastHashJoin(adaptivePlan) + assert(bhj.size == 1) + checkNumLocalShuffleReaders(adaptivePlan) + } + } + + test("Reuse the parallelism of CoalescedShuffleReaderExec in LocalShuffleReaderExec") { + withSQLConf( + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80", + SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES.key -> "10") { + val (plan, adaptivePlan) = runAdaptiveAndVerifyResult( + "SELECT * FROM testData join testData2 ON key = a where value = '1'") + val smj = findTopLevelSortMergeJoin(plan) + assert(smj.size == 1) + val bhj = findTopLevelColumnarBroadcastHashJoin(adaptivePlan) + assert(bhj.size == 1) + val localReaders = collect(adaptivePlan) { + case reader: ColumnarCustomShuffleReaderExec if reader.isLocalReader => reader + } + assert(localReaders.length == 2) + val localShuffleRDD0 = localReaders(0).executeColumnar().asInstanceOf[ShuffledColumnarRDD] + val localShuffleRDD1 = localReaders(1).executeColumnar().asInstanceOf[ShuffledColumnarRDD] + // The pre-shuffle partition size is [0, 0, 0, 72, 0] + // We exclude the 0-size partitions, so only one partition, advisoryParallelism = 1 + // the final parallelism is + // math.max(1, advisoryParallelism / numMappers): math.max(1, 1/2) = 1 + // and the partitions length is 1 * numMappers = 2 + assert(localShuffleRDD0.getPartitions.length == 2) + // The pre-shuffle partition size is [0, 72, 0, 72, 126] + // We exclude the 0-size partitions, so only 3 partition, advisoryParallelism = 3 + // the final parallelism is + // math.max(1, advisoryParallelism / numMappers): math.max(1, 3/2) = 1 + // and the partitions length is 1 * numMappers = 2 + assert(localShuffleRDD1.getPartitions.length == 2) + } + } + + test("Reuse the default parallelism in LocalShuffleReaderExec") { + withSQLConf( + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80", + SQLConf.COALESCE_PARTITIONS_ENABLED.key -> "false") { + val (plan, adaptivePlan) = runAdaptiveAndVerifyResult( + "SELECT * FROM testData join testData2 ON key = a where value = '1'") + val smj = findTopLevelSortMergeJoin(plan) + assert(smj.size == 1) + val bhj = findTopLevelColumnarBroadcastHashJoin(adaptivePlan) + assert(bhj.size == 1) + val localReaders = collect(adaptivePlan) { + case reader: ColumnarCustomShuffleReaderExec if reader.isLocalReader => reader + } + assert(localReaders.length == 2) + val localShuffleRDD0 = localReaders(0).executeColumnar().asInstanceOf[ShuffledColumnarRDD] + val localShuffleRDD1 = localReaders(1).executeColumnar().asInstanceOf[ShuffledColumnarRDD] + // the final parallelism is math.max(1, numReduces / numMappers): math.max(1, 5/2) = 2 + // and the partitions length is 2 * numMappers = 4 + assert(localShuffleRDD0.getPartitions.length == 4) + // the final parallelism is math.max(1, numReduces / numMappers): math.max(1, 5/2) = 2 + // and the partitions length is 2 * numMappers = 4 + assert(localShuffleRDD1.getPartitions.length == 4) + } + } + + test("Empty stage coalesced to 1-partition RDD") { + withSQLConf( + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", + SQLConf.COALESCE_PARTITIONS_ENABLED.key -> "true") { + val df1 = spark.range(10).withColumn("a", 'id) + val df2 = spark.range(10).withColumn("b", 'id) + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { + val testDf = df1.where('a > 10).join(df2.where('b > 10), Seq("id"), "left_outer") + .groupBy('a).count() + checkAnswer(testDf, Seq()) + val plan = testDf.queryExecution.executedPlan + assert(find(plan)(_.isInstanceOf[SortMergeJoinExec]).isDefined) + val coalescedReaders = collect(plan) { + case r: ColumnarCustomShuffleReaderExec => r + } + assert(coalescedReaders.length == 3) + coalescedReaders.foreach(r => assert(r.partitionSpecs.length == 1)) + } + + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "1") { + // currently, we only support "inner" join type + val testDf = df1.where('a > 10).join(df2.where('b > 10), Seq("id"), "left_outer") + .groupBy('a).count() + checkAnswer(testDf, Seq()) + val plan = testDf.queryExecution.executedPlan + print(plan) + assert(find(plan)(_.isInstanceOf[BroadcastHashJoinExec]).isDefined) + val coalescedReaders = collect(plan) { + case r: ColumnarCustomShuffleReaderExec => r + } + assert(coalescedReaders.length == 3, s"$plan") + coalescedReaders.foreach(r => assert(r.isLocalReader || r.partitionSpecs.length == 1)) + } + } + } + + test("Scalar subquery") { + withSQLConf( + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") { + val (plan, adaptivePlan) = runAdaptiveAndVerifyResult( + "SELECT * FROM testData join testData2 ON key = a " + + "where value = (SELECT max(a) from testData3)") + val smj = findTopLevelSortMergeJoin(plan) + assert(smj.size == 1) + val bhj = findTopLevelColumnarBroadcastHashJoin(adaptivePlan) + assert(bhj.size == 1) + checkNumLocalShuffleReaders(adaptivePlan) + } + } + + // Currently, OmniFilterExec will fall back to Filter, if AQE is enabled, it will cause error + ignore("Scalar subquery in later stages") { + withSQLConf( + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") { + val (plan, adaptivePlan) = runAdaptiveAndVerifyResult( + "SELECT * FROM testData join testData2 ON key = a " + + "where (value + a) = (SELECT max(a) from testData3)") + val smj = findTopLevelSortMergeJoin(plan) + assert(smj.size == 1) + val bhj = findTopLevelColumnarBroadcastHashJoin(adaptivePlan) + assert(bhj.size == 1) + checkNumLocalShuffleReaders(adaptivePlan) + } + } + + test("multiple joins") { + withSQLConf( + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") { + val (plan, adaptivePlan) = runAdaptiveAndVerifyResult( + """ + |WITH t4 AS ( + | SELECT * FROM lowercaseData t2 JOIN testData3 t3 ON t2.n = t3.a where t2.n = '1' + |) + |SELECT * FROM testData + |JOIN testData2 t2 ON key = t2.a + |JOIN t4 ON t2.b = t4.a + |WHERE value = 1 + """.stripMargin) + val smj = findTopLevelSortMergeJoin(plan) + assert(smj.size == 3) + val bhj = findTopLevelColumnarBroadcastHashJoin(adaptivePlan) + assert(bhj.size == 3) + + // A possible resulting query plan: + // BroadcastHashJoin + // +- BroadcastExchange + // +- LocalShuffleReader* + // +- ShuffleExchange + // +- BroadcastHashJoin + // +- BroadcastExchange + // +- LocalShuffleReader* + // +- ShuffleExchange + // +- LocalShuffleReader* + // +- ShuffleExchange + // +- BroadcastHashJoin + // +- LocalShuffleReader* + // +- ShuffleExchange + // +- BroadcastExchange + // +-LocalShuffleReader* + // +- ShuffleExchange + + // After applied the 'OptimizeLocalShuffleReader' rule, we can convert all the four + // shuffle reader to local shuffle reader in the bottom two 'BroadcastHashJoin'. + // For the top level 'BroadcastHashJoin', the probe side is not shuffle query stage + // and the build side shuffle query stage is also converted to local shuffle reader. + checkNumLocalShuffleReaders(adaptivePlan) + } + } + + test("multiple joins with aggregate") { + withSQLConf( + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") { + val (plan, adaptivePlan) = runAdaptiveAndVerifyResult( + """ + |WITH t4 AS ( + | SELECT * FROM lowercaseData t2 JOIN ( + | select a, sum(b) from testData3 group by a + | ) t3 ON t2.n = t3.a where t2.n = '1' + |) + |SELECT * FROM testData + |JOIN testData2 t2 ON key = t2.a + |JOIN t4 ON t2.b = t4.a + |WHERE value = 1 + """.stripMargin) + val smj = findTopLevelSortMergeJoin(plan) + assert(smj.size == 3) + val bhj = findTopLevelColumnarBroadcastHashJoin(adaptivePlan) + assert(bhj.size == 3) + + // A possible resulting query plan: + // BroadcastHashJoin + // +- BroadcastExchange + // +- LocalShuffleReader* + // +- ShuffleExchange + // +- BroadcastHashJoin + // +- BroadcastExchange + // +- LocalShuffleReader* + // +- ShuffleExchange + // +- LocalShuffleReader* + // +- ShuffleExchange + // +- BroadcastHashJoin + // +- LocalShuffleReader* + // +- ShuffleExchange + // +- BroadcastExchange + // +-HashAggregate + // +- CoalescedShuffleReader + // +- ShuffleExchange + + // The shuffle added by Aggregate can't apply local reader. + checkNumLocalShuffleReaders(adaptivePlan, 1) + } + } + + test("multiple joins with aggregate 2") { + withSQLConf( + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "500") { + val (plan, adaptivePlan) = runAdaptiveAndVerifyResult( + """ + |WITH t4 AS ( + | SELECT * FROM lowercaseData t2 JOIN ( + | select a, max(b) b from testData2 group by a + | ) t3 ON t2.n = t3.b + |) + |SELECT * FROM testData + |JOIN testData2 t2 ON key = t2.a + |JOIN t4 ON value = t4.a + |WHERE value = 1 + """.stripMargin) + val smj = findTopLevelSortMergeJoin(plan) + assert(smj.size == 3) + val bhj = findTopLevelColumnarBroadcastHashJoin(adaptivePlan) + assert(bhj.size == 2) + + // A possible resulting query plan: + // BroadcastHashJoin + // +- BroadcastExchange + // +- LocalShuffleReader* + // +- ShuffleExchange + // +- BroadcastHashJoin + // +- BroadcastExchange + // +- LocalShuffleReader* + // +- ShuffleExchange + // +- LocalShuffleReader* + // +- ShuffleExchange + // +- BroadcastHashJoin + // +- Filter + // +- HashAggregate + // +- CoalescedShuffleReader + // +- ShuffleExchange + // +- BroadcastExchange + // +-LocalShuffleReader* + // +- ShuffleExchange + + // The shuffle added by Aggregate can't apply local reader. + checkNumLocalShuffleReaders(adaptivePlan, 1) + } + } + + test("Exchange reuse") { + withSQLConf( + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") { + val (plan, adaptivePlan) = runAdaptiveAndVerifyResult( + "SELECT value FROM testData join testData2 ON key = a " + + "join (SELECT value v from testData join testData3 ON key = a) on value = v") + val smj = findTopLevelSortMergeJoin(plan) + assert(smj.size == 3) + val bhj = findTopLevelColumnarBroadcastHashJoin(adaptivePlan) + assert(bhj.size == 3) + // There is no SMJ + checkNumLocalShuffleReaders(adaptivePlan, 0) + // Even with local shuffle reader, the query stage reuse can also work. + val ex = findReusedExchange(adaptivePlan) + assert(ex.size == 1) + } + } + + test("Exchange reuse with subqueries") { + withSQLConf( + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") { + val (plan, adaptivePlan) = runAdaptiveAndVerifyResult( + "SELECT a FROM testData join testData2 ON key = a " + + "where value = (SELECT max(a) from testData join testData2 ON key = a)") + val smj = findTopLevelSortMergeJoin(plan) + assert(smj.size == 1) + val bhj = findTopLevelColumnarBroadcastHashJoin(adaptivePlan) + assert(bhj.size == 1) + checkNumLocalShuffleReaders(adaptivePlan) + // Even with local shuffle reader, the query stage reuse can also work. + val ex = findReusedExchange(adaptivePlan) + assert(ex.size == 1) + } + } + + test("Exchange reuse across subqueries") { + withSQLConf( + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80", + SQLConf.SUBQUERY_REUSE_ENABLED.key -> "false") { + val (plan, adaptivePlan) = runAdaptiveAndVerifyResult( + "SELECT a FROM testData join testData2 ON key = a " + + "where value >= (SELECT max(a) from testData join testData2 ON key = a) " + + "and a <= (SELECT max(a) from testData join testData2 ON key = a)") + val smj = findTopLevelSortMergeJoin(plan) + assert(smj.size == 1) + val bhj = findTopLevelColumnarBroadcastHashJoin(adaptivePlan) + assert(bhj.size == 1) + checkNumLocalShuffleReaders(adaptivePlan) + // Even with local shuffle reader, the query stage reuse can also work. + val ex = findReusedExchange(adaptivePlan) + assert(ex.nonEmpty) + val sub = findReusedSubquery(adaptivePlan) + assert(sub.isEmpty) + } + } + + test("Subquery reuse") { + withSQLConf( + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") { + val (plan, adaptivePlan) = runAdaptiveAndVerifyResult( + "SELECT a FROM testData join testData2 ON key = a " + + "where value >= (SELECT max(a) from testData join testData2 ON key = a) " + + "and a <= (SELECT max(a) from testData join testData2 ON key = a)") + val smj = findTopLevelSortMergeJoin(plan) + assert(smj.size == 1) + val bhj = findTopLevelColumnarBroadcastHashJoin(adaptivePlan) + assert(bhj.size == 1) + checkNumLocalShuffleReaders(adaptivePlan) + // Even with local shuffle reader, the query stage reuse can also work. + val ex = findReusedExchange(adaptivePlan) + assert(ex.isEmpty) + val sub = findReusedSubquery(adaptivePlan) + assert(sub.nonEmpty) + } + } + + test("Broadcast exchange reuse across subqueries") { + withSQLConf( + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "20000000", + SQLConf.SUBQUERY_REUSE_ENABLED.key -> "false") { + val (plan, adaptivePlan) = runAdaptiveAndVerifyResult( + "SELECT a FROM testData join testData2 ON key = a " + + "where value >= (" + + "SELECT /*+ broadcast(testData2) */ max(key) from testData join testData2 ON key = a) " + + "and a <= (" + + "SELECT /*+ broadcast(testData2) */ max(value) from testData join testData2 ON key = a)") + val smj = findTopLevelSortMergeJoin(plan) + assert(smj.size == 1) + val bhj = findTopLevelColumnarBroadcastHashJoin(adaptivePlan) + assert(bhj.size == 1) + checkNumLocalShuffleReaders(adaptivePlan) + // Even with local shuffle reader, the query stage reuse can also work. + val ex = findReusedExchange(adaptivePlan) + assert(ex.nonEmpty) + assert(ex.head.child.isInstanceOf[ColumnarBroadcastExchangeExec]) + val sub = findReusedSubquery(adaptivePlan) + assert(sub.isEmpty) + } + } + + test("Union/Except/Intersect queries") { + withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") { + runAdaptiveAndVerifyResult( + """ + |SELECT * FROM testData + |EXCEPT + |SELECT * FROM testData2 + |UNION ALL + |SELECT * FROM testData + |INTERSECT ALL + |SELECT * FROM testData2 + """.stripMargin) + } + } + + test("Subquery de-correlation in Union queries") { + withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") { + withTempView("a", "b") { + Seq("a" -> 2, "b" -> 1).toDF("id", "num").createTempView("a") + Seq("a" -> 2, "b" -> 1).toDF("id", "num").createTempView("b") + + runAdaptiveAndVerifyResult( + """ + |SELECT id,num,source FROM ( + | SELECT id, num, 'a' as source FROM a + | UNION ALL + | SELECT id, num, 'b' as source FROM b + |) AS c WHERE c.id IN (SELECT id FROM b WHERE num = 2) + """.stripMargin) + } + } + } + + test("Avoid plan change if cost is greater") { + val origPlan = sql("SELECT * FROM testData " + + "join testData2 t2 ON key = t2.a " + + "join testData2 t3 on t2.a = t3.a where t2.b = 1").queryExecution.executedPlan + + withSQLConf( + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "25", + SQLConf.BROADCAST_HASH_JOIN_OUTPUT_PARTITIONING_EXPAND_LIMIT.key -> "0") { + val (plan, adaptivePlan) = runAdaptiveAndVerifyResult( + "SELECT * FROM testData " + + "join testData2 t2 ON key = t2.a " + + "join testData2 t3 on t2.a = t3.a where t2.b = 1") + val smj = findTopLevelSortMergeJoin(plan) + assert(smj.size == 2) + val smj2 = findTopLevelSortMergeJoin(adaptivePlan) + assert(smj2.size == 2, origPlan.toString) + } + } + + test("Change merge join to broadcast join without local shuffle reader") { + withSQLConf( + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", + SQLConf.LOCAL_SHUFFLE_READER_ENABLED.key -> "true", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "25") { + val (plan, adaptivePlan) = runAdaptiveAndVerifyResult( + """ + |SELECT * FROM testData t1 join testData2 t2 + |ON t1.key = t2.a join testData3 t3 on t2.a = t3.a + |where t1.value = 1 + """.stripMargin + ) + val smj = findTopLevelSortMergeJoin(plan) + assert(smj.size == 2) + val bhj = findTopLevelColumnarBroadcastHashJoin(adaptivePlan) + assert(bhj.size == 1) + // TODO: check whey we don't have SMJ + checkNumLocalShuffleReaders(adaptivePlan, 2) + } + } + + test("Avoid changing merge join to broadcast join if too many empty partitions on build plan") { + withSQLConf( + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", + SQLConf.NON_EMPTY_PARTITION_RATIO_FOR_BROADCAST_JOIN.key -> "0.5") { + // `testData` is small enough to be broadcast but has empty partition ratio over the config. + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") { + val (plan, adaptivePlan) = runAdaptiveAndVerifyResult( + "SELECT * FROM testData join testData2 ON key = a where value = '1'") + val smj = findTopLevelSortMergeJoin(plan) + assert(smj.size == 1) + val bhj = findTopLevelBroadcastHashJoin(adaptivePlan) + assert(bhj.isEmpty) + } + // It is still possible to broadcast `testData2`. + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "2000") { + val (plan, adaptivePlan) = runAdaptiveAndVerifyResult( + "SELECT * FROM testData join testData2 ON key = a where value = '1'") + val smj = findTopLevelSortMergeJoin(plan) + assert(smj.size == 1) + val bhj = findTopLevelColumnarBroadcastHashJoin(adaptivePlan) + assert(bhj.size == 1) + assert(bhj.head.buildSide == BuildRight) + } + } + } + + test("SPARK-29906: AQE should not introduce extra shuffle for outermost limit") { + var numStages = 0 + val listener = new SparkListener { + override def onJobStart(jobStart: SparkListenerJobStart): Unit = { + numStages = jobStart.stageInfos.length + } + } + try { + withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") { + spark.sparkContext.addSparkListener(listener) + spark.range(0, 100, 1, numPartitions = 10).take(1) + spark.sparkContext.listenerBus.waitUntilEmpty() + // Should be only one stage since there is no shuffle. + assert(numStages == 1) + } + } finally { + spark.sparkContext.removeSparkListener(listener) + } + } + + test("SPARK-30524: Do not optimize skew join if introduce additional shuffle") { + withSQLConf( + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", + SQLConf.SKEW_JOIN_SKEWED_PARTITION_THRESHOLD.key -> "100", + SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES.key -> "100") { + withTempView("skewData1", "skewData2") { + spark + .range(0, 1000, 1, 10) + .selectExpr("id % 3 as key1", "id as value1") + .createOrReplaceTempView("skewData1") + spark + .range(0, 1000, 1, 10) + .selectExpr("id % 1 as key2", "id as value2") + .createOrReplaceTempView("skewData2") + + def checkSkewJoin(query: String, optimizeSkewJoin: Boolean): Unit = { + val (_, innerAdaptivePlan) = runAdaptiveAndVerifyResult(query) + val innerSmj = findTopLevelColumnarSortMergeJoin(innerAdaptivePlan) + assert(innerSmj.size == 1 && innerSmj.head.isSkewJoin == optimizeSkewJoin) + } + + checkSkewJoin( + "SELECT key1 FROM skewData1 JOIN skewData2 ON key1 = key2", true) + // Additional shuffle introduced, so disable the "OptimizeSkewedJoin" optimization + checkSkewJoin( + "SELECT key1 FROM skewData1 JOIN skewData2 ON key1 = key2 GROUP BY key1", false) + } + } + } + + ignore("SPARK-29544: adaptive skew join with different join types") { + withSQLConf( + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", + SQLConf.COALESCE_PARTITIONS_MIN_PARTITION_NUM.key -> "1", + SQLConf.SHUFFLE_PARTITIONS.key -> "100", + SQLConf.SKEW_JOIN_SKEWED_PARTITION_THRESHOLD.key -> "800", + SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES.key -> "800") { + withTempView("skewData1", "skewData2") { + spark + .range(0, 1000, 1, 10) + .select( + when('id < 250, 249) + .when('id >= 750, 1000) + .otherwise('id).as("key1"), + 'id as "value1") + .createOrReplaceTempView("skewData1") + spark + .range(0, 1000, 1, 10) + .select( + when('id < 250, 249) + .otherwise('id).as("key2"), + 'id as "value2") + .createOrReplaceTempView("skewData2") + + def checkSkewJoin( + joins: Seq[SortMergeJoinExec], + leftSkewNum: Int, + rightSkewNum: Int): Unit = { + assert(joins.size == 1 && joins.head.isSkewJoin) + assert(joins.head.left.collect { + case r: ColumnarCustomShuffleReaderExec => r + }.head.partitionSpecs.collect { + case p: PartialReducerPartitionSpec => p.reducerIndex + }.distinct.length == leftSkewNum) + assert(joins.head.right.collect { + case r: ColumnarCustomShuffleReaderExec => r + }.head.partitionSpecs.collect { + case p: PartialReducerPartitionSpec => p.reducerIndex + }.distinct.length == rightSkewNum) + } + + // skewed inner join optimization + val (_, innerAdaptivePlan) = runAdaptiveAndVerifyResult( + "SELECT * FROM skewData1 join skewData2 ON key1 = key2") + val innerSmj = findTopLevelColumnarSortMergeJoin(innerAdaptivePlan) + checkSkewJoin(innerSmj, 1, 1) + + // skewed left outer join optimization + val (_, leftAdaptivePlan) = runAdaptiveAndVerifyResult( + "SELECT * FROM skewData1 left outer join skewData2 ON key1 = key2") + val leftSmj = findTopLevelColumnarSortMergeJoin(leftAdaptivePlan) + checkSkewJoin(leftSmj, 2, 0) + + // skewed right outer join optimization + val (_, rightAdaptivePlan) = runAdaptiveAndVerifyResult( + "SELECT * FROM skewData1 right outer join skewData2 ON key1 = key2") + val rightSmj = findTopLevelColumnarSortMergeJoin(rightAdaptivePlan) + checkSkewJoin(rightSmj, 0, 1) + } + } + } + + test("SPARK-30291: AQE should catch the exceptions when doing materialize") { + withSQLConf( + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") { + withTable("bucketed_table") { + val df1 = + (0 until 50).map(i => (i % 5, i % 13, i.toString)).toDF("i", "j", "k").as("df1") + df1.write.format("parquet").bucketBy(8, "i").saveAsTable("bucketed_table") + val warehouseFilePath = new URI(spark.sessionState.conf.warehousePath).getPath + val tableDir = new File(warehouseFilePath, "bucketed_table") + Utils.deleteRecursively(tableDir) + df1.write.parquet(tableDir.getAbsolutePath) + + val aggregated = spark.table("bucketed_table").groupBy("i").sum() + val error = intercept[Exception] { + aggregated.head() + } + assert(error.getMessage contains "Invalid bucket file") + assert(error.getSuppressed.size === 0) + } + } + } + + test("SPARK-30403: AQE should handle InSubquery") { + withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") { + runAdaptiveAndVerifyResult("SELECT * FROM testData LEFT OUTER join testData2" + + " ON key = a AND key NOT IN (select a from testData3) where value = '1'" + ) + } + } + + test("force apply AQE") { + withSQLConf( + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", + SQLConf.ADAPTIVE_EXECUTION_FORCE_APPLY.key -> "true") { + val plan = sql("SELECT * FROM testData").queryExecution.executedPlan + assert(plan.isInstanceOf[AdaptiveSparkPlanExec]) + } + } + + test("SPARK-30719: do not log warning if intentionally skip AQE") { + val testAppender = new LogAppender("aqe logging warning test when skip") + withLogAppender(testAppender) { + withSQLConf( + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") { + val plan = sql("SELECT * FROM testData").queryExecution.executedPlan + assert(!plan.isInstanceOf[AdaptiveSparkPlanExec]) + } + } + assert(!testAppender.loggingEvents + .exists(msg => msg.getRenderedMessage.contains( + s"${SQLConf.ADAPTIVE_EXECUTION_ENABLED.key} is" + + s" enabled but is not supported for"))) + } + + test("test log level") { + def verifyLog(expectedLevel: Level): Unit = { + val logAppender = new LogAppender("adaptive execution") + withLogAppender( + logAppender, + loggerName = Some(AdaptiveSparkPlanExec.getClass.getName.dropRight(1)), + level = Some(Level.TRACE)) { + withSQLConf( + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") { + sql("SELECT * FROM testData join testData2 ON key = a where value = '1'").collect() + } + } + Seq("Plan changed", "Final plan").foreach { msg => + assert( + logAppender.loggingEvents.exists { event => + event.getRenderedMessage.contains(msg) && event.getLevel == expectedLevel + }) + } + } + + // Verify default log level + verifyLog(Level.DEBUG) + + // Verify custom log level + val levels = Seq( + "TRACE" -> Level.TRACE, + "trace" -> Level.TRACE, + "DEBUG" -> Level.DEBUG, + "debug" -> Level.DEBUG, + "INFO" -> Level.INFO, + "info" -> Level.INFO, + "WARN" -> Level.WARN, + "warn" -> Level.WARN, + "ERROR" -> Level.ERROR, + "error" -> Level.ERROR, + "deBUG" -> Level.DEBUG) + + levels.foreach { level => + withSQLConf(SQLConf.ADAPTIVE_EXECUTION_LOG_LEVEL.key -> level._1) { + verifyLog(level._2) + } + } + } + + test("tree string output") { + withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") { + val df = sql("SELECT * FROM testData join testData2 ON key = a where value = '1'") + val planBefore = df.queryExecution.executedPlan + assert(!planBefore.toString.contains("== Current Plan ==")) + assert(!planBefore.toString.contains("== Initial Plan ==")) + df.collect() + val planAfter = df.queryExecution.executedPlan + assert(planAfter.toString.contains("== Final Plan ==")) + assert(planAfter.toString.contains("== Initial Plan ==")) + } + } + + test("SPARK-31384: avoid NPE in OptimizeSkewedJoin when there's 0 partition plan") { + withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { + withTempView("t2") { + // create DataFrame with 0 partition + spark.createDataFrame(sparkContext.emptyRDD[Row], new StructType().add("b", IntegerType)) + .createOrReplaceTempView("t2") + // should run successfully without NPE + runAdaptiveAndVerifyResult("SELECT * FROM testData2 t1 join t2 ON t1.a=t2.b") + } + } + } + + ignore("metrics of the shuffle reader") { + withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") { + val (_, adaptivePlan) = runAdaptiveAndVerifyResult( + "SELECT key FROM testData GROUP BY key") + val readers = collect(adaptivePlan) { + case r: ColumnarCustomShuffleReaderExec => r + } + print(readers.length) + assert(readers.length == 1) + val reader = readers.head + assert(!reader.isLocalReader) + assert(!reader.hasSkewedPartition) + assert(reader.hasCoalescedPartition) + assert(reader.metrics.keys.toSeq.sorted == Seq( + "numPartitions", "partitionDataSize")) + assert(reader.metrics("numPartitions").value == reader.partitionSpecs.length) + assert(reader.metrics("partitionDataSize").value > 0) + + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") { + val (_, adaptivePlan) = runAdaptiveAndVerifyResult( + "SELECT * FROM testData join testData2 ON key = a where value = '1'") + val join = collect(adaptivePlan) { + case j: ColumnarBroadcastHashJoinExec => j + }.head + assert(join.buildSide == BuildLeft) + + val readers = collect(join.right) { + case r: ColumnarCustomShuffleReaderExec => r + } + assert(readers.length == 1) + val reader = readers.head + assert(reader.isLocalReader) + assert(reader.metrics.keys.toSeq == Seq("numPartitions")) + assert(reader.metrics("numPartitions").value == reader.partitionSpecs.length) + } + + withSQLConf( + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", + SQLConf.SHUFFLE_PARTITIONS.key -> "100", + SQLConf.SKEW_JOIN_SKEWED_PARTITION_THRESHOLD.key -> "800", + SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES.key -> "800") { + withTempView("skewData1", "skewData2") { + spark + .range(0, 1000, 1, 10) + .select( + when('id < 250, 249) + .when('id >= 750, 1000) + .otherwise('id).as("key1"), + 'id as "value1") + .createOrReplaceTempView("skewData1") + spark + .range(0, 1000, 1, 10) + .select( + when('id < 250, 249) + .otherwise('id).as("key2"), + 'id as "value2") + .createOrReplaceTempView("skewData2") + val (_, adaptivePlan) = runAdaptiveAndVerifyResult( + "SELECT * FROM skewData1 join skewData2 ON key1 = key2") + val readers = collect(adaptivePlan) { + case r: CustomShuffleReaderExec => r + } + readers.foreach { reader => + assert(!reader.isLocalReader) + assert(reader.hasCoalescedPartition) + assert(reader.hasSkewedPartition) + assert(reader.metrics.contains("numSkewedPartitions")) + } + print(readers(1).metrics("numSkewedPartitions")) + print(readers(1).metrics("numSkewedSplits")) + assert(readers(0).metrics("numSkewedPartitions").value == 2) + assert(readers(0).metrics("numSkewedSplits").value == 15) + assert(readers(1).metrics("numSkewedPartitions").value == 1) + assert(readers(1).metrics("numSkewedSplits").value == 12) + } + } + } + } + + test("control a plan explain mode in listeners via SQLConf") { + + def checkPlanDescription(mode: String, expected: Seq[String]): Unit = { + var checkDone = false + val listener = new SparkListener { + override def onOtherEvent(event: SparkListenerEvent): Unit = { + event match { + case SparkListenerSQLAdaptiveExecutionUpdate(_, planDescription, _) => + assert(expected.forall(planDescription.contains)) + checkDone = true + case _ => // ignore other events + } + } + } + spark.sparkContext.addSparkListener(listener) + withSQLConf(SQLConf.UI_EXPLAIN_MODE.key -> mode, + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") { + val dfAdaptive = sql("SELECT * FROM testData JOIN testData2 ON key = a WHERE value = '1'") + try { + checkAnswer(dfAdaptive, Row(1, "1", 1, 1) :: Row(1, "1", 1, 2) :: Nil) + spark.sparkContext.listenerBus.waitUntilEmpty() + assert(checkDone) + } finally { + spark.sparkContext.removeSparkListener(listener) + } + } + } + + Seq(("simple", Seq("== Physical Plan ==")), + ("extended", Seq("== Parsed Logical Plan ==", "== Analyzed Logical Plan ==", + "== Optimized Logical Plan ==", "== Physical Plan ==")), + ("codegen", Seq("WholeStageCodegen subtrees")), + ("cost", Seq("== Optimized Logical Plan ==", "Statistics(sizeInBytes")), + ("formatted", Seq("== Physical Plan ==", "Output", "Arguments"))).foreach { + case (mode, expected) => + checkPlanDescription(mode, expected) + } + } + + test("SPARK-30953: InsertAdaptiveSparkPlan should apply AQE on child plan of write commands") { + withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", + SQLConf.ADAPTIVE_EXECUTION_FORCE_APPLY.key -> "true") { + withTable("t1") { + val plan = sql("CREATE TABLE t1 USING parquet AS SELECT 1 col").queryExecution.executedPlan + assert(plan.isInstanceOf[DataWritingCommandExec]) + assert(plan.asInstanceOf[DataWritingCommandExec].child.isInstanceOf[AdaptiveSparkPlanExec]) + } + } + } + + test("AQE should set active session during execution") { + withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") { + val df = spark.range(10).select(sum('id)) + assert(df.queryExecution.executedPlan.isInstanceOf[AdaptiveSparkPlanExec]) + SparkSession.setActiveSession(null) + checkAnswer(df, Seq(Row(45))) + SparkSession.setActiveSession(spark) // recover the active session. + } + } + + test("No deadlock in UI update") { + object TestStrategy extends Strategy { + def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + case _: Aggregate => + withSQLConf( + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", + SQLConf.ADAPTIVE_EXECUTION_FORCE_APPLY.key -> "true") { + spark.range(5).rdd + } + Nil + case _ => Nil + } + } + + withSQLConf( + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", + SQLConf.ADAPTIVE_EXECUTION_FORCE_APPLY.key -> "true") { + try { + spark.experimental.extraStrategies = TestStrategy :: Nil + val df = spark.range(10).groupBy('id).count() + df.collect() + } finally { + spark.experimental.extraStrategies = Nil + } + } + } + + test("SPARK-31658: SQL UI should show write commands") { + withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", + SQLConf.ADAPTIVE_EXECUTION_FORCE_APPLY.key -> "true") { + withTable("t1") { + var checkDone = false + val listener = new SparkListener { + override def onOtherEvent(event: SparkListenerEvent): Unit = { + event match { + case SparkListenerSQLAdaptiveExecutionUpdate(_, _, planInfo) => + assert(planInfo.nodeName == "Execute CreateDataSourceTableAsSelectCommand") + checkDone = true + case _ => // ignore other events + } + } + } + spark.sparkContext.addSparkListener(listener) + try { + sql("CREATE TABLE t1 USING parquet AS SELECT 1 col").collect() + spark.sparkContext.listenerBus.waitUntilEmpty() + assert(checkDone) + } finally { + spark.sparkContext.removeSparkListener(listener) + } + } + } + } + + test("SPARK-31220, SPARK-32056: repartition by expression with AQE") { + Seq(true, false).foreach { enableAQE => + withSQLConf( + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> enableAQE.toString, + SQLConf.COALESCE_PARTITIONS_ENABLED.key -> "true", + SQLConf.COALESCE_PARTITIONS_INITIAL_PARTITION_NUM.key -> "10", + SQLConf.SHUFFLE_PARTITIONS.key -> "10") { + + val df1 = spark.range(10).repartition($"id") + val df2 = spark.range(10).repartition($"id" + 1) + + val partitionsNum1 = df1.rdd.collectPartitions().length + val partitionsNum2 = df2.rdd.collectPartitions().length + + if (enableAQE) { + assert(partitionsNum1 < 10) + assert(partitionsNum2 < 10) + + checkInitialPartitionNum(df1, 10) + checkInitialPartitionNum(df2, 10) + } else { + assert(partitionsNum1 === 10) + assert(partitionsNum2 === 10) + } + + + // Don't coalesce partitions if the number of partitions is specified. + val df3 = spark.range(10).repartition(10, $"id") + val df4 = spark.range(10).repartition(10) + assert(df3.rdd.collectPartitions().length == 10) + assert(df4.rdd.collectPartitions().length == 10) + } + } + } + + test("SPARK-31220, SPARK-32056: repartition by range with AQE") { + Seq(true, false).foreach { enableAQE => + withSQLConf( + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> enableAQE.toString, + SQLConf.COALESCE_PARTITIONS_ENABLED.key -> "true", + SQLConf.COALESCE_PARTITIONS_INITIAL_PARTITION_NUM.key -> "10", + SQLConf.SHUFFLE_PARTITIONS.key -> "10") { + + val df1 = spark.range(10).toDF.repartitionByRange($"id".asc) + val df2 = spark.range(10).toDF.repartitionByRange(($"id" + 1).asc) + + val partitionsNum1 = df1.rdd.collectPartitions().length + val partitionsNum2 = df2.rdd.collectPartitions().length + + if (enableAQE) { + assert(partitionsNum1 < 10) + assert(partitionsNum2 < 10) + + checkInitialPartitionNum(df1, 10) + checkInitialPartitionNum(df2, 10) + } else { + assert(partitionsNum1 === 10) + assert(partitionsNum2 === 10) + } + + // Don't coalesce partitions if the number of partitions is specified. + val df3 = spark.range(10).repartitionByRange(10, $"id".asc) + assert(df3.rdd.collectPartitions().length == 10) + } + } + } + + test("SPARK-31220, SPARK-32056: repartition using sql and hint with AQE") { + Seq(true, false).foreach { enableAQE => + withTempView("test") { + withSQLConf( + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> enableAQE.toString, + SQLConf.COALESCE_PARTITIONS_ENABLED.key -> "true", + SQLConf.COALESCE_PARTITIONS_INITIAL_PARTITION_NUM.key -> "10", + SQLConf.SHUFFLE_PARTITIONS.key -> "10") { + + spark.range(10).toDF.createTempView("test") + + val df1 = spark.sql("SELECT /*+ REPARTITION(id) */ * from test") + val df2 = spark.sql("SELECT /*+ REPARTITION_BY_RANGE(id) */ * from test") + val df3 = spark.sql("SELECT * from test DISTRIBUTE BY id") + val df4 = spark.sql("SELECT * from test CLUSTER BY id") + + val partitionsNum1 = df1.rdd.collectPartitions().length + val partitionsNum2 = df2.rdd.collectPartitions().length + val partitionsNum3 = df3.rdd.collectPartitions().length + val partitionsNum4 = df4.rdd.collectPartitions().length + + if (enableAQE) { + assert(partitionsNum1 < 10) + assert(partitionsNum2 < 10) + assert(partitionsNum3 < 10) + assert(partitionsNum4 < 10) + + checkInitialPartitionNum(df1, 10) + checkInitialPartitionNum(df2, 10) + checkInitialPartitionNum(df3, 10) + checkInitialPartitionNum(df4, 10) + } else { + assert(partitionsNum1 === 10) + assert(partitionsNum2 === 10) + assert(partitionsNum3 === 10) + assert(partitionsNum4 === 10) + } + + // Don't coalesce partitions if the number of partitions is specified. + val df5 = spark.sql("SELECT /*+ REPARTITION(10, id) */ * from test") + val df6 = spark.sql("SELECT /*+ REPARTITION_BY_RANGE(10, id) */ * from test") + assert(df5.rdd.collectPartitions().length == 10) + assert(df6.rdd.collectPartitions().length == 10) + } + } + } + } + + test("SPARK-32573: Eliminate NAAJ when BuildSide is HashedRelationWithAllNullKeys") { + withSQLConf( + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> Long.MaxValue.toString) { + val (plan, adaptivePlan) = runAdaptiveAndVerifyResult( + "SELECT * FROM testData2 t1 WHERE t1.b NOT IN (SELECT b FROM testData3)") + val bhj = findTopLevelBroadcastHashJoin(plan) + assert(bhj.size == 1) + val join = findTopLevelBaseJoin(adaptivePlan) + assert(join.isEmpty) + checkNumLocalShuffleReaders(adaptivePlan) + } + } + + test("SPARK-32717: AQEOptimizer should respect excludedRules configuration") { + withSQLConf( + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> Long.MaxValue.toString, + // This test is a copy of test(SPARK-32573), in order to test the configuration + // `spark.sql.adaptive.optimizer.excludedRules` works as expect. + SQLConf.ADAPTIVE_OPTIMIZER_EXCLUDED_RULES.key -> EliminateJoinToEmptyRelation.ruleName) { + val (plan, adaptivePlan) = runAdaptiveAndVerifyResult( + "SELECT * FROM testData2 t1 WHERE t1.b NOT IN (SELECT b FROM testData3)") + val bhj = findTopLevelBroadcastHashJoin(plan) + assert(bhj.size == 1) + val join = findTopLevelBaseJoin(adaptivePlan) + // this is different compares to test(SPARK-32573) due to the rule + // `EliminateJoinToEmptyRelation` has been excluded. + assert(join.nonEmpty) + checkNumLocalShuffleReaders(adaptivePlan) + } + } + + test("SPARK-32649: Eliminate inner to empty relation") { + withSQLConf( + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") { + Seq( + // inner join (small table at right side) + "SELECT * FROM testData t1 join testData3 t2 ON t1.key = t2.a WHERE t2.b = 1", + // inner join (small table at left side) + "SELECT * FROM testData3 t1 join testData t2 ON t1.a = t2.key WHERE t1.b = 1", + // left semi join : left join do not has omni impl + // "SELECT * FROM testData t1 left semi join testData3 t2 ON t1.key = t2.a AND t2.b = 1" + ).foreach(query => { + val (plan, adaptivePlan) = runAdaptiveAndVerifyResult(query) + val smj = findTopLevelSortMergeJoin(plan) + assert(smj.size == 1) + val join = findTopLevelBaseJoin(adaptivePlan) + assert(join.isEmpty) + checkNumLocalShuffleReaders(adaptivePlan) + }) + } + } + + test("SPARK-32753: Only copy tags to node with no tags") { + withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") { + withTempView("v1") { + spark.range(10).union(spark.range(10)).createOrReplaceTempView("v1") + + val (_, adaptivePlan) = runAdaptiveAndVerifyResult( + "SELECT id FROM v1 GROUP BY id DISTRIBUTE BY id") + assert(collect(adaptivePlan) { + case s: ShuffleExchangeExec => s + }.length == 1) + } + } + } + + test("Logging plan changes for AQE") { + val testAppender = new LogAppender("plan changes") + withLogAppender(testAppender) { + withSQLConf( + SQLConf.PLAN_CHANGE_LOG_LEVEL.key -> "INFO", + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") { + sql("SELECT * FROM testData JOIN testData2 ON key = a " + + "WHERE value = (SELECT max(a) FROM testData3)").collect() + } + Seq("=== Result of Batch AQE Preparations ===", + "=== Result of Batch AQE Post Stage Creation ===", + "=== Result of Batch AQE Replanning ===", + "=== Result of Batch AQE Query Stage Optimization ===", + "=== Result of Batch AQE Final Query Stage Optimization ===").foreach { expectedMsg => + assert(testAppender.loggingEvents.exists(_.getRenderedMessage.contains(expectedMsg))) + } + } + } + + test("SPARK-32932: Do not use local shuffle reader at final stage on write command") { + withSQLConf(SQLConf.PARTITION_OVERWRITE_MODE.key -> PartitionOverwriteMode.DYNAMIC.toString, + SQLConf.SHUFFLE_PARTITIONS.key -> "5", + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") { + val data = for ( + i <- 1L to 10L; + j <- 1L to 3L + ) yield (i, j) + + val df = data.toDF("i", "j").repartition($"j") + var noLocalReader: Boolean = false + val listener = new QueryExecutionListener { + override def onSuccess(funcName: String, qe: QueryExecution, durationNs: Long): Unit = { + qe.executedPlan match { + case plan@(_: DataWritingCommandExec | _: V2TableWriteExec) => + assert(plan.asInstanceOf[UnaryExecNode].child.isInstanceOf[AdaptiveSparkPlanExec]) + noLocalReader = collect(plan) { + case exec: CustomShuffleReaderExec if exec.isLocalReader => exec + }.isEmpty + case _ => // ignore other events + } + } + override def onFailure(funcName: String, qe: QueryExecution, + exception: Exception): Unit = {} + } + spark.listenerManager.register(listener) + + withTable("t") { + df.write.partitionBy("j").saveAsTable("t") + sparkContext.listenerBus.waitUntilEmpty() + assert(noLocalReader) + noLocalReader = false + } + + // Test DataSource v2 + val format = classOf[NoopDataSource].getName + df.write.format(format).mode("overwrite").save() + sparkContext.listenerBus.waitUntilEmpty() + assert(noLocalReader) + noLocalReader = false + + spark.listenerManager.unregister(listener) + } + } + + test("SPARK-33494: Do not use local shuffle reader for repartition") { + withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") { + val df = spark.table("testData").repartition('key) + df.collect() + // local shuffle reader breaks partitioning and shouldn't be used for repartition operation + // which is specified by users. + checkNumLocalShuffleReaders(df.queryExecution.executedPlan, numShufflesWithoutLocalReader = 1) + } + } + + test("SPARK-33551: Do not use custom shuffle reader for repartition") { + def hasRepartitionShuffle(plan: SparkPlan): Boolean = { + find(plan) { + case s: ShuffleExchangeLike => + s.shuffleOrigin == REPARTITION || s.shuffleOrigin == REPARTITION_WITH_NUM + case _ => false + }.isDefined + } + + withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", + SQLConf.SHUFFLE_PARTITIONS.key -> "5") { + val df = sql( + """ + |SELECT * FROM ( + | SELECT * FROM testData WHERE key = 1 + |) + |RIGHT OUTER JOIN testData2 + |ON value = b + """.stripMargin) + + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") { + // Repartition with no partition num specified. + val dfRepartition = df.repartition('b) + dfRepartition.collect() + val plan = dfRepartition.queryExecution.executedPlan + // The top shuffle from repartition is optimized out. + assert(!hasRepartitionShuffle(plan)) + val bhj = findTopLevelBroadcastHashJoin(plan) + assert(bhj.length == 1) + checkNumLocalShuffleReaders(plan, 1) + // Probe side is coalesced. + val customReader = bhj.head.right.find(_.isInstanceOf[ColumnarCustomShuffleReaderExec]) + assert(customReader.isDefined) + assert(customReader.get.asInstanceOf[ColumnarCustomShuffleReaderExec].hasCoalescedPartition) + + // Repartition with partition default num specified. + val dfRepartitionWithNum = df.repartition(5, 'b) + dfRepartitionWithNum.collect() + val planWithNum = dfRepartitionWithNum.queryExecution.executedPlan + // The top shuffle from repartition is optimized out. + assert(!hasRepartitionShuffle(planWithNum)) + val bhjWithNum = findTopLevelBroadcastHashJoin(planWithNum) + assert(bhjWithNum.length == 1) + checkNumLocalShuffleReaders(planWithNum, 1) + // Probe side is not coalesced. + assert(bhjWithNum.head.right.find(_.isInstanceOf[CustomShuffleReaderExec]).isEmpty) + + // Repartition with partition non-default num specified. + val dfRepartitionWithNum2 = df.repartition(3, 'b) + dfRepartitionWithNum2.collect() + val planWithNum2 = dfRepartitionWithNum2.queryExecution.executedPlan + // The top shuffle from repartition is not optimized out, and this is the only shuffle that + // does not have local shuffle reader. + assert(hasRepartitionShuffle(planWithNum2)) + val bhjWithNum2 = findTopLevelBroadcastHashJoin(planWithNum2) + assert(bhjWithNum2.length == 1) + checkNumLocalShuffleReaders(planWithNum2, 1) + val customReader2 = bhjWithNum2.head.right + .find(_.isInstanceOf[ColumnarCustomShuffleReaderExec]) + assert(customReader2.isDefined) + assert(customReader2.get.asInstanceOf[ColumnarCustomShuffleReaderExec].isLocalReader) + } + + // Force skew join + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", + SQLConf.SKEW_JOIN_ENABLED.key -> "true", + SQLConf.SKEW_JOIN_SKEWED_PARTITION_THRESHOLD.key -> "1", + SQLConf.SKEW_JOIN_SKEWED_PARTITION_FACTOR.key -> "0", + SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES.key -> "10") { + // Repartition with no partition num specified. + val dfRepartition = df.repartition('b) + dfRepartition.collect() + val plan = dfRepartition.queryExecution.executedPlan + // The top shuffle from repartition is optimized out. + assert(!hasRepartitionShuffle(plan)) + val smj = findTopLevelSortMergeJoin(plan) + assert(smj.length == 1) + // No skew join due to the repartition. + assert(!smj.head.isSkewJoin) + // Both sides are coalesced. + val customReaders = collect(smj.head) { + case c: CustomShuffleReaderExec if c.hasCoalescedPartition => c + case c: ColumnarCustomShuffleReaderExec if c.hasCoalescedPartition => c + } + assert(customReaders.length == 2) + + // Repartition with default partition num specified. + val dfRepartitionWithNum = df.repartition(5, 'b) + dfRepartitionWithNum.collect() + val planWithNum = dfRepartitionWithNum.queryExecution.executedPlan + // The top shuffle from repartition is optimized out. + assert(!hasRepartitionShuffle(planWithNum)) + val smjWithNum = findTopLevelSortMergeJoin(planWithNum) + assert(smjWithNum.length == 1) + // No skew join due to the repartition. + assert(!smjWithNum.head.isSkewJoin) + // No coalesce due to the num in repartition. + val customReadersWithNum = collect(smjWithNum.head) { + case c: CustomShuffleReaderExec if c.hasCoalescedPartition => c + } + assert(customReadersWithNum.isEmpty) + + // Repartition with default non-partition num specified. + val dfRepartitionWithNum2 = df.repartition(3, 'b) + dfRepartitionWithNum2.collect() + val planWithNum2 = dfRepartitionWithNum2.queryExecution.executedPlan + // The top shuffle from repartition is not optimized out. + assert(hasRepartitionShuffle(planWithNum2)) + val smjWithNum2 = findTopLevelSortMergeJoin(planWithNum2) + assert(smjWithNum2.length == 1) + // Skew join can apply as the repartition is not optimized out. + assert(smjWithNum2.head.isSkewJoin) + } + } + } + + ignore("SPARK-34091: Batch shuffle fetch in AQE partition coalescing") { + withSQLConf( + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", + SQLConf.SHUFFLE_PARTITIONS.key -> "10000", + SQLConf.FETCH_SHUFFLE_BLOCKS_IN_BATCH.key -> "true") { + withTable("t1") { + spark.range(100).selectExpr("id + 1 as a").write.format("parquet").saveAsTable("t1") + val query = "SELECT SUM(a) FROM t1 GROUP BY a" + val (_, adaptivePlan) = runAdaptiveAndVerifyResult(query) + val metricName = SQLShuffleReadMetricsReporter.LOCAL_BLOCKS_FETCHED + val blocksFetchedMetric = collectFirst(adaptivePlan) { + case p if p.metrics.contains(metricName) => p.metrics(metricName) + } + assert(blocksFetchedMetric.isDefined) + val blocksFetched = blocksFetchedMetric.get.value + withSQLConf(SQLConf.FETCH_SHUFFLE_BLOCKS_IN_BATCH.key -> "false") { + val (_, adaptivePlan2) = runAdaptiveAndVerifyResult(query) + val blocksFetchedMetric2 = collectFirst(adaptivePlan2) { + case p if p.metrics.contains(metricName) => p.metrics(metricName) + } + assert(blocksFetchedMetric2.isDefined) + val blocksFetched2 = blocksFetchedMetric2.get.value + assert(blocksFetched < blocksFetched2) + } + } + } + } +}