diff --git a/omnioperator/omniop-spark-extension/java/pom.xml b/omnioperator/omniop-spark-extension/java/pom.xml
index caafa313fbd2cb88b124e370f2d73460199b7051..3e3175bab7d6d61499e28ea729810f9865cf4474 100644
--- a/omnioperator/omniop-spark-extension/java/pom.xml
+++ b/omnioperator/omniop-spark-extension/java/pom.xml
@@ -7,7 +7,7 @@
com.huawei.kunpeng
boostkit-omniop-spark-parent
- 3.1.1-1.1.0
+ 3.3.1-1.1.0
../pom.xml
@@ -103,20 +103,20 @@
spark-core_${scala.binary.version}
test-jar
test
- 3.1.1
+ 3.3.1
org.apache.spark
spark-catalyst_${scala.binary.version}
test-jar
test
- 3.1.1
+ 3.3.1
org.apache.spark
spark-sql_${scala.binary.version}
test-jar
- 3.1.1
+ 3.3.1
test
@@ -127,7 +127,7 @@
org.apache.spark
spark-hive_${scala.binary.version}
- 3.1.1
+ 3.3.1
provided
diff --git a/omnioperator/omniop-spark-extension/java/src/main/java/org/apache/spark/sql/execution/vectorized/OmniColumnVector.java b/omnioperator/omniop-spark-extension/java/src/main/java/org/apache/spark/sql/execution/vectorized/OmniColumnVector.java
index 808f96e1fb666def4ff9fc224f01020a81a5baf7..cc750a371cdb64c1e60eac27f5b5881a964b5ac4 100644
--- a/omnioperator/omniop-spark-extension/java/src/main/java/org/apache/spark/sql/execution/vectorized/OmniColumnVector.java
+++ b/omnioperator/omniop-spark-extension/java/src/main/java/org/apache/spark/sql/execution/vectorized/OmniColumnVector.java
@@ -354,6 +354,11 @@ public class OmniColumnVector extends WritableColumnVector {
}
}
+ @Override
+ public void putBooleans(int rowId, byte src) {
+ throw new UnsupportedOperationException("putBooleans is not supported");
+ }
+
@Override
public boolean getBoolean(int rowId) {
if (dictionaryData != null) {
@@ -453,6 +458,11 @@ public class OmniColumnVector extends WritableColumnVector {
return UTF8String.fromBytes(getBytes(rowId, count), rowId, count);
}
+ @Override
+ public ByteBuffer getByteBuffer(int rowId, int count) {
+ throw new UnsupportedOperationException("getByteBuffer is not supported");
+ }
+
//
// APIs dealing with Shorts
//
diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarGuardRule.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarGuardRule.scala
index a4e4eaa0a877f7ee2e3401ecf4ee98fecfcb7314..ec075787233b025ad5caae8a8c80c473a573b5e1 100644
--- a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarGuardRule.scala
+++ b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarGuardRule.scala
@@ -22,7 +22,7 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution._
-import org.apache.spark.sql.execution.adaptive.{BroadcastQueryStageExec, CustomShuffleReaderExec}
+import org.apache.spark.sql.execution.adaptive.{BroadcastQueryStageExec, OmniAQEShuffleReadExec}
import org.apache.spark.sql.execution.aggregate.HashAggregateExec
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec}
import org.apache.spark.sql.execution.joins._
@@ -37,6 +37,9 @@ case class RowGuard(child: SparkPlan) extends SparkPlan {
}
def children: Seq[SparkPlan] = Seq(child)
+
+ @Override override protected def withNewChildrenInternal(newChildren: IndexedSeq[SparkPlan]): SparkPlan =
+ legacyWithNewChildren(newChildren)
}
case class ColumnarGuardRule() extends Rule[SparkPlan] {
@@ -92,6 +95,8 @@ case class ColumnarGuardRule() extends Rule[SparkPlan] {
if (!enableColumnarHashAgg) return false
new ColumnarHashAggregateExec(
plan.requiredChildDistributionExpressions,
+ plan.isStreaming,
+ plan.numShufflePartitions,
plan.groupingExpressions,
plan.aggregateExpressions,
plan.aggregateAttributes,
@@ -127,9 +132,9 @@ case class ColumnarGuardRule() extends Rule[SparkPlan] {
left match {
case exec: BroadcastExchangeExec =>
new ColumnarBroadcastExchangeExec(exec.mode, exec.child)
- case BroadcastQueryStageExec(_, plan: BroadcastExchangeExec) =>
+ case BroadcastQueryStageExec(_, plan: BroadcastExchangeExec, _) =>
new ColumnarBroadcastExchangeExec(plan.mode, plan.child)
- case BroadcastQueryStageExec(_, plan: ReusedExchangeExec) =>
+ case BroadcastQueryStageExec(_, plan: ReusedExchangeExec, _) =>
plan match {
case ReusedExchangeExec(_, b: BroadcastExchangeExec) =>
new ColumnarBroadcastExchangeExec(b.mode, b.child)
@@ -141,9 +146,9 @@ case class ColumnarGuardRule() extends Rule[SparkPlan] {
right match {
case exec: BroadcastExchangeExec =>
new ColumnarBroadcastExchangeExec(exec.mode, exec.child)
- case BroadcastQueryStageExec(_, plan: BroadcastExchangeExec) =>
+ case BroadcastQueryStageExec(_, plan: BroadcastExchangeExec, _) =>
new ColumnarBroadcastExchangeExec(plan.mode, plan.child)
- case BroadcastQueryStageExec(_, plan: ReusedExchangeExec) =>
+ case BroadcastQueryStageExec(_, plan: ReusedExchangeExec, _) =>
plan match {
case ReusedExchangeExec(_, b: BroadcastExchangeExec) =>
new ColumnarBroadcastExchangeExec(b.mode, b.child)
@@ -182,7 +187,8 @@ case class ColumnarGuardRule() extends Rule[SparkPlan] {
plan.buildSide,
plan.condition,
plan.left,
- plan.right).buildCheck()
+ plan.right,
+ plan.isSkewJoin).buildCheck()
case plan: BroadcastNestedLoopJoinExec => return false
case p =>
p
@@ -237,7 +243,7 @@ case class ColumnarGuardRule() extends Rule[SparkPlan] {
case p if !supportCodegen(p) =>
// insert row guard them recursively
p.withNewChildren(p.children.map(insertRowGuardOrNot))
- case p: CustomShuffleReaderExec =>
+ case p: OmniAQEShuffleReadExec =>
p.withNewChildren(p.children.map(insertRowGuardOrNot))
case p: BroadcastQueryStageExec =>
p
diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarPlugin.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarPlugin.scala
index d3fcbaf539493ff8a26b72b6e1b98c4c448b54c9..a94eb5d67d9612feee34c5c23194bf60537392d9 100644
--- a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarPlugin.scala
+++ b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarPlugin.scala
@@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.expressions.DynamicPruningSubquery
import org.apache.spark.sql.catalyst.expressions.aggregate.Partial
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.{RowToOmniColumnarExec, _}
-import org.apache.spark.sql.execution.adaptive.{BroadcastQueryStageExec, ColumnarCustomShuffleReaderExec, CustomShuffleReaderExec, QueryStageExec, ShuffleQueryStageExec}
+import org.apache.spark.sql.execution.adaptive.{BroadcastQueryStageExec, OmniAQEShuffleReadExec, AQEShuffleReadExec, QueryStageExec, ShuffleQueryStageExec}
import org.apache.spark.sql.execution.aggregate.HashAggregateExec
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, Exchange, ReusedExchangeExec, ShuffleExchangeExec}
import org.apache.spark.sql.execution.joins._
@@ -247,6 +247,8 @@ case class ColumnarPreOverrides() extends Rule[SparkPlan] {
case _ =>
new ColumnarHashAggregateExec(
plan.requiredChildDistributionExpressions,
+ plan.isStreaming,
+ plan.numShufflePartitions,
plan.groupingExpressions,
plan.aggregateExpressions,
plan.aggregateAttributes,
@@ -257,6 +259,8 @@ case class ColumnarPreOverrides() extends Rule[SparkPlan] {
} else {
new ColumnarHashAggregateExec(
plan.requiredChildDistributionExpressions,
+ plan.isStreaming,
+ plan.numShufflePartitions,
plan.groupingExpressions,
plan.aggregateExpressions,
plan.aggregateAttributes,
@@ -267,6 +271,8 @@ case class ColumnarPreOverrides() extends Rule[SparkPlan] {
} else {
new ColumnarHashAggregateExec(
plan.requiredChildDistributionExpressions,
+ plan.isStreaming,
+ plan.numShufflePartitions,
plan.groupingExpressions,
plan.aggregateExpressions,
plan.aggregateAttributes,
@@ -311,7 +317,8 @@ case class ColumnarPreOverrides() extends Rule[SparkPlan] {
plan.buildSide,
plan.condition,
left,
- right)
+ right,
+ plan.isSkewJoin)
case plan: SortMergeJoinExec if enableColumnarSortMergeJoin =>
logInfo(s"Columnar Processing for ${plan.getClass} is currently supported.")
val left = replaceWithColumnarPlan(plan.left)
@@ -341,19 +348,19 @@ case class ColumnarPreOverrides() extends Rule[SparkPlan] {
val child = replaceWithColumnarPlan(plan.child)
logInfo(s"Columnar Processing for ${plan.getClass} is currently supported.")
new ColumnarShuffleExchangeExec(plan.outputPartitioning, child, plan.shuffleOrigin)
- case plan: CustomShuffleReaderExec if columnarConf.enableColumnarShuffle =>
+ case plan: AQEShuffleReadExec if columnarConf.enableColumnarShuffle =>
plan.child match {
case shuffle: ColumnarShuffleExchangeExec =>
logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.")
- ColumnarCustomShuffleReaderExec(plan.child, plan.partitionSpecs)
- case ShuffleQueryStageExec(_, shuffle: ColumnarShuffleExchangeExec) =>
+ OmniAQEShuffleReadExec(plan.child, plan.partitionSpecs)
+ case ShuffleQueryStageExec(_, shuffle: ColumnarShuffleExchangeExec, _) =>
logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.")
- ColumnarCustomShuffleReaderExec(plan.child, plan.partitionSpecs)
- case ShuffleQueryStageExec(_, reused: ReusedExchangeExec) =>
+ OmniAQEShuffleReadExec(plan.child, plan.partitionSpecs)
+ case ShuffleQueryStageExec(_, reused: ReusedExchangeExec, _) =>
reused match {
case ReusedExchangeExec(_, shuffle: ColumnarShuffleExchangeExec) =>
logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.")
- ColumnarCustomShuffleReaderExec(
+ OmniAQEShuffleReadExec(
plan.child,
plan.partitionSpecs)
case _ =>
@@ -375,13 +382,15 @@ case class ColumnarPreOverrides() extends Rule[SparkPlan] {
curPlan.id,
BroadcastExchangeExec(
originalBroadcastPlan.mode,
- ColumnarBroadcastExchangeAdaptorExec(originalBroadcastPlan, 1)))
+ ColumnarBroadcastExchangeAdaptorExec(originalBroadcastPlan, 1)),
+ curPlan._canonicalized)
case ReusedExchangeExec(_, originalBroadcastPlan: ColumnarBroadcastExchangeExec) =>
BroadcastQueryStageExec(
curPlan.id,
BroadcastExchangeExec(
originalBroadcastPlan.mode,
- ColumnarBroadcastExchangeAdaptorExec(curPlan.plan, 1)))
+ ColumnarBroadcastExchangeAdaptorExec(curPlan.plan, 1)),
+ curPlan._canonicalized)
case _ =>
curPlan
}
@@ -409,11 +418,26 @@ case class ColumnarPostOverrides() extends Rule[SparkPlan] {
case ColumnarToRowExec(child: ColumnarBroadcastExchangeExec) =>
replaceWithColumnarPlan(child)
case plan: ColumnarToRowExec =>
- val child = replaceWithColumnarPlan(plan.child)
- if (conf.getConfString("spark.omni.sql.columnar.columnarToRow", "true").toBoolean) {
- OmniColumnarToRowExec(child)
- } else {
- ColumnarToRowExec(child)
+ plan.child match {
+ case child: BroadcastQueryStageExec =>
+ child.plan match {
+ case originalBroadcastPlan: ColumnarBroadcastExchangeExec =>
+ BroadcastQueryStageExec(
+ child.id,
+ BroadcastExchangeExec(
+ originalBroadcastPlan.mode,
+ ColumnarBroadcastExchangeAdaptorExec(originalBroadcastPlan, 1)), child._canonicalized)
+ case ReusedExchangeExec(_, originalBroadcastPlan: ColumnarBroadcastExchangeExec) =>
+ BroadcastQueryStageExec(
+ child.id,
+ BroadcastExchangeExec(
+ originalBroadcastPlan.mode,
+ ColumnarBroadcastExchangeAdaptorExec(child.plan, 1)), child._canonicalized)
+ case _ =>
+ replaceColumnarToRow(plan, conf)
+ }
+ case _ =>
+ replaceColumnarToRow(plan, conf)
}
case r: SparkPlan
if !r.isInstanceOf[QueryStageExec] && !r.supportsColumnar && r.children.exists(c =>
@@ -430,6 +454,15 @@ case class ColumnarPostOverrides() extends Rule[SparkPlan] {
val children = p.children.map(replaceWithColumnarPlan)
p.withNewChildren(children)
}
+
+ def replaceColumnarToRow(plan: ColumnarToRowExec, conf: SQLConf) : SparkPlan = {
+ val child = replaceWithColumnarPlan(plan.child)
+ if (conf.getConfString("spark.omni.sql.columnar.columnarToRow", "true").toBoolean) {
+ OmniColumnarToRowExec(child)
+ } else {
+ ColumnarToRowExec(child)
+ }
+ }
}
case class ColumnarOverrideRules(session: SparkSession) extends ColumnarRule with Logging {
diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarPluginConfig.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarPluginConfig.scala
index 29776a07ac820c9d243df4d7cc8fc6135f1b8db6..a698c81089f517d06eaeb011a6606a3a39b81054 100644
--- a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarPluginConfig.scala
+++ b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarPluginConfig.scala
@@ -153,7 +153,7 @@ class ColumnarPluginConfig(conf: SQLConf) extends Logging {
.toBoolean
val enableFusion: Boolean = conf
- .getConfString("spark.omni.sql.columnar.fusion", "true")
+ .getConfString("spark.omni.sql.columnar.fusion", "false")
.toBoolean
// Pick columnar shuffle hash join if one side join count > = 0 to build local hash map, and is
diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ShuffleJoinStrategy.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ShuffleJoinStrategy.scala
index 2071420c9d219e4b4a029bd17eba114ac9d2dd7e..6b065552ceb1d46327fd3644001b0ca408d5d46b 100644
--- a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ShuffleJoinStrategy.scala
+++ b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ShuffleJoinStrategy.scala
@@ -37,7 +37,7 @@ object ShuffleJoinStrategy extends Strategy
ColumnarPluginConfig.getConf.columnarPreferShuffledHashJoinCBO
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
- case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, nonEquiCond, left, right, hint)
+ case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, nonEquiCond, _, left, right, hint)
if columnarPreferShuffledHashJoin =>
val enable = getBroadcastBuildSide(left, right, joinType, hint, true, conf).isEmpty &&
!hintToSortMergeJoin(hint) &&
diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/expression/OmniExpressionAdaptor.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/expression/OmniExpressionAdaptor.scala
index da1a5b7479d8a3814ef6b07dde6b9ad8acdcd50e..c4307082aec6aa42fbd5754b42950fd0c3f201f4 100644
--- a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/expression/OmniExpressionAdaptor.scala
+++ b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/expression/OmniExpressionAdaptor.scala
@@ -668,9 +668,9 @@ object OmniExpressionAdaptor extends Logging {
def toOmniAggFunType(agg: AggregateExpression, isHashAgg: Boolean = false, isFinal: Boolean = false): FunctionType = {
agg.aggregateFunction match {
- case Sum(_) => OMNI_AGGREGATION_TYPE_SUM
+ case Sum(_, _) => OMNI_AGGREGATION_TYPE_SUM
case Max(_) => OMNI_AGGREGATION_TYPE_MAX
- case Average(_) => OMNI_AGGREGATION_TYPE_AVG
+ case Average(_, _) => OMNI_AGGREGATION_TYPE_AVG
case Min(_) => OMNI_AGGREGATION_TYPE_MIN
case Count(Literal(1, IntegerType) :: Nil) | Count(ArrayBuffer(Literal(1, IntegerType))) =>
if (isFinal) {
diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/shuffle/ColumnarShuffleWriter.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/shuffle/ColumnarShuffleWriter.scala
index 7eca3427ec3f6c618f84e70aeb85ce98d0267176..4883203c99954c5fb596f3804e5903c57709db4a 100644
--- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/shuffle/ColumnarShuffleWriter.scala
+++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/shuffle/ColumnarShuffleWriter.scala
@@ -71,7 +71,7 @@ class ColumnarShuffleWriter[K, V](
override def write(records: Iterator[Product2[K, V]]): Unit = {
if (!records.hasNext) {
partitionLengths = new Array[Long](dep.partitioner.numPartitions)
- shuffleBlockResolver.writeIndexFileAndCommit(dep.shuffleId, mapId, partitionLengths, null)
+ shuffleBlockResolver.writeMetadataFileAndCommit(dep.shuffleId, mapId, partitionLengths, Array[Long](), null)
mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths, mapId)
return
}
@@ -107,7 +107,7 @@ class ColumnarShuffleWriter[K, V](
jniWrapper.split(nativeSplitter, vb.getNativeVectorBatch)
dep.splitTime.add(System.nanoTime() - startTime)
dep.numInputRows.add(cb.numRows)
- writeMetrics.incRecordsWritten(1)
+ writeMetrics.incRecordsWritten(cb.numRows())
}
}
val startTime = System.nanoTime()
@@ -122,10 +122,11 @@ class ColumnarShuffleWriter[K, V](
partitionLengths = splitResult.getPartitionLengths
try {
- shuffleBlockResolver.writeIndexFileAndCommit(
+ shuffleBlockResolver.writeMetadataFileAndCommit(
dep.shuffleId,
mapId,
partitionLengths,
+ Array[Long](),
dataTmp)
} finally {
if (dataTmp.exists() && !dataTmp.delete()) {
diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/shuffle/sort/OmniColumnarShuffleManager.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/shuffle/sort/OmniColumnarShuffleManager.scala
index e7c66ee726ae4b9090e41e5d71de386e4b94ed13..28427bba2842f77d53327121001966dbbdb17a01 100644
--- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/shuffle/sort/OmniColumnarShuffleManager.scala
+++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/shuffle/sort/OmniColumnarShuffleManager.scala
@@ -99,7 +99,7 @@ class OmniColumnarShuffleManager(conf: SparkConf) extends ColumnarShuffleManager
env.conf,
metrics,
shuffleExecutorComponents)
- case bypassMergeSortHandle: BypassMergeSortShuffleHandle[K@unchecked, V@unchecked] =>
+ case bypassMergeSortHandle: BypassMergeSortShuffleHandle[K @unchecked, V @unchecked] =>
new BypassMergeSortShuffleWriter(
env.blockManager,
bypassMergeSortHandle,
@@ -107,9 +107,8 @@ class OmniColumnarShuffleManager(conf: SparkConf) extends ColumnarShuffleManager
env.conf,
metrics,
shuffleExecutorComponents)
- case other: BaseShuffleHandle[K@unchecked, V@unchecked, _] =>
+ case other: BaseShuffleHandle[K @unchecked, V @unchecked, _] =>
new SortShuffleWriter(
- shuffleBlockResolver,
other,
mapId,
context,
diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarBasicPhysicalOperators.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarBasicPhysicalOperators.scala
index cb23b68f09bb085d86e133d3ef40628e8c5ca4c2..4dc6ede583fbb33048d7054b0e50f4a309ddd732 100644
--- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarBasicPhysicalOperators.scala
+++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarBasicPhysicalOperators.scala
@@ -18,6 +18,7 @@
package org.apache.spark.sql.execution
import java.util.concurrent.TimeUnit.NANOSECONDS
+
import com.huawei.boostkit.spark.Constant.IS_SKIP_VERIFY_EXP
import com.huawei.boostkit.spark.expression.OmniExpressionAdaptor._
import com.huawei.boostkit.spark.util.OmniAdaptorUtil
@@ -101,6 +102,9 @@ case class ColumnarProjectExec(projectList: Seq[NamedExpression], child: SparkPl
|${ExplainUtils.generateFieldString("Input", child.output)}
|""".stripMargin
}
+
+ override protected def withNewChildInternal(newChild: SparkPlan): ColumnarProjectExec =
+ copy(child = newChild)
}
case class ColumnarFilterExec(condition: Expression, child: SparkPlan)
@@ -109,6 +113,9 @@ case class ColumnarFilterExec(condition: Expression, child: SparkPlan)
override def supportsColumnar: Boolean = true
override def nodeName: String = "OmniColumnarFilter"
+ override protected def withNewChildInternal(newChild: SparkPlan): ColumnarFilterExec =
+ copy(this.condition, newChild)
+
// Split out all the IsNotNulls from condition.
private val (notNullPreds, otherPreds) = splitConjunctivePredicates(condition).partition {
case IsNotNull(a) => isNullIntolerant(a) && a.references.subsetOf(child.outputSet)
@@ -116,7 +123,7 @@ case class ColumnarFilterExec(condition: Expression, child: SparkPlan)
}
// If one expression and its children are null intolerant, it is null intolerant.
- private def isNullIntolerant(expr: Expression): Boolean = expr match {
+ override def isNullIntolerant(expr: Expression): Boolean = expr match {
case e: NullIntolerant => e.children.forall(isNullIntolerant)
case _ => false
}
@@ -267,6 +274,9 @@ case class ColumnarConditionProjectExec(projectList: Seq[NamedExpression],
override def output: Seq[Attribute] = projectList.map(_.toAttribute)
+ override protected def withNewChildInternal(newChild: SparkPlan): ColumnarConditionProjectExec =
+ copy(child = newChild)
+
override lazy val metrics = Map(
"addInputTime" -> SQLMetrics.createTimingMetric(sparkContext, "time in omni addInput"),
"numInputVecBatchs" -> SQLMetrics.createMetric(sparkContext, "number of input vecBatchs"),
@@ -383,7 +393,7 @@ case class ColumnarUnionExec(children: Seq[SparkPlan]) extends SparkPlan {
children.map(_.output).transpose.map { attrs =>
val firstAttr = attrs.head
val nullable = attrs.exists(_.nullable)
- val newDt = attrs.map(_.dataType).reduce(StructType.merge)
+ val newDt = attrs.map(_.dataType).reduce(StructType.unionLikeMerge)
if (firstAttr.dataType == newDt) {
firstAttr.withNullability(nullable)
} else {
@@ -393,6 +403,9 @@ case class ColumnarUnionExec(children: Seq[SparkPlan]) extends SparkPlan {
}
}
+ override protected def withNewChildrenInternal(newChildren: IndexedSeq[SparkPlan]): SparkPlan =
+ legacyWithNewChildren(newChildren)
+
def buildCheck(): Unit = {
val inputTypes = new Array[DataType](output.size)
output.zipWithIndex.foreach {
@@ -420,7 +433,7 @@ class ColumnarRangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range
protected override def doExecuteColumnar(): RDD[ColumnarBatch] = {
val numOutputRows = longMetric("numOutputRows")
- sqlContext
+ session.sqlContext
.sparkContext
.parallelize(0 until numSlices, numSlices)
.mapPartitionsWithIndex { (i, _) =>
diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarBroadcastExchangeAdaptorExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarBroadcastExchangeAdaptorExec.scala
index d137388ab3c41c3ee103ac974cb594990379d394..1d236c16d3849906ec997726ca62ee5957ff0740 100644
--- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarBroadcastExchangeAdaptorExec.scala
+++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarBroadcastExchangeAdaptorExec.scala
@@ -64,4 +64,7 @@ case class ColumnarBroadcastExchangeAdaptorExec(child: SparkPlan, numPartitions:
"numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"),
"numOutputBatches" -> SQLMetrics.createMetric(sparkContext, "output_batches"),
"processTime" -> SQLMetrics.createTimingMetric(sparkContext, "totaltime_datatoarrowcolumnar"))
+
+ override protected def withNewChildInternal(newChild: SparkPlan):
+ ColumnarBroadcastExchangeAdaptorExec = copy(child = newChild)
}
\ No newline at end of file
diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarBroadcastExchangeExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarBroadcastExchangeExec.scala
index 72d1aae05d8f9e4a22a7f0bc17e68aca8b157d74..8a29e0d2bc1531351210fe7ae77b7da0577e2fa2 100644
--- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarBroadcastExchangeExec.scala
+++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarBroadcastExchangeExec.scala
@@ -65,7 +65,7 @@ class ColumnarBroadcastExchangeExec(mode: BroadcastMode, child: SparkPlan)
@transient
override lazy val relationFuture: Future[broadcast.Broadcast[Any]] = {
SQLExecution.withThreadLocalCaptured[broadcast.Broadcast[Any]](
- sqlContext.sparkSession, ColumnarBroadcastExchangeExec.executionContext) {
+ session.sqlContext.sparkSession, ColumnarBroadcastExchangeExec.executionContext) {
try {
// Setup a job group here so later it may get cancelled by groupId if necessary.
sparkContext.setJobGroup(runId.toString, s"broadcast exchange (runId $runId)",
@@ -159,6 +159,9 @@ class ColumnarBroadcastExchangeExec(mode: BroadcastMode, child: SparkPlan)
}
}
+ override protected def withNewChildInternal(newChild: SparkPlan): ColumnarBroadcastExchangeExec =
+ new ColumnarBroadcastExchangeExec(this.mode, newChild)
+
override protected def doPrepare(): Unit = {
// Materialize the future.
relationFuture
diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarExec.scala
index b1fd51f4867cf2c435c8ddd7036bf6f8b6818212..e88fec3a56ad6da646cabe73bc09146064e07718 100644
--- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarExec.scala
+++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarExec.scala
@@ -18,10 +18,8 @@
package org.apache.spark.sql.execution
import java.util.concurrent.TimeUnit.NANOSECONDS
-
import scala.collection.JavaConverters._
import scala.collection.mutable.ListBuffer
-
import org.apache.spark.broadcast
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
@@ -31,8 +29,9 @@ import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
import org.apache.spark.sql.execution.util.SparkMemoryUtils
import org.apache.spark.sql.execution.vectorized.{OffHeapColumnVector, OmniColumnVector, WritableColumnVector}
-import org.apache.spark.sql.types.{BooleanType, ByteType, CalendarIntervalType, DataType, DateType, DecimalType, DoubleType, IntegerType, LongType, ShortType, StringType, StructType, TimestampType}
+import org.apache.spark.sql.types.{BinaryType, BooleanType, ByteType, CalendarIntervalType, DataType, DateType, DecimalType, DoubleType, IntegerType, LongType, ShortType, StringType, StructType, TimestampType}
import org.apache.spark.sql.vectorized.ColumnarBatch
+import org.apache.spark.util.Utils
import nova.hetu.omniruntime.vector.Vec
@@ -101,6 +100,7 @@ private object RowToColumnConverter {
private def getConverterForType(dataType: DataType, nullable: Boolean): TypeConverter = {
val core = dataType match {
+ case BinaryType => BinaryConverter
case BooleanType => BooleanConverter
case ByteType => ByteConverter
case ShortType => ShortConverter
@@ -123,6 +123,13 @@ private object RowToColumnConverter {
}
}
+ private object BinaryConverter extends TypeConverter {
+ override def append(row: SpecializedGetters, column: Int, cv: WritableColumnVector): Unit = {
+ val bytes = row.getBinary(column)
+ cv.appendByteArray(bytes, 0, bytes.length)
+ }
+ }
+
private object BooleanConverter extends TypeConverter {
override def append(row: SpecializedGetters, column: Int, cv: WritableColumnVector): Unit =
cv.appendBoolean(row.getBoolean(column))
@@ -232,8 +239,11 @@ case class RowToOmniColumnarExec(child: SparkPlan) extends RowToColumnarTransiti
"rowToOmniColumnarTime" -> SQLMetrics.createTimingMetric(sparkContext, "time in row to OmniColumnar")
)
+ override protected def withNewChildInternal(newChild: SparkPlan): RowToOmniColumnarExec =
+ copy(child = newChild)
+
override def doExecuteColumnar(): RDD[ColumnarBatch] = {
- val enableOffHeapColumnVector = sqlContext.conf.offHeapColumnVectorEnabled
+ val enableOffHeapColumnVector = session.sqlContext.conf.offHeapColumnVectorEnabled
val numInputRows = longMetric("numInputRows")
val numOutputBatches = longMetric("numOutputBatches")
val rowToOmniColumnarTime = longMetric("rowToOmniColumnarTime")
@@ -313,6 +323,9 @@ case class OmniColumnarToRowExec(child: SparkPlan) extends ColumnarToRowTransiti
ColumnarBatchToInternalRow.convert(localOutput, batches, numOutputRows, numInputBatches, omniColumnarToRowTime)
}
}
+
+ override protected def withNewChildInternal(newChild: SparkPlan):
+ OmniColumnarToRowExec = copy(child = newChild)
}
object ColumnarBatchToInternalRow {
diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarExpandExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarExpandExec.scala
index 27b05b16c017c43e73a0c3b6d4f05ea02d11f951..b25d97d604da1ae0cbaef04b34bbf53e61b8af83 100644
--- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarExpandExec.scala
+++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarExpandExec.scala
@@ -1,3 +1,20 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
package org.apache.spark.sql.execution
import com.huawei.boostkit.spark.Constant.IS_SKIP_VERIFY_EXP
@@ -161,4 +178,6 @@ case class ColumnarExpandExec(
throw new UnsupportedOperationException(s"ColumnarExpandExec operator doesn't support doExecute().")
}
+ override protected def withNewChildInternal(newChild: SparkPlan): ColumnarExpandExec =
+ copy(child = newChild)
}
diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarFileSourceScanExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarFileSourceScanExec.scala
index 73091d069cb311f129fef45078d936ad365e14e0..fb741f5effc94772dc95f0f485dd0e8eb58cce6c 100644
--- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarFileSourceScanExec.scala
+++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarFileSourceScanExec.scala
@@ -47,6 +47,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.optimizer.BuildLeft
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, UnknownPartitioning}
+import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.execution.aggregate.HashAggregateExec
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.execution.datasources.orc.{OmniOrcFileFormat, OrcFileFormat}
@@ -54,6 +55,7 @@ import org.apache.spark.sql.execution.joins.ColumnarBroadcastHashJoinExec
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
import org.apache.spark.sql.execution.util.SparkMemoryUtils
import org.apache.spark.sql.execution.util.SparkMemoryUtils.addLeakSafeTaskCompletionListener
+import org.apache.spark.sql.execution.vectorized.ConstantColumnVector
import org.apache.spark.sql.execution.vectorized.OmniColumnVector
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{DecimalType, StructType}
@@ -74,13 +76,19 @@ abstract class BaseColumnarFileSourceScanExec(
disableBucketedScan: Boolean = false)
extends DataSourceScanExec {
+ lazy val metadataColumns: Seq[AttributeReference] =
+ output.collect { case FileSourceMetadataAttribute(attr) => attr }
+
override lazy val supportsColumnar: Boolean = true
override def vectorTypes: Option[Seq[String]] =
relation.fileFormat.vectorTypes(
requiredSchema = requiredSchema,
partitionSchema = relation.partitionSchema,
- relation.sparkSession.sessionState.conf)
+ relation.sparkSession.sessionState.conf).map { vectorTypes =>
+ // for column-based file format, append metadata column's vector type classes if any
+ vectorTypes ++ Seq.fill(metadataColumns.size)(classOf[ConstantColumnVector].getName)
+ }
private lazy val driverMetrics: HashMap[String, Long] = HashMap.empty
@@ -96,7 +104,7 @@ abstract class BaseColumnarFileSourceScanExec(
}
private def isDynamicPruningFilter(e: Expression): Boolean =
- e.find(_.isInstanceOf[PlanExpression[_]]).isDefined
+ e.exists(_.isInstanceOf[PlanExpression[_]])
@transient lazy val selectedPartitions: Array[PartitionDirectory] = {
val optimizerMetadataTimeNs = relation.location.metadataOpsTimeNs.getOrElse(0L)
@@ -223,7 +231,13 @@ abstract class BaseColumnarFileSourceScanExec(
@transient
private lazy val pushedDownFilters = {
val supportNestedPredicatePushdown = DataSourceUtils.supportNestedPredicatePushdown(relation)
- dataFilters.flatMap(DataSourceStrategy.translateFilter(_, supportNestedPredicatePushdown))
+ // `dataFilters` should not include any metadata col filters
+ // because the metadata struct has been flatted in FileSourceStrategy
+ // and thus metadata col filters are invalid to be pushed down
+ dataFilters.filterNot(_.references.exists {
+ case FileSourceMetadataAttribute(_) => true
+ case _ => false
+ }).flatMap(DataSourceStrategy.translateFilter(_, supportNestedPredicatePushdown))
}
override protected def metadata: Map[String, String] = {
@@ -242,22 +256,27 @@ abstract class BaseColumnarFileSourceScanExec(
"DataFilters" -> seqToString(dataFilters),
"Location" -> locationDesc)
- // (SPARK-32986): Add bucketed scan info in explain output of FileSourceScanExec
- if (bucketedScan) {
- relation.bucketSpec.map { spec =>
+ relation.bucketSpec.map { spec =>
+ val bucketedKey = "Bucketed"
+ if (bucketedScan) {
val numSelectedBuckets = optionalBucketSet.map { b =>
b.cardinality()
} getOrElse {
spec.numBuckets
}
- metadata + ("SelectedBucketsCount" ->
- (s"$numSelectedBuckets out of ${spec.numBuckets}" +
+ metadata ++ Map(
+ bucketedKey -> "true",
+ "SelectedBucketsCount" -> (s"$numSelectedBuckets out of ${spec.numBuckets}" +
optionalNumCoalescedBuckets.map { b => s" (Coalesced to $b)" }.getOrElse("")))
- } getOrElse {
- metadata
+ } else if (!relation.sparkSession.sessionState.conf.bucketingEnabled) {
+ metadata + (bucketedKey -> "false (disabled by configuration)")
+ } else if (disableBucketedScan) {
+ metadata + (bucketedKey -> "false (disabled by query planner")
+ } else {
+ metadata + (bucketedKey -> "false (disabled column(s) not read)")
}
- } else {
- metadata
+ } getOrElse {
+ metadata
}
}
@@ -312,7 +331,7 @@ abstract class BaseColumnarFileSourceScanExec(
createBucketedReadRDD(relation.bucketSpec.get, readFile, dynamicallySelectedPartitions,
relation)
} else {
- createNonBucketedReadRDD(readFile, dynamicallySelectedPartitions, relation)
+ createReadRDD(readFile, dynamicallySelectedPartitions, relation)
}
sendDriverMetrics()
readRDD
@@ -343,7 +362,7 @@ abstract class BaseColumnarFileSourceScanExec(
driverMetrics("staticFilesNum") = filesNum
driverMetrics("staticFilesSize") = filesSize
}
- if (relation.partitionSchemaOption.isDefined) {
+ if (relation.partitionSchema.nonEmpty) {
driverMetrics("numPartitions") = partitions.length
}
}
@@ -363,7 +382,7 @@ abstract class BaseColumnarFileSourceScanExec(
None
}
} ++ {
- if (relation.partitionSchemaOption.isDefined) {
+ if (relation.partitionSchema.nonEmpty) {
Map(
"numPartitions" -> SQLMetrics.createMetric(sparkContext, "number of partitions read"),
"pruningTime" ->
@@ -423,7 +442,7 @@ abstract class BaseColumnarFileSourceScanExec(
/**
* Create an RDD for bucketed reads.
- * The non-bucketed variant of this function is [[createNonBucketedReadRDD]].
+ * The non-bucketed variant of this function is [[createReadRDD]].
*
* The algorithm is pretty simple: each RDD partition being returned should include all the files
* with the same bucket id from all the given Hive partitions.
@@ -447,10 +466,9 @@ abstract class BaseColumnarFileSourceScanExec(
}.groupBy { f =>
BucketingUtils
.getBucketId(new Path(f.filePath).getName)
- .getOrElse(sys.error(s"Invalid bucket file ${f.filePath}"))
+ .getOrElse(throw QueryExecutionErrors.invalidBucketFile(f.filePath))
}
- // (SPARK-32985): Decouple bucket filter pruning and bucketed table scan
val prunedFilesGroupedToBuckets = if (optionalBucketSet.isDefined) {
val bucketSet = optionalBucketSet.get
filesGroupedToBuckets.filter {
@@ -475,7 +493,8 @@ abstract class BaseColumnarFileSourceScanExec(
}
}
- new FileScanRDD(fsRelation.sparkSession, readFile, filePartitions)
+ new FileScanRDD(fsRelation.sparkSession, readFile, filePartitions,
+ new StructType(requiredSchema.fields ++ fsRelation.partitionSchema.fields), metadataColumns)
}
/**
@@ -486,7 +505,7 @@ abstract class BaseColumnarFileSourceScanExec(
* @param selectedPartitions Hive-style partition that are part of the read.
* @param fsRelation [[HadoopFsRelation]] associated with the read.
*/
- private def createNonBucketedReadRDD(
+ private def createReadRDD(
readFile: (PartitionedFile) => Iterator[InternalRow],
selectedPartitions: Array[PartitionDirectory],
fsRelation: HadoopFsRelation): RDD[InternalRow] = {
@@ -496,27 +515,43 @@ abstract class BaseColumnarFileSourceScanExec(
logInfo(s"Planning scan with bin packing, max size: $maxSplitBytes bytes, " +
s"open cost is considered as scanning $openCostInBytes bytes.")
+ // Filter files with bucket pruning if possible
+ val bucketingEnabled = fsRelation.sparkSession.sessionState.conf.bucketingEnabled
+ val shouldProcess: Path => Boolean = optionalBucketSet match {
+ case Some(bucketSet) if bucketingEnabled =>
+ // Do not prune the file if bucket file name is invalid
+ filePath => BucketingUtils.getBucketId(filePath.getName).forall(bucketSet.get)
+ case _ =>
+ _ => true
+ }
+
val splitFiles = selectedPartitions.flatMap { partition =>
partition.files.flatMap { file =>
// getPath() is very expensive so we only want to call it once in this block:
val filePath = file.getPath
- val isSplitable = relation.fileFormat.isSplitable(
- relation.sparkSession, relation.options, filePath)
- PartitionedFileUtil.splitFiles(
- sparkSession = relation.sparkSession,
- file = file,
- filePath = filePath,
- isSplitable = isSplitable,
- maxSplitBytes = maxSplitBytes,
- partitionValues = partition.values
- )
+
+ if (shouldProcess(filePath)) {
+ val isSplitable = relation.fileFormat.isSplitable(
+ relation.sparkSession, relation.options, filePath)
+ PartitionedFileUtil.splitFiles(
+ sparkSession = relation.sparkSession,
+ file = file,
+ filePath = filePath,
+ isSplitable = isSplitable,
+ maxSplitBytes = maxSplitBytes,
+ partitionValues = partition.values
+ )
+ } else {
+ Seq.empty
+ }
}
}.sortBy(_.length)(implicitly[Ordering[Long]].reverse)
val partitions =
FilePartition.getFilePartitions(relation.sparkSession, splitFiles, maxSplitBytes)
- new FileScanRDD(fsRelation.sparkSession, readFile, partitions)
+ new FileScanRDD(fsRelation.sparkSession, readFile, partitions,
+ new StructType(requiredSchema.fields ++ fsRelation.partitionSchema.fields), metadataColumns)
}
// Filters unused DynamicPruningExpression expressions - one which has been replaced
@@ -551,7 +586,7 @@ abstract class BaseColumnarFileSourceScanExec(
throw new UnsupportedOperationException(s"Unsupported final aggregate expression in operator fusion, exp: $exp")
} else if (exp.mode == Partial) {
exp.aggregateFunction match {
- case Sum(_) | Min(_) | Average(_) | Max(_) | Count(_) | First(_, _) =>
+ case Sum(_, _) | Min(_) | Average(_, _) | Max(_) | Count(_) | First(_, _) =>
val aggExp = exp.aggregateFunction.children.head
omniOutputExressionOrder += {
exp.aggregateFunction.inputAggBufferAttributes.head.exprId ->
@@ -569,7 +604,7 @@ abstract class BaseColumnarFileSourceScanExec(
}
} else if (exp.mode == PartialMerge) {
exp.aggregateFunction match {
- case Sum(_) | Min(_) | Average(_) | Max(_) | Count(_) | First(_, _) =>
+ case Sum(_, _) | Min(_) | Average(_, _) | Max(_) | Count(_) | First(_, _) =>
val aggExp = exp.aggregateFunction.children.head
omniOutputExressionOrder += {
exp.aggregateFunction.inputAggBufferAttributes.head.exprId ->
@@ -815,7 +850,7 @@ case class ColumnarMultipleOperatorExec(
None
}
} ++ {
- if (relation.partitionSchemaOption.isDefined) {
+ if (relation.partitionSchema.nonEmpty) {
Map(
"numPartitions" -> SQLMetrics.createMetric(sparkContext, "number of partitions read"),
"pruningTime" ->
@@ -1162,7 +1197,7 @@ case class ColumnarMultipleOperatorExec1(
None
}
} ++ {
- if (relation.partitionSchemaOption.isDefined) {
+ if (relation.partitionSchema.nonEmpty) {
Map(
"numPartitions" -> SQLMetrics.createMetric(sparkContext, "number of partitions read"),
"pruningTime" ->
diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarHashAggregateExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarHashAggregateExec.scala
index e2618842a7dd6a2170b8738be2a299cca2a86d47..be2aa8f0cf8f57e2ac8c56c04d0195c25ba0d0fc 100644
--- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarHashAggregateExec.scala
+++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarHashAggregateExec.scala
@@ -18,6 +18,7 @@
package org.apache.spark.sql.execution
import java.util.concurrent.TimeUnit.NANOSECONDS
+
import com.huawei.boostkit.spark.Constant.IS_SKIP_VERIFY_EXP
import com.huawei.boostkit.spark.expression.OmniExpressionAdaptor._
import com.huawei.boostkit.spark.util.OmniAdaptorUtil
@@ -32,8 +33,9 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
+import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.execution.ColumnarProjection.dealPartitionData
-import org.apache.spark.sql.execution.aggregate.BaseAggregateExec
+import org.apache.spark.sql.execution.aggregate.{AggregateCodegenSupport, BaseAggregateExec}
import org.apache.spark.sql.execution.metric.SQLMetrics
import org.apache.spark.sql.execution.util.SparkMemoryUtils
import org.apache.spark.sql.execution.vectorized.OmniColumnVector
@@ -45,14 +47,18 @@ import org.apache.spark.sql.vectorized.ColumnarBatch
*/
case class ColumnarHashAggregateExec(
requiredChildDistributionExpressions: Option[Seq[Expression]],
+ isStreaming: Boolean,
+ numShufflePartitions: Option[Int],
groupingExpressions: Seq[NamedExpression],
aggregateExpressions: Seq[AggregateExpression],
aggregateAttributes: Seq[Attribute],
initialInputBufferOffset: Int,
resultExpressions: Seq[NamedExpression],
child: SparkPlan)
- extends BaseAggregateExec
- with AliasAwareOutputPartitioning {
+ extends AggregateCodegenSupport {
+
+ override protected def withNewChildInternal(newChild: SparkPlan): ColumnarHashAggregateExec =
+ copy(child = newChild)
override def verboseStringWithOperatorId(): String = {
s"""
@@ -77,6 +83,15 @@ case class ColumnarHashAggregateExec(
"numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"),
"numOutputVecBatchs" -> SQLMetrics.createMetric(sparkContext, "number of output vecBatchs"))
+ protected override def needHashTable: Boolean = true
+
+ protected override def doConsumeWithKeys(ctx: CodegenContext, input: Seq[ExprCode]): String = {
+ throw new UnsupportedOperationException("ColumnarHashAgg code-gen does not support grouping keys")
+ }
+
+ protected override def doProduceWithKeys(ctx: CodegenContext): String = {
+ throw new UnsupportedOperationException("ColumnarHashAgg code-gen does not support grouping keys")
+ }
override def supportsColumnar: Boolean = true
@@ -99,7 +114,7 @@ case class ColumnarHashAggregateExec(
}
if (exp.mode == Final) {
exp.aggregateFunction match {
- case Sum(_) | Min(_) | Max(_) | Count(_) | Average(_) | First(_,_) =>
+ case Sum(_, _) | Min(_) | Max(_) | Count(_) | Average(_, _) | First(_,_) =>
omniAggFunctionTypes(index) = toOmniAggFunType(exp, true, true)
omniAggOutputTypes(index) = toOmniAggInOutType(exp.aggregateFunction.dataType)
omniAggChannels(index) =
@@ -110,7 +125,7 @@ case class ColumnarHashAggregateExec(
}
} else if (exp.mode == PartialMerge) {
exp.aggregateFunction match {
- case Sum(_) | Min(_) | Max(_) | Count(_) | Average(_) | First(_,_) =>
+ case Sum(_, _) | Min(_) | Max(_) | Count(_) | Average(_, _) | First(_,_) =>
omniAggFunctionTypes(index) = toOmniAggFunType(exp, true)
omniAggOutputTypes(index) =
toOmniAggInOutType(exp.aggregateFunction.inputAggBufferAttributes)
@@ -125,7 +140,7 @@ case class ColumnarHashAggregateExec(
}
} else if (exp.mode == Partial) {
exp.aggregateFunction match {
- case Sum(_) | Min(_) | Max(_) | Count(_) | Average(_) | First(_,_) =>
+ case Sum(_, _) | Min(_) | Max(_) | Count(_) | Average(_, _) | First(_,_) =>
omniAggFunctionTypes(index) = toOmniAggFunType(exp, true)
omniAggOutputTypes(index) =
toOmniAggInOutType(exp.aggregateFunction.inputAggBufferAttributes)
@@ -150,7 +165,7 @@ case class ColumnarHashAggregateExec(
omniSourceTypes(i) = sparkTypeToOmniType(attr.dataType, attr.metadata)
}
- for (aggChannel <-omniAggChannels) {
+ for (aggChannel <- omniAggChannels) {
if (!isSimpleColumnForAll(aggChannel)) {
checkOmniJsonWhiteList("", aggChannel.toArray)
}
@@ -202,7 +217,7 @@ case class ColumnarHashAggregateExec(
}
if (exp.mode == Final) {
exp.aggregateFunction match {
- case Sum(_) | Min(_) | Max(_) | Count(_) | Average(_) | First(_,_) =>
+ case Sum(_, _) | Min(_) | Max(_) | Count(_) | Average(_, _) | First(_,_) =>
omniAggFunctionTypes(index) = toOmniAggFunType(exp, true, true)
omniAggOutputTypes(index) =
toOmniAggInOutType(exp.aggregateFunction.dataType)
@@ -214,7 +229,7 @@ case class ColumnarHashAggregateExec(
}
} else if (exp.mode == PartialMerge) {
exp.aggregateFunction match {
- case Sum(_) | Min(_) | Max(_) | Count(_) | Average(_) | First(_,_) =>
+ case Sum(_, _) | Min(_) | Max(_) | Count(_) | Average(_, _) | First(_,_) =>
omniAggFunctionTypes(index) = toOmniAggFunType(exp, true)
omniAggOutputTypes(index) =
toOmniAggInOutType(exp.aggregateFunction.inputAggBufferAttributes)
@@ -229,7 +244,7 @@ case class ColumnarHashAggregateExec(
}
} else if (exp.mode == Partial) {
exp.aggregateFunction match {
- case Sum(_) | Min(_) | Max(_) | Count(_) | Average(_) | First(_,_) =>
+ case Sum(_, _) | Min(_) | Max(_) | Count(_) | Average(_, _) | First(_,_) =>
omniAggFunctionTypes(index) = toOmniAggFunType(exp, true)
omniAggOutputTypes(index) =
toOmniAggInOutType(exp.aggregateFunction.inputAggBufferAttributes)
@@ -338,10 +353,3 @@ case class ColumnarHashAggregateExec(
throw new UnsupportedOperationException("This operator doesn't support doExecute().")
}
}
-
-object ColumnarHashAggregateExec {
- def supportsAggregate(aggregateBufferAttributes: Seq[Attribute]): Boolean = {
- val aggregationBufferSchema = StructType.fromAttributes(aggregateBufferAttributes)
- UnsafeFixedWidthAggregationMap.supportsAggregationBufferSchema(aggregationBufferSchema)
- }
-}
diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarShuffleExchangeExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarShuffleExchangeExec.scala
index cea0a1438b1c64a0d1372e2a272742aa9be08502..746e1898a23528b869c6097fdd4cc8bc799d59be 100644
--- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarShuffleExchangeExec.scala
+++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarShuffleExchangeExec.scala
@@ -18,8 +18,6 @@
package org.apache.spark.sql.execution
import com.huawei.boostkit.spark.ColumnarPluginConfig
-
-import java.util.Random
import com.huawei.boostkit.spark.Constant.IS_SKIP_VERIFY_EXP
import scala.collection.JavaConverters._
@@ -41,6 +39,7 @@ import org.apache.spark.shuffle.ColumnarShuffleDependency
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.LazilyGeneratedOrdering
+import org.apache.spark.sql.catalyst.plans.logical.Statistics
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution.exchange.{ENSURE_REQUIREMENTS, ShuffleExchangeExec, ShuffleExchangeLike, ShuffleOrigin}
import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec.createShuffleWriteProcessor
@@ -53,8 +52,9 @@ import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{IntegerType, StructType}
import org.apache.spark.sql.vectorized.ColumnarBatch
import org.apache.spark.util.MutablePair
+import org.apache.spark.util.random.XORShiftRandom
-class ColumnarShuffleExchangeExec(
+case class ColumnarShuffleExchangeExec(
override val outputPartitioning: Partitioning,
child: SparkPlan,
shuffleOrigin: ShuffleOrigin = ENSURE_REQUIREMENTS)
@@ -62,7 +62,7 @@ class ColumnarShuffleExchangeExec(
private lazy val writeMetrics =
SQLShuffleWriteMetricsReporter.createShuffleWriteMetrics(sparkContext)
- override lazy val readMetrics =
+ private[sql] lazy val readMetrics =
SQLShuffleReadMetricsReporter.createShuffleReadMetrics(sparkContext)
override lazy val metrics: Map[String, SQLMetric] = Map(
"dataSize" -> SQLMetrics.createSizeMetric(sparkContext, "data size"),
@@ -100,9 +100,19 @@ class ColumnarShuffleExchangeExec(
override def numPartitions: Int = columnarShuffleDependency.partitioner.numPartitions
+ override def getShuffleRDD(partitionSpecs: Array[ShufflePartitionSpec]): RDD[ColumnarBatch] = {
+ new ShuffledColumnarRDD(columnarShuffleDependency, readMetrics, partitionSpecs)
+ }
+
+ override def runtimeStatistics: Statistics = {
+ val dataSize = metrics("dataSize").value
+ val rowCount = metrics(SQLShuffleWriteMetricsReporter.SHUFFLE_RECORDS_WRITTEN).value
+ Statistics(dataSize, Some(rowCount))
+ }
+
@transient
lazy val columnarShuffleDependency: ShuffleDependency[Int, ColumnarBatch, ColumnarBatch] = {
- ColumnarShuffleExchangeExec.prepareShuffleDependency(
+ val dep = ColumnarShuffleExchangeExec.prepareShuffleDependency(
inputColumnarRDD,
child.output,
outputPartitioning,
@@ -113,8 +123,8 @@ class ColumnarShuffleExchangeExec(
longMetric("numInputRows"),
longMetric("splitTime"),
longMetric("spillTime"))
+ dep
}
-
var cachedShuffleRDD: ShuffledColumnarRDD = _
override def doExecute(): RDD[InternalRow] = {
@@ -155,6 +165,8 @@ class ColumnarShuffleExchangeExec(
cachedShuffleRDD
}
}
+ override protected def withNewChildInternal(newChild: SparkPlan): ColumnarShuffleExchangeExec =
+ copy(child = newChild)
}
object ColumnarShuffleExchangeExec extends Logging {
@@ -324,6 +336,7 @@ object ColumnarShuffleExchangeExec extends Logging {
rdd.mapPartitionsWithIndexInternal((_, cbIter) => {
cbIter.map { cb => (0, cb) }
}, isOrderSensitive = isOrderSensitive)
+ case _ => throw new IllegalStateException(s"Exchange not implemented for $newPartitioning")
}
val numCols = outputAttributes.size
@@ -341,6 +354,7 @@ object ColumnarShuffleExchangeExec extends Logging {
new PartitionInfo("hash", numPartitions, numCols, intputTypes)
case RangePartitioning(ordering, numPartitions) =>
new PartitionInfo("range", numPartitions, numCols, intputTypes)
+ case _ => throw new IllegalStateException(s"Exchange not implemented for $newPartitioning")
}
new ColumnarShuffleDependency[Int, ColumnarBatch, ColumnarBatch](
diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarSortExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarSortExec.scala
index 7c7001dbc1c468465a0115946aeff9849d51a3df..49f2451112f66915d95b57e76dfdd8203a2af635 100644
--- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarSortExec.scala
+++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarSortExec.scala
@@ -56,6 +56,9 @@ case class ColumnarSortExec(
override def outputPartitioning: Partitioning = child.outputPartitioning
+ override protected def withNewChildInternal(newChild: SparkPlan): ColumnarSortExec =
+ copy(child = newChild)
+
override def requiredChildDistribution: Seq[Distribution] =
if (global) OrderedDistribution(sortOrder) :: Nil else UnspecifiedDistribution :: Nil
diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarTakeOrderedAndProjectExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarTakeOrderedAndProjectExec.scala
index 6fec9f9a054f345a83bc20f278c2ed3be57e6dbd..0e5fac68c74aea7980f7f207dee4345565515a62 100644
--- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarTakeOrderedAndProjectExec.scala
+++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarTakeOrderedAndProjectExec.scala
@@ -49,6 +49,9 @@ case class ColumnarTakeOrderedAndProjectExec(
override def nodeName: String = "OmniColumnarTakeOrderedAndProject"
+ override protected def withNewChildInternal(newChild: SparkPlan): ColumnarTakeOrderedAndProjectExec =
+ copy(child = newChild)
+
val serializer: Serializer = new ColumnarBatchSerializer(
longMetric("avgReadBatchNumRows"),
longMetric("numOutputRows"))
diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarWindowExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarWindowExec.scala
index e5534d3c67680a747f1b4d92fcb2c377c81577c9..63414c781030455c89ebc434cd54db0cbcbbd34a 100644
--- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarWindowExec.scala
+++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarWindowExec.scala
@@ -50,6 +50,9 @@ case class ColumnarWindowExec(windowExpression: Seq[NamedExpression],
override def supportsColumnar: Boolean = true
+ override protected def withNewChildInternal(newChild: SparkPlan): ColumnarWindowExec =
+ copy(child = newChild)
+
override lazy val metrics = Map(
"addInputTime" -> SQLMetrics.createTimingMetric(sparkContext, "time in omni addInput"),
"numInputVecBatchs" -> SQLMetrics.createMetric(sparkContext, "number of input vecBatchs"),
@@ -59,25 +62,6 @@ case class ColumnarWindowExec(windowExpression: Seq[NamedExpression],
"numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"),
"numOutputVecBatchs" -> SQLMetrics.createMetric(sparkContext, "number of output vecBatchs"))
- override def output: Seq[Attribute] =
- child.output ++ windowExpression.map(_.toAttribute)
-
- override def requiredChildDistribution: Seq[Distribution] = {
- if (partitionSpec.isEmpty) {
- // Only show warning when the number of bytes is larger than 100 MiB?
- logWarning("No Partition Defined for Window operation! Moving all data to a single "
- + "partition, this can cause serious performance degradation.")
- AllTuples :: Nil
- } else ClusteredDistribution(partitionSpec) :: Nil
- }
-
- override def requiredChildOrdering: Seq[Seq[SortOrder]] =
- Seq(partitionSpec.map(SortOrder(_, Ascending)) ++ orderSpec)
-
- override def outputOrdering: Seq[SortOrder] = child.outputOrdering
-
- override def outputPartitioning: Partitioning = child.outputPartitioning
-
override protected def doExecute(): RDD[InternalRow] = {
throw new UnsupportedOperationException(s"This operator doesn't support doExecute().")
}
diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ShuffledColumnarRDD.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ShuffledColumnarRDD.scala
index 1e728239b988592aa7ef8e89cba4cccf0751c065..eb11d449c25483e231ee1c672295a0e83186214f 100644
--- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ShuffledColumnarRDD.scala
+++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ShuffledColumnarRDD.scala
@@ -24,6 +24,43 @@ import org.apache.spark.sql.execution.metric.{SQLMetric, SQLShuffleReadMetricsRe
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.vectorized.ColumnarBatch
+sealed trait ShufflePartitionSpec
+
+// A partition that reads data of one or more reducers, from `startReducerIndex` (inclusive) to
+// `endReducerIndex` (exclusive).
+case class CoalescedPartitionSpec(
+ startReducerIndex: Int,
+ endReducerIndex: Int,
+ @transient dataSize: Option[Long] = None) extends ShufflePartitionSpec
+
+object CoalescedPartitionSpec {
+ def apply(startReducerIndex: Int,
+ endReducerIndex: Int,
+ dataSize: Long): CoalescedPartitionSpec = {
+ CoalescedPartitionSpec(startReducerIndex, endReducerIndex, Some(dataSize))
+ }
+}
+
+// A partition that reads partial data of one reducer, from `startMapIndex` (inclusive) to
+// `endMapIndex` (exclusive).
+case class PartialReducerPartitionSpec(
+ reducerIndex: Int,
+ startMapIndex: Int,
+ endMapIndex: Int,
+ @transient dataSize: Long) extends ShufflePartitionSpec
+
+// A partition that reads partial data of one mapper, from `startReducerIndex` (inclusive) to
+// `endReducerIndex` (exclusive).
+case class PartialMapperPartitionSpec(
+ mapIndex: Int,
+ startReducerIndex: Int,
+ endReducerIndex: Int) extends ShufflePartitionSpec
+
+case class CoalescedMapperPartitionSpec(
+ startMapIndex: Int,
+ endMapIndex: Int,
+ numReducers: Int) extends ShufflePartitionSpec
+
/**
* The [[Partition]] used by [[ShuffledRowRDD]].
*/
@@ -70,7 +107,7 @@ class ShuffledColumnarRDD(
override def getPreferredLocations(partition: Partition): Seq[String] = {
val tracker = SparkEnv.get.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster]
partition.asInstanceOf[ShuffledColumnarRDDPartition].spec match {
- case CoalescedPartitionSpec(startReducerIndex, endReducerIndex) =>
+ case CoalescedPartitionSpec(startReducerIndex, endReducerIndex, _) =>
startReducerIndex.until(endReducerIndex).flatMap { reducerIndex =>
tracker.getPreferredLocationsForShuffle(dependency, reducerIndex)
}
@@ -80,6 +117,9 @@ class ShuffledColumnarRDD(
case PartialMapperPartitionSpec(mapIndex, _, _) =>
tracker.getMapLocation(dependency, mapIndex, mapIndex + 1)
+
+ case CoalescedMapperPartitionSpec(startMapIndex, endMapIndex, numReducers) =>
+ tracker.getMapLocation(dependency, startMapIndex, endMapIndex)
}
}
@@ -89,7 +129,7 @@ class ShuffledColumnarRDD(
// as well as the `tempMetrics` for basic shuffle metrics.
val sqlMetricsReporter = new SQLShuffleReadMetricsReporter(tempMetrics, metrics)
val reader = split.asInstanceOf[ShuffledColumnarRDDPartition].spec match {
- case CoalescedPartitionSpec(startReducerIndex, endReducerIndex) =>
+ case CoalescedPartitionSpec(startReducerIndex, endReducerIndex, _) =>
SparkEnv.get.shuffleManager.getReader(
dependency.shuffleHandle,
startReducerIndex,
@@ -116,7 +156,22 @@ class ShuffledColumnarRDD(
endReducerIndex,
context,
sqlMetricsReporter)
+
+ case CoalescedMapperPartitionSpec(startMapIndex, endMapIndex, numReducers) =>
+ SparkEnv.get.shuffleManager.getReader(
+ dependency.shuffleHandle,
+ startMapIndex,
+ endMapIndex,
+ 0,
+ numReducers,
+ context,
+ sqlMetricsReporter)
}
reader.read().asInstanceOf[Iterator[Product2[Int, ColumnarBatch]]].map(_._2)
}
+
+ override def clearDependencies(): Unit = {
+ super.clearDependencies()
+ dependency = null
+ }
}
\ No newline at end of file
diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEPropagateEmptyRelation.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEPropagateEmptyRelation.scala
new file mode 100644
index 0000000000000000000000000000000000000000..b5de9dff4b303652ca1eb4fa08f4d07aef81d7bc
--- /dev/null
+++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEPropagateEmptyRelation.scala
@@ -0,0 +1,100 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.adaptive
+
+import org.apache.spark.sql.catalyst.optimizer.PropagateEmptyRelationBase
+import org.apache.spark.sql.catalyst.planning.ExtractSingleColumnNullAwareAntiJoin
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.catalyst.trees.TreePattern.{LOCAL_RELATION, LOGICAL_QUERY_STAGE, TRUE_OR_FALSE_LITERAL}
+import org.apache.spark.sql.execution.ColumnarHashedRelation
+import org.apache.spark.sql.execution.aggregate.BaseAggregateExec
+import org.apache.spark.sql.execution.exchange.{REPARTITION_BY_COL, REPARTITION_BY_NUM, ShuffleExchangeLike}
+import org.apache.spark.sql.execution.joins.HashedRelationWithAllNullKeys
+
+/**
+ * This rule runs in the AQE optimizer and optimizes more cases
+ * compared to [[PropagateEmptyRelationBase]]:
+ * 1. Join is single column NULL-aware anti join (NAAJ)
+ * Broadcasted [[HashedRelation]] is [[HashedRelationWithAllNullKeys]]. Eliminate join to an
+ * empty [[LocalRelation]].
+ */
+object AQEPropagateEmptyRelation extends AQEPropagateEmptyRelationBase {
+ override protected def isEmpty(plan: LogicalPlan): Boolean =
+ super.isEmpty(plan) || (!isRootRepartition(plan) && getEstimatedRowCount(plan).contains(0))
+
+ override protected def notEmpty(plan: LogicalPlan): Boolean =
+ super.notEmpty(plan) || getEstimatedRowCount(plan).exists(_ > 0)
+
+ private def isRootRepartition(plan: LogicalPlan): Boolean = plan match {
+ case l: LogicalQueryStage if l.getTagValue(ROOT_REPARTITION).isDefined => true
+ case _ => false
+ }
+
+ // The returned value follows:
+ // - 0 means the plan must produce 0 row
+ // - positive value means an estimated row count which can be over-estimated
+ // - none means the plan has not materialized or the plan can not be estimated
+ private def getEstimatedRowCount(plan: LogicalPlan): Option[BigInt] = plan match {
+ case LogicalQueryStage(_, stage: QueryStageExec) if stage.isMeterialized =>
+ stage.getRuntimeStatistics.rowCount
+
+ case LogicalQueryStage(_, agg: BaseAggregateExec) if agg.groupingExpressions.nonEmpty &&
+ agg.child.isInstanceOf[QueryStageExec] =>
+ val stage = agg.child.asInstanceOf[QueryStageExec]
+ if (stage.isMeterialized) {
+ stage.getRuntimeStatistics.rowCount
+ } else {
+ None
+ }
+
+ case _ => None
+ }
+
+ private def isRelationWithAllNullKeys(plan: LogicalPlan): Boolean = plan match {
+ case LogicalQueryStage(_, stage: BroadcastQueryStageExec) if stage.isMeterialized =>
+ if (stage.broadcast.supportsColumnar) {
+ val colRelation = stage.broadcast.relationFuture.get().value.asInstanceOf[ColumnarHashedRelation]
+ colRelation.relation == HashedRelationWithAllNullKeys
+ } else {
+ stage.broadcast.relationFuture.get().value == HashedRelationWithAllNullKeys
+ }
+ case _ => false
+ }
+
+ private def eliminateSingleColumnarNullAwareAntiJoin: PartialFunction[LogicalPlan, LogicalPlan] = {
+ case j @ ExtractSingleColumnarNullAwareAntiJoin(_, _) if isRelationWithAllNullKeys(j.right) =>
+ empty(j)
+ }
+
+ override protected def userSpecifiedRepartition(p: LogicalPlan): Boolean = p match {
+ case LogicalQueryStage(_, ShuffleQueryStageExec(_, shuffle: ShuffleExchangeLike, _))
+ if shuffle.shuffleOrigin == REPARTITION_BY_COL ||
+ shuffle.shuffleOrigin == REPARTITION_BY_NUM => true
+ case _ => false
+ }
+
+ override protected def applyInternal(p: LogicalPlan): LogicalPlan = p.transformUpWithPruning(
+ // LOCAL_RELATION and TRUE_OR_FALSE_LITERAL pattern are matched at
+ // `PropagateEmptyRelationBase.commonApplyFunc`
+ // LOGICAL_QUERY_STAGE pattern is matched at `PropagateEmptyRelationBase.commonAppleFunc`
+ // and `AQEPropagateEmptyRelation.eliminateSingleColumnarNullAwareAntiJoin`
+ // Note that, We can not specify ruleId here since the LogicalQueryStage is not immutable.
+ _.containsAnyPattern(LOGICAL_QUERY_STAGE, LOCAL_RELATION, TRUE_OR_FALSE_LITERAL)) {
+ eliminateSingleColumnarNullAwareAntiJoin.orElse(commonApplyFunc)
+ }
+}
diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/adaptive/EliminateJoinToEmptyRelation.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/adaptive/EliminateJoinToEmptyRelation.scala
deleted file mode 100644
index 4edf0f4f86cb79a3e2b3a5c2fc01c999a42349c8..0000000000000000000000000000000000000000
--- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/adaptive/EliminateJoinToEmptyRelation.scala
+++ /dev/null
@@ -1,63 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.execution.adaptive
-
-import org.apache.spark.sql.catalyst.planning.ExtractSingleColumnNullAwareAntiJoin
-import org.apache.spark.sql.catalyst.plans.{Inner, LeftSemi}
-import org.apache.spark.sql.catalyst.plans.logical.{Join, LocalRelation, LogicalPlan}
-import org.apache.spark.sql.catalyst.rules.Rule
-import org.apache.spark.sql.execution.ColumnarHashedRelation
-import org.apache.spark.sql.execution.joins.{EmptyHashedRelation, HashedRelation, HashedRelationWithAllNullKeys}
-
-/**
- * This optimization rule detects and converts a Join to an empty [[LocalRelation]]:
- * 1. Join is single column NULL-aware anti join (NAAJ), and broadcasted [[HashedRelation]]
- * is [[HashedRelationWithAllNullKeys]].
- *
- * 2. Join is inner or left semi join, and broadcasted [[HashedRelation]]
- * is [[EmptyHashedRelation]].
- * This applies to all Joins (sort merge join, shuffled hash join, and broadcast hash join),
- * because sort merge join and shuffled hash join will be changed to broadcast hash join with AQE
- * at the first place.
- */
-object EliminateJoinToEmptyRelation extends Rule[LogicalPlan] {
-
- private def canEliminate(plan: LogicalPlan, relation: HashedRelation): Boolean = plan match {
- case LogicalQueryStage(_, stage: BroadcastQueryStageExec) if stage.resultOption.get().isDefined
- && stage.broadcast.relationFuture.get().value == relation => true
- case LogicalQueryStage(_, stage: BroadcastQueryStageExec) if stage.resultOption.get().isDefined
- && stage.broadcast.supportsColumnar => {
- val cr = stage.broadcast.relationFuture.get().value.asInstanceOf[ColumnarHashedRelation]
- cr.relation == relation
- }
- case _ => false
- }
-
- def apply(plan: LogicalPlan): LogicalPlan = plan.transformDown {
- case j @ ExtractSingleColumnNullAwareAntiJoin(_, _)
- if canEliminate(j.right, HashedRelationWithAllNullKeys) =>
- LocalRelation(j.output, data = Seq.empty, isStreaming = j.isStreaming)
-
- case j @ Join(_, _, Inner, _, _) if canEliminate(j.left, EmptyHashedRelation) ||
- canEliminate(j.right, EmptyHashedRelation) =>
- LocalRelation(j.output, data = Seq.empty, isStreaming = j.isStreaming)
-
- case j @ Join(_, _, LeftSemi, _, _) if canEliminate(j.right, EmptyHashedRelation) =>
- LocalRelation(j.output, data = Seq.empty, isStreaming = j.isStreaming)
- }
-}
diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/adaptive/ColumnarCustomShuffleReaderExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/adaptive/OmniAQEShuffleReaderExec.scala
similarity index 99%
rename from omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/adaptive/ColumnarCustomShuffleReaderExec.scala
rename to omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/adaptive/OmniAQEShuffleReaderExec.scala
index d34b93e5b0da5b61ac35c0824acbf817f1a5e938..c26bed04f20f97821eb9a528cf943b0d2c240c92 100644
--- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/adaptive/ColumnarCustomShuffleReaderExec.scala
+++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/adaptive/OmniAQEShuffleReaderExec.scala
@@ -36,7 +36,7 @@ import scala.collection.mutable.ArrayBuffer
* node during canonicalization.
* @param partitionSpecs The partition specs that defines the arrangement.
*/
-case class ColumnarCustomShuffleReaderExec(
+case class OmniAQEShuffleReaderExec(
child: SparkPlan,
partitionSpecs: Seq[ShufflePartitionSpec])
extends UnaryExecNode {
diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OmniOrcFileFormat.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OmniOrcFileFormat.scala
index 0e5a7eae6efaac64d59c9effdcd8304d30c5c9fe..57ca9688df38b19cad0c70e17021ec0588a6dd13 100644
--- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OmniOrcFileFormat.scala
+++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OmniOrcFileFormat.scala
@@ -51,7 +51,7 @@ class OmniOrcFileFormat extends FileFormat with DataSourceRegister with Serializ
sparkSession: SparkSession,
options: Map[String, String],
files: Seq[FileStatus]): Option[StructType] = {
- OrcUtils.inferSchema(sparkSession, files, options)
+ OmniOrcUtils.inferSchema(sparkSession, files, options)
}
override def buildReaderWithPartitionValues(
@@ -82,18 +82,17 @@ class OmniOrcFileFormat extends FileFormat with DataSourceRegister with Serializ
val fs = filePath.getFileSystem(conf)
val readerOptions = OrcFile.readerOptions(conf).filesystem(fs)
- val resultedColPruneInfo =
- Utils.tryWithResource(OrcFile.createReader(filePath, readerOptions)) { reader =>
- OrcUtils.requestedColumnIds(
- isCaseSensitive, dataSchema, requiredSchema, reader, conf)
- }
+ val orcSchema =
+ Utils.tryWithResource(OrcFile.createReader(filePath, readerOptions))(_.getSchema)
+ val resultedColPruneInfo = OmniOrcUtils.requestedColumnIds(
+ isCaseSensitive, dataSchema, requiredSchema, orcSchema, conf)
if (resultedColPruneInfo.isEmpty) {
Iterator.empty
} else {
// ORC predicate pushdown
- if (orcFilterPushDown) {
- OrcUtils.readCatalystSchema(filePath, conf, ignoreCorruptFiles).foreach {
+ if (orcFilterPushDown && filters.nonEmpty) {
+ OmniOrcUtils.readCatalystSchema(filePath, conf, ignoreCorruptFiles).foreach {
fileSchema => OrcFilters.createFilter(fileSchema, filters).foreach { f =>
OrcInputFormat.setSearchArgument(conf, f, fileSchema.fieldNames)
}
@@ -101,12 +100,15 @@ class OmniOrcFileFormat extends FileFormat with DataSourceRegister with Serializ
}
val (requestedColIds, canPruneCols) = resultedColPruneInfo.get
- val resultSchemaString = OrcUtils.orcResultSchemaString(canPruneCols,
+ val resultSchemaString = OmniOrcUtils.orcResultSchemaString(canPruneCols,
dataSchema, resultSchema, partitionSchema, conf)
assert(requestedColIds.length == requiredSchema.length,
"[BUG] requested column IDs do not match required schema")
val taskConf = new Configuration(conf)
+ val includeColumns = requestedColIds.filter(_ != -1).sorted.mkString(",")
+ taskConf.set(OrcConf.INCLUDE_COLUMNS.getAttribute, includeColumns)
+
val fileSplit = new FileSplit(filePath, file.start, file.length, Array.empty)
val attemptId = new TaskAttemptID(new TaskID(new JobID(), TaskType.MAP, 0), 0)
val taskAttemptContext = new TaskAttemptContextImpl(taskConf, attemptId)
diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OmniOrcUtils.scala
similarity index 95%
rename from omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala
rename to omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OmniOrcUtils.scala
index 3392caa54f0cff52820b98196f9cbd0235151ef3..71b04ef489d7e3271a071ac439ccbdf0ba40decc 100644
--- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala
+++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OmniOrcUtils.scala
@@ -37,7 +37,7 @@ import org.apache.spark.sql.execution.datasources.SchemaMergeUtils
import org.apache.spark.sql.types._
import org.apache.spark.util.{ThreadUtils, Utils}
-object OrcUtils extends Logging {
+object OmniOrcUtils extends Logging {
// The extensions for ORC compression codecs
val extensionsForCompressionCodecNames = Map(
@@ -121,7 +121,7 @@ object OrcUtils extends Logging {
def readOrcSchemasInParallel(
files: Seq[FileStatus], conf: Configuration, ignoreCorruptFiles: Boolean): Seq[StructType] = {
ThreadUtils.parmap(files, "readingOrcSchemas", 8) { currentFile =>
- OrcUtils.readSchema(currentFile.getPath, conf, ignoreCorruptFiles).map(toCatalystSchema)
+ OmniOrcUtils.readSchema(currentFile.getPath, conf, ignoreCorruptFiles).map(toCatalystSchema)
}.flatten
}
@@ -130,9 +130,9 @@ object OrcUtils extends Logging {
val orcOptions = new OrcOptions(options, sparkSession.sessionState.conf)
if (orcOptions.mergeSchema) {
SchemaMergeUtils.mergeSchemasInParallel(
- sparkSession, options, files, OrcUtils.readOrcSchemasInParallel)
+ sparkSession, options, files, OmniOrcUtils.readOrcSchemasInParallel)
} else {
- OrcUtils.readSchema(sparkSession, files, options)
+ OmniOrcUtils.readSchema(sparkSession, files, options)
}
}
@@ -246,9 +246,9 @@ object OrcUtils extends Logging {
partitionSchema: StructType,
conf: Configuration): String = {
val resultSchemaString = if (canPruneCols) {
- OrcUtils.orcTypeDescriptionString(resultSchema)
+ OmniOrcUtils.orcTypeDescriptionString(resultSchema)
} else {
- OrcUtils.orcTypeDescriptionString(StructType(dataSchema.fields ++ partitionSchema.fields))
+ OmniOrcUtils.orcTypeDescriptionString(StructType(dataSchema.fields ++ partitionSchema.fields))
}
OrcConf.MAPRED_INPUT_SCHEMA.setString(conf, resultSchemaString)
resultSchemaString
diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/joins/ColumnarBroadcastHashJoinExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/joins/ColumnarBroadcastHashJoinExec.scala
index a2ee977f979a873bfb3447c59abac52319e0e0a1..2c1271fb009f14ce324fda3537c618298d198a16 100644
--- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/joins/ColumnarBroadcastHashJoinExec.scala
+++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/joins/ColumnarBroadcastHashJoinExec.scala
@@ -97,6 +97,9 @@ case class ColumnarBroadcastHashJoinExec(
override def nodeName: String = "OmniColumnarBroadcastHashJoin"
+ override protected def withNewChildrenInternal(newLeft: SparkPlan, newRight: SparkPlan):
+ ColumnarBroadcastHashJoinExec = copy(left = newLeft, right = newRight)
+
override def requiredChildDistribution: Seq[Distribution] = {
val mode = HashedRelationBroadcastMode(buildBoundKeys, isNullAwareAntiJoin)
buildSide match {
@@ -109,7 +112,7 @@ case class ColumnarBroadcastHashJoinExec(
override lazy val outputPartitioning: Partitioning = {
joinType match {
- case _: InnerLike if sqlContext.conf.broadcastHashJoinOutputPartitioningExpandLimit > 0 =>
+ case _: InnerLike if session.sqlContext.conf.broadcastHashJoinOutputPartitioningExpandLimit > 0 =>
streamedPlan.outputPartitioning match {
case h: HashPartitioning => expandOutputPartitioning(h)
case c: PartitioningCollection => expandOutputPartitioning(c)
@@ -150,7 +153,7 @@ case class ColumnarBroadcastHashJoinExec(
// Seq("a", "b", "c"), Seq("a", "b", "y"), Seq("a", "x", "c"), Seq("a", "x", "y").
// The expanded expressions are returned as PartitioningCollection.
private def expandOutputPartitioning(partitioning: HashPartitioning): PartitioningCollection = {
- val maxNumCombinations = sqlContext.conf.broadcastHashJoinOutputPartitioningExpandLimit
+ val maxNumCombinations = session.sqlContext.conf.broadcastHashJoinOutputPartitioningExpandLimit
var currentNumCombinations = 0
def generateExprCombinations(
diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/joins/ColumnarShuffledHashJoinExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/joins/ColumnarShuffledHashJoinExec.scala
index 9eb666fcc85df9295656973c4a833c52a472669e..263af0ddbeb6c6ac8ca7d64917eedf0e889782e9 100644
--- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/joins/ColumnarShuffledHashJoinExec.scala
+++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/joins/ColumnarShuffledHashJoinExec.scala
@@ -50,7 +50,8 @@ case class ColumnarShuffledHashJoinExec(
buildSide: BuildSide,
condition: Option[Expression],
left: SparkPlan,
- right: SparkPlan)
+ right: SparkPlan,
+ isSkewJoin: Boolean)
extends HashJoin with ShuffledJoin {
override lazy val metrics = Map(
@@ -81,6 +82,9 @@ case class ColumnarShuffledHashJoinExec(
override def outputPartitioning: Partitioning = super[ShuffledJoin].outputPartitioning
+ override protected def withNewChildrenInternal(newLeft: SparkPlan, newRight: SparkPlan):
+ ColumnarShuffledHashJoinExec = copy(left = newLeft, right = newRight)
+
override def outputOrdering: Seq[SortOrder] = joinType match {
case FullOuter => Nil
case _ => super.outputOrdering
diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/joins/ColumnarSortMergeJoinExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/joins/ColumnarSortMergeJoinExec.scala
index 59b763428b1f5955c581f85aee8765280dccac01..d55af2d9d7e8c0427ab7e27333c3e91a45a8be8d 100644
--- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/joins/ColumnarSortMergeJoinExec.scala
+++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/joins/ColumnarSortMergeJoinExec.scala
@@ -68,6 +68,12 @@ class ColumnarSortMergeJoinExec(
if (isSkewJoin) "OmniColumnarSortMergeJoin(skew=true)" else "OmniColumnarSortMergeJoin"
}
+ override protected def withNewChildrenInternal(newLeft: SparkPlan,
+ newRight: SparkPlan): ColumnarSortMergeJoinExec = {
+ new ColumnarSortMergeJoinExec(this.leftKeys, this.rightKeys, this.joinType,
+ this.condition, newLeft, newRight, this.isSkewJoin)
+ }
+
val SMJ_NEED_ADD_STREAM_TBL_DATA = 2
val SMJ_NEED_ADD_BUFFERED_TBL_DATA = 3
val SCAN_FINISH = 4
diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/hive/execution/PruneHiveTablePartitions.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/hive/execution/PruneHiveTablePartitions.scala
deleted file mode 100644
index 0503b2b7b684537f5191585cccf8b55cf50997d8..0000000000000000000000000000000000000000
--- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/hive/execution/PruneHiveTablePartitions.scala
+++ /dev/null
@@ -1,126 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.hive.execution
-
-import org.apache.hadoop.hive.common.StatsSetupConst
-
-import org.apache.spark.sql.SparkSession
-import org.apache.spark.sql.catalyst.analysis.CastSupport
-import org.apache.spark.sql.catalyst.catalog._
-import org.apache.spark.sql.catalyst.expressions.{And, AttributeSet, Expression, ExpressionSet, PredicateHelper, SubqueryExpression}
-import org.apache.spark.sql.catalyst.planning.PhysicalOperation
-import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, Project}
-import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.FilterEstimation
-import org.apache.spark.sql.catalyst.rules.Rule
-import org.apache.spark.sql.execution.datasources.DataSourceStrategy
-
-/**
- * Prune hive table partitions using partition filters on [[HiveTableRelation]]. The pruned
- * partitions will be kept in [[HiveTableRelation.prunedPartitions]], and the statistics of
- * the hive table relation will be updated based on pruned partitions.
- *
- * This rule is executed in optimization phase, so the statistics can be updated before physical
- * planning, which is useful for some spark strategy, e.g.
- * [[org.apache.spark.sql.execution.SparkStrategies.JoinSelection]].
- *
- * TODO: merge this with PruneFileSourcePartitions after we completely make hive as a data source.
- */
-private[sql] class PruneHiveTablePartitions(session: SparkSession)
- extends Rule[LogicalPlan] with CastSupport with PredicateHelper {
-
- /**
- * Extract the partition filters from the filters on the table.
- */
- private def getPartitionKeyFilters(
- filters: Seq[Expression],
- relation: HiveTableRelation): ExpressionSet = {
- val normalizedFilters = DataSourceStrategy.normalizeExprs(
- filters.filter(f => f.deterministic && !SubqueryExpression.hasSubquery(f)), relation.output)
- val partitionColumnSet = AttributeSet(relation.partitionCols)
- ExpressionSet(
- normalizedFilters.flatMap(extractPredicatesWithinOutputSet(_, partitionColumnSet)))
- }
-
- /**
- * Prune the hive table using filters on the partitions of the table.
- */
- private def prunePartitions(
- relation: HiveTableRelation,
- partitionFilters: ExpressionSet): Seq[CatalogTablePartition] = {
- if (conf.metastorePartitionPruning) {
- session.sessionState.catalog.listPartitionsByFilter(
- relation.tableMeta.identifier, partitionFilters.toSeq)
- } else {
- ExternalCatalogUtils.prunePartitionsByFilter(relation.tableMeta,
- session.sessionState.catalog.listPartitions(relation.tableMeta.identifier),
- partitionFilters.toSeq, conf.sessionLocalTimeZone)
- }
- }
-
- /**
- * Update the statistics of the table.
- */
- private def updateTableMeta(
- relation: HiveTableRelation,
- prunedPartitions: Seq[CatalogTablePartition],
- partitionKeyFilters: ExpressionSet): CatalogTable = {
- val sizeOfPartitions = prunedPartitions.map { partition =>
- val rawDataSize = partition.parameters.get(StatsSetupConst.RAW_DATA_SIZE).map(_.toLong)
- val totalSize = partition.parameters.get(StatsSetupConst.TOTAL_SIZE).map(_.toLong)
- if (rawDataSize.isDefined && rawDataSize.get > 0) {
- rawDataSize.get
- } else if (totalSize.isDefined && totalSize.get > 0L) {
- totalSize.get
- } else {
- 0L
- }
- }
- // Fix spark issue SPARK-34119(row 95-106)
- if (sizeOfPartitions.forall(_ > 0)) {
- val filteredStats =
- FilterEstimation(Filter(partitionKeyFilters.reduce(And), relation)).estimate
- val colStats = filteredStats.map(_.attributeStats.map { case (attr, colStat) =>
- (attr.name, colStat.toCatalogColumnStat(attr.name, attr.dataType))
- })
- relation.tableMeta.copy(
- stats = Some(CatalogStatistics(
- sizeInBytes = BigInt(sizeOfPartitions.sum),
- rowCount = filteredStats.flatMap(_.rowCount),
- colStats = colStats.getOrElse(Map.empty))))
- } else {
- relation.tableMeta
- }
- }
-
- override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
- case op @ PhysicalOperation(projections, filters, relation: HiveTableRelation)
- if filters.nonEmpty && relation.isPartitioned && relation.prunedPartitions.isEmpty =>
- val partitionKeyFilters = getPartitionKeyFilters(filters, relation)
- if (partitionKeyFilters.nonEmpty) {
- val newPartitions = prunePartitions(relation, partitionKeyFilters)
- // Fix spark issue SPARK-34119(row 117)
- val newTableMeta = updateTableMeta(relation, newPartitions, partitionKeyFilters)
- val newRelation = relation.copy(
- tableMeta = newTableMeta, prunedPartitions = Some(newPartitions))
- // Keep partition filters so that they are visible in physical planning
- Project(projections, Filter(filters.reduceLeft(And), newRelation))
- } else {
- op
- }
- }
-}
diff --git a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/shuffle/ColumnShuffleSerializerDisableCompressSuite.scala b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/shuffle/ColumnShuffleSerializerDisableCompressSuite.scala
index 237321f5921726da1b80119936b8bcadaa6f1c95..62a837953b5358a4e460f580a8048bb91fa5b759 100644
--- a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/shuffle/ColumnShuffleSerializerDisableCompressSuite.scala
+++ b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/shuffle/ColumnShuffleSerializerDisableCompressSuite.scala
@@ -107,14 +107,14 @@ class ColumnShuffleSerializerDisableCompressSuite extends SharedSparkSession {
when(blockResolver.getDataFile(0, 0)).thenReturn(outputFile)
doAnswer { (invocationOnMock: InvocationOnMock) =>
- val tmp = invocationOnMock.getArguments()(3).asInstanceOf[File]
+ val tmp = invocationOnMock.getArguments()(4).asInstanceOf[File]
if (tmp != null) {
outputFile.delete
tmp.renameTo(outputFile)
}
null
}.when(blockResolver)
- .writeIndexFileAndCommit(anyInt, anyLong, any(classOf[Array[Long]]), any(classOf[File]))
+ .writeMetadataFileAndCommit(anyInt, anyLong, any(classOf[Array[Long]]), any(classOf[Array[Long]]), any(classOf[File]))
}
override def afterEach(): Unit = {
diff --git a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/shuffle/ColumnShuffleSerializerLz4Suite.scala b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/shuffle/ColumnShuffleSerializerLz4Suite.scala
index 8f0329248c9cf8e75b277b9cae4a3bd3e5a2e361..a8f287e1f77d8ed058239e610b679ac22533a583 100644
--- a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/shuffle/ColumnShuffleSerializerLz4Suite.scala
+++ b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/shuffle/ColumnShuffleSerializerLz4Suite.scala
@@ -108,14 +108,14 @@ class ColumnShuffleSerializerLz4Suite extends SharedSparkSession {
when(blockResolver.getDataFile(0, 0)).thenReturn(outputFile)
doAnswer { (invocationOnMock: InvocationOnMock) =>
- val tmp = invocationOnMock.getArguments()(3).asInstanceOf[File]
+ val tmp = invocationOnMock.getArguments()(4).asInstanceOf[File]
if (tmp != null) {
outputFile.delete
tmp.renameTo(outputFile)
}
null
}.when(blockResolver)
- .writeIndexFileAndCommit(anyInt, anyLong, any(classOf[Array[Long]]), any(classOf[File]))
+ .writeMetadataFileAndCommit(anyInt, anyLong, any(classOf[Array[Long]]), any(classOf[Array[Long]]), any(classOf[File]))
}
override def afterEach(): Unit = {
diff --git a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/shuffle/ColumnShuffleSerializerSnappySuite.scala b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/shuffle/ColumnShuffleSerializerSnappySuite.scala
index 5b6811b03362294e35ca39a65de42592a9385aa8..df3004cce9479f43d81facc7a517f28ed18f02d9 100644
--- a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/shuffle/ColumnShuffleSerializerSnappySuite.scala
+++ b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/shuffle/ColumnShuffleSerializerSnappySuite.scala
@@ -108,14 +108,14 @@ class ColumnShuffleSerializerSnappySuite extends SharedSparkSession {
when(blockResolver.getDataFile(0, 0)).thenReturn(outputFile)
doAnswer { (invocationOnMock: InvocationOnMock) =>
- val tmp = invocationOnMock.getArguments()(3).asInstanceOf[File]
+ val tmp = invocationOnMock.getArguments()(4).asInstanceOf[File]
if (tmp != null) {
outputFile.delete
tmp.renameTo(outputFile)
}
null
}.when(blockResolver)
- .writeIndexFileAndCommit(anyInt, anyLong, any(classOf[Array[Long]]), any(classOf[File]))
+ .writeMetadataFileAndCommit(anyInt, anyLong, any(classOf[Array[Long]]), any(classOf[Array[Long]]), any(classOf[File]))
}
override def afterEach(): Unit = {
diff --git a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/shuffle/ColumnShuffleSerializerZlibSuite.scala b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/shuffle/ColumnShuffleSerializerZlibSuite.scala
index a9924a95d42310d1f784088f1b67e015f45d1ca3..8c3b27914008e57217fc5a45628d793bc3ab9d6f 100644
--- a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/shuffle/ColumnShuffleSerializerZlibSuite.scala
+++ b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/shuffle/ColumnShuffleSerializerZlibSuite.scala
@@ -108,14 +108,14 @@ class ColumnShuffleSerializerZlibSuite extends SharedSparkSession {
when(blockResolver.getDataFile(0, 0)).thenReturn(outputFile)
doAnswer { (invocationOnMock: InvocationOnMock) =>
- val tmp = invocationOnMock.getArguments()(3).asInstanceOf[File]
+ val tmp = invocationOnMock.getArguments()(4).asInstanceOf[File]
if (tmp != null) {
outputFile.delete
tmp.renameTo(outputFile)
}
null
}.when(blockResolver)
- .writeIndexFileAndCommit(anyInt, anyLong, any(classOf[Array[Long]]), any(classOf[File]))
+ .writeMetadataFileAndCommit(anyInt, anyLong, any(classOf[Array[Long]]), any(classOf[Array[Long]]), any(classOf[File]))
}
override def afterEach(): Unit = {
diff --git a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/shuffle/ColumnarShuffleWriterSuite.scala b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/shuffle/ColumnarShuffleWriterSuite.scala
index 00adf145979e33f7dd7b1c49873fd72cdff18756..d527c177805cd54baa53c18c98dddfd84870b953 100644
--- a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/shuffle/ColumnarShuffleWriterSuite.scala
+++ b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/shuffle/ColumnarShuffleWriterSuite.scala
@@ -107,14 +107,14 @@ class ColumnarShuffleWriterSuite extends SharedSparkSession {
when(blockResolver.getDataFile(0, 0)).thenReturn(outputFile)
doAnswer { (invocationOnMock: InvocationOnMock) =>
- val tmp = invocationOnMock.getArguments()(3).asInstanceOf[File]
+ val tmp = invocationOnMock.getArguments()(4).asInstanceOf[File]
if (tmp != null) {
outputFile.delete
tmp.renameTo(outputFile)
}
null
}.when(blockResolver)
- .writeIndexFileAndCommit(anyInt, anyLong, any(classOf[Array[Long]]), any(classOf[File]))
+ .writeMetadataFileAndCommit(anyInt, anyLong, any(classOf[Array[Long]]), any(classOf[Array[Long]]), any(classOf[File]))
}
override def afterEach(): Unit = {
diff --git a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala
new file mode 100644
index 0000000000000000000000000000000000000000..d3cbaa8c41e2d133c8b2ebd450195118a5c293ed
--- /dev/null
+++ b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala
@@ -0,0 +1,307 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.optimizer
+
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.dsl.plans._
+import org.apache.spark.sql.catalyst.expressions.Attribute
+import org.apache.spark.sql.catalyst.plans._
+import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.catalyst.rules._
+
+class CombiningLimitsSuite extends PlanTest {
+
+ object Optimize extends RuleExecutor[LogicalPlan] {
+ val batches =
+ Batch("Column Pruning", FixedPoint(100),
+ ColumnPruning,
+ RemoveNoopOperators) ::
+ Batch("Eliminate Limit", FixedPoint(10),
+ EliminateLimits) ::
+ Batch("Constant Folding", FixedPoint(10),
+ NullPropagation,
+ ConstantFolding,
+ BooleanSimplification,
+ SimplifyConditionals) :: Nil
+ }
+
+ val testRelation = LocalRelation.fromExternalRows(
+ Seq("a".attr.int, "b".attr.int, "c".attr.int),
+ 1.to(10).map(_ => Row(1, 2, 3))
+ )
+ val testRelation2 = LocalRelation.fromExternalRows(
+ Seq("x".attr.int, "y".attr.int, "z".attr.int),
+ Seq(Row(1, 2, 3), Row(2, 3, 4))
+ )
+ val testRelation3 = RelationWithoutMaxRows(Seq("i".attr.int))
+ val testRelation4 = LongMaxRelation(Seq("j".attr.int))
+ val testRelation5 = EmptyRelation(Seq("k".attr.int))
+
+ test("limits: combines two limits") {
+ val originalQuery =
+ testRelation
+ .select('a)
+ .limit(10)
+ .limit(5)
+
+ val optimized = Optimize.execute(originalQuery.analyze)
+ val correctAnswer =
+ testRelation
+ .select('a)
+ .limit(5).analyze
+
+ comparePlans(optimized, correctAnswer)
+ }
+
+ test("limits: combines three limits") {
+ val originalQuery =
+ testRelation
+ .select('a)
+ .limit(2)
+ .limit(7)
+ .limit(5)
+
+ val optimized = Optimize.execute(originalQuery.analyze)
+ val correctAnswer =
+ testRelation
+ .select('a)
+ .limit(2).analyze
+
+ comparePlans(optimized, correctAnswer)
+ }
+
+ test("limits: combines two limits after ColumnPruning") {
+ val originalQuery =
+ testRelation
+ .select('a)
+ .limit(2)
+ .select('a)
+ .limit(5)
+
+ val optimized = Optimize.execute(originalQuery.analyze)
+ val correctAnswer =
+ testRelation
+ .select('a)
+ .limit(2).analyze
+
+ comparePlans(optimized, correctAnswer)
+ }
+
+ test("SPARK-33442: Change Combine Limit to Eliminate limit using max row") {
+ // test child max row <= limit.
+ val query1 = testRelation.select().groupBy()(count(1)).limit(1).analyze
+ val optimized1 = Optimize.execute(query1)
+ val expected1 = testRelation.select().groupBy()(count(1)).analyze
+ comparePlans(optimized1, expected1)
+
+ // test child max row > limit.
+ val query2 = testRelation.select().groupBy()(count(1)).limit(0).analyze
+ val optimized2 = Optimize.execute(query2)
+ comparePlans(optimized2, query2)
+
+ // test child max row is none
+ val query3 = testRelation.select(Symbol("a")).limit(1).analyze
+ val optimized3 = Optimize.execute(query3)
+ comparePlans(optimized3, query3)
+
+ // test sort after limit
+ val query4 = testRelation.select().groupBy()(count(1))
+ .orderBy(count(1).asc).limit(1).analyze
+ val optimized4 = Optimize.execute(query4)
+ // the top project has been removed, so we need optimize expected too
+ val expected4 = Optimize.execute(
+ testRelation.select().groupBy()(count(1)).orderBy(count(1).asc).analyze)
+ comparePlans(optimized4, expected4)
+ }
+
+ test("SPARK-33497: Eliminate Limit if LocalRelation max rows not larger than Limit") {
+ checkPlanAndMaxRow(
+ testRelation.select().limit(10),
+ testRelation.select(),
+ 10
+ )
+ }
+
+ test("SPARK-33497: Eliminate Limit if Range max rows not larger than Limit") {
+ checkPlanAndMaxRow(
+ Range(0, 100, 1, None).select().limit(200),
+ Range(0, 100, 1, None).select(),
+ 100
+ )
+ checkPlanAndMaxRow(
+ Range(-1, Long.MaxValue, 1, None).select().limit(1),
+ Range(-1, Long.MaxValue, 1, None).select().limit(1),
+ 1
+ )
+ }
+
+ test("SPARK-33497: Eliminate Limit if Sample max rows not larger than Limit") {
+ checkPlanAndMaxRow(
+ testRelation.select().sample(0, 0.2, false, 1).limit(10),
+ testRelation.select().sample(0, 0.2, false, 1),
+ 10
+ )
+ }
+
+ test("SPARK-38271: PoissonSampler may output more rows than child.maxRows") {
+ val query = testRelation.select().sample(0, 0.2, true, 1)
+ assert(query.maxRows.isEmpty)
+ val optimized = Optimize.execute(query.analyze)
+ assert(optimized.maxRows.isEmpty)
+ // can not eliminate Limit since Sample.maxRows is None
+ checkPlanAndMaxRow(
+ query.limit(10),
+ query.limit(10),
+ 10
+ )
+ }
+
+ test("SPARK-33497: Eliminate Limit if Deduplicate max rows not larger than Limit") {
+ checkPlanAndMaxRow(
+ testRelation.deduplicate("a".attr).limit(10),
+ testRelation.deduplicate("a".attr),
+ 10
+ )
+ }
+
+ test("SPARK-33497: Eliminate Limit if Repartition max rows not larger than Limit") {
+ checkPlanAndMaxRow(
+ testRelation.repartition(2).limit(10),
+ testRelation.repartition(2),
+ 10
+ )
+ checkPlanAndMaxRow(
+ testRelation.distribute("a".attr)(2).limit(10),
+ testRelation.distribute("a".attr)(2),
+ 10
+ )
+ }
+
+ test("SPARK-33497: Eliminate Limit if Join max rows not larger than Limit") {
+ Seq(Inner, FullOuter, LeftOuter, RightOuter).foreach { joinType =>
+ checkPlanAndMaxRow(
+ testRelation.join(testRelation2, joinType).limit(20),
+ testRelation.join(testRelation2, joinType),
+ 20
+ )
+ checkPlanAndMaxRow(
+ testRelation.join(testRelation2, joinType).limit(10),
+ testRelation.join(testRelation2, joinType).limit(10),
+ 10
+ )
+ // without maxRow
+ checkPlanAndMaxRow(
+ testRelation.join(testRelation3, joinType).limit(100),
+ testRelation.join(testRelation3, joinType).limit(100),
+ 100
+ )
+ // maxRow is not valid long
+ checkPlanAndMaxRow(
+ testRelation.join(testRelation4, joinType).limit(100),
+ testRelation.join(testRelation4, joinType).limit(100),
+ 100
+ )
+ }
+
+ Seq(LeftSemi, LeftAnti).foreach { joinType =>
+ checkPlanAndMaxRow(
+ testRelation.join(testRelation2, joinType).limit(5),
+ testRelation.join(testRelation2.select(), joinType).limit(5),
+ 5
+ )
+ checkPlanAndMaxRow(
+ testRelation.join(testRelation2, joinType).limit(10),
+ testRelation.join(testRelation2.select(), joinType),
+ 10
+ )
+ }
+ }
+
+ test("SPARK-33497: Eliminate Limit if Window max rows not larger than Limit") {
+ checkPlanAndMaxRow(
+ testRelation.window(
+ Seq(count(1).as("c")), Seq("a".attr), Seq("b".attr.asc)).limit(20),
+ testRelation.window(
+ Seq(count(1).as("c")), Seq("a".attr), Seq("b".attr.asc)),
+ 10
+ )
+ }
+
+ test("SPARK-34628: Remove GlobalLimit operator if its child max rows <= limit") {
+ val query = GlobalLimit(100, testRelation)
+ val optimized = Optimize.execute(query.analyze)
+ comparePlans(optimized, testRelation)
+ }
+
+ test("SPARK-37064: Fix outer join return the wrong max rows if other side is empty") {
+ Seq(LeftOuter, FullOuter).foreach { joinType =>
+ checkPlanAndMaxRow(
+ testRelation.join(testRelation5, joinType).limit(9),
+ testRelation.join(testRelation5, joinType).limit(9),
+ 9
+ )
+
+ checkPlanAndMaxRow(
+ testRelation.join(testRelation5, joinType).limit(10),
+ testRelation.join(testRelation5, joinType),
+ 10
+ )
+ }
+
+ Seq(RightOuter, FullOuter).foreach { joinType =>
+ checkPlanAndMaxRow(
+ testRelation5.join(testRelation, joinType).limit(9),
+ testRelation5.join(testRelation, joinType).limit(9),
+ 9
+ )
+
+ checkPlanAndMaxRow(
+ testRelation5.join(testRelation, joinType).limit(10),
+ testRelation5.join(testRelation, joinType),
+ 10
+ )
+ }
+
+ Seq(Inner, Cross).foreach { joinType =>
+ checkPlanAndMaxRow(
+ testRelation.join(testRelation5, joinType).limit(9),
+ testRelation.join(testRelation5, joinType),
+ 0
+ )
+ }
+ }
+
+ private def checkPlanAndMaxRow(
+ optimized: LogicalPlan, expected: LogicalPlan, expectedMaxRow: Long): Unit = {
+ comparePlans(Optimize.execute(optimized.analyze), expected.analyze)
+ assert(expected.maxRows.get == expectedMaxRow)
+ }
+}
+
+case class RelationWithoutMaxRows(output: Seq[Attribute]) extends LeafNode {
+ override def maxRows: Option[Long] = None
+}
+
+case class LongMaxRelation(output: Seq[Attribute]) extends LeafNode {
+ override def maxRows: Option[Long] = Some(Long.MaxValue)
+}
+
+case class EmptyRelation(output: Seq[Attribute]) extends LeafNode {
+ override def maxRows: Option[Long] = Some(0)
+}
diff --git a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConvertToLocalRelationSuite.scala b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConvertToLocalRelationSuite.scala
new file mode 100644
index 0000000000000000000000000000000000000000..02b6eed9ed050e3e718dd855d07f01da4c8ddb0f
--- /dev/null
+++ b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConvertToLocalRelationSuite.scala
@@ -0,0 +1,110 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.optimizer
+
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.dsl.plans._
+import org.apache.spark.sql.catalyst.expressions.{Expression, GenericInternalRow, LessThan, Literal, UnaryExpression}
+import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
+import org.apache.spark.sql.catalyst.plans.PlanTest
+import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
+import org.apache.spark.sql.catalyst.rules.RuleExecutor
+import org.apache.spark.sql.types.{DataType, StructType}
+
+
+class ConvertToLocalRelationSuite extends PlanTest {
+
+ object Optimize extends RuleExecutor[LogicalPlan] {
+ val batches =
+ Batch("LocalRelation", FixedPoint(100),
+ ConvertToLocalRelation) :: Nil
+ }
+
+ test("Project on LocalRelation should be turned into a single LocalRelation") {
+ val testRelation = LocalRelation(
+ LocalRelation('a.int, 'b.int).output,
+ InternalRow(1, 2) :: InternalRow(4, 5) :: Nil)
+
+ val correctAnswer = LocalRelation(
+ LocalRelation('a1.int, 'b1.int).output,
+ InternalRow(1, 3) :: InternalRow(4, 6) :: Nil)
+
+ val projectOnLocal = testRelation.select(
+ UnresolvedAttribute("a").as("a1"),
+ (UnresolvedAttribute("b") + 1).as("b1"))
+
+ val optimized = Optimize.execute(projectOnLocal.analyze)
+
+ comparePlans(optimized, correctAnswer)
+ }
+
+ test("Filter on LocalRelation should be turned into a single LocalRelation") {
+ val testRelation = LocalRelation(
+ LocalRelation('a.int, 'b.int).output,
+ InternalRow(1, 2) :: InternalRow(4, 5) :: Nil)
+
+ val correctAnswer = LocalRelation(
+ LocalRelation('a1.int, 'b1.int).output,
+ InternalRow(1, 3) :: Nil)
+
+ val filterAndProjectOnLocal = testRelation
+ .select(UnresolvedAttribute("a").as("a1"), (UnresolvedAttribute("b") + 1).as("b1"))
+ .where(LessThan(UnresolvedAttribute("b1"), Literal.create(6)))
+
+ val optimized = Optimize.execute(filterAndProjectOnLocal.analyze)
+
+ comparePlans(optimized, correctAnswer)
+ }
+
+ test("SPARK-27798: Expression reusing output shouldn't override values in local relation") {
+ val testRelation = LocalRelation(
+ LocalRelation('a.int).output,
+ InternalRow(1) :: InternalRow(2) :: Nil)
+
+ val correctAnswer = LocalRelation(
+ LocalRelation('a.struct('a1.int)).output,
+ InternalRow(InternalRow(1)) :: InternalRow(InternalRow(2)) :: Nil)
+
+ val projected = testRelation.select(ExprReuseOutput(UnresolvedAttribute("a")).as("a"))
+ val optimized = Optimize.execute(projected.analyze)
+
+ comparePlans(optimized, correctAnswer)
+ }
+}
+
+
+// Dummy expression used for testing. It reuses output row. Assumes child expr outputs an integer.
+case class ExprReuseOutput(child: Expression) extends UnaryExpression {
+ override def dataType: DataType = StructType.fromDDL("a1 int")
+ override def nullable: Boolean = true
+
+ override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode =
+ throw new UnsupportedOperationException("Should not trigger codegen")
+
+ private val row: InternalRow = new GenericInternalRow(1)
+
+ override def eval(input: InternalRow): Any = {
+ row.update(0, child.eval(input))
+ row
+ }
+
+ override protected def withNewChildInternal(newChild: Expression): ExprReuseOutput =
+ copy(child = newChild)
+}
diff --git a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeOneRowPlanSuite.scala b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeOneRowPlanSuite.scala
new file mode 100644
index 0000000000000000000000000000000000000000..3266febb9ed69d06b75e3a92855924c943ab6ec2
--- /dev/null
+++ b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeOneRowPlanSuite.scala
@@ -0,0 +1,104 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.optimizer
+
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.dsl.plans._
+import org.apache.spark.sql.catalyst.expressions.Literal
+import org.apache.spark.sql.catalyst.plans._
+import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.catalyst.rules.RuleExecutor
+
+class OptimizeOneRowPlanSuite extends PlanTest {
+ object Optimize extends RuleExecutor[LogicalPlan] {
+ val batches =
+ Batch("Replace Operators", Once, ReplaceDistinctWithAggregate) ::
+ Batch("Eliminate Sorts", Once, EliminateSorts) ::
+ Batch("Optimize One Row Plan", FixedPoint(10), OptimizeOneRowPlan) :: Nil
+ }
+
+ private val t1 = LocalRelation.fromExternalRows(Seq($"a".int), data = Seq(Row(1)))
+ private val t2 = LocalRelation.fromExternalRows(Seq($"a".int), data = Seq(Row(1), Row(2)))
+
+ test("SPARK-35906: Remove order by if the maximum number of rows less than or equal to 1") {
+ comparePlans(
+ Optimize.execute(t2.groupBy()(count(1).as("cnt")).orderBy('cnt.asc)).analyze,
+ t2.groupBy()(count(1).as("cnt")).analyze)
+
+ comparePlans(
+ Optimize.execute(t2.limit(Literal(1)).orderBy('a.asc).orderBy('a.asc)).analyze,
+ t2.limit(Literal(1)).analyze)
+ }
+
+ test("Remove sort") {
+ // remove local sort
+ val plan1 = LocalLimit(0, t1).union(LocalLimit(0, t2)).sortBy($"a".desc).analyze
+ val expected = LocalLimit(0, t1).union(LocalLimit(0, t2)).analyze
+ comparePlans(Optimize.execute(plan1), expected)
+
+ // do not remove
+ val plan2 = t2.orderBy($"a".desc).analyze
+ comparePlans(Optimize.execute(plan2), plan2)
+
+ val plan3 = t2.sortBy($"a".desc).analyze
+ comparePlans(Optimize.execute(plan3), plan3)
+ }
+
+ test("Convert group only aggregate to project") {
+ val plan1 = t1.groupBy($"a")($"a").analyze
+ comparePlans(Optimize.execute(plan1), t1.select($"a").analyze)
+
+ val plan2 = t1.groupBy($"a" + 1)($"a" + 1).analyze
+ comparePlans(Optimize.execute(plan2), t1.select($"a" + 1).analyze)
+
+ // do not remove
+ val plan3 = t2.groupBy($"a")($"a").analyze
+ comparePlans(Optimize.execute(plan3), plan3)
+
+ val plan4 = t1.groupBy($"a")(sum($"a")).analyze
+ comparePlans(Optimize.execute(plan4), plan4)
+
+ val plan5 = t1.groupBy()(sum($"a")).analyze
+ comparePlans(Optimize.execute(plan5), plan5)
+ }
+
+ test("Remove distinct in aggregate expression") {
+ val plan1 = t1.groupBy($"a")(sumDistinct($"a").as("s")).analyze
+ val expected1 = t1.groupBy($"a")(sum($"a").as("s")).analyze
+ comparePlans(Optimize.execute(plan1), expected1)
+
+ val plan2 = t1.groupBy()(sumDistinct($"a").as("s")).analyze
+ val expected2 = t1.groupBy()(sum($"a").as("s")).analyze
+ comparePlans(Optimize.execute(plan2), expected2)
+
+ // do not remove
+ val plan3 = t2.groupBy($"a")(sumDistinct($"a").as("s")).analyze
+ comparePlans(Optimize.execute(plan3), plan3)
+ }
+
+ test("Remove in complex case") {
+ val plan1 = t1.groupBy($"a")($"a").orderBy($"a".asc).analyze
+ val expected1 = t1.select($"a").analyze
+ comparePlans(Optimize.execute(plan1), expected1)
+
+ val plan2 = t1.groupBy($"a")(sumDistinct($"a").as("s")).orderBy($"s".asc).analyze
+ val expected2 = t1.groupBy($"a")(sum($"a").as("s")).analyze
+ comparePlans(Optimize.execute(plan2), expected2)
+ }
+}
diff --git a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/CoalesceShufflePartitionsSuite.scala b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/CoalesceShufflePartitionsSuite.scala
index 9f4ae359e1cc8459841dc9757fb52a514a1cbfb4..ddf4d421f3e5d8a9e43c54c9ed31b16439b5c597 100644
--- a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/CoalesceShufflePartitionsSuite.scala
+++ b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/CoalesceShufflePartitionsSuite.scala
@@ -18,13 +18,16 @@
package org.apache.spark.sql.execution
import org.scalatest.BeforeAndAfterAll
+
import org.apache.spark.{SparkConf, SparkFunSuite}
+import org.apache.spark.internal.config.IO_ENCRYPTION_ENABLED
import org.apache.spark.internal.config.UI.UI_ENABLED
import org.apache.spark.sql._
import org.apache.spark.sql.execution.adaptive._
+import org.apache.spark.sql.execution.adaptive.AQEShuffleReadExec
import org.apache.spark.sql.execution.exchange.ReusedExchangeExec
import org.apache.spark.sql.functions._
-import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf}
+import org.apache.spark.sql.internal.SQLConf
class CoalesceShufflePartitionsSuite extends SparkFunSuite with BeforeAndAfterAll {
@@ -53,23 +56,24 @@ class CoalesceShufflePartitionsSuite extends SparkFunSuite with BeforeAndAfterAl
val numInputPartitions: Int = 10
def withSparkSession(
- f: SparkSession => Unit,
- targetPostShuffleInputSize: Int,
- minNumPostShufflePartitions: Option[Int]): Unit = {
+ f: SparkSession => Unit,
+ targetPostShuffleInputSize: Int,
+ minNumPostShufflePartitions: Option[Int],
+ enableIOEncryption: Boolean = false): Unit = {
val sparkConf =
new SparkConf(false)
.setMaster("local[*]")
.setAppName("test")
.set(UI_ENABLED, false)
+ .set(IO_ENCRYPTION_ENABLED, enableIOEncryption)
.set(SQLConf.SHUFFLE_PARTITIONS.key, "5")
.set(SQLConf.COALESCE_PARTITIONS_INITIAL_PARTITION_NUM.key, "5")
.set(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key, "true")
+ .set(SQLConf.FETCH_SHUFFLE_BLOCKS_IN_BATCH.key, "true")
.set(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key, "-1")
.set(
SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES.key,
targetPostShuffleInputSize.toString)
- .set(StaticSQLConf.SPARK_SESSION_EXTENSIONS.key, "com.huawei.boostkit.spark.ColumnarPlugin")
- .set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.OmniColumnarShuffleManager")
minNumPostShufflePartitions match {
case Some(numPartitions) =>
sparkConf.set(SQLConf.COALESCE_PARTITIONS_MIN_PARTITION_NUM.key, numPartitions.toString)
@@ -90,7 +94,7 @@ class CoalesceShufflePartitionsSuite extends SparkFunSuite with BeforeAndAfterAl
}
test(s"determining the number of reducers: aggregate operator$testNameNote") {
- val test = { spark: SparkSession =>
+ val test: SparkSession => Unit = { spark: SparkSession =>
val df =
spark
.range(0, 1000, 1, numInputPartitions)
@@ -106,27 +110,27 @@ class CoalesceShufflePartitionsSuite extends SparkFunSuite with BeforeAndAfterAl
// by the ExchangeCoordinator.
val finalPlan = agg.queryExecution.executedPlan
.asInstanceOf[AdaptiveSparkPlanExec].executedPlan
- val shuffleReaders = finalPlan.collect {
- case r @ ColumnarCoalescedShuffleReader() => r
+ val shuffleReads = finalPlan.collect {
+ case r @ CoalescedShuffleRead() => r
}
- assert(shuffleReaders.length === 1)
+
minNumPostShufflePartitions match {
case Some(numPartitions) =>
- shuffleReaders.foreach { reader =>
- assert(reader.outputPartitioning.numPartitions === numPartitions)
- }
+ assert(shuffleReads.isEmpty)
+
case None =>
- shuffleReaders.foreach { reader =>
- assert(reader.outputPartitioning.numPartitions === 3)
+ assert(shuffleReads.length === 1)
+ shuffleReads.foreach { read =>
+ assert(read.outputPartitioning.numPartitions === 3)
}
}
}
- // The number of coulmn partitions byte is small. smaller threshold value should be used
- withSparkSession(test, 1500, minNumPostShufflePartitions)
+
+ withSparkSession(test, 2000, minNumPostShufflePartitions)
}
test(s"determining the number of reducers: join operator$testNameNote") {
- val test = { spark: SparkSession =>
+ val test: SparkSession => Unit = { spark: SparkSession =>
val df1 =
spark
.range(0, 1000, 1, numInputPartitions)
@@ -152,23 +156,23 @@ class CoalesceShufflePartitionsSuite extends SparkFunSuite with BeforeAndAfterAl
// by the ExchangeCoordinator.
val finalPlan = join.queryExecution.executedPlan
.asInstanceOf[AdaptiveSparkPlanExec].executedPlan
- val shuffleReaders = finalPlan.collect {
- case r @ ColumnarCoalescedShuffleReader() => r
+ val shuffleReads = finalPlan.collect {
+ case r @ CoalescedShuffleRead() => r
}
- assert(shuffleReaders.length === 2)
+
minNumPostShufflePartitions match {
case Some(numPartitions) =>
- shuffleReaders.foreach { reader =>
- assert(reader.outputPartitioning.numPartitions === numPartitions)
- }
+ assert(shuffleReads.isEmpty)
+
case None =>
- shuffleReaders.foreach { reader =>
- assert(reader.outputPartitioning.numPartitions === 2)
+ assert(shuffleReads.length === 2)
+ shuffleReads.foreach { read =>
+ assert(read.outputPartitioning.numPartitions === 2)
}
}
}
- // The number of coulmn partitions byte is small. smaller threshold value should be used
- withSparkSession(test, 11384, minNumPostShufflePartitions)
+
+ withSparkSession(test, 16384, minNumPostShufflePartitions)
}
test(s"determining the number of reducers: complex query 1$testNameNote") {
@@ -203,23 +207,23 @@ class CoalesceShufflePartitionsSuite extends SparkFunSuite with BeforeAndAfterAl
// by the ExchangeCoordinator.
val finalPlan = join.queryExecution.executedPlan
.asInstanceOf[AdaptiveSparkPlanExec].executedPlan
- val shuffleReaders = finalPlan.collect {
- case r @ ColumnarCoalescedShuffleReader() => r
+ val shuffleReads = finalPlan.collect {
+ case r @ CoalescedShuffleRead() => r
}
- assert(shuffleReaders.length === 2)
+
minNumPostShufflePartitions match {
case Some(numPartitions) =>
- shuffleReaders.foreach { reader =>
- assert(reader.outputPartitioning.numPartitions === numPartitions)
- }
+ assert(shuffleReads.isEmpty)
+
case None =>
- shuffleReaders.foreach { reader =>
- assert(reader.outputPartitioning.numPartitions === 3)
+ assert(shuffleReads.length === 2)
+ shuffleReads.foreach { read =>
+ assert(read.outputPartitioning.numPartitions === 2)
}
}
}
- // The number of coulmn partitions byte is small. smaller threshold value should be used
- withSparkSession(test, 7384, minNumPostShufflePartitions)
+
+ withSparkSession(test, 16384, minNumPostShufflePartitions)
}
test(s"determining the number of reducers: complex query 2$testNameNote") {
@@ -254,23 +258,23 @@ class CoalesceShufflePartitionsSuite extends SparkFunSuite with BeforeAndAfterAl
// by the ExchangeCoordinator.
val finalPlan = join.queryExecution.executedPlan
.asInstanceOf[AdaptiveSparkPlanExec].executedPlan
- val shuffleReaders = finalPlan.collect {
- case r @ ColumnarCoalescedShuffleReader() => r
+ val shuffleReads = finalPlan.collect {
+ case r @ CoalescedShuffleRead() => r
}
- assert(shuffleReaders.length === 2)
+
minNumPostShufflePartitions match {
case Some(numPartitions) =>
- shuffleReaders.foreach { reader =>
- assert(reader.outputPartitioning.numPartitions === numPartitions)
- }
+ assert(shuffleReads.isEmpty)
+
case None =>
- shuffleReaders.foreach { reader =>
- assert(reader.outputPartitioning.numPartitions === 2)
+ assert(shuffleReads.length === 2)
+ shuffleReads.foreach { read =>
+ assert(read.outputPartitioning.numPartitions === 3)
}
}
}
- // The number of coulmn partitions byte is small. smaller threshold value should be used
- withSparkSession(test, 10000, minNumPostShufflePartitions)
+
+ withSparkSession(test, 12000, minNumPostShufflePartitions)
}
test(s"determining the number of reducers: plan already partitioned$testNameNote") {
@@ -296,10 +300,10 @@ class CoalesceShufflePartitionsSuite extends SparkFunSuite with BeforeAndAfterAl
// Then, let's make sure we do not reduce number of post shuffle partitions.
val finalPlan = join.queryExecution.executedPlan
.asInstanceOf[AdaptiveSparkPlanExec].executedPlan
- val shuffleReaders = finalPlan.collect {
- case r @ ColumnarCoalescedShuffleReader() => r
+ val shuffleReads = finalPlan.collect {
+ case r @ CoalescedShuffleRead() => r
}
- assert(shuffleReaders.length === 0)
+ assert(shuffleReads.length === 0)
} finally {
spark.sql("drop table t")
}
@@ -308,10 +312,10 @@ class CoalesceShufflePartitionsSuite extends SparkFunSuite with BeforeAndAfterAl
}
}
- ignore("SPARK-24705 adaptive query execution works correctly when exchange reuse enabled") {
+ test("SPARK-24705 adaptive query execution works correctly when exchange reuse enabled") {
val test: SparkSession => Unit = { spark: SparkSession =>
spark.sql("SET spark.sql.exchange.reuse=true")
- val df = spark.range(1).selectExpr("id AS key", "id AS value")
+ val df = spark.range(0, 6, 1).selectExpr("id AS key", "id AS value")
// test case 1: a query stage has 3 child stages but they are the same stage.
// Final Stage 1
@@ -319,15 +323,15 @@ class CoalesceShufflePartitionsSuite extends SparkFunSuite with BeforeAndAfterAl
// ReusedQueryStage 0
// ReusedQueryStage 0
val resultDf = df.join(df, "key").join(df, "key")
- QueryTest.checkAnswer(resultDf, Row(0, 0, 0, 0) :: Nil)
+ QueryTest.checkAnswer(resultDf, (0 to 5).map(i => Row(i, i, i, i)))
val finalPlan = resultDf.queryExecution.executedPlan
.asInstanceOf[AdaptiveSparkPlanExec].executedPlan
assert(finalPlan.collect {
- case ShuffleQueryStageExec(_, r: ReusedExchangeExec) => r
+ case ShuffleQueryStageExec(_, r: ReusedExchangeExec, _) => r
}.length == 2)
assert(
finalPlan.collect {
- case r @ ColumnarCoalescedShuffleReader() => r
+ case r @ CoalescedShuffleRead() => r
}.length == 3)
@@ -340,7 +344,9 @@ class CoalesceShufflePartitionsSuite extends SparkFunSuite with BeforeAndAfterAl
val grouped = df.groupBy("key").agg(max("value").as("value"))
val resultDf2 = grouped.groupBy(col("key") + 1).max("value")
.union(grouped.groupBy(col("key") + 2).max("value"))
- QueryTest.checkAnswer(resultDf2, Row(1, 0) :: Row(2, 0) :: Nil)
+ QueryTest.checkAnswer(resultDf2, Row(1, 0) :: Row(2, 0) :: Row(2, 1) :: Row(3, 1) ::
+ Row(3, 2) :: Row(4, 2) :: Row(4, 3) :: Row(5, 3) :: Row(5, 4) :: Row(6, 4) :: Row(6, 5) ::
+ Row(7, 5) :: Nil)
val finalPlan2 = resultDf2.queryExecution.executedPlan
.asInstanceOf[AdaptiveSparkPlanExec].executedPlan
@@ -349,6 +355,17 @@ class CoalesceShufflePartitionsSuite extends SparkFunSuite with BeforeAndAfterAl
val level1Stages = finalPlan2.collect { case q: QueryStageExec => q }
assert(level1Stages.length == 2)
+ assert(
+ finalPlan2.collect {
+ case r @ CoalescedShuffleRead() => r
+ }.length == 2, "finalPlan2")
+
+ level1Stages.foreach(qs =>
+ assert(qs.plan.collect {
+ case r @ CoalescedShuffleRead() => r
+ }.length == 1,
+ "Wrong CoalescedShuffleRead below " + qs.simpleString(3)))
+
val leafStages = level1Stages.flatMap { stage =>
// All of the child stages of result stage have only one child stage.
val children = stage.plan.collect { case q: QueryStageExec => q }
@@ -359,12 +376,12 @@ class CoalesceShufflePartitionsSuite extends SparkFunSuite with BeforeAndAfterAl
val reusedStages = level1Stages.flatMap { stage =>
stage.plan.collect {
- case ShuffleQueryStageExec(_, r: ReusedExchangeExec) => r
+ case ShuffleQueryStageExec(_, r: ReusedExchangeExec, _) => r
}
}
assert(reusedStages.length == 1)
}
- withSparkSession(test, 4, None)
+ withSparkSession(test, 400, None)
}
test("Do not reduce the number of shuffle partition for repartition") {
@@ -378,7 +395,7 @@ class CoalesceShufflePartitionsSuite extends SparkFunSuite with BeforeAndAfterAl
.asInstanceOf[AdaptiveSparkPlanExec].executedPlan
assert(
finalPlan.collect {
- case r @ ColumnarCoalescedShuffleReader() => r
+ case r @ CoalescedShuffleRead() => r
}.isEmpty)
}
withSparkSession(test, 200, None)
@@ -393,21 +410,40 @@ class CoalesceShufflePartitionsSuite extends SparkFunSuite with BeforeAndAfterAl
QueryTest.checkAnswer(resultDf, Seq((0), (1), (2), (3)).map(i => Row(i)))
+ // Shuffle partition coalescing of the join is performed independent of the non-grouping
+ // aggregate on the other side of the union.
val finalPlan = resultDf.queryExecution.executedPlan
.asInstanceOf[AdaptiveSparkPlanExec].executedPlan
- // As the pre-shuffle partition number are different, we will skip reducing
- // the shuffle partition numbers.
assert(
finalPlan.collect {
- case r @ ColumnarCoalescedShuffleReader() => r
- }.isEmpty)
+ case r @ CoalescedShuffleRead() => r
+ }.size == 2)
}
withSparkSession(test, 100, None)
}
+
+ test("SPARK-34790: enable IO encryption in AQE partition coalescing") {
+ val test: SparkSession => Unit = { spark: SparkSession =>
+ val ds = spark.range(0, 100, 1, numInputPartitions)
+ val resultDf = ds.repartition(ds.col("id"))
+ resultDf.collect()
+
+ val finalPlan = resultDf.queryExecution.executedPlan
+ .asInstanceOf[AdaptiveSparkPlanExec].executedPlan
+ assert(
+ finalPlan.collect {
+ case r @ CoalescedShuffleRead() => r
+ }.isDefinedAt(0))
+ }
+ Seq(true, false).foreach { enableIOEncryption =>
+ // Before SPARK-34790, it will throw an exception when io encryption enabled.
+ withSparkSession(test, Int.MaxValue, None, enableIOEncryption)
+ }
+ }
}
-object ColumnarCoalescedShuffleReader {
- def unapply(reader: ColumnarCustomShuffleReaderExec): Boolean = {
- !reader.isLocalReader && !reader.hasSkewedPartition && reader.hasCoalescedPartition
+object CoalescedShuffleRead {
+ def unapply(read: AQEShuffleReadExec): Boolean = {
+ !read.isLocalRead && !read.hasSkewedPartition && read.hasCoalescedPartition
}
}
diff --git a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/ColumnarSparkPlanTest.scala b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/ColumnarSparkPlanTest.scala
index 16ab589578aacc68a1964566ef05f3898d0e406a..fd5649c4486d0def0131ddb9a42102db11a38718 100644
--- a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/ColumnarSparkPlanTest.scala
+++ b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/ColumnarSparkPlanTest.scala
@@ -31,6 +31,7 @@ private[sql] abstract class ColumnarSparkPlanTest extends SparkPlanTest with Sha
.set(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, "false")
.set("spark.executorEnv.OMNI_CONNECTED_ENGINE", "Spark")
.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.OmniColumnarShuffleManager")
+ .set(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key, "false")
protected def checkAnswer(df: => DataFrame, expectedAnswer: Seq[Row]): Unit = {
val analyzedDF = try df catch {
diff --git a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/adaptive/ColumnarAdaptiveQueryExecSuite.scala b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/adaptive/ColumnarAdaptiveQueryExecSuite.scala
index cf2537484aefdcd68214db0046877652847bb34b..0055b94fa06626c48b19ef98d770d9109ace0726 100644
--- a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/adaptive/ColumnarAdaptiveQueryExecSuite.scala
+++ b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/adaptive/ColumnarAdaptiveQueryExecSuite.scala
@@ -17,34 +17,41 @@
package org.apache.spark.sql.execution.adaptive
-import org.apache.log4j.Level
-import org.apache.spark.Partition
-import org.apache.spark.rdd.RDD
+import java.io.File
+import java.net.URI
+
+import org.apache.logging.log4j.Level
+import org.scalatest.PrivateMethodTester
+import org.scalatest.time.SpanSugar._
+
+import org.apache.spark.SparkException
import org.apache.spark.scheduler.{SparkListener, SparkListenerEvent, SparkListenerJobStart}
-import org.apache.spark.sql.{Dataset, Row, SparkSession, Strategy}
+import org.apache.spark.sql.{Dataset, QueryTest, Row, SparkSession, Strategy}
import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight}
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan}
+import org.apache.spark.sql.execution.{CollectLimitExec, CommandResultExec, LocalTableScanExec, PartialReducerPartitionSpec, QueryExecution, ReusedSubqueryExec, ShuffledRowRDD, SortExec, SparkPlan, UnaryExecNode, UnionExec}
+import org.apache.spark.sql.execution.aggregate.BaseAggregateExec
import org.apache.spark.sql.execution.command.DataWritingCommandExec
import org.apache.spark.sql.execution.datasources.noop.NoopDataSource
import org.apache.spark.sql.execution.datasources.v2.V2TableWriteExec
-import org.apache.spark.sql.execution.{ColumnarBroadcastExchangeExec, ColumnarSparkPlanTest, PartialReducerPartitionSpec, QueryExecution, ReusedSubqueryExec, ShuffledColumnarRDD, SparkPlan, UnaryExecNode}
-import org.apache.spark.sql.execution.exchange.{Exchange, REPARTITION, REPARTITION_WITH_NUM, ReusedExchangeExec, ShuffleExchangeExec, ShuffleExchangeLike}
-import org.apache.spark.sql.execution.joins.{BaseJoinExec, BroadcastHashJoinExec, ColumnarBroadcastHashJoinExec, ColumnarSortMergeJoinExec, SortMergeJoinExec}
+import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ENSURE_REQUIREMENTS, Exchange, REPARTITION_BY_COL, REPARTITION_BY_NUM, ReusedExchangeExec, ShuffleExchangeExec, ShuffleExchangeLike, ShuffleOrigin}
+import org.apache.spark.sql.execution.joins.{BaseJoinExec, BroadcastHashJoinExec, BroadcastNestedLoopJoinExec, ShuffledHashJoinExec, ShuffledJoin, SortMergeJoinExec}
import org.apache.spark.sql.execution.metric.SQLShuffleReadMetricsReporter
import org.apache.spark.sql.execution.ui.SparkListenerSQLAdaptiveExecutionUpdate
-import org.apache.spark.sql.functions.{sum, when}
+import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.SQLConf.PartitionOverwriteMode
+import org.apache.spark.sql.test.SharedSparkSession
+import org.apache.spark.sql.test.SQLTestData.TestData
import org.apache.spark.sql.types.{IntegerType, StructType}
import org.apache.spark.sql.util.QueryExecutionListener
-import org.apache.spark.sql.vectorized.ColumnarBatch
import org.apache.spark.util.Utils
-import java.io.File
-import java.net.URI
-
-class ColumnarAdaptiveQueryExecSuite extends ColumnarSparkPlanTest
- with AdaptiveSparkPlanHelper {
+class AdaptiveQueryExecSuite
+ extends QueryTest
+ with SharedSparkSession
+ with AdaptiveSparkPlanHelper
+ with PrivateMethodTester {
import testImplicits._
@@ -98,10 +105,9 @@ class ColumnarAdaptiveQueryExecSuite extends ColumnarSparkPlanTest
}
}
- private def findTopLevelColumnarBroadcastHashJoin(plan: SparkPlan)
- : Seq[ColumnarBroadcastHashJoinExec] = {
+ def findTopLevelBroadcastNestedLoopJoin(plan: SparkPlan): Seq[BaseJoinExec] = {
collect(plan) {
- case j: ColumnarBroadcastHashJoinExec => j
+ case j: BroadcastNestedLoopJoinExec => j
}
}
@@ -111,9 +117,9 @@ class ColumnarAdaptiveQueryExecSuite extends ColumnarSparkPlanTest
}
}
- private def findTopLevelColumnarSortMergeJoin(plan: SparkPlan): Seq[ColumnarSortMergeJoinExec] = {
+ private def findTopLevelShuffledHashJoin(plan: SparkPlan): Seq[ShuffledHashJoinExec] = {
collect(plan) {
- case j: ColumnarSortMergeJoinExec => j
+ case j: ShuffledHashJoinExec => j
}
}
@@ -123,10 +129,28 @@ class ColumnarAdaptiveQueryExecSuite extends ColumnarSparkPlanTest
}
}
+ private def findTopLevelSort(plan: SparkPlan): Seq[SortExec] = {
+ collect(plan) {
+ case s: SortExec => s
+ }
+ }
+
+ private def findTopLevelAggregate(plan: SparkPlan): Seq[BaseAggregateExec] = {
+ collect(plan) {
+ case agg: BaseAggregateExec => agg
+ }
+ }
+
+ private def findTopLevelLimit(plan: SparkPlan): Seq[CollectLimitExec] = {
+ collect(plan) {
+ case l: CollectLimitExec => l
+ }
+ }
+
private def findReusedExchange(plan: SparkPlan): Seq[ReusedExchangeExec] = {
collectWithSubqueries(plan) {
- case ShuffleQueryStageExec(_, e: ReusedExchangeExec) => e
- case BroadcastQueryStageExec(_, e: ReusedExchangeExec) => e
+ case ShuffleQueryStageExec(_, e: ReusedExchangeExec, _) => e
+ case BroadcastQueryStageExec(_, e: ReusedExchangeExec, _) => e
}
}
@@ -136,28 +160,21 @@ class ColumnarAdaptiveQueryExecSuite extends ColumnarSparkPlanTest
}
}
- private def checkNumLocalShuffleReaders(
- plan: SparkPlan, numShufflesWithoutLocalReader: Int = 0): Unit = {
+ private def checkNumLocalShuffleReads(
+ plan: SparkPlan, numShufflesWithoutLocalRead: Int = 0): Unit = {
val numShuffles = collect(plan) {
case s: ShuffleQueryStageExec => s
}.length
- val numLocalReaders = collect(plan) {
- case rowReader: CustomShuffleReaderExec if rowReader.isLocalReader => rowReader
- case colReader: ColumnarCustomShuffleReaderExec if colReader.isLocalReader => colReader
+ val numLocalReads = collect(plan) {
+ case read: AQEShuffleReadExec if read.isLocalRead => read
}
- numLocalReaders.foreach {
- case rowCus: CustomShuffleReaderExec =>
- val rdd = rowCus.execute()
- val parts = rdd.partitions
- assert(parts.forall(rdd.preferredLocations(_).nonEmpty))
- case r =>
- val columnarCus = r.asInstanceOf[ColumnarCustomShuffleReaderExec]
- val rdd: RDD[ColumnarBatch] = columnarCus.executeColumnar()
- val parts: Array[Partition] = rdd.partitions
- assert(parts.forall(rdd.preferredLocations(_).nonEmpty))
+ numLocalReads.foreach { r =>
+ val rdd = r.execute()
+ val parts = rdd.partitions
+ assert(parts.forall(rdd.preferredLocations(_).nonEmpty))
}
- assert(numShuffles === (numLocalReaders.length + numShufflesWithoutLocalReader))
+ assert(numShuffles === (numLocalReads.length + numShufflesWithoutLocalRead))
}
private def checkInitialPartitionNum(df: Dataset[_], numPartition: Int): Unit = {
@@ -173,20 +190,42 @@ class ColumnarAdaptiveQueryExecSuite extends ColumnarSparkPlanTest
test("Change merge join to broadcast join") {
withSQLConf(
- SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
- SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") {
+ SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
+ SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") {
val (plan, adaptivePlan) = runAdaptiveAndVerifyResult(
"SELECT * FROM testData join testData2 ON key = a where value = '1'")
- val smj: Seq[SortMergeJoinExec] = findTopLevelSortMergeJoin(plan)
+ val smj = findTopLevelSortMergeJoin(plan)
assert(smj.size == 1)
- val bhj: Seq[ColumnarBroadcastHashJoinExec] =
- findTopLevelColumnarBroadcastHashJoin(adaptivePlan)
+ val bhj = findTopLevelBroadcastHashJoin(adaptivePlan)
assert(bhj.size == 1)
- checkNumLocalShuffleReaders(adaptivePlan)
+ checkNumLocalShuffleReads(adaptivePlan)
}
}
- test("Reuse the parallelism of CoalescedShuffleReaderExec in LocalShuffleReaderExec") {
+ test("Change broadcast join to merge join") {
+ withTable("t1", "t2") {
+ withSQLConf(
+ SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "10000",
+ SQLConf.ADAPTIVE_AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
+ SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
+ sql("CREATE TABLE t1 USING PARQUET AS SELECT 1 c1")
+ sql("CREATE TABLE t2 USING PARQUET AS SELECT 1 c1")
+ val (plan, adaptivePlan) = runAdaptiveAndVerifyResult(
+ """
+ |SELECT * FROM (
+ | SELECT distinct c1 from t1
+ | ) tmp1 JOIN (
+ | SELECT distinct c1 from t2
+ | ) tmp2 ON tmp1.c1 = tmp2.c1
+ |""".stripMargin)
+ assert(findTopLevelBroadcastHashJoin(plan).size == 1)
+ assert(findTopLevelBroadcastHashJoin(adaptivePlan).isEmpty)
+ assert(findTopLevelSortMergeJoin(adaptivePlan).size == 1)
+ }
+ }
+ }
+
+ test("Reuse the parallelism of coalesced shuffle in local shuffle read") {
withSQLConf(
SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80",
@@ -195,30 +234,30 @@ class ColumnarAdaptiveQueryExecSuite extends ColumnarSparkPlanTest
"SELECT * FROM testData join testData2 ON key = a where value = '1'")
val smj = findTopLevelSortMergeJoin(plan)
assert(smj.size == 1)
- val bhj = findTopLevelColumnarBroadcastHashJoin(adaptivePlan)
+ val bhj = findTopLevelBroadcastHashJoin(adaptivePlan)
assert(bhj.size == 1)
- val localReaders = collect(adaptivePlan) {
- case reader: ColumnarCustomShuffleReaderExec if reader.isLocalReader => reader
+ val localReads = collect(adaptivePlan) {
+ case read: AQEShuffleReadExec if read.isLocalRead => read
}
- assert(localReaders.length == 2)
- val localShuffleRDD0 = localReaders(0).executeColumnar().asInstanceOf[ShuffledColumnarRDD]
- val localShuffleRDD1 = localReaders(1).executeColumnar().asInstanceOf[ShuffledColumnarRDD]
+ assert(localReads.length == 2)
+ val localShuffleRDD0 = localReads(0).execute().asInstanceOf[ShuffledRowRDD]
+ val localShuffleRDD1 = localReads(1).execute().asInstanceOf[ShuffledRowRDD]
// The pre-shuffle partition size is [0, 0, 0, 72, 0]
// We exclude the 0-size partitions, so only one partition, advisoryParallelism = 1
// the final parallelism is
- // math.max(1, advisoryParallelism / numMappers): math.max(1, 1/2) = 1
- // and the partitions length is 1 * numMappers = 2
- assert(localShuffleRDD0.getPartitions.length == 2)
+ // advisoryParallelism = 1 since advisoryParallelism < numMappers
+ // and the partitions length is 1
+ assert(localShuffleRDD0.getPartitions.length == 1)
// The pre-shuffle partition size is [0, 72, 0, 72, 126]
// We exclude the 0-size partitions, so only 3 partition, advisoryParallelism = 3
// the final parallelism is
- // math.max(1, advisoryParallelism / numMappers): math.max(1, 3/2) = 1
+ // advisoryParallelism / numMappers: 3/2 = 1 since advisoryParallelism >= numMappers
// and the partitions length is 1 * numMappers = 2
assert(localShuffleRDD1.getPartitions.length == 2)
}
}
- test("Reuse the default parallelism in LocalShuffleReaderExec") {
+ test("Reuse the default parallelism in local shuffle read") {
withSQLConf(
SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80",
@@ -227,14 +266,14 @@ class ColumnarAdaptiveQueryExecSuite extends ColumnarSparkPlanTest
"SELECT * FROM testData join testData2 ON key = a where value = '1'")
val smj = findTopLevelSortMergeJoin(plan)
assert(smj.size == 1)
- val bhj = findTopLevelColumnarBroadcastHashJoin(adaptivePlan)
+ val bhj = findTopLevelBroadcastHashJoin(adaptivePlan)
assert(bhj.size == 1)
- val localReaders = collect(adaptivePlan) {
- case reader: ColumnarCustomShuffleReaderExec if reader.isLocalReader => reader
+ val localReads = collect(adaptivePlan) {
+ case read: AQEShuffleReadExec if read.isLocalRead => read
}
- assert(localReaders.length == 2)
- val localShuffleRDD0 = localReaders(0).executeColumnar().asInstanceOf[ShuffledColumnarRDD]
- val localShuffleRDD1 = localReaders(1).executeColumnar().asInstanceOf[ShuffledColumnarRDD]
+ assert(localReads.length == 2)
+ val localShuffleRDD0 = localReads(0).execute().asInstanceOf[ShuffledRowRDD]
+ val localShuffleRDD1 = localReads(1).execute().asInstanceOf[ShuffledRowRDD]
// the final parallelism is math.max(1, numReduces / numMappers): math.max(1, 5/2) = 2
// and the partitions length is 2 * numMappers = 4
assert(localShuffleRDD0.getPartitions.length == 4)
@@ -247,73 +286,75 @@ class ColumnarAdaptiveQueryExecSuite extends ColumnarSparkPlanTest
test("Empty stage coalesced to 1-partition RDD") {
withSQLConf(
SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
- SQLConf.COALESCE_PARTITIONS_ENABLED.key -> "true") {
- val df1 = spark.range(10).withColumn("a", 'id)
- val df2 = spark.range(10).withColumn("b", 'id)
+ SQLConf.COALESCE_PARTITIONS_ENABLED.key -> "true",
+ SQLConf.ADAPTIVE_OPTIMIZER_EXCLUDED_RULES.key -> AQEPropagateEmptyRelation.ruleName) {
+ val df1 = spark.range(10).withColumn("a", Symbol("id"))
+ val df2 = spark.range(10).withColumn("b", Symbol("id"))
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
- val testDf = df1.where('a > 10).join(df2.where('b > 10), Seq("id"), "left_outer")
- .groupBy('a).count()
+ val testDf = df1.where(Symbol("a") > 10)
+ .join(df2.where(Symbol("b") > 10), Seq("id"), "left_outer")
+ .groupBy(Symbol("a")).count()
checkAnswer(testDf, Seq())
val plan = testDf.queryExecution.executedPlan
assert(find(plan)(_.isInstanceOf[SortMergeJoinExec]).isDefined)
- val coalescedReaders = collect(plan) {
- case r: ColumnarCustomShuffleReaderExec => r
+ val coalescedReads = collect(plan) {
+ case r: AQEShuffleReadExec => r
}
- assert(coalescedReaders.length == 3)
- coalescedReaders.foreach(r => assert(r.partitionSpecs.length == 1))
+ assert(coalescedReads.length == 3)
+ coalescedReads.foreach(r => assert(r.partitionSpecs.length == 1))
}
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "1") {
- val testDf = df1.where('a > 10).join(df2.where('b > 10), Seq("id"), "left_outer")
- .groupBy('a).count()
+ val testDf = df1.where(Symbol("a") > 10)
+ .join(df2.where(Symbol("b") > 10), Seq("id"), "left_outer")
+ .groupBy(Symbol("a")).count()
checkAnswer(testDf, Seq())
val plan = testDf.queryExecution.executedPlan
- print(plan)
- assert(find(plan)(_.isInstanceOf[ColumnarBroadcastHashJoinExec]).isDefined)
- val coalescedReaders = collect(plan) {
- case r: ColumnarCustomShuffleReaderExec => r
+ assert(find(plan)(_.isInstanceOf[BroadcastHashJoinExec]).isDefined)
+ val coalescedReads = collect(plan) {
+ case r: AQEShuffleReadExec => r
}
- assert(coalescedReaders.length == 3, s"$plan")
- coalescedReaders.foreach(r => assert(r.isLocalReader || r.partitionSpecs.length == 1))
+ assert(coalescedReads.length == 3, s"$plan")
+ coalescedReads.foreach(r => assert(r.isLocalRead || r.partitionSpecs.length == 1))
}
}
}
test("Scalar subquery") {
withSQLConf(
- SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
- SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") {
+ SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
+ SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") {
val (plan, adaptivePlan) = runAdaptiveAndVerifyResult(
"SELECT * FROM testData join testData2 ON key = a " +
- "where value = (SELECT max(a) from testData3)")
+ "where value = (SELECT max(a) from testData3)")
val smj = findTopLevelSortMergeJoin(plan)
assert(smj.size == 1)
- val bhj = findTopLevelColumnarBroadcastHashJoin(adaptivePlan)
+ val bhj = findTopLevelBroadcastHashJoin(adaptivePlan)
assert(bhj.size == 1)
- checkNumLocalShuffleReaders(adaptivePlan)
+ checkNumLocalShuffleReads(adaptivePlan)
}
}
- // Currently, OmniFilterExec will fall back to Filter, if AQE is enabled, it will cause error
- ignore("Scalar subquery in later stages") {
+ test("Scalar subquery in later stages") {
withSQLConf(
- SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
- SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") {
+ SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
+ SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") {
val (plan, adaptivePlan) = runAdaptiveAndVerifyResult(
"SELECT * FROM testData join testData2 ON key = a " +
- "where (value + a) = (SELECT max(a) from testData3)")
+ "where (value + a) = (SELECT max(a) from testData3)")
val smj = findTopLevelSortMergeJoin(plan)
assert(smj.size == 1)
- val bhj = findTopLevelColumnarBroadcastHashJoin(adaptivePlan)
+ val bhj = findTopLevelBroadcastHashJoin(adaptivePlan)
assert(bhj.size == 1)
- checkNumLocalShuffleReaders(adaptivePlan)
+
+ checkNumLocalShuffleReads(adaptivePlan)
}
}
test("multiple joins") {
withSQLConf(
- SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
- SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") {
+ SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
+ SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") {
val (plan, adaptivePlan) = runAdaptiveAndVerifyResult(
"""
|WITH t4 AS (
@@ -326,7 +367,7 @@ class ColumnarAdaptiveQueryExecSuite extends ColumnarSparkPlanTest
""".stripMargin)
val smj = findTopLevelSortMergeJoin(plan)
assert(smj.size == 3)
- val bhj = findTopLevelColumnarBroadcastHashJoin(adaptivePlan)
+ val bhj = findTopLevelBroadcastHashJoin(adaptivePlan)
assert(bhj.size == 3)
// A possible resulting query plan:
@@ -347,18 +388,18 @@ class ColumnarAdaptiveQueryExecSuite extends ColumnarSparkPlanTest
// +-LocalShuffleReader*
// +- ShuffleExchange
- // After applied the 'OptimizeLocalShuffleReader' rule, we can convert all the four
- // shuffle reader to local shuffle reader in the bottom two 'BroadcastHashJoin'.
+ // After applied the 'OptimizeShuffleWithLocalRead' rule, we can convert all the four
+ // shuffle read to local shuffle read in the bottom two 'BroadcastHashJoin'.
// For the top level 'BroadcastHashJoin', the probe side is not shuffle query stage
- // and the build side shuffle query stage is also converted to local shuffle reader.
- checkNumLocalShuffleReaders(adaptivePlan)
+ // and the build side shuffle query stage is also converted to local shuffle read.
+ checkNumLocalShuffleReads(adaptivePlan)
}
}
test("multiple joins with aggregate") {
withSQLConf(
- SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
- SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") {
+ SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
+ SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") {
val (plan, adaptivePlan) = runAdaptiveAndVerifyResult(
"""
|WITH t4 AS (
@@ -373,7 +414,7 @@ class ColumnarAdaptiveQueryExecSuite extends ColumnarSparkPlanTest
""".stripMargin)
val smj = findTopLevelSortMergeJoin(plan)
assert(smj.size == 3)
- val bhj = findTopLevelColumnarBroadcastHashJoin(adaptivePlan)
+ val bhj = findTopLevelBroadcastHashJoin(adaptivePlan)
assert(bhj.size == 3)
// A possible resulting query plan:
@@ -395,15 +436,15 @@ class ColumnarAdaptiveQueryExecSuite extends ColumnarSparkPlanTest
// +- CoalescedShuffleReader
// +- ShuffleExchange
- // The shuffle added by Aggregate can't apply local reader.
- checkNumLocalShuffleReaders(adaptivePlan, 1)
+ // The shuffle added by Aggregate can't apply local read.
+ checkNumLocalShuffleReads(adaptivePlan, 1)
}
}
test("multiple joins with aggregate 2") {
withSQLConf(
- SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
- SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "500") {
+ SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
+ SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "500") {
val (plan, adaptivePlan) = runAdaptiveAndVerifyResult(
"""
|WITH t4 AS (
@@ -418,8 +459,8 @@ class ColumnarAdaptiveQueryExecSuite extends ColumnarSparkPlanTest
""".stripMargin)
val smj = findTopLevelSortMergeJoin(plan)
assert(smj.size == 3)
- val bhj = findTopLevelColumnarBroadcastHashJoin(adaptivePlan)
- assert(bhj.size == 2)
+ val bhj = findTopLevelBroadcastHashJoin(adaptivePlan)
+ assert(bhj.size == 3)
// A possible resulting query plan:
// BroadcastHashJoin
@@ -441,25 +482,25 @@ class ColumnarAdaptiveQueryExecSuite extends ColumnarSparkPlanTest
// +-LocalShuffleReader*
// +- ShuffleExchange
- // The shuffle added by Aggregate can't apply local reader.
- checkNumLocalShuffleReaders(adaptivePlan, 1)
+ // The shuffle added by Aggregate can't apply local read.
+ checkNumLocalShuffleReads(adaptivePlan, 1)
}
}
test("Exchange reuse") {
withSQLConf(
- SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
- SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") {
+ SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
+ SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") {
val (plan, adaptivePlan) = runAdaptiveAndVerifyResult(
"SELECT value FROM testData join testData2 ON key = a " +
- "join (SELECT value v from testData join testData3 ON key = a) on value = v")
+ "join (SELECT value v from testData join testData3 ON key = a) on value = v")
val smj = findTopLevelSortMergeJoin(plan)
assert(smj.size == 3)
- val bhj = findTopLevelColumnarBroadcastHashJoin(adaptivePlan)
- assert(bhj.size == 3)
- // There is no SMJ
- checkNumLocalShuffleReaders(adaptivePlan, 0)
- // Even with local shuffle reader, the query stage reuse can also work.
+ val bhj = findTopLevelBroadcastHashJoin(adaptivePlan)
+ assert(bhj.size == 2)
+ // There is still a SMJ, and its two shuffles can't apply local read.
+ checkNumLocalShuffleReads(adaptivePlan, 2)
+ // Even with local shuffle read, the query stage reuse can also work.
val ex = findReusedExchange(adaptivePlan)
assert(ex.size == 1)
}
@@ -467,17 +508,17 @@ class ColumnarAdaptiveQueryExecSuite extends ColumnarSparkPlanTest
test("Exchange reuse with subqueries") {
withSQLConf(
- SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
- SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") {
+ SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
+ SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") {
val (plan, adaptivePlan) = runAdaptiveAndVerifyResult(
"SELECT a FROM testData join testData2 ON key = a " +
- "where value = (SELECT max(a) from testData join testData2 ON key = a)")
+ "where value = (SELECT max(a) from testData join testData2 ON key = a)")
val smj = findTopLevelSortMergeJoin(plan)
assert(smj.size == 1)
- val bhj = findTopLevelColumnarBroadcastHashJoin(adaptivePlan)
+ val bhj = findTopLevelBroadcastHashJoin(adaptivePlan)
assert(bhj.size == 1)
- checkNumLocalShuffleReaders(adaptivePlan)
- // Even with local shuffle reader, the query stage reuse can also work.
+ checkNumLocalShuffleReads(adaptivePlan)
+ // Even with local shuffle read, the query stage reuse can also work.
val ex = findReusedExchange(adaptivePlan)
assert(ex.size == 1)
}
@@ -485,19 +526,19 @@ class ColumnarAdaptiveQueryExecSuite extends ColumnarSparkPlanTest
test("Exchange reuse across subqueries") {
withSQLConf(
- SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
- SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80",
- SQLConf.SUBQUERY_REUSE_ENABLED.key -> "false") {
+ SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
+ SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80",
+ SQLConf.SUBQUERY_REUSE_ENABLED.key -> "false") {
val (plan, adaptivePlan) = runAdaptiveAndVerifyResult(
"SELECT a FROM testData join testData2 ON key = a " +
- "where value >= (SELECT max(a) from testData join testData2 ON key = a) " +
- "and a <= (SELECT max(a) from testData join testData2 ON key = a)")
+ "where value >= (SELECT max(a) from testData join testData2 ON key = a) " +
+ "and a <= (SELECT max(a) from testData join testData2 ON key = a)")
val smj = findTopLevelSortMergeJoin(plan)
assert(smj.size == 1)
- val bhj = findTopLevelColumnarBroadcastHashJoin(adaptivePlan)
+ val bhj = findTopLevelBroadcastHashJoin(adaptivePlan)
assert(bhj.size == 1)
- checkNumLocalShuffleReaders(adaptivePlan)
- // Even with local shuffle reader, the query stage reuse can also work.
+ checkNumLocalShuffleReads(adaptivePlan)
+ // Even with local shuffle read, the query stage reuse can also work.
val ex = findReusedExchange(adaptivePlan)
assert(ex.nonEmpty)
val sub = findReusedSubquery(adaptivePlan)
@@ -507,18 +548,18 @@ class ColumnarAdaptiveQueryExecSuite extends ColumnarSparkPlanTest
test("Subquery reuse") {
withSQLConf(
- SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
- SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") {
+ SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
+ SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") {
val (plan, adaptivePlan) = runAdaptiveAndVerifyResult(
"SELECT a FROM testData join testData2 ON key = a " +
- "where value >= (SELECT max(a) from testData join testData2 ON key = a) " +
- "and a <= (SELECT max(a) from testData join testData2 ON key = a)")
+ "where value >= (SELECT max(a) from testData join testData2 ON key = a) " +
+ "and a <= (SELECT max(a) from testData join testData2 ON key = a)")
val smj = findTopLevelSortMergeJoin(plan)
assert(smj.size == 1)
- val bhj = findTopLevelColumnarBroadcastHashJoin(adaptivePlan)
+ val bhj = findTopLevelBroadcastHashJoin(adaptivePlan)
assert(bhj.size == 1)
- checkNumLocalShuffleReaders(adaptivePlan)
- // Even with local shuffle reader, the query stage reuse can also work.
+ checkNumLocalShuffleReads(adaptivePlan)
+ // Even with local shuffle read, the query stage reuse can also work.
val ex = findReusedExchange(adaptivePlan)
assert(ex.isEmpty)
val sub = findReusedSubquery(adaptivePlan)
@@ -528,24 +569,24 @@ class ColumnarAdaptiveQueryExecSuite extends ColumnarSparkPlanTest
test("Broadcast exchange reuse across subqueries") {
withSQLConf(
- SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
- SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "20000000",
- SQLConf.SUBQUERY_REUSE_ENABLED.key -> "false") {
+ SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
+ SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "20000000",
+ SQLConf.SUBQUERY_REUSE_ENABLED.key -> "false") {
val (plan, adaptivePlan) = runAdaptiveAndVerifyResult(
"SELECT a FROM testData join testData2 ON key = a " +
- "where value >= (" +
- "SELECT /*+ broadcast(testData2) */ max(key) from testData join testData2 ON key = a) " +
- "and a <= (" +
- "SELECT /*+ broadcast(testData2) */ max(value) from testData join testData2 ON key = a)")
+ "where value >= (" +
+ "SELECT /*+ broadcast(testData2) */ max(key) from testData join testData2 ON key = a) " +
+ "and a <= (" +
+ "SELECT /*+ broadcast(testData2) */ max(value) from testData join testData2 ON key = a)")
val smj = findTopLevelSortMergeJoin(plan)
assert(smj.size == 1)
- val bhj = findTopLevelColumnarBroadcastHashJoin(adaptivePlan)
+ val bhj = findTopLevelBroadcastHashJoin(adaptivePlan)
assert(bhj.size == 1)
- checkNumLocalShuffleReaders(adaptivePlan)
- // Even with local shuffle reader, the query stage reuse can also work.
+ checkNumLocalShuffleReads(adaptivePlan)
+ // Even with local shuffle read, the query stage reuse can also work.
val ex = findReusedExchange(adaptivePlan)
assert(ex.nonEmpty)
- assert(ex.head.child.isInstanceOf[ColumnarBroadcastExchangeExec])
+ assert(ex.head.child.isInstanceOf[BroadcastExchangeExec])
val sub = findReusedSubquery(adaptivePlan)
assert(sub.isEmpty)
}
@@ -591,7 +632,7 @@ class ColumnarAdaptiveQueryExecSuite extends ColumnarSparkPlanTest
withSQLConf(
SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
- SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "25",
+ SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80",
SQLConf.BROADCAST_HASH_JOIN_OUTPUT_PARTITIONING_EXPAND_LIMIT.key -> "0") {
val (plan, adaptivePlan) = runAdaptiveAndVerifyResult(
"SELECT * FROM testData " +
@@ -604,11 +645,11 @@ class ColumnarAdaptiveQueryExecSuite extends ColumnarSparkPlanTest
}
}
- test("Change merge join to broadcast join without local shuffle reader") {
+ test("Change merge join to broadcast join without local shuffle read") {
withSQLConf(
SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
SQLConf.LOCAL_SHUFFLE_READER_ENABLED.key -> "true",
- SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "25") {
+ SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "40") {
val (plan, adaptivePlan) = runAdaptiveAndVerifyResult(
"""
|SELECT * FROM testData t1 join testData2 t2
@@ -618,9 +659,10 @@ class ColumnarAdaptiveQueryExecSuite extends ColumnarSparkPlanTest
)
val smj = findTopLevelSortMergeJoin(plan)
assert(smj.size == 2)
- val bhj = findTopLevelColumnarBroadcastHashJoin(adaptivePlan)
+ val bhj = findTopLevelBroadcastHashJoin(adaptivePlan)
assert(bhj.size == 1)
- checkNumLocalShuffleReaders(adaptivePlan, 2)
+ // There is still a SMJ, and its two shuffles can't apply local read.
+ checkNumLocalShuffleReads(adaptivePlan, 2)
}
}
@@ -643,12 +685,53 @@ class ColumnarAdaptiveQueryExecSuite extends ColumnarSparkPlanTest
"SELECT * FROM testData join testData2 ON key = a where value = '1'")
val smj = findTopLevelSortMergeJoin(plan)
assert(smj.size == 1)
- val bhj = findTopLevelColumnarBroadcastHashJoin(adaptivePlan)
+ val bhj = findTopLevelBroadcastHashJoin(adaptivePlan)
assert(bhj.size == 1)
assert(bhj.head.buildSide == BuildRight)
}
}
}
+ test("SPARK-37753: Allow changing outer join to broadcast join even if too many empty" +
+ " partitions on broadcast side") {
+ withSQLConf(
+ SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
+ SQLConf.NON_EMPTY_PARTITION_RATIO_FOR_BROADCAST_JOIN.key -> "0.5") {
+ // `testData` is small enough to be broadcast but has empty partition ratio over the config.
+ withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") {
+ val (plan, adaptivePlan) = runAdaptiveAndVerifyResult(
+ "SELECT * FROM (select * from testData where value = '1') td" +
+ " right outer join testData2 ON key = a")
+ val smj = findTopLevelSortMergeJoin(plan)
+ assert(smj.size == 1)
+ val bhj = findTopLevelBroadcastHashJoin(adaptivePlan)
+ assert(bhj.size == 1)
+ }
+ }
+ }
+
+ test("SPARK-37753: Inhibit broadcast in left outer join when there are many empty" +
+ " partitions on outer/left side") {
+ // if the right side is completed first and the left side is still being executed,
+ // the right side does not know whether there are many empty partitions on the left side,
+ // so there is no demote, and then the right side is broadcast in the planning stage.
+ // so retry several times here to avoid unit test failure.
+ eventually(timeout(15.seconds), interval(500.milliseconds)) {
+ withSQLConf(
+ SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
+ SQLConf.NON_EMPTY_PARTITION_RATIO_FOR_BROADCAST_JOIN.key -> "0.5") {
+ // `testData` is small enough to be broadcast but has empty partition ratio over the config.
+ withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "200") {
+ val (plan, adaptivePlan) = runAdaptiveAndVerifyResult(
+ "SELECT * FROM (select * from testData where value = '1') td" +
+ " left outer join testData2 ON key = a")
+ val smj = findTopLevelSortMergeJoin(plan)
+ assert(smj.size == 1)
+ val bhj = findTopLevelBroadcastHashJoin(adaptivePlan)
+ assert(bhj.isEmpty)
+ }
+ }
+ }
+ }
test("SPARK-29906: AQE should not introduce extra shuffle for outermost limit") {
var numStages = 0
@@ -688,7 +771,7 @@ class ColumnarAdaptiveQueryExecSuite extends ColumnarSparkPlanTest
def checkSkewJoin(query: String, optimizeSkewJoin: Boolean): Unit = {
val (_, innerAdaptivePlan) = runAdaptiveAndVerifyResult(query)
- val innerSmj = findTopLevelColumnarSortMergeJoin(innerAdaptivePlan)
+ val innerSmj = findTopLevelSortMergeJoin(innerAdaptivePlan)
assert(innerSmj.size == 1 && innerSmj.head.isSkewJoin == optimizeSkewJoin)
}
@@ -701,65 +784,75 @@ class ColumnarAdaptiveQueryExecSuite extends ColumnarSparkPlanTest
}
}
- ignore("SPARK-29544: adaptive skew join with different join types") {
- withSQLConf(
- SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
- SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
- SQLConf.COALESCE_PARTITIONS_MIN_PARTITION_NUM.key -> "1",
- SQLConf.SHUFFLE_PARTITIONS.key -> "100",
- SQLConf.SKEW_JOIN_SKEWED_PARTITION_THRESHOLD.key -> "800",
- SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES.key -> "800") {
- withTempView("skewData1", "skewData2") {
- spark
- .range(0, 1000, 1, 10)
- .select(
- when('id < 250, 249)
- .when('id >= 750, 1000)
- .otherwise('id).as("key1"),
- 'id as "value1")
- .createOrReplaceTempView("skewData1")
- spark
- .range(0, 1000, 1, 10)
- .select(
- when('id < 250, 249)
- .otherwise('id).as("key2"),
- 'id as "value2")
- .createOrReplaceTempView("skewData2")
-
- def checkSkewJoin(
- joins: Seq[SortMergeJoinExec],
- leftSkewNum: Int,
- rightSkewNum: Int): Unit = {
- assert(joins.size == 1 && joins.head.isSkewJoin)
- assert(joins.head.left.collect {
- case r: ColumnarCustomShuffleReaderExec => r
- }.head.partitionSpecs.collect {
- case p: PartialReducerPartitionSpec => p.reducerIndex
- }.distinct.length == leftSkewNum)
- assert(joins.head.right.collect {
- case r: ColumnarCustomShuffleReaderExec => r
- }.head.partitionSpecs.collect {
- case p: PartialReducerPartitionSpec => p.reducerIndex
- }.distinct.length == rightSkewNum)
- }
-
- // skewed inner join optimization
- val (_, innerAdaptivePlan) = runAdaptiveAndVerifyResult(
- "SELECT * FROM skewData1 join skewData2 ON key1 = key2")
- val innerSmj = findTopLevelColumnarSortMergeJoin(innerAdaptivePlan)
- checkSkewJoin(innerSmj, 1, 1)
+ test("SPARK-29544: adaptive skew join with different join types") {
+ Seq("SHUFFLE_MERGE", "SHUFFLE_HASH").foreach { joinHint =>
+ def getJoinNode(plan: SparkPlan): Seq[ShuffledJoin] = if (joinHint == "SHUFFLE_MERGE") {
+ findTopLevelSortMergeJoin(plan)
+ } else {
+ findTopLevelShuffledHashJoin(plan)
+ }
+ withSQLConf(
+ SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
+ SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
+ SQLConf.COALESCE_PARTITIONS_MIN_PARTITION_NUM.key -> "1",
+ SQLConf.SHUFFLE_PARTITIONS.key -> "100",
+ SQLConf.SKEW_JOIN_SKEWED_PARTITION_THRESHOLD.key -> "800",
+ SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES.key -> "800") {
+ withTempView("skewData1", "skewData2") {
+ spark
+ .range(0, 1000, 1, 10)
+ .select(
+ when(Symbol("id") < 250, 249)
+ .when(Symbol("id") >= 750, 1000)
+ .otherwise(Symbol("id")).as("key1"),
+ Symbol("id") as "value1")
+ .createOrReplaceTempView("skewData1")
+ spark
+ .range(0, 1000, 1, 10)
+ .select(
+ when(Symbol("id") < 250, 249)
+ .otherwise(Symbol("id")).as("key2"),
+ Symbol("id") as "value2")
+ .createOrReplaceTempView("skewData2")
- // skewed left outer join optimization
- val (_, leftAdaptivePlan) = runAdaptiveAndVerifyResult(
- "SELECT * FROM skewData1 left outer join skewData2 ON key1 = key2")
- val leftSmj = findTopLevelColumnarSortMergeJoin(leftAdaptivePlan)
- checkSkewJoin(leftSmj, 2, 0)
+ def checkSkewJoin(
+ joins: Seq[ShuffledJoin],
+ leftSkewNum: Int,
+ rightSkewNum: Int): Unit = {
+ assert(joins.size == 1 && joins.head.isSkewJoin)
+ assert(joins.head.left.collect {
+ case r: AQEShuffleReadExec => r
+ }.head.partitionSpecs.collect {
+ case p: PartialReducerPartitionSpec => p.reducerIndex
+ }.distinct.length == leftSkewNum)
+ assert(joins.head.right.collect {
+ case r: AQEShuffleReadExec => r
+ }.head.partitionSpecs.collect {
+ case p: PartialReducerPartitionSpec => p.reducerIndex
+ }.distinct.length == rightSkewNum)
+ }
- // skewed right outer join optimization
- val (_, rightAdaptivePlan) = runAdaptiveAndVerifyResult(
- "SELECT * FROM skewData1 right outer join skewData2 ON key1 = key2")
- val rightSmj = findTopLevelColumnarSortMergeJoin(rightAdaptivePlan)
- checkSkewJoin(rightSmj, 0, 1)
+ // skewed inner join optimization
+ val (_, innerAdaptivePlan) = runAdaptiveAndVerifyResult(
+ s"SELECT /*+ $joinHint(skewData1) */ * FROM skewData1 " +
+ "JOIN skewData2 ON key1 = key2")
+ val inner = getJoinNode(innerAdaptivePlan)
+ checkSkewJoin(inner, 2, 1)
+
+ // skewed left outer join optimization
+ val (_, leftAdaptivePlan) = runAdaptiveAndVerifyResult(
+ s"SELECT /*+ $joinHint(skewData2) */ * FROM skewData1 " +
+ "LEFT OUTER JOIN skewData2 ON key1 = key2")
+ val leftJoin = getJoinNode(leftAdaptivePlan)
+ checkSkewJoin(leftJoin, 2, 0)
+
+ // skewed right outer join optimization
+ val (_, rightAdaptivePlan) = runAdaptiveAndVerifyResult(
+ s"SELECT /*+ $joinHint(skewData1) */ * FROM skewData1 " +
+ "RIGHT OUTER JOIN skewData2 ON key1 = key2")
+ val rightJoin = getJoinNode(rightAdaptivePlan)
+ checkSkewJoin(rightJoin, 0, 1)
+ }
}
}
}
@@ -770,18 +863,18 @@ class ColumnarAdaptiveQueryExecSuite extends ColumnarSparkPlanTest
withTable("bucketed_table") {
val df1 =
(0 until 50).map(i => (i % 5, i % 13, i.toString)).toDF("i", "j", "k").as("df1")
- df1.write.format("orc").bucketBy(8, "i").saveAsTable("bucketed_table")
+ df1.write.format("parquet").bucketBy(8, "i").saveAsTable("bucketed_table")
val warehouseFilePath = new URI(spark.sessionState.conf.warehousePath).getPath
val tableDir = new File(warehouseFilePath, "bucketed_table")
Utils.deleteRecursively(tableDir)
- df1.write.orc(tableDir.getAbsolutePath)
+ df1.write.parquet(tableDir.getAbsolutePath)
val aggregated = spark.table("bucketed_table").groupBy("i").count()
- val error = intercept[Exception] {
+ val error = intercept[SparkException] {
aggregated.count()
}
- assert(error.getCause.toString contains "Invalid bucket file")
- assert(error.getSuppressed.size === 0)
+ assert(error.getErrorClass === "INVALID_BUCKET_FILE")
+ assert(error.getMessage contains "Invalid bucket file")
}
}
}
@@ -794,409 +887,430 @@ class ColumnarAdaptiveQueryExecSuite extends ColumnarSparkPlanTest
}
}
- test("force apply AQE") {
+ test("force apply AQE") {
+ withSQLConf(
+ SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
+ SQLConf.ADAPTIVE_EXECUTION_FORCE_APPLY.key -> "true") {
+ val plan = sql("SELECT * FROM testData").queryExecution.executedPlan
+ assert(plan.isInstanceOf[AdaptiveSparkPlanExec])
+ }
+ }
+
+ test("SPARK-30719: do not log warning if intentionally skip AQE") {
+ val testAppender = new LogAppender("aqe logging warning test when skip")
+ withLogAppender(testAppender) {
withSQLConf(
- SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
- SQLConf.ADAPTIVE_EXECUTION_FORCE_APPLY.key -> "true") {
+ SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") {
val plan = sql("SELECT * FROM testData").queryExecution.executedPlan
- assert(plan.isInstanceOf[AdaptiveSparkPlanExec])
+ assert(!plan.isInstanceOf[AdaptiveSparkPlanExec])
}
}
+ assert(!testAppender.loggingEvents
+ .exists(msg => msg.getMessage.getFormattedMessage.contains(
+ s"${SQLConf.ADAPTIVE_EXECUTION_ENABLED.key} is" +
+ s" enabled but is not supported for")))
+ }
- test("SPARK-30719: do not log warning if intentionally skip AQE") {
- val testAppender = new LogAppender("aqe logging warning test when skip")
- withLogAppender(testAppender) {
+ test("test log level") {
+ def verifyLog(expectedLevel: Level): Unit = {
+ val logAppender = new LogAppender("adaptive execution")
+ logAppender.setThreshold(expectedLevel)
+ withLogAppender(
+ logAppender,
+ loggerNames = Seq(AdaptiveSparkPlanExec.getClass.getName.dropRight(1)),
+ level = Some(Level.TRACE)) {
withSQLConf(
- SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") {
- val plan = sql("SELECT * FROM testData").queryExecution.executedPlan
- assert(!plan.isInstanceOf[AdaptiveSparkPlanExec])
+ SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
+ SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") {
+ sql("SELECT * FROM testData join testData2 ON key = a where value = '1'").collect()
}
}
- assert(!testAppender.loggingEvents
- .exists(msg => msg.getRenderedMessage.contains(
- s"${SQLConf.ADAPTIVE_EXECUTION_ENABLED.key} is" +
- s" enabled but is not supported for")))
+ Seq("Plan changed", "Final plan").foreach { msg =>
+ assert(
+ logAppender.loggingEvents.exists { event =>
+ event.getMessage.getFormattedMessage.contains(msg) && event.getLevel == expectedLevel
+ })
+ }
}
- test("test log level") {
- def verifyLog(expectedLevel: Level): Unit = {
- val logAppender = new LogAppender("adaptive execution")
- withLogAppender(
- logAppender,
- loggerName = Some(AdaptiveSparkPlanExec.getClass.getName.dropRight(1)),
- level = Some(Level.TRACE)) {
- withSQLConf(
- SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
- SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") {
- sql("SELECT * FROM testData join testData2 ON key = a where value = '1'").collect()
- }
- }
- Seq("Plan changed", "Final plan").foreach { msg =>
- assert(
- logAppender.loggingEvents.exists { event =>
- event.getRenderedMessage.contains(msg) && event.getLevel == expectedLevel
- })
- }
+ // Verify default log level
+ verifyLog(Level.DEBUG)
+
+ // Verify custom log level
+ val levels = Seq(
+ "TRACE" -> Level.TRACE,
+ "trace" -> Level.TRACE,
+ "DEBUG" -> Level.DEBUG,
+ "debug" -> Level.DEBUG,
+ "INFO" -> Level.INFO,
+ "info" -> Level.INFO,
+ "WARN" -> Level.WARN,
+ "warn" -> Level.WARN,
+ "ERROR" -> Level.ERROR,
+ "error" -> Level.ERROR,
+ "deBUG" -> Level.DEBUG)
+
+ levels.foreach { level =>
+ withSQLConf(SQLConf.ADAPTIVE_EXECUTION_LOG_LEVEL.key -> level._1) {
+ verifyLog(level._2)
}
+ }
+ }
- // Verify default log level
- verifyLog(Level.DEBUG)
-
- // Verify custom log level
- val levels = Seq(
- "TRACE" -> Level.TRACE,
- "trace" -> Level.TRACE,
- "DEBUG" -> Level.DEBUG,
- "debug" -> Level.DEBUG,
- "INFO" -> Level.INFO,
- "info" -> Level.INFO,
- "WARN" -> Level.WARN,
- "warn" -> Level.WARN,
- "ERROR" -> Level.ERROR,
- "error" -> Level.ERROR,
- "deBUG" -> Level.DEBUG)
-
- levels.foreach { level =>
- withSQLConf(SQLConf.ADAPTIVE_EXECUTION_LOG_LEVEL.key -> level._1) {
- verifyLog(level._2)
- }
- }
+ test("tree string output") {
+ withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") {
+ val df = sql("SELECT * FROM testData join testData2 ON key = a where value = '1'")
+ val planBefore = df.queryExecution.executedPlan
+ assert(!planBefore.toString.contains("== Current Plan =="))
+ assert(!planBefore.toString.contains("== Initial Plan =="))
+ df.collect()
+ val planAfter = df.queryExecution.executedPlan
+ assert(planAfter.toString.contains("== Final Plan =="))
+ assert(planAfter.toString.contains("== Initial Plan =="))
}
+ }
- test("tree string output") {
- withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") {
- val df = sql("SELECT * FROM testData join testData2 ON key = a where value = '1'")
- val planBefore = df.queryExecution.executedPlan
- assert(!planBefore.toString.contains("== Current Plan =="))
- assert(!planBefore.toString.contains("== Initial Plan =="))
- df.collect()
- val planAfter = df.queryExecution.executedPlan
- assert(planAfter.toString.contains("== Final Plan =="))
- assert(planAfter.toString.contains("== Initial Plan =="))
+ test("SPARK-31384: avoid NPE in OptimizeSkewedJoin when there's 0 partition plan") {
+ withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
+ SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
+ withTempView("t2") {
+ // create DataFrame with 0 partition
+ spark.createDataFrame(sparkContext.emptyRDD[Row], new StructType().add("b", IntegerType))
+ .createOrReplaceTempView("t2")
+ // should run successfully without NPE
+ runAdaptiveAndVerifyResult("SELECT * FROM testData2 t1 left semi join t2 ON t1.a=t2.b")
}
}
+ }
- test("SPARK-31384: avoid NPE in OptimizeSkewedJoin when there's 0 partition plan") {
- withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
- SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
- withTempView("t2") {
- // create DataFrame with 0 partition
- spark.createDataFrame(sparkContext.emptyRDD[Row], new StructType().add("b", IntegerType))
- .createOrReplaceTempView("t2")
- // should run successfully without NPE
- runAdaptiveAndVerifyResult("SELECT * FROM testData2 t1 join t2 ON t1.a=t2.b")
- }
+ test("SPARK-34682: AQEShuffleReadExec operating on canonicalized plan") {
+ withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") {
+ val (_, adaptivePlan) = runAdaptiveAndVerifyResult(
+ "SELECT key FROM testData GROUP BY key")
+ val reads = collect(adaptivePlan) {
+ case r: AQEShuffleReadExec => r
+ }
+ assert(reads.length == 1)
+ val read = reads.head
+ val c = read.canonicalized.asInstanceOf[AQEShuffleReadExec]
+ // we can't just call execute() because that has separate checks for canonicalized plans
+ val ex = intercept[IllegalStateException] {
+ val doExecute = PrivateMethod[Unit](Symbol("doExecute"))
+ c.invokePrivate(doExecute())
}
+ assert(ex.getMessage === "operating on canonicalized plan")
}
+ }
- ignore("metrics of the shuffle reader") {
+ test("metrics of the shuffle read") {
withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") {
val (_, adaptivePlan) = runAdaptiveAndVerifyResult(
"SELECT key FROM testData GROUP BY key")
- val readers = collect(adaptivePlan) {
- case r: ColumnarCustomShuffleReaderExec => r
- }
- print(readers.length)
- assert(readers.length == 1)
- val reader = readers.head
- assert(!reader.isLocalReader)
- assert(!reader.hasSkewedPartition)
- assert(reader.hasCoalescedPartition)
- assert(reader.metrics.keys.toSeq.sorted == Seq(
- "numPartitions", "partitionDataSize"))
- assert(reader.metrics("numPartitions").value == reader.partitionSpecs.length)
- assert(reader.metrics("partitionDataSize").value > 0)
+ val reads = collect(adaptivePlan) {
+ case r: AQEShuffleReadExec => r
+ }
+ assert(reads.length == 1)
+ val read = reads.head
+ assert(!read.isLocalRead)
+ assert(!read.hasSkewedPartition)
+ assert(read.hasCoalescedPartition)
+ assert(read.metrics.keys.toSeq.sorted == Seq(
+ "numCoalescedPartitions", "numPartitions", "partitionDataSize"))
+ assert(read.metrics("numCoalescedPartitions").value == 1)
+ assert(read.metrics("numPartitions").value == read.partitionSpecs.length)
+ assert(read.metrics("partitionDataSize").value > 0)
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") {
val (_, adaptivePlan) = runAdaptiveAndVerifyResult(
"SELECT * FROM testData join testData2 ON key = a where value = '1'")
val join = collect(adaptivePlan) {
- case j: ColumnarBroadcastHashJoinExec => j
+ case j: BroadcastHashJoinExec => j
}.head
assert(join.buildSide == BuildLeft)
- val readers = collect(join.right) {
- case r: ColumnarCustomShuffleReaderExec => r
+ val reads = collect(join.right) {
+ case r: AQEShuffleReadExec => r
}
- assert(readers.length == 1)
- val reader = readers.head
- assert(reader.isLocalReader)
- assert(reader.metrics.keys.toSeq == Seq("numPartitions"))
- assert(reader.metrics("numPartitions").value == reader.partitionSpecs.length)
+ assert(reads.length == 1)
+ val read = reads.head
+ assert(read.isLocalRead)
+ assert(read.metrics.keys.toSeq == Seq("numPartitions"))
+ assert(read.metrics("numPartitions").value == read.partitionSpecs.length)
}
withSQLConf(
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
SQLConf.SHUFFLE_PARTITIONS.key -> "100",
SQLConf.SKEW_JOIN_SKEWED_PARTITION_THRESHOLD.key -> "800",
- SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES.key -> "800") {
+ SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES.key -> "1000") {
withTempView("skewData1", "skewData2") {
spark
.range(0, 1000, 1, 10)
.select(
- when('id < 250, 249)
- .when('id >= 750, 1000)
- .otherwise('id).as("key1"),
- 'id as "value1")
+ when(Symbol("id") < 250, 249)
+ .when(Symbol("id") >= 750, 1000)
+ .otherwise(Symbol("id")).as("key1"),
+ Symbol("id") as "value1")
.createOrReplaceTempView("skewData1")
spark
.range(0, 1000, 1, 10)
.select(
- when('id < 250, 249)
- .otherwise('id).as("key2"),
- 'id as "value2")
+ when(Symbol("id") < 250, 249)
+ .otherwise(Symbol("id")).as("key2"),
+ Symbol("id") as "value2")
.createOrReplaceTempView("skewData2")
val (_, adaptivePlan) = runAdaptiveAndVerifyResult(
"SELECT * FROM skewData1 join skewData2 ON key1 = key2")
- val readers = collect(adaptivePlan) {
- case r: CustomShuffleReaderExec => r
+ val reads = collect(adaptivePlan) {
+ case r: AQEShuffleReadExec => r
}
- readers.foreach { reader =>
- assert(!reader.isLocalReader)
- assert(reader.hasCoalescedPartition)
- assert(reader.hasSkewedPartition)
- assert(reader.metrics.contains("numSkewedPartitions"))
+ reads.foreach { read =>
+ assert(!read.isLocalRead)
+ assert(read.hasCoalescedPartition)
+ assert(read.hasSkewedPartition)
+ assert(read.metrics.contains("numSkewedPartitions"))
}
- print(readers(1).metrics("numSkewedPartitions"))
- print(readers(1).metrics("numSkewedSplits"))
- assert(readers(0).metrics("numSkewedPartitions").value == 2)
- assert(readers(0).metrics("numSkewedSplits").value == 15)
- assert(readers(1).metrics("numSkewedPartitions").value == 1)
- assert(readers(1).metrics("numSkewedSplits").value == 12)
+ assert(reads(0).metrics("numSkewedPartitions").value == 2)
+ assert(reads(0).metrics("numSkewedSplits").value == 11)
+ assert(reads(1).metrics("numSkewedPartitions").value == 1)
+ assert(reads(1).metrics("numSkewedSplits").value == 9)
}
}
}
}
- test("control a plan explain mode in listeners via SQLConf") {
-
- def checkPlanDescription(mode: String, expected: Seq[String]): Unit = {
- var checkDone = false
- val listener = new SparkListener {
- override def onOtherEvent(event: SparkListenerEvent): Unit = {
- event match {
- case SparkListenerSQLAdaptiveExecutionUpdate(_, planDescription, _) =>
- assert(expected.forall(planDescription.contains))
- checkDone = true
- case _ => // ignore other events
- }
+ test("control a plan explain mode in listeners via SQLConf") {
+
+ def checkPlanDescription(mode: String, expected: Seq[String]): Unit = {
+ var checkDone = false
+ val listener = new SparkListener {
+ override def onOtherEvent(event: SparkListenerEvent): Unit = {
+ event match {
+ case SparkListenerSQLAdaptiveExecutionUpdate(_, planDescription, _) =>
+ assert(expected.forall(planDescription.contains))
+ checkDone = true
+ case _ => // ignore other events
}
}
- spark.sparkContext.addSparkListener(listener)
- withSQLConf(SQLConf.UI_EXPLAIN_MODE.key -> mode,
+ }
+ spark.sparkContext.addSparkListener(listener)
+ withSQLConf(SQLConf.UI_EXPLAIN_MODE.key -> mode,
SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") {
- val dfAdaptive = sql("SELECT * FROM testData JOIN testData2 ON key = a WHERE value = '1'")
- try {
- checkAnswer(dfAdaptive, Row(1, "1", 1, 1) :: Row(1, "1", 1, 2) :: Nil)
- spark.sparkContext.listenerBus.waitUntilEmpty()
- assert(checkDone)
- } finally {
- spark.sparkContext.removeSparkListener(listener)
- }
+ val dfAdaptive = sql("SELECT * FROM testData JOIN testData2 ON key = a WHERE value = '1'")
+ try {
+ checkAnswer(dfAdaptive, Row(1, "1", 1, 1) :: Row(1, "1", 1, 2) :: Nil)
+ spark.sparkContext.listenerBus.waitUntilEmpty()
+ assert(checkDone)
+ } finally {
+ spark.sparkContext.removeSparkListener(listener)
}
}
+ }
- Seq(("simple", Seq("== Physical Plan ==")),
+ Seq(("simple", Seq("== Physical Plan ==")),
("extended", Seq("== Parsed Logical Plan ==", "== Analyzed Logical Plan ==",
"== Optimized Logical Plan ==", "== Physical Plan ==")),
("codegen", Seq("WholeStageCodegen subtrees")),
("cost", Seq("== Optimized Logical Plan ==", "Statistics(sizeInBytes")),
("formatted", Seq("== Physical Plan ==", "Output", "Arguments"))).foreach {
- case (mode, expected) =>
- checkPlanDescription(mode, expected)
- }
+ case (mode, expected) =>
+ checkPlanDescription(mode, expected)
}
+ }
- test("SPARK-30953: InsertAdaptiveSparkPlan should apply AQE on child plan of write commands") {
- withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
- SQLConf.ADAPTIVE_EXECUTION_FORCE_APPLY.key -> "true") {
- withTable("t1") {
- val plan = sql("CREATE TABLE t1 USING parquet AS SELECT 1 col").queryExecution.executedPlan
- assert(plan.isInstanceOf[DataWritingCommandExec])
- assert(plan.asInstanceOf[DataWritingCommandExec].child.isInstanceOf[AdaptiveSparkPlanExec])
- }
+ test("SPARK-30953: InsertAdaptiveSparkPlan should apply AQE on child plan of write commands") {
+ withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
+ SQLConf.ADAPTIVE_EXECUTION_FORCE_APPLY.key -> "true") {
+ withTable("t1") {
+ val plan = sql("CREATE TABLE t1 USING parquet AS SELECT 1 col").queryExecution.executedPlan
+ assert(plan.isInstanceOf[CommandResultExec])
+ val commandResultExec = plan.asInstanceOf[CommandResultExec]
+ assert(commandResultExec.commandPhysicalPlan.isInstanceOf[DataWritingCommandExec])
+ assert(commandResultExec.commandPhysicalPlan.asInstanceOf[DataWritingCommandExec]
+ .child.isInstanceOf[AdaptiveSparkPlanExec])
}
}
+ }
- test("AQE should set active session during execution") {
- withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") {
- val df = spark.range(10).select(sum('id))
- assert(df.queryExecution.executedPlan.isInstanceOf[AdaptiveSparkPlanExec])
- SparkSession.setActiveSession(null)
- checkAnswer(df, Seq(Row(45)))
- SparkSession.setActiveSession(spark) // recover the active session.
- }
- }
-
- test("No deadlock in UI update") {
- object TestStrategy extends Strategy {
- def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
- case _: Aggregate =>
- withSQLConf(
- SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
- SQLConf.ADAPTIVE_EXECUTION_FORCE_APPLY.key -> "true") {
- spark.range(5).rdd
- }
- Nil
- case _ => Nil
- }
+ test("AQE should set active session during execution") {
+ withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") {
+ val df = spark.range(10).select(sum(Symbol("id")))
+ assert(df.queryExecution.executedPlan.isInstanceOf[AdaptiveSparkPlanExec])
+ SparkSession.setActiveSession(null)
+ checkAnswer(df, Seq(Row(45)))
+ SparkSession.setActiveSession(spark) // recover the active session.
+ }
+ }
+
+ test("No deadlock in UI update") {
+ object TestStrategy extends Strategy {
+ def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
+ case _: Aggregate =>
+ withSQLConf(
+ SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
+ SQLConf.ADAPTIVE_EXECUTION_FORCE_APPLY.key -> "true") {
+ spark.range(5).rdd
+ }
+ Nil
+ case _ => Nil
}
+ }
- withSQLConf(
- SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
- SQLConf.ADAPTIVE_EXECUTION_FORCE_APPLY.key -> "true") {
- try {
- spark.experimental.extraStrategies = TestStrategy :: Nil
- val df = spark.range(10).groupBy('id).count()
- df.collect()
- } finally {
- spark.experimental.extraStrategies = Nil
- }
+ withSQLConf(
+ SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
+ SQLConf.ADAPTIVE_EXECUTION_FORCE_APPLY.key -> "true") {
+ try {
+ spark.experimental.extraStrategies = TestStrategy :: Nil
+ val df = spark.range(10).groupBy(Symbol("id")).count()
+ df.collect()
+ } finally {
+ spark.experimental.extraStrategies = Nil
}
}
+ }
- test("SPARK-31658: SQL UI should show write commands") {
- withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
- SQLConf.ADAPTIVE_EXECUTION_FORCE_APPLY.key -> "true") {
- withTable("t1") {
- var checkDone = false
- val listener = new SparkListener {
- override def onOtherEvent(event: SparkListenerEvent): Unit = {
- event match {
- case SparkListenerSQLAdaptiveExecutionUpdate(_, _, planInfo) =>
- assert(planInfo.nodeName == "Execute CreateDataSourceTableAsSelectCommand")
- checkDone = true
- case _ => // ignore other events
- }
+ test("SPARK-31658: SQL UI should show write commands") {
+ withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
+ SQLConf.ADAPTIVE_EXECUTION_FORCE_APPLY.key -> "true") {
+ withTable("t1") {
+ var checkDone = false
+ val listener = new SparkListener {
+ override def onOtherEvent(event: SparkListenerEvent): Unit = {
+ event match {
+ case SparkListenerSQLAdaptiveExecutionUpdate(_, _, planInfo) =>
+ assert(planInfo.nodeName == "Execute CreateDataSourceTableAsSelectCommand")
+ checkDone = true
+ case _ => // ignore other events
}
}
- spark.sparkContext.addSparkListener(listener)
- try {
- sql("CREATE TABLE t1 USING parquet AS SELECT 1 col").collect()
- spark.sparkContext.listenerBus.waitUntilEmpty()
- assert(checkDone)
- } finally {
- spark.sparkContext.removeSparkListener(listener)
- }
+ }
+ spark.sparkContext.addSparkListener(listener)
+ try {
+ sql("CREATE TABLE t1 USING parquet AS SELECT 1 col").collect()
+ spark.sparkContext.listenerBus.waitUntilEmpty()
+ assert(checkDone)
+ } finally {
+ spark.sparkContext.removeSparkListener(listener)
}
}
}
+ }
- test("SPARK-31220, SPARK-32056: repartition by expression with AQE") {
- Seq(true, false).foreach { enableAQE =>
- withSQLConf(
- SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> enableAQE.toString,
- SQLConf.COALESCE_PARTITIONS_ENABLED.key -> "true",
- SQLConf.COALESCE_PARTITIONS_INITIAL_PARTITION_NUM.key -> "10",
- SQLConf.SHUFFLE_PARTITIONS.key -> "10") {
-
- val df1 = spark.range(10).repartition($"id")
- val df2 = spark.range(10).repartition($"id" + 1)
-
- val partitionsNum1 = df1.rdd.collectPartitions().length
- val partitionsNum2 = df2.rdd.collectPartitions().length
-
- if (enableAQE) {
- assert(partitionsNum1 < 10)
- assert(partitionsNum2 < 10)
+ test("SPARK-31220, SPARK-32056: repartition by expression with AQE") {
+ Seq(true, false).foreach { enableAQE =>
+ withSQLConf(
+ SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> enableAQE.toString,
+ SQLConf.COALESCE_PARTITIONS_ENABLED.key -> "true",
+ SQLConf.COALESCE_PARTITIONS_INITIAL_PARTITION_NUM.key -> "10",
+ SQLConf.SHUFFLE_PARTITIONS.key -> "10") {
+
+ val df1 = spark.range(10).repartition($"id")
+ val df2 = spark.range(10).repartition($"id" + 1)
+
+ val partitionsNum1 = df1.rdd.collectPartitions().length
+ val partitionsNum2 = df2.rdd.collectPartitions().length
+
+ if (enableAQE) {
+ assert(partitionsNum1 < 10)
+ assert(partitionsNum2 < 10)
+
+ checkInitialPartitionNum(df1, 10)
+ checkInitialPartitionNum(df2, 10)
+ } else {
+ assert(partitionsNum1 === 10)
+ assert(partitionsNum2 === 10)
+ }
- checkInitialPartitionNum(df1, 10)
- checkInitialPartitionNum(df2, 10)
- } else {
- assert(partitionsNum1 === 10)
- assert(partitionsNum2 === 10)
- }
+ // Don't coalesce partitions if the number of partitions is specified.
+ val df3 = spark.range(10).repartition(10, $"id")
+ val df4 = spark.range(10).repartition(10)
+ assert(df3.rdd.collectPartitions().length == 10)
+ assert(df4.rdd.collectPartitions().length == 10)
+ }
+ }
+ }
- // Don't coalesce partitions if the number of partitions is specified.
- val df3 = spark.range(10).repartition(10, $"id")
- val df4 = spark.range(10).repartition(10)
- assert(df3.rdd.collectPartitions().length == 10)
- assert(df4.rdd.collectPartitions().length == 10)
+ test("SPARK-31220, SPARK-32056: repartition by range with AQE") {
+ Seq(true, false).foreach { enableAQE =>
+ withSQLConf(
+ SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> enableAQE.toString,
+ SQLConf.COALESCE_PARTITIONS_ENABLED.key -> "true",
+ SQLConf.COALESCE_PARTITIONS_INITIAL_PARTITION_NUM.key -> "10",
+ SQLConf.SHUFFLE_PARTITIONS.key -> "10") {
+
+ val df1 = spark.range(10).toDF.repartitionByRange($"id".asc)
+ val df2 = spark.range(10).toDF.repartitionByRange(($"id" + 1).asc)
+
+ val partitionsNum1 = df1.rdd.collectPartitions().length
+ val partitionsNum2 = df2.rdd.collectPartitions().length
+
+ if (enableAQE) {
+ assert(partitionsNum1 < 10)
+ assert(partitionsNum2 < 10)
+
+ checkInitialPartitionNum(df1, 10)
+ checkInitialPartitionNum(df2, 10)
+ } else {
+ assert(partitionsNum1 === 10)
+ assert(partitionsNum2 === 10)
}
+
+ // Don't coalesce partitions if the number of partitions is specified.
+ val df3 = spark.range(10).repartitionByRange(10, $"id".asc)
+ assert(df3.rdd.collectPartitions().length == 10)
}
}
+ }
- test("SPARK-31220, SPARK-32056: repartition by range with AQE") {
- Seq(true, false).foreach { enableAQE =>
+ test("SPARK-31220, SPARK-32056: repartition using sql and hint with AQE") {
+ Seq(true, false).foreach { enableAQE =>
+ withTempView("test") {
withSQLConf(
SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> enableAQE.toString,
SQLConf.COALESCE_PARTITIONS_ENABLED.key -> "true",
SQLConf.COALESCE_PARTITIONS_INITIAL_PARTITION_NUM.key -> "10",
SQLConf.SHUFFLE_PARTITIONS.key -> "10") {
- val df1 = spark.range(10).toDF.repartitionByRange($"id".asc)
- val df2 = spark.range(10).toDF.repartitionByRange(($"id" + 1).asc)
+ spark.range(10).toDF.createTempView("test")
+
+ val df1 = spark.sql("SELECT /*+ REPARTITION(id) */ * from test")
+ val df2 = spark.sql("SELECT /*+ REPARTITION_BY_RANGE(id) */ * from test")
+ val df3 = spark.sql("SELECT * from test DISTRIBUTE BY id")
+ val df4 = spark.sql("SELECT * from test CLUSTER BY id")
val partitionsNum1 = df1.rdd.collectPartitions().length
val partitionsNum2 = df2.rdd.collectPartitions().length
+ val partitionsNum3 = df3.rdd.collectPartitions().length
+ val partitionsNum4 = df4.rdd.collectPartitions().length
if (enableAQE) {
assert(partitionsNum1 < 10)
assert(partitionsNum2 < 10)
+ assert(partitionsNum3 < 10)
+ assert(partitionsNum4 < 10)
checkInitialPartitionNum(df1, 10)
checkInitialPartitionNum(df2, 10)
+ checkInitialPartitionNum(df3, 10)
+ checkInitialPartitionNum(df4, 10)
} else {
assert(partitionsNum1 === 10)
assert(partitionsNum2 === 10)
+ assert(partitionsNum3 === 10)
+ assert(partitionsNum4 === 10)
}
// Don't coalesce partitions if the number of partitions is specified.
- val df3 = spark.range(10).repartitionByRange(10, $"id".asc)
- assert(df3.rdd.collectPartitions().length == 10)
- }
- }
- }
-
- test("SPARK-31220, SPARK-32056: repartition using sql and hint with AQE") {
- Seq(true, false).foreach { enableAQE =>
- withTempView("test") {
- withSQLConf(
- SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> enableAQE.toString,
- SQLConf.COALESCE_PARTITIONS_ENABLED.key -> "true",
- SQLConf.COALESCE_PARTITIONS_INITIAL_PARTITION_NUM.key -> "10",
- SQLConf.SHUFFLE_PARTITIONS.key -> "10") {
-
- spark.range(10).toDF.createTempView("test")
-
- val df1 = spark.sql("SELECT /*+ REPARTITION(id) */ * from test")
- val df2 = spark.sql("SELECT /*+ REPARTITION_BY_RANGE(id) */ * from test")
- val df3 = spark.sql("SELECT * from test DISTRIBUTE BY id")
- val df4 = spark.sql("SELECT * from test CLUSTER BY id")
-
- val partitionsNum1 = df1.rdd.collectPartitions().length
- val partitionsNum2 = df2.rdd.collectPartitions().length
- val partitionsNum3 = df3.rdd.collectPartitions().length
- val partitionsNum4 = df4.rdd.collectPartitions().length
-
- if (enableAQE) {
- assert(partitionsNum1 < 10)
- assert(partitionsNum2 < 10)
- assert(partitionsNum3 < 10)
- assert(partitionsNum4 < 10)
-
- checkInitialPartitionNum(df1, 10)
- checkInitialPartitionNum(df2, 10)
- checkInitialPartitionNum(df3, 10)
- checkInitialPartitionNum(df4, 10)
- } else {
- assert(partitionsNum1 === 10)
- assert(partitionsNum2 === 10)
- assert(partitionsNum3 === 10)
- assert(partitionsNum4 === 10)
- }
-
- // Don't coalesce partitions if the number of partitions is specified.
- val df5 = spark.sql("SELECT /*+ REPARTITION(10, id) */ * from test")
- val df6 = spark.sql("SELECT /*+ REPARTITION_BY_RANGE(10, id) */ * from test")
- assert(df5.rdd.collectPartitions().length == 10)
- assert(df6.rdd.collectPartitions().length == 10)
- }
+ val df5 = spark.sql("SELECT /*+ REPARTITION(10, id) */ * from test")
+ val df6 = spark.sql("SELECT /*+ REPARTITION_BY_RANGE(10, id) */ * from test")
+ assert(df5.rdd.collectPartitions().length == 10)
+ assert(df6.rdd.collectPartitions().length == 10)
}
}
}
+ }
test("SPARK-32573: Eliminate NAAJ when BuildSide is HashedRelationWithAllNullKeys") {
withSQLConf(
@@ -1208,149 +1322,373 @@ class ColumnarAdaptiveQueryExecSuite extends ColumnarSparkPlanTest
assert(bhj.size == 1)
val join = findTopLevelBaseJoin(adaptivePlan)
assert(join.isEmpty)
- checkNumLocalShuffleReaders(adaptivePlan)
+ checkNumLocalShuffleReads(adaptivePlan)
}
}
- test("SPARK-32717: AQEOptimizer should respect excludedRules configuration") {
- withSQLConf(
- SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
- SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> Long.MaxValue.toString,
- // This test is a copy of test(SPARK-32573), in order to test the configuration
- // `spark.sql.adaptive.optimizer.excludedRules` works as expect.
- SQLConf.ADAPTIVE_OPTIMIZER_EXCLUDED_RULES.key -> EliminateJoinToEmptyRelation.ruleName) {
- val (plan, adaptivePlan) = runAdaptiveAndVerifyResult(
- "SELECT * FROM testData2 t1 WHERE t1.b NOT IN (SELECT b FROM testData3)")
- val bhj = findTopLevelBroadcastHashJoin(plan)
- assert(bhj.size == 1)
+ test("SPARK-32717: AQEOptimizer should respect excludedRules configuration") {
+ withSQLConf(
+ SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
+ SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> Long.MaxValue.toString,
+ // This test is a copy of test(SPARK-32573), in order to test the configuration
+ // `spark.sql.adaptive.optimizer.excludedRules` works as expect.
+ SQLConf.ADAPTIVE_OPTIMIZER_EXCLUDED_RULES.key -> AQEPropagateEmptyRelation.ruleName) {
+ val (plan, adaptivePlan) = runAdaptiveAndVerifyResult(
+ "SELECT * FROM testData2 t1 WHERE t1.b NOT IN (SELECT b FROM testData3)")
+ val bhj = findTopLevelBroadcastHashJoin(plan)
+ assert(bhj.size == 1)
+ val join = findTopLevelBaseJoin(adaptivePlan)
+ // this is different compares to test(SPARK-32573) due to the rule
+ // `EliminateUnnecessaryJoin` has been excluded.
+ assert(join.nonEmpty)
+ checkNumLocalShuffleReads(adaptivePlan)
+ }
+ }
+
+ test("SPARK-32649: Eliminate inner and semi join to empty relation") {
+ withSQLConf(
+ SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
+ SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") {
+ Seq(
+ // inner join (small table at right side)
+ "SELECT * FROM testData t1 join testData3 t2 ON t1.key = t2.a WHERE t2.b = 1",
+ // inner join (small table at left side)
+ "SELECT * FROM testData3 t1 join testData t2 ON t1.a = t2.key WHERE t1.b = 1",
+ // left semi join
+ "SELECT * FROM testData t1 left semi join testData3 t2 ON t1.key = t2.a AND t2.b = 1"
+ ).foreach(query => {
+ val (plan, adaptivePlan) = runAdaptiveAndVerifyResult(query)
+ val smj = findTopLevelSortMergeJoin(plan)
+ assert(smj.size == 1)
val join = findTopLevelBaseJoin(adaptivePlan)
- // this is different compares to test(SPARK-32573) due to the rule
- // `EliminateJoinToEmptyRelation` has been excluded.
- assert(join.nonEmpty)
- checkNumLocalShuffleReaders(adaptivePlan)
+ assert(join.isEmpty)
+ checkNumLocalShuffleReads(adaptivePlan)
+ })
+ }
+ }
+
+ test("SPARK-34533: Eliminate left anti join to empty relation") {
+ withSQLConf(
+ SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") {
+ Seq(
+ // broadcast non-empty right side
+ ("SELECT /*+ broadcast(testData3) */ * FROM testData LEFT ANTI JOIN testData3", true),
+ // broadcast empty right side
+ ("SELECT /*+ broadcast(emptyTestData) */ * FROM testData LEFT ANTI JOIN emptyTestData",
+ true),
+ // broadcast left side
+ ("SELECT /*+ broadcast(testData) */ * FROM testData LEFT ANTI JOIN testData3", false)
+ ).foreach { case (query, isEliminated) =>
+ val (plan, adaptivePlan) = runAdaptiveAndVerifyResult(query)
+ assert(findTopLevelBaseJoin(plan).size == 1)
+ assert(findTopLevelBaseJoin(adaptivePlan).isEmpty == isEliminated)
}
}
+ }
- test("SPARK-32649: Eliminate inner to empty relation") {
- withSQLConf(
- SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
- SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") {
- Seq(
- // inner join (small table at right side)
- "SELECT * FROM testData t1 join testData3 t2 ON t1.key = t2.a WHERE t2.b = 1",
- // inner join (small table at left side)
- "SELECT * FROM testData3 t1 join testData t2 ON t1.a = t2.key WHERE t1.b = 1",
- // left semi join : left join do not has omni impl
- // "SELECT * FROM testData t1 left semi join testData3 t2 ON t1.key = t2.a AND t2.b = 1"
- ).foreach(query => {
- val (plan, adaptivePlan) = runAdaptiveAndVerifyResult(query)
- val smj = findTopLevelSortMergeJoin(plan)
- assert(smj.size == 1)
- val join = findTopLevelBaseJoin(adaptivePlan)
- assert(join.isEmpty)
- checkNumLocalShuffleReaders(adaptivePlan)
- })
+ test("SPARK-34781: Eliminate left semi/anti join to its left side") {
+ withSQLConf(
+ SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") {
+ Seq(
+ // left semi join and non-empty right side
+ ("SELECT * FROM testData LEFT SEMI JOIN testData3", true),
+ // left semi join, non-empty right side and non-empty join condition
+ ("SELECT * FROM testData t1 LEFT SEMI JOIN testData3 t2 ON t1.key = t2.a", false),
+ // left anti join and empty right side
+ ("SELECT * FROM testData LEFT ANTI JOIN emptyTestData", true),
+ // left anti join, empty right side and non-empty join condition
+ ("SELECT * FROM testData t1 LEFT ANTI JOIN emptyTestData t2 ON t1.key = t2.key", true)
+ ).foreach { case (query, isEliminated) =>
+ val (plan, adaptivePlan) = runAdaptiveAndVerifyResult(query)
+ assert(findTopLevelBaseJoin(plan).size == 1)
+ assert(findTopLevelBaseJoin(adaptivePlan).isEmpty == isEliminated)
}
}
+ }
- test("SPARK-32753: Only copy tags to node with no tags") {
- withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") {
- withTempView("v1") {
- spark.range(10).union(spark.range(10)).createOrReplaceTempView("v1")
+ test("SPARK-35455: Unify empty relation optimization between normal and AQE optimizer " +
+ "- single join") {
+ withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
+ SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
+ Seq(
+ // left semi join and empty left side
+ ("SELECT * FROM (SELECT * FROM testData WHERE value = '0')t1 LEFT SEMI JOIN " +
+ "testData2 t2 ON t1.key = t2.a", true),
+ // left anti join and empty left side
+ ("SELECT * FROM (SELECT * FROM testData WHERE value = '0')t1 LEFT ANTI JOIN " +
+ "testData2 t2 ON t1.key = t2.a", true),
+ // left outer join and empty left side
+ ("SELECT * FROM (SELECT * FROM testData WHERE key = 0)t1 LEFT JOIN testData2 t2 ON " +
+ "t1.key = t2.a", true),
+ // left outer join and non-empty left side
+ ("SELECT * FROM testData t1 LEFT JOIN testData2 t2 ON " +
+ "t1.key = t2.a", false),
+ // right outer join and empty right side
+ ("SELECT * FROM testData t1 RIGHT JOIN (SELECT * FROM testData2 WHERE b = 0)t2 ON " +
+ "t1.key = t2.a", true),
+ // right outer join and non-empty right side
+ ("SELECT * FROM testData t1 RIGHT JOIN testData2 t2 ON " +
+ "t1.key = t2.a", false),
+ // full outer join and both side empty
+ ("SELECT * FROM (SELECT * FROM testData WHERE key = 0)t1 FULL JOIN " +
+ "(SELECT * FROM testData2 WHERE b = 0)t2 ON t1.key = t2.a", true),
+ // full outer join and left side empty right side non-empty
+ ("SELECT * FROM (SELECT * FROM testData WHERE key = 0)t1 FULL JOIN " +
+ "testData2 t2 ON t1.key = t2.a", true)
+ ).foreach { case (query, isEliminated) =>
+ val (plan, adaptivePlan) = runAdaptiveAndVerifyResult(query)
+ assert(findTopLevelBaseJoin(plan).size == 1)
+ assert(findTopLevelBaseJoin(adaptivePlan).isEmpty == isEliminated, adaptivePlan)
+ }
+ }
+ }
- val (_, adaptivePlan) = runAdaptiveAndVerifyResult(
- "SELECT id FROM v1 GROUP BY id DISTRIBUTE BY id")
- assert(collect(adaptivePlan) {
- case s: ShuffleExchangeExec => s
- }.length == 1)
- }
+ test("SPARK-35455: Unify empty relation optimization between normal and AQE optimizer " +
+ "- multi join") {
+ withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
+ SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
+ Seq(
+ """
+ |SELECT * FROM testData t1
+ | JOIN (SELECT * FROM testData2 WHERE b = 0) t2 ON t1.key = t2.a
+ | LEFT JOIN testData2 t3 ON t1.key = t3.a
+ |""".stripMargin,
+ """
+ |SELECT * FROM (SELECT * FROM testData WHERE key = 0) t1
+ | LEFT ANTI JOIN testData2 t2
+ | FULL JOIN (SELECT * FROM testData2 WHERE b = 0) t3 ON t1.key = t3.a
+ |""".stripMargin,
+ """
+ |SELECT * FROM testData t1
+ | LEFT SEMI JOIN (SELECT * FROM testData2 WHERE b = 0)
+ | RIGHT JOIN testData2 t3 on t1.key = t3.a
+ |""".stripMargin
+ ).foreach { query =>
+ val (plan, adaptivePlan) = runAdaptiveAndVerifyResult(query)
+ assert(findTopLevelBaseJoin(plan).size == 2)
+ assert(findTopLevelBaseJoin(adaptivePlan).isEmpty)
}
}
+ }
- test("Logging plan changes for AQE") {
- val testAppender = new LogAppender("plan changes")
- withLogAppender(testAppender) {
- withSQLConf(
+ test("SPARK-35585: Support propagate empty relation through project/filter") {
+ withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
+ SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
+ val (plan1, adaptivePlan1) = runAdaptiveAndVerifyResult(
+ "SELECT key FROM testData WHERE key = 0 ORDER BY key, value")
+ assert(findTopLevelSort(plan1).size == 1)
+ assert(stripAQEPlan(adaptivePlan1).isInstanceOf[LocalTableScanExec])
+
+ val (plan2, adaptivePlan2) = runAdaptiveAndVerifyResult(
+ "SELECT key FROM (SELECT * FROM testData WHERE value = 'no_match' ORDER BY key)" +
+ " WHERE key > rand()")
+ assert(findTopLevelSort(plan2).size == 1)
+ assert(stripAQEPlan(adaptivePlan2).isInstanceOf[LocalTableScanExec])
+ }
+ }
+
+ test("SPARK-35442: Support propagate empty relation through aggregate") {
+ withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") {
+ val (plan1, adaptivePlan1) = runAdaptiveAndVerifyResult(
+ "SELECT key, count(*) FROM testData WHERE value = 'no_match' GROUP BY key")
+ assert(!plan1.isInstanceOf[LocalTableScanExec])
+ assert(stripAQEPlan(adaptivePlan1).isInstanceOf[LocalTableScanExec])
+
+ val (plan2, adaptivePlan2) = runAdaptiveAndVerifyResult(
+ "SELECT key, count(*) FROM testData WHERE value = 'no_match' GROUP BY key limit 1")
+ assert(!plan2.isInstanceOf[LocalTableScanExec])
+ assert(stripAQEPlan(adaptivePlan2).isInstanceOf[LocalTableScanExec])
+
+ val (plan3, adaptivePlan3) = runAdaptiveAndVerifyResult(
+ "SELECT count(*) FROM testData WHERE value = 'no_match'")
+ assert(!plan3.isInstanceOf[LocalTableScanExec])
+ assert(!stripAQEPlan(adaptivePlan3).isInstanceOf[LocalTableScanExec])
+ }
+ }
+
+ test("SPARK-35442: Support propagate empty relation through union") {
+ def checkNumUnion(plan: SparkPlan, numUnion: Int): Unit = {
+ assert(
+ collect(plan) {
+ case u: UnionExec => u
+ }.size == numUnion)
+ }
+
+ withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") {
+ val (plan1, adaptivePlan1) = runAdaptiveAndVerifyResult(
+ """
+ |SELECT key, count(*) FROM testData WHERE value = 'no_match' GROUP BY key
+ |UNION ALL
+ |SELECT key, 1 FROM testData
+ |""".stripMargin)
+ checkNumUnion(plan1, 1)
+ checkNumUnion(adaptivePlan1, 0)
+ assert(!stripAQEPlan(adaptivePlan1).isInstanceOf[LocalTableScanExec])
+
+ val (plan2, adaptivePlan2) = runAdaptiveAndVerifyResult(
+ """
+ |SELECT key, count(*) FROM testData WHERE value = 'no_match' GROUP BY key
+ |UNION ALL
+ |SELECT /*+ REPARTITION */ key, 1 FROM testData WHERE value = 'no_match'
+ |""".stripMargin)
+ checkNumUnion(plan2, 1)
+ checkNumUnion(adaptivePlan2, 0)
+ assert(stripAQEPlan(adaptivePlan2).isInstanceOf[LocalTableScanExec])
+ }
+ }
+
+ test("SPARK-32753: Only copy tags to node with no tags") {
+ withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") {
+ withTempView("v1") {
+ spark.range(10).union(spark.range(10)).createOrReplaceTempView("v1")
+
+ val (_, adaptivePlan) = runAdaptiveAndVerifyResult(
+ "SELECT id FROM v1 GROUP BY id DISTRIBUTE BY id")
+ assert(collect(adaptivePlan) {
+ case s: ShuffleExchangeExec => s
+ }.length == 1)
+ }
+ }
+ }
+
+ test("Logging plan changes for AQE") {
+ val testAppender = new LogAppender("plan changes")
+ withLogAppender(testAppender) {
+ withSQLConf(
SQLConf.PLAN_CHANGE_LOG_LEVEL.key -> "INFO",
SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") {
- sql("SELECT * FROM testData JOIN testData2 ON key = a " +
- "WHERE value = (SELECT max(a) FROM testData3)").collect()
- }
- Seq("=== Result of Batch AQE Preparations ===",
+ sql("SELECT * FROM testData JOIN testData2 ON key = a " +
+ "WHERE value = (SELECT max(a) FROM testData3)").collect()
+ }
+ Seq("=== Result of Batch AQE Preparations ===",
"=== Result of Batch AQE Post Stage Creation ===",
"=== Result of Batch AQE Replanning ===",
- "=== Result of Batch AQE Query Stage Optimization ===",
- "=== Result of Batch AQE Final Query Stage Optimization ===").foreach { expectedMsg =>
- assert(testAppender.loggingEvents.exists(_.getRenderedMessage.contains(expectedMsg)))
- }
+ "=== Result of Batch AQE Query Stage Optimization ===").foreach { expectedMsg =>
+ assert(testAppender.loggingEvents.exists(
+ _.getMessage.getFormattedMessage.contains(expectedMsg)))
}
}
+ }
- test("SPARK-32932: Do not use local shuffle reader at final stage on write command") {
- withSQLConf(SQLConf.PARTITION_OVERWRITE_MODE.key -> PartitionOverwriteMode.DYNAMIC.toString,
- SQLConf.SHUFFLE_PARTITIONS.key -> "5",
- SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") {
- val data = for (
- i <- 1L to 10L;
- j <- 1L to 3L
- ) yield (i, j)
-
- val df = data.toDF("i", "j").repartition($"j")
- var noLocalReader: Boolean = false
- val listener = new QueryExecutionListener {
- override def onSuccess(funcName: String, qe: QueryExecution, durationNs: Long): Unit = {
- qe.executedPlan match {
- case plan@(_: DataWritingCommandExec | _: V2TableWriteExec) =>
- assert(plan.asInstanceOf[UnaryExecNode].child.isInstanceOf[AdaptiveSparkPlanExec])
- noLocalReader = collect(plan) {
- case exec: CustomShuffleReaderExec if exec.isLocalReader => exec
- }.isEmpty
- case _ => // ignore other events
- }
+ test("SPARK-32932: Do not use local shuffle read at final stage on write command") {
+ withSQLConf(SQLConf.PARTITION_OVERWRITE_MODE.key -> PartitionOverwriteMode.DYNAMIC.toString,
+ SQLConf.SHUFFLE_PARTITIONS.key -> "5",
+ SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") {
+ val data = for (
+ i <- 1L to 10L;
+ j <- 1L to 3L
+ ) yield (i, j)
+
+ val df = data.toDF("i", "j").repartition($"j")
+ var noLocalread: Boolean = false
+ val listener = new QueryExecutionListener {
+ override def onSuccess(funcName: String, qe: QueryExecution, durationNs: Long): Unit = {
+ qe.executedPlan match {
+ case plan@(_: DataWritingCommandExec | _: V2TableWriteExec) =>
+ assert(plan.asInstanceOf[UnaryExecNode].child.isInstanceOf[AdaptiveSparkPlanExec])
+ noLocalread = collect(plan) {
+ case exec: AQEShuffleReadExec if exec.isLocalRead => exec
+ }.isEmpty
+ case _ => // ignore other events
}
- override def onFailure(funcName: String, qe: QueryExecution,
- exception: Exception): Unit = {}
- }
- spark.listenerManager.register(listener)
-
- withTable("t") {
- df.write.partitionBy("j").saveAsTable("t")
- sparkContext.listenerBus.waitUntilEmpty()
- assert(noLocalReader)
- noLocalReader = false
}
+ override def onFailure(funcName: String, qe: QueryExecution,
+ exception: Exception): Unit = {}
+ }
+ spark.listenerManager.register(listener)
- // Test DataSource v2
- val format = classOf[NoopDataSource].getName
- df.write.format(format).mode("overwrite").save()
+ withTable("t") {
+ df.write.partitionBy("j").saveAsTable("t")
sparkContext.listenerBus.waitUntilEmpty()
- assert(noLocalReader)
- noLocalReader = false
-
- spark.listenerManager.unregister(listener)
+ assert(noLocalread)
+ noLocalread = false
}
+
+ // Test DataSource v2
+ val format = classOf[NoopDataSource].getName
+ df.write.format(format).mode("overwrite").save()
+ sparkContext.listenerBus.waitUntilEmpty()
+ assert(noLocalread)
+ noLocalread = false
+
+ spark.listenerManager.unregister(listener)
}
+ }
- test("SPARK-33494: Do not use local shuffle reader for repartition") {
- withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") {
- val df = spark.table("testData").repartition('key)
- df.collect()
- // local shuffle reader breaks partitioning and shouldn't be used for repartition operation
- // which is specified by users.
- checkNumLocalShuffleReaders(df.queryExecution.executedPlan, numShufflesWithoutLocalReader = 1)
- }
+ test("SPARK-33494: Do not use local shuffle read for repartition") {
+ withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") {
+ val df = spark.table("testData").repartition(Symbol("key"))
+ df.collect()
+ // local shuffle read breaks partitioning and shouldn't be used for repartition operation
+ // which is specified by users.
+ checkNumLocalShuffleReads(df.queryExecution.executedPlan, numShufflesWithoutLocalRead = 1)
}
+ }
- test("SPARK-33551: Do not use custom shuffle reader for repartition") {
+ test("SPARK-33551: Do not use AQE shuffle read for repartition") {
def hasRepartitionShuffle(plan: SparkPlan): Boolean = {
find(plan) {
case s: ShuffleExchangeLike =>
- s.shuffleOrigin == REPARTITION || s.shuffleOrigin == REPARTITION_WITH_NUM
+ s.shuffleOrigin == REPARTITION_BY_COL || s.shuffleOrigin == REPARTITION_BY_NUM
case _ => false
}.isDefined
}
+ def checkBHJ(
+ df: Dataset[Row],
+ optimizeOutRepartition: Boolean,
+ probeSideLocalRead: Boolean,
+ probeSideCoalescedRead: Boolean): Unit = {
+ df.collect()
+ val plan = df.queryExecution.executedPlan
+ // There should be only one shuffle that can't do local read, which is either the top shuffle
+ // from repartition, or BHJ probe side shuffle.
+ checkNumLocalShuffleReads(plan, 1)
+ assert(hasRepartitionShuffle(plan) == !optimizeOutRepartition)
+ val bhj = findTopLevelBroadcastHashJoin(plan)
+ assert(bhj.length == 1)
+
+ // Build side should do local read.
+ val buildSide = find(bhj.head.left)(_.isInstanceOf[AQEShuffleReadExec])
+ assert(buildSide.isDefined)
+ assert(buildSide.get.asInstanceOf[AQEShuffleReadExec].isLocalRead)
+
+ val probeSide = find(bhj.head.right)(_.isInstanceOf[AQEShuffleReadExec])
+ if (probeSideLocalRead || probeSideCoalescedRead) {
+ assert(probeSide.isDefined)
+ if (probeSideLocalRead) {
+ assert(probeSide.get.asInstanceOf[AQEShuffleReadExec].isLocalRead)
+ } else {
+ assert(probeSide.get.asInstanceOf[AQEShuffleReadExec].hasCoalescedPartition)
+ }
+ } else {
+ assert(probeSide.isEmpty)
+ }
+ }
+
+ def checkSMJ(
+ df: Dataset[Row],
+ optimizeOutRepartition: Boolean,
+ optimizeSkewJoin: Boolean,
+ coalescedRead: Boolean): Unit = {
+ df.collect()
+ val plan = df.queryExecution.executedPlan
+ assert(hasRepartitionShuffle(plan) == !optimizeOutRepartition)
+ val smj = findTopLevelSortMergeJoin(plan)
+ assert(smj.length == 1)
+ assert(smj.head.isSkewJoin == optimizeSkewJoin)
+ val aqeReads = collect(smj.head) {
+ case c: AQEShuffleReadExec => c
+ }
+ if (coalescedRead || optimizeSkewJoin) {
+ assert(aqeReads.length == 2)
+ if (coalescedRead) assert(aqeReads.forall(_.hasCoalescedPartition))
+ } else {
+ assert(aqeReads.isEmpty)
+ }
+ }
+
withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
SQLConf.SHUFFLE_PARTITIONS.key -> "5") {
val df = sql(
@@ -1359,50 +1697,30 @@ class ColumnarAdaptiveQueryExecSuite extends ColumnarSparkPlanTest
| SELECT * FROM testData WHERE key = 1
|)
|RIGHT OUTER JOIN testData2
- |ON value = b
- """.stripMargin)
+ |ON CAST(value AS INT) = b
+ """.stripMargin)
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") {
// Repartition with no partition num specified.
- val dfRepartition = df.repartition('b)
- dfRepartition.collect()
- val plan = dfRepartition.queryExecution.executedPlan
- // The top shuffle from repartition is optimized out.
- assert(!hasRepartitionShuffle(plan))
- val bhj = findTopLevelBroadcastHashJoin(plan)
- assert(bhj.length == 1)
- checkNumLocalShuffleReaders(plan, 1)
- // Probe side is coalesced.
- val customReader = bhj.head.right.find(_.isInstanceOf[ColumnarCustomShuffleReaderExec])
- assert(customReader.isDefined)
- assert(customReader.get.asInstanceOf[ColumnarCustomShuffleReaderExec].hasCoalescedPartition)
-
- // Repartition with partition default num specified.
- val dfRepartitionWithNum = df.repartition(5, 'b)
- dfRepartitionWithNum.collect()
- val planWithNum = dfRepartitionWithNum.queryExecution.executedPlan
- // The top shuffle from repartition is optimized out.
- assert(!hasRepartitionShuffle(planWithNum))
- val bhjWithNum = findTopLevelBroadcastHashJoin(planWithNum)
- assert(bhjWithNum.length == 1)
- checkNumLocalShuffleReaders(planWithNum, 1)
- // Probe side is not coalesced.
- assert(bhjWithNum.head.right.find(_.isInstanceOf[CustomShuffleReaderExec]).isEmpty)
-
- // Repartition with partition non-default num specified.
- val dfRepartitionWithNum2 = df.repartition(3, 'b)
- dfRepartitionWithNum2.collect()
- val planWithNum2 = dfRepartitionWithNum2.queryExecution.executedPlan
- // The top shuffle from repartition is not optimized out, and this is the only shuffle that
- // does not have local shuffle reader.
- assert(hasRepartitionShuffle(planWithNum2))
- val bhjWithNum2 = findTopLevelBroadcastHashJoin(planWithNum2)
- assert(bhjWithNum2.length == 1)
- checkNumLocalShuffleReaders(planWithNum2, 1)
- val customReader2 = bhjWithNum2.head.right
- .find(_.isInstanceOf[ColumnarCustomShuffleReaderExec])
- assert(customReader2.isDefined)
- assert(customReader2.get.asInstanceOf[ColumnarCustomShuffleReaderExec].isLocalReader)
+ checkBHJ(df.repartition(Symbol("b")),
+ // The top shuffle from repartition is optimized out.
+ optimizeOutRepartition = true, probeSideLocalRead = false, probeSideCoalescedRead = true)
+
+ // Repartition with default partition num (5 in test env) specified.
+ checkBHJ(df.repartition(5, Symbol("b")),
+ // The top shuffle from repartition is optimized out
+ // The final plan must have 5 partitions, no optimization can be made to the probe side.
+ optimizeOutRepartition = true, probeSideLocalRead = false, probeSideCoalescedRead = false)
+
+ // Repartition with non-default partition num specified.
+ checkBHJ(df.repartition(4, Symbol("b")),
+ // The top shuffle from repartition is not optimized out
+ optimizeOutRepartition = false, probeSideLocalRead = true, probeSideCoalescedRead = true)
+
+ // Repartition by col and project away the partition cols
+ checkBHJ(df.repartition(Symbol("b")).select(Symbol("key")),
+ // The top shuffle from repartition is not optimized out
+ optimizeOutRepartition = false, probeSideLocalRead = true, probeSideCoalescedRead = true)
}
// Force skew join
@@ -1412,108 +1730,941 @@ class ColumnarAdaptiveQueryExecSuite extends ColumnarSparkPlanTest
SQLConf.SKEW_JOIN_SKEWED_PARTITION_FACTOR.key -> "0",
SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES.key -> "10") {
// Repartition with no partition num specified.
- val dfRepartition = df.repartition('b)
- dfRepartition.collect()
- val plan = dfRepartition.queryExecution.executedPlan
- // The top shuffle from repartition is optimized out.
- assert(!hasRepartitionShuffle(plan))
- val smj = findTopLevelSortMergeJoin(plan)
- assert(smj.length == 1)
- // No skew join due to the repartition.
- assert(!smj.head.isSkewJoin)
- // Both sides are coalesced.
- val customReaders = collect(smj.head) {
- case c: CustomShuffleReaderExec if c.hasCoalescedPartition => c
- case c: ColumnarCustomShuffleReaderExec if c.hasCoalescedPartition => c
+ checkSMJ(df.repartition(Symbol("b")),
+ // The top shuffle from repartition is optimized out.
+ optimizeOutRepartition = true, optimizeSkewJoin = false, coalescedRead = true)
+
+ // Repartition with default partition num (5 in test env) specified.
+ checkSMJ(df.repartition(5, Symbol("b")),
+ // The top shuffle from repartition is optimized out.
+ // The final plan must have 5 partitions, can't do coalesced read.
+ optimizeOutRepartition = true, optimizeSkewJoin = false, coalescedRead = false)
+
+ // Repartition with non-default partition num specified.
+ checkSMJ(df.repartition(4, Symbol("b")),
+ // The top shuffle from repartition is not optimized out.
+ optimizeOutRepartition = false, optimizeSkewJoin = true, coalescedRead = false)
+
+ // Repartition by col and project away the partition cols
+ checkSMJ(df.repartition(Symbol("b")).select(Symbol("key")),
+ // The top shuffle from repartition is not optimized out.
+ optimizeOutRepartition = false, optimizeSkewJoin = true, coalescedRead = false)
+ }
+ }
+ }
+
+ test("SPARK-34091: Batch shuffle fetch in AQE partition coalescing") {
+ withSQLConf(
+ SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
+ SQLConf.SHUFFLE_PARTITIONS.key -> "10",
+ SQLConf.FETCH_SHUFFLE_BLOCKS_IN_BATCH.key -> "true") {
+ withTable("t1") {
+ spark.range(100).selectExpr("id + 1 as a").write.format("parquet").saveAsTable("t1")
+ val query = "SELECT SUM(a) FROM t1 GROUP BY a"
+ val (_, adaptivePlan) = runAdaptiveAndVerifyResult(query)
+ val metricName = SQLShuffleReadMetricsReporter.LOCAL_BLOCKS_FETCHED
+ val blocksFetchedMetric = collectFirst(adaptivePlan) {
+ case p if p.metrics.contains(metricName) => p.metrics(metricName)
+ }
+ assert(blocksFetchedMetric.isDefined)
+ val blocksFetched = blocksFetchedMetric.get.value
+ withSQLConf(SQLConf.FETCH_SHUFFLE_BLOCKS_IN_BATCH.key -> "false") {
+ val (_, adaptivePlan2) = runAdaptiveAndVerifyResult(query)
+ val blocksFetchedMetric2 = collectFirst(adaptivePlan2) {
+ case p if p.metrics.contains(metricName) => p.metrics(metricName)
+ }
+ assert(blocksFetchedMetric2.isDefined)
+ val blocksFetched2 = blocksFetchedMetric2.get.value
+ assert(blocksFetched < blocksFetched2)
+ }
+ }
+ }
+ }
+
+ test("SPARK-33933: Materialize BroadcastQueryStage first in AQE") {
+ val testAppender = new LogAppender("aqe query stage materialization order test")
+ testAppender.setThreshold(Level.DEBUG)
+ val df = spark.range(1000).select($"id" % 26, $"id" % 10)
+ .toDF("index", "pv")
+ val dim = Range(0, 26).map(x => (x, ('a' + x).toChar.toString))
+ .toDF("index", "name")
+ val testDf = df.groupBy("index")
+ .agg(sum($"pv").alias("pv"))
+ .join(dim, Seq("index"))
+ val loggerNames =
+ Seq(classOf[BroadcastQueryStageExec].getName, classOf[ShuffleQueryStageExec].getName)
+ withLogAppender(testAppender, loggerNames, level = Some(Level.DEBUG)) {
+ withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") {
+ val result = testDf.collect()
+ assert(result.length == 26)
+ }
+ }
+ val materializeLogs = testAppender.loggingEvents
+ .map(_.getMessage.getFormattedMessage)
+ .filter(_.startsWith("Materialize query stage"))
+ .toArray
+ assert(materializeLogs(0).startsWith("Materialize query stage BroadcastQueryStageExec"))
+ assert(materializeLogs(1).startsWith("Materialize query stage ShuffleQueryStageExec"))
+ }
+
+ test("SPARK-34899: Use origin plan if we can not coalesce shuffle partition") {
+ def checkNoCoalescePartitions(ds: Dataset[Row], origin: ShuffleOrigin): Unit = {
+ assert(collect(ds.queryExecution.executedPlan) {
+ case s: ShuffleExchangeExec if s.shuffleOrigin == origin && s.numPartitions == 2 => s
+ }.size == 1)
+ ds.collect()
+ val plan = ds.queryExecution.executedPlan
+ assert(collect(plan) {
+ case c: AQEShuffleReadExec => c
+ }.isEmpty)
+ assert(collect(plan) {
+ case s: ShuffleExchangeExec if s.shuffleOrigin == origin && s.numPartitions == 2 => s
+ }.size == 1)
+ checkAnswer(ds, testData)
+ }
+
+ withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
+ SQLConf.COALESCE_PARTITIONS_ENABLED.key -> "true",
+ // Pick a small value so that no coalesce can happen.
+ SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES.key -> "100",
+ SQLConf.COALESCE_PARTITIONS_MIN_PARTITION_NUM.key -> "1",
+ SQLConf.SHUFFLE_PARTITIONS.key -> "2") {
+ val df = spark.sparkContext.parallelize(
+ (1 to 100).map(i => TestData(i, i.toString)), 10).toDF()
+
+ // partition size [1420, 1420]
+ checkNoCoalescePartitions(df.repartition($"key"), REPARTITION_BY_COL)
+ // partition size [1140, 1119]
+ checkNoCoalescePartitions(df.sort($"key"), ENSURE_REQUIREMENTS)
+ }
+ }
+
+ test("SPARK-34980: Support coalesce partition through union") {
+ def checkResultPartition(
+ df: Dataset[Row],
+ numUnion: Int,
+ numShuffleReader: Int,
+ numPartition: Int): Unit = {
+ df.collect()
+ assert(collect(df.queryExecution.executedPlan) {
+ case u: UnionExec => u
+ }.size == numUnion)
+ assert(collect(df.queryExecution.executedPlan) {
+ case r: AQEShuffleReadExec => r
+ }.size === numShuffleReader)
+ assert(df.rdd.partitions.length === numPartition)
+ }
+
+ Seq(true, false).foreach { combineUnionEnabled =>
+ val combineUnionConfig = if (combineUnionEnabled) {
+ SQLConf.OPTIMIZER_EXCLUDED_RULES.key -> ""
+ } else {
+ SQLConf.OPTIMIZER_EXCLUDED_RULES.key ->
+ "org.apache.spark.sql.catalyst.optimizer.CombineUnions"
+ }
+ // advisory partition size 1048576 has no special meaning, just a big enough value
+ withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
+ SQLConf.COALESCE_PARTITIONS_ENABLED.key -> "true",
+ SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES.key -> "1048576",
+ SQLConf.COALESCE_PARTITIONS_MIN_PARTITION_NUM.key -> "1",
+ SQLConf.SHUFFLE_PARTITIONS.key -> "10",
+ combineUnionConfig) {
+ withTempView("t1", "t2") {
+ spark.sparkContext.parallelize((1 to 10).map(i => TestData(i, i.toString)), 2)
+ .toDF().createOrReplaceTempView("t1")
+ spark.sparkContext.parallelize((1 to 10).map(i => TestData(i, i.toString)), 4)
+ .toDF().createOrReplaceTempView("t2")
+
+ // positive test that could be coalesced
+ checkResultPartition(
+ sql("""
+ |SELECT key, count(*) FROM t1 GROUP BY key
+ |UNION ALL
+ |SELECT * FROM t2
+ """.stripMargin),
+ numUnion = 1,
+ numShuffleReader = 1,
+ numPartition = 1 + 4)
+
+ checkResultPartition(
+ sql("""
+ |SELECT key, count(*) FROM t1 GROUP BY key
+ |UNION ALL
+ |SELECT * FROM t2
+ |UNION ALL
+ |SELECT * FROM t1
+ """.stripMargin),
+ numUnion = if (combineUnionEnabled) 1 else 2,
+ numShuffleReader = 1,
+ numPartition = 1 + 4 + 2)
+
+ checkResultPartition(
+ sql("""
+ |SELECT /*+ merge(t2) */ t1.key, t2.key FROM t1 JOIN t2 ON t1.key = t2.key
+ |UNION ALL
+ |SELECT key, count(*) FROM t2 GROUP BY key
+ |UNION ALL
+ |SELECT * FROM t1
+ """.stripMargin),
+ numUnion = if (combineUnionEnabled) 1 else 2,
+ numShuffleReader = 3,
+ numPartition = 1 + 1 + 2)
+
+ // negative test
+ checkResultPartition(
+ sql("SELECT * FROM t1 UNION ALL SELECT * FROM t2"),
+ numUnion = if (combineUnionEnabled) 1 else 1,
+ numShuffleReader = 0,
+ numPartition = 2 + 4
+ )
}
- assert(customReaders.length == 2)
-
- // Repartition with default partition num specified.
- val dfRepartitionWithNum = df.repartition(5, 'b)
- dfRepartitionWithNum.collect()
- val planWithNum = dfRepartitionWithNum.queryExecution.executedPlan
- // The top shuffle from repartition is optimized out.
- assert(!hasRepartitionShuffle(planWithNum))
- val smjWithNum = findTopLevelSortMergeJoin(planWithNum)
- assert(smjWithNum.length == 1)
- // No skew join due to the repartition.
- assert(!smjWithNum.head.isSkewJoin)
- // No coalesce due to the num in repartition.
- val customReadersWithNum = collect(smjWithNum.head) {
- case c: CustomShuffleReaderExec if c.hasCoalescedPartition => c
+ }
+ }
+ }
+
+ test("SPARK-35239: Coalesce shuffle partition should handle empty input RDD") {
+ withTable("t") {
+ withSQLConf(SQLConf.COALESCE_PARTITIONS_MIN_PARTITION_NUM.key -> "1",
+ SQLConf.SHUFFLE_PARTITIONS.key -> "2",
+ SQLConf.ADAPTIVE_OPTIMIZER_EXCLUDED_RULES.key -> AQEPropagateEmptyRelation.ruleName) {
+ spark.sql("CREATE TABLE t (c1 int) USING PARQUET")
+ val (_, adaptive) = runAdaptiveAndVerifyResult("SELECT c1, count(*) FROM t GROUP BY c1")
+ assert(
+ collect(adaptive) {
+ case c @ AQEShuffleReadExec(_, partitionSpecs) if partitionSpecs.length == 1 =>
+ assert(c.hasCoalescedPartition)
+ c
+ }.length == 1
+ )
+ }
+ }
+ }
+
+ test("SPARK-35264: Support AQE side broadcastJoin threshold") {
+ withTempView("t1", "t2") {
+ def checkJoinStrategy(shouldBroadcast: Boolean): Unit = {
+ withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
+ val (origin, adaptive) = runAdaptiveAndVerifyResult(
+ "SELECT t1.c1, t2.c1 FROM t1 JOIN t2 ON t1.c1 = t2.c1")
+ assert(findTopLevelSortMergeJoin(origin).size == 1)
+ if (shouldBroadcast) {
+ assert(findTopLevelBroadcastHashJoin(adaptive).size == 1)
+ } else {
+ assert(findTopLevelSortMergeJoin(adaptive).size == 1)
+ }
}
- assert(customReadersWithNum.isEmpty)
+ }
+
+ // t1: 1600 bytes
+ // t2: 160 bytes
+ spark.sparkContext.parallelize(
+ (1 to 100).map(i => TestData(i, i.toString)), 10)
+ .toDF("c1", "c2").createOrReplaceTempView("t1")
+ spark.sparkContext.parallelize(
+ (1 to 10).map(i => TestData(i, i.toString)), 5)
+ .toDF("c1", "c2").createOrReplaceTempView("t2")
+
+ checkJoinStrategy(false)
+ withSQLConf(SQLConf.ADAPTIVE_AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
+ checkJoinStrategy(false)
+ }
- // Repartition with default non-partition num specified.
- val dfRepartitionWithNum2 = df.repartition(3, 'b)
- dfRepartitionWithNum2.collect()
- val planWithNum2 = dfRepartitionWithNum2.queryExecution.executedPlan
- // The top shuffle from repartition is not optimized out.
- assert(hasRepartitionShuffle(planWithNum2))
- val smjWithNum2 = findTopLevelSortMergeJoin(planWithNum2)
- assert(smjWithNum2.length == 1)
- // Skew join can apply as the repartition is not optimized out.
- assert(smjWithNum2.head.isSkewJoin)
+ withSQLConf(SQLConf.ADAPTIVE_AUTO_BROADCASTJOIN_THRESHOLD.key -> "160") {
+ checkJoinStrategy(true)
}
}
}
- ignore("SPARK-34091: Batch shuffle fetch in AQE partition coalescing") {
+ test("SPARK-35264: Support AQE side shuffled hash join formula") {
+ withTempView("t1", "t2") {
+ def checkJoinStrategy(shouldShuffleHashJoin: Boolean): Unit = {
+ Seq("100", "100000").foreach { size =>
+ withSQLConf(SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES.key -> size) {
+ val (origin1, adaptive1) = runAdaptiveAndVerifyResult(
+ "SELECT t1.c1, t2.c1 FROM t1 JOIN t2 ON t1.c1 = t2.c1")
+ assert(findTopLevelSortMergeJoin(origin1).size === 1)
+ if (shouldShuffleHashJoin && size.toInt < 100000) {
+ val shj = findTopLevelShuffledHashJoin(adaptive1)
+ assert(shj.size === 1)
+ assert(shj.head.buildSide == BuildRight)
+ } else {
+ assert(findTopLevelSortMergeJoin(adaptive1).size === 1)
+ }
+ }
+ }
+ // respect user specified join hint
+ val (origin2, adaptive2) = runAdaptiveAndVerifyResult(
+ "SELECT /*+ MERGE(t1) */ t1.c1, t2.c1 FROM t1 JOIN t2 ON t1.c1 = t2.c1")
+ assert(findTopLevelSortMergeJoin(origin2).size === 1)
+ assert(findTopLevelSortMergeJoin(adaptive2).size === 1)
+ }
+
+ spark.sparkContext.parallelize(
+ (1 to 100).map(i => TestData(i, i.toString)), 10)
+ .toDF("c1", "c2").createOrReplaceTempView("t1")
+ spark.sparkContext.parallelize(
+ (1 to 10).map(i => TestData(i, i.toString)), 5)
+ .toDF("c1", "c2").createOrReplaceTempView("t2")
+
+ // t1 partition size: [926, 729, 731]
+ // t2 partition size: [318, 120, 0]
+ withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "3",
+ SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
+ SQLConf.PREFER_SORTMERGEJOIN.key -> "true") {
+ // check default value
+ checkJoinStrategy(false)
+ withSQLConf(SQLConf.ADAPTIVE_MAX_SHUFFLE_HASH_JOIN_LOCAL_MAP_THRESHOLD.key -> "400") {
+ checkJoinStrategy(true)
+ }
+ withSQLConf(SQLConf.ADAPTIVE_MAX_SHUFFLE_HASH_JOIN_LOCAL_MAP_THRESHOLD.key -> "300") {
+ checkJoinStrategy(false)
+ }
+ withSQLConf(SQLConf.ADAPTIVE_MAX_SHUFFLE_HASH_JOIN_LOCAL_MAP_THRESHOLD.key -> "1000") {
+ checkJoinStrategy(true)
+ }
+ }
+ }
+ }
+
+ test("SPARK-35650: Coalesce number of partitions by AEQ") {
+ withSQLConf(SQLConf.COALESCE_PARTITIONS_MIN_PARTITION_NUM.key -> "1") {
+ Seq("REPARTITION", "REBALANCE(key)")
+ .foreach {repartition =>
+ val query = s"SELECT /*+ $repartition */ * FROM testData"
+ val (_, adaptivePlan) = runAdaptiveAndVerifyResult(query)
+ collect(adaptivePlan) {
+ case r: AQEShuffleReadExec => r
+ } match {
+ case Seq(aqeShuffleRead) =>
+ assert(aqeShuffleRead.partitionSpecs.size === 1)
+ assert(!aqeShuffleRead.isLocalRead)
+ case _ =>
+ fail("There should be a AQEShuffleReadExec")
+ }
+ }
+ }
+ }
+
+ test("SPARK-35650: Use local shuffle read if can not coalesce number of partitions") {
+ withSQLConf(SQLConf.COALESCE_PARTITIONS_ENABLED.key -> "false") {
+ val query = "SELECT /*+ REPARTITION */ * FROM testData"
+ val (_, adaptivePlan) = runAdaptiveAndVerifyResult(query)
+ collect(adaptivePlan) {
+ case r: AQEShuffleReadExec => r
+ } match {
+ case Seq(aqeShuffleRead) =>
+ assert(aqeShuffleRead.partitionSpecs.size === 4)
+ assert(aqeShuffleRead.isLocalRead)
+ case _ =>
+ fail("There should be a AQEShuffleReadExec")
+ }
+ }
+ }
+
+ test("SPARK-35725: Support optimize skewed partitions in RebalancePartitions") {
+ withTempView("v") {
withSQLConf(
SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
- SQLConf.SHUFFLE_PARTITIONS.key -> "10000",
- SQLConf.FETCH_SHUFFLE_BLOCKS_IN_BATCH.key -> "true") {
- withTable("t1") {
- spark.range(100).selectExpr("id + 1 as a").write.format("parquet").saveAsTable("t1")
- val query = "SELECT SUM(a) FROM t1 GROUP BY a"
- val (_, adaptivePlan) = runAdaptiveAndVerifyResult(query)
- val metricName = SQLShuffleReadMetricsReporter.LOCAL_BLOCKS_FETCHED
- val blocksFetchedMetric = collectFirst(adaptivePlan) {
- case p if p.metrics.contains(metricName) => p.metrics(metricName)
+ SQLConf.COALESCE_PARTITIONS_ENABLED.key -> "true",
+ SQLConf.ADAPTIVE_OPTIMIZE_SKEWS_IN_REBALANCE_PARTITIONS_ENABLED.key -> "true",
+ SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
+ SQLConf.SHUFFLE_PARTITIONS.key -> "5",
+ SQLConf.COALESCE_PARTITIONS_MIN_PARTITION_NUM.key -> "1") {
+
+ spark.sparkContext.parallelize(
+ (1 to 10).map(i => TestData(if (i > 4) 5 else i, i.toString)), 3)
+ .toDF("c1", "c2").createOrReplaceTempView("v")
+
+ def checkPartitionNumber(
+ query: String, skewedPartitionNumber: Int, totalNumber: Int): Unit = {
+ val (_, adaptive) = runAdaptiveAndVerifyResult(query)
+ val read = collect(adaptive) {
+ case read: AQEShuffleReadExec => read
}
- assert(blocksFetchedMetric.isDefined)
- val blocksFetched = blocksFetchedMetric.get.value
- withSQLConf(SQLConf.FETCH_SHUFFLE_BLOCKS_IN_BATCH.key -> "false") {
- val (_, adaptivePlan2) = runAdaptiveAndVerifyResult(query)
- val blocksFetchedMetric2 = collectFirst(adaptivePlan2) {
- case p if p.metrics.contains(metricName) => p.metrics(metricName)
+ assert(read.size == 1)
+ assert(read.head.partitionSpecs.count(_.isInstanceOf[PartialReducerPartitionSpec]) ==
+ skewedPartitionNumber)
+ assert(read.head.partitionSpecs.size == totalNumber)
+ }
+
+ withSQLConf(SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES.key -> "150") {
+ // partition size [0,258,72,72,72]
+ checkPartitionNumber("SELECT /*+ REBALANCE(c1) */ * FROM v", 2, 4)
+ // partition size [144,72,144,72,72,144,72]
+ checkPartitionNumber("SELECT /*+ REBALANCE */ * FROM v", 6, 7)
+ }
+
+ // no skewed partition should be optimized
+ withSQLConf(SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES.key -> "10000") {
+ checkPartitionNumber("SELECT /*+ REBALANCE(c1) */ * FROM v", 0, 1)
+ }
+ }
+ }
+ }
+
+ test("SPARK-35888: join with a 0-partition table") {
+ withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
+ SQLConf.COALESCE_PARTITIONS_MIN_PARTITION_NUM.key -> "1",
+ SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
+ SQLConf.ADAPTIVE_OPTIMIZER_EXCLUDED_RULES.key -> AQEPropagateEmptyRelation.ruleName) {
+ withTempView("t2") {
+ // create a temp view with 0 partition
+ spark.createDataFrame(sparkContext.emptyRDD[Row], new StructType().add("b", IntegerType))
+ .createOrReplaceTempView("t2")
+ val (_, adaptive) =
+ runAdaptiveAndVerifyResult("SELECT * FROM testData2 t1 left semi join t2 ON t1.a=t2.b")
+ val aqeReads = collect(adaptive) {
+ case c: AQEShuffleReadExec => c
+ }
+ assert(aqeReads.length == 2)
+ aqeReads.foreach { c =>
+ val stats = c.child.asInstanceOf[QueryStageExec].getRuntimeStatistics
+ assert(stats.sizeInBytes >= 0)
+ assert(stats.rowCount.get >= 0)
+ }
+ }
+ }
+ }
+
+ test("SPARK-33832: Support optimize skew join even if introduce extra shuffle") {
+ withSQLConf(
+ SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
+ SQLConf.ADAPTIVE_OPTIMIZE_SKEWS_IN_REBALANCE_PARTITIONS_ENABLED.key -> "false",
+ SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
+ SQLConf.SKEW_JOIN_SKEWED_PARTITION_THRESHOLD.key -> "100",
+ SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES.key -> "100",
+ SQLConf.COALESCE_PARTITIONS_MIN_PARTITION_NUM.key -> "1",
+ SQLConf.SHUFFLE_PARTITIONS.key -> "10",
+ SQLConf.ADAPTIVE_FORCE_OPTIMIZE_SKEWED_JOIN.key -> "true") {
+ withTempView("skewData1", "skewData2") {
+ spark
+ .range(0, 1000, 1, 10)
+ .selectExpr("id % 3 as key1", "id as value1")
+ .createOrReplaceTempView("skewData1")
+ spark
+ .range(0, 1000, 1, 10)
+ .selectExpr("id % 1 as key2", "id as value2")
+ .createOrReplaceTempView("skewData2")
+
+ // check if optimized skewed join does not satisfy the required distribution
+ Seq(true, false).foreach { hasRequiredDistribution =>
+ Seq(true, false).foreach { hasPartitionNumber =>
+ val repartition = if (hasRequiredDistribution) {
+ s"/*+ repartition(${ if (hasPartitionNumber) "10," else ""}key1) */"
+ } else {
+ ""
+ }
+
+ // check required distribution and extra shuffle
+ val (_, adaptive1) =
+ runAdaptiveAndVerifyResult(s"SELECT $repartition key1 FROM skewData1 " +
+ s"JOIN skewData2 ON key1 = key2 GROUP BY key1")
+ val shuffles1 = collect(adaptive1) {
+ case s: ShuffleExchangeExec => s
}
- assert(blocksFetchedMetric2.isDefined)
- val blocksFetched2 = blocksFetchedMetric2.get.value
- assert(blocksFetched < blocksFetched2)
+ assert(shuffles1.size == 3)
+ // shuffles1.head is the top-level shuffle under the Aggregate operator
+ assert(shuffles1.head.shuffleOrigin == ENSURE_REQUIREMENTS)
+ val smj1 = findTopLevelSortMergeJoin(adaptive1)
+ assert(smj1.size == 1 && smj1.head.isSkewJoin)
+
+ // only check required distribution
+ val (_, adaptive2) =
+ runAdaptiveAndVerifyResult(s"SELECT $repartition key1 FROM skewData1 " +
+ s"JOIN skewData2 ON key1 = key2")
+ val shuffles2 = collect(adaptive2) {
+ case s: ShuffleExchangeExec => s
+ }
+ if (hasRequiredDistribution) {
+ assert(shuffles2.size == 3)
+ val finalShuffle = shuffles2.head
+ if (hasPartitionNumber) {
+ assert(finalShuffle.shuffleOrigin == REPARTITION_BY_NUM)
+ } else {
+ assert(finalShuffle.shuffleOrigin == REPARTITION_BY_COL)
+ }
+ } else {
+ assert(shuffles2.size == 2)
+ }
+ val smj2 = findTopLevelSortMergeJoin(adaptive2)
+ assert(smj2.size == 1 && smj2.head.isSkewJoin)
}
}
}
}
+ }
+
+ test("SPARK-35968: AQE coalescing should not produce too small partitions by default") {
+ withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") {
+ val (_, adaptive) =
+ runAdaptiveAndVerifyResult("SELECT sum(id) FROM RANGE(10) GROUP BY id % 3")
+ val coalesceRead = collect(adaptive) {
+ case r: AQEShuffleReadExec if r.hasCoalescedPartition => r
+ }
+ assert(coalesceRead.length == 1)
+ // RANGE(10) is a very small dataset and AQE coalescing should produce one partition.
+ assert(coalesceRead.head.partitionSpecs.length == 1)
+ }
+ }
+
+ test("SPARK-35794: Allow custom plugin for cost evaluator") {
+ CostEvaluator.instantiate(
+ classOf[SimpleShuffleSortCostEvaluator].getCanonicalName, spark.sparkContext.getConf)
+ intercept[IllegalArgumentException] {
+ CostEvaluator.instantiate(
+ classOf[InvalidCostEvaluator].getCanonicalName, spark.sparkContext.getConf)
+ }
+
+ withSQLConf(
+ SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
+ SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") {
+ val query = "SELECT * FROM testData join testData2 ON key = a where value = '1'"
+
+ withSQLConf(SQLConf.ADAPTIVE_CUSTOM_COST_EVALUATOR_CLASS.key ->
+ "org.apache.spark.sql.execution.adaptive.SimpleShuffleSortCostEvaluator") {
+ val (plan, adaptivePlan) = runAdaptiveAndVerifyResult(query)
+ val smj = findTopLevelSortMergeJoin(plan)
+ assert(smj.size == 1)
+ val bhj = findTopLevelBroadcastHashJoin(adaptivePlan)
+ assert(bhj.size == 1)
+ checkNumLocalShuffleReads(adaptivePlan)
+ }
+
+ withSQLConf(SQLConf.ADAPTIVE_CUSTOM_COST_EVALUATOR_CLASS.key ->
+ "org.apache.spark.sql.execution.adaptive.InvalidCostEvaluator") {
+ intercept[IllegalArgumentException] {
+ runAdaptiveAndVerifyResult(query)
+ }
+ }
+ }
+ }
- test("Do not use column shuffle in AQE") {
- def findCustomShuffleReader(plan: SparkPlan): Seq[CustomShuffleReaderExec] ={
- collect(plan) {
- case j: CustomShuffleReaderExec => j
+ test("SPARK-36020: Check logical link in remove redundant projects") {
+ withTempView("t") {
+ spark.range(10).selectExpr("id % 10 as key", "cast(id * 2 as int) as a",
+ "cast(id * 3 as int) as b", "array(id, id + 1, id + 3) as c").createOrReplaceTempView("t")
+ withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
+ SQLConf.ADAPTIVE_AUTO_BROADCASTJOIN_THRESHOLD.key -> "800") {
+ val query =
+ """
+ |WITH tt AS (
+ | SELECT key, a, b, explode(c) AS c FROM t
+ |)
+ |SELECT t1.key, t1.c, t2.key, t2.c
+ |FROM (SELECT a, b, c, key FROM tt WHERE a > 1) t1
+ |JOIN (SELECT a, b, c, key FROM tt) t2
+ | ON t1.key = t2.key
+ |""".stripMargin
+ val (origin, adaptive) = runAdaptiveAndVerifyResult(query)
+ assert(findTopLevelSortMergeJoin(origin).size == 1)
+ assert(findTopLevelBroadcastHashJoin(adaptive).size == 1)
}
}
- def findShuffleExchange(plan: SparkPlan): Seq[ShuffleExchangeExec] ={
- collect(plan) {
- case j: ShuffleExchangeExec => j
+ }
+
+ test("SPARK-35874: AQE Shuffle should wait for its subqueries to finish before materializing") {
+ withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") {
+ val query = "SELECT b FROM testData2 DISTRIBUTE BY (b, (SELECT max(key) FROM testData))"
+ runAdaptiveAndVerifyResult(query)
+ }
+ }
+
+ test("SPARK-36032: Use inputPlan instead of currentPhysicalPlan to initialize logical link") {
+ withTempView("v") {
+ spark.sparkContext.parallelize(
+ (1 to 10).map(i => TestData(i, i.toString)), 2)
+ .toDF("c1", "c2").createOrReplaceTempView("v")
+
+ Seq("-1", "10000").foreach { aqeBhj =>
+ withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
+ SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
+ SQLConf.ADAPTIVE_AUTO_BROADCASTJOIN_THRESHOLD.key -> aqeBhj,
+ SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
+ val (origin, adaptive) = runAdaptiveAndVerifyResult(
+ """
+ |SELECT * FROM v t1 JOIN (
+ | SELECT c1 + 1 as c3 FROM v
+ |)t2 ON t1.c1 = t2.c3
+ |SORT BY c1
+ """.stripMargin)
+ if (aqeBhj.toInt < 0) {
+ // 1 sort since spark plan has no shuffle for SMJ
+ assert(findTopLevelSort(origin).size == 1)
+ // 2 sorts in SMJ
+ assert(findTopLevelSort(adaptive).size == 2)
+ } else {
+ assert(findTopLevelSort(origin).size == 1)
+ // 1 sort at top node and BHJ has no sort
+ assert(findTopLevelSort(adaptive).size == 1)
+ }
+ }
}
}
+ }
+
+ test("SPARK-36424: Support eliminate limits in AQE Optimizer") {
+ withTempView("v") {
+ spark.sparkContext.parallelize(
+ (1 to 10).map(i => TestData(i, if (i > 2) "2" else i.toString)), 2)
+ .toDF("c1", "c2").createOrReplaceTempView("v")
+
+ withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
+ SQLConf.SHUFFLE_PARTITIONS.key -> "3") {
+ val (origin1, adaptive1) = runAdaptiveAndVerifyResult(
+ """
+ |SELECT c2, sum(c1) FROM v GROUP BY c2 LIMIT 5
+ """.stripMargin)
+ assert(findTopLevelLimit(origin1).size == 1)
+ assert(findTopLevelLimit(adaptive1).isEmpty)
+
+ // eliminate limit through filter
+ val (origin2, adaptive2) = runAdaptiveAndVerifyResult(
+ """
+ |SELECT c2, sum(c1) FROM v GROUP BY c2 HAVING sum(c1) > 1 LIMIT 5
+ """.stripMargin)
+ assert(findTopLevelLimit(origin2).size == 1)
+ assert(findTopLevelLimit(adaptive2).isEmpty)
+
+ // The strategy of Eliminate Limits batch should be fixedPoint
+ val (origin3, adaptive3) = runAdaptiveAndVerifyResult(
+ """
+ |SELECT * FROM (SELECT c1 + c2 FROM (SELECT DISTINCT * FROM v LIMIT 10086)) LIMIT 20
+ """.stripMargin
+ )
+ assert(findTopLevelLimit(origin3).size == 1)
+ assert(findTopLevelLimit(adaptive3).isEmpty)
+ }
+ }
+ }
+
+ test("SPARK-37063: OptimizeSkewInRebalancePartitions support optimize non-root node") {
+ withTempView("v") {
+ withSQLConf(
+ SQLConf.ADAPTIVE_OPTIMIZE_SKEWS_IN_REBALANCE_PARTITIONS_ENABLED.key -> "true",
+ SQLConf.SHUFFLE_PARTITIONS.key -> "1",
+ SQLConf.COALESCE_PARTITIONS_MIN_PARTITION_NUM.key -> "1") {
+ spark.sparkContext.parallelize(
+ (1 to 10).map(i => TestData(if (i > 2) 2 else i, i.toString)), 2)
+ .toDF("c1", "c2").createOrReplaceTempView("v")
+
+ def checkRebalance(query: String, numShufflePartitions: Int): Unit = {
+ val (_, adaptive) = runAdaptiveAndVerifyResult(query)
+ assert(adaptive.collect {
+ case sort: SortExec => sort
+ }.size == 1)
+ val read = collect(adaptive) {
+ case read: AQEShuffleReadExec => read
+ }
+ assert(read.size == 1)
+ assert(read.head.partitionSpecs.forall(_.isInstanceOf[PartialReducerPartitionSpec]))
+ assert(read.head.partitionSpecs.size == numShufflePartitions)
+ }
+
+ withSQLConf(SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES.key -> "50") {
+ checkRebalance("SELECT /*+ REBALANCE(c1) */ * FROM v SORT BY c1", 2)
+ checkRebalance("SELECT /*+ REBALANCE */ * FROM v SORT BY c1", 2)
+ }
+ }
+ }
+ }
+
+ test("SPARK-37357: Add small partition factor for rebalance partitions") {
+ withTempView("v") {
+ withSQLConf(
+ SQLConf.ADAPTIVE_OPTIMIZE_SKEWS_IN_REBALANCE_PARTITIONS_ENABLED.key -> "true",
+ SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
+ spark.sparkContext.parallelize(
+ (1 to 8).map(i => TestData(if (i > 2) 2 else i, i.toString)), 3)
+ .toDF("c1", "c2").createOrReplaceTempView("v")
+
+ def checkAQEShuffleReadExists(query: String, exists: Boolean): Unit = {
+ val (_, adaptive) = runAdaptiveAndVerifyResult(query)
+ assert(
+ collect(adaptive) {
+ case read: AQEShuffleReadExec => read
+ }.nonEmpty == exists)
+ }
+
+ withSQLConf(SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES.key -> "200") {
+ withSQLConf(SQLConf.ADAPTIVE_REBALANCE_PARTITIONS_SMALL_PARTITION_FACTOR.key -> "0.5") {
+ // block size: [88, 97, 97]
+ checkAQEShuffleReadExists("SELECT /*+ REBALANCE(c1) */ * FROM v", false)
+ }
+ withSQLConf(SQLConf.ADAPTIVE_REBALANCE_PARTITIONS_SMALL_PARTITION_FACTOR.key -> "0.2") {
+ // block size: [88, 97, 97]
+ checkAQEShuffleReadExists("SELECT /*+ REBALANCE(c1) */ * FROM v", true)
+ }
+ }
+ }
+ }
+ }
+
+ test("SPARK-37742: AQE reads invalid InMemoryRelation stats and mistakenly plans BHJ") {
+ withSQLConf(
+ SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
+ SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "1048584",
+ SQLConf.ADAPTIVE_OPTIMIZER_EXCLUDED_RULES.key -> AQEPropagateEmptyRelation.ruleName) {
+ // Spark estimates a string column as 20 bytes so with 60k rows, these relations should be
+ // estimated at ~120m bytes which is greater than the broadcast join threshold.
+ val joinKeyOne = "00112233445566778899"
+ val joinKeyTwo = "11223344556677889900"
+ Seq.fill(60000)(joinKeyOne).toDF("key")
+ .createOrReplaceTempView("temp")
+ Seq.fill(60000)(joinKeyTwo).toDF("key")
+ .createOrReplaceTempView("temp2")
+
+ Seq(joinKeyOne).toDF("key").createOrReplaceTempView("smallTemp")
+ spark.sql("SELECT key as newKey FROM temp").persist()
+
+ // This query is trying to set up a situation where there are three joins.
+ // The first join will join the cached relation with a smaller relation.
+ // The first join is expected to be a broadcast join since the smaller relation will
+ // fit under the broadcast join threshold.
+ // The second join will join the first join with another relation and is expected
+ // to remain as a sort-merge join.
+ // The third join will join the cached relation with another relation and is expected
+ // to remain as a sort-merge join.
+ val query =
+ s"""
+ |SELECT t3.newKey
+ |FROM
+ | (SELECT t1.newKey
+ | FROM (SELECT key as newKey FROM temp) as t1
+ | JOIN
+ | (SELECT key FROM smallTemp) as t2
+ | ON t1.newKey = t2.key
+ | ) as t3
+ | JOIN
+ | (SELECT key FROM temp2) as t4
+ | ON t3.newKey = t4.key
+ |UNION
+ |SELECT t1.newKey
+ |FROM
+ | (SELECT key as newKey FROM temp) as t1
+ | JOIN
+ | (SELECT key FROM temp2) as t2
+ | ON t1.newKey = t2.key
+ |""".stripMargin
+ val df = spark.sql(query)
+ df.collect()
+ val adaptivePlan = df.queryExecution.executedPlan
+ val bhj = findTopLevelBroadcastHashJoin(adaptivePlan)
+ assert(bhj.length == 1)
+ }
+ }
+
+ test("SPARK-37328: skew join with 3 tables") {
withSQLConf(
SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
- "spark.shuffle.manager"-> "sort",
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
+ SQLConf.SKEW_JOIN_SKEWED_PARTITION_THRESHOLD.key -> "100",
SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES.key -> "100",
- SQLConf.COALESCE_PARTITIONS_ENABLED.key -> "true") {
- spark
- .range(1, 1000, 1).where("id > 995").createOrReplaceTempView("t1")
- spark
- .range(1, 5, 1)createOrReplaceTempView("t2")
+ SQLConf.COALESCE_PARTITIONS_MIN_PARTITION_NUM.key -> "1",
+ SQLConf.SHUFFLE_PARTITIONS.key -> "10") {
+ withTempView("skewData1", "skewData2", "skewData3") {
+ spark
+ .range(0, 1000, 1, 10)
+ .selectExpr("id % 3 as key1", "id % 3 as value1")
+ .createOrReplaceTempView("skewData1")
+ spark
+ .range(0, 1000, 1, 10)
+ .selectExpr("id % 1 as key2", "id as value2")
+ .createOrReplaceTempView("skewData2")
+ spark
+ .range(0, 1000, 1, 10)
+ .selectExpr("id % 1 as key3", "id as value3")
+ .createOrReplaceTempView("skewData3")
+
+ // skewedJoin doesn't happen in last stage
+ val (_, adaptive1) =
+ runAdaptiveAndVerifyResult("SELECT key1 FROM skewData1 JOIN skewData2 ON key1 = key2 " +
+ "JOIN skewData3 ON value2 = value3")
+ val shuffles1 = collect(adaptive1) {
+ case s: ShuffleExchangeExec => s
+ }
+ assert(shuffles1.size == 4)
+ val smj1 = findTopLevelSortMergeJoin(adaptive1)
+ assert(smj1.size == 2 && smj1.last.isSkewJoin && !smj1.head.isSkewJoin)
+
+ // Query has two skewJoin in two continuous stages.
+ val (_, adaptive2) =
+ runAdaptiveAndVerifyResult("SELECT key1 FROM skewData1 JOIN skewData2 ON key1 = key2 " +
+ "JOIN skewData3 ON value1 = value3")
+ val shuffles2 = collect(adaptive2) {
+ case s: ShuffleExchangeExec => s
+ }
+ assert(shuffles2.size == 4)
+ val smj2 = findTopLevelSortMergeJoin(adaptive2)
+ assert(smj2.size == 2 && smj2.forall(_.isSkewJoin))
+ }
+ }
+ }
+
+ test("SPARK-37652: optimize skewed join through union") {
+ withSQLConf(
+ SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
+ SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
+ SQLConf.SKEW_JOIN_SKEWED_PARTITION_THRESHOLD.key -> "100",
+ SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES.key -> "100") {
+ withTempView("skewData1", "skewData2") {
+ spark
+ .range(0, 1000, 1, 10)
+ .selectExpr("id % 3 as key1", "id as value1")
+ .createOrReplaceTempView("skewData1")
+ spark
+ .range(0, 1000, 1, 10)
+ .selectExpr("id % 1 as key2", "id as value2")
+ .createOrReplaceTempView("skewData2")
+
+ def checkSkewJoin(query: String, joinNums: Int, optimizeSkewJoinNums: Int): Unit = {
+ val (_, innerAdaptivePlan) = runAdaptiveAndVerifyResult(query)
+ val joins = findTopLevelSortMergeJoin(innerAdaptivePlan)
+ val optimizeSkewJoins = joins.filter(_.isSkewJoin)
+ assert(joins.size == joinNums && optimizeSkewJoins.size == optimizeSkewJoinNums)
+ }
+
+ // skewJoin union skewJoin
+ checkSkewJoin(
+ "SELECT key1 FROM skewData1 JOIN skewData2 ON key1 = key2 " +
+ "UNION ALL SELECT key2 FROM skewData1 JOIN skewData2 ON key1 = key2", 2, 2)
+
+ // skewJoin union aggregate
+ checkSkewJoin(
+ "SELECT key1 FROM skewData1 JOIN skewData2 ON key1 = key2 " +
+ "UNION ALL SELECT key2 FROM skewData2 GROUP BY key2", 1, 1)
+
+ // skewJoin1 union (skewJoin2 join aggregate)
+ // skewJoin2 will lead to extra shuffles, but skew1 cannot be optimized
+ checkSkewJoin(
+ "SELECT key1 FROM skewData1 JOIN skewData2 ON key1 = key2 UNION ALL " +
+ "SELECT key1 from (SELECT key1 FROM skewData1 JOIN skewData2 ON key1 = key2) tmp1 " +
+ "JOIN (SELECT key2 FROM skewData2 GROUP BY key2) tmp2 ON key1 = key2", 3, 0)
+ }
+ }
+ }
+
+ test("SPARK-38162: Optimize one row plan in AQE Optimizer") {
+ withTempView("v") {
+ spark.sparkContext.parallelize(
+ (1 to 4).map(i => TestData(i, i.toString)), 2)
+ .toDF("c1", "c2").createOrReplaceTempView("v")
+
+ // remove sort
+ val (origin1, adaptive1) = runAdaptiveAndVerifyResult(
+ """
+ |SELECT * FROM v where c1 = 1 order by c1, c2
+ |""".stripMargin)
+ assert(findTopLevelSort(origin1).size == 1)
+ assert(findTopLevelSort(adaptive1).isEmpty)
+
+ // convert group only aggregate to project
+ val (origin2, adaptive2) = runAdaptiveAndVerifyResult(
+ """
+ |SELECT distinct c1 FROM (SELECT /*+ repartition(c1) */ * FROM v where c1 = 1)
+ |""".stripMargin)
+ assert(findTopLevelAggregate(origin2).size == 2)
+ assert(findTopLevelAggregate(adaptive2).isEmpty)
+
+ // remove distinct in aggregate
+ val (origin3, adaptive3) = runAdaptiveAndVerifyResult(
+ """
+ |SELECT sum(distinct c1) FROM (SELECT /*+ repartition(c1) */ * FROM v where c1 = 1)
+ |""".stripMargin)
+ assert(findTopLevelAggregate(origin3).size == 4)
+ assert(findTopLevelAggregate(adaptive3).size == 2)
+
+ // do not optimize if the aggregate is inside query stage
+ val (origin4, adaptive4) = runAdaptiveAndVerifyResult(
+ """
+ |SELECT distinct c1 FROM v where c1 = 1
+ |""".stripMargin)
+ assert(findTopLevelAggregate(origin4).size == 2)
+ assert(findTopLevelAggregate(adaptive4).size == 2)
+
+ val (origin5, adaptive5) = runAdaptiveAndVerifyResult(
+ """
+ |SELECT sum(distinct c1) FROM v where c1 = 1
+ |""".stripMargin)
+ assert(findTopLevelAggregate(origin5).size == 4)
+ assert(findTopLevelAggregate(adaptive5).size == 4)
+ }
+ }
+
+ test("SPARK-39551: Invalid plan check - invalid broadcast query stage") {
+ withSQLConf(
+ SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") {
val (_, adaptivePlan) = runAdaptiveAndVerifyResult(
- "SELECT * FROM t1 JOIN t2 ON t1.id = t2.id")
- val shuffleNum = findShuffleExchange(adaptivePlan)
- assert(shuffleNum.length == 2)
- val shuffleReaderNum = findCustomShuffleReader(adaptivePlan)
- assert(shuffleReaderNum.length == 2)
+ """
+ |SELECT /*+ BROADCAST(t3) */ t3.b, count(t3.a) FROM testData2 t1
+ |INNER JOIN testData2 t2
+ |ON t1.b = t2.b AND t1.a = 0
+ |RIGHT OUTER JOIN testData2 t3
+ |ON t1.a > t3.a
+ |GROUP BY t3.b
+ """.stripMargin
+ )
+ assert(findTopLevelBroadcastNestedLoopJoin(adaptivePlan).size == 1)
+ }
+ }
+
+ test("SPARK-39915: Dataset.repartition(N) may not create N partitions") {
+ withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "6") {
+ // partitioning: HashPartitioning
+ // shuffleOrigin: REPARTITION_BY_NUM
+ assert(spark.range(0).repartition(5, $"id").rdd.getNumPartitions == 5)
+ // shuffleOrigin: REPARTITION_BY_COL
+ // The minimum partition number after AQE coalesce is 1
+ assert(spark.range(0).repartition($"id").rdd.getNumPartitions == 1)
+ // through project
+ assert(spark.range(0).selectExpr("id % 3 as c1", "id % 7 as c2")
+ .repartition(5, $"c1").select($"c2").rdd.getNumPartitions == 5)
+
+ // partitioning: RangePartitioning
+ // shuffleOrigin: REPARTITION_BY_NUM
+ // The minimum partition number of RangePartitioner is 1
+ assert(spark.range(0).repartitionByRange(5, $"id").rdd.getNumPartitions == 1)
+ // shuffleOrigin: REPARTITION_BY_COL
+ assert(spark.range(0).repartitionByRange($"id").rdd.getNumPartitions == 1)
+
+ // partitioning: RoundRobinPartitioning
+ // shuffleOrigin: REPARTITION_BY_NUM
+ assert(spark.range(0).repartition(5).rdd.getNumPartitions == 5)
+ // shuffleOrigin: REBALANCE_PARTITIONS_BY_NONE
+ assert(spark.range(0).repartition().rdd.getNumPartitions == 0)
+ // through project
+ assert(spark.range(0).selectExpr("id % 3 as c1", "id % 7 as c2")
+ .repartition(5).select($"c2").rdd.getNumPartitions == 5)
+
+ // partitioning: SinglePartition
+ assert(spark.range(0).repartition(1).rdd.getNumPartitions == 1)
+ }
+ }
+ test("SPARK-39915: Ensure the output partitioning is user-specified") {
+ withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "3",
+ SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
+ val df1 = spark.range(1).selectExpr("id as c1")
+ val df2 = spark.range(1).selectExpr("id as c2")
+ val df = df1.join(df2, col("c1") === col("c2")).repartition(3, col("c1"))
+ assert(df.rdd.getNumPartitions == 3)
}
}
}
+
+/**
+ * Invalid implementation class for [[CostEvaluator]].
+ */
+private class InvalidCostEvaluator() {}
+
+/**
+ * A simple [[CostEvaluator]] to count number of [[ShuffleExchangeLike]] and [[SortExec]].
+ */
+private case class SimpleShuffleSortCostEvaluator() extends CostEvaluator {
+ override def evaluateCost(plan: SparkPlan): Cost = {
+ val cost = plan.collect {
+ case s: ShuffleExchangeLike => s
+ case s: SortExec => s
+ }.size
+ SimpleCost(cost)
+ }
+}
diff --git a/omnioperator/omniop-spark-extension/pom.xml b/omnioperator/omniop-spark-extension/pom.xml
index 026fc59977b443256c933202f1ebb1dbc19ce3d7..fab207f793b948648cad6b9f6c6f5ed6d585af08 100644
--- a/omnioperator/omniop-spark-extension/pom.xml
+++ b/omnioperator/omniop-spark-extension/pom.xml
@@ -8,14 +8,14 @@
com.huawei.kunpeng
boostkit-omniop-spark-parent
pom
- 3.1.1-1.1.0
+ 3.3.1-1.1.0
BoostKit Spark Native Sql Engine Extension Parent Pom
2.12.10
2.12
- 3.1.1
+ 3.3.1
3.2.2
UTF-8
UTF-8
@@ -55,6 +55,18 @@
org.apache.curator
curator-recipes
+
+ com.fasterxml.jackson.core
+ jackson-core
+
+
+ com.fasterxml.jackson.core
+ jackson-annotations
+
+
+ com.fasterxml.jackson.core
+ jackson-databind
+