From 49b0dc36d5244855837dac834849cb96abd22117 Mon Sep 17 00:00:00 2001 From: ycsongcs Date: Sun, 29 Sep 2024 18:19:58 +0800 Subject: [PATCH] [spark-extension] perf:Push ordered limit through agg --- .../boostkit/spark/ColumnarPlugin.scala | 2 + .../boostkit/spark/ColumnarPluginConfig.scala | 2 + .../adaptive/AdaptiveSparkPlanExec.scala | 7 +- .../PushOrderedLimitThroughAgg.scala | 88 +++++++++++++++++++ .../ColumnarHashAggregateExecSuite.scala | 55 ++++++++++++ .../PushOrderedLimitThroughAggSuite.scala | 75 ++++++++++++++++ 6 files changed, 228 insertions(+), 1 deletion(-) create mode 100644 omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/aggregate/PushOrderedLimitThroughAgg.scala create mode 100644 omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/PushOrderedLimitThroughAggSuite.scala diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarPlugin.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarPlugin.scala index af2515af1..93972f0ef 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarPlugin.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarPlugin.scala @@ -40,6 +40,7 @@ import org.apache.spark.sql.catalyst.planning.PhysicalAggregation import org.apache.spark.sql.catalyst.plans.LeftSemi import org.apache.spark.sql.catalyst.plans.logical.Aggregate import org.apache.spark.sql.execution.util.SparkMemoryUtils.addLeakSafeTaskCompletionListener +import org.apache.spark.sql.execution.aggregate.PushOrderedLimitThroughAgg case class ColumnarPreOverrides() extends Rule[SparkPlan] { val columnarConf: ColumnarPluginConfig = ColumnarPluginConfig.getSessionConf @@ -765,6 +766,7 @@ class ColumnarPlugin extends (SparkSessionExtensions => Unit) with Logging { extensions.injectQueryStagePrepRule(session => FallbackBroadcastExchange(session)) extensions.injectQueryStagePrepRule(session => DedupLeftSemiJoinAQE(session)) extensions.injectQueryStagePrepRule(_ => TopNPushDownForWindow) + extensions.injectQueryStagePrepRule(_ => PushOrderedLimitThroughAgg) } } diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarPluginConfig.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarPluginConfig.scala index a6c15e104..9b5b06026 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarPluginConfig.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarPluginConfig.scala @@ -239,6 +239,8 @@ class ColumnarPluginConfig(conf: SQLConf) extends Logging { val topNPushDownForWindowEnable: Boolean = conf.getConfString("spark.sql.execution.topNPushDownForWindow.enabled", "true").toBoolean + val pushOrderedLimitThroughAggEnable: Boolean = conf.getConfString("spark.sql.execution.pushOrderedLimitThroughAggEnable.enabled", "false").toBoolean + var pushOrderedLimitThroughAggApplied: Boolean = false; // enable or disable deduplicate the right side of left semi join val enableDedupLeftSemiJoin: Boolean = conf.getConfString("spark.omni.sql.columnar.dedupLeftSemiJoin", "false").toBoolean diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala index 2599a1410..1bacb90de 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala @@ -44,6 +44,7 @@ import org.apache.spark.sql.execution.ui.{SparkListenerSQLAdaptiveExecutionUpdat import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.spark.util.{SparkFatalException, ThreadUtils} +import com.huawei.boostkit.spark.ColumnarPluginConfig /** * A root node to execute the query plan adaptively. It splits the query plan into independent @@ -634,7 +635,11 @@ case class AdaptiveSparkPlanExec( val newLogicalPlan = logicalPlan.transformDown { case p if p.eq(logicalNode) => newLogicalNode } - logicalPlan = newLogicalPlan + if (!ColumnarPluginConfig.getConf.pushOrderedLimitThroughAggApplied) { + logicalPlan = newLogicalPlan + } else { + logicalPlan + } case _ => // Ignore those earlier stages that have been wrapped in later stages. } diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/aggregate/PushOrderedLimitThroughAgg.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/aggregate/PushOrderedLimitThroughAgg.scala new file mode 100644 index 000000000..8743cb9d9 --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/aggregate/PushOrderedLimitThroughAgg.scala @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.aggregate; + +import com.huawei.boostkit.spark.ColumnarPluginConfig +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution.{LocalLimitExec, SortExec, SparkPlan, TakeOrderedAndProjectExec, TopNSortExec} +import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec + +object PushOrderedLimitThroughAgg extends Rule[SparkPlan] with PredicateHelper { + override def apply(plan: SparkPlan): SparkPlan = { + if (!ColumnarPluginConfig.getConf.pushOrderedLimitThroughAggEnable) { + return plan + } + + val enableColumnarTopNSort: Boolean = conf.getConfString("spark.omni.sql.columnar.topNSort", "true").toBoolean + + plan.transform { + case orderAndProject @ TakeOrderedAndProjectExec(limit, sortOrder, projectList, orderAndProjectChild) => { + orderAndProjectChild match { + case finalAgg @ HashAggregateExec(_, _, _, _, _, _, _, _, finalAggChild) => + finalAggChild match { + case shuffleExchange @ ShuffleExchangeExec(_, shuffleExchangeChild, _) => + shuffleExchangeChild match { + case partialAgg @ HashAggregateExec(_, _, _, partialAggGroupingExpressions, _, _, _, _, _) => + val validSortOrder = sortOrder.takeWhile { order => + partialAggGroupingExpressions.exists(attr => order.child.references.exists(ref => ref.name == attr.name)) + } + if(validSortOrder.nonEmpty) { + val newTopNSort = if (enableColumnarTopNSort) { + TopNSortExec(limit, strictTopN = false, validSortOrder.take(0), validSortOrder, global = false, child = partialAgg); + } else { + val newSortExec = SortExec( + validSortOrder, + global = false, + child = partialAgg + ) + LocalLimitExec(limit, child = newSortExec) + } + ColumnarPluginConfig.getConf.pushOrderedLimitThroughAggApplied = true + TakeOrderedAndProjectExec( + limit, sortOrder, projectList, + child = HashAggregateExec( + finalAgg.requiredChildDistributionExpressions, + finalAgg.isStreaming, + finalAgg.numShufflePartitions, + finalAgg.groupingExpressions, + finalAgg.aggregateExpressions, + finalAgg.aggregateAttributes, + finalAgg.initialInputBufferOffset, + finalAgg.resultExpressions, + child = ShuffleExchangeExec( + shuffleExchange.outputPartitioning, + child = newTopNSort, + shuffleExchange.shuffleOrigin + ) + ) + ) + } else { + orderAndProject + } + + case _ => orderAndProject + } + case _ => orderAndProject + } + case _ => orderAndProject + } + } + } + } +} diff --git a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/ColumnarHashAggregateExecSuite.scala b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/ColumnarHashAggregateExecSuite.scala index 0b4a51d7f..5b7ce48c6 100644 --- a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/ColumnarHashAggregateExecSuite.scala +++ b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/ColumnarHashAggregateExecSuite.scala @@ -152,4 +152,59 @@ class ColumnarHashAggregateExecSuite extends ColumnarSparkPlanTest { Seq(Row(0)) ) } + + test("Test No PushOrderedLimitThroughAgg") { + val sql1 = """ + SELECT a, b, SUM(c) AS sum_agg + FROM df_tbl + GROUP BY a, b + ORDER BY a, sum_agg, b + LIMIT 10 + """ + assertPushColumnarTopNSortThroughAggNotEffective(sql1) + } + + private def assertPushColumnarTopNSortThroughAggNotEffective(sql: String): Unit = { + val omniResult = spark.sql(sql) + omniResult.collect(); + val omniPlan = omniResult.queryExecution.executedPlan.toString() + val patternOpenTopn = """(?s)OmniColumnarShuffleExchange.*?\n.*?OmniColumnarTopNSort.*?\n.*?OmniColumnarHashAggregate""".r + + assert(!patternOpenTopn.findFirstIn(omniPlan).isDefined, + s"SQL:${sql}\n@OmniEnv with PushColumnarTopNSortThroughAgg and topNSort, omniPlan:${omniPlan}" + ) + + spark.conf.set("spark.omni.sql.columnar.topNSort", false) + val omniPlanWithSortResult = spark.sql(sql) + omniPlanWithSortResult.collect() + val omniPlanWithSort = omniPlanWithSortResult.queryExecution.executedPlan.toString() + val patternCloseTopn = """(?s)OmniColumnarShuffleExchange.*?\n.*?OmniColumnar.*?Limit.*?\n.*?OmniColumnarSort.*?\n.*?OmniColumnarHashAggregate""".r + assert(!patternCloseTopn.findFirstIn(omniPlanWithSort).isDefined, + s"SQL:${sql}\n@OmniEnv with PushColumnarTopNSortThroughAgg, omniPlan:${omniPlanWithSort}" + ) + } + + private def assertPushThroughAggAndSparkResultEqual(sql: String, pushTopNSortThroughAgg: Boolean = true): Unit = { + val omniResult = spark.sql(sql) + omniResult.collect(); + val omniPlan = omniResult.queryExecution.executedPlan.toString() + val patternOpenTopn = """(?s)OmniColumnarShuffleExchange.*?\n.*?OmniColumnarTopNSort.*?\n.*?OmniColumnarHashAggregate""".r + + if (!pushTopNSortThroughAgg) { + assert(!patternOpenTopn.findFirstIn(omniPlan).isDefined, + s"SQL:${sql}\n@OmniEnv no PushColumnarTopNSortThroughAgg, omniPlan:${omniPlan}" + ) + } + + spark.conf.set("spark.omni.sql.columnar.topNSort", false) + val sparkResult = spark.sql(sql) + sparkResult.collect() + val sparkPlan = sparkResult.queryExecution.executedPlan.toString() + val patternCloseTopn = """(?s)OmniColumnarShuffleExchange.*?\n.*?OmniColumnar.*?Limit.*?\n.*?OmniColumnarSort.*?\n.*?OmniColumnarHashAggregate""".r + if (!pushTopNSortThroughAgg) { + assert(!patternCloseTopn.findFirstIn(sparkPlan).isDefined, + s"SQL:${sql}\n@SparkEnv no PushColumnarTopNSortThroughAgg, sparkPlan:${sparkPlan}" + ) + } + } } diff --git a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/PushOrderedLimitThroughAggSuite.scala b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/PushOrderedLimitThroughAggSuite.scala new file mode 100644 index 000000000..ff72cc30e --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/PushOrderedLimitThroughAggSuite.scala @@ -0,0 +1,75 @@ +/* + * Copyright (C) 2022-2022. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.sql.functions.{avg, count, first, max, min, sum} +import org.apache.spark.sql.types._ +import org.apache.spark.sql.{DataFrame, Row} + +class PushOrderedLimitThroughAggSuite extends ColumnarSparkPlanTest { + private var df: DataFrame = _ + + protected override def beforeAll(): Unit = { + super.beforeAll() + spark.conf.set("spark.omni.sql.columnar.topNSort", true) + spark.conf.set("spark.sql.execution.pushOrderedLimitThroughAggEnable.enabled", true) + spark.conf.set("spark.sql.adaptive.enabled", true) + df = spark.createDataFrame( + sparkContext.parallelize(Seq( + Row(1, 2.0, 1L, "a"), + Row(1, 2.0, 2L, null), + Row(2, 1.0, 3L, "c"), + Row(null, null, 6L, "e"), + Row(null, 5.0, 7L, "f") + )), new StructType().add("a", IntegerType).add("b", DoubleType) + .add("c", LongType).add("d", StringType)) + df.createOrReplaceTempView("df_tbl") + } + + test("Test PushOrderedLimitThroughAgg") { + val sql1 = """ + SELECT a, b, SUM(c) AS sum_agg + FROM df_tbl + GROUP BY a, b + ORDER BY a, sum_agg, b + LIMIT 10 + """ + assertPushColumnarTopNSortThroughAggEffective(sql1, true) + } + + private def assertPushColumnarTopNSortThroughAggEffective(sql: String, pushTopNSortThroughAgg: Boolean = true): Unit = { + val omniResult = spark.sql(sql) + omniResult.collect(); + val omniPlan = omniResult.queryExecution.executedPlan.toString() + val patternOpenTopn = """(?s)OmniColumnarShuffleExchange.*?\n.*?OmniColumnarTopNSort.*?\n.*?OmniColumnarHashAggregate""".r + + assert(patternOpenTopn.findFirstIn(omniPlan).isDefined, + s"SQL:${sql}\n@OmniEnv with PushColumnarTopNSortThroughAgg and topNSort, omniPlan:${omniPlan}" + ) + + spark.conf.set("spark.omni.sql.columnar.topNSort", false) + val omniPlanWithSortResult = spark.sql(sql) + omniPlanWithSortResult.collect() + val omniPlanWithSort = omniPlanWithSortResult.queryExecution.executedPlan.toString() + val patternCloseTopn = """(?s)OmniColumnarShuffleExchange.*?\n.*?OmniColumnar.*?Limit.*?\n.*?OmniColumnarSort.*?\n.*?OmniColumnarHashAggregate""".r + assert(patternCloseTopn.findFirstIn(omniPlanWithSort).isDefined, + s"SQL:${sql}\n@OmniEnv with PushColumnarTopNSortThroughAgg, omniPlan:${omniPlanWithSort}" + ) + } +} -- Gitee