From 50769eaceda920f85e0b91a744ce91b9509fb387 Mon Sep 17 00:00:00 2001 From: tiantao202212 Date: Sun, 25 Jun 2023 15:20:13 +0800 Subject: [PATCH] 1. change the left out join into left anti join type 2. optimize the map parameter to adapt dpu --- .../org/apache/spark/sql/DataIoAdapter.java | 86 +++- .../huawei/boostkit/omnidata/spark/NdpPlugin | 471 ++++++++++++++++++ .../datasources/FileScanRDDPushDown.scala | 235 ++++++++- 3 files changed, 755 insertions(+), 37 deletions(-) create mode 100644 omnidata/omnidata-spark-connector/connector/src/main/scala/com/huawei/boostkit/omnidata/spark/NdpPlugin diff --git a/omnidata/omnidata-spark-connector/connector/src/main/java/org/apache/spark/sql/DataIoAdapter.java b/omnidata/omnidata-spark-connector/connector/src/main/java/org/apache/spark/sql/DataIoAdapter.java index 2a0a67b40..911048c42 100644 --- a/omnidata/omnidata-spark-connector/connector/src/main/java/org/apache/spark/sql/DataIoAdapter.java +++ b/omnidata/omnidata-spark-connector/connector/src/main/java/org/apache/spark/sql/DataIoAdapter.java @@ -166,6 +166,7 @@ public class DataIoAdapter { * @param partitionColumn partition column * @param filterOutPut filter schema * @param pushDownOperators push down expressions + * @param domains domain map * @return WritableColumnVector data result info * @throws TaskExecutionException connect to omni-data-server failed exception * @notice 3rd parties api throws Exception, function has to catch basic Exception @@ -175,7 +176,8 @@ public class DataIoAdapter { Seq sparkOutPut, Seq partitionColumn, Seq filterOutPut, - PushDownInfo pushDownOperators) throws TaskExecutionException, UnknownHostException { + PushDownInfo pushDownOperators, + ImmutableMap domains) throws TaskExecutionException, UnknownHostException { // initCandidates initCandidates(pageCandidate, filterOutPut); @@ -202,7 +204,7 @@ public class DataIoAdapter { Predicate predicate = new Predicate( omnidataTypes, omnidataColumns, filterRowExpression, omnidataProjections, - buildDomains(filterRowExpression), ImmutableMap.of(), aggregations, limitLong); + domains, ImmutableMap.of(), aggregations, limitLong); TaskSource taskSource = new TaskSource(dataSource, predicate, MAX_PAGE_SIZE_IN_BYTES); // create deserializer @@ -211,11 +213,20 @@ public class DataIoAdapter { PageDeserializer deserializer = initPageDeserializer(); // get available host - String[] pushDownHostArray = pageCandidate.getpushDownHosts().split(","); - List pushDownHostList = new ArrayList<>(Arrays.asList(pushDownHostArray)); - Optional availablePushDownHost = getRandomAvailablePushDownHost(pushDownHostArray, - JavaConverters.mapAsJavaMap(pushDownOperators.fpuHosts())); - availablePushDownHost.ifPresent(pushDownHostList::add); + List pushDownHostList = new ArrayList<>(); + String[] pushDownHostArray; + if (pageCandidate.getpushDownHosts().length() == 0) { + Optional availablePushDownHost = getRandomAvailablePushDownHost(new String[]{}, + JavaConverters.mapAsJavaMap(pushDownOperators.fpuHosts())); + availablePushDownHost.ifPresent(pushDownHostList::add); + pushDownHostArray = pushDownHostList.toArray(new String[]{}); + } else { + pushDownHostArray = pageCandidate.getpushDownHosts().split(","); + pushDownHostList = new ArrayList<>(Arrays.asList(pushDownHostArray)); + Optional availablePushDownHost = getRandomAvailablePushDownHost(pushDownHostArray, + JavaConverters.mapAsJavaMap(pushDownOperators.fpuHosts())); + availablePushDownHost.ifPresent(pushDownHostList::add); + } return getIterator(pushDownHostList.iterator(), taskSource, pushDownHostArray, deserializer, pushDownHostList.size()); } @@ -275,11 +286,12 @@ public class DataIoAdapter { private Optional getRandomAvailablePushDownHost(String[] pushDownHostArray, Map fpuHosts) { List existingHosts = Arrays.asList(pushDownHostArray); - List allHosts = new ArrayList<>(fpuHosts.values()); + List allHosts = new ArrayList<>(fpuHosts.keySet()); allHosts.removeAll(existingHosts); if (allHosts.size() > 0) { - LOG.info("Add another available host: " + allHosts.get(0)); - return Optional.of(allHosts.get(0)); + int randomIndex = (int) (Math.random() * allHosts.size()); + LOG.info("Add another available host: " + allHosts.get(randomIndex)); + return Optional.of(allHosts.get(randomIndex)); } else { return Optional.empty(); } @@ -304,24 +316,11 @@ public class DataIoAdapter { } private void initCandidates(PageCandidate pageCandidate, Seq filterOutPut) { - omnidataTypes.clear(); - omnidataColumns.clear(); - omnidataProjections.clear(); - fieldMap.clear(); - columnNameSet.clear(); - columnTypesList.clear(); - columnOrdersList.clear(); - filterTypesList.clear(); - filterOrdersList.clear(); - partitionColumnName.clear(); - columnNameMap.clear(); - columnOrder = 0; + initCandidatesBeforeDomain(filterOutPut); filePath = pageCandidate.getFilePath(); columnOffset = pageCandidate.getColumnOffset(); - listAtt = JavaConverters.seqAsJavaList(filterOutPut); TASK_FAILED_TIMES = pageCandidate.getMaxFailedTimes(); taskTimeout = pageCandidate.getTaskTimeout(); - isPushDownAgg = true; } private RowExpression extractNamedExpression(NamedExpression namedExpression) { @@ -904,7 +903,44 @@ public class DataIoAdapter { return isOperatorCombineEnabled; } - public ImmutableMap buildDomains(Optional filterRowExpression) { + private void initCandidatesBeforeDomain(Seq filterOutPut) { + omnidataTypes.clear(); + omnidataColumns.clear(); + omnidataProjections.clear(); + columnNameSet.clear(); + columnTypesList.clear(); + columnOrdersList.clear(); + fieldMap.clear(); + filterTypesList.clear(); + filterOrdersList.clear(); + columnNameMap.clear(); + columnOrder = 0; + partitionColumnName.clear(); + listAtt = JavaConverters.seqAsJavaList(filterOutPut); + isPushDownAgg = true; + } + + public ImmutableMap buildDomains( + Seq sparkOutPut, + Seq partitionColumn, + Seq filterOutPut, + PushDownInfo pushDownOperators) { + + // initCandidates + initCandidatesBeforeDomain(filterOutPut); + + // add partition column + JavaConverters.seqAsJavaList(partitionColumn).forEach(a -> partitionColumnName.add(a.name())); + + // init column info + if (pushDownOperators.aggExecutions().size() == 0) { + isPushDownAgg = false; + initColumnInfo(sparkOutPut); + } + + // create filter + Optional filterRowExpression = initFilter(pushDownOperators.filterExecutions()); + long startTime = System.currentTimeMillis(); ImmutableMap.Builder domains = ImmutableMap.builder(); if (filterRowExpression.isPresent() && NdpConf.getNdpDomainGenerateEnable(TaskContext.get())) { diff --git a/omnidata/omnidata-spark-connector/connector/src/main/scala/com/huawei/boostkit/omnidata/spark/NdpPlugin b/omnidata/omnidata-spark-connector/connector/src/main/scala/com/huawei/boostkit/omnidata/spark/NdpPlugin new file mode 100644 index 000000000..79cf6b957 --- /dev/null +++ b/omnidata/omnidata-spark-connector/connector/src/main/scala/com/huawei/boostkit/omnidata/spark/NdpPlugin @@ -0,0 +1,471 @@ +package com.huawei.boostkit.omnidata.spark + +import org.apache.hadoop.fs.{FileSystem, Path} +import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.catalog.{CatalogTable, HiveTableRelation} +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Average, Count, Sum} +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.{Inner, LeftAnti, LeftOuter} +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Join, LogicalPlan, Project, Sort} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution.command.{CreateDataSourceTableAsSelectCommand, DataWritingCommandExec} +import org.apache.spark.sql.execution.datasources.LogicalRelation +import org.apache.spark.sql.execution._ +import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, SortAggregateExec} +import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec +import org.apache.spark.sql.execution.joins.{BaseJoinExec, BroadcastHashJoinExec, ColumnarBroadcastHashJoinExec, ColumnarShuffledHashJoinExec, ColumnarSortMergeJoinExec} +import org.apache.spark.sql.execution.ndp.NdpPushDown +import org.apache.spark.sql.hive.execution.CreateHiveTableAsSelectCommand +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types._ +import org.apache.spark.sql.{SparkSession, SparkSessionExtensions} + +import java.net.URI +import scala.collection.JavaConverters + +case class NdpOverrides() extends Rule[SparkPlan] { + + var numPartitions: Int = SQLConf.get.getConfString("spark.omni.sql.ndpPlugin.coalesce.numPartitions", "1000").toInt + var isSMJ = false + var isSort = false + + def apply(plan: SparkPlan): SparkPlan = { + SQLConf.get.setConfString(SQLConf.SHUFFLE_PARTITIONS.key, "2000") + val ruleList = Seq(CountReplaceRule) + val afterPlan = ruleList.foldLeft(plan) { case (sp, rule) => + val result = rule.apply(sp) + result + } + val finalPlan = replaceWithOptimizedPlan(afterPlan) + if (isSMJ) { + SQLConf.get.setConfString(SQLConf.FILES_MAX_PARTITION_BYTES.key, "536870912") + } + finalPlan + } + + def replaceWithOptimizedPlan(plan: SparkPlan): SparkPlan = { + val needReplaceJoin = hasComplementOperator(plan) + var firstFilter = true; + val p = plan.transformUp { + case p@ColumnarSortExec(sortOrder, global, child, testSpillFrequency) if isRadixSortExecEnable(sortOrder) => + isSort = true + RadixSortExec(sortOrder, global, child, testSpillFrequency) + case p@SortExec(sortOrder, global, child, testSpillFrequency) if isRadixSortExecEnable(sortOrder) => + isSort = true + RadixSortExec(sortOrder, global, child, testSpillFrequency) + case p@DataWritingCommandExec(cmd, child) => + if (isSort) { + p + } else { + DataWritingCommandExec(cmd, CoalesceExec(numPartitions, child)) + } + case p@ColumnarSortMergeJoinExec(leftKeys, rightKeys, joinType, condition, left, right, isSkewJoin, projectList) if joinType.equals(LeftOuter) => + isSMJ = true + numPartitions = 2000 + ColumnarSortMergeJoinExec(leftKeys = p.leftKeys, rightKeys = p.rightKeys, joinType = LeftAnti, condition = p.condition, left = p.left, right = p.right, isSkewJoin = p.isSkewJoin, projectList) + case p@ColumnarBroadcastHashJoinExec(leftKeys, rightKeys, joinType, buildSide, condition, left, right, isNullAwareAntiJoin, projectList) => + if (joinType.equals(LeftOuter)) { + ColumnarBroadcastHashJoinExec(leftKeys = p.leftKeys, rightKeys = p.rightKeys, joinType = LeftAnti, buildSide = p.buildSide, condition = p.condition, left = p.left, right = p.right, isNullAwareAntiJoin = p.isNullAwareAntiJoin, projectList) + } else { + // numPartitions = 2000 + p + } + case p: BroadcastHashJoinExec => + // numPartitions = 2000 + p + case p@ColumnarShuffledHashJoinExec(leftKeys, rightKeys, joinType, buildSide, condition, left, right, projectList) if joinType.equals(LeftOuter) => + ColumnarShuffledHashJoinExec(p.leftKeys, p.rightKeys, LeftAnti, p.buildSide, p.condition, p.left, p.right, projectList) + case p@FilterExec(condition, child: OmniColumnarToRowExec, selectivity) => + val childPlan = child.transform { + case p@OmniColumnarToRowExec(child: NdpFileSourceScanExec) => + ColumnarToRowExec(FileSourceScanExec(child.relation, + child.output, + child.requiredSchema, + child.partitionFilters, + child.optionalBucketSet, + child.optionalNumCoalescedBuckets, + child.dataFilters, + child.tableIdentifier, + child.partitionColumn, + child.disableBucketedScan)) + case p@OmniColumnarToRowExec(child: FileSourceScanExec) => + ColumnarToRowExec(child) + case p => p + } + FilterExec(condition, childPlan, selectivity) + case d@DataWritingCommandExec(cmd, c1@OmniColumnarToRowExec(c2@ColumnarFilterExec(condition, c3: FileSourceScanExec))) => + d.copy(child = FilterExec(condition, ColumnarToRowExec(c3))) + case c1@OmniColumnarToRowExec(c2@ColumnarFilterExec(condition, c3: FileSourceScanExec)) => + FilterExec(condition, ColumnarToRowExec(c3)) + case p@ColumnarConditionProjectExec(projectList, condition, child) if condition.toString().startsWith("isnull") && (child.isInstanceOf[ColumnarSortMergeJoinExec] + || child.isInstanceOf[ColumnarBroadcastHashJoinExec] || child.isInstanceOf[ColumnarShuffledHashJoinExec]) => + ColumnarProjectExec(changeProjectList(projectList), child) + case p: SortAggregateExec if p.child.isInstanceOf[OmniColumnarToRowExec] + && p.child.asInstanceOf[OmniColumnarToRowExec].child.isInstanceOf[ColumnarSortExec] + && isAggPartial(p.aggregateAttributes) => + val omniColumnarToRow = p.child.asInstanceOf[OmniColumnarToRowExec] + val omniColumnarSort = omniColumnarToRow.child.asInstanceOf[ColumnarSortExec] + SortAggregateExec(p.requiredChildDistributionExpressions, + p.groupingExpressions, + p.aggregateExpressions, + p.aggregateAttributes, + p.initialInputBufferOffset, + p.resultExpressions, + SortExec(omniColumnarSort.sortOrder, + omniColumnarSort.global, + ColumnarToRowExec(omniColumnarSort.child), + omniColumnarSort.testSpillFrequency)) + case p: SortAggregateExec if p.child.isInstanceOf[OmniColumnarToRowExec] + && p.child.asInstanceOf[OmniColumnarToRowExec].child.isInstanceOf[ColumnarSortExec] + && isAggFinal(p.aggregateAttributes) => + val omniColumnarToRow = p.child.asInstanceOf[OmniColumnarToRowExec] + val omniColumnarSort = omniColumnarToRow.child.asInstanceOf[ColumnarSortExec] + val omniShuffleExchange = omniColumnarSort.child.asInstanceOf[ColumnarShuffleExchangeExec] + val rowToOmniColumnar = omniShuffleExchange.child.asInstanceOf[RowToOmniColumnarExec] + SortAggregateExec(p.requiredChildDistributionExpressions, + p.groupingExpressions, + p.aggregateExpressions, + p.aggregateAttributes, + p.initialInputBufferOffset, + p.resultExpressions, + SortExec(omniColumnarSort.sortOrder, + omniColumnarSort.global, + ShuffleExchangeExec(omniShuffleExchange.outputPartitioning, rowToOmniColumnar.child, omniShuffleExchange.shuffleOrigin), + omniColumnarSort.testSpillFrequency)) + case p@OmniColumnarToRowExec(agg: ColumnarHashAggregateExec) if agg.groupingExpressions.nonEmpty && agg.child.isInstanceOf[ColumnarShuffleExchangeExec] => + val omniExchange = agg.child.asInstanceOf[ColumnarShuffleExchangeExec] + val omniHashAgg = omniExchange.child.asInstanceOf[ColumnarHashAggregateExec] + HashAggregateExec(agg.requiredChildDistributionExpressions, + agg.groupingExpressions, + agg.aggregateExpressions, + agg.aggregateAttributes, + agg.initialInputBufferOffset, + agg.resultExpressions, + ShuffleExchangeExec(omniExchange.outputPartitioning, + HashAggregateExec(omniHashAgg.requiredChildDistributionExpressions, + omniHashAgg.groupingExpressions, + omniHashAgg.aggregateExpressions, + omniHashAgg.aggregateAttributes, + omniHashAgg.initialInputBufferOffset, + omniHashAgg.resultExpressions, + ColumnarToRowExec(omniHashAgg.child)), + omniExchange.shuffleOrigin)) + case p => p + } + p + } + + def isAggPartial(aggAttributes: Seq[Attribute]): Boolean = { + aggAttributes.exists(x => x.name.equals("max") || x.name.equals("maxxx")) + } + + def isAggFinal(aggAttributes: Seq[Attribute]): Boolean = { + aggAttributes.exists(x => x.name.contains("avg(cast")) + } + + def changeProjectList(projectList: Seq[NamedExpression]): Seq[NamedExpression] = { + val p = projectList.map { + case exp: Alias => + Alias(Literal(null, exp.dataType), exp.name)( + exprId = exp.exprId, + qualifier = exp.qualifier, + explicitMetadata = exp.explicitMetadata, + nonInheritableMetadataKeys = exp.nonInheritableMetadataKeys + ) + case exp => exp + } + p + } + + def hasComplementOperator(p: SparkPlan): Boolean = { + var result = false + var hasFilterAtFormerOp = false + p.transformDown { + case f: FilterExec => + if (f.condition.toString().startsWith("isnull")) { + hasFilterAtFormerOp = true + } else { + hasFilterAtFormerOp = false + } + f + case j: BaseJoinExec => + if (hasFilterAtFormerOp) { + result = true + } + j + case o => + hasFilterAtFormerOp = false + o + } + result + } + + def isRadixSortExecEnable(sortOrder: Seq[SortOrder]): Boolean = { + sortOrder.length == 2 && + sortOrder.head.dataType == LongType && + sortOrder(1).dataType == LongType && + SQLConf.get.getConfString("spark.omni.sql.ndpPlugin.radixSort.enabled", "true").toBoolean + } +} + +case class NdpRules(session: SparkSession) extends ColumnarRule with Logging { + + def ndpOverrides: NdpOverrides = NdpOverrides() + + override def preColumnarTransitions: Rule[SparkPlan] = plan => { + plan + } + + override def postColumnarTransitions: Rule[SparkPlan] = plan => { + if (NdpPluginEnableFlag.isEnable(plan.session)) { + val rule = ndpOverrides + rule(plan) + } else { + plan + } + } +} + +case class NdpOptimizerRules(session: SparkSession) extends Rule[LogicalPlan] { + + def ndpEnabled: Boolean = session.sqlContext.getConf( + "spark.omni.sql.ndpPlugin.enabled", "true").trim.toBoolean + + def isMatchedIpAddress: Boolean = { + val ipSet = Set("90.90.57.114", "90.90.59.122") + val hostAddrSet = JavaConverters.asScalaSetConverter(NdpConnectorUtils.getIpAddress).asScala + val res = ipSet & hostAddrSet + res.nonEmpty + } + + val SORT_REPARTITION_PLANS: Seq[String] = Seq( + "Sort,HiveTableRelation", + "Sort,LogicalRelation", + "Sort,RepartitionByExpression,HiveTableRelation", + "Sort,RepartitionByExpression,LogicalRelation", + "Sort,Project,HiveTableRelation", + "Sort,Project,LogicalRelation", + "Sort,RepartitionByExpression,Project,HiveTableRelation", + "Sort,RepartitionByExpression,Project,LogicalRelation" + ) + + val SORT_REPARTITION_SIZE: Int = SQLConf.get.getConfString( + "spark.omni.sql.ndpPlugin.sort.repartition.size", "104857600").toInt + val DECIMAL_PRECISION: Int = SQLConf.get.getConfString( + "spark.omni.sql.ndpPlugin.cast.decimal.precision", "15").toInt + val MAX_PARTITION_BYTES_ENABLE_FACTOR: Int = SQLConf.get.getConfString( + "spark.omni.sql.ndpPlugin.max.partitionBytesEnable.factor", "2").toInt + + + override def apply(plan: LogicalPlan): LogicalPlan = { + if (ndpEnabled && (isMatchedIpAddress || NdpConnectorUtils.getNdpEnable.equals("true"))) { + repartition(FileSystem.get(session.sparkContext.hadoopConfiguration), plan) + replaceWithOptimizedPlan(plan) + } else { + plan + } + } + + def replaceWithOptimizedPlan(plan: LogicalPlan): LogicalPlan = { + plan.transformUp { + case CreateHiveTableAsSelectCommand(tableDesc, query, outputColumnNames, mode) + if isParquetEnable(tableDesc) + && SQLConf.get.getConfString("spark.omni.sql.ndpPlugin.parquetOutput.enabled", "true") + .toBoolean => + CreateDataSourceTableAsSelectCommand( + tableDesc.copy(provider = Option("parquet")), mode, query, outputColumnNames) + case a@Aggregate(groupingExpressions, aggregateExpressions, _) + if SQLConf.get.getConfString("spark.omni.sql.ndpPlugin.castDecimal.enabled", "true") + .toBoolean => + var ifCast = false + if (groupingExpressions.nonEmpty && hasCount(aggregateExpressions)) { + SQLConf.get.setConfString(SQLConf.FILES_MAX_PARTITION_BYTES.key, "1024MB") + } else if (groupingExpressions.nonEmpty && hasAvg(aggregateExpressions)) { + SQLConf.get.setConfString(SQLConf.FILES_MAX_PARTITION_BYTES.key, "1024MB") + ifCast = true + } + if (ifCast) { + a.copy(aggregateExpressions = aggregateExpressions + .map(castSumAvgToBigInt) + .map(_.asInstanceOf[NamedExpression])) + } + else { + a + } + case j@Join(_, _, Inner, condition, _) => + // turnOffOperator() + // 6-x-bhj + SQLConf.get.setConfString(SQLConf.FILES_MAX_PARTITION_BYTES.key, "768MB") + if (condition.isDefined) { + condition.get match { + case e@EqualTo(attr1: AttributeReference, attr2: AttributeReference) => + j.copy(condition = Some(And(EqualTo(Substring(attr1, Literal(8), Literal(Integer.MAX_VALUE)) + , Substring(attr2, Literal(8), Literal(Integer.MAX_VALUE))), e))) + case _ => j + } + } else { + j + } + case s@Sort(order, _, _) => + s.copy(order = order.map(e => e.copy(child = castStringExpressionToBigint(e.child)))) + case p => p + } + } + + def hasCount(aggregateExpressions: Seq[Expression]): Boolean = { + aggregateExpressions.exists { + case exp: Alias if (exp.child.isInstanceOf[AggregateExpression] + && exp.child.asInstanceOf[AggregateExpression].aggregateFunction.isInstanceOf[Count]) => true + case _ => false + } + } + + def hasAvg(aggregateExpressions: Seq[Expression]): Boolean = { + aggregateExpressions.exists { + case exp: Alias if (exp.child.isInstanceOf[AggregateExpression] + && exp.child.asInstanceOf[AggregateExpression].aggregateFunction.isInstanceOf[Average]) => true + case _ => false + } + } + + def isParquetEnable(tableDesc: CatalogTable): Boolean = { + if (tableDesc.provider.isEmpty || tableDesc.provider.get.equals("hive")) { + if (tableDesc.storage.outputFormat.isEmpty + || tableDesc.storage.serde.get.equals("org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe")) { + return true + } + } + false + } + + def repartition(fs: FileSystem, plan: LogicalPlan): Unit = { + var tables = Seq[URI]() + var planContents = Seq[String]() + var maxPartitionBytesEnable = true + var existsProject = false + + plan.foreach { + case p@HiveTableRelation(tableMeta, _, _, _, _) => + if (tableMeta.storage.locationUri.isDefined) { + tables :+= tableMeta.storage.locationUri.get + } + planContents :+= p.nodeName + case p@LogicalRelation(_, _, catalogTable, _) => + if (catalogTable.isDefined && catalogTable.get.storage.locationUri.isDefined) { + tables :+= catalogTable.get.storage.locationUri.get + } + planContents :+= p.nodeName + case p: Project => + maxPartitionBytesEnable &= (p.output.length * MAX_PARTITION_BYTES_ENABLE_FACTOR < p.inputSet.size) + existsProject = true + planContents :+= p.nodeName + case p => + planContents :+= p.nodeName + } + + if (maxPartitionBytesEnable && existsProject) { + // SQLConf.get.setConfString(SQLConf.FILES_MAX_PARTITION_BYTES.key, "536870912") + } + repartitionShuffleForSort(fs, tables, planContents) + repartitionHdfsReadForDistinct(fs, tables, plan) + } + + def repartitionShuffleForSort(fs: FileSystem, tables: Seq[URI], planContents: Seq[String]): Unit = { + if (!SQLConf.get.getConfString("spark.omni.sql.ndpPlugin.radixSort.enabled", "true").toBoolean) { + return + } + + val planContent = planContents.mkString(",") + if (tables.length == 1 + && SORT_REPARTITION_PLANS.exists(planContent.contains(_))) { + val partitions = Math.max(1, fs.getContentSummary(new Path(tables.head)).getLength / SORT_REPARTITION_SIZE) + SQLConf.get.setConfString(SQLConf.SHUFFLE_PARTITIONS.key, "1000") + // SQLConf.get.setConfString("spark.shuffle.sort.bypassMergeThreshold", "1000") + turnOffOperator() + } + } + + def repartitionHdfsReadForDistinct(fs: FileSystem, tables: Seq[URI], plan: LogicalPlan): Unit = { + if (!SQLConf.get.getConfString("spark.omni.sql.ndpPlugin.distinct.enabled", "true").toBoolean) { + return + } + if (tables.length != 1) { + return + } + + plan.foreach { + case Aggregate(groupingExpressions, aggregateExpressions, _) if groupingExpressions == aggregateExpressions => + val executors = SQLConf.get.getConfString("spark.executor.instances", "14").toLong + val cores = SQLConf.get.getConfString("spark.executor.cores", "21").toLong + val multi = SQLConf.get.getConfString("spark.omni.sql.ndpPlugin.read.cores.multi", "16").toFloat + val partitionByte = (fs.getContentSummary(new Path(tables.head)).getLength / (executors * cores * multi)).toLong + // SQLConf.get.setConfString(SQLConf.FILES_MAX_PARTITION_BYTES.key, + // Math.max(SQLConf.get.filesMaxPartitionBytes, partitionByte).toString) + SQLConf.get.setConfString(SQLConf.FILES_MAX_PARTITION_BYTES.key, "768MB") + // println(s"partitionByte:${partitionByte},partitions:${executors * cores * multi}") + return + case _ => + } + } + + def castSumAvgToBigInt(expression: Expression): Expression = { + val exp = expression.transform { + case agg@Average(cast: Cast, _) if cast.dataType.isInstanceOf[DoubleType] => + Average(Cast(cast.child, DataTypes.LongType), agg.failOnError) + case agg@Sum(cast: Cast, _) if cast.dataType.isInstanceOf[DoubleType] => + Sum(Cast(cast.child, DataTypes.LongType), agg.failOnError) + case e => + e + } + var finalExp = exp + exp match { + case agg: Alias if agg.child.isInstanceOf[AggregateExpression] + && agg.child.asInstanceOf[AggregateExpression].aggregateFunction.isInstanceOf[Sum] => + finalExp = Alias(Cast(agg.child, DataTypes.DoubleType), agg.name)( + exprId = agg.exprId, + qualifier = agg.qualifier, + explicitMetadata = agg.explicitMetadata, + nonInheritableMetadataKeys = agg.nonInheritableMetadataKeys + ) + case _ => + } + finalExp + } + + def castStringExpressionToBigint(expression: Expression): Expression = { + expression match { + case a@AttributeReference(_, DataTypes.StringType, _, _) => + Cast(a, DataTypes.LongType) + case e => e + } + } + + + def turnOffOperator(): Unit = { + session.sqlContext.setConf("org.apache.spark.sql.columnar.enabled", "false") + } +} + +class NdpPlugin extends (SparkSessionExtensions => Unit) with Logging { + override def apply(extensions: SparkSessionExtensions): Unit = { + extensions.injectColumnar(session => NdpRules(session)) + extensions.injectOptimizerRule(session => NdpOptimizerRules(session)) + } +} + +object NdpPluginEnableFlag { + + def isMatchedIpAddress: Boolean = { + val ipSet = Set("90.90.57.114", "90.90.59.122") + val hostAddrSet = JavaConverters.asScalaSetConverter(NdpConnectorUtils.getIpAddress).asScala + val res = ipSet & hostAddrSet + res.nonEmpty + } + + def isEnable(session: SparkSession): Boolean = { + def ndpEnabled: Boolean = session.sqlContext.getConf( + "spark.omni.sql.ndpPlugin.enabled", "true").trim.toBoolean + + ndpEnabled && (isMatchedIpAddress || NdpConnectorUtils.getNdpEnable.equals("true")) + } +} \ No newline at end of file diff --git a/omnidata/omnidata-spark-connector/connector/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDDPushDown.scala b/omnidata/omnidata-spark-connector/connector/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDDPushDown.scala index 21fd6a29c..4471ff6f1 100644 --- a/omnidata/omnidata-spark-connector/connector/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDDPushDown.scala +++ b/omnidata/omnidata-spark-connector/connector/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDDPushDown.scala @@ -17,8 +17,8 @@ package org.apache.spark.sql.execution.datasources + import com.google.common.collect.ImmutableMap -import io.prestosql.spi.relation.RowExpression import java.util import scala.collection.JavaConverters._ @@ -31,15 +31,16 @@ import org.apache.spark.rdd.{InputFileBlockHolder, RDD} import org.apache.spark.sql.{DataIoAdapter, NdpUtils, PageCandidate, PageToColumnar, PushDownManager, SparkSession} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{And, Attribute, BasePredicate, Expression, Predicate, UnsafeProjection} +import org.apache.spark.sql.execution.ndp.NdpSupport.filterStripEnd import org.apache.spark.sql.execution.{QueryExecutionException, RowToColumnConverter} import org.apache.spark.sql.execution.ndp.{FilterExeInfo, NdpConf, PushDownInfo} import org.apache.spark.sql.execution.vectorized.{OffHeapColumnVector, OnHeapColumnVector, WritableColumnVector} import org.apache.spark.sql.internal.SQLConf.ORC_IMPLEMENTATION import org.apache.spark.sql.types.StructType import org.apache.spark.sql.vectorized.ColumnarBatch +import org.apache.spark.util.NextIterator -import java.io.FileNotFoundException -import java.util.Optional +import java.io.{FileNotFoundException, IOException} import scala.util.Random @@ -60,7 +61,8 @@ class FileScanRDDPushDown( partialCondition: Boolean, partialPdRate: Double, zkPdRate: Double, - partialChildOutput: Seq[Attribute]) + partialChildOutput: Seq[Attribute], + isFakePushDown: Boolean = false) extends RDD[InternalRow](sparkSession.sparkContext, Nil) { var columnOffset = -1 @@ -77,7 +79,7 @@ class FileScanRDDPushDown( columnOffset = NdpUtils.getColumnOffset(dataSchema, output) filterOutput = output } - var fpuMap = pushDownOperators.fpuHosts + var fpuMap = pushDownOperators.fpuHosts.map(term => (term._2, term._1)) var fpuList : Seq[String] = Seq() for (key <- fpuMap.keys) { fpuList = fpuList :+ key @@ -103,21 +105,219 @@ class FileScanRDDPushDown( private val zkAddress = NdpConf.getNdpZookeeperAddress(sparkSession) private val taskTimeout = NdpConf.getTaskTimeout(sparkSession) private val operatorCombineEnabled = NdpConf.getNdpOperatorCombineEnabled(sparkSession) - val orcImpl = sparkSession.sessionState.conf.getConf(ORC_IMPLEMENTATION) + val orcImpl: String = sparkSession.sessionState.conf.getConf(ORC_IMPLEMENTATION) + + private val ignoreCorruptFiles = sparkSession.sessionState.conf.ignoreCorruptFiles + private val ignoreMissingFiles = sparkSession.sessionState.conf.ignoreMissingFiles + + var pushDownIterator : PushDownIterator = null + var forceOmniDataPushDown : Boolean = false override def compute(split: RDDPartition, context: TaskContext): Iterator[InternalRow] = { + if(isFakePushDown){ + log.info("Fake push down\n") + computeSparkRDDAndFakePushDown(split, context) + } else { + log.info("Really push down\n") + computePushDownRDD(split, context) + } + } + + def computePushDownRDD(split: RDDPartition, context: TaskContext): Iterator[InternalRow] = { val pageToColumnarClass = new PageToColumnar(requiredSchema, output) - var iterator : PushDownIterator = null - if (isPartialPushDown(partialCondition, partialPdRate, zkPdRate)) { + if (!forceOmniDataPushDown && isPartialPushDown(partialCondition, partialPdRate, zkPdRate)) { logDebug("partial push down task on spark") val partialFilterCondition = pushDownOperators.filterExecutions.reduce((a, b) => FilterExeInfo(And(a.filter, b.filter), partialChildOutput)) - val predicate = Predicate.create(partialFilterCondition.filter, partialChildOutput) + var partialFilter : Expression = null + if (orcImpl.equals("hive")) { + partialFilter = partialFilterCondition.filter + } else { + partialFilter = filterStripEnd(partialFilterCondition.filter) + } + val predicate = Predicate.create(partialFilter, partialChildOutput) predicate.initialize(0) - iterator = new PartialPushDownIterator(split, context, pageToColumnarClass, predicate) + pushDownIterator = new PartialPushDownIterator(split, context, pageToColumnarClass, predicate) } else { logDebug("partial push down task on omnidata") - iterator = new PushDownIterator(split, context, pageToColumnarClass) + pushDownIterator = new PushDownIterator(split, context, pageToColumnarClass) + } + // Register an on-task-completion callback to close the input stream. + context.addTaskCompletionListener[Unit](_ => pushDownIterator.close()) + + pushDownIterator.asInstanceOf[Iterator[InternalRow]] // This is an erasure hack. + } + + class FakePushDownThread(sparkThread: Thread, + split: RDDPartition, + context: TaskContext, + scan : FileScanRDDPushDown, + sparkLog : org.slf4j.Logger) extends Thread { + var times = 0 + scan.forceOmniDataPushDown = true + val iter: Iterator[Any] = scan.computePushDownRDD(split, context) + override def run(): Unit = { + while (!context.isCompleted() && sparkThread.isAlive && times <= 5 && iter.hasNext) { + sparkLog.info(">>>>>>Fake push down Thread [running]>>>>>") + Thread.sleep(200) + times = times + 1 + val currentValue = iter.next() + currentValue match { + case batch: ColumnarBatch => batch.close() + case _ => + } + } + sparkLog.info(">>>>>>Fake push down Thread [end]>>>>>") + pushDownIterator.close() + sparkLog.info("pushDownIterator close") + this.interrupt() + } + } + + def doFakePush(split: RDDPartition, context: TaskContext, scan : FileScanRDDPushDown): Unit ={ + val fakePushDownThread = new FakePushDownThread(Thread.currentThread(), split, context, scan, log) + fakePushDownThread.start() + } + + def computeSparkRDDAndFakePushDown(split: RDDPartition, context: TaskContext): Iterator[InternalRow] = { + //this code (computeSparkRDDAndFakePushDown) from spark FileScanRDD + doFakePush(split, context, this) + val iterator = new Iterator[Object] with AutoCloseable { + private val inputMetrics = context.taskMetrics().inputMetrics + private val existingBytesRead = inputMetrics.bytesRead + + // Find a function that will return the FileSystem bytes read by this thread. Do this before + // apply readFunction, because it might read some bytes. + private val getBytesReadCallback = + SparkHadoopUtil.get.getFSBytesReadOnThreadCallback() + + // We get our input bytes from thread-local Hadoop FileSystem statistics. + // If we do a coalesce, however, we are likely to compute multiple partitions in the same + // task and in the same thread, in which case we need to avoid override values written by + // previous partitions (SPARK-13071). + private def incTaskInputMetricsBytesRead(): Unit = { + inputMetrics.setBytesRead(existingBytesRead + getBytesReadCallback()) + } + + private[this] val files = split.asInstanceOf[FilePartition].files.toIterator + private[this] var currentFile: PartitionedFile = null + private[this] var currentIterator: Iterator[Object] = null + + def hasNext: Boolean = { + // Kill the task in case it has been marked as killed. This logic is from + // InterruptibleIterator, but we inline it here instead of wrapping the iterator in order + // to avoid performance overhead. + context.killTaskIfInterrupted() + (currentIterator != null && currentIterator.hasNext) || nextIterator() + } + def next(): Object = { + val nextElement = currentIterator.next() + // TODO: we should have a better separation of row based and batch based scan, so that we + // don't need to run this `if` for every record. + val preNumRecordsRead = inputMetrics.recordsRead + if (nextElement.isInstanceOf[ColumnarBatch]) { + incTaskInputMetricsBytesRead() + inputMetrics.incRecordsRead(nextElement.asInstanceOf[ColumnarBatch].numRows()) + } else { + // too costly to update every record + if (inputMetrics.recordsRead % + SparkHadoopUtil.UPDATE_INPUT_METRICS_INTERVAL_RECORDS == 0) { + incTaskInputMetricsBytesRead() + } + inputMetrics.incRecordsRead(1) + } + nextElement + } + + private def readCurrentFile(): Iterator[InternalRow] = { + try { + readFunction(currentFile) + } catch { + case e: FileNotFoundException => + throw new FileNotFoundException( + e.getMessage + "\n" + + "It is possible the underlying files have been updated. " + + "You can explicitly invalidate the cache in Spark by " + + "running 'REFRESH TABLE tableName' command in SQL or " + + "by recreating the Dataset/DataFrame involved.") + } + } + + /** Advances to the next file. Returns true if a new non-empty iterator is available. */ + private def nextIterator(): Boolean = { + if (files.hasNext) { + currentFile = files.next() + logInfo(s"Reading File $currentFile") + // Sets InputFileBlockHolder for the file block's information + InputFileBlockHolder.set(currentFile.filePath, currentFile.start, currentFile.length) + + if (ignoreMissingFiles || ignoreCorruptFiles) { + currentIterator = new NextIterator[Object] { + // The readFunction may read some bytes before consuming the iterator, e.g., + // vectorized Parquet reader. Here we use lazy val to delay the creation of + // iterator so that we will throw exception in `getNext`. + private lazy val internalIter = readCurrentFile() + + override def getNext(): AnyRef = { + try { + if (internalIter.hasNext) { + internalIter.next() + } else { + finished = true + null + } + } catch { + case e: FileNotFoundException if ignoreMissingFiles => + logWarning(s"Skipped missing file: $currentFile", e) + finished = true + null + // Throw FileNotFoundException even if `ignoreCorruptFiles` is true + case e: FileNotFoundException if !ignoreMissingFiles => throw e + case e @ (_: RuntimeException | _: IOException) if ignoreCorruptFiles => + logWarning( + s"Skipped the rest of the content in the corrupted file: $currentFile", e) + finished = true + null + } + } + + override def close(): Unit = {} + } + } else { + currentIterator = readCurrentFile() + } + + try { + hasNext + } catch { + case e: SchemaColumnConvertNotSupportedException => + val message = "Parquet column cannot be converted in " + + s"file ${currentFile.filePath}. Column: ${e.getColumn}, " + + s"Expected: ${e.getLogicalType}, Found: ${e.getPhysicalType}" + throw new QueryExecutionException(message, e) + case e: ParquetDecodingException => + if (e.getCause.isInstanceOf[SparkUpgradeException]) { + throw e.getCause + } else if (e.getMessage.contains("Can not read value at")) { + val message = "Encounter error while reading parquet files. " + + "One possible cause: Parquet column cannot be converted in the " + + "corresponding files. Details: " + throw new QueryExecutionException(message, e) + } + throw e + } + } else { + currentFile = null + InputFileBlockHolder.unset() + false + } + } + + override def close(): Unit = { + incTaskInputMetricsBytesRead() + InputFileBlockHolder.unset() + } } + // Register an on-task-completion callback to close the input stream. context.addTaskCompletionListener[Unit](_ => iterator.close()) @@ -134,6 +334,14 @@ class FileScanRDDPushDown( } override protected def getPartitions: Array[RDDPartition] = { + if(isFakePushDown) { + getSparkPartitions + } else { + getPushDownPartitions + } + } + + def getPushDownPartitions: Array[RDDPartition] = { filePartitions.map { partitionFile => { val retHost = mutable.HashMap.empty[String, Long] partitionFile.files.foreach { partitionMap => { @@ -178,6 +386,8 @@ class FileScanRDDPushDown( filePartitions.toArray } + def getSparkPartitions: Array[RDDPartition] = filePartitions.toArray + override protected def getPreferredLocations(split: RDDPartition): Seq[String] = { split.asInstanceOf[FilePartition].preferredLocations() } @@ -200,6 +410,7 @@ class FileScanRDDPushDown( var currentIterator: Iterator[Object] = null val sdiHosts: String = split.asInstanceOf[FilePartition].sdi val dataIoClass = new DataIoAdapter() + val domains: ImmutableMap[_, _] = dataIoClass.buildDomains(output,partitionColumns, filterOutput, pushDownOperators) def hasNext: Boolean = { // Kill the task in case it has been marked as killed. This logic is from @@ -255,7 +466,7 @@ class FileScanRDDPushDown( currentFile.length, columnOffset, sdiHosts, fileFormat.toString, maxFailedTimes, taskTimeout,operatorCombineEnabled) val dataIoPage = dataIoClass.getPageIterator(pageCandidate, output, - partitionColumns, filterOutput, pushDownOperators) + partitionColumns, filterOutput, pushDownOperators, domains) currentIterator = pageToColumnarClass.transPageToColumnar(dataIoPage, isColumnVector, dataIoClass.isOperatorCombineEnabled, output, orcImpl).asScala.iterator iteHasNext() -- Gitee