diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarPlugin.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarPlugin.scala index d19d1a46723c305ec28b897ef122d326ed693568..0cdd5ce4b0bc0522e3a8346a8fd4d78be8ad8f4d 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarPlugin.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarPlugin.scala @@ -116,6 +116,18 @@ case class ColumnarPreOverrides(isSupportAdaptive: Boolean = true) } else { ColumnarProjectExec(plan.projectList, child) } + case join: ColumnarBroadcastNestedLoopJoinExec => + if (plan.projectList.forall(project => OmniExpressionAdaptor.isSimpleProjectForAll(project)) && enableColumnarProjectFusion) { + ColumnarBroadcastNestedLoopJoinExec( + join.left, + join.right, + join.buildSide, + join.joinType, + join.condition, + plan.projectList) + } else { + ColumnarProjectExec(plan.projectList, child) + } case join: ColumnarShuffledHashJoinExec => if (plan.projectList.forall(project => OmniExpressionAdaptor.isSimpleProjectForAll(project)) && enableColumnarProjectFusion) { ColumnarShuffledHashJoinExec( @@ -392,6 +404,18 @@ case class ColumnarPreOverrides(isSupportAdaptive: Boolean = true) plan.condition, left, right) + case plan: BroadcastNestedLoopJoinExec => + logInfo(s"Columnar Processing for ${plan.getClass} is currently supported.") + val left = replaceWithColumnarPlan(plan.left) + val right = replaceWithColumnarPlan(plan.right) + logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.") + ColumnarBroadcastNestedLoopJoinExec( + left, + right, + plan.buildSide, + plan.joinType, + plan.condition, + ) case plan: ShuffledHashJoinExec if enableDedupLeftSemiJoin && !SQLConf.get.adaptiveExecutionEnabled => { plan.joinType match { case LeftSemi => { diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarPluginConfig.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarPluginConfig.scala index e807f96fce22b6a8e35e1ce55e00961b948f71cb..ea846f188f73b456f237c6d9d8e63141dd1c2cad 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarPluginConfig.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarPluginConfig.scala @@ -59,6 +59,8 @@ class ColumnarPluginConfig(conf: SQLConf) extends Logging { def enableColumnarBroadcastJoin: Boolean = conf.getConf(ENABLE_COLUMNAR_BROADCAST_JOIN) + def enableColumnarBroadcastNestedJoin: Boolean = conf.getConf(ENABLE_COLUMNAR_BROADCAST_NESTED_JOIN) + def enableShareBroadcastJoinHashTable: Boolean = conf.getConf(ENABLE_SHARE_BROADCAST_JOIN_HASH_TABLE) def enableHeuristicJoinReorder: Boolean = conf.getConf(ENABLE_HEURISTIC_JOIN_REORDER) @@ -287,6 +289,12 @@ object ColumnarPluginConfig { .booleanConf .createWithDefault(true) + val ENABLE_COLUMNAR_BROADCAST_NESTED_JOIN = buildConf("spark.omni.sql.columnar.broadcastNestedJoin") + .internal() + .doc("enable or disable columnar broadcastNestedJoin") + .booleanConf + .createWithDefault(true) + val ENABLE_SHARE_BROADCAST_JOIN_HASH_TABLE = buildConf("spark.omni.sql.columnar.broadcastJoin.sharehashtable") .internal() .doc("enable or disable share columnar BroadcastHashJoin hashtable") diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/TransformHintRule.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/TransformHintRule.scala index 553463d56a7876adb129ba282481e34ec2909156..90399b6ff60484dc09f2d3ad184bea0b41b67b49 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/TransformHintRule.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/TransformHintRule.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.execution.adaptive.{BroadcastQueryStageExec, OmniAQE import org.apache.spark.sql.execution.aggregate.HashAggregateExec import org.apache.spark.sql.execution.command.DataWritingCommandExec import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec} -import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, BroadcastNestedLoopJoinExec, ColumnarBroadcastHashJoinExec, ColumnarShuffledHashJoinExec, ColumnarSortMergeJoinExec, ShuffledHashJoinExec, SortMergeJoinExec} +import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, BroadcastNestedLoopJoinExec, ColumnarBroadcastHashJoinExec, ColumnarBroadcastNestedLoopJoinExec, ColumnarShuffledHashJoinExec, ColumnarSortMergeJoinExec, ShuffledHashJoinExec, SortMergeJoinExec} import org.apache.spark.sql.execution.window.WindowExec import org.apache.spark.sql.execution.{CoalesceExec, CodegenSupport, ColumnarBroadcastExchangeExec, ColumnarCoalesceExec, ColumnarDataWritingCommandExec, ColumnarExpandExec, ColumnarFileSourceScanExec, ColumnarFilterExec, ColumnarGlobalLimitExec, ColumnarHashAggregateExec, ColumnarLocalLimitExec, ColumnarProjectExec, ColumnarShuffleExchangeExec, ColumnarSortExec, ColumnarTakeOrderedAndProjectExec, ColumnarTopNSortExec, ColumnarUnionExec, ColumnarWindowExec, ExpandExec, FileSourceScanExec, FilterExec, GlobalLimitExec, LocalLimitExec, ProjectExec, SortExec, SparkPlan, TakeOrderedAndProjectExec, TopNSortExec, UnionExec} import org.apache.spark.sql.types.ColumnarBatchSupportUtil.checkColumnarBatchSupport @@ -120,6 +120,7 @@ case class AddTransformHintRule() extends Rule[SparkPlan] { val enableColumnarExpand: Boolean = columnarConf.enableColumnarExpand val enableColumnarBroadcastExchange: Boolean = columnarConf.enableColumnarBroadcastExchange val enableColumnarBroadcastJoin: Boolean = columnarConf.enableColumnarBroadcastJoin + val enableColumnarBroadcastNestedJoin: Boolean = columnarConf.enableColumnarBroadcastNestedJoin val enableColumnarSortMergeJoin: Boolean = columnarConf.enableColumnarSortMergeJoin val enableShuffledHashJoin: Boolean = columnarConf.enableShuffledHashJoin val enableColumnarFileScan: Boolean = columnarConf.enableColumnarFileScan @@ -315,6 +316,21 @@ case class AddTransformHintRule() extends Rule[SparkPlan] { plan.right, plan.isNullAwareAntiJoin).buildCheck() TransformHints.tagTransformable(plan) + case plan: BroadcastNestedLoopJoinExec => + // We need to check if BroadcastExchangeExec can be converted to columnar-based. + // If not, BHJ should also be row-based. + if (!enableColumnarBroadcastNestedJoin) { + TransformHints.tagNotTransformable( + plan, "columnar BroadcastNestedLoopJoin is not enabled in BroadcastNestedLoopJoinExec") + return + } + ColumnarBroadcastNestedLoopJoinExec( + plan.left, + plan.right, + plan.buildSide, + plan.joinType, + plan.condition).buildCheck() + TransformHints.tagTransformable(plan) case plan: SortMergeJoinExec => if (!enableColumnarSortMergeJoin) { TransformHints.tagNotTransformable( diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/joins/ColumnarBroadcastNestedLoopJoinExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/joins/ColumnarBroadcastNestedLoopJoinExec.scala new file mode 100644 index 0000000000000000000000000000000000000000..6a49edb428026f76387223f3277fd4486eb07bae --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/joins/ColumnarBroadcastNestedLoopJoinExec.scala @@ -0,0 +1,399 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024-2024. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.joins + +import java.util.Optional +import java.util.concurrent.TimeUnit.NANOSECONDS +import scala.collection.mutable +import com.huawei.boostkit.spark.ColumnarPluginConfig +import com.huawei.boostkit.spark.Constant.IS_SKIP_VERIFY_EXP +import com.huawei.boostkit.spark.expression.OmniExpressionAdaptor +import com.huawei.boostkit.spark.expression.OmniExpressionAdaptor.{checkOmniJsonWhiteList, isSimpleColumn, isSimpleColumnForAll} +import com.huawei.boostkit.spark.util.OmniAdaptorUtil +import com.huawei.boostkit.spark.util.OmniAdaptorUtil.{getExprIdForProjectList, getIndexArray, getProjectListIndex,pruneOutput, reorderOutputVecs, transColBatchToOmniVecs} +import nova.hetu.omniruntime.constants.JoinType._ +import nova.hetu.omniruntime.`type`.DataType +import nova.hetu.omniruntime.operator.OmniOperator +import nova.hetu.omniruntime.operator.config.{OperatorConfig, OverflowConfig, SpillConfig} +import nova.hetu.omniruntime.operator.join.{OmniNestedLoopJoinBuildOperatorFactory, OmniNestedLoopJoinLookupOperatorFactory} +import nova.hetu.omniruntime.vector.VecBatch +import nova.hetu.omniruntime.vector.serialize.VecBatchSerializerFactory +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide} +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.execution.{CodegenSupport, ColumnarHashedRelation, ExplainUtils, SparkPlan} +import org.apache.spark.sql.execution.metric.SQLMetrics +import org.apache.spark.sql.execution.util.{MergeIterator, SparkMemoryUtils} +import org.apache.spark.sql.execution.vectorized.OmniColumnVector +import org.apache.spark.sql.vectorized.ColumnarBatch + +/** + * Performs a nested loop join of two child relations. When the output RDD of this operator is + * being constructed, a Spark job is asynchronously started to calculate the values for the + * broadcast relation. This data is then placed in a Spark broadcast variable. The streamedPlan + * relation is not shuffled. + */ +case class ColumnarBroadcastNestedLoopJoinExec( + left: SparkPlan, + right: SparkPlan, + buildSide: BuildSide, + joinType: JoinType, + condition: Option[Expression], + projectList: Seq[NamedExpression] = Seq.empty) extends JoinCodegenSupport { + + override def verboseStringWithOperatorId(): String = { + val joinCondStr = if (condition.isDefined) { + s"${condition.get}${condition.get.dataType}" + } else "None" + s""" + |$formattedNodeName + |$simpleStringWithNodeId + |${ExplainUtils.generateFieldString("buildInput", buildOutput ++ buildOutput.map(_.dataType))} + |${ExplainUtils.generateFieldString("streamedInput", streamedOutput ++ streamedOutput.map(_.dataType))} + |${ExplainUtils.generateFieldString("condition", joinCondStr)} + |${ExplainUtils.generateFieldString("projectList", projectList.map(_.toAttribute) ++ projectList.map(_.toAttribute).map(_.dataType))} + |${ExplainUtils.generateFieldString("output", output ++ output.map(_.dataType))} + |Condition : $condition + |""".stripMargin + } + + private val (buildOutput, streamedOutput) = buildSide match { + case BuildLeft => (left.output, right.output) + case BuildRight => (right.output, left.output) + } + + override def leftKeys: Seq[Expression] = Nil + + override def rightKeys: Seq[Expression] = Nil + + private val (streamedPlan, buildPlan) = buildSide match { + case BuildRight => (left, right) + case BuildLeft => (right, left) + } + + protected lazy val (buildKeys, streamedKeys) = { + require(leftKeys.length == rightKeys.length && + leftKeys.map(_.dataType) + .zip(rightKeys.map(_.dataType)) + .forall(types => types._1.sameType(types._2)), + "Join keys from two sides should have same length and types") + buildSide match { + case BuildLeft => (leftKeys, rightKeys) + case BuildRight => (rightKeys, leftKeys) + } + } + + override lazy val metrics = Map( + "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"), + "lookupAddInputTime" -> + SQLMetrics.createTimingMetric(sparkContext, "time in omni lookup addInput"), + "lookupGetOutputTime" -> + SQLMetrics.createTimingMetric(sparkContext, "time in omni lookup getOutput"), + "lookupCodegenTime" -> + SQLMetrics.createTimingMetric(sparkContext, "time in omni lookup codegen"), + "buildAddInputTime" -> + SQLMetrics.createTimingMetric(sparkContext, "time in omni build addInput"), + "buildGetOutputTime" -> + SQLMetrics.createTimingMetric(sparkContext, "time in omni build getOutput"), + "buildCodegenTime" -> + SQLMetrics.createTimingMetric(sparkContext, "time in omni build codegen"), + "numOutputVecBatches" -> SQLMetrics.createMetric(sparkContext, "number of output vecBatches"), + "numMergedVecBatches" -> SQLMetrics.createMetric(sparkContext, "number of merged vecBatches") + ) + + override protected def withNewChildrenInternal( + newLeft: SparkPlan, newRight: SparkPlan): ColumnarBroadcastNestedLoopJoinExec = + copy(left = newLeft, right = newRight) + + override def inputRDDs(): Seq[RDD[InternalRow]] = { + streamedPlan.asInstanceOf[CodegenSupport].inputRDDs() + } + + override def supportsColumnar: Boolean = true + + override def supportCodegen: Boolean = false + + override def nodeName: String = "OmniColumnarBroadcastNestedLoopJoin" + + /** only for operator fusion */ + def getBuildOutput: Seq[Attribute] = { + buildOutput + } + + def getBuildPlan: SparkPlan = { + buildPlan + } + + def getStreamedOutput: Seq[Attribute] = { + streamedOutput + } + + def buildCheck(): Unit = { + joinType match { + case Inner => + case LeftOuter => + require(buildSide == BuildRight, "In left outer join case,buildSide must be BuildRight.") + case RightOuter => + require(buildSide == BuildLeft, "In right outer join case,buildSide must be BuildLeft.") + case _ => + throw new UnsupportedOperationException(s"Join-type[${joinType}] is not supported " + + s"in ${this.nodeName}") + } + val buildTypes = new Array[DataType](buildOutput.size) // {2, 2}, buildOutput:col1#12,col2#13 + buildOutput.zipWithIndex.foreach {case (att, i) => + buildTypes(i) = OmniExpressionAdaptor.sparkTypeToOmniType(att.dataType, att.metadata) + } + val probeTypes = new Array[DataType](streamedOutput.size) + streamedOutput.zipWithIndex.foreach { case (attr, i) => + probeTypes(i) = OmniExpressionAdaptor.sparkTypeToOmniType(attr.dataType, attr.metadata) + } + condition match { + case Some(expr) => + val filterExpr: String = OmniExpressionAdaptor.rewriteToOmniJsonExpressionLiteral(expr, + OmniExpressionAdaptor.getExprIdMap((streamedOutput ++ buildOutput).map(_.toAttribute))) + if (!isSimpleColumn(filterExpr)) { + checkOmniJsonWhiteList(filterExpr, new Array[AnyRef](0)) + } + case _ => Optional.empty() + } + } + + override def doExecuteColumnar(): RDD[ColumnarBatch] = { + val numOutputRows = longMetric("numOutputRows") + val numOutputVecBatches = longMetric("numOutputVecBatches") + val numMergedVecBatches = longMetric("numMergedVecBatches") + val buildAddInputTime = longMetric("buildAddInputTime") + val buildCodegenTime = longMetric("buildCodegenTime") + val buildGetOutputTime = longMetric("buildGetOutputTime") + val lookupAddInputTime = longMetric("lookupAddInputTime") + val lookupCodegenTime = longMetric("lookupCodegenTime") + val lookupGetOutputTime = longMetric("lookupGetOutputTime") + + val buildTypes = new Array[DataType](buildOutput.size) // {2,2}, buildOutput:col1#12,col2#13 + buildOutput.zipWithIndex.foreach { case (att, i) => + buildTypes(i) = OmniExpressionAdaptor.sparkTypeToOmniType(att.dataType, att.metadata) + } + + val columnarConf: ColumnarPluginConfig = ColumnarPluginConfig.getSessionConf + val enableJoinBatchMerge: Boolean = columnarConf.enableJoinBatchMerge + + val projectExprIdList = getExprIdForProjectList(projectList) + // {0}, buildKeys: col1#12 + val buildOutputCols: Array[Int] = joinType match { + case Inner | LeftOuter | RightOuter => + getIndexArray(buildOutput, projectExprIdList) + case x => + throw new UnsupportedOperationException(s"ColumnBroadcastNestedLoopJoin Join-type[$x] is not supported!") + } + val prunedBuildOutput = pruneOutput(buildOutput, projectExprIdList) + val buildOutputTypes = new Array[DataType](prunedBuildOutput.size) + prunedBuildOutput.zipWithIndex.foreach { case (att, i) => + buildOutputTypes(i) = OmniExpressionAdaptor.sparkTypeToOmniType(att.dataType, att.metadata) + } + + val probeTypes = new Array[DataType](streamedOutput.size) + streamedOutput.zipWithIndex.foreach { case (attr, i) => + probeTypes(i) = OmniExpressionAdaptor.sparkTypeToOmniType(attr.dataType, attr.metadata) + } + val probeOutputCols = getIndexArray(streamedOutput, projectExprIdList) // {0,1} + val prunedStreamedOutput = pruneOutput(streamedOutput, projectExprIdList) + + val projectListIndex = getProjectListIndex(projectExprIdList, prunedStreamedOutput, prunedBuildOutput) + val lookupJoinType = OmniExpressionAdaptor.toOmniJoinType(joinType) + val relation = buildPlan.executeBroadcast[ColumnarHashedRelation]() + streamedPlan.executeColumnar().mapPartitionsWithIndexInternal { (index, iter) => + val filter: Optional[String] = condition match { + case Some(expr) => + Optional.of(OmniExpressionAdaptor.rewriteToOmniJsonExpressionLiteral(expr, + OmniExpressionAdaptor.getExprIdMap((streamedOutput ++ buildOutput).map(_.toAttribute)))) + case _ => + Optional.empty() + } + + def createBuildOpFactoryAndOp(): (OmniNestedLoopJoinBuildOperatorFactory, OmniOperator) = { + val startBuildCodegen = System.nanoTime() + val opFactory = + new OmniNestedLoopJoinBuildOperatorFactory(buildTypes, buildOutputCols) + val op = opFactory.createOperator() + buildCodegenTime += NANOSECONDS.toMillis(System.nanoTime() - startBuildCodegen) + + val deserializer = VecBatchSerializerFactory.create() + relation.value.buildData.foreach { input => + val startBuildInput = System.nanoTime() + op.addInput(deserializer.deserialize(input)) + buildAddInputTime += NANOSECONDS.toMillis(System.nanoTime() - startBuildInput) + } + val startBuildGetOp = System.nanoTime() + try { + op.getOutput + } catch { + case e: Exception => { + op.close() + opFactory.close() + throw new RuntimeException("NestedLoopJoinBuilder getOutput failed") + } + } + buildGetOutputTime += NANOSECONDS.toMillis(System.nanoTime() - startBuildGetOp) + (opFactory, op) + } + var buildOp: OmniOperator = null + var buildOpFactory: OmniNestedLoopJoinBuildOperatorFactory = null + val (opFactory, op) = createBuildOpFactoryAndOp() + buildOpFactory = opFactory + buildOp = op + + + val startLookupCodegen = System.nanoTime() + val lookupOpFactory = new OmniNestedLoopJoinLookupOperatorFactory(lookupJoinType, probeTypes,probeOutputCols,filter,buildOpFactory, + new OperatorConfig(SpillConfig.NONE, + new OverflowConfig(OmniAdaptorUtil.overflowConf()), IS_SKIP_VERIFY_EXP)) + val lookupOp = lookupOpFactory.createOperator() + lookupCodegenTime += NANOSECONDS.toMillis(System.nanoTime() - startLookupCodegen) + + // close operator + SparkMemoryUtils.addLeakSafeTaskCompletionListener[Unit](_ => { + lookupOp.close() + lookupOpFactory.close() + buildOp.close() + buildOpFactory.close() + + }) + + val resultSchema = this.schema + val reverse = buildSide == BuildLeft + var left = 0 + var leftLen = prunedStreamedOutput.size + var right = prunedStreamedOutput.size + var rightLen = output.size + if (reverse) { + left = prunedStreamedOutput.size + leftLen = output.size + right = 0 + rightLen = prunedStreamedOutput.size + } + + val iterBatch = new Iterator[ColumnarBatch] { + private var results: java.util.Iterator[VecBatch] = _ + var res: Boolean = true + + override def hasNext: Boolean = { + while ((results == null || !res) && iter.hasNext) { + val batch = iter.next() + val input = transColBatchToOmniVecs(batch) + val vecBatch = new VecBatch(input, batch.numRows()) + val startlookupInput = System.nanoTime() + lookupOp.addInput(vecBatch) + lookupAddInputTime += NANOSECONDS.toMillis(System.nanoTime() - startlookupInput) + val startLookupGetOp = System.nanoTime() + results = lookupOp.getOutput + res = results.hasNext + lookupGetOutputTime += NANOSECONDS.toMillis(System.nanoTime() - startLookupGetOp) + } + if (results == null) { + false + } else { + if (!res) { + false + } else { + val startLookupGetOp = System.nanoTime() + res = results.hasNext + lookupGetOutputTime += NANOSECONDS.toMillis(System.nanoTime() - startLookupGetOp) + res + } + } + } + override def next(): ColumnarBatch = { + val startLookupGetOp = System.nanoTime() + val result = results.next() + res = results.hasNext + lookupGetOutputTime += NANOSECONDS.toMillis(System.nanoTime() - startLookupGetOp) + val resultVecs = result.getVectors + val vecs = OmniColumnVector + .allocateColumns(result.getRowCount, resultSchema, false) + if (projectList.nonEmpty) { + reorderOutputVecs(projectListIndex, resultVecs, vecs) + } else { + var index = 0 + for (i <- left until leftLen) { + val v = vecs(index) + v.reset() + v.setVec(resultVecs(i)) + index += 1 + } + for (i <- right until rightLen) { + val v = vecs(index) + v.reset() + v.setVec(resultVecs(i)) + index += 1 + } + } + val rowCnt: Int = result.getRowCount + numOutputRows += rowCnt + numOutputVecBatches += 1 + result.close() + new ColumnarBatch(vecs.toArray, rowCnt) + } + } + + if (enableJoinBatchMerge) { + val mergeIterator = new MergeIterator(iterBatch, resultSchema, numMergedVecBatches) + SparkMemoryUtils.addLeakSafeTaskCompletionListener[Unit](_ => { + mergeIterator.close() + }) + mergeIterator + } else { + iterBatch + } + } + } + + override protected def doExecute(): RDD[InternalRow] = { + throw new UnsupportedOperationException(s"This operator doesn't support doExecute().") + } + + override def doProduce(ctx: CodegenContext): String = { + throw new UnsupportedOperationException(s"This operator doesn't support doProduce().") + } + + override def output: Seq[Attribute] = { + if (projectList.nonEmpty) { + projectList.map(_.toAttribute) + } else { + joinType match { + case Inner => + left.output ++ right.output + case LeftOuter => + left.output ++ right.output.map(_.withNullability(true)) + case RightOuter => + left.output.map(_.withNullability(true)) ++ right.output + case j: ExistenceJoin => + left.output :+ j.exists + case LeftExistence(_) => + left.output + case x => + throw new IllegalArgumentException(s"NestedLoopJoin should not take $x as the JoinType") + } + } + } +} diff --git a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/ColumnarNestedLoopJoinExecSuite.scala b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/ColumnarNestedLoopJoinExecSuite.scala new file mode 100644 index 0000000000000000000000000000000000000000..f7472a9b55edd40f417be471357de64633c1ab2f --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/ColumnarNestedLoopJoinExecSuite.scala @@ -0,0 +1,200 @@ +/* + * Copyright (C) 2022-2022. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.optimizer.BuildRight +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.execution.joins._ +import org.apache.spark.sql.functions.col +import org.apache.spark.sql.types.{IntegerType, StringType, StructType} +import org.apache.spark.sql.{DataFrame, QueryTest, Row} + +// refer to joins package +class ColumnarNestedLoopJoinExecSuite extends ColumnarSparkPlanTest { + + import testImplicits.{localSeqToDatasetHolder, newProductEncoder} + + private var left: DataFrame = _ + private var right: DataFrame = _ + private var leftWithNull: DataFrame = _ + private var rightWithNull: DataFrame = _ + private var person_test: DataFrame = _ + private var order_test: DataFrame = _ + + protected override def beforeAll(): Unit = { + super.beforeAll() + left = Seq[(String, String, java.lang.Integer, java.lang.Double)]( + ("abc", "", 4, 2.0), + ("", "Hello", 1, 1.0), + (" add", "World", 8, 3.0), + (" yeah ", "yeah", 10, 8.0) + ).toDF("a", "b", "q", "d") + + right = Seq[(String, String, java.lang.Integer, java.lang.Double)]( + ("abc", "", 4, 1.0), + ("", "Hello", 2, 2.0), + (" add", "World", 1, 3.0), + (" yeah ", "yeah", 0, 4.0) + ).toDF("a", "b", "c", "d") + + leftWithNull = Seq[(String, String, java.lang.Integer, java.lang.Double)]( + ("abc", null, 4, 2.0), + ("", "Hello", null, 1.0), + (" add", "World", 8, 3.0), + (" yeah ", "yeah", 10, 8.0) + ).toDF("a", "b", "q", "d") + + rightWithNull = Seq[(String, String, java.lang.Integer, java.lang.Double)]( + ("abc", "", 4, 1.0), + ("", "Hello", 2, 2.0), + (" add", null, 1, null), + (" yeah ", null, null, 4.0) + ).toDF("a", "b", "c", "d") + } + test("columnar nestedLoopJoin Inner Join is equal to native") { + val df = left.join(right, col("q") < col("c")) + assert( + df.queryExecution.executedPlan.find(_.isInstanceOf[ColumnarBroadcastNestedLoopJoinExec]).isDefined, + s"ColumnarBroadcastNestedLoopJoinExec not happened, " + + s"executedPlan as follows: \n${df.queryExecution.executedPlan}") + checkAnswer(df, Seq( + Row("", "Hello", 1, 1.0, "abc", "", 4, 1.0), + Row("", "Hello", 1, 1.0, "", "Hello", 2, 2.0) + )) + } + + test("columnar nestedLoopJoin Inner Join is equal to native With NULL") { + val df = leftWithNull.join(rightWithNull, col("q") < col("c")) + assert( + df.queryExecution.executedPlan.find(_.isInstanceOf[ColumnarBroadcastNestedLoopJoinExec]).isDefined, + s"ColumnarBroadcastNestedLoopJoinExec not happened, " + + s"executedPlan as follows: \n${df.queryExecution.executedPlan}") + checkAnswer(df, Seq()) + } + + test("columnar nestedLoopJoin LeftOuter Join is equal to native") { + val df = left.join(right, col("q") < col("c"),"leftouter") + assert( + df.queryExecution.executedPlan.find(_.isInstanceOf[ColumnarBroadcastNestedLoopJoinExec]).isDefined, + s"ColumnarBroadcastNestedLoopJoinExec not happened, " + + s"executedPlan as follows: \n${df.queryExecution.executedPlan}") + checkAnswer(df, Seq( + Row("abc", "", 4, 2.0, null, null, null, null), + Row(" yeah ", "yeah", 10, 8.0, null, null, null, null), + Row("", "Hello", 1, 1.0, "abc", "", 4, 1.0), + Row("", "Hello", 1, 1.0, "", "Hello", 2, 2.0), + Row(" add", "World", 8, 3.0, null, null, null, null) + )) + } + + test("columnar nestedLoopJoin LeftOuter Join is equal to native With NULL") { + val df = leftWithNull.join(rightWithNull, col("q") < col("c"),"leftouter") + assert( + df.queryExecution.executedPlan.find(_.isInstanceOf[ColumnarBroadcastNestedLoopJoinExec]).isDefined, + s"ColumnarBroadcastNestedLoopJoinExec not happened, " + + s"executedPlan as follows: \n${df.queryExecution.executedPlan}") + checkAnswer(df, Seq( + Row(" add", "World", 8, 3.0, null, null, null, null), + Row(" yeah ", "yeah", 10, 8.0, null, null, null, null), + Row("", "Hello", null, 1.0, null, null, null, null), + Row("abc", null, 4, 2.0, null, null, null, null) + )) + } + + test("columnar nestedLoopJoin right outer join is equal to native") { + val df = left.join(right, col("q") < col("c"), "rightouter") + assert( + df.queryExecution.executedPlan.find(_.isInstanceOf[ColumnarBroadcastNestedLoopJoinExec]).isDefined, + s"ColumnarBroadcastNestedLoopJoinExec not happened, " + + s"executedPlan as follows: \n${df.queryExecution.executedPlan}") + checkAnswer(df, Seq( + Row("", "Hello", 1, 1.0, "abc", "", 4, 1.0), + Row("", "Hello", 1, 1.0, "", "Hello", 2, 2.0), + Row(null, null, null, null, " add", "World", 1, 3.0), + Row(null, null, null, null, " yeah ", "yeah", 0, 4.0) + )) + } + + test("columnar nestedLoopJoin right outer join is equal to native with null") { + val df = leftWithNull.join(rightWithNull, col("q") < col("c"), "rightouter") + assert( + df.queryExecution.executedPlan.find(_.isInstanceOf[ColumnarBroadcastNestedLoopJoinExec]).isDefined, + s"ColumnarBroadcastNestedLoopJoinExec not happened, " + + s"executedPlan as follows: \n${df.queryExecution.executedPlan}") + assert(QueryTest.sameRows(Seq( + Row(null, null, null, null, " add", null, 1, null), + Row(null, null, null, null, " yeah ", null, null, 4.0), + Row(null, null, null, null, "", "Hello", 2, 2.0), + Row(null, null, null, null, "abc", "", 4, 1.0) + ),df.collect()).isEmpty,"the run value is error") + } + + test("columnar nestedLoopJoin Cross Join is equal to native") { + val df = left.join(right) + assert( + df.queryExecution.executedPlan.find(_.isInstanceOf[ColumnarBroadcastNestedLoopJoinExec]).isDefined, + s"ColumnarBroadcastNestedLoopJoinExec not happened, " + + s"executedPlan as follows: \n${df.queryExecution.executedPlan}") + checkAnswer(df, Seq( + Row("abc", "", 4, 2.0, "abc", "", 4, 1.0), + Row("abc", "", 4, 2.0, "", "Hello", 2, 2.0), + Row("abc", "", 4, 2.0, " yeah ", "yeah", 0, 4.0), + Row("abc", "", 4, 2.0, " add", "World", 1, 3.0), + Row(" yeah ", "yeah", 10, 8.0, "abc", "", 4, 1.0), + Row(" yeah ", "yeah", 10, 8.0, "","Hello", 2, 2.0), + Row(" yeah ", "yeah", 10, 8.0, " yeah ", "yeah", 0, 4.0), + Row(" yeah ", "yeah", 10, 8.0, " add", "World", 1, 3.0), + Row("", "Hello", 1, 1.0, "abc", "", 4, 1.0), + Row("", "Hello", 1, 1.0, "", "Hello", 2, 2.0), + Row("", "Hello", 1, 1.0, " yeah ", "yeah", 0, 4.0), + Row("", "Hello", 1, 1.0, " add", "World", 1, 3.0), + Row(" add", "World", 8, 3.0, "abc", "", 4, 1.0), + Row(" add", "World", 8, 3.0, "", "Hello", 2, 2.0), + Row(" add", "World", 8, 3.0, " yeah ", "yeah", 0, 4.0), + Row(" add", "World", 8, 3.0, " add", "World", 1, 3.0) + )) + } + + test("columnar nestedLoopJoin Cross Join is equal to native With NULL") { + val df = leftWithNull.join(rightWithNull) + assert( + df.queryExecution.executedPlan.find(_.isInstanceOf[ColumnarBroadcastNestedLoopJoinExec]).isDefined, + s"ColumnarBroadcastNestedLoopJoinExec not happened, " + + s"executedPlan as follows: \n${df.queryExecution.executedPlan}") + checkAnswer(df, Seq( + Row("abc", null, 4, 2.0, " yeah ", null, null, 4.0), + Row("abc", null, 4, 2.0, "abc", "", 4, 1.0), + Row("abc", null, 4, 2.0, "", "Hello", 2, 2.0), + Row("abc", null, 4, 2.0, " add", null, 1, null), + Row(" yeah ", "yeah", 10, 8.0, " yeah ", null, null, 4.0), + Row(" yeah ", "yeah", 10, 8.0, "abc", "", 4, 1.0), + Row(" yeah ", "yeah", 10, 8.0, "", "Hello", 2, 2.0), + Row(" yeah ", "yeah", 10, 8.0, " add", null, 1, null), + Row(" add", "World", 8, 3.0, " yeah ", null, null, 4.0), + Row(" add", "World", 8, 3.0, "abc", "", 4, 1.0), + Row(" add", "World", 8, 3.0, "", "Hello", 2, 2.0), + Row(" add", "World", 8, 3.0, " add", null, 1, null), + Row("", "Hello", null, 1.0, " yeah ", null, null, 4.0), + Row("", "Hello", null, 1.0, "abc", "", 4, 1.0), + Row("", "Hello", null, 1.0, "", "Hello", 2, 2.0), + Row("", "Hello", null, 1.0, " add", null, 1, null) + )) + } +} \ No newline at end of file