diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarPlugin.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarPlugin.scala index 7b94255e4774a75220ad4765152335d90ce0b2d5..28eb2579c0a5e85502be43394a5ca41a17203d3a 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarPlugin.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarPlugin.scala @@ -32,6 +32,8 @@ import org.apache.spark.sql.execution.window.WindowExec import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.ColumnarBatchSupportUtil.checkColumnarBatchSupport +import scala.collection.mutable.ListBuffer + case class ColumnarPreOverrides() extends Rule[SparkPlan] { val columnarConf: ColumnarPluginConfig = ColumnarPluginConfig.getSessionConf val enableColumnarFileScan: Boolean = columnarConf.enableColumnarFileScan @@ -67,6 +69,31 @@ case class ColumnarPreOverrides() extends Rule[SparkPlan] { } } + def GetJoins(join: SparkPlan): ListBuffer[ColumnarBroadcastHashJoinExec] = { + val joins = new ListBuffer[ColumnarBroadcastHashJoinExec]() + var isMatch = true + var curJoin = join + while (isMatch) { + isMatch = curJoin match { + case curJoinTemp@ColumnarBroadcastHashJoinExec(_, _, _, _, _, + nextJoin@ColumnarBroadcastHashJoinExec(_, _, _, _, _, _, _, _, _), _, _, _) => + joins.append(curJoinTemp) + curJoin = nextJoin + true + case curJoinTemp@ColumnarBroadcastHashJoinExec(_, _, _, _, _, _, + ColumnarFilterExec(_, _), _, _) => + joins.append(curJoinTemp) + false + case curJoinTemp@ColumnarBroadcastHashJoinExec(_, _, _, _, _, + ColumnarFilterExec(_, _), _, _, _) => + joins.append(curJoinTemp) + false + case _ => false + } + } + joins + } + def replaceWithColumnarPlan(plan: SparkPlan): SparkPlan = plan match { case plan: RowGuard => val actualPlan: SparkPlan = plan.child match { @@ -152,99 +179,47 @@ case class ColumnarPreOverrides() extends Rule[SparkPlan] { logInfo(s"Columnar Processing for ${plan.getClass} is currently supported.") if (enableFusion && !isSupportAdaptive) { if (plan.aggregateExpressions.forall(_.mode == Partial)) { + val joins = GetJoins(child) child match { - case proj1 @ ColumnarProjectExec(_, - join1 @ ColumnarBroadcastHashJoinExec(_, _, _, _, _, - proj2 @ ColumnarProjectExec(_, - join2 @ ColumnarBroadcastHashJoinExec(_, _, _, _, _, - proj3 @ ColumnarProjectExec(_, - join3 @ ColumnarBroadcastHashJoinExec(_, _, _, _, _, - proj4 @ ColumnarProjectExec(_, - join4 @ ColumnarBroadcastHashJoinExec(_, _, _, _, _, - filter @ ColumnarFilterExec(_, - scan @ ColumnarFileSourceScanExec(_, _, _, _, _, _, _, _, _) - ), _, _, _)), _, _, _)), _, _, _)), _, _, _)) - if checkBhjRightChild( - child.asInstanceOf[ColumnarProjectExec].child.children(1) - .asInstanceOf[ColumnarBroadcastExchangeExec].child) => - ColumnarMultipleOperatorExec( - plan, - proj1, - join1, - proj2, - join2, - proj3, - join3, - proj4, - join4, - filter, - scan.relation, - plan.output, - scan.requiredSchema, - scan.partitionFilters, - scan.optionalBucketSet, - scan.optionalNumCoalescedBuckets, - scan.dataFilters, - scan.tableIdentifier, - scan.disableBucketedScan) - case proj1 @ ColumnarProjectExec(_, - join1 @ ColumnarBroadcastHashJoinExec(_, _, _, _, _, - proj2 @ ColumnarProjectExec(_, - join2 @ ColumnarBroadcastHashJoinExec(_, _, _, _, _, - proj3 @ ColumnarProjectExec(_, - join3 @ ColumnarBroadcastHashJoinExec(_, _, _, _, _, _, - filter @ ColumnarFilterExec(_, - scan @ ColumnarFileSourceScanExec(_, _, _, _, _, _, _, _, _)), _, _)) , _, _, _)), _, _, _)) - if checkBhjRightChild( - child.asInstanceOf[ColumnarProjectExec].child.children(1) - .asInstanceOf[ColumnarBroadcastExchangeExec].child) => - ColumnarMultipleOperatorExec1( - plan, - proj1, - join1, - proj2, - join2, - proj3, - join3, - filter, - scan.relation, - plan.output, - scan.requiredSchema, - scan.partitionFilters, - scan.optionalBucketSet, - scan.optionalNumCoalescedBuckets, - scan.dataFilters, - scan.tableIdentifier, - scan.disableBucketedScan) - case proj1 @ ColumnarProjectExec(_, - join1 @ ColumnarBroadcastHashJoinExec(_, _, _, _, _, - proj2 @ ColumnarProjectExec(_, - join2 @ ColumnarBroadcastHashJoinExec(_, _, _, _, _, - proj3 @ ColumnarProjectExec(_, - join3 @ ColumnarBroadcastHashJoinExec(_, _, _, _, _, - filter @ ColumnarFilterExec(_, - scan @ ColumnarFileSourceScanExec(_, _, _, _, _, _, _, _, _)), _, _, _)) , _, _, _)), _, _, _)) - if checkBhjRightChild( - child.asInstanceOf[ColumnarProjectExec].child.children(1) - .asInstanceOf[ColumnarBroadcastExchangeExec].child) => - ColumnarMultipleOperatorExec1( - plan, - proj1, - join1, - proj2, - join2, - proj3, - join3, - filter, - scan.relation, - plan.output, - scan.requiredSchema, - scan.partitionFilters, - scan.optionalBucketSet, - scan.optionalNumCoalescedBuckets, - scan.dataFilters, - scan.tableIdentifier, - scan.disableBucketedScan) + case _:ColumnarBroadcastHashJoinExec if joins.nonEmpty && + checkBhjRightChild(child.children(1) + .asInstanceOf[ColumnarBroadcastExchangeExec].child) => + val lastBhj = joins.last + lastBhj.left match { + case exec: ColumnarFilterExec => + val scan = exec.child.asInstanceOf[ColumnarFileSourceScanExec] + ColumnarFusedOperatorExec(plan, joins, exec, scan.relation, + plan.output, + scan.requiredSchema, + scan.partitionFilters, + scan.optionalBucketSet, + scan.optionalNumCoalescedBuckets, + scan.dataFilters, + scan.tableIdentifier, + scan.disableBucketedScan) + case _ => lastBhj.right match { + case exec: ColumnarFilterExec => + val scan = exec.child.asInstanceOf[ColumnarFileSourceScanExec] + ColumnarFusedOperatorExec(plan, joins, exec, scan.relation, + plan.output, + scan.requiredSchema, + scan.partitionFilters, + scan.optionalBucketSet, + scan.optionalNumCoalescedBuckets, + scan.dataFilters, + scan.tableIdentifier, + scan.disableBucketedScan) + case _ => + new ColumnarHashAggregateExec( + plan.requiredChildDistributionExpressions, + plan.groupingExpressions, + plan.aggregateExpressions, + plan.aggregateAttributes, + plan.initialInputBufferOffset, + plan.resultExpressions, + child) + } + } case _ => new ColumnarHashAggregateExec( plan.requiredChildDistributionExpressions, diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarFusedOperatorExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarFusedOperatorExec.scala new file mode 100644 index 0000000000000000000000000000000000000000..e7bd1fec1a514e99739fa9721f2d6cfdabec1037 --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarFusedOperatorExec.scala @@ -0,0 +1,300 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import com.huawei.boostkit.spark.Constant.IS_SKIP_VERIFY_EXP +import com.huawei.boostkit.spark.expression.OmniExpressionAdaptor +import com.huawei.boostkit.spark.util.OmniAdaptorUtil +import com.huawei.boostkit.spark.util.OmniAdaptorUtil.transColBatchToOmniVecs +import nova.hetu.omniruntime.`type`.DataType +import nova.hetu.omniruntime.constants.JoinType.OMNI_JOIN_TYPE_INNER +import nova.hetu.omniruntime.constants.OperatorType +import nova.hetu.omniruntime.operator.OmniOperatorFactory +import nova.hetu.omniruntime.operator.aggregator.OmniHashAggregationWithExprOperatorFactory +import nova.hetu.omniruntime.operator.config.{OperatorConfig, OverflowConfig, SpillConfig} +import nova.hetu.omniruntime.operator.filter.OmniFilterAndProjectOperatorFactory +import nova.hetu.omniruntime.operator.fusion.OmniFusionOperatorFactory +import nova.hetu.omniruntime.operator.join.{OmniHashBuilderWithExprOperatorFactory, OmniLookupJoinWithExprOperatorFactory} +import nova.hetu.omniruntime.vector.VecBatch +import nova.hetu.omniruntime.vector.serialize.VecBatchSerializerFactory +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression} +import org.apache.spark.sql.catalyst.optimizer.BuildLeft +import org.apache.spark.sql.execution.aggregate.HashAggregateExec +import org.apache.spark.sql.execution.datasources.HadoopFsRelation +import org.apache.spark.sql.execution.joins.ColumnarBroadcastHashJoinExec +import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} +import org.apache.spark.sql.execution.util.SparkMemoryUtils +import org.apache.spark.sql.execution.util.SparkMemoryUtils.addLeakSafeTaskCompletionListener +import org.apache.spark.sql.execution.vectorized.OmniColumnVector +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.vectorized.ColumnarBatch +import org.apache.spark.util.collection.BitSet + +import java.util.Optional +import java.util.concurrent.TimeUnit.NANOSECONDS +import scala.collection.JavaConverters.seqAsJavaList +import scala.collection.mutable.{ArrayBuffer, ListBuffer} + +case class ColumnarFusedOperatorExec( + aggregate: HashAggregateExec, + joins: Seq[ColumnarBroadcastHashJoinExec], + filter: ColumnarFilterExec, + @transient relation: HadoopFsRelation, + output: Seq[Attribute], + requiredSchema: StructType, + partitionFilters: Seq[Expression], + optionalBucketSet: Option[BitSet], + optionalNumCoalescedBuckets: Option[Int], + dataFilters: Seq[Expression], + tableIdentifier: Option[TableIdentifier], + disableBucketedScan: Boolean = false) + extends BaseColumnarFileSourceScanExec( + relation, + output, + requiredSchema, + partitionFilters, + optionalBucketSet, + optionalNumCoalescedBuckets, + dataFilters, + tableIdentifier, + disableBucketedScan) { + protected override def doPrepare(): Unit = { + super.doPrepare() + for (join <- joins) { + join.getBuildPlan.asInstanceOf[ColumnarBroadcastExchangeExec].relationFuture + } + } + + override lazy val metrics = Map( + "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"), + "numFiles" -> SQLMetrics.createMetric(sparkContext, "number of files read"), + "metadataTime" -> SQLMetrics.createTimingMetric(sparkContext, "metadata time"), + "filesSize" -> SQLMetrics.createSizeMetric(sparkContext, "size of files read"), + "numInputRows" -> SQLMetrics.createMetric(sparkContext, "number of input rows"), + "addInputTime" -> SQLMetrics.createTimingMetric(sparkContext, "time in omni addInput"), + "outputTime" -> SQLMetrics.createTimingMetric(sparkContext, "time in omni getOutput"), + "omniJitTime" -> SQLMetrics.createTimingMetric(sparkContext, "time in omni codegen"), + "numOutputVecBatchs" -> SQLMetrics.createMetric(sparkContext, "number of output vecBatchs"), + //operator metric + "lookupAddInputTime" -> SQLMetrics.createTimingMetric(sparkContext, "time in omni lookup addInput"), + // + ) ++ { + // Tracking scan time has overhead, we can't afford to do it for each row, and can only do + // it for each batch. + if (supportsColumnar) { + Some("scanTime" -> SQLMetrics.createTimingMetric(sparkContext, "scan time")) + } else { + None + } + } ++ { + if (relation.partitionSchemaOption.isDefined) { + Map( + "numPartitions" -> SQLMetrics.createMetric(sparkContext, "number of partitions read"), + "pruningTime" -> + SQLMetrics.createTimingMetric(sparkContext, "dynamic partition pruning time")) + } else { + Map.empty[String, SQLMetric] + } + } ++ staticMetrics + + protected override def doExecuteColumnar(): RDD[ColumnarBatch] = { + val numOutputRows = longMetric("numOutputRows") + val scanTime = longMetric("scanTime") + val numInputRows = longMetric("numInputRows") + val numOutputVecBatchs = longMetric("numOutputVecBatchs") + val addInputTime = longMetric("addInputTime") + val omniCodegenTime = longMetric("omniJitTime") + val getOutputTime = longMetric("outputTime") + + // get agg parameters for creating factory + val (omniGroupByChanel, omniAggChannels, omniAggSourceTypes, omniAggFunctionTypes, omniAggOutputTypes, + omniAggInputRaw, omniAggOutputPartial, resultIdxToOmniResultIdxMap) = genAggOutput(aggregate) + + // get join parameters for creating factory + val joinParams = new ListBuffer[(Array[DataType], Array[String], Option[String], Array[DataType], Array[Int], + Array[String], Array[Int], Array[DataType], Broadcast[ColumnarHashedRelation], Boolean)]() + for (join <- joins) { + joinParams.append(genJoinOutputWithReverseFlag(join)) + } + + // get filter parameters for creating factory + val (conditionExpression, omniCondInputTypes, omniCondExpressions) = genFilterOutput(filter) + + val operatorFactories = ArrayBuffer[OmniOperatorFactory[_]]() + val operatorTypes = ArrayBuffer[OperatorType]() + + inputRDD.asInstanceOf[RDD[ColumnarBatch]].mapPartitionsInternal { batches => + // create hash aggregation factory + val deserializer = VecBatchSerializerFactory.create() + val startCodegen = System.nanoTime() + val aggFactory = new OmniHashAggregationWithExprOperatorFactory( + omniGroupByChanel, + omniAggChannels, + omniAggSourceTypes, + omniAggFunctionTypes, + omniAggOutputTypes, + omniAggInputRaw, + omniAggOutputPartial, + new OperatorConfig(SpillConfig.NONE, new OverflowConfig(OmniAdaptorUtil.overflowConf()), IS_SKIP_VERIFY_EXP)) + operatorFactories.prepend(aggFactory) + operatorTypes.prepend(OperatorType.OMNI_HASH_AGGREGATION) + + // create join factories + for (i <- joins.indices) { + // create join factories + val (buildTypes, buildJoinColsExp, joinFilter, probeTypes, probeOutputCols, + probeHashColsExp, buildOutputCols, buildOutputTypes, relation, needReverseJoinOutput) = joinParams(i) + val buildOpFactory = new OmniHashBuilderWithExprOperatorFactory(buildTypes, + buildJoinColsExp, if (joinFilter.nonEmpty) { + Optional.of(joinFilter.get) + } else { + Optional.empty() + }, 1, + new OperatorConfig(SpillConfig.NONE, new OverflowConfig(OmniAdaptorUtil.overflowConf()), IS_SKIP_VERIFY_EXP)) + val buildOp = buildOpFactory.createOperator() + // close build operator + SparkMemoryUtils.addLeakSafeTaskCompletionListener[Unit](_ => { + buildOp.close() + buildOpFactory.close() + }) + + relation.value.buildData.foreach { input => + buildOp.addInput(deserializer.deserialize(input)) + } + buildOp.getOutput + val lookupOpFactory = new OmniLookupJoinWithExprOperatorFactory(probeTypes, probeOutputCols, + probeHashColsExp, buildOutputCols, buildOutputTypes, OMNI_JOIN_TYPE_INNER, buildOpFactory, + new OperatorConfig(SpillConfig.NONE, new OverflowConfig(OmniAdaptorUtil.overflowConf()), IS_SKIP_VERIFY_EXP,needReverseJoinOutput)) + operatorFactories.prepend(lookupOpFactory) + operatorTypes.prepend(OperatorType.OMNI_LOOKUP_JOIN) + + // close lookupOpFactory + SparkMemoryUtils.addLeakSafeTaskCompletionListener[Unit](_ => { + lookupOpFactory.close() + }) + } + + // create filter factory + val condOperatorFactory = new OmniFilterAndProjectOperatorFactory( + conditionExpression, omniCondInputTypes, seqAsJavaList(omniCondExpressions), 1, + new OperatorConfig(SpillConfig.NONE, new OverflowConfig(OmniAdaptorUtil.overflowConf()), IS_SKIP_VERIFY_EXP)) + omniCodegenTime += NANOSECONDS.toMillis(System.nanoTime() - startCodegen) + operatorFactories.prepend(condOperatorFactory) + operatorTypes.prepend(OperatorType.OMNI_FILTER_AND_PROJECT) + + // create fusion factory and operator + val fusionOperatorFactory = new OmniFusionOperatorFactory(operatorFactories.toArray, operatorTypes.toArray) + val fusionOperator = fusionOperatorFactory.createOperator + // close fusion operator + addLeakSafeTaskCompletionListener[Unit](_ => { + fusionOperator.close() + }) + + // fusion operator addinput + while (batches.hasNext) { + val batch = batches.next() + val startInput = System.nanoTime() + val input = transColBatchToOmniVecs(batch) + val vecBatch = new VecBatch(input, batch.numRows()) + fusionOperator.addInput(vecBatch) + addInputTime += NANOSECONDS.toMillis(System.nanoTime() - startInput) + } + // fusion operator getoutput + val startGetOp = System.nanoTime() + val fusionOutput = fusionOperator.getOutput + getOutputTime += NANOSECONDS.toMillis(System.nanoTime() - startGetOp) + val localSchema = aggregate.schema + + new Iterator[ColumnarBatch] { + override def hasNext: Boolean = { + // The `FileScanRDD` returns an iterator which scans the file during the `hasNext` call. + val startNs = System.nanoTime() + val res = fusionOutput.hasNext + scanTime += NANOSECONDS.toMillis(System.nanoTime() - startNs) + res + } + + override def next(): ColumnarBatch = { + val vecBatch = fusionOutput.next() + val vectors: Seq[OmniColumnVector] = OmniColumnVector.allocateColumns( + vecBatch.getRowCount, localSchema, false) + vectors.zipWithIndex.foreach { case (vector, i) => + vector.reset() + vector.setVec(vecBatch.getVectors()(resultIdxToOmniResultIdxMap(i))) + } + numOutputRows += vecBatch.getRowCount + numOutputVecBatchs += 1 + getOutputTime += NANOSECONDS.toMillis(System.nanoTime() - startGetOp) + vecBatch.close() + new ColumnarBatch(vectors.toArray, vecBatch.getRowCount) + } + } + } + } + + override val nodeNamePrefix: String = "" + + override val nodeName: String = "ColumnarFusedOperatorExec" + + override protected def doCanonicalize(): SparkPlan = WrapperLeafExec() + + def genJoinOutputWithReverseFlag(join: ColumnarBroadcastHashJoinExec) = { + val buildTypes = new Array[DataType](join.getBuildOutput.size) // {2,2}, buildOutput:col1#12,col2#13 + join.getBuildOutput.zipWithIndex.foreach { case (att, i) => + buildTypes(i) = OmniExpressionAdaptor.sparkTypeToOmniType(att.dataType, att.metadata) + } + + // {0}, buildKeys: col1#12 + val buildOutputCols = join.getIndexArray(join.getBuildOutput, join.projectList) // {0,1} + val buildJoinColsExp = join.getBuildKeys.map { x => + OmniExpressionAdaptor.rewriteToOmniJsonExpressionLiteral(x, + OmniExpressionAdaptor.getExprIdMap(join.getBuildOutput.map(_.toAttribute))) + }.toArray + val relation = join.getBuildPlan.executeBroadcast[ColumnarHashedRelation]() + + // since project will trim some columns, buildOutputTypes no longer equals to buildTypes + val prunedBuildOutput = join.pruneOutput(join.getBuildOutput, join.projectList) + val buildOutputTypes = new Array[DataType](prunedBuildOutput.size) // {2,2}, buildOutput:col1#12,col2#13 + prunedBuildOutput.zipWithIndex.foreach { case (att, i) => + buildOutputTypes(i) = OmniExpressionAdaptor.sparkTypeToOmniType(att.dataType, att.metadata) + } + + val probeTypes = new Array[DataType](join.getStreamedOutput.size) // {2,2},streamedOutput:col1#10,col2#11 + join.getStreamedOutput.zipWithIndex.foreach { case (attr, i) => + probeTypes(i) = OmniExpressionAdaptor.sparkTypeToOmniType(attr.dataType, attr.metadata) + } + val probeOutputCols = join.getIndexArray(join.getStreamedOutput, join.projectList) // {0,1} + val probeHashColsExp = join.getStreamedKeys.map { x => + OmniExpressionAdaptor.rewriteToOmniJsonExpressionLiteral(x, + OmniExpressionAdaptor.getExprIdMap(join.getStreamedOutput.map(_.toAttribute))) + }.toArray + val filter: Option[String] = join.condition match { + case Some(expr) => + Some(OmniExpressionAdaptor.rewriteToOmniJsonExpressionLiteral(expr, + OmniExpressionAdaptor.getExprIdMap((join.getStreamedOutput ++ join.getBuildOutput).map(_.toAttribute)))) + case _ => None + } + + val needReverseJoinOutput = join.buildSide == BuildLeft + (buildTypes, buildJoinColsExp, filter, probeTypes, probeOutputCols, + probeHashColsExp, buildOutputCols, buildOutputTypes, relation, needReverseJoinOutput) + } +}