diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarPlugin.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarPlugin.scala index af2515af1193e54666cbc861aa054b415612404d..93972f0effb770953dd5e2d96af6b8410249d717 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarPlugin.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarPlugin.scala @@ -40,6 +40,7 @@ import org.apache.spark.sql.catalyst.planning.PhysicalAggregation import org.apache.spark.sql.catalyst.plans.LeftSemi import org.apache.spark.sql.catalyst.plans.logical.Aggregate import org.apache.spark.sql.execution.util.SparkMemoryUtils.addLeakSafeTaskCompletionListener +import org.apache.spark.sql.execution.aggregate.PushOrderedLimitThroughAgg case class ColumnarPreOverrides() extends Rule[SparkPlan] { val columnarConf: ColumnarPluginConfig = ColumnarPluginConfig.getSessionConf @@ -765,6 +766,7 @@ class ColumnarPlugin extends (SparkSessionExtensions => Unit) with Logging { extensions.injectQueryStagePrepRule(session => FallbackBroadcastExchange(session)) extensions.injectQueryStagePrepRule(session => DedupLeftSemiJoinAQE(session)) extensions.injectQueryStagePrepRule(_ => TopNPushDownForWindow) + extensions.injectQueryStagePrepRule(_ => PushOrderedLimitThroughAgg) } } diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarPluginConfig.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarPluginConfig.scala index a6c15e104f574b0fd66b84f3eda44d0544a963a6..9b5b06026b34fe34547e34c15376841a5b44f0b9 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarPluginConfig.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarPluginConfig.scala @@ -239,6 +239,8 @@ class ColumnarPluginConfig(conf: SQLConf) extends Logging { val topNPushDownForWindowEnable: Boolean = conf.getConfString("spark.sql.execution.topNPushDownForWindow.enabled", "true").toBoolean + val pushOrderedLimitThroughAggEnable: Boolean = conf.getConfString("spark.sql.execution.pushOrderedLimitThroughAggEnable.enabled", "false").toBoolean + var pushOrderedLimitThroughAggApplied: Boolean = false; // enable or disable deduplicate the right side of left semi join val enableDedupLeftSemiJoin: Boolean = conf.getConfString("spark.omni.sql.columnar.dedupLeftSemiJoin", "false").toBoolean diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala index 2599a14109c1819ac0c438d3ddea0f7f41fc19bd..1bacb90de235202c94a709617f4edb8b707149b1 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala @@ -44,6 +44,7 @@ import org.apache.spark.sql.execution.ui.{SparkListenerSQLAdaptiveExecutionUpdat import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.spark.util.{SparkFatalException, ThreadUtils} +import com.huawei.boostkit.spark.ColumnarPluginConfig /** * A root node to execute the query plan adaptively. It splits the query plan into independent @@ -634,7 +635,11 @@ case class AdaptiveSparkPlanExec( val newLogicalPlan = logicalPlan.transformDown { case p if p.eq(logicalNode) => newLogicalNode } - logicalPlan = newLogicalPlan + if (!ColumnarPluginConfig.getConf.pushOrderedLimitThroughAggApplied) { + logicalPlan = newLogicalPlan + } else { + logicalPlan + } case _ => // Ignore those earlier stages that have been wrapped in later stages. } diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/aggregate/PushOrderedLimitThroughAgg.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/aggregate/PushOrderedLimitThroughAgg.scala new file mode 100644 index 0000000000000000000000000000000000000000..8743cb9d966a7f78692b6edf364c4f1493479e96 --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/aggregate/PushOrderedLimitThroughAgg.scala @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.aggregate; + +import com.huawei.boostkit.spark.ColumnarPluginConfig +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution.{LocalLimitExec, SortExec, SparkPlan, TakeOrderedAndProjectExec, TopNSortExec} +import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec + +object PushOrderedLimitThroughAgg extends Rule[SparkPlan] with PredicateHelper { + override def apply(plan: SparkPlan): SparkPlan = { + if (!ColumnarPluginConfig.getConf.pushOrderedLimitThroughAggEnable) { + return plan + } + + val enableColumnarTopNSort: Boolean = conf.getConfString("spark.omni.sql.columnar.topNSort", "true").toBoolean + + plan.transform { + case orderAndProject @ TakeOrderedAndProjectExec(limit, sortOrder, projectList, orderAndProjectChild) => { + orderAndProjectChild match { + case finalAgg @ HashAggregateExec(_, _, _, _, _, _, _, _, finalAggChild) => + finalAggChild match { + case shuffleExchange @ ShuffleExchangeExec(_, shuffleExchangeChild, _) => + shuffleExchangeChild match { + case partialAgg @ HashAggregateExec(_, _, _, partialAggGroupingExpressions, _, _, _, _, _) => + val validSortOrder = sortOrder.takeWhile { order => + partialAggGroupingExpressions.exists(attr => order.child.references.exists(ref => ref.name == attr.name)) + } + if(validSortOrder.nonEmpty) { + val newTopNSort = if (enableColumnarTopNSort) { + TopNSortExec(limit, strictTopN = false, validSortOrder.take(0), validSortOrder, global = false, child = partialAgg); + } else { + val newSortExec = SortExec( + validSortOrder, + global = false, + child = partialAgg + ) + LocalLimitExec(limit, child = newSortExec) + } + ColumnarPluginConfig.getConf.pushOrderedLimitThroughAggApplied = true + TakeOrderedAndProjectExec( + limit, sortOrder, projectList, + child = HashAggregateExec( + finalAgg.requiredChildDistributionExpressions, + finalAgg.isStreaming, + finalAgg.numShufflePartitions, + finalAgg.groupingExpressions, + finalAgg.aggregateExpressions, + finalAgg.aggregateAttributes, + finalAgg.initialInputBufferOffset, + finalAgg.resultExpressions, + child = ShuffleExchangeExec( + shuffleExchange.outputPartitioning, + child = newTopNSort, + shuffleExchange.shuffleOrigin + ) + ) + ) + } else { + orderAndProject + } + + case _ => orderAndProject + } + case _ => orderAndProject + } + case _ => orderAndProject + } + } + } + } +} diff --git a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/ColumnarHashAggregateExecSuite.scala b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/ColumnarHashAggregateExecSuite.scala index 0b4a51d7f1c7eb389f522b516c1e2edc04773b78..5b7ce48c610a2dadae3027da4d01523e9e7a364e 100644 --- a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/ColumnarHashAggregateExecSuite.scala +++ b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/ColumnarHashAggregateExecSuite.scala @@ -152,4 +152,59 @@ class ColumnarHashAggregateExecSuite extends ColumnarSparkPlanTest { Seq(Row(0)) ) } + + test("Test No PushOrderedLimitThroughAgg") { + val sql1 = """ + SELECT a, b, SUM(c) AS sum_agg + FROM df_tbl + GROUP BY a, b + ORDER BY a, sum_agg, b + LIMIT 10 + """ + assertPushColumnarTopNSortThroughAggNotEffective(sql1) + } + + private def assertPushColumnarTopNSortThroughAggNotEffective(sql: String): Unit = { + val omniResult = spark.sql(sql) + omniResult.collect(); + val omniPlan = omniResult.queryExecution.executedPlan.toString() + val patternOpenTopn = """(?s)OmniColumnarShuffleExchange.*?\n.*?OmniColumnarTopNSort.*?\n.*?OmniColumnarHashAggregate""".r + + assert(!patternOpenTopn.findFirstIn(omniPlan).isDefined, + s"SQL:${sql}\n@OmniEnv with PushColumnarTopNSortThroughAgg and topNSort, omniPlan:${omniPlan}" + ) + + spark.conf.set("spark.omni.sql.columnar.topNSort", false) + val omniPlanWithSortResult = spark.sql(sql) + omniPlanWithSortResult.collect() + val omniPlanWithSort = omniPlanWithSortResult.queryExecution.executedPlan.toString() + val patternCloseTopn = """(?s)OmniColumnarShuffleExchange.*?\n.*?OmniColumnar.*?Limit.*?\n.*?OmniColumnarSort.*?\n.*?OmniColumnarHashAggregate""".r + assert(!patternCloseTopn.findFirstIn(omniPlanWithSort).isDefined, + s"SQL:${sql}\n@OmniEnv with PushColumnarTopNSortThroughAgg, omniPlan:${omniPlanWithSort}" + ) + } + + private def assertPushThroughAggAndSparkResultEqual(sql: String, pushTopNSortThroughAgg: Boolean = true): Unit = { + val omniResult = spark.sql(sql) + omniResult.collect(); + val omniPlan = omniResult.queryExecution.executedPlan.toString() + val patternOpenTopn = """(?s)OmniColumnarShuffleExchange.*?\n.*?OmniColumnarTopNSort.*?\n.*?OmniColumnarHashAggregate""".r + + if (!pushTopNSortThroughAgg) { + assert(!patternOpenTopn.findFirstIn(omniPlan).isDefined, + s"SQL:${sql}\n@OmniEnv no PushColumnarTopNSortThroughAgg, omniPlan:${omniPlan}" + ) + } + + spark.conf.set("spark.omni.sql.columnar.topNSort", false) + val sparkResult = spark.sql(sql) + sparkResult.collect() + val sparkPlan = sparkResult.queryExecution.executedPlan.toString() + val patternCloseTopn = """(?s)OmniColumnarShuffleExchange.*?\n.*?OmniColumnar.*?Limit.*?\n.*?OmniColumnarSort.*?\n.*?OmniColumnarHashAggregate""".r + if (!pushTopNSortThroughAgg) { + assert(!patternCloseTopn.findFirstIn(sparkPlan).isDefined, + s"SQL:${sql}\n@SparkEnv no PushColumnarTopNSortThroughAgg, sparkPlan:${sparkPlan}" + ) + } + } } diff --git a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/PushOrderedLimitThroughAggSuite.scala b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/PushOrderedLimitThroughAggSuite.scala new file mode 100644 index 0000000000000000000000000000000000000000..ff72cc30ea7a7fd0bc7b3dd3227c90c085faa4b5 --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/PushOrderedLimitThroughAggSuite.scala @@ -0,0 +1,75 @@ +/* + * Copyright (C) 2022-2022. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.sql.functions.{avg, count, first, max, min, sum} +import org.apache.spark.sql.types._ +import org.apache.spark.sql.{DataFrame, Row} + +class PushOrderedLimitThroughAggSuite extends ColumnarSparkPlanTest { + private var df: DataFrame = _ + + protected override def beforeAll(): Unit = { + super.beforeAll() + spark.conf.set("spark.omni.sql.columnar.topNSort", true) + spark.conf.set("spark.sql.execution.pushOrderedLimitThroughAggEnable.enabled", true) + spark.conf.set("spark.sql.adaptive.enabled", true) + df = spark.createDataFrame( + sparkContext.parallelize(Seq( + Row(1, 2.0, 1L, "a"), + Row(1, 2.0, 2L, null), + Row(2, 1.0, 3L, "c"), + Row(null, null, 6L, "e"), + Row(null, 5.0, 7L, "f") + )), new StructType().add("a", IntegerType).add("b", DoubleType) + .add("c", LongType).add("d", StringType)) + df.createOrReplaceTempView("df_tbl") + } + + test("Test PushOrderedLimitThroughAgg") { + val sql1 = """ + SELECT a, b, SUM(c) AS sum_agg + FROM df_tbl + GROUP BY a, b + ORDER BY a, sum_agg, b + LIMIT 10 + """ + assertPushColumnarTopNSortThroughAggEffective(sql1, true) + } + + private def assertPushColumnarTopNSortThroughAggEffective(sql: String, pushTopNSortThroughAgg: Boolean = true): Unit = { + val omniResult = spark.sql(sql) + omniResult.collect(); + val omniPlan = omniResult.queryExecution.executedPlan.toString() + val patternOpenTopn = """(?s)OmniColumnarShuffleExchange.*?\n.*?OmniColumnarTopNSort.*?\n.*?OmniColumnarHashAggregate""".r + + assert(patternOpenTopn.findFirstIn(omniPlan).isDefined, + s"SQL:${sql}\n@OmniEnv with PushColumnarTopNSortThroughAgg and topNSort, omniPlan:${omniPlan}" + ) + + spark.conf.set("spark.omni.sql.columnar.topNSort", false) + val omniPlanWithSortResult = spark.sql(sql) + omniPlanWithSortResult.collect() + val omniPlanWithSort = omniPlanWithSortResult.queryExecution.executedPlan.toString() + val patternCloseTopn = """(?s)OmniColumnarShuffleExchange.*?\n.*?OmniColumnar.*?Limit.*?\n.*?OmniColumnarSort.*?\n.*?OmniColumnarHashAggregate""".r + assert(patternCloseTopn.findFirstIn(omniPlanWithSort).isDefined, + s"SQL:${sql}\n@OmniEnv with PushColumnarTopNSortThroughAgg, omniPlan:${omniPlanWithSort}" + ) + } +}