diff --git a/omniadvisor/pom.xml b/omniadvisor/pom.xml index 74ad839147c027537db0dee19ef162f0df6743e7..d72c33d17aa732bcbca24f960e15d14c3c1ad05c 100644 --- a/omniadvisor/pom.xml +++ b/omniadvisor/pom.xml @@ -305,6 +305,10 @@ org.apache.commons commons-text + + org.apache.commons + commons-compress + com.google.guava guava diff --git a/omniadvisor/src/main/scala/org/apache/spark/SparkApplicationDataExtractor.scala b/omniadvisor/src/main/scala/org/apache/spark/SparkApplicationDataExtractor.scala index 65368433a2bf733b6c5bdb63784980528765a712..b15800ef498fa27c58c80a1d0133432066e6486a 100644 --- a/omniadvisor/src/main/scala/org/apache/spark/SparkApplicationDataExtractor.scala +++ b/omniadvisor/src/main/scala/org/apache/spark/SparkApplicationDataExtractor.scala @@ -17,11 +17,10 @@ package org.apache.spark import com.huawei.boostkit.omniadvisor.fetcher.FetcherType import com.huawei.boostkit.omniadvisor.models.AppResult -import com.huawei.boostkit.omniadvisor.spark.utils.ScalaUtils.parseMapToJsonString -import com.huawei.boostkit.omniadvisor.spark.utils.ScalaUtils.checkSuccess +import com.huawei.boostkit.omniadvisor.spark.utils.ScalaUtils.{checkSuccess, parseMapToJsonString} import com.huawei.boostkit.omniadvisor.spark.utils.SparkUtils import com.nimbusds.jose.util.StandardCharset -import org.apache.spark.sql.execution.ui.SQLExecutionUIData +import org.apache.spark.sql.execution.ui.{SQLExecutionUIData, SparkPlanGraph} import org.apache.spark.status.api.v1._ import org.slf4j.{Logger, LoggerFactory} @@ -41,7 +40,8 @@ object SparkApplicationDataExtractor { workload: String, environmentInfo: ApplicationEnvironmentInfo, jobsList: Seq[JobData], - sqlExecutionsList: Seq[SQLExecutionUIData]): AppResult = { + sqlExecutionsList: Seq[SQLExecutionUIData], + sqlGraphMap: mutable.Map[Long, SparkPlanGraph]): AppResult = { val appResult = new AppResult appResult.applicationId = appInfo.id appResult.applicationName = appInfo.name @@ -57,7 +57,7 @@ object SparkApplicationDataExtractor { val attempt: ApplicationAttemptInfo = lastAttempt(appInfo) if (attempt.completed && jobsList.nonEmpty && checkSuccess(jobsList)) { - saveSuccessfulStatus(appResult, jobsList, sqlExecutionsList) + saveSuccessfulStatus(appResult, jobsList, sqlExecutionsList, sqlGraphMap) } else { saveFailedStatus(appResult, attempt) } @@ -112,7 +112,7 @@ object SparkApplicationDataExtractor { appResult.query = "" } - private def saveSuccessfulStatus(appResult: AppResult, jobsList: Seq[JobData], sqlExecutionsList: Seq[SQLExecutionUIData]): Unit = { + private def saveSuccessfulStatus(appResult: AppResult, jobsList: Seq[JobData], sqlExecutionsList: Seq[SQLExecutionUIData], sqlGraphMap: mutable.Map[Long, SparkPlanGraph]): Unit = { appResult.executionStatus = AppResult.SUCCEEDED_STATUS val (startTime, finishTime) = extractJobsTime(jobsList) @@ -122,7 +122,7 @@ object SparkApplicationDataExtractor { finishTime - startTime else AppResult.FAILED_JOB_DURATION if (appResult.submit_method.equals(AppResult.SPARK_SQL)) { - appResult.query = extractQuerySQL(sqlExecutionsList) + appResult.query = extractQuerySQL(sqlExecutionsList, sqlGraphMap) } else { appResult.query = "" } @@ -153,16 +153,32 @@ object SparkApplicationDataExtractor { (startTime, finishTime) } - private def extractQuerySQL(sqlExecutionsList: Seq[SQLExecutionUIData]): String = { + private def extractQuerySQL(sqlExecutionsList: Seq[SQLExecutionUIData], sqlGraphMap: mutable.Map[Long, SparkPlanGraph]): String = { require(sqlExecutionsList.nonEmpty) - val nonEmptyDescriptions = sqlExecutionsList.flatMap { execution => - Option(execution.description).filter(_.nonEmpty) + val descriptions: mutable.ArrayBuffer[String] = mutable.ArrayBuffer.empty[String] + var previousExecutionDesc: Option[String] = None + for ((execution, index) <- sqlExecutionsList.zipWithIndex) { + if (index > 0) { + val sqlGraph = sqlGraphMap(execution.executionId) + if (execution.description.nonEmpty && !isDuplicatedQuery(execution, previousExecutionDesc, sqlGraph)) { + descriptions += execution.description.trim + } + } else { + if (execution.description.nonEmpty) { + descriptions += execution.description.trim + } + } + previousExecutionDesc = Some(execution.description) } - nonEmptyDescriptions.mkString(";") + descriptions.mkString(";\n") } private def lastAttempt(applicationInfo: ApplicationInfo): ApplicationAttemptInfo = { require(applicationInfo.attempts.nonEmpty) applicationInfo.attempts.last } + + private def isDuplicatedQuery(execution: SQLExecutionUIData, previousExecutionDesc: Option[String], sqlGraph: SparkPlanGraph): Boolean = { + execution.description.equals(previousExecutionDesc.getOrElse("")) && sqlGraph.allNodes.size == 1 && sqlGraph.allNodes.head.name.equals("LocalTableScan") + } } diff --git a/omniadvisor/src/main/scala/org/apache/spark/SparkDataCollection.scala b/omniadvisor/src/main/scala/org/apache/spark/SparkDataCollection.scala index 236469cce9da0828965d2a694a426e214069e1fb..aa5a4bf7d8b1c68c249575acd90e94518a62376e 100644 --- a/omniadvisor/src/main/scala/org/apache/spark/SparkDataCollection.scala +++ b/omniadvisor/src/main/scala/org/apache/spark/SparkDataCollection.scala @@ -16,16 +16,17 @@ package org.apache.spark import com.huawei.boostkit.omniadvisor.models.AppResult -import org.apache.spark.status.api.v1 -import org.apache.spark.util.kvstore.{InMemoryStore, KVStore} import org.apache.spark.internal.config.Status.ASYNC_TRACKING_ENABLED import org.apache.spark.scheduler.ReplayListenerBus +import org.apache.spark.sql.execution.ui.{SQLAppStatusListener, SQLAppStatusStore, SQLExecutionUIData, SparkPlanGraph} +import org.apache.spark.status.api.v1 import org.apache.spark.status.{AppStatusListener, AppStatusStore, ElementTrackingStore} -import org.apache.spark.sql.execution.ui.{SQLAppStatusListener, SQLAppStatusStore, SQLExecutionUIData} import org.apache.spark.util.Utils +import org.apache.spark.util.kvstore.{InMemoryStore, KVStore} import org.slf4j.{Logger, LoggerFactory} import java.io.InputStream +import scala.collection.mutable class SparkDataCollection { val LOG: Logger = LoggerFactory.getLogger(classOf[SparkDataCollection]) @@ -36,6 +37,7 @@ class SparkDataCollection { var jobsList: Seq[v1.JobData] = _ var appInfo: v1.ApplicationInfo = _ var sqlExecutionsList: Seq[SQLExecutionUIData] = _ + var sqlGraphMap: mutable.Map[Long, SparkPlanGraph] = _ def replayEventLogs(in: InputStream, sourceName: String): Unit = { @@ -66,12 +68,22 @@ class SparkDataCollection { val sqlAppStatusStore: SQLAppStatusStore = new SQLAppStatusStore(store) sqlExecutionsList = sqlAppStatusStore.executionsList() + sqlGraphMap = mutable.HashMap.empty[Long, SparkPlanGraph] + sqlExecutionsList.foreach { sqlExecution => + try { + val planGraph = sqlAppStatusStore.planGraph(sqlExecution.executionId) + sqlGraphMap.put(sqlExecution.executionId, planGraph) + } catch { + case e: Exception => + LOG.warn(s"Get PlanGraph for SQLExecution [${sqlExecution.executionId}] in ${appInfo.id} failed") + } + } appStatusStore.close() } def getAppResult(workload: String): AppResult = { - SparkApplicationDataExtractor.extractAppResultFromAppStatusStore(appInfo, workload, environmentInfo, jobsList, sqlExecutionsList) + SparkApplicationDataExtractor.extractAppResultFromAppStatusStore(appInfo, workload, environmentInfo, jobsList, sqlExecutionsList, sqlGraphMap) } private def createInMemoryStore(): KVStore = { diff --git a/omniadvisor/src/test/java/org/apache/spark/TestSparkApplicationDataExtractor.java b/omniadvisor/src/test/java/org/apache/spark/TestSparkApplicationDataExtractor.java index 7eed130902a290b108988aadbedd3db301053c24..02953cb7cc46400f5de9343995b8e07dfbf47d28 100644 --- a/omniadvisor/src/test/java/org/apache/spark/TestSparkApplicationDataExtractor.java +++ b/omniadvisor/src/test/java/org/apache/spark/TestSparkApplicationDataExtractor.java @@ -113,7 +113,7 @@ public class TestSparkApplicationDataExtractor { sqlExecutionList.add(new SQLExecutionUIData(1, sqlQ10, "", "", asScalaBuffer(ImmutableList.of()), 0, Option.apply(new Date()), new HashMap<>(), new HashSet<>(), new HashMap<>())); sqlExecutionList.add(new SQLExecutionUIData(2, sqlQ11, "", "", asScalaBuffer(ImmutableList.of()), 0, Option.apply(new Date()), new HashMap<>(), new HashSet<>(), new HashMap<>())); - AppResult result = SparkApplicationDataExtractor.extractAppResultFromAppStatusStore(applicationInfo, TEST_WORK_LOAD, environmentInfo, asScalaBuffer(jobsList), asScalaBuffer(sqlExecutionList)); + AppResult result = SparkApplicationDataExtractor.extractAppResultFromAppStatusStore(applicationInfo, TEST_WORK_LOAD, environmentInfo, asScalaBuffer(jobsList), asScalaBuffer(sqlExecutionList), null); assertEquals(result.applicationId, "id"); assertEquals(result.durationTime, 15 * 60 * 1000L); assertEquals(result.submit_method, "spark-sql"); @@ -134,7 +134,7 @@ public class TestSparkApplicationDataExtractor { when(environmentInfo.sparkProperties()).thenReturn(asScalaBuffer(clientSparkProperties)); when(environmentInfo.systemProperties()).thenReturn(asScalaBuffer(systemProperties)); - AppResult result = SparkApplicationDataExtractor.extractAppResultFromAppStatusStore(applicationInfo, TEST_WORK_LOAD, environmentInfo, asScalaBuffer(jobsList), asScalaBuffer(ImmutableList.of())); + AppResult result = SparkApplicationDataExtractor.extractAppResultFromAppStatusStore(applicationInfo, TEST_WORK_LOAD, environmentInfo, asScalaBuffer(jobsList), asScalaBuffer(ImmutableList.of()), null); assertEquals(result.applicationId, "id"); assertEquals(result.durationTime, 15 * 60 * 1000L); assertEquals(result.submit_method, "spark-submit"); @@ -154,7 +154,7 @@ public class TestSparkApplicationDataExtractor { when(environmentInfo.sparkProperties()).thenReturn(asScalaBuffer(clusterSparkProperties)); when(environmentInfo.systemProperties()).thenReturn(asScalaBuffer(systemProperties)); - AppResult result = SparkApplicationDataExtractor.extractAppResultFromAppStatusStore(applicationInfo, TEST_WORK_LOAD, environmentInfo, asScalaBuffer(jobsList), asScalaBuffer(ImmutableList.of())); + AppResult result = SparkApplicationDataExtractor.extractAppResultFromAppStatusStore(applicationInfo, TEST_WORK_LOAD, environmentInfo, asScalaBuffer(jobsList), asScalaBuffer(ImmutableList.of()), null); assertEquals(result.applicationId, "id"); assertEquals(result.durationTime, 15 * 60 * 1000L); assertEquals(result.submit_method, "spark-submit"); @@ -169,7 +169,7 @@ public class TestSparkApplicationDataExtractor { ApplicationEnvironmentInfo environmentInfo = Mockito.mock(ApplicationEnvironmentInfo.class); when(environmentInfo.sparkProperties()).thenReturn(asScalaBuffer(ImmutableList.of())); when(environmentInfo.systemProperties()).thenReturn(asScalaBuffer(ImmutableList.of())); - AppResult result = SparkApplicationDataExtractor.extractAppResultFromAppStatusStore(applicationInfo, TEST_WORK_LOAD, environmentInfo, asScalaBuffer(ImmutableList.of(runningData)), asScalaBuffer(ImmutableList.of())); + AppResult result = SparkApplicationDataExtractor.extractAppResultFromAppStatusStore(applicationInfo, TEST_WORK_LOAD, environmentInfo, asScalaBuffer(ImmutableList.of(runningData)), asScalaBuffer(ImmutableList.of()), null); assertEquals(result.applicationId, "id"); assertEquals(result.executionStatus, AppResult.FAILED_STATUS); assertEquals(result.durationTime, AppResult.FAILED_JOB_DURATION); @@ -181,7 +181,7 @@ public class TestSparkApplicationDataExtractor { ApplicationEnvironmentInfo environmentInfo = Mockito.mock(ApplicationEnvironmentInfo.class); when(environmentInfo.sparkProperties()).thenReturn(asScalaBuffer(ImmutableList.of())); when(environmentInfo.systemProperties()).thenReturn(asScalaBuffer(ImmutableList.of())); - AppResult result = SparkApplicationDataExtractor.extractAppResultFromAppStatusStore(applicationInfo, TEST_WORK_LOAD, environmentInfo, asScalaBuffer(ImmutableList.of(failedData)), asScalaBuffer(ImmutableList.of())); + AppResult result = SparkApplicationDataExtractor.extractAppResultFromAppStatusStore(applicationInfo, TEST_WORK_LOAD, environmentInfo, asScalaBuffer(ImmutableList.of(failedData)), asScalaBuffer(ImmutableList.of()), null); assertEquals(result.applicationId, "id"); assertEquals(result.executionStatus, AppResult.FAILED_STATUS); assertEquals(result.durationTime, AppResult.FAILED_JOB_DURATION); @@ -193,7 +193,7 @@ public class TestSparkApplicationDataExtractor { ApplicationEnvironmentInfo environmentInfo = Mockito.mock(ApplicationEnvironmentInfo.class); when(environmentInfo.sparkProperties()).thenReturn(asScalaBuffer(ImmutableList.of())); when(environmentInfo.systemProperties()).thenReturn(asScalaBuffer(ImmutableList.of())); - AppResult result = SparkApplicationDataExtractor.extractAppResultFromAppStatusStore(applicationInfo, TEST_WORK_LOAD, environmentInfo, asScalaBuffer(ImmutableList.of()), asScalaBuffer(ImmutableList.of())); + AppResult result = SparkApplicationDataExtractor.extractAppResultFromAppStatusStore(applicationInfo, TEST_WORK_LOAD, environmentInfo, asScalaBuffer(ImmutableList.of()), asScalaBuffer(ImmutableList.of()), null); assertEquals(result.applicationId, "id"); assertEquals(result.executionStatus, AppResult.FAILED_STATUS); assertEquals(result.durationTime, AppResult.FAILED_JOB_DURATION); @@ -205,6 +205,6 @@ public class TestSparkApplicationDataExtractor { ApplicationEnvironmentInfo environmentInfo = Mockito.mock(ApplicationEnvironmentInfo.class); when(environmentInfo.sparkProperties()).thenReturn(asScalaBuffer(ImmutableList.of())); when(environmentInfo.systemProperties()).thenReturn(asScalaBuffer(ImmutableList.of())); - SparkApplicationDataExtractor.extractAppResultFromAppStatusStore(applicationInfo, TEST_WORK_LOAD, environmentInfo, asScalaBuffer(ImmutableList.of()), asScalaBuffer(ImmutableList.of())); + SparkApplicationDataExtractor.extractAppResultFromAppStatusStore(applicationInfo, TEST_WORK_LOAD, environmentInfo, asScalaBuffer(ImmutableList.of()), asScalaBuffer(ImmutableList.of()), null); } } diff --git a/omnioperator/omniop-native-reader/cpp/src/jni/OrcColumnarBatchJniReader.cpp b/omnioperator/omniop-native-reader/cpp/src/jni/OrcColumnarBatchJniReader.cpp index 99956458a90124e168d708fc49358f97549fa622..00f53c8f1517dc66b4aee65fa2f7239347fbf04b 100644 --- a/omnioperator/omniop-native-reader/cpp/src/jni/OrcColumnarBatchJniReader.cpp +++ b/omnioperator/omniop-native-reader/cpp/src/jni/OrcColumnarBatchJniReader.cpp @@ -19,6 +19,7 @@ #include "OrcColumnarBatchJniReader.h" #include +#include #include "jni_common.h" #include "common/UriInfo.h" @@ -334,75 +335,80 @@ JNIEXPORT jlong JNICALL Java_com_huawei_boostkit_scan_jni_OrcColumnarBatchJniRea JNI_FUNC_END(runtimeExceptionClass) } -template uint64_t CopyFixedWidth(orc::ColumnVectorBatch *field) +template +std::unique_ptr CopyFixedWidth(orc::ColumnVectorBatch *field) { using T = typename NativeType::type; ORC_TYPE *lvb = dynamic_cast(field); auto numElements = lvb->numElements; auto values = lvb->data.data(); auto notNulls = lvb->notNull.data(); - auto newVector = new Vector(numElements); + auto newVector = std::make_unique>(numElements); + auto newVectorPtr = newVector.get(); // Check ColumnVectorBatch has null or not firstly if (lvb->hasNulls) { for (uint i = 0; i < numElements; i++) { if (notNulls[i]) { - newVector->SetValue(i, (T)(values[i])); + newVectorPtr->SetValue(i, (T)(values[i])); } else { - newVector->SetNull(i); + newVectorPtr->SetNull(i); } } } else { for (uint i = 0; i < numElements; i++) { - newVector->SetValue(i, (T)(values[i])); + newVectorPtr->SetValue(i, (T)(values[i])); } } - return (uint64_t)newVector; + return newVector; } -template uint64_t CopyOptimizedForInt64(orc::ColumnVectorBatch *field) +template +std::unique_ptr CopyOptimizedForInt64(orc::ColumnVectorBatch *field) { using T = typename NativeType::type; ORC_TYPE *lvb = dynamic_cast(field); auto numElements = lvb->numElements; auto values = lvb->data.data(); auto notNulls = lvb->notNull.data(); - auto newVector = new Vector(numElements); + auto newVector = std::make_unique>(numElements); + auto newVectorPtr = newVector.get(); // Check ColumnVectorBatch has null or not firstly if (lvb->hasNulls) { for (uint i = 0; i < numElements; i++) { if (!notNulls[i]) { - newVector->SetNull(i); + newVectorPtr->SetNull(i); } } } - newVector->SetValues(0, values, numElements); - return (uint64_t)newVector; + newVectorPtr->SetValues(0, values, numElements); + return newVector; } -uint64_t CopyVarWidth(orc::ColumnVectorBatch *field) +std::unique_ptr CopyVarWidth(orc::ColumnVectorBatch *field) { orc::StringVectorBatch *lvb = dynamic_cast(field); auto numElements = lvb->numElements; auto values = lvb->data.data(); auto notNulls = lvb->notNull.data(); auto lens = lvb->length.data(); - auto newVector = new Vector>(numElements); + auto newVector = std::make_unique>>(numElements); + auto newVectorPtr = newVector.get(); if (lvb->hasNulls) { for (uint i = 0; i < numElements; i++) { if (notNulls[i]) { auto data = std::string_view(reinterpret_cast(values[i]), lens[i]); - newVector->SetValue(i, data); + newVectorPtr->SetValue(i, data); } else { - newVector->SetNull(i); + newVectorPtr->SetNull(i); } } } else { for (uint i = 0; i < numElements; i++) { auto data = std::string_view(reinterpret_cast(values[i]), lens[i]); - newVector->SetValue(i, data); + newVectorPtr->SetValue(i, data); } } - return (uint64_t)newVector; + return newVector; } inline void FindLastNotEmpty(const char *chars, long &len) @@ -412,14 +418,15 @@ inline void FindLastNotEmpty(const char *chars, long &len) } } -uint64_t CopyCharType(orc::ColumnVectorBatch *field) +std::unique_ptr CopyCharType(orc::ColumnVectorBatch *field) { orc::StringVectorBatch *lvb = dynamic_cast(field); auto numElements = lvb->numElements; auto values = lvb->data.data(); auto notNulls = lvb->notNull.data(); auto lens = lvb->length.data(); - auto newVector = new Vector>(numElements); + auto newVector = std::make_unique>>(numElements); + auto newVectorPtr = newVector.get(); if (lvb->hasNulls) { for (uint i = 0; i < numElements; i++) { if (notNulls[i]) { @@ -427,9 +434,9 @@ uint64_t CopyCharType(orc::ColumnVectorBatch *field) auto len = lens[i]; FindLastNotEmpty(chars, len); auto data = std::string_view(chars, len); - newVector->SetValue(i, data); + newVectorPtr->SetValue(i, data); } else { - newVector->SetNull(i); + newVectorPtr->SetNull(i); } } } else { @@ -438,202 +445,156 @@ uint64_t CopyCharType(orc::ColumnVectorBatch *field) auto len = lens[i]; FindLastNotEmpty(chars, len); auto data = std::string_view(chars, len); - newVector->SetValue(i, data); + newVectorPtr->SetValue(i, data); } } - return (uint64_t)newVector; + return newVector; } -uint64_t CopyToOmniDecimal128Vec(orc::ColumnVectorBatch *field) +std::unique_ptr CopyToOmniDecimal128Vec(orc::ColumnVectorBatch *field) { orc::Decimal128VectorBatch *lvb = dynamic_cast(field); auto numElements = lvb->numElements; auto values = lvb->values.data(); auto notNulls = lvb->notNull.data(); - auto newVector = new Vector(numElements); + auto newVector = std::make_unique>(numElements); + auto newVectorPtr = newVector.get(); if (lvb->hasNulls) { for (uint i = 0; i < numElements; i++) { if (notNulls[i]) { - newVector->SetValue(i, Decimal128(values[i].getHighBits(), values[i].getLowBits())); + __int128_t dst = values[i].getHighBits(); + dst <<= 64; + dst |= values[i].getLowBits(); + newVectorPtr->SetValue(i, Decimal128(dst)); } else { - newVector->SetNull(i); + newVectorPtr->SetNull(i); } } } else { for (uint i = 0; i < numElements; i++) { - newVector->SetValue(i, Decimal128(values[i].getHighBits(), values[i].getLowBits())); + newVectorPtr->SetValue(i, Decimal128(values[i].getHighBits(), values[i].getLowBits())); } } - return (uint64_t)newVector; + return newVector; } -uint64_t CopyToOmniDecimal64Vec(orc::ColumnVectorBatch *field) +std::unique_ptr CopyToOmniDecimal64Vec(orc::ColumnVectorBatch *field) { orc::Decimal64VectorBatch *lvb = dynamic_cast(field); auto numElements = lvb->numElements; auto values = lvb->values.data(); auto notNulls = lvb->notNull.data(); - auto newVector = new Vector(numElements); + auto newVector = std::make_unique>(numElements); + auto newVectorPtr = newVector.get(); if (lvb->hasNulls) { for (uint i = 0; i < numElements; i++) { if (!notNulls[i]) { - newVector->SetNull(i); + newVectorPtr->SetNull(i); } } } - newVector->SetValues(0, values, numElements); - return (uint64_t)newVector; + newVectorPtr->SetValues(0, values, numElements); + return newVector; } -uint64_t CopyToOmniDecimal128VecFrom64(orc::ColumnVectorBatch *field) +std::unique_ptr CopyToOmniDecimal128VecFrom64(orc::ColumnVectorBatch *field) { orc::Decimal64VectorBatch *lvb = dynamic_cast(field); auto numElements = lvb->numElements; auto values = lvb->values.data(); auto notNulls = lvb->notNull.data(); - auto newVector = new Vector(numElements); + auto newVector = std::make_unique>(numElements); + auto newVectorPtr = newVector.get(); if (lvb->hasNulls) { for (uint i = 0; i < numElements; i++) { if (!notNulls[i]) { - newVector->SetNull(i); + newVectorPtr->SetNull(i); } else { Decimal128 d128(values[i]); - newVector->SetValue(i, d128); + newVectorPtr->SetValue(i, d128); } } } else { for (uint i = 0; i < numElements; i++) { Decimal128 d128(values[i]); - newVector->SetValue(i, d128); + newVectorPtr->SetValue(i, d128); } } - return (uint64_t)newVector; + return newVector; } -uint64_t dealLongVectorBatch(DataTypeId id, orc::ColumnVectorBatch *field) { - switch (id) { - case omniruntime::type::OMNI_BOOLEAN: - return CopyFixedWidth(field); - case omniruntime::type::OMNI_SHORT: - return CopyFixedWidth(field); - case omniruntime::type::OMNI_INT: - return CopyFixedWidth(field); - case omniruntime::type::OMNI_LONG: - return CopyOptimizedForInt64(field); - case omniruntime::type::OMNI_DATE32: - return CopyFixedWidth(field); - case omniruntime::type::OMNI_DATE64: - return CopyOptimizedForInt64(field); - default: { - throw std::runtime_error("dealLongVectorBatch not support for type: " + id); - } - } - return -1; -} - -uint64_t dealDoubleVectorBatch(DataTypeId id, orc::ColumnVectorBatch *field) { - switch (id) { - case omniruntime::type::OMNI_DOUBLE: - return CopyOptimizedForInt64(field); - default: { - throw std::runtime_error("dealDoubleVectorBatch not support for type: " + id); - } - } - return -1; -} - -uint64_t dealDecimal64VectorBatch(DataTypeId id, orc::ColumnVectorBatch *field) { - switch (id) { - case omniruntime::type::OMNI_DECIMAL64: - return CopyToOmniDecimal64Vec(field); - case omniruntime::type::OMNI_DECIMAL128: - return CopyToOmniDecimal128VecFrom64(field); - default: { - throw std::runtime_error("dealDecimal64VectorBatch not support for type: " + id); - } - } - return -1; -} - -uint64_t dealDecimal128VectorBatch(DataTypeId id, orc::ColumnVectorBatch *field) { - switch (id) { - case omniruntime::type::OMNI_DECIMAL128: - return CopyToOmniDecimal128Vec(field); - default: { - throw std::runtime_error("dealDecimal128VectorBatch not support for type: " + id); - } - } - return -1; -} - -int CopyToOmniVec(const orc::Type *type, int omniTypeId, uint64_t &omniVecId, orc::ColumnVectorBatch *field, +std::unique_ptr CopyToOmniVec(const orc::Type *type, int &omniTypeId, orc::ColumnVectorBatch *field, bool isDecimal64Transfor128) { - DataTypeId dataTypeId = static_cast(omniTypeId); switch (type->getKind()) { case orc::TypeKind::BOOLEAN: + omniTypeId = static_cast(OMNI_BOOLEAN); + return CopyFixedWidth(field); case orc::TypeKind::SHORT: + omniTypeId = static_cast(OMNI_SHORT); + return CopyFixedWidth(field); case orc::TypeKind::DATE: + omniTypeId = static_cast(OMNI_DATE32); + return CopyFixedWidth(field); case orc::TypeKind::INT: + omniTypeId = static_cast(OMNI_INT); + return CopyFixedWidth(field); case orc::TypeKind::LONG: - omniVecId = dealLongVectorBatch(dataTypeId, field); - break; + omniTypeId = static_cast(OMNI_LONG); + return CopyOptimizedForInt64(field); case orc::TypeKind::DOUBLE: - omniVecId = dealDoubleVectorBatch(dataTypeId, field); - break; + omniTypeId = static_cast(OMNI_DOUBLE); + return CopyOptimizedForInt64(field); case orc::TypeKind::CHAR: - if (dataTypeId != OMNI_VARCHAR) { - throw std::runtime_error("Cannot transfer to other OMNI_TYPE but VARCHAR for orc char"); - } - omniVecId = CopyCharType(field); - break; + omniTypeId = static_cast(OMNI_VARCHAR); + return CopyCharType(field); case orc::TypeKind::STRING: case orc::TypeKind::VARCHAR: - if (dataTypeId != OMNI_VARCHAR) { - throw std::runtime_error("Cannot transfer to other OMNI_TYPE but VARCHAR for orc string/varchar"); - } - omniVecId = CopyVarWidth(field); - break; + omniTypeId = static_cast(OMNI_VARCHAR); + return CopyVarWidth(field); case orc::TypeKind::DECIMAL: if (type->getPrecision() > MAX_DECIMAL64_DIGITS) { - omniVecId = dealDecimal128VectorBatch(dataTypeId, field); + omniTypeId = static_cast(OMNI_DECIMAL128); + return CopyToOmniDecimal128Vec(field); + } else if (isDecimal64Transfor128) { + omniTypeId = static_cast(OMNI_DECIMAL128); + return CopyToOmniDecimal128VecFrom64(field); } else { - omniVecId = dealDecimal64VectorBatch(dataTypeId, field); + omniTypeId = static_cast(OMNI_DECIMAL64); + return CopyToOmniDecimal64Vec(field); } - break; default: { throw std::runtime_error("Native ColumnarFileScan Not support For This Type: " + type->getKind()); } } - return 1; } JNIEXPORT jlong JNICALL Java_com_huawei_boostkit_scan_jni_OrcColumnarBatchJniReader_recordReaderNext(JNIEnv *env, jobject jObj, jlong rowReader, jlong batch, jintArray typeId, jlongArray vecNativeId) { - JNI_FUNC_START orc::RowReader *rowReaderPtr = (orc::RowReader *)rowReader; orc::ColumnVectorBatch *columnVectorBatch = (orc::ColumnVectorBatch *)batch; + std::vector> omniVecs; + const orc::Type &baseTp = rowReaderPtr->getSelectedType(); - int vecCnt = 0; - long batchRowSize = 0; - auto ptr = env->GetIntArrayElements(typeId, JNI_FALSE); + uint64_t batchRowSize = 0; if (rowReaderPtr->next(*columnVectorBatch)) { orc::StructVectorBatch *root = dynamic_cast(columnVectorBatch); - vecCnt = root->fields.size(); batchRowSize = root->fields[0]->numElements; - for (int id = 0; id < vecCnt; id++) { + int32_t vecCnt = root->fields.size(); + std::vector omniTypeIds(vecCnt, 0); + for (int32_t id = 0; id < vecCnt; id++) { auto type = baseTp.getSubtype(id); - int omniTypeId = ptr[id]; - uint64_t omniVecId = 0; - CopyToOmniVec(type, omniTypeId, omniVecId, root->fields[id], isDecimal64Transfor128); - jlong omniVec = static_cast(omniVecId); + omniVecs.emplace_back(CopyToOmniVec(type, omniTypeIds[id], root->fields[id], isDecimal64Transfor128)); + } + for (int32_t id = 0; id < vecCnt; id++) { + env->SetIntArrayRegion(typeId, id, 1, omniTypeIds.data() + id); + jlong omniVec = reinterpret_cast(omniVecs[id].release()); env->SetLongArrayRegion(vecNativeId, id, 1, &omniVec); } } - return (jlong)batchRowSize; - JNI_FUNC_END(runtimeExceptionClass) + return (jlong) batchRowSize; } /* diff --git a/omnioperator/omniop-native-reader/cpp/src/jni/ParquetColumnarBatchJniReader.cpp b/omnioperator/omniop-native-reader/cpp/src/jni/ParquetColumnarBatchJniReader.cpp index 21c0b81c96b2c83fcade1207ae7a2816da7af7fb..991699a7be573db1f191fe7b20de0a45baf4964d 100644 --- a/omnioperator/omniop-native-reader/cpp/src/jni/ParquetColumnarBatchJniReader.cpp +++ b/omnioperator/omniop-native-reader/cpp/src/jni/ParquetColumnarBatchJniReader.cpp @@ -93,10 +93,14 @@ JNIEXPORT jlong JNICALL Java_com_huawei_boostkit_scan_jni_ParquetColumnarBatchJn { JNI_FUNC_START ParquetReader *pReader = (ParquetReader *)reader; - std::vector recordBatch(pReader->columnReaders.size()); + std::vector recordBatch(pReader->columnReaders.size(), 0); long batchRowSize = 0; auto state = pReader->ReadNextBatch(recordBatch, &batchRowSize); if (state != Status::OK()) { + for (auto vec : recordBatch) { + delete vec; + } + recordBatch.clear(); env->ThrowNew(runtimeExceptionClass, state.ToString().c_str()); return 0; } diff --git a/omnioperator/omniop-native-reader/cpp/src/parquet/ParquetReader.cpp b/omnioperator/omniop-native-reader/cpp/src/parquet/ParquetReader.cpp index 5f4f379e413c854161dc5e04cce146e98e5a148f..672f4ceedffa043b5577f7a954154ae8e8e599e2 100644 --- a/omnioperator/omniop-native-reader/cpp/src/parquet/ParquetReader.cpp +++ b/omnioperator/omniop-native-reader/cpp/src/parquet/ParquetReader.cpp @@ -143,8 +143,12 @@ Status ParquetReader::GetRecordBatchReader(const std::vector &row_group_ind return Status::OK(); } - for (uint64_t i = 0; i < columnReaders.size(); i++) { - RETURN_NOT_OK(columnReaders[i]->NextBatch(read_size, &batch[i])); + try { + for (uint64_t i = 0; i < columnReaders.size(); i++) { + RETURN_NOT_OK(columnReaders[i]->NextBatch(read_size, &batch[i])); + } + } catch (const std::exception &e) { + return Status::Invalid(e.what()); } // Check BaseVector diff --git a/omnioperator/omniop-native-reader/cpp/src/parquet/ParquetTypedRecordReader.h b/omnioperator/omniop-native-reader/cpp/src/parquet/ParquetTypedRecordReader.h index 76108fab68de5c3583b6708cfba60f899a7aa2d9..3f602c979d71e76d2995e99c15f73355c956a98c 100644 --- a/omnioperator/omniop-native-reader/cpp/src/parquet/ParquetTypedRecordReader.h +++ b/omnioperator/omniop-native-reader/cpp/src/parquet/ParquetTypedRecordReader.h @@ -486,11 +486,10 @@ namespace omniruntime::reader { virtual void InitVec(int64_t capacity) { vec_ = new Vector(capacity); + auto capacity_bytes = capacity * byte_width_; if (parquet_vec_ != nullptr) { - auto capacity_bytes = capacity * byte_width_; memset(parquet_vec_, 0, capacity_bytes); } else { - auto capacity_bytes = capacity * byte_width_; parquet_vec_ = new uint8_t[capacity_bytes]; } // Init nulls diff --git a/omnioperator/omniop-spark-extension/cpp/CMakeLists.txt b/omnioperator/omniop-spark-extension/cpp/CMakeLists.txt index 491cfb7086037229608f2963cf6c278ca132b198..10f630ad13925922872540fb13b379a0b52e15b3 100644 --- a/omnioperator/omniop-spark-extension/cpp/CMakeLists.txt +++ b/omnioperator/omniop-spark-extension/cpp/CMakeLists.txt @@ -1,9 +1,9 @@ -# project name -project(spark-thestral-plugin) - # required cmake version cmake_minimum_required(VERSION 3.10) +# project name +project(spark-thestral-plugin) + # configure cmake set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_COMPILER "g++") diff --git a/omnioperator/omniop-spark-extension/cpp/src/common/Buffer.h b/omnioperator/omniop-spark-extension/cpp/src/common/Buffer.h index 73fe13732d27dca87e12ac72900635f8f26cd5f4..ab8a52c229017b4277c7c3b5552477133aa27e4b 100644 --- a/omnioperator/omniop-spark-extension/cpp/src/common/Buffer.h +++ b/omnioperator/omniop-spark-extension/cpp/src/common/Buffer.h @@ -16,29 +16,42 @@ * limitations under the License. */ - #ifndef CPP_BUFFER_H - #define CPP_BUFFER_H +#ifndef CPP_BUFFER_H +#define CPP_BUFFER_H - #include - #include - #include - #include - #include - - class Buffer { - public: - Buffer(uint8_t* data, int64_t size, int64_t capacity) - : data_(data), - size_(size), - capacity_(capacity) { +#include +#include +#include +#include +#include +#include + +class Buffer { +public: + Buffer(uint8_t* data, int64_t size, int64_t capacity, bool isOmniAllocated = true) + : data_(data), + size_(size), + capacity_(capacity), + allocatedByOmni(isOmniAllocated) { + } + + ~Buffer() { + if (allocatedByOmni && not releaseFlag) { + auto *allocator = omniruntime::mem::Allocator::GetAllocator(); + allocator->Free(data_, capacity_); } + } - ~Buffer() {} + void SetReleaseFlag() { + releaseFlag = true; + } - public: - uint8_t * data_; - int64_t size_; - int64_t capacity_; - }; +public: + uint8_t * data_; + int64_t size_; + int64_t capacity_; + bool allocatedByOmni = true; + bool releaseFlag = false; +}; - #endif //CPP_BUFFER_H \ No newline at end of file +#endif //CPP_BUFFER_H \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension/cpp/src/jni/SparkJniWrapper.cpp b/omnioperator/omniop-spark-extension/cpp/src/jni/SparkJniWrapper.cpp index ca982c0a4ca56100cb6c11599d6d0c334009da92..14785a9cf453f5925974f8b85dc80538a5b85a17 100644 --- a/omnioperator/omniop-spark-extension/cpp/src/jni/SparkJniWrapper.cpp +++ b/omnioperator/omniop-spark-extension/cpp/src/jni/SparkJniWrapper.cpp @@ -131,7 +131,6 @@ JNIEXPORT jlong JNICALL Java_com_huawei_boostkit_spark_jni_SparkJniWrapper_nativ JNIEXPORT jlong JNICALL Java_com_huawei_boostkit_spark_jni_SparkJniWrapper_split( JNIEnv *env, jobject jObj, jlong splitter_id, jlong jVecBatchAddress) { - JNI_FUNC_START auto splitter = g_shuffleSplitterHolder.Lookup(splitter_id); if (!splitter) { std::string error_message = "Invalid splitter id " + std::to_string(splitter_id); @@ -140,10 +139,11 @@ JNIEXPORT jlong JNICALL Java_com_huawei_boostkit_spark_jni_SparkJniWrapper_split } auto vecBatch = (VectorBatch *) jVecBatchAddress; - + splitter->SetInputVecBatch(vecBatch); + JNI_FUNC_START splitter->Split(*vecBatch); return 0L; - JNI_FUNC_END(runtimeExceptionClass) + JNI_FUNC_END_WITH_VECBATCH(runtimeExceptionClass, splitter->GetInputVecBatch()) } JNIEXPORT jobject JNICALL Java_com_huawei_boostkit_spark_jni_SparkJniWrapper_stop( diff --git a/omnioperator/omniop-spark-extension/cpp/src/jni/jni_common.h b/omnioperator/omniop-spark-extension/cpp/src/jni/jni_common.h index 4b59296e152876062a06db3d69c81a7ed22b670b..964fab6dfc06ac692294fa212f40afc21a4d1041 100644 --- a/omnioperator/omniop-spark-extension/cpp/src/jni/jni_common.h +++ b/omnioperator/omniop-spark-extension/cpp/src/jni/jni_common.h @@ -48,6 +48,15 @@ jmethodID GetMethodID(JNIEnv* env, jclass this_class, const char* name, const ch return; \ } \ +#define JNI_FUNC_END_WITH_VECBATCH(exceptionClass, toDeleteVecBatch) \ + } \ + catch (const std::exception &e) \ + { \ + VectorHelper::FreeVecBatch(toDeleteVecBatch); \ + env->ThrowNew(exceptionClass, e.what()); \ + return 0; \ + } + extern jclass runtimeExceptionClass; extern jclass splitResultClass; extern jclass jsonClass; diff --git a/omnioperator/omniop-spark-extension/cpp/src/shuffle/splitter.cpp b/omnioperator/omniop-spark-extension/cpp/src/shuffle/splitter.cpp index f0a9b3e76ab9956581a13744b16851cd4f34d040..b0deafe611c37e67561e6092e61c64a4fd67a294 100644 --- a/omnioperator/omniop-spark-extension/cpp/src/shuffle/splitter.cpp +++ b/omnioperator/omniop-spark-extension/cpp/src/shuffle/splitter.cpp @@ -449,6 +449,7 @@ int Splitter::DoSplit(VectorBatch& vb) { num_row_splited_ += vb.GetRowCount(); // release the fixed width vector and release vectorBatch at the same time ReleaseVectorBatch(&vb); + this->ResetInputVecBatch(); // 阈值检查,是否溢写 if (num_row_splited_ >= SHUFFLE_SPILL_NUM_ELEMENTS_FORCE_SPILL_THRESHOLD) { @@ -675,7 +676,7 @@ void Splitter::SerializingFixedColumns(int32_t partitionId, valueStr.resize(onceCopyLen); std::string nullStr; - std::shared_ptr ptr_value (new Buffer((uint8_t*)valueStr.data(), 0, onceCopyLen)); + std::shared_ptr ptr_value (new Buffer((uint8_t*)valueStr.data(), 0, onceCopyLen, false)); std::shared_ptr ptr_validity; // options_.spill_batch_row_num长度切割与拼接 @@ -698,7 +699,7 @@ void Splitter::SerializingFixedColumns(int32_t partitionId, splitRowInfoTmp->cacheBatchCopyedLen[fixColIndexTmp]); if (not nullAllocated && partition_cached_vectorbatch_[partitionId][splitRowInfoTmp->cacheBatchIndex[fixColIndexTmp]][fixColIndexTmp][0] != nullptr) { nullStr.resize(splitRowInfoTmp->onceCopyRow); - ptr_validity.reset(new Buffer((uint8_t*)nullStr.data(), 0, splitRowInfoTmp->onceCopyRow)); + ptr_validity.reset(new Buffer((uint8_t*)nullStr.data(), 0, splitRowInfoTmp->onceCopyRow, false)); nullAllocated = true; } if ((onceCopyLen - destCopyedLength) >= (cacheBatchSize - splitRowInfoTmp->cacheBatchCopyedLen[fixColIndexTmp])) { @@ -714,9 +715,11 @@ void Splitter::SerializingFixedColumns(int32_t partitionId, // 释放内存 options_.allocator->Free(partition_cached_vectorbatch_[partitionId][splitRowInfoTmp->cacheBatchIndex[fixColIndexTmp]][fixColIndexTmp][0]->data_, partition_cached_vectorbatch_[partitionId][splitRowInfoTmp->cacheBatchIndex[fixColIndexTmp]][fixColIndexTmp][0]->capacity_); + partition_cached_vectorbatch_[partitionId][splitRowInfoTmp->cacheBatchIndex[fixColIndexTmp]][fixColIndexTmp][0]->SetReleaseFlag(); } options_.allocator->Free(partition_cached_vectorbatch_[partitionId][splitRowInfoTmp->cacheBatchIndex[fixColIndexTmp]][fixColIndexTmp][1]->data_, partition_cached_vectorbatch_[partitionId][splitRowInfoTmp->cacheBatchIndex[fixColIndexTmp]][fixColIndexTmp][1]->capacity_); + partition_cached_vectorbatch_[partitionId][splitRowInfoTmp->cacheBatchIndex[fixColIndexTmp]][fixColIndexTmp][1]->SetReleaseFlag(); destCopyedLength += memCopyLen; splitRowInfoTmp->cacheBatchIndex[fixColIndexTmp] += 1; // cacheBatchIndex下标后移 splitRowInfoTmp->cacheBatchCopyedLen[fixColIndexTmp] = 0; // 初始化下一个cacheBatch的起始偏移 @@ -1058,6 +1061,7 @@ int Splitter::DeleteSpilledTmpFile() { auto tmpDataFilePath = pair.first + ".data"; // 释放存储有各个临时文件的偏移数据内存 options_.allocator->Free(pair.second->data_, pair.second->capacity_); + pair.second->SetReleaseFlag(); if (IsFileExist(tmpDataFilePath)) { remove(tmpDataFilePath.c_str()); } diff --git a/omnioperator/omniop-spark-extension/cpp/src/shuffle/splitter.h b/omnioperator/omniop-spark-extension/cpp/src/shuffle/splitter.h index 4e9684de963a3bb4eedf0720ec8a01c1d2160f2b..d54b97849dadb06c72c6a456aedd9eefa1f21092 100644 --- a/omnioperator/omniop-spark-extension/cpp/src/shuffle/splitter.h +++ b/omnioperator/omniop-spark-extension/cpp/src/shuffle/splitter.h @@ -165,6 +165,7 @@ private: } } vectorAddress.clear(); + vb->ClearVectors(); delete vb; } @@ -173,7 +174,7 @@ private: std::vector vector_batch_col_types_; InputDataTypes input_col_types; std::vector binary_array_empirical_size_; - + omniruntime::vec::VectorBatch *inputVecBatch = nullptr; public: bool singlePartitionFlag = false; int32_t num_partitions_; @@ -221,6 +222,22 @@ public: int64_t TotalComputePidTime() const { return total_compute_pid_time_; } const std::vector& PartitionLengths() const { return partition_lengths_; } + + omniruntime::vec::VectorBatch *GetInputVecBatch() + { + return inputVecBatch; + } + + void SetInputVecBatch(omniruntime::vec::VectorBatch *inVecBatch) + { + inputVecBatch = inVecBatch; + } + + // no need to clear memory when exception, so we have to reset + void ResetInputVecBatch() + { + inputVecBatch = nullptr; + } }; diff --git a/omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/serialize/ShuffleDataSerializer.java b/omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/serialize/ShuffleDataSerializer.java index 9f6cadf70b7348a1696e140a13887c87018af4b2..6a0c1b27c4282016aecada2ba4ef0c48c320f20f 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/serialize/ShuffleDataSerializer.java +++ b/omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/serialize/ShuffleDataSerializer.java @@ -20,6 +20,7 @@ package com.huawei.boostkit.spark.serialize; import com.google.protobuf.InvalidProtocolBufferException; +import nova.hetu.omniruntime.utils.OmniRuntimeException; import nova.hetu.omniruntime.vector.BooleanVec; import nova.hetu.omniruntime.vector.Decimal128Vec; import nova.hetu.omniruntime.vector.DoubleVec; @@ -35,21 +36,31 @@ import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.vectorized.ColumnVector; import org.apache.spark.sql.vectorized.ColumnarBatch; - public class ShuffleDataSerializer { public static ColumnarBatch deserialize(byte[] bytes) { + ColumnVector[] vecs = null; try { VecData.VecBatch vecBatch = VecData.VecBatch.parseFrom(bytes); int vecCount = vecBatch.getVecCnt(); int rowCount = vecBatch.getRowCnt(); - ColumnVector[] vecs = new ColumnVector[vecCount]; + vecs = new ColumnVector[vecCount]; for (int i = 0; i < vecCount; i++) { vecs[i] = buildVec(vecBatch.getVecs(i), rowCount); } return new ColumnarBatch(vecs, rowCount); } catch (InvalidProtocolBufferException e) { throw new RuntimeException("deserialize failed. errmsg:" + e.getMessage()); + } catch (OmniRuntimeException e) { + if (vecs != null) { + for (int i = 0; i < vecs.length; i++) { + ColumnVector vec = vecs[i]; + if (vec != null) { + vec.close(); + } + } + } + throw new RuntimeException("deserialize failed. errmsg:" + e.getMessage()); } } diff --git a/omnioperator/omniop-spark-extension/java/src/main/java/org/apache/spark/sql/execution/datasources/orc/OmniOrcColumnarBatchReader.java b/omnioperator/omniop-spark-extension/java/src/main/java/org/apache/spark/sql/execution/datasources/orc/OmniOrcColumnarBatchReader.java index e8e7db3af876ce726d9b22e8ead47cf9f15a9fbe..e301372f5223be6c432df82fd38d9d1d417f8126 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/java/org/apache/spark/sql/execution/datasources/orc/OmniOrcColumnarBatchReader.java +++ b/omnioperator/omniop-spark-extension/java/src/main/java/org/apache/spark/sql/execution/datasources/orc/OmniOrcColumnarBatchReader.java @@ -118,8 +118,10 @@ public class OmniOrcColumnarBatchReader extends RecordReader - if (!isSlice) { - vector.getVec - } else { - vector.getVec.slice(0, cb.numRows()) + try { + for (i <- 0 until cb.numCols()) { + val omniVec: Vec = cb.column(i) match { + case vector: OmniColumnVector => + if (!isSlice) { + vector.getVec + } else { + vector.getVec.slice(0, cb.numRows()) + } + case vector: ColumnVector => + transColumnVector(vector, cb.numRows()) + case _ => + throw new UnsupportedOperationException("unsupport column vector!") + } + input(i) = omniVec + } + } catch { + case e: Exception => { + for (j <- 0 until cb.numCols()) { + val vec = input(j) + if (vec != null) vec.close + cb.column(j) match { + case vector: OmniColumnVector => + vector.close() } - case vector: ColumnVector => - transColumnVector(vector, cb.numRows()) - case _ => - throw new UnsupportedOperationException("unsupport column vector!") + } + throw new RuntimeException("allocate memory failed!") } - input(i) = omniVec } input } diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarExec.scala index 77707365e2e92680b027c4649144a0d599768720..33e8037a12aa253da8b7a40f0895c511d31cfec8 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarExec.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarExec.scala @@ -312,7 +312,7 @@ object ColumnarBatchToInternalRow { val batchIter = batches.flatMap { batch => - // toClosedVecs closed case: + // toClosedVecs closed case: [Deprcated] // 1) all rows of batch fetched and closed // 2) only fetch parital rows(eg: top-n, limit-n), closed at task CompletionListener callback val toClosedVecs = new ListBuffer[Vec] @@ -330,27 +330,22 @@ object ColumnarBatchToInternalRow { new Iterator[InternalRow] { val numOutputRowsMetric: SQLMetric = numOutputRows - var closed = false - - // only invoke if fetch partial rows of batch - if (mayPartialFetch) { - SparkMemoryUtils.addLeakSafeTaskCompletionListener { _ => - if (!closed) { - toClosedVecs.foreach {vec => - vec.close() - } + + + SparkMemoryUtils.addLeakSafeTaskCompletionListener { _ => + toClosedVecs.foreach {vec => + vec.close() } - } } override def hasNext: Boolean = { val has = iter.hasNext - // fetch all rows and closed - if (!has && !closed) { + // fetch all rows + if (!has) { toClosedVecs.foreach {vec => vec.close() + toClosedVecs.remove(toClosedVecs.indexOf(vec)) } - closed = true } has } diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarExpandExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarExpandExec.scala index 11541ccee1c0e9ad775eb71360bf0d081d72bcd2..12fbdb6b4a2c5b8b4eec6ea51fbe7f6d919dbb76 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarExpandExec.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarExpandExec.scala @@ -330,24 +330,27 @@ case class ColumnarOptRollupExec( omniOutputPartials) omniCodegenTimeMetric += NANOSECONDS.toMillis(System.nanoTime() - startCodegen) + + val results = new ListBuffer[VecBatch]() + var hashaggResults: java.util.Iterator[VecBatch] = null + // close operator addLeakSafeTaskCompletionListener[Unit](_ => { projectOperators.foreach(operator => operator.close()) hashaggOperator.close() + results.foreach(vecBatch => { + vecBatch.releaseAllVectors() + vecBatch.close() + }) }) - val results = new ListBuffer[VecBatch]() - var hashaggResults: java.util.Iterator[VecBatch] = null - while (iter.hasNext) { val batch = iter.next() val input = transColBatchToOmniVecs(batch) val vecBatch = new VecBatch(input, batch.numRows()) results.append(vecBatch) projectOperators.foreach(projectOperator => { - val vecs = vecBatch.getVectors.map(vec => { - vec.slice(0, vecBatch.getRowCount) - }) + val vecs = transColBatchToOmniVecs(batch, true) val projectInput = new VecBatch(vecs, vecBatch.getRowCount) var startInput = System.nanoTime() diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarShuffleExchangeExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarShuffleExchangeExec.scala index f928ad11c258bb07c64f11d065adadb5f439dcc0..233503141dc11a6d74ea8e000a02594cb4607116 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarShuffleExchangeExec.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarShuffleExchangeExec.scala @@ -40,7 +40,7 @@ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec.createShuffleWriteProcessor import org.apache.spark.sql.execution.exchange.{ENSURE_REQUIREMENTS, ShuffleExchangeExec, ShuffleExchangeLike, ShuffleOrigin} import org.apache.spark.sql.execution.metric._ -import org.apache.spark.sql.execution.util.MergeIterator +import org.apache.spark.sql.execution.util.{MergeIterator, SparkMemoryUtils} import org.apache.spark.sql.execution.util.SparkMemoryUtils.addLeakSafeTaskCompletionListener import org.apache.spark.sql.execution.vectorized.OmniColumnVector import org.apache.spark.sql.internal.SQLConf @@ -167,10 +167,14 @@ class ColumnarShuffleExchangeExec( if (enableShuffleBatchMerge) { cachedShuffleRDD.mapPartitionsWithIndexInternal { (index, iter) => - new MergeIterator(iter, + val mergeIterator = new MergeIterator(iter, StructType.fromAttributes(child.output), longMetric("numMergedVecBatches"), longMetric("bypassVecBatches")) + SparkMemoryUtils.addLeakSafeTaskCompletionListener[Unit](_ => { + mergeIterator.close() + }) + mergeIterator } } else { cachedShuffleRDD diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/adaptive/ColumnarCustomShuffleReaderExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/adaptive/ColumnarCustomShuffleReaderExec.scala index d7365c8c152b507c366d917cce8fd127acf55a2b..098b606becd652009b272f2bbbc4e130f81e8f80 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/adaptive/ColumnarCustomShuffleReaderExec.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/adaptive/ColumnarCustomShuffleReaderExec.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, UnknownPartit import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.exchange.{ReusedExchangeExec, ShuffleExchangeLike} import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} -import org.apache.spark.sql.execution.util.MergeIterator +import org.apache.spark.sql.execution.util.{MergeIterator, SparkMemoryUtils} import org.apache.spark.sql.types.StructType import org.apache.spark.sql.vectorized.ColumnarBatch @@ -207,10 +207,14 @@ case class ColumnarCustomShuffleReaderExec( if (enableShuffleBatchMerge) { cachedShuffleRDD.mapPartitionsWithIndexInternal { (index, iter) => - new MergeIterator(iter, + val mergeIterator = new MergeIterator(iter, StructType.fromAttributes(child.output), longMetric("numMergedVecBatches"), longMetric("bypassVecBatches")) + SparkMemoryUtils.addLeakSafeTaskCompletionListener[Unit](_ => { + mergeIterator.close() + }) + mergeIterator } } else { cachedShuffleRDD diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/joins/ColumnarBroadcastHashJoinExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/joins/ColumnarBroadcastHashJoinExec.scala index 40e43f707a5fd9cfe647d677df82558bfc0927e3..d5fd3a10b81aadba5acdc378bef55d63955f7d4b 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/joins/ColumnarBroadcastHashJoinExec.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/joins/ColumnarBroadcastHashJoinExec.scala @@ -340,7 +340,7 @@ case class ColumnarBroadcastHashJoinExec( Optional.empty() } - def createBuildOpFactoryAndOp(): (OmniHashBuilderWithExprOperatorFactory, OmniOperator) = { + def createBuildOpFactoryAndOp(isShared: Boolean): (OmniHashBuilderWithExprOperatorFactory, OmniOperator) = { val startBuildCodegen = System.nanoTime() val opFactory = new OmniHashBuilderWithExprOperatorFactory(lookupJoinType, buildTypes, buildJoinColsExp, 1, @@ -349,6 +349,10 @@ case class ColumnarBroadcastHashJoinExec( val op = opFactory.createOperator() buildCodegenTime += NANOSECONDS.toMillis(System.nanoTime() - startBuildCodegen) + if (isShared) { + OmniHashBuilderWithExprOperatorFactory.saveHashBuilderOperatorAndFactory(buildPlan.id, + opFactory, op) + } val deserializer = VecBatchSerializerFactory.create() relation.value.buildData.foreach { input => val startBuildInput = System.nanoTime() @@ -356,7 +360,19 @@ case class ColumnarBroadcastHashJoinExec( buildAddInputTime += NANOSECONDS.toMillis(System.nanoTime() - startBuildInput) } val startBuildGetOp = System.nanoTime() - op.getOutput + try { + op.getOutput + } catch { + case e: Exception => { + if (isShared) { + OmniHashBuilderWithExprOperatorFactory.dereferenceHashBuilderOperatorAndFactory(buildPlan.id) + } else { + op.close() + opFactory.close() + } + throw new RuntimeException("HashBuilder getOutput failed") + } + } buildGetOutputTime += NANOSECONDS.toMillis(System.nanoTime() - startBuildGetOp) (opFactory, op) } @@ -368,11 +384,9 @@ case class ColumnarBroadcastHashJoinExec( try { buildOpFactory = OmniHashBuilderWithExprOperatorFactory.getHashBuilderOperatorFactory(buildPlan.id) if (buildOpFactory == null) { - val (opFactory, op) = createBuildOpFactoryAndOp() + val (opFactory, op) = createBuildOpFactoryAndOp(true) buildOpFactory = opFactory buildOp = op - OmniHashBuilderWithExprOperatorFactory.saveHashBuilderOperatorAndFactory(buildPlan.id, - buildOpFactory, buildOp) } } catch { case e: Exception => { @@ -382,7 +396,7 @@ case class ColumnarBroadcastHashJoinExec( OmniHashBuilderWithExprOperatorFactory.gLock.unlock() } } else { - val (opFactory, op) = createBuildOpFactoryAndOp() + val (opFactory, op) = createBuildOpFactoryAndOp(false) buildOpFactory = opFactory buildOp = op } @@ -492,7 +506,11 @@ case class ColumnarBroadcastHashJoinExec( } if (enableJoinBatchMerge) { - new MergeIterator(iterBatch, resultSchema, numMergedVecBatches, bypassVecBatches) + val mergeIterator = new MergeIterator(iterBatch, resultSchema, numMergedVecBatches, bypassVecBatches) + SparkMemoryUtils.addLeakSafeTaskCompletionListener[Unit](_ => { + mergeIterator.close() + }) + mergeIterator } else { iterBatch } diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/joins/ColumnarSortMergeJoinExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/joins/ColumnarSortMergeJoinExec.scala index f3fb45a1cab4a0b1f9fe23c45c9f298b728a3761..8a61c3fdbe61f4a568d10c796dd32ba5e64bbab0 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/joins/ColumnarSortMergeJoinExec.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/joins/ColumnarSortMergeJoinExec.scala @@ -285,7 +285,11 @@ case class ColumnarSortMergeJoinExec( } if (enableSortMergeJoinBatchMerge) { - new MergeIterator(iterBatch, resultSchema, numMergedVecBatches, bypassVecBatches) + val mergeIterator = new MergeIterator(iterBatch, resultSchema, numMergedVecBatches, bypassVecBatches) + SparkMemoryUtils.addLeakSafeTaskCompletionListener[Unit](_ => { + mergeIterator.close() + }) + mergeIterator } else { iterBatch } diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/util/MergeIterator.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/util/MergeIterator.scala index 105d3435367c2818c7f2b1da267ca19f43ce27ef..cea35f2a37d159fee797632c49101939b09ae67c 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/util/MergeIterator.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/util/MergeIterator.scala @@ -44,29 +44,40 @@ class MergeIterator(iter: Iterator[ColumnarBatch], localSchema: StructType, private def createOmniVectors(schema: StructType, columnSize: Int): Array[Vec] = { val vecs = new Array[Vec](schema.fields.length) - schema.fields.zipWithIndex.foreach { case (field, index) => - field.dataType match { - case LongType => - vecs(index) = new LongVec(columnSize) - case DateType | IntegerType => - vecs(index) = new IntVec(columnSize) - case ShortType => - vecs(index) = new ShortVec(columnSize) - case DoubleType => - vecs(index) = new DoubleVec(columnSize) - case BooleanType => - vecs(index) = new BooleanVec(columnSize) - case StringType => - val vecType: DataType = sparkTypeToOmniType(field.dataType, field.metadata) - vecs(index) = new VarcharVec(columnSize) - case dt: DecimalType => - if (DecimalType.is64BitDecimalType(dt)) { + try { + schema.fields.zipWithIndex.foreach { case (field, index) => + field.dataType match { + case LongType => vecs(index) = new LongVec(columnSize) - } else { - vecs(index) = new Decimal128Vec(columnSize) + case DateType | IntegerType => + vecs(index) = new IntVec(columnSize) + case ShortType => + vecs(index) = new ShortVec(columnSize) + case DoubleType => + vecs(index) = new DoubleVec(columnSize) + case BooleanType => + vecs(index) = new BooleanVec(columnSize) + case StringType => + val vecType: DataType = sparkTypeToOmniType(field.dataType, field.metadata) + vecs(index) = new VarcharVec(columnSize) + case dt: DecimalType => + if (DecimalType.is64BitDecimalType(dt)) { + vecs(index) = new LongVec(columnSize) + } else { + vecs(index) = new Decimal128Vec(columnSize) + } + case _ => + throw new UnsupportedOperationException("Fail to create omni vector, unsupported fields") + } + } + } catch { + case e: Exception => { + for (vec <- vecs) { + if (vec != null) { + vec.close() } - case _ => - throw new UnsupportedOperationException("Fail to create omni vector, unsupported fields") + } + throw new RuntimeException("allocate memory failed!") } } vecs @@ -165,4 +176,15 @@ class MergeIterator(iter: Iterator[ColumnarBatch], localSchema: StructType, def isFull(): Boolean = { totalRows > maxRowCount || currentBatchSizeInBytes >= maxBatchSizeInBytes } + + def close(): Unit = { + for (elem <- bufferedVecBatch) { + elem.releaseAllVectors() + elem.close() + } + for (elem <- outputQueue) { + elem.releaseAllVectors() + elem.close() + } + } }