From e329b525b2148c1b861493820606ae5e610e9dd3 Mon Sep 17 00:00:00 2001 From: linlong_job Date: Wed, 17 Apr 2024 11:24:39 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90spark=20extension=E3=80=91spark331=20b?= =?UTF-8?q?roadcast=20join=20support=20L0=20memory?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../com/huawei/boostkit/spark/ColumnarPluginConfig.scala | 8 ++++++-- .../execution/joins/ColumnarBroadcastHashJoinExec.scala | 6 +++++- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarPluginConfig.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarPluginConfig.scala index e87122e87..46c4d6cf9 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarPluginConfig.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarPluginConfig.scala @@ -170,7 +170,7 @@ class ColumnarPluginConfig(conf: SQLConf) extends Logging { // columnar sort spill threshold val columnarSortSpillRowThreshold: Integer = - conf.getConfString("spark.omni.sql.columnar.sortSpill.rowThreshold", Integer.MAX_VALUE.toString).toInt + conf.getConfString("spark.omni.sql.columnar.sortSpill.rowThreshold", Integer.MAX_VALUE.toString).toInt // enable or disable columnar window spill val enableWindowSpill: Boolean = conf @@ -186,7 +186,7 @@ class ColumnarPluginConfig(conf: SQLConf) extends Logging { // columnar hash aggregate spill threshold val columnarHashAggSpillRowThreshold: Integer = - conf.getConfString("spark.omni.sql.columnar.hashAggSpill.rowThreshold", Integer.MAX_VALUE.toString).toInt + conf.getConfString("spark.omni.sql.columnar.hashAggSpill.rowThreshold", Integer.MAX_VALUE.toString).toInt // enable or disable columnar shuffledHashJoin val enableShuffledHashJoin: Boolean = conf @@ -259,6 +259,10 @@ class ColumnarPluginConfig(conf: SQLConf) extends Logging { val radixSortThreshold: Int = conf.getConfString("spark.omni.sql.columnar.radixSortThreshold", "1000000").toInt + + val enableL0BroadcastJoin: Boolean = + conf.getConfString("spark.omni.sql.columnar.L0.broadcastJoin.enabled", "false").toBoolean + } diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/joins/ColumnarBroadcastHashJoinExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/joins/ColumnarBroadcastHashJoinExec.scala index ed3ca244b..df2bef1b9 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/joins/ColumnarBroadcastHashJoinExec.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/joins/ColumnarBroadcastHashJoinExec.scala @@ -28,6 +28,7 @@ import com.huawei.boostkit.spark.util.OmniAdaptorUtil import com.huawei.boostkit.spark.util.OmniAdaptorUtil.{getIndexArray, pruneOutput, reorderVecs, transColBatchToOmniVecs} import nova.hetu.omniruntime.constants.JoinType._ import nova.hetu.omniruntime.`type`.DataType +import nova.hetu.omniruntime.constants.MemoryType import nova.hetu.omniruntime.operator.OmniOperator import nova.hetu.omniruntime.operator.config.{OperatorConfig, OverflowConfig, SpillConfig} import nova.hetu.omniruntime.operator.join.{OmniHashBuilderWithExprOperatorFactory, OmniLookupJoinWithExprOperatorFactory} @@ -299,6 +300,8 @@ case class ColumnarBroadcastHashJoinExec( val enableShareBuildOp: Boolean = columnarConf.enableShareBroadcastJoinHashTable val enableJoinBatchMerge: Boolean = columnarConf.enableJoinBatchMerge + val enableL0BroadcastJoin: Boolean = columnarConf.enableL0BroadcastJoin + // {0}, buildKeys: col1#12 val buildOutputCols: Array[Int] = joinType match { case Inner | LeftOuter | RightOuter => @@ -350,7 +353,8 @@ case class ColumnarBroadcastHashJoinExec( val op = opFactory.createOperator() buildCodegenTime += NANOSECONDS.toMillis(System.nanoTime() - startBuildCodegen) - val deserializer = VecBatchSerializerFactory.create() + val memoryType: MemoryType = if (enableL0BroadcastJoin) MemoryType.L0 else MemoryType.DDR + val deserializer = VecBatchSerializerFactory.create(memoryType) relation.value.buildData.foreach { input => val startBuildInput = System.nanoTime() op.addInput(deserializer.deserialize(input)) -- Gitee