diff --git a/omniadvisor/pom.xml b/omniadvisor/pom.xml
index 74ad839147c027537db0dee19ef162f0df6743e7..d72c33d17aa732bcbca24f960e15d14c3c1ad05c 100644
--- a/omniadvisor/pom.xml
+++ b/omniadvisor/pom.xml
@@ -305,6 +305,10 @@
org.apache.commons
commons-text
+
+ org.apache.commons
+ commons-compress
+
com.google.guava
guava
diff --git a/omniadvisor/src/main/scala/org/apache/spark/SparkApplicationDataExtractor.scala b/omniadvisor/src/main/scala/org/apache/spark/SparkApplicationDataExtractor.scala
index 65368433a2bf733b6c5bdb63784980528765a712..b15800ef498fa27c58c80a1d0133432066e6486a 100644
--- a/omniadvisor/src/main/scala/org/apache/spark/SparkApplicationDataExtractor.scala
+++ b/omniadvisor/src/main/scala/org/apache/spark/SparkApplicationDataExtractor.scala
@@ -17,11 +17,10 @@ package org.apache.spark
import com.huawei.boostkit.omniadvisor.fetcher.FetcherType
import com.huawei.boostkit.omniadvisor.models.AppResult
-import com.huawei.boostkit.omniadvisor.spark.utils.ScalaUtils.parseMapToJsonString
-import com.huawei.boostkit.omniadvisor.spark.utils.ScalaUtils.checkSuccess
+import com.huawei.boostkit.omniadvisor.spark.utils.ScalaUtils.{checkSuccess, parseMapToJsonString}
import com.huawei.boostkit.omniadvisor.spark.utils.SparkUtils
import com.nimbusds.jose.util.StandardCharset
-import org.apache.spark.sql.execution.ui.SQLExecutionUIData
+import org.apache.spark.sql.execution.ui.{SQLExecutionUIData, SparkPlanGraph}
import org.apache.spark.status.api.v1._
import org.slf4j.{Logger, LoggerFactory}
@@ -41,7 +40,8 @@ object SparkApplicationDataExtractor {
workload: String,
environmentInfo: ApplicationEnvironmentInfo,
jobsList: Seq[JobData],
- sqlExecutionsList: Seq[SQLExecutionUIData]): AppResult = {
+ sqlExecutionsList: Seq[SQLExecutionUIData],
+ sqlGraphMap: mutable.Map[Long, SparkPlanGraph]): AppResult = {
val appResult = new AppResult
appResult.applicationId = appInfo.id
appResult.applicationName = appInfo.name
@@ -57,7 +57,7 @@ object SparkApplicationDataExtractor {
val attempt: ApplicationAttemptInfo = lastAttempt(appInfo)
if (attempt.completed && jobsList.nonEmpty && checkSuccess(jobsList)) {
- saveSuccessfulStatus(appResult, jobsList, sqlExecutionsList)
+ saveSuccessfulStatus(appResult, jobsList, sqlExecutionsList, sqlGraphMap)
} else {
saveFailedStatus(appResult, attempt)
}
@@ -112,7 +112,7 @@ object SparkApplicationDataExtractor {
appResult.query = ""
}
- private def saveSuccessfulStatus(appResult: AppResult, jobsList: Seq[JobData], sqlExecutionsList: Seq[SQLExecutionUIData]): Unit = {
+ private def saveSuccessfulStatus(appResult: AppResult, jobsList: Seq[JobData], sqlExecutionsList: Seq[SQLExecutionUIData], sqlGraphMap: mutable.Map[Long, SparkPlanGraph]): Unit = {
appResult.executionStatus = AppResult.SUCCEEDED_STATUS
val (startTime, finishTime) = extractJobsTime(jobsList)
@@ -122,7 +122,7 @@ object SparkApplicationDataExtractor {
finishTime - startTime else AppResult.FAILED_JOB_DURATION
if (appResult.submit_method.equals(AppResult.SPARK_SQL)) {
- appResult.query = extractQuerySQL(sqlExecutionsList)
+ appResult.query = extractQuerySQL(sqlExecutionsList, sqlGraphMap)
} else {
appResult.query = ""
}
@@ -153,16 +153,32 @@ object SparkApplicationDataExtractor {
(startTime, finishTime)
}
- private def extractQuerySQL(sqlExecutionsList: Seq[SQLExecutionUIData]): String = {
+ private def extractQuerySQL(sqlExecutionsList: Seq[SQLExecutionUIData], sqlGraphMap: mutable.Map[Long, SparkPlanGraph]): String = {
require(sqlExecutionsList.nonEmpty)
- val nonEmptyDescriptions = sqlExecutionsList.flatMap { execution =>
- Option(execution.description).filter(_.nonEmpty)
+ val descriptions: mutable.ArrayBuffer[String] = mutable.ArrayBuffer.empty[String]
+ var previousExecutionDesc: Option[String] = None
+ for ((execution, index) <- sqlExecutionsList.zipWithIndex) {
+ if (index > 0) {
+ val sqlGraph = sqlGraphMap(execution.executionId)
+ if (execution.description.nonEmpty && !isDuplicatedQuery(execution, previousExecutionDesc, sqlGraph)) {
+ descriptions += execution.description.trim
+ }
+ } else {
+ if (execution.description.nonEmpty) {
+ descriptions += execution.description.trim
+ }
+ }
+ previousExecutionDesc = Some(execution.description)
}
- nonEmptyDescriptions.mkString(";")
+ descriptions.mkString(";\n")
}
private def lastAttempt(applicationInfo: ApplicationInfo): ApplicationAttemptInfo = {
require(applicationInfo.attempts.nonEmpty)
applicationInfo.attempts.last
}
+
+ private def isDuplicatedQuery(execution: SQLExecutionUIData, previousExecutionDesc: Option[String], sqlGraph: SparkPlanGraph): Boolean = {
+ execution.description.equals(previousExecutionDesc.getOrElse("")) && sqlGraph.allNodes.size == 1 && sqlGraph.allNodes.head.name.equals("LocalTableScan")
+ }
}
diff --git a/omniadvisor/src/main/scala/org/apache/spark/SparkDataCollection.scala b/omniadvisor/src/main/scala/org/apache/spark/SparkDataCollection.scala
index 236469cce9da0828965d2a694a426e214069e1fb..aa5a4bf7d8b1c68c249575acd90e94518a62376e 100644
--- a/omniadvisor/src/main/scala/org/apache/spark/SparkDataCollection.scala
+++ b/omniadvisor/src/main/scala/org/apache/spark/SparkDataCollection.scala
@@ -16,16 +16,17 @@
package org.apache.spark
import com.huawei.boostkit.omniadvisor.models.AppResult
-import org.apache.spark.status.api.v1
-import org.apache.spark.util.kvstore.{InMemoryStore, KVStore}
import org.apache.spark.internal.config.Status.ASYNC_TRACKING_ENABLED
import org.apache.spark.scheduler.ReplayListenerBus
+import org.apache.spark.sql.execution.ui.{SQLAppStatusListener, SQLAppStatusStore, SQLExecutionUIData, SparkPlanGraph}
+import org.apache.spark.status.api.v1
import org.apache.spark.status.{AppStatusListener, AppStatusStore, ElementTrackingStore}
-import org.apache.spark.sql.execution.ui.{SQLAppStatusListener, SQLAppStatusStore, SQLExecutionUIData}
import org.apache.spark.util.Utils
+import org.apache.spark.util.kvstore.{InMemoryStore, KVStore}
import org.slf4j.{Logger, LoggerFactory}
import java.io.InputStream
+import scala.collection.mutable
class SparkDataCollection {
val LOG: Logger = LoggerFactory.getLogger(classOf[SparkDataCollection])
@@ -36,6 +37,7 @@ class SparkDataCollection {
var jobsList: Seq[v1.JobData] = _
var appInfo: v1.ApplicationInfo = _
var sqlExecutionsList: Seq[SQLExecutionUIData] = _
+ var sqlGraphMap: mutable.Map[Long, SparkPlanGraph] = _
def replayEventLogs(in: InputStream, sourceName: String): Unit = {
@@ -66,12 +68,22 @@ class SparkDataCollection {
val sqlAppStatusStore: SQLAppStatusStore = new SQLAppStatusStore(store)
sqlExecutionsList = sqlAppStatusStore.executionsList()
+ sqlGraphMap = mutable.HashMap.empty[Long, SparkPlanGraph]
+ sqlExecutionsList.foreach { sqlExecution =>
+ try {
+ val planGraph = sqlAppStatusStore.planGraph(sqlExecution.executionId)
+ sqlGraphMap.put(sqlExecution.executionId, planGraph)
+ } catch {
+ case e: Exception =>
+ LOG.warn(s"Get PlanGraph for SQLExecution [${sqlExecution.executionId}] in ${appInfo.id} failed")
+ }
+ }
appStatusStore.close()
}
def getAppResult(workload: String): AppResult = {
- SparkApplicationDataExtractor.extractAppResultFromAppStatusStore(appInfo, workload, environmentInfo, jobsList, sqlExecutionsList)
+ SparkApplicationDataExtractor.extractAppResultFromAppStatusStore(appInfo, workload, environmentInfo, jobsList, sqlExecutionsList, sqlGraphMap)
}
private def createInMemoryStore(): KVStore = {
diff --git a/omniadvisor/src/test/java/org/apache/spark/TestSparkApplicationDataExtractor.java b/omniadvisor/src/test/java/org/apache/spark/TestSparkApplicationDataExtractor.java
index 7eed130902a290b108988aadbedd3db301053c24..02953cb7cc46400f5de9343995b8e07dfbf47d28 100644
--- a/omniadvisor/src/test/java/org/apache/spark/TestSparkApplicationDataExtractor.java
+++ b/omniadvisor/src/test/java/org/apache/spark/TestSparkApplicationDataExtractor.java
@@ -113,7 +113,7 @@ public class TestSparkApplicationDataExtractor {
sqlExecutionList.add(new SQLExecutionUIData(1, sqlQ10, "", "", asScalaBuffer(ImmutableList.of()), 0, Option.apply(new Date()), new HashMap<>(), new HashSet<>(), new HashMap<>()));
sqlExecutionList.add(new SQLExecutionUIData(2, sqlQ11, "", "", asScalaBuffer(ImmutableList.of()), 0, Option.apply(new Date()), new HashMap<>(), new HashSet<>(), new HashMap<>()));
- AppResult result = SparkApplicationDataExtractor.extractAppResultFromAppStatusStore(applicationInfo, TEST_WORK_LOAD, environmentInfo, asScalaBuffer(jobsList), asScalaBuffer(sqlExecutionList));
+ AppResult result = SparkApplicationDataExtractor.extractAppResultFromAppStatusStore(applicationInfo, TEST_WORK_LOAD, environmentInfo, asScalaBuffer(jobsList), asScalaBuffer(sqlExecutionList), null);
assertEquals(result.applicationId, "id");
assertEquals(result.durationTime, 15 * 60 * 1000L);
assertEquals(result.submit_method, "spark-sql");
@@ -134,7 +134,7 @@ public class TestSparkApplicationDataExtractor {
when(environmentInfo.sparkProperties()).thenReturn(asScalaBuffer(clientSparkProperties));
when(environmentInfo.systemProperties()).thenReturn(asScalaBuffer(systemProperties));
- AppResult result = SparkApplicationDataExtractor.extractAppResultFromAppStatusStore(applicationInfo, TEST_WORK_LOAD, environmentInfo, asScalaBuffer(jobsList), asScalaBuffer(ImmutableList.of()));
+ AppResult result = SparkApplicationDataExtractor.extractAppResultFromAppStatusStore(applicationInfo, TEST_WORK_LOAD, environmentInfo, asScalaBuffer(jobsList), asScalaBuffer(ImmutableList.of()), null);
assertEquals(result.applicationId, "id");
assertEquals(result.durationTime, 15 * 60 * 1000L);
assertEquals(result.submit_method, "spark-submit");
@@ -154,7 +154,7 @@ public class TestSparkApplicationDataExtractor {
when(environmentInfo.sparkProperties()).thenReturn(asScalaBuffer(clusterSparkProperties));
when(environmentInfo.systemProperties()).thenReturn(asScalaBuffer(systemProperties));
- AppResult result = SparkApplicationDataExtractor.extractAppResultFromAppStatusStore(applicationInfo, TEST_WORK_LOAD, environmentInfo, asScalaBuffer(jobsList), asScalaBuffer(ImmutableList.of()));
+ AppResult result = SparkApplicationDataExtractor.extractAppResultFromAppStatusStore(applicationInfo, TEST_WORK_LOAD, environmentInfo, asScalaBuffer(jobsList), asScalaBuffer(ImmutableList.of()), null);
assertEquals(result.applicationId, "id");
assertEquals(result.durationTime, 15 * 60 * 1000L);
assertEquals(result.submit_method, "spark-submit");
@@ -169,7 +169,7 @@ public class TestSparkApplicationDataExtractor {
ApplicationEnvironmentInfo environmentInfo = Mockito.mock(ApplicationEnvironmentInfo.class);
when(environmentInfo.sparkProperties()).thenReturn(asScalaBuffer(ImmutableList.of()));
when(environmentInfo.systemProperties()).thenReturn(asScalaBuffer(ImmutableList.of()));
- AppResult result = SparkApplicationDataExtractor.extractAppResultFromAppStatusStore(applicationInfo, TEST_WORK_LOAD, environmentInfo, asScalaBuffer(ImmutableList.of(runningData)), asScalaBuffer(ImmutableList.of()));
+ AppResult result = SparkApplicationDataExtractor.extractAppResultFromAppStatusStore(applicationInfo, TEST_WORK_LOAD, environmentInfo, asScalaBuffer(ImmutableList.of(runningData)), asScalaBuffer(ImmutableList.of()), null);
assertEquals(result.applicationId, "id");
assertEquals(result.executionStatus, AppResult.FAILED_STATUS);
assertEquals(result.durationTime, AppResult.FAILED_JOB_DURATION);
@@ -181,7 +181,7 @@ public class TestSparkApplicationDataExtractor {
ApplicationEnvironmentInfo environmentInfo = Mockito.mock(ApplicationEnvironmentInfo.class);
when(environmentInfo.sparkProperties()).thenReturn(asScalaBuffer(ImmutableList.of()));
when(environmentInfo.systemProperties()).thenReturn(asScalaBuffer(ImmutableList.of()));
- AppResult result = SparkApplicationDataExtractor.extractAppResultFromAppStatusStore(applicationInfo, TEST_WORK_LOAD, environmentInfo, asScalaBuffer(ImmutableList.of(failedData)), asScalaBuffer(ImmutableList.of()));
+ AppResult result = SparkApplicationDataExtractor.extractAppResultFromAppStatusStore(applicationInfo, TEST_WORK_LOAD, environmentInfo, asScalaBuffer(ImmutableList.of(failedData)), asScalaBuffer(ImmutableList.of()), null);
assertEquals(result.applicationId, "id");
assertEquals(result.executionStatus, AppResult.FAILED_STATUS);
assertEquals(result.durationTime, AppResult.FAILED_JOB_DURATION);
@@ -193,7 +193,7 @@ public class TestSparkApplicationDataExtractor {
ApplicationEnvironmentInfo environmentInfo = Mockito.mock(ApplicationEnvironmentInfo.class);
when(environmentInfo.sparkProperties()).thenReturn(asScalaBuffer(ImmutableList.of()));
when(environmentInfo.systemProperties()).thenReturn(asScalaBuffer(ImmutableList.of()));
- AppResult result = SparkApplicationDataExtractor.extractAppResultFromAppStatusStore(applicationInfo, TEST_WORK_LOAD, environmentInfo, asScalaBuffer(ImmutableList.of()), asScalaBuffer(ImmutableList.of()));
+ AppResult result = SparkApplicationDataExtractor.extractAppResultFromAppStatusStore(applicationInfo, TEST_WORK_LOAD, environmentInfo, asScalaBuffer(ImmutableList.of()), asScalaBuffer(ImmutableList.of()), null);
assertEquals(result.applicationId, "id");
assertEquals(result.executionStatus, AppResult.FAILED_STATUS);
assertEquals(result.durationTime, AppResult.FAILED_JOB_DURATION);
@@ -205,6 +205,6 @@ public class TestSparkApplicationDataExtractor {
ApplicationEnvironmentInfo environmentInfo = Mockito.mock(ApplicationEnvironmentInfo.class);
when(environmentInfo.sparkProperties()).thenReturn(asScalaBuffer(ImmutableList.of()));
when(environmentInfo.systemProperties()).thenReturn(asScalaBuffer(ImmutableList.of()));
- SparkApplicationDataExtractor.extractAppResultFromAppStatusStore(applicationInfo, TEST_WORK_LOAD, environmentInfo, asScalaBuffer(ImmutableList.of()), asScalaBuffer(ImmutableList.of()));
+ SparkApplicationDataExtractor.extractAppResultFromAppStatusStore(applicationInfo, TEST_WORK_LOAD, environmentInfo, asScalaBuffer(ImmutableList.of()), asScalaBuffer(ImmutableList.of()), null);
}
}
diff --git a/omnioperator/omniop-native-reader/cpp/src/jni/OrcColumnarBatchJniReader.cpp b/omnioperator/omniop-native-reader/cpp/src/jni/OrcColumnarBatchJniReader.cpp
index 99956458a90124e168d708fc49358f97549fa622..00f53c8f1517dc66b4aee65fa2f7239347fbf04b 100644
--- a/omnioperator/omniop-native-reader/cpp/src/jni/OrcColumnarBatchJniReader.cpp
+++ b/omnioperator/omniop-native-reader/cpp/src/jni/OrcColumnarBatchJniReader.cpp
@@ -19,6 +19,7 @@
#include "OrcColumnarBatchJniReader.h"
#include
+#include
#include "jni_common.h"
#include "common/UriInfo.h"
@@ -334,75 +335,80 @@ JNIEXPORT jlong JNICALL Java_com_huawei_boostkit_scan_jni_OrcColumnarBatchJniRea
JNI_FUNC_END(runtimeExceptionClass)
}
-template uint64_t CopyFixedWidth(orc::ColumnVectorBatch *field)
+template
+std::unique_ptr CopyFixedWidth(orc::ColumnVectorBatch *field)
{
using T = typename NativeType::type;
ORC_TYPE *lvb = dynamic_cast(field);
auto numElements = lvb->numElements;
auto values = lvb->data.data();
auto notNulls = lvb->notNull.data();
- auto newVector = new Vector(numElements);
+ auto newVector = std::make_unique>(numElements);
+ auto newVectorPtr = newVector.get();
// Check ColumnVectorBatch has null or not firstly
if (lvb->hasNulls) {
for (uint i = 0; i < numElements; i++) {
if (notNulls[i]) {
- newVector->SetValue(i, (T)(values[i]));
+ newVectorPtr->SetValue(i, (T)(values[i]));
} else {
- newVector->SetNull(i);
+ newVectorPtr->SetNull(i);
}
}
} else {
for (uint i = 0; i < numElements; i++) {
- newVector->SetValue(i, (T)(values[i]));
+ newVectorPtr->SetValue(i, (T)(values[i]));
}
}
- return (uint64_t)newVector;
+ return newVector;
}
-template uint64_t CopyOptimizedForInt64(orc::ColumnVectorBatch *field)
+template
+std::unique_ptr CopyOptimizedForInt64(orc::ColumnVectorBatch *field)
{
using T = typename NativeType::type;
ORC_TYPE *lvb = dynamic_cast(field);
auto numElements = lvb->numElements;
auto values = lvb->data.data();
auto notNulls = lvb->notNull.data();
- auto newVector = new Vector(numElements);
+ auto newVector = std::make_unique>(numElements);
+ auto newVectorPtr = newVector.get();
// Check ColumnVectorBatch has null or not firstly
if (lvb->hasNulls) {
for (uint i = 0; i < numElements; i++) {
if (!notNulls[i]) {
- newVector->SetNull(i);
+ newVectorPtr->SetNull(i);
}
}
}
- newVector->SetValues(0, values, numElements);
- return (uint64_t)newVector;
+ newVectorPtr->SetValues(0, values, numElements);
+ return newVector;
}
-uint64_t CopyVarWidth(orc::ColumnVectorBatch *field)
+std::unique_ptr CopyVarWidth(orc::ColumnVectorBatch *field)
{
orc::StringVectorBatch *lvb = dynamic_cast(field);
auto numElements = lvb->numElements;
auto values = lvb->data.data();
auto notNulls = lvb->notNull.data();
auto lens = lvb->length.data();
- auto newVector = new Vector>(numElements);
+ auto newVector = std::make_unique>>(numElements);
+ auto newVectorPtr = newVector.get();
if (lvb->hasNulls) {
for (uint i = 0; i < numElements; i++) {
if (notNulls[i]) {
auto data = std::string_view(reinterpret_cast(values[i]), lens[i]);
- newVector->SetValue(i, data);
+ newVectorPtr->SetValue(i, data);
} else {
- newVector->SetNull(i);
+ newVectorPtr->SetNull(i);
}
}
} else {
for (uint i = 0; i < numElements; i++) {
auto data = std::string_view(reinterpret_cast(values[i]), lens[i]);
- newVector->SetValue(i, data);
+ newVectorPtr->SetValue(i, data);
}
}
- return (uint64_t)newVector;
+ return newVector;
}
inline void FindLastNotEmpty(const char *chars, long &len)
@@ -412,14 +418,15 @@ inline void FindLastNotEmpty(const char *chars, long &len)
}
}
-uint64_t CopyCharType(orc::ColumnVectorBatch *field)
+std::unique_ptr CopyCharType(orc::ColumnVectorBatch *field)
{
orc::StringVectorBatch *lvb = dynamic_cast(field);
auto numElements = lvb->numElements;
auto values = lvb->data.data();
auto notNulls = lvb->notNull.data();
auto lens = lvb->length.data();
- auto newVector = new Vector>(numElements);
+ auto newVector = std::make_unique>>(numElements);
+ auto newVectorPtr = newVector.get();
if (lvb->hasNulls) {
for (uint i = 0; i < numElements; i++) {
if (notNulls[i]) {
@@ -427,9 +434,9 @@ uint64_t CopyCharType(orc::ColumnVectorBatch *field)
auto len = lens[i];
FindLastNotEmpty(chars, len);
auto data = std::string_view(chars, len);
- newVector->SetValue(i, data);
+ newVectorPtr->SetValue(i, data);
} else {
- newVector->SetNull(i);
+ newVectorPtr->SetNull(i);
}
}
} else {
@@ -438,202 +445,156 @@ uint64_t CopyCharType(orc::ColumnVectorBatch *field)
auto len = lens[i];
FindLastNotEmpty(chars, len);
auto data = std::string_view(chars, len);
- newVector->SetValue(i, data);
+ newVectorPtr->SetValue(i, data);
}
}
- return (uint64_t)newVector;
+ return newVector;
}
-uint64_t CopyToOmniDecimal128Vec(orc::ColumnVectorBatch *field)
+std::unique_ptr CopyToOmniDecimal128Vec(orc::ColumnVectorBatch *field)
{
orc::Decimal128VectorBatch *lvb = dynamic_cast(field);
auto numElements = lvb->numElements;
auto values = lvb->values.data();
auto notNulls = lvb->notNull.data();
- auto newVector = new Vector(numElements);
+ auto newVector = std::make_unique>(numElements);
+ auto newVectorPtr = newVector.get();
if (lvb->hasNulls) {
for (uint i = 0; i < numElements; i++) {
if (notNulls[i]) {
- newVector->SetValue(i, Decimal128(values[i].getHighBits(), values[i].getLowBits()));
+ __int128_t dst = values[i].getHighBits();
+ dst <<= 64;
+ dst |= values[i].getLowBits();
+ newVectorPtr->SetValue(i, Decimal128(dst));
} else {
- newVector->SetNull(i);
+ newVectorPtr->SetNull(i);
}
}
} else {
for (uint i = 0; i < numElements; i++) {
- newVector->SetValue(i, Decimal128(values[i].getHighBits(), values[i].getLowBits()));
+ newVectorPtr->SetValue(i, Decimal128(values[i].getHighBits(), values[i].getLowBits()));
}
}
- return (uint64_t)newVector;
+ return newVector;
}
-uint64_t CopyToOmniDecimal64Vec(orc::ColumnVectorBatch *field)
+std::unique_ptr CopyToOmniDecimal64Vec(orc::ColumnVectorBatch *field)
{
orc::Decimal64VectorBatch *lvb = dynamic_cast(field);
auto numElements = lvb->numElements;
auto values = lvb->values.data();
auto notNulls = lvb->notNull.data();
- auto newVector = new Vector(numElements);
+ auto newVector = std::make_unique>(numElements);
+ auto newVectorPtr = newVector.get();
if (lvb->hasNulls) {
for (uint i = 0; i < numElements; i++) {
if (!notNulls[i]) {
- newVector->SetNull(i);
+ newVectorPtr->SetNull(i);
}
}
}
- newVector->SetValues(0, values, numElements);
- return (uint64_t)newVector;
+ newVectorPtr->SetValues(0, values, numElements);
+ return newVector;
}
-uint64_t CopyToOmniDecimal128VecFrom64(orc::ColumnVectorBatch *field)
+std::unique_ptr CopyToOmniDecimal128VecFrom64(orc::ColumnVectorBatch *field)
{
orc::Decimal64VectorBatch *lvb = dynamic_cast(field);
auto numElements = lvb->numElements;
auto values = lvb->values.data();
auto notNulls = lvb->notNull.data();
- auto newVector = new Vector(numElements);
+ auto newVector = std::make_unique>(numElements);
+ auto newVectorPtr = newVector.get();
if (lvb->hasNulls) {
for (uint i = 0; i < numElements; i++) {
if (!notNulls[i]) {
- newVector->SetNull(i);
+ newVectorPtr->SetNull(i);
} else {
Decimal128 d128(values[i]);
- newVector->SetValue(i, d128);
+ newVectorPtr->SetValue(i, d128);
}
}
} else {
for (uint i = 0; i < numElements; i++) {
Decimal128 d128(values[i]);
- newVector->SetValue(i, d128);
+ newVectorPtr->SetValue(i, d128);
}
}
- return (uint64_t)newVector;
+ return newVector;
}
-uint64_t dealLongVectorBatch(DataTypeId id, orc::ColumnVectorBatch *field) {
- switch (id) {
- case omniruntime::type::OMNI_BOOLEAN:
- return CopyFixedWidth(field);
- case omniruntime::type::OMNI_SHORT:
- return CopyFixedWidth(field);
- case omniruntime::type::OMNI_INT:
- return CopyFixedWidth(field);
- case omniruntime::type::OMNI_LONG:
- return CopyOptimizedForInt64(field);
- case omniruntime::type::OMNI_DATE32:
- return CopyFixedWidth(field);
- case omniruntime::type::OMNI_DATE64:
- return CopyOptimizedForInt64(field);
- default: {
- throw std::runtime_error("dealLongVectorBatch not support for type: " + id);
- }
- }
- return -1;
-}
-
-uint64_t dealDoubleVectorBatch(DataTypeId id, orc::ColumnVectorBatch *field) {
- switch (id) {
- case omniruntime::type::OMNI_DOUBLE:
- return CopyOptimizedForInt64(field);
- default: {
- throw std::runtime_error("dealDoubleVectorBatch not support for type: " + id);
- }
- }
- return -1;
-}
-
-uint64_t dealDecimal64VectorBatch(DataTypeId id, orc::ColumnVectorBatch *field) {
- switch (id) {
- case omniruntime::type::OMNI_DECIMAL64:
- return CopyToOmniDecimal64Vec(field);
- case omniruntime::type::OMNI_DECIMAL128:
- return CopyToOmniDecimal128VecFrom64(field);
- default: {
- throw std::runtime_error("dealDecimal64VectorBatch not support for type: " + id);
- }
- }
- return -1;
-}
-
-uint64_t dealDecimal128VectorBatch(DataTypeId id, orc::ColumnVectorBatch *field) {
- switch (id) {
- case omniruntime::type::OMNI_DECIMAL128:
- return CopyToOmniDecimal128Vec(field);
- default: {
- throw std::runtime_error("dealDecimal128VectorBatch not support for type: " + id);
- }
- }
- return -1;
-}
-
-int CopyToOmniVec(const orc::Type *type, int omniTypeId, uint64_t &omniVecId, orc::ColumnVectorBatch *field,
+std::unique_ptr CopyToOmniVec(const orc::Type *type, int &omniTypeId, orc::ColumnVectorBatch *field,
bool isDecimal64Transfor128)
{
- DataTypeId dataTypeId = static_cast(omniTypeId);
switch (type->getKind()) {
case orc::TypeKind::BOOLEAN:
+ omniTypeId = static_cast(OMNI_BOOLEAN);
+ return CopyFixedWidth(field);
case orc::TypeKind::SHORT:
+ omniTypeId = static_cast(OMNI_SHORT);
+ return CopyFixedWidth(field);
case orc::TypeKind::DATE:
+ omniTypeId = static_cast(OMNI_DATE32);
+ return CopyFixedWidth(field);
case orc::TypeKind::INT:
+ omniTypeId = static_cast(OMNI_INT);
+ return CopyFixedWidth(field);
case orc::TypeKind::LONG:
- omniVecId = dealLongVectorBatch(dataTypeId, field);
- break;
+ omniTypeId = static_cast(OMNI_LONG);
+ return CopyOptimizedForInt64(field);
case orc::TypeKind::DOUBLE:
- omniVecId = dealDoubleVectorBatch(dataTypeId, field);
- break;
+ omniTypeId = static_cast(OMNI_DOUBLE);
+ return CopyOptimizedForInt64(field);
case orc::TypeKind::CHAR:
- if (dataTypeId != OMNI_VARCHAR) {
- throw std::runtime_error("Cannot transfer to other OMNI_TYPE but VARCHAR for orc char");
- }
- omniVecId = CopyCharType(field);
- break;
+ omniTypeId = static_cast(OMNI_VARCHAR);
+ return CopyCharType(field);
case orc::TypeKind::STRING:
case orc::TypeKind::VARCHAR:
- if (dataTypeId != OMNI_VARCHAR) {
- throw std::runtime_error("Cannot transfer to other OMNI_TYPE but VARCHAR for orc string/varchar");
- }
- omniVecId = CopyVarWidth(field);
- break;
+ omniTypeId = static_cast(OMNI_VARCHAR);
+ return CopyVarWidth(field);
case orc::TypeKind::DECIMAL:
if (type->getPrecision() > MAX_DECIMAL64_DIGITS) {
- omniVecId = dealDecimal128VectorBatch(dataTypeId, field);
+ omniTypeId = static_cast(OMNI_DECIMAL128);
+ return CopyToOmniDecimal128Vec(field);
+ } else if (isDecimal64Transfor128) {
+ omniTypeId = static_cast(OMNI_DECIMAL128);
+ return CopyToOmniDecimal128VecFrom64(field);
} else {
- omniVecId = dealDecimal64VectorBatch(dataTypeId, field);
+ omniTypeId = static_cast(OMNI_DECIMAL64);
+ return CopyToOmniDecimal64Vec(field);
}
- break;
default: {
throw std::runtime_error("Native ColumnarFileScan Not support For This Type: " + type->getKind());
}
}
- return 1;
}
JNIEXPORT jlong JNICALL Java_com_huawei_boostkit_scan_jni_OrcColumnarBatchJniReader_recordReaderNext(JNIEnv *env,
jobject jObj, jlong rowReader, jlong batch, jintArray typeId, jlongArray vecNativeId)
{
- JNI_FUNC_START
orc::RowReader *rowReaderPtr = (orc::RowReader *)rowReader;
orc::ColumnVectorBatch *columnVectorBatch = (orc::ColumnVectorBatch *)batch;
+ std::vector> omniVecs;
+
const orc::Type &baseTp = rowReaderPtr->getSelectedType();
- int vecCnt = 0;
- long batchRowSize = 0;
- auto ptr = env->GetIntArrayElements(typeId, JNI_FALSE);
+ uint64_t batchRowSize = 0;
if (rowReaderPtr->next(*columnVectorBatch)) {
orc::StructVectorBatch *root = dynamic_cast(columnVectorBatch);
- vecCnt = root->fields.size();
batchRowSize = root->fields[0]->numElements;
- for (int id = 0; id < vecCnt; id++) {
+ int32_t vecCnt = root->fields.size();
+ std::vector omniTypeIds(vecCnt, 0);
+ for (int32_t id = 0; id < vecCnt; id++) {
auto type = baseTp.getSubtype(id);
- int omniTypeId = ptr[id];
- uint64_t omniVecId = 0;
- CopyToOmniVec(type, omniTypeId, omniVecId, root->fields[id], isDecimal64Transfor128);
- jlong omniVec = static_cast(omniVecId);
+ omniVecs.emplace_back(CopyToOmniVec(type, omniTypeIds[id], root->fields[id], isDecimal64Transfor128));
+ }
+ for (int32_t id = 0; id < vecCnt; id++) {
+ env->SetIntArrayRegion(typeId, id, 1, omniTypeIds.data() + id);
+ jlong omniVec = reinterpret_cast(omniVecs[id].release());
env->SetLongArrayRegion(vecNativeId, id, 1, &omniVec);
}
}
- return (jlong)batchRowSize;
- JNI_FUNC_END(runtimeExceptionClass)
+ return (jlong) batchRowSize;
}
/*
diff --git a/omnioperator/omniop-native-reader/cpp/src/jni/ParquetColumnarBatchJniReader.cpp b/omnioperator/omniop-native-reader/cpp/src/jni/ParquetColumnarBatchJniReader.cpp
index 21c0b81c96b2c83fcade1207ae7a2816da7af7fb..991699a7be573db1f191fe7b20de0a45baf4964d 100644
--- a/omnioperator/omniop-native-reader/cpp/src/jni/ParquetColumnarBatchJniReader.cpp
+++ b/omnioperator/omniop-native-reader/cpp/src/jni/ParquetColumnarBatchJniReader.cpp
@@ -93,10 +93,14 @@ JNIEXPORT jlong JNICALL Java_com_huawei_boostkit_scan_jni_ParquetColumnarBatchJn
{
JNI_FUNC_START
ParquetReader *pReader = (ParquetReader *)reader;
- std::vector recordBatch(pReader->columnReaders.size());
+ std::vector recordBatch(pReader->columnReaders.size(), 0);
long batchRowSize = 0;
auto state = pReader->ReadNextBatch(recordBatch, &batchRowSize);
if (state != Status::OK()) {
+ for (auto vec : recordBatch) {
+ delete vec;
+ }
+ recordBatch.clear();
env->ThrowNew(runtimeExceptionClass, state.ToString().c_str());
return 0;
}
diff --git a/omnioperator/omniop-native-reader/cpp/src/parquet/ParquetReader.cpp b/omnioperator/omniop-native-reader/cpp/src/parquet/ParquetReader.cpp
index 5f4f379e413c854161dc5e04cce146e98e5a148f..672f4ceedffa043b5577f7a954154ae8e8e599e2 100644
--- a/omnioperator/omniop-native-reader/cpp/src/parquet/ParquetReader.cpp
+++ b/omnioperator/omniop-native-reader/cpp/src/parquet/ParquetReader.cpp
@@ -143,8 +143,12 @@ Status ParquetReader::GetRecordBatchReader(const std::vector &row_group_ind
return Status::OK();
}
- for (uint64_t i = 0; i < columnReaders.size(); i++) {
- RETURN_NOT_OK(columnReaders[i]->NextBatch(read_size, &batch[i]));
+ try {
+ for (uint64_t i = 0; i < columnReaders.size(); i++) {
+ RETURN_NOT_OK(columnReaders[i]->NextBatch(read_size, &batch[i]));
+ }
+ } catch (const std::exception &e) {
+ return Status::Invalid(e.what());
}
// Check BaseVector
diff --git a/omnioperator/omniop-native-reader/cpp/src/parquet/ParquetTypedRecordReader.h b/omnioperator/omniop-native-reader/cpp/src/parquet/ParquetTypedRecordReader.h
index 76108fab68de5c3583b6708cfba60f899a7aa2d9..3f602c979d71e76d2995e99c15f73355c956a98c 100644
--- a/omnioperator/omniop-native-reader/cpp/src/parquet/ParquetTypedRecordReader.h
+++ b/omnioperator/omniop-native-reader/cpp/src/parquet/ParquetTypedRecordReader.h
@@ -486,11 +486,10 @@ namespace omniruntime::reader {
virtual void InitVec(int64_t capacity) {
vec_ = new Vector(capacity);
+ auto capacity_bytes = capacity * byte_width_;
if (parquet_vec_ != nullptr) {
- auto capacity_bytes = capacity * byte_width_;
memset(parquet_vec_, 0, capacity_bytes);
} else {
- auto capacity_bytes = capacity * byte_width_;
parquet_vec_ = new uint8_t[capacity_bytes];
}
// Init nulls
diff --git a/omnioperator/omniop-spark-extension/cpp/CMakeLists.txt b/omnioperator/omniop-spark-extension/cpp/CMakeLists.txt
index 491cfb7086037229608f2963cf6c278ca132b198..10f630ad13925922872540fb13b379a0b52e15b3 100644
--- a/omnioperator/omniop-spark-extension/cpp/CMakeLists.txt
+++ b/omnioperator/omniop-spark-extension/cpp/CMakeLists.txt
@@ -1,9 +1,9 @@
-# project name
-project(spark-thestral-plugin)
-
# required cmake version
cmake_minimum_required(VERSION 3.10)
+# project name
+project(spark-thestral-plugin)
+
# configure cmake
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_COMPILER "g++")
diff --git a/omnioperator/omniop-spark-extension/cpp/src/common/Buffer.h b/omnioperator/omniop-spark-extension/cpp/src/common/Buffer.h
index 73fe13732d27dca87e12ac72900635f8f26cd5f4..ab8a52c229017b4277c7c3b5552477133aa27e4b 100644
--- a/omnioperator/omniop-spark-extension/cpp/src/common/Buffer.h
+++ b/omnioperator/omniop-spark-extension/cpp/src/common/Buffer.h
@@ -16,29 +16,42 @@
* limitations under the License.
*/
- #ifndef CPP_BUFFER_H
- #define CPP_BUFFER_H
+#ifndef CPP_BUFFER_H
+#define CPP_BUFFER_H
- #include
- #include
- #include
- #include
- #include
-
- class Buffer {
- public:
- Buffer(uint8_t* data, int64_t size, int64_t capacity)
- : data_(data),
- size_(size),
- capacity_(capacity) {
+#include
+#include
+#include
+#include
+#include
+#include
+
+class Buffer {
+public:
+ Buffer(uint8_t* data, int64_t size, int64_t capacity, bool isOmniAllocated = true)
+ : data_(data),
+ size_(size),
+ capacity_(capacity),
+ allocatedByOmni(isOmniAllocated) {
+ }
+
+ ~Buffer() {
+ if (allocatedByOmni && not releaseFlag) {
+ auto *allocator = omniruntime::mem::Allocator::GetAllocator();
+ allocator->Free(data_, capacity_);
}
+ }
- ~Buffer() {}
+ void SetReleaseFlag() {
+ releaseFlag = true;
+ }
- public:
- uint8_t * data_;
- int64_t size_;
- int64_t capacity_;
- };
+public:
+ uint8_t * data_;
+ int64_t size_;
+ int64_t capacity_;
+ bool allocatedByOmni = true;
+ bool releaseFlag = false;
+};
- #endif //CPP_BUFFER_H
\ No newline at end of file
+#endif //CPP_BUFFER_H
\ No newline at end of file
diff --git a/omnioperator/omniop-spark-extension/cpp/src/jni/SparkJniWrapper.cpp b/omnioperator/omniop-spark-extension/cpp/src/jni/SparkJniWrapper.cpp
index ca982c0a4ca56100cb6c11599d6d0c334009da92..14785a9cf453f5925974f8b85dc80538a5b85a17 100644
--- a/omnioperator/omniop-spark-extension/cpp/src/jni/SparkJniWrapper.cpp
+++ b/omnioperator/omniop-spark-extension/cpp/src/jni/SparkJniWrapper.cpp
@@ -131,7 +131,6 @@ JNIEXPORT jlong JNICALL Java_com_huawei_boostkit_spark_jni_SparkJniWrapper_nativ
JNIEXPORT jlong JNICALL Java_com_huawei_boostkit_spark_jni_SparkJniWrapper_split(
JNIEnv *env, jobject jObj, jlong splitter_id, jlong jVecBatchAddress)
{
- JNI_FUNC_START
auto splitter = g_shuffleSplitterHolder.Lookup(splitter_id);
if (!splitter) {
std::string error_message = "Invalid splitter id " + std::to_string(splitter_id);
@@ -140,10 +139,11 @@ JNIEXPORT jlong JNICALL Java_com_huawei_boostkit_spark_jni_SparkJniWrapper_split
}
auto vecBatch = (VectorBatch *) jVecBatchAddress;
-
+ splitter->SetInputVecBatch(vecBatch);
+ JNI_FUNC_START
splitter->Split(*vecBatch);
return 0L;
- JNI_FUNC_END(runtimeExceptionClass)
+ JNI_FUNC_END_WITH_VECBATCH(runtimeExceptionClass, splitter->GetInputVecBatch())
}
JNIEXPORT jobject JNICALL Java_com_huawei_boostkit_spark_jni_SparkJniWrapper_stop(
diff --git a/omnioperator/omniop-spark-extension/cpp/src/jni/jni_common.h b/omnioperator/omniop-spark-extension/cpp/src/jni/jni_common.h
index 4b59296e152876062a06db3d69c81a7ed22b670b..964fab6dfc06ac692294fa212f40afc21a4d1041 100644
--- a/omnioperator/omniop-spark-extension/cpp/src/jni/jni_common.h
+++ b/omnioperator/omniop-spark-extension/cpp/src/jni/jni_common.h
@@ -48,6 +48,15 @@ jmethodID GetMethodID(JNIEnv* env, jclass this_class, const char* name, const ch
return; \
} \
+#define JNI_FUNC_END_WITH_VECBATCH(exceptionClass, toDeleteVecBatch) \
+ } \
+ catch (const std::exception &e) \
+ { \
+ VectorHelper::FreeVecBatch(toDeleteVecBatch); \
+ env->ThrowNew(exceptionClass, e.what()); \
+ return 0; \
+ }
+
extern jclass runtimeExceptionClass;
extern jclass splitResultClass;
extern jclass jsonClass;
diff --git a/omnioperator/omniop-spark-extension/cpp/src/shuffle/splitter.cpp b/omnioperator/omniop-spark-extension/cpp/src/shuffle/splitter.cpp
index f0a9b3e76ab9956581a13744b16851cd4f34d040..b0deafe611c37e67561e6092e61c64a4fd67a294 100644
--- a/omnioperator/omniop-spark-extension/cpp/src/shuffle/splitter.cpp
+++ b/omnioperator/omniop-spark-extension/cpp/src/shuffle/splitter.cpp
@@ -449,6 +449,7 @@ int Splitter::DoSplit(VectorBatch& vb) {
num_row_splited_ += vb.GetRowCount();
// release the fixed width vector and release vectorBatch at the same time
ReleaseVectorBatch(&vb);
+ this->ResetInputVecBatch();
// 阈值检查,是否溢写
if (num_row_splited_ >= SHUFFLE_SPILL_NUM_ELEMENTS_FORCE_SPILL_THRESHOLD) {
@@ -675,7 +676,7 @@ void Splitter::SerializingFixedColumns(int32_t partitionId,
valueStr.resize(onceCopyLen);
std::string nullStr;
- std::shared_ptr ptr_value (new Buffer((uint8_t*)valueStr.data(), 0, onceCopyLen));
+ std::shared_ptr ptr_value (new Buffer((uint8_t*)valueStr.data(), 0, onceCopyLen, false));
std::shared_ptr ptr_validity;
// options_.spill_batch_row_num长度切割与拼接
@@ -698,7 +699,7 @@ void Splitter::SerializingFixedColumns(int32_t partitionId,
splitRowInfoTmp->cacheBatchCopyedLen[fixColIndexTmp]);
if (not nullAllocated && partition_cached_vectorbatch_[partitionId][splitRowInfoTmp->cacheBatchIndex[fixColIndexTmp]][fixColIndexTmp][0] != nullptr) {
nullStr.resize(splitRowInfoTmp->onceCopyRow);
- ptr_validity.reset(new Buffer((uint8_t*)nullStr.data(), 0, splitRowInfoTmp->onceCopyRow));
+ ptr_validity.reset(new Buffer((uint8_t*)nullStr.data(), 0, splitRowInfoTmp->onceCopyRow, false));
nullAllocated = true;
}
if ((onceCopyLen - destCopyedLength) >= (cacheBatchSize - splitRowInfoTmp->cacheBatchCopyedLen[fixColIndexTmp])) {
@@ -714,9 +715,11 @@ void Splitter::SerializingFixedColumns(int32_t partitionId,
// 释放内存
options_.allocator->Free(partition_cached_vectorbatch_[partitionId][splitRowInfoTmp->cacheBatchIndex[fixColIndexTmp]][fixColIndexTmp][0]->data_,
partition_cached_vectorbatch_[partitionId][splitRowInfoTmp->cacheBatchIndex[fixColIndexTmp]][fixColIndexTmp][0]->capacity_);
+ partition_cached_vectorbatch_[partitionId][splitRowInfoTmp->cacheBatchIndex[fixColIndexTmp]][fixColIndexTmp][0]->SetReleaseFlag();
}
options_.allocator->Free(partition_cached_vectorbatch_[partitionId][splitRowInfoTmp->cacheBatchIndex[fixColIndexTmp]][fixColIndexTmp][1]->data_,
partition_cached_vectorbatch_[partitionId][splitRowInfoTmp->cacheBatchIndex[fixColIndexTmp]][fixColIndexTmp][1]->capacity_);
+ partition_cached_vectorbatch_[partitionId][splitRowInfoTmp->cacheBatchIndex[fixColIndexTmp]][fixColIndexTmp][1]->SetReleaseFlag();
destCopyedLength += memCopyLen;
splitRowInfoTmp->cacheBatchIndex[fixColIndexTmp] += 1; // cacheBatchIndex下标后移
splitRowInfoTmp->cacheBatchCopyedLen[fixColIndexTmp] = 0; // 初始化下一个cacheBatch的起始偏移
@@ -1058,6 +1061,7 @@ int Splitter::DeleteSpilledTmpFile() {
auto tmpDataFilePath = pair.first + ".data";
// 释放存储有各个临时文件的偏移数据内存
options_.allocator->Free(pair.second->data_, pair.second->capacity_);
+ pair.second->SetReleaseFlag();
if (IsFileExist(tmpDataFilePath)) {
remove(tmpDataFilePath.c_str());
}
diff --git a/omnioperator/omniop-spark-extension/cpp/src/shuffle/splitter.h b/omnioperator/omniop-spark-extension/cpp/src/shuffle/splitter.h
index 4e9684de963a3bb4eedf0720ec8a01c1d2160f2b..d54b97849dadb06c72c6a456aedd9eefa1f21092 100644
--- a/omnioperator/omniop-spark-extension/cpp/src/shuffle/splitter.h
+++ b/omnioperator/omniop-spark-extension/cpp/src/shuffle/splitter.h
@@ -165,6 +165,7 @@ private:
}
}
vectorAddress.clear();
+ vb->ClearVectors();
delete vb;
}
@@ -173,7 +174,7 @@ private:
std::vector vector_batch_col_types_;
InputDataTypes input_col_types;
std::vector binary_array_empirical_size_;
-
+ omniruntime::vec::VectorBatch *inputVecBatch = nullptr;
public:
bool singlePartitionFlag = false;
int32_t num_partitions_;
@@ -221,6 +222,22 @@ public:
int64_t TotalComputePidTime() const { return total_compute_pid_time_; }
const std::vector& PartitionLengths() const { return partition_lengths_; }
+
+ omniruntime::vec::VectorBatch *GetInputVecBatch()
+ {
+ return inputVecBatch;
+ }
+
+ void SetInputVecBatch(omniruntime::vec::VectorBatch *inVecBatch)
+ {
+ inputVecBatch = inVecBatch;
+ }
+
+ // no need to clear memory when exception, so we have to reset
+ void ResetInputVecBatch()
+ {
+ inputVecBatch = nullptr;
+ }
};
diff --git a/omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/serialize/ShuffleDataSerializer.java b/omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/serialize/ShuffleDataSerializer.java
index 9f6cadf70b7348a1696e140a13887c87018af4b2..6a0c1b27c4282016aecada2ba4ef0c48c320f20f 100644
--- a/omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/serialize/ShuffleDataSerializer.java
+++ b/omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/serialize/ShuffleDataSerializer.java
@@ -20,6 +20,7 @@ package com.huawei.boostkit.spark.serialize;
import com.google.protobuf.InvalidProtocolBufferException;
+import nova.hetu.omniruntime.utils.OmniRuntimeException;
import nova.hetu.omniruntime.vector.BooleanVec;
import nova.hetu.omniruntime.vector.Decimal128Vec;
import nova.hetu.omniruntime.vector.DoubleVec;
@@ -35,21 +36,31 @@ import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.vectorized.ColumnVector;
import org.apache.spark.sql.vectorized.ColumnarBatch;
-
public class ShuffleDataSerializer {
public static ColumnarBatch deserialize(byte[] bytes) {
+ ColumnVector[] vecs = null;
try {
VecData.VecBatch vecBatch = VecData.VecBatch.parseFrom(bytes);
int vecCount = vecBatch.getVecCnt();
int rowCount = vecBatch.getRowCnt();
- ColumnVector[] vecs = new ColumnVector[vecCount];
+ vecs = new ColumnVector[vecCount];
for (int i = 0; i < vecCount; i++) {
vecs[i] = buildVec(vecBatch.getVecs(i), rowCount);
}
return new ColumnarBatch(vecs, rowCount);
} catch (InvalidProtocolBufferException e) {
throw new RuntimeException("deserialize failed. errmsg:" + e.getMessage());
+ } catch (OmniRuntimeException e) {
+ if (vecs != null) {
+ for (int i = 0; i < vecs.length; i++) {
+ ColumnVector vec = vecs[i];
+ if (vec != null) {
+ vec.close();
+ }
+ }
+ }
+ throw new RuntimeException("deserialize failed. errmsg:" + e.getMessage());
}
}
diff --git a/omnioperator/omniop-spark-extension/java/src/main/java/org/apache/spark/sql/execution/datasources/orc/OmniOrcColumnarBatchReader.java b/omnioperator/omniop-spark-extension/java/src/main/java/org/apache/spark/sql/execution/datasources/orc/OmniOrcColumnarBatchReader.java
index e8e7db3af876ce726d9b22e8ead47cf9f15a9fbe..e301372f5223be6c432df82fd38d9d1d417f8126 100644
--- a/omnioperator/omniop-spark-extension/java/src/main/java/org/apache/spark/sql/execution/datasources/orc/OmniOrcColumnarBatchReader.java
+++ b/omnioperator/omniop-spark-extension/java/src/main/java/org/apache/spark/sql/execution/datasources/orc/OmniOrcColumnarBatchReader.java
@@ -118,8 +118,10 @@ public class OmniOrcColumnarBatchReader extends RecordReader
- if (!isSlice) {
- vector.getVec
- } else {
- vector.getVec.slice(0, cb.numRows())
+ try {
+ for (i <- 0 until cb.numCols()) {
+ val omniVec: Vec = cb.column(i) match {
+ case vector: OmniColumnVector =>
+ if (!isSlice) {
+ vector.getVec
+ } else {
+ vector.getVec.slice(0, cb.numRows())
+ }
+ case vector: ColumnVector =>
+ transColumnVector(vector, cb.numRows())
+ case _ =>
+ throw new UnsupportedOperationException("unsupport column vector!")
+ }
+ input(i) = omniVec
+ }
+ } catch {
+ case e: Exception => {
+ for (j <- 0 until cb.numCols()) {
+ val vec = input(j)
+ if (vec != null) vec.close
+ cb.column(j) match {
+ case vector: OmniColumnVector =>
+ vector.close()
}
- case vector: ColumnVector =>
- transColumnVector(vector, cb.numRows())
- case _ =>
- throw new UnsupportedOperationException("unsupport column vector!")
+ }
+ throw new RuntimeException("allocate memory failed!")
}
- input(i) = omniVec
}
input
}
diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarExec.scala
index 77707365e2e92680b027c4649144a0d599768720..33e8037a12aa253da8b7a40f0895c511d31cfec8 100644
--- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarExec.scala
+++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarExec.scala
@@ -312,7 +312,7 @@ object ColumnarBatchToInternalRow {
val batchIter = batches.flatMap { batch =>
- // toClosedVecs closed case:
+ // toClosedVecs closed case: [Deprcated]
// 1) all rows of batch fetched and closed
// 2) only fetch parital rows(eg: top-n, limit-n), closed at task CompletionListener callback
val toClosedVecs = new ListBuffer[Vec]
@@ -330,27 +330,22 @@ object ColumnarBatchToInternalRow {
new Iterator[InternalRow] {
val numOutputRowsMetric: SQLMetric = numOutputRows
- var closed = false
-
- // only invoke if fetch partial rows of batch
- if (mayPartialFetch) {
- SparkMemoryUtils.addLeakSafeTaskCompletionListener { _ =>
- if (!closed) {
- toClosedVecs.foreach {vec =>
- vec.close()
- }
+
+
+ SparkMemoryUtils.addLeakSafeTaskCompletionListener { _ =>
+ toClosedVecs.foreach {vec =>
+ vec.close()
}
- }
}
override def hasNext: Boolean = {
val has = iter.hasNext
- // fetch all rows and closed
- if (!has && !closed) {
+ // fetch all rows
+ if (!has) {
toClosedVecs.foreach {vec =>
vec.close()
+ toClosedVecs.remove(toClosedVecs.indexOf(vec))
}
- closed = true
}
has
}
diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarExpandExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarExpandExec.scala
index 11541ccee1c0e9ad775eb71360bf0d081d72bcd2..12fbdb6b4a2c5b8b4eec6ea51fbe7f6d919dbb76 100644
--- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarExpandExec.scala
+++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarExpandExec.scala
@@ -330,24 +330,27 @@ case class ColumnarOptRollupExec(
omniOutputPartials)
omniCodegenTimeMetric += NANOSECONDS.toMillis(System.nanoTime() - startCodegen)
+
+ val results = new ListBuffer[VecBatch]()
+ var hashaggResults: java.util.Iterator[VecBatch] = null
+
// close operator
addLeakSafeTaskCompletionListener[Unit](_ => {
projectOperators.foreach(operator => operator.close())
hashaggOperator.close()
+ results.foreach(vecBatch => {
+ vecBatch.releaseAllVectors()
+ vecBatch.close()
+ })
})
- val results = new ListBuffer[VecBatch]()
- var hashaggResults: java.util.Iterator[VecBatch] = null
-
while (iter.hasNext) {
val batch = iter.next()
val input = transColBatchToOmniVecs(batch)
val vecBatch = new VecBatch(input, batch.numRows())
results.append(vecBatch)
projectOperators.foreach(projectOperator => {
- val vecs = vecBatch.getVectors.map(vec => {
- vec.slice(0, vecBatch.getRowCount)
- })
+ val vecs = transColBatchToOmniVecs(batch, true)
val projectInput = new VecBatch(vecs, vecBatch.getRowCount)
var startInput = System.nanoTime()
diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarShuffleExchangeExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarShuffleExchangeExec.scala
index f928ad11c258bb07c64f11d065adadb5f439dcc0..233503141dc11a6d74ea8e000a02594cb4607116 100644
--- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarShuffleExchangeExec.scala
+++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarShuffleExchangeExec.scala
@@ -40,7 +40,7 @@ import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec.createShuffleWriteProcessor
import org.apache.spark.sql.execution.exchange.{ENSURE_REQUIREMENTS, ShuffleExchangeExec, ShuffleExchangeLike, ShuffleOrigin}
import org.apache.spark.sql.execution.metric._
-import org.apache.spark.sql.execution.util.MergeIterator
+import org.apache.spark.sql.execution.util.{MergeIterator, SparkMemoryUtils}
import org.apache.spark.sql.execution.util.SparkMemoryUtils.addLeakSafeTaskCompletionListener
import org.apache.spark.sql.execution.vectorized.OmniColumnVector
import org.apache.spark.sql.internal.SQLConf
@@ -167,10 +167,14 @@ class ColumnarShuffleExchangeExec(
if (enableShuffleBatchMerge) {
cachedShuffleRDD.mapPartitionsWithIndexInternal { (index, iter) =>
- new MergeIterator(iter,
+ val mergeIterator = new MergeIterator(iter,
StructType.fromAttributes(child.output),
longMetric("numMergedVecBatches"),
longMetric("bypassVecBatches"))
+ SparkMemoryUtils.addLeakSafeTaskCompletionListener[Unit](_ => {
+ mergeIterator.close()
+ })
+ mergeIterator
}
} else {
cachedShuffleRDD
diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/adaptive/ColumnarCustomShuffleReaderExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/adaptive/ColumnarCustomShuffleReaderExec.scala
index d7365c8c152b507c366d917cce8fd127acf55a2b..098b606becd652009b272f2bbbc4e130f81e8f80 100644
--- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/adaptive/ColumnarCustomShuffleReaderExec.scala
+++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/adaptive/ColumnarCustomShuffleReaderExec.scala
@@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, UnknownPartit
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.exchange.{ReusedExchangeExec, ShuffleExchangeLike}
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
-import org.apache.spark.sql.execution.util.MergeIterator
+import org.apache.spark.sql.execution.util.{MergeIterator, SparkMemoryUtils}
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.vectorized.ColumnarBatch
@@ -207,10 +207,14 @@ case class ColumnarCustomShuffleReaderExec(
if (enableShuffleBatchMerge) {
cachedShuffleRDD.mapPartitionsWithIndexInternal { (index, iter) =>
- new MergeIterator(iter,
+ val mergeIterator = new MergeIterator(iter,
StructType.fromAttributes(child.output),
longMetric("numMergedVecBatches"),
longMetric("bypassVecBatches"))
+ SparkMemoryUtils.addLeakSafeTaskCompletionListener[Unit](_ => {
+ mergeIterator.close()
+ })
+ mergeIterator
}
} else {
cachedShuffleRDD
diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/joins/ColumnarBroadcastHashJoinExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/joins/ColumnarBroadcastHashJoinExec.scala
index 40e43f707a5fd9cfe647d677df82558bfc0927e3..d5fd3a10b81aadba5acdc378bef55d63955f7d4b 100644
--- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/joins/ColumnarBroadcastHashJoinExec.scala
+++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/joins/ColumnarBroadcastHashJoinExec.scala
@@ -340,7 +340,7 @@ case class ColumnarBroadcastHashJoinExec(
Optional.empty()
}
- def createBuildOpFactoryAndOp(): (OmniHashBuilderWithExprOperatorFactory, OmniOperator) = {
+ def createBuildOpFactoryAndOp(isShared: Boolean): (OmniHashBuilderWithExprOperatorFactory, OmniOperator) = {
val startBuildCodegen = System.nanoTime()
val opFactory =
new OmniHashBuilderWithExprOperatorFactory(lookupJoinType, buildTypes, buildJoinColsExp, 1,
@@ -349,6 +349,10 @@ case class ColumnarBroadcastHashJoinExec(
val op = opFactory.createOperator()
buildCodegenTime += NANOSECONDS.toMillis(System.nanoTime() - startBuildCodegen)
+ if (isShared) {
+ OmniHashBuilderWithExprOperatorFactory.saveHashBuilderOperatorAndFactory(buildPlan.id,
+ opFactory, op)
+ }
val deserializer = VecBatchSerializerFactory.create()
relation.value.buildData.foreach { input =>
val startBuildInput = System.nanoTime()
@@ -356,7 +360,19 @@ case class ColumnarBroadcastHashJoinExec(
buildAddInputTime += NANOSECONDS.toMillis(System.nanoTime() - startBuildInput)
}
val startBuildGetOp = System.nanoTime()
- op.getOutput
+ try {
+ op.getOutput
+ } catch {
+ case e: Exception => {
+ if (isShared) {
+ OmniHashBuilderWithExprOperatorFactory.dereferenceHashBuilderOperatorAndFactory(buildPlan.id)
+ } else {
+ op.close()
+ opFactory.close()
+ }
+ throw new RuntimeException("HashBuilder getOutput failed")
+ }
+ }
buildGetOutputTime += NANOSECONDS.toMillis(System.nanoTime() - startBuildGetOp)
(opFactory, op)
}
@@ -368,11 +384,9 @@ case class ColumnarBroadcastHashJoinExec(
try {
buildOpFactory = OmniHashBuilderWithExprOperatorFactory.getHashBuilderOperatorFactory(buildPlan.id)
if (buildOpFactory == null) {
- val (opFactory, op) = createBuildOpFactoryAndOp()
+ val (opFactory, op) = createBuildOpFactoryAndOp(true)
buildOpFactory = opFactory
buildOp = op
- OmniHashBuilderWithExprOperatorFactory.saveHashBuilderOperatorAndFactory(buildPlan.id,
- buildOpFactory, buildOp)
}
} catch {
case e: Exception => {
@@ -382,7 +396,7 @@ case class ColumnarBroadcastHashJoinExec(
OmniHashBuilderWithExprOperatorFactory.gLock.unlock()
}
} else {
- val (opFactory, op) = createBuildOpFactoryAndOp()
+ val (opFactory, op) = createBuildOpFactoryAndOp(false)
buildOpFactory = opFactory
buildOp = op
}
@@ -492,7 +506,11 @@ case class ColumnarBroadcastHashJoinExec(
}
if (enableJoinBatchMerge) {
- new MergeIterator(iterBatch, resultSchema, numMergedVecBatches, bypassVecBatches)
+ val mergeIterator = new MergeIterator(iterBatch, resultSchema, numMergedVecBatches, bypassVecBatches)
+ SparkMemoryUtils.addLeakSafeTaskCompletionListener[Unit](_ => {
+ mergeIterator.close()
+ })
+ mergeIterator
} else {
iterBatch
}
diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/joins/ColumnarSortMergeJoinExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/joins/ColumnarSortMergeJoinExec.scala
index f3fb45a1cab4a0b1f9fe23c45c9f298b728a3761..8a61c3fdbe61f4a568d10c796dd32ba5e64bbab0 100644
--- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/joins/ColumnarSortMergeJoinExec.scala
+++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/joins/ColumnarSortMergeJoinExec.scala
@@ -285,7 +285,11 @@ case class ColumnarSortMergeJoinExec(
}
if (enableSortMergeJoinBatchMerge) {
- new MergeIterator(iterBatch, resultSchema, numMergedVecBatches, bypassVecBatches)
+ val mergeIterator = new MergeIterator(iterBatch, resultSchema, numMergedVecBatches, bypassVecBatches)
+ SparkMemoryUtils.addLeakSafeTaskCompletionListener[Unit](_ => {
+ mergeIterator.close()
+ })
+ mergeIterator
} else {
iterBatch
}
diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/util/MergeIterator.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/util/MergeIterator.scala
index 105d3435367c2818c7f2b1da267ca19f43ce27ef..cea35f2a37d159fee797632c49101939b09ae67c 100644
--- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/util/MergeIterator.scala
+++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/util/MergeIterator.scala
@@ -44,29 +44,40 @@ class MergeIterator(iter: Iterator[ColumnarBatch], localSchema: StructType,
private def createOmniVectors(schema: StructType, columnSize: Int): Array[Vec] = {
val vecs = new Array[Vec](schema.fields.length)
- schema.fields.zipWithIndex.foreach { case (field, index) =>
- field.dataType match {
- case LongType =>
- vecs(index) = new LongVec(columnSize)
- case DateType | IntegerType =>
- vecs(index) = new IntVec(columnSize)
- case ShortType =>
- vecs(index) = new ShortVec(columnSize)
- case DoubleType =>
- vecs(index) = new DoubleVec(columnSize)
- case BooleanType =>
- vecs(index) = new BooleanVec(columnSize)
- case StringType =>
- val vecType: DataType = sparkTypeToOmniType(field.dataType, field.metadata)
- vecs(index) = new VarcharVec(columnSize)
- case dt: DecimalType =>
- if (DecimalType.is64BitDecimalType(dt)) {
+ try {
+ schema.fields.zipWithIndex.foreach { case (field, index) =>
+ field.dataType match {
+ case LongType =>
vecs(index) = new LongVec(columnSize)
- } else {
- vecs(index) = new Decimal128Vec(columnSize)
+ case DateType | IntegerType =>
+ vecs(index) = new IntVec(columnSize)
+ case ShortType =>
+ vecs(index) = new ShortVec(columnSize)
+ case DoubleType =>
+ vecs(index) = new DoubleVec(columnSize)
+ case BooleanType =>
+ vecs(index) = new BooleanVec(columnSize)
+ case StringType =>
+ val vecType: DataType = sparkTypeToOmniType(field.dataType, field.metadata)
+ vecs(index) = new VarcharVec(columnSize)
+ case dt: DecimalType =>
+ if (DecimalType.is64BitDecimalType(dt)) {
+ vecs(index) = new LongVec(columnSize)
+ } else {
+ vecs(index) = new Decimal128Vec(columnSize)
+ }
+ case _ =>
+ throw new UnsupportedOperationException("Fail to create omni vector, unsupported fields")
+ }
+ }
+ } catch {
+ case e: Exception => {
+ for (vec <- vecs) {
+ if (vec != null) {
+ vec.close()
}
- case _ =>
- throw new UnsupportedOperationException("Fail to create omni vector, unsupported fields")
+ }
+ throw new RuntimeException("allocate memory failed!")
}
}
vecs
@@ -165,4 +176,15 @@ class MergeIterator(iter: Iterator[ColumnarBatch], localSchema: StructType,
def isFull(): Boolean = {
totalRows > maxRowCount || currentBatchSizeInBytes >= maxBatchSizeInBytes
}
+
+ def close(): Unit = {
+ for (elem <- bufferedVecBatch) {
+ elem.releaseAllVectors()
+ elem.close()
+ }
+ for (elem <- outputQueue) {
+ elem.releaseAllVectors()
+ elem.close()
+ }
+ }
}