From ee4926244c708f87bd24559aac04e29df549e91c Mon Sep 17 00:00:00 2001 From: mushsoooup Date: Fri, 7 Mar 2025 13:00:33 +0800 Subject: [PATCH] optimize file partition --- .../ColumnarFileSourceScanExec.scala | 4 +- .../datasources/OmniFilePartition.scala | 125 ++++++++++++++++++ 2 files changed, 127 insertions(+), 2 deletions(-) create mode 100644 omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/OmniFilePartition.scala diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarFileSourceScanExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarFileSourceScanExec.scala index ddabce367..f9bd8d236 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarFileSourceScanExec.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarFileSourceScanExec.scala @@ -525,7 +525,7 @@ abstract class BaseColumnarFileSourceScanExec( fsRelation: HadoopFsRelation): RDD[InternalRow] = { val openCostInBytes = fsRelation.sparkSession.sessionState.conf.filesOpenCostInBytes val maxSplitBytes = - FilePartition.maxSplitBytes(fsRelation.sparkSession, selectedPartitions) + OmniFilePartition.maxSplitBytes(fsRelation.sparkSession, selectedPartitions) logInfo(s"Planning scan with bin packing, max size: $maxSplitBytes bytes, " + s"open cost is considered as scanning $openCostInBytes bytes.") @@ -562,7 +562,7 @@ abstract class BaseColumnarFileSourceScanExec( }.sortBy(_.length)(implicitly[Ordering[Long]].reverse) val partitions = - FilePartition.getFilePartitions(relation.sparkSession, splitFiles, maxSplitBytes) + OmniFilePartition.getFilePartitions(relation.sparkSession, splitFiles, maxSplitBytes) new FileScanRDD(fsRelation.sparkSession, readFile, partitions, new StructType(requiredSchema.fields ++ fsRelation.partitionSchema.fields), metadataColumns) diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/OmniFilePartition.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/OmniFilePartition.scala new file mode 100644 index 000000000..9b6d2d42c --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/OmniFilePartition.scala @@ -0,0 +1,125 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.datasources + +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.Partition +import org.apache.spark.internal.Logging +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.connector.read.InputPartition + +/** + * A collection of file blocks that should be read as a single task + * (possibly from multiple partitioned directories). + */ +case class FilePartition(index: Int, files: Array[PartitionedFile]) + extends Partition with InputPartition { + override def preferredLocations(): Array[String] = { + // Computes total number of bytes can be retrieved from each host. + val hostToNumBytes = mutable.HashMap.empty[String, Long] + files.foreach { file => + file.locations.filter(_ != "localhost").foreach { host => + hostToNumBytes(host) = hostToNumBytes.getOrElse(host, 0L) + file.length + } + } + + // Takes the first 3 hosts with the most data to be retrieved + hostToNumBytes.toSeq.sortBy { + case (host, numBytes) => numBytes + }.reverse.take(3).map { + case (host, numBytes) => host + }.toArray + } +} + +object OmniFilePartition extends Logging { + + def getFilePartitions( + sparkSession: SparkSession, + partitionedFiles: Seq[PartitionedFile], + maxSplitBytes: Long): Seq[FilePartition] = { + val partitions = new ArrayBuffer[FilePartition] + val currentFiles = new ArrayBuffer[PartitionedFile] + var currentSize = 0L + + /** Close the current partition and move to the next. */ + def closePartition(): Unit = { + if (currentFiles.nonEmpty) { + // Copy to a new Array. + val newPartition = FilePartition(partitions.size, currentFiles.toArray) + partitions += newPartition + } + currentFiles.clear() + currentSize = 0 + } + + val openCostInBytes = sparkSession.sessionState.conf.filesOpenCostInBytes + var start = 0 + partitionedFiles.iterator.takeWhile( + _.length + openCostInBytes >= maxSplitBytes).foreach { file => + currentFiles += file + closePartition() + start += 1 + } + val fileFragments = partitionedFiles.slice(start, partitionedFiles.length) + var bins : List[List[PartitionedFile]] = List() + var binRemain : List[Long] = List() + fileFragments.iterator.foreach { file => + var placed = false + for (i <- bins.indices if !placed) { + val bin = bins(i) + val currentBinCapacity = binRemain(i) + val expectedBinCapacity = currentBinCapacity - file.length - openCostInBytes + if (expectedBinCapacity >= 0) { + bins = bins.updated(i, bin :+ file) + binRemain = binRemain.updated(i, expectedBinCapacity) + placed = true + } + } + if (!placed) { + bins = bins :+ List(file) + binRemain = binRemain :+ maxSplitBytes - openCostInBytes - file.length + } + } + bins.iterator.foreach { bin => + bin.iterator.foreach { file => + currentFiles += file + } + closePartition() + } + partitions.toSeq + } + + def maxSplitBytes( + sparkSession: SparkSession, + selectedPartitions: Seq[PartitionDirectory]): Long = { + val defaultMaxSplitBytes = sparkSession.sessionState.conf.filesMaxPartitionBytes + val openCostInBytes = sparkSession.sessionState.conf.filesOpenCostInBytes + val minPartitionNum = sparkSession.sessionState.conf.filesMinPartitionNum + .getOrElse(sparkSession.leafNodeDefaultParallelism) + val totalBytes = selectedPartitions.flatMap(_.files.map(_.getLen + openCostInBytes)).sum + val bytesPerCore = totalBytes / minPartitionNum + + var adaptiveMaxSplitBytes = bytesPerCore + while (adaptiveMaxSplitBytes > defaultMaxSplitBytes) { + adaptiveMaxSplitBytes /= 2 + } + adaptiveMaxSplitBytes + } +} -- Gitee