diff --git a/omnidata/omnidata-spark-connector/connector/pom.xml b/omnidata/omnidata-spark-connector/connector/pom.xml
index 4fd2668b10b13f63ccef418c56af7ee9d1bfe8ce..dda5fe43f387e6d1fa7346321109b821f8cd0322 100644
--- a/omnidata/omnidata-spark-connector/connector/pom.xml
+++ b/omnidata/omnidata-spark-connector/connector/pom.xml
@@ -5,12 +5,12 @@
org.apache.spark
omnidata-spark-connector-root
- 1.4.0
+ 1.5.0
4.0.0
boostkit-omnidata-spark-sql_2.12-3.1.1
- 1.4.0
+ 1.5.0
boostkit omnidata spark sql
2021
jar
@@ -55,7 +55,7 @@
com.huawei.boostkit
boostkit-omnidata-stub
- 1.4.0
+ 1.5.0
compile
diff --git a/omnioperator/omniop-native-reader/cpp/src/CMakeLists.txt b/omnioperator/omniop-native-reader/cpp/src/CMakeLists.txt
index 7ba2967f87f8ebcaeb6959b51c23cef857462a07..8aa1e62449a5c8729506f5854772fb8a14b687b8 100644
--- a/omnioperator/omniop-native-reader/cpp/src/CMakeLists.txt
+++ b/omnioperator/omniop-native-reader/cpp/src/CMakeLists.txt
@@ -48,7 +48,7 @@ target_link_libraries (${PROJ_TARGET} PUBLIC
Arrow::arrow_shared
Parquet::parquet_shared
orc
- boostkit-omniop-vector-1.4.0-aarch64
+ boostkit-omniop-vector-1.5.0-aarch64
hdfs
)
diff --git a/omnioperator/omniop-native-reader/cpp/src/jni/OrcColumnarBatchJniReader.cpp b/omnioperator/omniop-native-reader/cpp/src/jni/OrcColumnarBatchJniReader.cpp
index e1300c4e062857780c1c30d56fe9d366cce6f868..f8ee293e2dc7576f0784fe5f2e862141616dc3f4 100644
--- a/omnioperator/omniop-native-reader/cpp/src/jni/OrcColumnarBatchJniReader.cpp
+++ b/omnioperator/omniop-native-reader/cpp/src/jni/OrcColumnarBatchJniReader.cpp
@@ -28,7 +28,6 @@ using namespace std;
using namespace orc;
static constexpr int32_t MAX_DECIMAL64_DIGITS = 18;
-bool isDecimal64Transfor128 = false;
// vecFildsNames存储文件每列的列名,从orc reader c++侧获取,回传到java侧使用
JNIEXPORT jlong JNICALL Java_com_huawei_boostkit_scan_jni_OrcColumnarBatchJniReader_initializeReader(JNIEnv *env,
@@ -74,7 +73,7 @@ JNIEXPORT jlong JNICALL Java_com_huawei_boostkit_scan_jni_OrcColumnarBatchJniRea
UriInfo uri{schemaStr, fileStr, hostStr, std::to_string(port)};
reader = createReader(orc::readFileOverride(uri), readerOptions);
std::vector orcColumnNames = reader->getAllFiedsName();
- for (int i = 0; i < orcColumnNames.size(); i++) {
+ for (uint32_t i = 0; i < orcColumnNames.size(); i++) {
jstring fildname = env->NewStringUTF(orcColumnNames[i].c_str());
// use ArrayList and function
env->CallBooleanMethod(vecFildsNames, arrayListAdd, fildname);
@@ -268,12 +267,7 @@ JNIEXPORT jlong JNICALL Java_com_huawei_boostkit_scan_jni_OrcColumnarBatchJniRea
{
JNI_FUNC_START
orc::Reader *readerPtr = (orc::Reader *)reader;
- // Get if the decimal for spark or hive
- jboolean jni_isDecimal64Transfor128 = env->CallBooleanMethod(jsonObj, jsonMethodHas,
- env->NewStringUTF("isDecimal64Transfor128"));
- if (jni_isDecimal64Transfor128) {
- isDecimal64Transfor128 = true;
- }
+
// get offset from json obj
jlong offset = env->CallLongMethod(jsonObj, jsonMethodLong, env->NewStringUTF("offset"));
jlong length = env->CallLongMethod(jsonObj, jsonMethodLong, env->NewStringUTF("length"));
@@ -520,45 +514,86 @@ std::unique_ptr CopyToOmniDecimal128VecFrom64(orc::ColumnVectorBatch
return newVector;
}
-std::unique_ptr CopyToOmniVec(const orc::Type *type, int &omniTypeId, orc::ColumnVectorBatch *field,
- bool isDecimal64Transfor128)
+std::unique_ptr DealLongVectorBatch(DataTypeId id, orc::ColumnVectorBatch *field) {
+ switch (id) {
+ case omniruntime::type::OMNI_BOOLEAN:
+ return CopyFixedWidth(field);
+ case omniruntime::type::OMNI_SHORT:
+ return CopyFixedWidth(field);
+ case omniruntime::type::OMNI_INT:
+ return CopyFixedWidth(field);
+ case omniruntime::type::OMNI_LONG:
+ return CopyOptimizedForInt64(field);
+ case omniruntime::type::OMNI_DATE32:
+ return CopyFixedWidth(field);
+ case omniruntime::type::OMNI_DATE64:
+ return CopyOptimizedForInt64(field);
+ default: {
+ throw std::runtime_error("DealLongVectorBatch not support for type: " + id);
+ }
+ }
+}
+
+std::unique_ptr DealDoubleVectorBatch(DataTypeId id, orc::ColumnVectorBatch *field) {
+ switch (id) {
+ case omniruntime::type::OMNI_DOUBLE:
+ return CopyOptimizedForInt64(field);
+ default: {
+ throw std::runtime_error("DealDoubleVectorBatch not support for type: " + id);
+ }
+ }
+}
+
+std::unique_ptr DealDecimal64VectorBatch(DataTypeId id, orc::ColumnVectorBatch *field) {
+ switch (id) {
+ case omniruntime::type::OMNI_DECIMAL64:
+ return CopyToOmniDecimal64Vec(field);
+ case omniruntime::type::OMNI_DECIMAL128:
+ return CopyToOmniDecimal128VecFrom64(field);
+ default: {
+ throw std::runtime_error("DealDecimal64VectorBatch not support for type: " + id);
+ }
+ }
+}
+
+std::unique_ptr DealDecimal128VectorBatch(DataTypeId id, orc::ColumnVectorBatch *field) {
+ switch (id) {
+ case omniruntime::type::OMNI_DECIMAL128:
+ return CopyToOmniDecimal128Vec(field);
+ default: {
+ throw std::runtime_error("DealDecimal128VectorBatch not support for type: " + id);
+ }
+ }
+}
+
+std::unique_ptr CopyToOmniVec(const orc::Type *type, int omniTypeId, orc::ColumnVectorBatch *field)
{
+ DataTypeId dataTypeId = static_cast(omniTypeId);
switch (type->getKind()) {
case orc::TypeKind::BOOLEAN:
- omniTypeId = static_cast(OMNI_BOOLEAN);
- return CopyFixedWidth(field);
case orc::TypeKind::SHORT:
- omniTypeId = static_cast(OMNI_SHORT);
- return CopyFixedWidth(field);
case orc::TypeKind::DATE:
- omniTypeId = static_cast(OMNI_DATE32);
- return CopyFixedWidth(field);
case orc::TypeKind::INT:
- omniTypeId = static_cast(OMNI_INT);
- return CopyFixedWidth(field);
case orc::TypeKind::LONG:
- omniTypeId = static_cast(OMNI_LONG);
- return CopyOptimizedForInt64(field);
+ return DealLongVectorBatch(dataTypeId, field);
case orc::TypeKind::DOUBLE:
- omniTypeId = static_cast(OMNI_DOUBLE);
- return CopyOptimizedForInt64(field);
+ return DealDoubleVectorBatch(dataTypeId, field);
case orc::TypeKind::CHAR:
- omniTypeId = static_cast(OMNI_VARCHAR);
+ if (dataTypeId != OMNI_VARCHAR) {
+ throw std::runtime_error("Cannot transfer to other OMNI_TYPE but VARCHAR for orc char");
+ }
return CopyCharType(field);
case orc::TypeKind::STRING:
case orc::TypeKind::VARCHAR:
- omniTypeId = static_cast(OMNI_VARCHAR);
+ if (dataTypeId != OMNI_VARCHAR) {
+ throw std::runtime_error("Cannot transfer to other OMNI_TYPE but VARCHAR for orc string/varchar");
+ }
return CopyVarWidth(field);
case orc::TypeKind::DECIMAL:
if (type->getPrecision() > MAX_DECIMAL64_DIGITS) {
- omniTypeId = static_cast(OMNI_DECIMAL128);
- return CopyToOmniDecimal128Vec(field);
- } else if (isDecimal64Transfor128) {
- omniTypeId = static_cast(OMNI_DECIMAL128);
- return CopyToOmniDecimal128VecFrom64(field);
+ return DealDecimal128VectorBatch(dataTypeId, field);
} else {
- omniTypeId = static_cast(OMNI_DECIMAL64);
- return CopyToOmniDecimal64Vec(field);
+ return DealDecimal64VectorBatch(dataTypeId, field);
}
default: {
throw std::runtime_error("Native ColumnarFileScan Not support For This Type: " + type->getKind());
@@ -569,28 +604,37 @@ std::unique_ptr CopyToOmniVec(const orc::Type *type, int &omniTypeId
JNIEXPORT jlong JNICALL Java_com_huawei_boostkit_scan_jni_OrcColumnarBatchJniReader_recordReaderNext(JNIEnv *env,
jobject jObj, jlong rowReader, jlong batch, jintArray typeId, jlongArray vecNativeId)
{
+ JNI_FUNC_START
orc::RowReader *rowReaderPtr = (orc::RowReader *)rowReader;
orc::ColumnVectorBatch *columnVectorBatch = (orc::ColumnVectorBatch *)batch;
std::vector> omniVecs;
const orc::Type &baseTp = rowReaderPtr->getSelectedType();
uint64_t batchRowSize = 0;
+ auto ptr = env->GetIntArrayElements(typeId, JNI_FALSE);
+ if (ptr == NULL) {
+ throw std::runtime_error("Types should not be null");
+ }
+ int32_t arrLen = (int32_t) env->GetArrayLength(typeId);
if (rowReaderPtr->next(*columnVectorBatch)) {
orc::StructVectorBatch *root = dynamic_cast(columnVectorBatch);
batchRowSize = root->fields[0]->numElements;
int32_t vecCnt = root->fields.size();
- std::vector omniTypeIds(vecCnt, 0);
+ if (vecCnt != arrLen) {
+ throw std::runtime_error("Types should align to root fields");
+ }
for (int32_t id = 0; id < vecCnt; id++) {
auto type = baseTp.getSubtype(id);
- omniVecs.emplace_back(CopyToOmniVec(type, omniTypeIds[id], root->fields[id], isDecimal64Transfor128));
+ int omniTypeId = ptr[id];
+ omniVecs.emplace_back(CopyToOmniVec(type, omniTypeId, root->fields[id]));
}
for (int32_t id = 0; id < vecCnt; id++) {
- env->SetIntArrayRegion(typeId, id, 1, omniTypeIds.data() + id);
jlong omniVec = reinterpret_cast(omniVecs[id].release());
env->SetLongArrayRegion(vecNativeId, id, 1, &omniVec);
}
}
return (jlong) batchRowSize;
+ JNI_FUNC_END(runtimeExceptionClass)
}
/*
diff --git a/omnioperator/omniop-native-reader/cpp/src/jni/OrcColumnarBatchJniReader.h b/omnioperator/omniop-native-reader/cpp/src/jni/OrcColumnarBatchJniReader.h
index 829f5c0744d3d563601ec6506ebfc82b5a020e93..8b942fe8b3e0975cccfece4a68e9112b6c550802 100644
--- a/omnioperator/omniop-native-reader/cpp/src/jni/OrcColumnarBatchJniReader.h
+++ b/omnioperator/omniop-native-reader/cpp/src/jni/OrcColumnarBatchJniReader.h
@@ -141,8 +141,7 @@ int BuildLeaves(PredicateOperatorType leafOp, std::vector &litList
bool StringToBool(const std::string &boolStr);
-int CopyToOmniVec(const orc::Type *type, int &omniTypeId, uint64_t &omniVecId, orc::ColumnVectorBatch *field,
- bool isDecimal64Transfor128);
+std::unique_ptr CopyToOmniVec(const orc::Type *type, int omniTypeId, orc::ColumnVectorBatch *field);
#ifdef __cplusplus
}
diff --git a/omnioperator/omniop-native-reader/cpp/src/parquet/ParquetDecoder.h b/omnioperator/omniop-native-reader/cpp/src/parquet/ParquetDecoder.h
index a36c2e2acb430d15e32f1a1da1be6c83700ecd7d..0bc32f33277da54253f4ca81bab490a3c2a1875d 100644
--- a/omnioperator/omniop-native-reader/cpp/src/parquet/ParquetDecoder.h
+++ b/omnioperator/omniop-native-reader/cpp/src/parquet/ParquetDecoder.h
@@ -346,6 +346,7 @@ namespace omniruntime::reader {
vec->SetValue(i + offset, value);
}
values_decoded += num_indices;
+ offset += num_indices;
}
*out_num_values = values_decoded;
return Status::OK();
diff --git a/omnioperator/omniop-native-reader/cpp/test/CMakeLists.txt b/omnioperator/omniop-native-reader/cpp/test/CMakeLists.txt
index 3d1d559df94b1137db424b3318b86b956830e09c..cff2824fae3053fd2ec04b16471ab775a23b2de9 100644
--- a/omnioperator/omniop-native-reader/cpp/test/CMakeLists.txt
+++ b/omnioperator/omniop-native-reader/cpp/test/CMakeLists.txt
@@ -31,7 +31,7 @@ target_link_libraries(${TP_TEST_TARGET}
pthread
stdc++
dl
- boostkit-omniop-vector-1.4.0-aarch64
+ boostkit-omniop-vector-1.5.0-aarch64
securec
spark_columnar_plugin)
diff --git a/omnioperator/omniop-native-reader/java/pom.xml b/omnioperator/omniop-native-reader/java/pom.xml
index 99c66a43076a3c9b7f8b529749436e9c4ea10ed9..8f6a401efe919fc58e426b25dd185f383fca2266 100644
--- a/omnioperator/omniop-native-reader/java/pom.xml
+++ b/omnioperator/omniop-native-reader/java/pom.xml
@@ -8,7 +8,7 @@
com.huawei.boostkit
boostkit-omniop-native-reader
jar
- 3.3.1-1.4.0
+ 3.3.1-1.5.0
BoostKit Spark Native Sql Engine Extension With OmniOperator
@@ -31,7 +31,7 @@
com.huawei.boostkit
boostkit-omniop-bindings
aarch64
- 1.4.0
+ 1.5.0
org.slf4j
diff --git a/omnioperator/omniop-spark-extension-ock/ock-omniop-tuning/pom.xml b/omnioperator/omniop-spark-extension-ock/ock-omniop-tuning/pom.xml
index 608a3ca714fe707b61758f870d0b32877e22f9b2..e8815aea6ee6936e5d7a4ab2083ec115a3c347f7 100644
--- a/omnioperator/omniop-spark-extension-ock/ock-omniop-tuning/pom.xml
+++ b/omnioperator/omniop-spark-extension-ock/ock-omniop-tuning/pom.xml
@@ -65,12 +65,12 @@
com.huawei.boostkit
boostkit-omniop-bindings
- 1.3.0
+ 1.4.0
com.huawei.kunpeng
boostkit-omniop-spark
- 3.3.1-1.3.0
+ 3.3.1-1.4.0
org.scalatest
diff --git a/omnioperator/omniop-spark-extension-ock/pom.xml b/omnioperator/omniop-spark-extension-ock/pom.xml
index 84c9208cc3caebf6a4a3dff7a9b4cd9b0b4d63ee..8201865444e199606338b6812f6d11b5fe8d250b 100644
--- a/omnioperator/omniop-spark-extension-ock/pom.xml
+++ b/omnioperator/omniop-spark-extension-ock/pom.xml
@@ -62,12 +62,12 @@
com.huawei.boostkit
boostkit-omniop-bindings
- 1.3.0
+ 1.4.0
com.huawei.kunpeng
boostkit-omniop-spark
- 3.3.1-1.3.0
+ 3.3.1-1.4.0
com.huawei.ock
diff --git a/omnioperator/omniop-spark-extension/cpp/src/CMakeLists.txt b/omnioperator/omniop-spark-extension/cpp/src/CMakeLists.txt
index 26df3cb85b9255ee0969d8631deb6ab76488101d..fe4dc5fc593d0435571859b9b687bff24170f265 100644
--- a/omnioperator/omniop-spark-extension/cpp/src/CMakeLists.txt
+++ b/omnioperator/omniop-spark-extension/cpp/src/CMakeLists.txt
@@ -42,7 +42,7 @@ target_link_libraries (${PROJ_TARGET} PUBLIC
snappy
lz4
zstd
- boostkit-omniop-vector-1.4.0-aarch64
+ boostkit-omniop-vector-1.5.0-aarch64
)
set_target_properties(${PROJ_TARGET} PROPERTIES
diff --git a/omnioperator/omniop-spark-extension/cpp/src/io/SparkFile.cc b/omnioperator/omniop-spark-extension/cpp/src/io/SparkFile.cc
index 3c6e3b3bc31746c7c28a2a63f5bd1b5b1b2a3e44..7e46b9f560f80d1c689e9a1e9781c08de8b3ee54 100644
--- a/omnioperator/omniop-spark-extension/cpp/src/io/SparkFile.cc
+++ b/omnioperator/omniop-spark-extension/cpp/src/io/SparkFile.cc
@@ -141,14 +141,18 @@ namespace spark {
if (closed) {
throw std::logic_error("Cannot write to closed stream.");
}
- ssize_t bytesWrite = ::write(file, buf, length);
- if (bytesWrite == -1) {
- throw std::runtime_error("Bad write of " + filename);
- }
- if (static_cast(bytesWrite) != length) {
- throw std::runtime_error("Short write of " + filename);
+
+ size_t bytesWritten = 0;
+ while (bytesWritten < length) {
+ ssize_t actualBytes = ::write(file, static_cast(buf) + bytesWritten, length - bytesWritten);
+ if (actualBytes <= 0) {
+ close();
+ std::string errMsg = "Bad write of " + filename + " since " + strerror(errno) + ",actual write bytes " +
+ std::to_string(actualBytes) + ".";
+ throw std::runtime_error(errMsg);
+ }
+ bytesWritten += actualBytes;
}
- bytesWritten += static_cast(bytesWrite);
}
const std::string& getName() const override {
@@ -177,4 +181,4 @@ namespace spark {
InputStream::~InputStream() {
// PASS
};
-}
\ No newline at end of file
+}
diff --git a/omnioperator/omniop-spark-extension/cpp/src/jni/SparkJniWrapper.cpp b/omnioperator/omniop-spark-extension/cpp/src/jni/SparkJniWrapper.cpp
index 14785a9cf453f5925974f8b85dc80538a5b85a17..d67ba33c7c0da99334e5f573ef397db0494ddfad 100644
--- a/omnioperator/omniop-spark-extension/cpp/src/jni/SparkJniWrapper.cpp
+++ b/omnioperator/omniop-spark-extension/cpp/src/jni/SparkJniWrapper.cpp
@@ -34,7 +34,7 @@ JNIEXPORT jlong JNICALL Java_com_huawei_boostkit_spark_jni_SparkJniWrapper_nativ
jstring jInputType, jint jNumCols, jint buffer_size,
jstring compression_type_jstr, jstring data_file_jstr, jint num_sub_dirs,
jstring local_dirs_jstr, jlong compress_block_size,
- jint spill_batch_row, jlong spill_memory_threshold)
+ jint spill_batch_row, jlong task_spill_memory_threshold, jlong executor_spill_memory_threshold)
{
JNI_FUNC_START
if (partitioning_name_jstr == nullptr) {
@@ -107,8 +107,11 @@ JNIEXPORT jlong JNICALL Java_com_huawei_boostkit_spark_jni_SparkJniWrapper_nativ
if (spill_batch_row > 0) {
splitOptions.spill_batch_row_num = spill_batch_row;
}
- if (spill_memory_threshold > 0) {
- splitOptions.spill_mem_threshold = spill_memory_threshold;
+ if (task_spill_memory_threshold > 0) {
+ splitOptions.task_spill_mem_threshold = task_spill_memory_threshold;
+ }
+ if (executor_spill_memory_threshold > 0) {
+ splitOptions.executor_spill_mem_threshold = executor_spill_memory_threshold;
}
if (compress_block_size > 0) {
splitOptions.compress_block_size = compress_block_size;
@@ -124,16 +127,16 @@ JNIEXPORT jlong JNICALL Java_com_huawei_boostkit_spark_jni_SparkJniWrapper_nativ
auto splitter = Splitter::Make(partitioning_name, inputDataTypesTmp, jNumCols, num_partitions,
std::move(splitOptions));
- return g_shuffleSplitterHolder.Insert(std::shared_ptr(splitter));
+ return reinterpret_cast(static_cast(splitter));
JNI_FUNC_END(runtimeExceptionClass)
}
JNIEXPORT jlong JNICALL Java_com_huawei_boostkit_spark_jni_SparkJniWrapper_split(
- JNIEnv *env, jobject jObj, jlong splitter_id, jlong jVecBatchAddress)
+ JNIEnv *env, jobject jObj, jlong splitter_addr, jlong jVecBatchAddress)
{
- auto splitter = g_shuffleSplitterHolder.Lookup(splitter_id);
+ auto splitter = reinterpret_cast(splitter_addr);
if (!splitter) {
- std::string error_message = "Invalid splitter id " + std::to_string(splitter_id);
+ std::string error_message = "Invalid splitter id " + std::to_string(splitter_addr);
env->ThrowNew(runtimeExceptionClass, error_message.c_str());
return -1;
}
@@ -146,14 +149,33 @@ JNIEXPORT jlong JNICALL Java_com_huawei_boostkit_spark_jni_SparkJniWrapper_split
JNI_FUNC_END_WITH_VECBATCH(runtimeExceptionClass, splitter->GetInputVecBatch())
}
+JNIEXPORT jlong JNICALL Java_com_huawei_boostkit_spark_jni_SparkJniWrapper_rowSplit(
+ JNIEnv *env, jobject jObj, jlong splitter_addr, jlong jVecBatchAddress)
+{
+ auto splitter = reinterpret_cast(splitter_addr);
+ if (!splitter) {
+ std::string error_message = "Invalid splitter id " + std::to_string(splitter_addr);
+ env->ThrowNew(runtimeExceptionClass, error_message.c_str());
+ return -1;
+ }
+
+ auto vecBatch = (VectorBatch *) jVecBatchAddress;
+ splitter->SetInputVecBatch(vecBatch);
+ JNI_FUNC_START
+ splitter->SplitByRow(vecBatch);
+ return 0L;
+ JNI_FUNC_END_WITH_VECBATCH(runtimeExceptionClass, splitter->GetInputVecBatch())
+}
+
JNIEXPORT jobject JNICALL Java_com_huawei_boostkit_spark_jni_SparkJniWrapper_stop(
- JNIEnv* env, jobject, jlong splitter_id)
+ JNIEnv* env, jobject, jlong splitter_addr)
{
JNI_FUNC_START
- auto splitter = g_shuffleSplitterHolder.Lookup(splitter_id);
+ auto splitter = reinterpret_cast(splitter_addr);
if (!splitter) {
- std::string error_message = "Invalid splitter id " + std::to_string(splitter_id);
+ std::string error_message = "Invalid splitter id " + std::to_string(splitter_addr);
env->ThrowNew(runtimeExceptionClass, error_message.c_str());
+ return nullptr;
}
splitter->Stop();
@@ -170,15 +192,40 @@ JNIEXPORT jobject JNICALL Java_com_huawei_boostkit_spark_jni_SparkJniWrapper_sto
JNI_FUNC_END(runtimeExceptionClass)
}
+JNIEXPORT jobject JNICALL Java_com_huawei_boostkit_spark_jni_SparkJniWrapper_rowStop(
+ JNIEnv* env, jobject, jlong splitter_addr)
+{
+ JNI_FUNC_START
+ auto splitter = reinterpret_cast(splitter_addr);
+ if (!splitter) {
+ std::string error_message = "Invalid splitter id " + std::to_string(splitter_addr);
+ env->ThrowNew(runtimeExceptionClass, error_message.c_str());
+ return nullptr;
+ }
+ splitter->StopByRow();
+
+ const auto& partition_length = splitter->PartitionLengths();
+ auto partition_length_arr = env->NewLongArray(partition_length.size());
+ auto src = reinterpret_cast(partition_length.data());
+ env->SetLongArrayRegion(partition_length_arr, 0, partition_length.size(), src);
+ jobject split_result = env->NewObject(
+ splitResultClass, splitResultConstructor, splitter->TotalComputePidTime(),
+ splitter->TotalWriteTime(), splitter->TotalSpillTime(),
+ splitter->TotalBytesWritten(), splitter->TotalBytesSpilled(), partition_length_arr);
+
+ return split_result;
+ JNI_FUNC_END(runtimeExceptionClass)
+}
+
JNIEXPORT void JNICALL Java_com_huawei_boostkit_spark_jni_SparkJniWrapper_close(
- JNIEnv* env, jobject, jlong splitter_id)
+ JNIEnv* env, jobject, jlong splitter_addr)
{
JNI_FUNC_START
- auto splitter = g_shuffleSplitterHolder.Lookup(splitter_id);
+ auto splitter = reinterpret_cast(splitter_addr);
if (!splitter) {
- std::string error_message = "Invalid splitter id " + std::to_string(splitter_id);
+ std::string error_message = "Invalid splitter id " + std::to_string(splitter_addr);
env->ThrowNew(runtimeExceptionClass, error_message.c_str());
}
- g_shuffleSplitterHolder.Erase(splitter_id);
+ delete splitter;
JNI_FUNC_END_VOID(runtimeExceptionClass)
}
diff --git a/omnioperator/omniop-spark-extension/cpp/src/jni/SparkJniWrapper.hh b/omnioperator/omniop-spark-extension/cpp/src/jni/SparkJniWrapper.hh
index c98c10383c4cab04c4770adb8ebdab0ebdb4424b..15076b2ab6c3d2900e0393200766b63db56f07d8 100644
--- a/omnioperator/omniop-spark-extension/cpp/src/jni/SparkJniWrapper.hh
+++ b/omnioperator/omniop-spark-extension/cpp/src/jni/SparkJniWrapper.hh
@@ -20,7 +20,6 @@
#include
#include
#include
-#include "concurrent_map.h"
#include "shuffle/splitter.h"
#ifndef SPARK_JNI_WRAPPER
@@ -39,21 +38,27 @@ Java_com_huawei_boostkit_spark_jni_SparkJniWrapper_nativeMake(
jstring jInputType, jint jNumCols, jint buffer_size,
jstring compression_type_jstr, jstring data_file_jstr, jint num_sub_dirs,
jstring local_dirs_jstr, jlong compress_block_size,
- jint spill_batch_row, jlong spill_memory_threshold);
+ jint spill_batch_row, jlong task_spill_memory_threshold, jlong executor_spill_memory_threshold);
JNIEXPORT jlong JNICALL
Java_com_huawei_boostkit_spark_jni_SparkJniWrapper_split(
JNIEnv* env, jobject jObj, jlong splitter_id, jlong jVecBatchAddress);
+JNIEXPORT jlong JNICALL
+Java_com_huawei_boostkit_spark_jni_SparkJniWrapper_rowSplit(
+ JNIEnv* env, jobject jObj, jlong splitter_id, jlong jVecBatchAddress);
+
JNIEXPORT jobject JNICALL
Java_com_huawei_boostkit_spark_jni_SparkJniWrapper_stop(
JNIEnv* env, jobject, jlong splitter_id);
-
+
+JNIEXPORT jobject JNICALL
+Java_com_huawei_boostkit_spark_jni_SparkJniWrapper_rowStop(
+ JNIEnv* env, jobject, jlong splitter_id);
+
JNIEXPORT void JNICALL
Java_com_huawei_boostkit_spark_jni_SparkJniWrapper_close(
- JNIEnv* env, jobject, jlong splitter_id);
-
-static ConcurrentMap> g_shuffleSplitterHolder;
+ JNIEnv* env, jobject, jlong splitter_id);
#ifdef __cplusplus
}
diff --git a/omnioperator/omniop-spark-extension/cpp/src/jni/jni_common.cpp b/omnioperator/omniop-spark-extension/cpp/src/jni/jni_common.cpp
index f0e3a225363913268b885cbcde2903d00eea7476..605107e52b2adad036f7682b161c46fecd84f190 100644
--- a/omnioperator/omniop-spark-extension/cpp/src/jni/jni_common.cpp
+++ b/omnioperator/omniop-spark-extension/cpp/src/jni/jni_common.cpp
@@ -109,8 +109,6 @@ void JNI_OnUnload(JavaVM* vm, void* reserved)
env->DeleteGlobalRef(jsonClass);
env->DeleteGlobalRef(arrayListClass);
env->DeleteGlobalRef(threadClass);
-
- g_shuffleSplitterHolder.Clear();
}
#endif //THESTRAL_PLUGIN_MASTER_JNI_COMMON_CPP
diff --git a/omnioperator/omniop-spark-extension/cpp/src/proto/vec_data.proto b/omnioperator/omniop-spark-extension/cpp/src/proto/vec_data.proto
index c40472020171692ea7b0acde2dd873efeda691f4..33ee64ec84d12b4bbf76c579645a8cb8a2e9db7d 100644
--- a/omnioperator/omniop-spark-extension/cpp/src/proto/vec_data.proto
+++ b/omnioperator/omniop-spark-extension/cpp/src/proto/vec_data.proto
@@ -57,4 +57,16 @@ message VecType {
NANOSEC = 3;
}
TimeUnit timeUnit = 6;
+}
+
+message ProtoRow {
+ bytes data = 1;
+ uint32 length = 2;
+}
+
+message ProtoRowBatch {
+ int32 rowCnt = 1;
+ int32 vecCnt = 2;
+ repeated VecType vecTypes = 3;
+ repeated ProtoRow rows = 4;
}
\ No newline at end of file
diff --git a/omnioperator/omniop-spark-extension/cpp/src/shuffle/splitter.cpp b/omnioperator/omniop-spark-extension/cpp/src/shuffle/splitter.cpp
index c503c38f085cd51ed29061fbfb4d91d2e573cd8c..2e85b61a2fe90ecdc16b23a73e9d3b081175cada 100644
--- a/omnioperator/omniop-spark-extension/cpp/src/shuffle/splitter.cpp
+++ b/omnioperator/omniop-spark-extension/cpp/src/shuffle/splitter.cpp
@@ -77,7 +77,7 @@ int Splitter::AllocatePartitionBuffers(int32_t partition_id, int32_t new_size) {
case SHUFFLE_DECIMAL128:
default: {
void *ptr_tmp = static_cast(options_.allocator->Alloc(new_size * (1 << column_type_id_[i])));
- fixed_valueBuffer_size_[partition_id] = new_size * (1 << column_type_id_[i]);
+ fixed_valueBuffer_size_[partition_id] += new_size * (1 << column_type_id_[i]);
if (nullptr == ptr_tmp) {
throw std::runtime_error("Allocator for AllocatePartitionBuffers Failed! ");
}
@@ -238,6 +238,7 @@ void Splitter::SplitBinaryVector(BaseVector *varcharVector, int col_schema) {
if (varcharVector->GetEncoding() == OMNI_DICTIONARY) {
auto vc = reinterpret_cast> *>(
varcharVector);
+ cached_vectorbatch_size_ += num_rows * (sizeof(bool) + sizeof(int32_t));
for (auto row = 0; row < num_rows; ++row) {
auto pid = partition_id_[row];
uint8_t *dst = nullptr;
@@ -272,7 +273,8 @@ void Splitter::SplitBinaryVector(BaseVector *varcharVector, int col_schema) {
}
} else {
auto vc = reinterpret_cast> *>(varcharVector);
- for (auto row = 0; row < num_rows; ++row) {
+ cached_vectorbatch_size_ += num_rows * (sizeof(bool) + sizeof(int32_t)) + sizeof(int32_t);
+ for (auto row = 0; row < num_rows; ++row) {
auto pid = partition_id_[row];
uint8_t *dst = nullptr;
uint32_t str_len = 0;
@@ -310,7 +312,6 @@ void Splitter::SplitBinaryVector(BaseVector *varcharVector, int col_schema) {
int Splitter::SplitBinaryArray(VectorBatch& vb)
{
- const auto num_rows = vb.GetRowCount();
auto vec_cnt_vb = vb.GetVectorCount();
auto vec_cnt_schema = singlePartitionFlag ? vec_cnt_vb : vec_cnt_vb - 1;
for (auto col_schema = 0; col_schema < vec_cnt_schema; ++col_schema) {
@@ -355,7 +356,7 @@ int Splitter::SplitFixedWidthValidityBuffer(VectorBatch& vb){
dst_addrs[pid] = const_cast(validity_buffer->data_);
std::memset(validity_buffer->data_, 0, new_size);
partition_fixed_width_buffers_[col][pid][0] = std::move(validity_buffer);
- fixed_nullBuffer_size_[pid] = new_size;
+ fixed_nullBuffer_size_[pid] += new_size;
}
}
@@ -412,9 +413,11 @@ int Splitter::CacheVectorBatch(int32_t partition_id, bool reset_buffers) {
}
}
}
- cached_vectorbatch_size_ += batch_partition_size;
- partition_cached_vectorbatch_[partition_id].push_back(std::move(bufferArrayTotal));
- partition_buffer_idx_base_[partition_id] = 0;
+ cached_vectorbatch_size_ += batch_partition_size;
+ partition_cached_vectorbatch_[partition_id].push_back(std::move(bufferArrayTotal));
+ fixed_valueBuffer_size_[partition_id] = 0;
+ fixed_nullBuffer_size_[partition_id] = 0;
+ partition_buffer_idx_base_[partition_id] = 0;
}
return 0;
}
@@ -458,7 +461,7 @@ int Splitter::DoSplit(VectorBatch& vb) {
TIME_NANO_OR_RAISE(total_spill_time_, SpillToTmpFile());
isSpill = true;
}
- if (cached_vectorbatch_size_ + current_fixed_alloc_buffer_size_ >= options_.spill_mem_threshold) {
+ if (cached_vectorbatch_size_ + current_fixed_alloc_buffer_size_ >= options_.task_spill_mem_threshold) {
LogsDebug(" Spill For Memory Size Threshold.");
TIME_NANO_OR_RAISE(total_spill_time_, SpillToTmpFile());
isSpill = true;
@@ -521,7 +524,7 @@ void Splitter::ToSplitterTypeId(int num_cols)
void Splitter::CastOmniToShuffleType(DataTypeId omniType, ShuffleTypeId shuffleType)
{
- vector_batch_col_types_.push_back(omniType);
+ proto_col_types_.push_back(CastOmniTypeIdToProtoVecType(omniType));
column_type_id_.push_back(shuffleType);
}
@@ -529,35 +532,21 @@ int Splitter::Split_Init(){
num_row_splited_ = 0;
cached_vectorbatch_size_ = 0;
- partition_id_cnt_cur_ = static_cast(malloc(num_partitions_ * sizeof(int32_t)));
- std::memset(partition_id_cnt_cur_, 0, num_partitions_ * sizeof(int32_t));
-
- partition_id_cnt_cache_ = static_cast(malloc(num_partitions_ * sizeof(uint64_t)));
- std::memset(partition_id_cnt_cache_, 0, num_partitions_ * sizeof(uint64_t));
-
- partition_buffer_size_ = static_cast(malloc(num_partitions_ * sizeof(int32_t)));
- std::memset(partition_buffer_size_, 0, num_partitions_ * sizeof(int32_t));
-
- partition_buffer_idx_base_ = static_cast(malloc(num_partitions_ * sizeof(int32_t)));
- std::memset(partition_buffer_idx_base_, 0, num_partitions_ * sizeof(int32_t));
-
- partition_buffer_idx_offset_ = static_cast(malloc(num_partitions_ * sizeof(int32_t)));
- std::memset(partition_buffer_idx_offset_, 0, num_partitions_ * sizeof(int32_t));
-
- partition_serialization_size_ = static_cast(malloc(num_partitions_ * sizeof(uint32_t)));
- std::memset(partition_serialization_size_, 0, num_partitions_ * sizeof(uint32_t));
+ partition_id_cnt_cur_ = new int32_t[num_partitions_]();
+ partition_id_cnt_cache_ = new uint64_t[num_partitions_]();
+ partition_buffer_size_ = new int32_t[num_partitions_]();
+ partition_buffer_idx_base_ = new int32_t[num_partitions_]();
+ partition_buffer_idx_offset_ = new int32_t[num_partitions_]();
+ partition_serialization_size_ = new uint32_t[num_partitions_]();
partition_cached_vectorbatch_.resize(num_partitions_);
fixed_width_array_idx_.clear();
partition_lengths_.resize(num_partitions_);
- fixed_valueBuffer_size_ = static_cast(malloc(num_partitions_ * sizeof(uint32_t)));
- std::memset(fixed_valueBuffer_size_, 0, num_partitions_ * sizeof(uint32_t));
+ fixed_valueBuffer_size_ = new uint32_t[num_partitions_]();
+ fixed_nullBuffer_size_ = new uint32_t[num_partitions_]();
- fixed_nullBuffer_size_ = static_cast(malloc(num_partitions_ * sizeof(uint32_t)));
- std::memset(fixed_nullBuffer_size_, 0, num_partitions_ * sizeof(uint32_t));
-
- //obtain configed dir from Environment Variables
+ // obtain configed dir from Environment Variables
configured_dirs_ = GetConfiguredLocalDirs();
sub_dir_selection_.assign(configured_dirs_.size(), 0);
@@ -601,19 +590,76 @@ int Splitter::Split_Init(){
for (auto i = 0; i < num_partitions_; ++i) {
vc_partition_array_buffers_[i].resize(column_type_id_.size());
}
+
+ partition_rows.resize(num_partitions_);
return 0;
}
int Splitter::Split(VectorBatch& vb )
{
- //计算vectorBatch分区信息
+ // 计算vectorBatch分区信息
LogsTrace(" split vb row number: %d ", vb.GetRowCount());
TIME_NANO_OR_RAISE(total_compute_pid_time_, ComputeAndCountPartitionId(vb));
- //执行分区动作
+ // 执行分区动作
DoSplit(vb);
return 0;
}
+int Splitter::SplitByRow(VectorBatch *vecBatch) {
+ int32_t rowCount = vecBatch->GetRowCount();
+ for (int pid = 0; pid < num_partitions_; ++pid) {
+ auto needCapacity = partition_rows[pid].size() + rowCount;
+ if (partition_rows[pid].capacity() < needCapacity) {
+ auto prepareCapacity = partition_rows[pid].capacity() * expansion;
+ auto newCapacity = prepareCapacity > needCapacity ? prepareCapacity : needCapacity;
+ partition_rows[pid].reserve(newCapacity);
+ }
+ }
+
+ if (singlePartitionFlag) {
+ RowBatch *rowBatch = VectorHelper::TransRowBatchFromVectorBatch(vecBatch);
+ for (int i = 0; i < rowCount; ++i) {
+ RowInfo *rowInfo = rowBatch->Get(i);
+ partition_rows[0].emplace_back(rowInfo);
+ total_input_size += rowInfo->length;
+ }
+ delete vecBatch;
+ } else {
+ auto pidVec = reinterpret_cast *>(vecBatch->Get(0));
+ auto tmpVectorBatch = new VectorBatch(rowCount);
+ for (int i = 1; i < vecBatch->GetVectorCount(); ++i) {
+ tmpVectorBatch->Append(vecBatch->Get(i));
+ }
+ vecBatch->ResizeVectorCount(1);
+ RowBatch *rowBatch = VectorHelper::TransRowBatchFromVectorBatch(tmpVectorBatch);
+ for (int i = 0; i < rowCount; ++i) {
+ auto pid = pidVec->GetValue(i);
+ RowInfo *rowInfo = rowBatch->Get(i);
+ partition_rows[pid].emplace_back(rowInfo);
+ total_input_size += rowInfo->length;
+ }
+ delete vecBatch;
+ delete tmpVectorBatch;
+ }
+
+ // spill
+ // process level: If the memory usage of the current executor exceeds the threshold, spill is triggered.
+ auto usedMemorySize = omniruntime::mem::MemoryManager::GetGlobalAccountedMemory();
+ if (usedMemorySize > options_.executor_spill_mem_threshold) {
+ TIME_NANO_OR_RAISE(total_spill_time_, SpillToTmpFileByRow());
+ total_input_size = 0;
+ isSpill = true;
+ }
+
+ // task level: If the memory usage of the current task exceeds the threshold, spill is triggered.
+ if (total_input_size > options_.task_spill_mem_threshold) {
+ TIME_NANO_OR_RAISE(total_spill_time_, SpillToTmpFileByRow());
+ total_input_size = 0;
+ isSpill = true;
+ }
+ return 0;
+}
+
std::shared_ptr Splitter::CaculateSpilledTmpFilePartitionOffsets() {
void *ptr_tmp = static_cast(options_.allocator->Alloc((num_partitions_ + 1) * sizeof(uint64_t)));
if (nullptr == ptr_tmp) {
@@ -633,8 +679,8 @@ std::shared_ptr Splitter::CaculateSpilledTmpFilePartitionOffsets() {
return ptrPartitionOffsets;
}
-spark::VecType::VecTypeId CastShuffleTypeIdToVecType(int32_t tmpType) {
- switch (tmpType) {
+spark::VecType::VecTypeId Splitter::CastOmniTypeIdToProtoVecType(int32_t omniType) {
+ switch (omniType) {
case OMNI_NONE:
return spark::VecType::VEC_TYPE_NONE;
case OMNI_INT:
@@ -674,7 +720,7 @@ spark::VecType::VecTypeId CastShuffleTypeIdToVecType(int32_t tmpType) {
case DataTypeId::OMNI_INVALID:
return spark::VecType::VEC_TYPE_INVALID;
default: {
- throw std::runtime_error("castShuffleTypeIdToVecType() unexpected ShuffleTypeId");
+ throw std::runtime_error("CastOmniTypeIdToProtoVecType() unexpected OmniTypeId");
}
}
};
@@ -822,13 +868,13 @@ int32_t Splitter::ProtoWritePartition(int32_t partition_id, std::unique_ptrset_veccnt(column_type_id_.size());
int fixColIndexTmp = 0;
for (size_t indexSchema = 0; indexSchema < column_type_id_.size(); indexSchema++) {
- spark::Vec * vec = vecBatchProto->add_vecs();
+ spark::Vec *vec = vecBatchProto->add_vecs();
switch (column_type_id_[indexSchema]) {
case ShuffleTypeId::SHUFFLE_1BYTE:
case ShuffleTypeId::SHUFFLE_2BYTE:
case ShuffleTypeId::SHUFFLE_4BYTE:
case ShuffleTypeId::SHUFFLE_8BYTE:
- case ShuffleTypeId::SHUFFLE_DECIMAL128:{
+ case ShuffleTypeId::SHUFFLE_DECIMAL128: {
SerializingFixedColumns(partition_id, *vec, fixColIndexTmp, &splitRowInfoTmp);
fixColIndexTmp++; // 定长序列化数量++
break;
@@ -842,13 +888,13 @@ int32_t Splitter::ProtoWritePartition(int32_t partition_id, std::unique_ptrmutable_vectype();
- vt->set_typeid_(CastShuffleTypeIdToVecType(vector_batch_col_types_[indexSchema]));
- LogsDebug("precision[indexSchema %d]: %d , scale[indexSchema %d]: %d ",
- indexSchema, input_col_types.inputDataPrecisions[indexSchema],
- indexSchema, input_col_types.inputDataScales[indexSchema]);
+ vt->set_typeid_(proto_col_types_[indexSchema]);
if(vt->typeid_() == spark::VecType::VEC_TYPE_DECIMAL128 || vt->typeid_() == spark::VecType::VEC_TYPE_DECIMAL64){
vt->set_precision(input_col_types.inputDataPrecisions[indexSchema]);
vt->set_scale(input_col_types.inputDataScales[indexSchema]);
+ LogsDebug("precision[indexSchema %d]: %d , scale[indexSchema %d]: %d ",
+ indexSchema, input_col_types.inputDataPrecisions[indexSchema],
+ indexSchema, input_col_types.inputDataScales[indexSchema]);
}
}
curBatch++;
@@ -882,7 +928,67 @@ int32_t Splitter::ProtoWritePartition(int32_t partition_id, std::unique_ptr &bufferStream, void *bufferOut, int32_t &sizeOut) {
+ uint64_t rowCount = partition_rows[partition_id].size();
+ uint64_t onceCopyRow = 0;
+ uint32_t batchCount = 0;
+ while (0 < rowCount) {
+ if (options_.spill_batch_row_num < rowCount) {
+ onceCopyRow = options_.spill_batch_row_num;
+ } else {
+ onceCopyRow = rowCount;
+ }
+
+ protoRowBatch->set_rowcnt(onceCopyRow);
+ protoRowBatch->set_veccnt(proto_col_types_.size());
+ for (int i = 0; i < proto_col_types_.size(); ++i) {
+ spark::VecType *vt = protoRowBatch->add_vectypes();
+ vt->set_typeid_(proto_col_types_[i]);
+ if(vt->typeid_() == spark::VecType::VEC_TYPE_DECIMAL128 || vt->typeid_() == spark::VecType::VEC_TYPE_DECIMAL64){
+ vt->set_precision(input_col_types.inputDataPrecisions[i]);
+ vt->set_scale(input_col_types.inputDataScales[i]);
+ LogsDebug("precision[indexSchema %d]: %d , scale[indexSchema %d]: %d ",
+ i, input_col_types.inputDataPrecisions[i],
+ i, input_col_types.inputDataScales[i]);
+ }
+ }
+
+ int64_t offset = batchCount * options_.spill_batch_row_num;
+ auto rowInfoPtr = partition_rows[partition_id].data() + offset;
+ for (int i = 0; i < onceCopyRow; ++i) {
+ RowInfo *rowInfo = rowInfoPtr[i];
+ spark::ProtoRow *protoRow = protoRowBatch->add_rows();
+ protoRow->set_data(rowInfo->row, rowInfo->length);
+ protoRow->set_length(rowInfo->length);
+ // free row memory
+ delete rowInfo;
+ }
+
+ if (protoRowBatch->ByteSizeLong() > UINT32_MAX) {
+ throw std::runtime_error("Unsafe static_cast long to uint_32t.");
+ }
+ uint32_t protoRowBatchSize = reversebytes_uint32t(static_cast(protoRowBatch->ByteSizeLong()));
+ if (bufferStream->Next(&bufferOut, &sizeOut)) {
+ std::memcpy(bufferOut, &protoRowBatchSize, sizeof(protoRowBatchSize));
+ if (sizeof(protoRowBatchSize) < sizeOut) {
+ bufferStream->BackUp(sizeOut - sizeof(protoRowBatchSize));
+ }
+ }
+
+ protoRowBatch->SerializeToZeroCopyStream(bufferStream.get());
+ rowCount -= onceCopyRow;
+ batchCount++;
+ protoRowBatch->Clear();
+ }
+ uint64_t partitionBatchSize = bufferStream->flush();
+ total_bytes_written_ += partitionBatchSize;
+ partition_lengths_[partition_id] += partitionBatchSize;
+ partition_rows[partition_id].clear();
+ LogsDebug(" partitionBatch write length: %lu", partitionBatchSize);
+ return 0;
}
int Splitter::protoSpillPartition(int32_t partition_id, std::unique_ptr &bufferStream) {
@@ -908,13 +1014,13 @@ int Splitter::protoSpillPartition(int32_t partition_id, std::unique_ptrset_veccnt(column_type_id_.size());
int fixColIndexTmp = 0;
for (size_t indexSchema = 0; indexSchema < column_type_id_.size(); indexSchema++) {
- spark::Vec * vec = vecBatchProto->add_vecs();
+ spark::Vec *vec = vecBatchProto->add_vecs();
switch (column_type_id_[indexSchema]) {
case ShuffleTypeId::SHUFFLE_1BYTE:
case ShuffleTypeId::SHUFFLE_2BYTE:
case ShuffleTypeId::SHUFFLE_4BYTE:
case ShuffleTypeId::SHUFFLE_8BYTE:
- case ShuffleTypeId::SHUFFLE_DECIMAL128:{
+ case ShuffleTypeId::SHUFFLE_DECIMAL128: {
SerializingFixedColumns(partition_id, *vec, fixColIndexTmp, &splitRowInfoTmp);
fixColIndexTmp++; // 定长序列化数量++
break;
@@ -928,13 +1034,13 @@ int Splitter::protoSpillPartition(int32_t partition_id, std::unique_ptrmutable_vectype();
- vt->set_typeid_(CastShuffleTypeIdToVecType(vector_batch_col_types_[indexSchema]));
- LogsDebug("precision[indexSchema %d]: %d , scale[indexSchema %d]: %d ",
- indexSchema, input_col_types.inputDataPrecisions[indexSchema],
- indexSchema, input_col_types.inputDataScales[indexSchema]);
+ vt->set_typeid_(proto_col_types_[indexSchema]);
if(vt->typeid_() == spark::VecType::VEC_TYPE_DECIMAL128 || vt->typeid_() == spark::VecType::VEC_TYPE_DECIMAL64){
vt->set_precision(input_col_types.inputDataPrecisions[indexSchema]);
vt->set_scale(input_col_types.inputDataScales[indexSchema]);
+ LogsDebug("precision[indexSchema %d]: %d , scale[indexSchema %d]: %d ",
+ indexSchema, input_col_types.inputDataPrecisions[indexSchema],
+ indexSchema, input_col_types.inputDataScales[indexSchema]);
}
}
curBatch++;
@@ -974,6 +1080,70 @@ int Splitter::protoSpillPartition(int32_t partition_id, std::unique_ptr &bufferStream) {
+ uint64_t rowCount = partition_rows[partition_id].size();
+ total_spill_row_num_ += rowCount;
+
+ uint64_t onceCopyRow = 0;
+ uint32_t batchCount = 0;
+ while (0 < rowCount) {
+ if (options_.spill_batch_row_num < rowCount) {
+ onceCopyRow = options_.spill_batch_row_num;
+ } else {
+ onceCopyRow = rowCount;
+ }
+
+ protoRowBatch->set_rowcnt(onceCopyRow);
+ protoRowBatch->set_veccnt(proto_col_types_.size());
+ for (int i = 0; i < proto_col_types_.size(); ++i) {
+ spark::VecType *vt = protoRowBatch->add_vectypes();
+ vt->set_typeid_(proto_col_types_[i]);
+ if(vt->typeid_() == spark::VecType::VEC_TYPE_DECIMAL128 || vt->typeid_() == spark::VecType::VEC_TYPE_DECIMAL64){
+ vt->set_precision(input_col_types.inputDataPrecisions[i]);
+ vt->set_scale(input_col_types.inputDataScales[i]);
+ LogsDebug("precision[indexSchema %d]: %d , scale[indexSchema %d]: %d ",
+ i, input_col_types.inputDataPrecisions[i],
+ i, input_col_types.inputDataScales[i]);
+ }
+ }
+
+ int64_t offset = batchCount * options_.spill_batch_row_num;
+ auto rowInfoPtr = partition_rows[partition_id].data() + offset;
+ for (int i = 0; i < onceCopyRow; ++i) {
+ RowInfo *rowInfo = rowInfoPtr[i];
+ spark::ProtoRow *protoRow = protoRowBatch->add_rows();
+ protoRow->set_data(rowInfo->row, rowInfo->length);
+ protoRow->set_length(rowInfo->length);
+ // free row memory
+ delete rowInfo;
+ }
+
+ if (protoRowBatch->ByteSizeLong() > UINT32_MAX) {
+ throw std::runtime_error("Unsafe static_cast long to uint_32t.");
+ }
+ uint32_t protoRowBatchSize = reversebytes_uint32t(static_cast(protoRowBatch->ByteSizeLong()));
+ void *buffer = nullptr;
+ if (!bufferStream->NextNBytes(&buffer, sizeof(protoRowBatchSize))) {
+ throw std::runtime_error("Allocate Memory Failed: Flush Spilled Data, Next failed.");
+ }
+ // set serizalized bytes to stream
+ memcpy(buffer, &protoRowBatchSize, sizeof(protoRowBatchSize));
+ LogsDebug(" A Slice Of vecBatchProtoSize: %d ", reversebytes_uint32t(protoRowBatchSize));
+
+ protoRowBatch->SerializeToZeroCopyStream(bufferStream.get());
+ rowCount -= onceCopyRow;
+ batchCount++;
+ protoRowBatch->Clear();
+ }
+
+ uint64_t partitionBatchSize = bufferStream->flush();
+ total_bytes_spilled_ += partitionBatchSize;
+ partition_serialization_size_[partition_id] = partitionBatchSize;
+ partition_rows[partition_id].clear();
+ LogsDebug(" partitionBatch write length: %lu", partitionBatchSize);
+ return 0;
+}
+
int Splitter::WriteDataFileProto() {
LogsDebug(" spill DataFile: %s ", (options_.next_spilled_file_dir + ".data").c_str());
std::unique_ptr outStream = writeLocalFile(options_.next_spilled_file_dir + ".data");
@@ -991,10 +1161,26 @@ int Splitter::WriteDataFileProto() {
return 0;
}
+int Splitter::WriteDataFileProtoByRow() {
+ LogsDebug(" spill DataFile: %s ", (options_.next_spilled_file_dir + ".data").c_str());
+ std::unique_ptr outStream = writeLocalFile(options_.next_spilled_file_dir + ".data");
+ WriterOptions options;
+ // tmp spilled file no need compression
+ options.setCompression(CompressionKind_NONE);
+ std::unique_ptr streamsFactory = createStreamsFactory(options, outStream.get());
+ std::unique_ptr bufferStream = streamsFactory->createStream();
+ // 顺序写入每个partition的offset
+ for (auto pid = 0; pid < num_partitions_; ++pid) {
+ protoSpillPartitionByRow(pid, bufferStream);
+ }
+ outStream->close();
+ return 0;
+}
+
void Splitter::MergeSpilled() {
for (auto pid = 0; pid < num_partitions_; ++pid) {
CacheVectorBatch(pid, true);
- partition_buffer_size_[pid] = 0; //溢写之后将其清零,条件溢写需要重新分配内存
+ partition_buffer_size_[pid] = 0; // 溢写之后将其清零,条件溢写需要重新分配内存
}
std::unique_ptr outStream = writeLocalFile(options_.data_file);
@@ -1004,13 +1190,13 @@ void Splitter::MergeSpilled() {
options.setCompressionBlockSize(options_.compress_block_size);
options.setCompressionStrategy(CompressionStrategy_COMPRESSION);
std::unique_ptr streamsFactory = createStreamsFactory(options, outStream.get());
- std::unique_ptr bufferOutPutStream = streamsFactory->createStream();
+ std::unique_ptr bufferOutPutStream = streamsFactory->createStream();
void* bufferOut = nullptr;
int sizeOut = 0;
for (int pid = 0; pid < num_partitions_; pid++) {
ProtoWritePartition(pid, bufferOutPutStream, bufferOut, sizeOut);
- LogsDebug(" MergeSplled traversal partition( %d ) ",pid);
+ LogsDebug(" MergeSpilled traversal partition( %d ) ", pid);
for (auto &pair : spilled_tmp_files_info_) {
auto tmpDataFilePath = pair.first + ".data";
auto tmpPartitionOffset = reinterpret_cast(pair.second->data_)[pid];
@@ -1047,10 +1233,56 @@ void Splitter::MergeSpilled() {
outStream->close();
}
+void Splitter::MergeSpilledByRow() {
+ std::unique_ptr outStream = writeLocalFile(options_.data_file);
+ LogsDebug(" Merge Spilled Tmp File: %s ", options_.data_file.c_str());
+ WriterOptions options;
+ options.setCompression(options_.compression_type);
+ options.setCompressionBlockSize(options_.compress_block_size);
+ options.setCompressionStrategy(CompressionStrategy_COMPRESSION);
+ std::unique_ptr streamsFactory = createStreamsFactory(options, outStream.get());
+ std::unique_ptr bufferOutPutStream = streamsFactory->createStream();
+
+ void* bufferOut = nullptr;
+ int sizeOut = 0;
+ for (int pid = 0; pid < num_partitions_; pid++) {
+ ProtoWritePartitionByRow(pid, bufferOutPutStream, bufferOut, sizeOut);
+ LogsDebug(" MergeSpilled traversal partition( %d ) ", pid);
+ for (auto &pair : spilled_tmp_files_info_) {
+ auto tmpDataFilePath = pair.first + ".data";
+ auto tmpPartitionOffset = reinterpret_cast(pair.second->data_)[pid];
+ auto tmpPartitionSize = reinterpret_cast(pair.second->data_)[pid + 1] - reinterpret_cast(pair.second->data_)[pid];
+ LogsDebug(" get Partition Stream...tmpPartitionOffset %d tmpPartitionSize %d path %s",
+ tmpPartitionOffset, tmpPartitionSize, tmpDataFilePath.c_str());
+ std::unique_ptr inputStream = readLocalFile(tmpDataFilePath);
+ uint64_t targetLen = tmpPartitionSize;
+ uint64_t seekPosit = tmpPartitionOffset;
+ uint64_t onceReadLen = 0;
+ while ((targetLen > 0) && bufferOutPutStream->Next(&bufferOut, &sizeOut)) {
+ onceReadLen = targetLen > sizeOut ? sizeOut : targetLen;
+ inputStream->read(bufferOut, onceReadLen, seekPosit);
+ targetLen -= onceReadLen;
+ seekPosit += onceReadLen;
+ if (onceReadLen < sizeOut) {
+ // Reached END.
+ bufferOutPutStream->BackUp(sizeOut - onceReadLen);
+ break;
+ }
+ }
+
+ uint64_t flushSize = bufferOutPutStream->flush();
+ total_bytes_written_ += flushSize;
+ LogsDebug(" Merge Flush Partition[%d] flushSize: %ld ", pid, flushSize);
+ partition_lengths_[pid] += flushSize;
+ }
+ }
+ outStream->close();
+}
+
void Splitter::WriteSplit() {
for (auto pid = 0; pid < num_partitions_; ++pid) {
CacheVectorBatch(pid, true);
- partition_buffer_size_[pid] = 0; //溢写之后将其清零,条件溢写需要重新分配内存
+ partition_buffer_size_[pid] = 0; // 溢写之后将其清零,条件溢写需要重新分配内存
}
std::unique_ptr outStream = writeLocalFile(options_.data_file);
@@ -1059,11 +1291,11 @@ void Splitter::WriteSplit() {
options.setCompressionBlockSize(options_.compress_block_size);
options.setCompressionStrategy(CompressionStrategy_COMPRESSION);
std::unique_ptr streamsFactory = createStreamsFactory(options, outStream.get());
- std::unique_ptr bufferOutPutStream = streamsFactory->createStream();
+ std::unique_ptr bufferOutPutStream = streamsFactory->createStream();
void* bufferOut = nullptr;
int32_t sizeOut = 0;
- for (auto pid = 0; pid < num_partitions_; ++ pid) {
+ for (auto pid = 0; pid < num_partitions_; ++pid) {
ProtoWritePartition(pid, bufferOutPutStream, bufferOut, sizeOut);
}
@@ -1074,6 +1306,23 @@ void Splitter::WriteSplit() {
outStream->close();
}
+void Splitter::WriteSplitByRow() {
+ std::unique_ptr outStream = writeLocalFile(options_.data_file);
+ WriterOptions options;
+ options.setCompression(options_.compression_type);
+ options.setCompressionBlockSize(options_.compress_block_size);
+ options.setCompressionStrategy(CompressionStrategy_COMPRESSION);
+ std::unique_ptr streamsFactory = createStreamsFactory(options, outStream.get());
+ std::unique_ptr bufferOutPutStream = streamsFactory->createStream();
+
+ void* bufferOut = nullptr;
+ int32_t sizeOut = 0;
+ for (auto pid = 0; pid < num_partitions_; ++pid) {
+ ProtoWritePartitionByRow(pid, bufferOutPutStream, bufferOut, sizeOut);
+ }
+ outStream->close();
+}
+
int Splitter::DeleteSpilledTmpFile() {
for (auto &pair : spilled_tmp_files_info_) {
auto tmpDataFilePath = pair.first + ".data";
@@ -1092,7 +1341,7 @@ int Splitter::DeleteSpilledTmpFile() {
int Splitter::SpillToTmpFile() {
for (auto pid = 0; pid < num_partitions_; ++pid) {
CacheVectorBatch(pid, true);
- partition_buffer_size_[pid] = 0; //溢写之后将其清零,条件溢写需要重新分配内存
+ partition_buffer_size_[pid] = 0; // 溢写之后将其清零,条件溢写需要重新分配内存
}
options_.next_spilled_file_dir = CreateTempShuffleFile(NextSpilledFileDir());
@@ -1105,6 +1354,14 @@ int Splitter::SpillToTmpFile() {
return 0;
}
+int Splitter::SpillToTmpFileByRow() {
+ options_.next_spilled_file_dir = CreateTempShuffleFile(NextSpilledFileDir());
+ WriteDataFileProtoByRow();
+ std::shared_ptr ptrTmp = CaculateSpilledTmpFilePartitionOffsets();
+ spilled_tmp_files_info_[options_.next_spilled_file_dir] = ptrTmp;
+ return 0;
+}
+
Splitter::Splitter(InputDataTypes inputDataTypes, int32_t num_cols, int32_t num_partitions, SplitOptions options, bool flag)
: input_col_types(inputDataTypes),
singlePartitionFlag(flag),
@@ -1116,23 +1373,18 @@ Splitter::Splitter(InputDataTypes inputDataTypes, int32_t num_cols, int32_t num_
ToSplitterTypeId(num_cols);
}
-std::shared_ptr Create(InputDataTypes inputDataTypes,
+Splitter *Create(InputDataTypes inputDataTypes,
int32_t num_cols,
int32_t num_partitions,
SplitOptions options,
bool flag)
{
- std::shared_ptr res(
- new Splitter(inputDataTypes,
- num_cols,
- num_partitions,
- std::move(options),
- flag));
+ auto res = new Splitter(inputDataTypes, num_cols, num_partitions, std::move(options), flag);
res->Split_Init();
return res;
}
-std::shared_ptr Splitter::Make(
+Splitter *Splitter::Make(
const std::string& short_name,
InputDataTypes inputDataTypes,
int32_t num_cols,
@@ -1168,14 +1420,19 @@ int Splitter::Stop() {
if (nullptr == vecBatchProto) {
throw std::runtime_error("delete nullptr error for free protobuf vecBatch memory");
}
- delete vecBatchProto; //free protobuf vecBatch memory
- delete partition_id_cnt_cur_;
- delete partition_id_cnt_cache_;
- delete fixed_valueBuffer_size_;
- delete fixed_nullBuffer_size_;
- delete partition_buffer_size_;
- delete partition_buffer_idx_base_;
- delete partition_buffer_idx_offset_;
- delete partition_serialization_size_;
+ return 0;
+}
+
+int Splitter::StopByRow() {
+ if (isSpill) {
+ TIME_NANO_OR_RAISE(total_write_time_, MergeSpilledByRow());
+ TIME_NANO_OR_RAISE(total_write_time_, DeleteSpilledTmpFile());
+ LogsDebug(" Spill For Splitter Stopped. total_spill_row_num_: %ld ", total_spill_row_num_);
+ } else {
+ TIME_NANO_OR_RAISE(total_write_time_, WriteSplitByRow());
+ }
+ if (nullptr == protoRowBatch) {
+ throw std::runtime_error("delete nullptr error for free protobuf rowBatch memory");
+ }
return 0;
}
diff --git a/omnioperator/omniop-spark-extension/cpp/src/shuffle/splitter.h b/omnioperator/omniop-spark-extension/cpp/src/shuffle/splitter.h
index ec0cc661f0a49d531b47dab87299cb6a8dfbde2a..9f0e8fa582aa30e7ef6933546367854e50fa986c 100644
--- a/omnioperator/omniop-spark-extension/cpp/src/shuffle/splitter.h
+++ b/omnioperator/omniop-spark-extension/cpp/src/shuffle/splitter.h
@@ -35,6 +35,7 @@
#include "../common/common.h"
#include "vec_data.pb.h"
#include "google/protobuf/io/zero_copy_stream_impl.h"
+#include "vector/omni_row.h"
using namespace std;
using namespace spark;
@@ -51,13 +52,16 @@ struct SplitRowInfo {
};
class Splitter {
-
virtual int DoSplit(VectorBatch& vb);
int WriteDataFileProto();
+ int WriteDataFileProtoByRow();
+
std::shared_ptr CaculateSpilledTmpFilePartitionOffsets();
+ spark::VecType::VecTypeId CastOmniTypeIdToProtoVecType(int32_t omniType);
+
void SerializingFixedColumns(int32_t partitionId,
spark::Vec& vec,
int fixColIndexTmp,
@@ -70,8 +74,12 @@ class Splitter {
int protoSpillPartition(int32_t partition_id, std::unique_ptr &bufferStream);
+ int protoSpillPartitionByRow(int32_t partition_id, std::unique_ptr &bufferStream);
+
int32_t ProtoWritePartition(int32_t partition_id, std::unique_ptr &bufferStream, void *bufferOut, int32_t &sizeOut);
+ int32_t ProtoWritePartitionByRow(int32_t partition_id, std::unique_ptr &bufferStream, void *bufferOut, int32_t &sizeOut);
+
int ComputeAndCountPartitionId(VectorBatch& vb);
int AllocatePartitionBuffers(int32_t partition_id, int32_t new_size);
@@ -93,9 +101,28 @@ class Splitter {
void MergeSpilled();
+ void MergeSpilledByRow();
+
void WriteSplit();
+ void WriteSplitByRow();
+
+ // Common structures for row formats and col formats
bool isSpill = false;
+ int64_t total_bytes_written_ = 0;
+ int64_t total_bytes_spilled_ = 0;
+ int64_t total_write_time_ = 0;
+ int64_t total_spill_time_ = 0;
+ int64_t total_spill_row_num_ = 0;
+
+ // configured local dirs for spilled file
+ int32_t dir_selection_ = 0;
+ std::vector sub_dir_selection_;
+ std::vector configured_dirs_;
+
+ // Data structures required to handle col formats
+ int64_t total_compute_pid_time_ = 0;
+ std::vector partition_lengths_;
std::vector partition_id_; // 记录当前vb每一行的pid
int32_t *partition_id_cnt_cur_; // 统计不同partition记录的行数(当前处理中的vb)
uint64_t *partition_id_cnt_cache_; // 统计不同partition记录的行数,cache住的
@@ -117,12 +144,6 @@ class Splitter {
int32_t *partition_buffer_idx_base_; //当前已缓存的各partition行数据记录,用于定位缓冲buffer当前可用位置
int32_t *partition_buffer_idx_offset_; //split定长列时用于统计offset的临时变量
uint32_t *partition_serialization_size_; // 记录序列化后的各partition大小,用于stop返回partition偏移 in bytes
-
- // configured local dirs for spilled file
- int32_t dir_selection_ = 0;
- std::vector sub_dir_selection_;
- std::vector configured_dirs_;
-
std::vector>>>> partition_cached_vectorbatch_;
/*
* varchar buffers:
@@ -130,14 +151,13 @@ class Splitter {
*
*/
std::vector>> vc_partition_array_buffers_;
+ spark::VecBatch *vecBatchProto = new VecBatch(); // protobuf 序列化对象结构
- int64_t total_bytes_written_ = 0;
- int64_t total_bytes_spilled_ = 0;
- int64_t total_write_time_ = 0;
- int64_t total_spill_time_ = 0;
- int64_t total_compute_pid_time_ = 0;
- int64_t total_spill_row_num_ = 0;
- std::vector partition_lengths_;
+ // Data structures required to handle row formats
+ std::vector> partition_rows; // pid : std::vector
+ uint64_t total_input_size = 0; // total row size in bytes
+ uint32_t expansion = 2; // expansion coefficient
+ spark::ProtoRowBatch *protoRowBatch = new ProtoRowBatch();
private:
void ReleaseVarcharVector()
@@ -167,37 +187,41 @@ private:
delete vb;
}
+ // Data structures required to handle col formats
std::set varcharVectorCache;
- std::vector vector_batch_col_types_;
- InputDataTypes input_col_types;
- std::vector binary_array_empirical_size_;
- omniruntime::vec::VectorBatch *inputVecBatch = nullptr;
public:
+ // Common structures for row formats and col formats
bool singlePartitionFlag = false;
int32_t num_partitions_;
SplitOptions options_;
// 分区数
int32_t num_fields_;
-
+ InputDataTypes input_col_types;
+ std::vector proto_col_types_; // Avoid repeated type conversion during the split process.
+ omniruntime::vec::VectorBatch *inputVecBatch = nullptr;
std::map> spilled_tmp_files_info_;
- spark::VecBatch *vecBatchProto = new VecBatch(); // protobuf 序列化对象结构
-
virtual int Split_Init();
virtual int Split(VectorBatch& vb);
+ virtual int SplitByRow(VectorBatch* vb);
+
int Stop();
+ int StopByRow();
+
int SpillToTmpFile();
+ int SpillToTmpFileByRow();
+
Splitter(InputDataTypes inputDataTypes,
int32_t num_cols,
int32_t num_partitions,
SplitOptions options,
bool flag);
- static std::shared_ptr Make(
+ static Splitter *Make(
const std::string &short_name,
InputDataTypes inputDataTypes,
int32_t num_cols,
@@ -220,6 +244,24 @@ public:
const std::vector& PartitionLengths() const { return partition_lengths_; }
+ virtual ~Splitter()
+ {
+ delete vecBatchProto; //free protobuf vecBatch memory
+ delete protoRowBatch; //free protobuf rowBatch memory
+ delete[] partition_id_cnt_cur_;
+ delete[] partition_id_cnt_cache_;
+ delete[] partition_buffer_size_;
+ delete[] partition_buffer_idx_base_;
+ delete[] partition_buffer_idx_offset_;
+ delete[] partition_serialization_size_;
+ delete[] fixed_valueBuffer_size_;
+ delete[] fixed_nullBuffer_size_;
+ partition_fixed_width_buffers_.clear();
+ partition_binary_builders_.clear();
+ partition_cached_vectorbatch_.clear();
+ spilled_tmp_files_info_.clear();
+ }
+
omniruntime::vec::VectorBatch *GetInputVecBatch()
{
return inputVecBatch;
diff --git a/omnioperator/omniop-spark-extension/cpp/src/shuffle/type.h b/omnioperator/omniop-spark-extension/cpp/src/shuffle/type.h
index 04d90130dea30a83651fff3526c08dc0992f9928..61b4bc1498b2dbc3b2435f428401c36bb17c5725 100644
--- a/omnioperator/omniop-spark-extension/cpp/src/shuffle/type.h
+++ b/omnioperator/omniop-spark-extension/cpp/src/shuffle/type.h
@@ -43,7 +43,8 @@ struct SplitOptions {
Allocator *allocator = Allocator::GetAllocator();
uint64_t spill_batch_row_num = 4096; // default value
- uint64_t spill_mem_threshold = 1024 * 1024 * 1024; // default value
+ uint64_t task_spill_mem_threshold = 1024 * 1024 * 1024; // default value
+ uint64_t executor_spill_mem_threshold = UINT64_MAX; // default value
uint64_t compress_block_size = 64 * 1024; // default value
static SplitOptions Defaults();
diff --git a/omnioperator/omniop-spark-extension/cpp/test/CMakeLists.txt b/omnioperator/omniop-spark-extension/cpp/test/CMakeLists.txt
index f53ac2ad45e95bddfe3a15a4da78b245e373981e..287223e1dfbd63405f6d0ae7d1605698b9566c3d 100644
--- a/omnioperator/omniop-spark-extension/cpp/test/CMakeLists.txt
+++ b/omnioperator/omniop-spark-extension/cpp/test/CMakeLists.txt
@@ -2,12 +2,14 @@ aux_source_directory(${CMAKE_CURRENT_LIST_DIR} TEST_ROOT_SRCS)
add_subdirectory(shuffle)
add_subdirectory(utils)
+add_subdirectory(benchmark)
# configure
set(TP_TEST_TARGET tptest)
set(MY_LINK
shuffletest
utilstest
+ benchmark_test
)
# find gtest package
@@ -27,7 +29,7 @@ target_link_libraries(${TP_TEST_TARGET}
pthread
stdc++
dl
- boostkit-omniop-vector-1.4.0-aarch64
+ boostkit-omniop-vector-1.5.0-aarch64
securec
spark_columnar_plugin)
diff --git a/omnioperator/omniop-spark-extension/cpp/test/benchmark/CMakeLists.txt b/omnioperator/omniop-spark-extension/cpp/test/benchmark/CMakeLists.txt
new file mode 100644
index 0000000000000000000000000000000000000000..e4b47d29549ad375eea27fbb9a7105569f5e8a73
--- /dev/null
+++ b/omnioperator/omniop-spark-extension/cpp/test/benchmark/CMakeLists.txt
@@ -0,0 +1,8 @@
+aux_source_directory(${CMAKE_CURRENT_LIST_DIR} BENCHMARK_LIST)
+set(BENCHMARK_TEST_TARGET benchmark_test)
+add_library(${BENCHMARK_TEST_TARGET} STATIC ${BENCHMARK_LIST})
+target_compile_options(${BENCHMARK_TEST_TARGET} PUBLIC )
+target_link_libraries(${BENCHMARK_TEST_TARGET} utilstest)
+target_include_directories(${BENCHMARK_TEST_TARGET} PUBLIC ${CMAKE_BINARY_DIR}/src)
+target_include_directories(${BENCHMARK_TEST_TARGET} PUBLIC $ENV{JAVA_HOME}/include)
+target_include_directories(${BENCHMARK_TEST_TARGET} PUBLIC $ENV{JAVA_HOME}/include/linux)
\ No newline at end of file
diff --git a/omnioperator/omniop-spark-extension/cpp/test/benchmark/shuffle_benchmark.cpp b/omnioperator/omniop-spark-extension/cpp/test/benchmark/shuffle_benchmark.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..db273192fc88cde342ca49951e1d740cdd8874a2
--- /dev/null
+++ b/omnioperator/omniop-spark-extension/cpp/test/benchmark/shuffle_benchmark.cpp
@@ -0,0 +1,386 @@
+/**
+ * Copyright (C) 2020-2024. Huawei Technologies Co., Ltd. All rights reserved.
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include
+#include
+#include "gtest/gtest.h"
+#include "../utils/test_utils.h"
+
+using namespace omniruntime::type;
+using namespace omniruntime;
+
+static constexpr int ROWS = 300;
+static constexpr int COLS = 300;
+static constexpr int PARTITION_SIZE = 600;
+static constexpr int BATCH_COUNT = 20;
+
+static int generateRandomNumber() {
+ return std::rand() % PARTITION_SIZE;
+}
+
+// construct data
+static std::vector constructVecs(int rows, int cols, int* inputTypeIds, double nullProbability) {
+ std::srand(time(nullptr));
+ std::vector vecs;
+ vecs.resize(cols);
+
+ for (int i = 0; i < cols; ++i) {
+ BaseVector *vector = VectorHelper::CreateFlatVector(inputTypeIds[i], rows);
+ if (inputTypeIds[i] == OMNI_VARCHAR) {
+ auto strVec = reinterpret_cast> *>(vector);
+ for(int j = 0; j < rows; ++j) {
+ auto randNum = static_cast(std::rand()) / RAND_MAX;
+ if (randNum < nullProbability) {
+ strVec->SetNull(j);
+ } else {
+ std::string_view str("hello world");
+ strVec->SetValue(j, str);
+ }
+ }
+ } else if (inputTypeIds[i] == OMNI_LONG) {
+ auto longVec = reinterpret_cast *>(vector);
+ for (int j = 0; j < rows; ++j) {
+ auto randNum = static_cast(std::rand()) / RAND_MAX;
+ if (randNum < nullProbability) {
+ longVec->SetNull(j);
+ } else {
+ long value = generateRandomNumber();
+ longVec->SetValue(j, value);
+ }
+ }
+ }
+ vecs[i] = vector;
+ }
+ return vecs;
+}
+
+// generate partitionId
+static Vector* constructPidVec(int rows) {
+ srand(time(nullptr));
+ auto pidVec = new Vector(rows);
+ for (int j = 0; j < rows; ++j) {
+ int pid = generateRandomNumber();
+ pidVec->SetValue(j, pid);
+ }
+ return pidVec;
+}
+
+static std::vector generateData(int rows, int cols, int* inputTypeIds, double nullProbability) {
+ std::vector vecBatches;
+ vecBatches.resize(BATCH_COUNT);
+ for (int i = 0; i < BATCH_COUNT; ++i) {
+ auto vecBatch = new VectorBatch(rows);
+ auto pidVec = constructPidVec(rows);
+ vecBatch->Append(pidVec);
+ auto vecs = constructVecs(rows, cols, inputTypeIds, nullProbability);
+ for (int j = 0; j < vecs.size(); ++j) {
+ vecBatch->Append(vecs[j]);
+ }
+ vecBatches[i] = vecBatch;
+ }
+ return vecBatches;
+}
+
+static std::vector copyData(const std::vector& origin) {
+ std::vector vecBatches;
+ vecBatches.resize(origin.size());
+ for (int i = 0; i < origin.size(); ++i) {
+ auto originBatch = origin[i];
+ auto vecBatch = new VectorBatch(originBatch->GetRowCount());
+
+ for (int j = 0; j < originBatch->GetVectorCount(); ++j) {
+ BaseVector *vec = originBatch->Get(j);
+ BaseVector *sliceVec = VectorHelper::SliceVector(vec, 0, originBatch->GetRowCount());
+ vecBatch->Append(sliceVec);
+ }
+ vecBatches[i] = vecBatch;
+ }
+ return vecBatches;
+}
+
+static void bm_row_handle(const std::vector& vecBatches, int *inputTypeIds, int cols) {
+ Timer timer;
+ timer.SetStart();
+
+ InputDataTypes inputDataTypes;
+ inputDataTypes.inputVecTypeIds = inputTypeIds;
+
+ auto splitOptions = SplitOptions::Defaults();
+ splitOptions.buffer_size = 4096;
+
+ auto compression_type_result = GetCompressionType("lz4");
+ splitOptions.compression_type = compression_type_result;
+ auto splitter = Splitter::Make("hash", inputDataTypes, cols, PARTITION_SIZE, std::move(splitOptions));
+
+ for ( int i = 0; i < vecBatches.size(); ++i) {
+ VectorBatch *vb = vecBatches[i];
+ splitter->SplitByRow(vb);
+ }
+ splitter->StopByRow();
+
+ timer.CalculateElapse();
+ double wallElapsed = timer.GetWallElapse();
+ double cpuElapsed = timer.GetCpuElapse();
+ std::cout << "row time, wall " << wallElapsed << " cpu " << cpuElapsed << std::endl;
+
+ delete splitter;
+}
+
+static void bm_col_handle(const std::vector& vecBatches, int *inputTypeIds, int cols) {
+ Timer timer;
+ timer.SetStart();
+
+ InputDataTypes inputDataTypes;
+ inputDataTypes.inputVecTypeIds = inputTypeIds;
+
+ auto splitOptions = SplitOptions::Defaults();
+ splitOptions.buffer_size = 4096;
+
+ auto compression_type_result = GetCompressionType("lz4");
+ splitOptions.compression_type = compression_type_result;
+ auto splitter = Splitter::Make("hash", inputDataTypes, cols, PARTITION_SIZE, std::move(splitOptions));
+
+ for ( int i = 0; i < vecBatches.size(); ++i) {
+ VectorBatch *vb = vecBatches[i];
+ splitter->Split(*vb);
+ }
+ splitter->Stop();
+
+ timer.CalculateElapse();
+ double wallElapsed = timer.GetWallElapse();
+ double cpuElapsed = timer.GetCpuElapse();
+ std::cout << "col time, wall " << wallElapsed << " cpu " << cpuElapsed << std::endl;
+
+ delete splitter;
+}
+
+TEST(shuffle_benchmark, null_0) {
+ double strProbability = 0.25;
+ double nullProbability = 0;
+
+ int *inputTypeIds = new int32_t[COLS];
+ for (int i = 0; i < COLS; ++i) {
+ double randNum = static_cast(std::rand()) / RAND_MAX;
+ if (randNum < strProbability) {
+ inputTypeIds[i] = OMNI_VARCHAR;
+ } else {
+ inputTypeIds[i] = OMNI_LONG;
+ }
+ }
+
+ auto vecBatches1 = generateData(ROWS, COLS, inputTypeIds, nullProbability);
+ auto vecBatches2 = copyData(vecBatches1);
+
+ std::cout << "rows: " << ROWS << ", cols: " << COLS << ", null probability: " << nullProbability << std::endl;
+ bm_row_handle(vecBatches1, inputTypeIds, COLS);
+ bm_col_handle(vecBatches2, inputTypeIds, COLS);
+ delete[] inputTypeIds;
+}
+
+TEST(shuffle_benchmark, null_25) {
+ double strProbability = 0.25;
+ double nullProbability = 0.25;
+
+ int *inputTypeIds = new int32_t[COLS];
+ for (int i = 0; i < COLS; ++i) {
+ double randNum = static_cast(std::rand()) / RAND_MAX;
+ if (randNum < strProbability) {
+ inputTypeIds[i] = OMNI_VARCHAR;
+ } else {
+ inputTypeIds[i] = OMNI_LONG;
+ }
+ }
+
+ auto vecBatches1 = generateData(ROWS, COLS, inputTypeIds, nullProbability);
+ auto vecBatches2 = copyData(vecBatches1);
+
+ std::cout << "rows: " << ROWS << ", cols: " << COLS << ", null probability: " << nullProbability << std::endl;
+ bm_row_handle(vecBatches1, inputTypeIds, COLS);
+ bm_col_handle(vecBatches2, inputTypeIds, COLS);
+ delete[] inputTypeIds;
+}
+
+TEST(shuffle_benchmark, null_50) {
+ double strProbability = 0.25;
+ double nullProbability = 0.5;
+
+ int *inputTypeIds = new int32_t[COLS];
+ for (int i = 0; i < COLS; ++i) {
+ double randNum = static_cast(std::rand()) / RAND_MAX;
+ if (randNum < strProbability) {
+ inputTypeIds[i] = OMNI_VARCHAR;
+ } else {
+ inputTypeIds[i] = OMNI_LONG;
+ }
+ }
+
+ auto vecBatches1 = generateData(ROWS, COLS, inputTypeIds, nullProbability);
+ auto vecBatches2 = copyData(vecBatches1);
+
+ std::cout << "rows: " << ROWS << ", cols: " << COLS << ", null probability: " << nullProbability << std::endl;
+ bm_row_handle(vecBatches1, inputTypeIds, COLS);
+ bm_col_handle(vecBatches2, inputTypeIds, COLS);
+ delete[] inputTypeIds;
+}
+
+TEST(shuffle_benchmark, null_75) {
+ double strProbability = 0.25;
+ double nullProbability = 0.75;
+
+ int *inputTypeIds = new int32_t[COLS];
+ for (int i = 0; i < COLS; ++i) {
+ double randNum = static_cast(std::rand()) / RAND_MAX;
+ if (randNum < strProbability) {
+ inputTypeIds[i] = OMNI_VARCHAR;
+ } else {
+ inputTypeIds[i] = OMNI_LONG;
+ }
+ }
+
+ auto vecBatches1 = generateData(ROWS, COLS, inputTypeIds, nullProbability);
+ auto vecBatches2 = copyData(vecBatches1);
+
+ std::cout << "rows: " << ROWS << ", cols: " << COLS << ", null probability: " << nullProbability << std::endl;
+ bm_row_handle(vecBatches1, inputTypeIds, COLS);
+ bm_col_handle(vecBatches2, inputTypeIds, COLS);
+ delete[] inputTypeIds;
+}
+
+TEST(shuffle_benchmark, null_100) {
+ double strProbability = 0.25;
+ double nullProbability = 1;
+
+ int *inputTypeIds = new int32_t[COLS];
+ for (int i = 0; i < COLS; ++i) {
+ double randNum = static_cast(std::rand()) / RAND_MAX;
+ if (randNum < strProbability) {
+ inputTypeIds[i] = OMNI_VARCHAR;
+ } else {
+ inputTypeIds[i] = OMNI_LONG;
+ }
+ }
+
+ auto vecBatches1 = generateData(ROWS, COLS, inputTypeIds, nullProbability);
+ auto vecBatches2 = copyData(vecBatches1);
+
+ std::cout << "rows: " << ROWS << ", cols: " << COLS << ", null probability: " << nullProbability << std::endl;
+ bm_row_handle(vecBatches1, inputTypeIds, COLS);
+ bm_col_handle(vecBatches2, inputTypeIds, COLS);
+ delete[] inputTypeIds;
+}
+
+TEST(shuffle_benchmark, null_25_row_900_col_100) {
+ double strProbability = 0.25;
+ double nullProbability = 0.25;
+ int rows = 900;
+ int cols = 100;
+
+ int *inputTypeIds = new int32_t[cols];
+ for (int i = 0; i < cols; ++i) {
+ double randNum = static_cast(std::rand()) / RAND_MAX;
+ if (randNum < strProbability) {
+ inputTypeIds[i] = OMNI_VARCHAR;
+ } else {
+ inputTypeIds[i] = OMNI_LONG;
+ }
+ }
+
+ auto vecBatches1 = generateData(rows, cols, inputTypeIds, nullProbability);
+ auto vecBatches2 = copyData(vecBatches1);
+
+ std::cout << "rows: " << rows << ", cols: " << cols << ", null probability: " << nullProbability << std::endl;
+ bm_row_handle(vecBatches1, inputTypeIds, cols);
+ bm_col_handle(vecBatches2, inputTypeIds, cols);
+ delete[] inputTypeIds;
+}
+
+TEST(shuffle_benchmark, null_25_row_1800_col_50) {
+ double strProbability = 0.25;
+ double nullProbability = 0.25;
+ int rows = 1800;
+ int cols = 50;
+
+ int *inputTypeIds = new int32_t[cols];
+ for (int i = 0; i < cols; ++i) {
+ double randNum = static_cast(std::rand()) / RAND_MAX;
+ if (randNum < strProbability) {
+ inputTypeIds[i] = OMNI_VARCHAR;
+ } else {
+ inputTypeIds[i] = OMNI_LONG;
+ }
+ }
+
+ auto vecBatches1 = generateData(rows, cols, inputTypeIds, nullProbability);
+ auto vecBatches2 = copyData(vecBatches1);
+
+ std::cout << "rows: " << rows << ", cols: " << cols << ", null probability: " << nullProbability << std::endl;
+ bm_row_handle(vecBatches1, inputTypeIds, cols);
+ bm_col_handle(vecBatches2, inputTypeIds, cols);
+ delete[] inputTypeIds;
+}
+
+TEST(shuffle_benchmark, null_25_row_9000_col_10) {
+ double strProbability = 0.25;
+ double nullProbability = 0.25;
+ int rows = 9000;
+ int cols = 10;
+
+ int *inputTypeIds = new int32_t[cols];
+ for (int i = 0; i < cols; ++i) {
+ double randNum = static_cast(std::rand()) / RAND_MAX;
+ if (randNum < strProbability) {
+ inputTypeIds[i] = OMNI_VARCHAR;
+ } else {
+ inputTypeIds[i] = OMNI_LONG;
+ }
+ }
+
+ auto vecBatches1 = generateData(rows, cols, inputTypeIds, nullProbability);
+ auto vecBatches2 = copyData(vecBatches1);
+
+ std::cout << "rows: " << rows << ", cols: " << cols << ", null probability: " << nullProbability << std::endl;
+ bm_row_handle(vecBatches1, inputTypeIds, cols);
+ bm_col_handle(vecBatches2, inputTypeIds, cols);
+ delete[] inputTypeIds;
+}
+
+TEST(shuffle_benchmark, null_25_row_18000_col_5) {
+ double strProbability = 0.25;
+ double nullProbability = 0.25;
+ int rows = 18000;
+ int cols = 5;
+
+ int *inputTypeIds = new int32_t[cols];
+ for (int i = 0; i < cols; ++i) {
+ double randNum = static_cast(std::rand()) / RAND_MAX;
+ if (randNum < strProbability) {
+ inputTypeIds[i] = OMNI_VARCHAR;
+ } else {
+ inputTypeIds[i] = OMNI_LONG;
+ }
+ }
+
+ auto vecBatches1 = generateData(rows, cols, inputTypeIds, nullProbability);
+ auto vecBatches2 = copyData(vecBatches1);
+
+ std::cout << "rows: " << rows << ", cols: " << cols << ", null probability: " << nullProbability << std::endl;
+ bm_row_handle(vecBatches1, inputTypeIds, cols);
+ bm_col_handle(vecBatches2, inputTypeIds, cols);
+ delete[] inputTypeIds;
+}
\ No newline at end of file
diff --git a/omnioperator/omniop-spark-extension/cpp/test/shuffle/shuffle_test.cpp b/omnioperator/omniop-spark-extension/cpp/test/shuffle/shuffle_test.cpp
index 3031943eeae22b591de4c4b3693eb1e1744b3ac3..27e1297e75f155862bbbf4cd5d82db89827cca80 100644
--- a/omnioperator/omniop-spark-extension/cpp/test/shuffle/shuffle_test.cpp
+++ b/omnioperator/omniop-spark-extension/cpp/test/shuffle/shuffle_test.cpp
@@ -39,7 +39,6 @@ protected:
if (IsFileExist(tmpTestingDir)) {
DeletePathAll(tmpTestingDir.c_str());
}
- testShuffleSplitterHolder.Clear();
}
// run before each case...
@@ -63,7 +62,7 @@ TEST_F (ShuffleTest, Split_SingleVarChar) {
inputDataTypes.inputVecTypeIds = inputVecTypeIds;
inputDataTypes.inputDataPrecisions = new uint32_t[colNumber];
inputDataTypes.inputDataScales = new uint32_t[colNumber];
- int splitterId = Test_splitter_nativeMake("hash",
+ long splitterAddr = Test_splitter_nativeMake("hash",
4,
inputDataTypes,
colNumber,
@@ -73,21 +72,21 @@ TEST_F (ShuffleTest, Split_SingleVarChar) {
0,
tmpTestingDir);
VectorBatch* vb1 = CreateVectorBatch_1row_varchar_withPid(3, "N");
- Test_splitter_split(splitterId, vb1);
+ Test_splitter_split(splitterAddr, vb1);
VectorBatch* vb2 = CreateVectorBatch_1row_varchar_withPid(2, "F");
- Test_splitter_split(splitterId, vb2);
+ Test_splitter_split(splitterAddr, vb2);
VectorBatch* vb3 = CreateVectorBatch_1row_varchar_withPid(3, "N");
- Test_splitter_split(splitterId, vb3);
+ Test_splitter_split(splitterAddr, vb3);
VectorBatch* vb4 = CreateVectorBatch_1row_varchar_withPid(2, "F");
- Test_splitter_split(splitterId, vb4);
+ Test_splitter_split(splitterAddr, vb4);
VectorBatch* vb5 = CreateVectorBatch_1row_varchar_withPid(2, "F");
- Test_splitter_split(splitterId, vb5);
+ Test_splitter_split(splitterAddr, vb5);
VectorBatch* vb6 = CreateVectorBatch_1row_varchar_withPid(1, "R");
- Test_splitter_split(splitterId, vb6);
+ Test_splitter_split(splitterAddr, vb6);
VectorBatch* vb7 = CreateVectorBatch_1row_varchar_withPid(3, "N");
- Test_splitter_split(splitterId, vb7);
- Test_splitter_stop(splitterId);
- Test_splitter_close(splitterId);
+ Test_splitter_split(splitterAddr, vb7);
+ Test_splitter_stop(splitterAddr);
+ Test_splitter_close(splitterAddr);
delete[] inputDataTypes.inputDataPrecisions;
delete[] inputDataTypes.inputDataScales;
}
@@ -101,7 +100,7 @@ TEST_F (ShuffleTest, Split_Fixed_Cols) {
inputDataTypes.inputDataPrecisions = new uint32_t[colNumber];
inputDataTypes.inputDataScales = new uint32_t[colNumber];
int partitionNum = 4;
- int splitterId = Test_splitter_nativeMake("hash",
+ long splitterAddr = Test_splitter_nativeMake("hash",
partitionNum,
inputDataTypes,
colNumber,
@@ -112,10 +111,10 @@ TEST_F (ShuffleTest, Split_Fixed_Cols) {
tmpTestingDir);
for (uint64_t j = 0; j < 1; j++) {
VectorBatch* vb = CreateVectorBatch_5fixedCols_withPid(partitionNum, 999);
- Test_splitter_split(splitterId, vb);
+ Test_splitter_split(splitterAddr, vb);
}
- Test_splitter_stop(splitterId);
- Test_splitter_close(splitterId);
+ Test_splitter_stop(splitterAddr);
+ Test_splitter_close(splitterAddr);
delete[] inputDataTypes.inputDataPrecisions;
delete[] inputDataTypes.inputDataScales;
}
@@ -129,7 +128,7 @@ TEST_F (ShuffleTest, Split_Fixed_SinglePartition_SomeNullRow) {
inputDataTypes.inputDataPrecisions = new uint32_t[colNumber];
inputDataTypes.inputDataScales = new uint32_t[colNumber];
int partitionNum = 1;
- int splitterId = Test_splitter_nativeMake("single",
+ long splitterAddr = Test_splitter_nativeMake("single",
partitionNum,
inputDataTypes,
colNumber,
@@ -140,10 +139,10 @@ TEST_F (ShuffleTest, Split_Fixed_SinglePartition_SomeNullRow) {
tmpTestingDir);
for (uint64_t j = 0; j < 100; j++) {
VectorBatch* vb = CreateVectorBatch_someNullRow_vectorBatch();
- Test_splitter_split(splitterId, vb);
+ Test_splitter_split(splitterAddr, vb);
}
- Test_splitter_stop(splitterId);
- Test_splitter_close(splitterId);
+ Test_splitter_stop(splitterAddr);
+ Test_splitter_close(splitterAddr);
delete[] inputDataTypes.inputDataPrecisions;
delete[] inputDataTypes.inputDataScales;
}
@@ -157,7 +156,7 @@ TEST_F (ShuffleTest, Split_Fixed_SinglePartition_SomeNullCol) {
inputDataTypes.inputDataPrecisions = new uint32_t[colNumber];
inputDataTypes.inputDataScales = new uint32_t[colNumber];
int partitionNum = 1;
- int splitterId = Test_splitter_nativeMake("single",
+ long splitterAddr = Test_splitter_nativeMake("single",
partitionNum,
inputDataTypes,
colNumber,
@@ -168,10 +167,10 @@ TEST_F (ShuffleTest, Split_Fixed_SinglePartition_SomeNullCol) {
tmpTestingDir);
for (uint64_t j = 0; j < 100; j++) {
VectorBatch* vb = CreateVectorBatch_someNullCol_vectorBatch();
- Test_splitter_split(splitterId, vb);
+ Test_splitter_split(splitterAddr, vb);
}
- Test_splitter_stop(splitterId);
- Test_splitter_close(splitterId);
+ Test_splitter_stop(splitterAddr);
+ Test_splitter_close(splitterAddr);
delete[] inputDataTypes.inputDataPrecisions;
delete[] inputDataTypes.inputDataScales;
}
@@ -205,7 +204,7 @@ TEST_F (ShuffleTest, Split_Mix_LargeSize) {
inputDataTypes.inputDataPrecisions = new uint32_t[colNumber];
inputDataTypes.inputDataScales = new uint32_t[colNumber];
int partitionNum = 4;
- int splitterId = Test_splitter_nativeMake("hash",
+ long splitterAddr = Test_splitter_nativeMake("hash",
partitionNum,
inputDataTypes,
colNumber,
@@ -216,10 +215,10 @@ TEST_F (ShuffleTest, Split_Mix_LargeSize) {
tmpTestingDir);
for (uint64_t j = 0; j < 999; j++) {
VectorBatch* vb = CreateVectorBatch_4col_withPid(partitionNum, 999);
- Test_splitter_split(splitterId, vb);
+ Test_splitter_split(splitterAddr, vb);
}
- Test_splitter_stop(splitterId);
- Test_splitter_close(splitterId);
+ Test_splitter_stop(splitterAddr);
+ Test_splitter_close(splitterAddr);
delete[] inputDataTypes.inputDataPrecisions;
delete[] inputDataTypes.inputDataScales;
}
@@ -233,7 +232,7 @@ TEST_F (ShuffleTest, Split_Short_10WRows) {
inputDataTypes.inputDataPrecisions = new uint32_t[colNumber];
inputDataTypes.inputDataScales = new uint32_t[colNumber];
int partitionNum = 10;
- int splitterId = Test_splitter_nativeMake("hash",
+ long splitterAddr = Test_splitter_nativeMake("hash",
partitionNum,
inputDataTypes,
colNumber,
@@ -244,10 +243,10 @@ TEST_F (ShuffleTest, Split_Short_10WRows) {
tmpTestingDir);
for (uint64_t j = 0; j < 100; j++) {
VectorBatch* vb = CreateVectorBatch_1FixCol_withPid(partitionNum, 1000, ShortType());
- Test_splitter_split(splitterId, vb);
+ Test_splitter_split(splitterAddr, vb);
}
- Test_splitter_stop(splitterId);
- Test_splitter_close(splitterId);
+ Test_splitter_stop(splitterAddr);
+ Test_splitter_close(splitterAddr);
delete[] inputDataTypes.inputDataPrecisions;
delete[] inputDataTypes.inputDataScales;
}
@@ -261,7 +260,7 @@ TEST_F (ShuffleTest, Split_Boolean_10WRows) {
inputDataTypes.inputDataPrecisions = new uint32_t[colNumber];
inputDataTypes.inputDataScales = new uint32_t[colNumber];
int partitionNum = 10;
- int splitterId = Test_splitter_nativeMake("hash",
+ long splitterAddr = Test_splitter_nativeMake("hash",
partitionNum,
inputDataTypes,
colNumber,
@@ -272,10 +271,10 @@ TEST_F (ShuffleTest, Split_Boolean_10WRows) {
tmpTestingDir);
for (uint64_t j = 0; j < 100; j++) {
VectorBatch* vb = CreateVectorBatch_1FixCol_withPid(partitionNum, 1000, BooleanType());
- Test_splitter_split(splitterId, vb);
+ Test_splitter_split(splitterAddr, vb);
}
- Test_splitter_stop(splitterId);
- Test_splitter_close(splitterId);
+ Test_splitter_stop(splitterAddr);
+ Test_splitter_close(splitterAddr);
delete[] inputDataTypes.inputDataPrecisions;
delete[] inputDataTypes.inputDataScales;
}
@@ -289,7 +288,7 @@ TEST_F (ShuffleTest, Split_Long_100WRows) {
inputDataTypes.inputDataPrecisions = new uint32_t[colNumber];
inputDataTypes.inputDataScales = new uint32_t[colNumber];
int partitionNum = 10;
- int splitterId = Test_splitter_nativeMake("hash",
+ long splitterAddr = Test_splitter_nativeMake("hash",
partitionNum,
inputDataTypes,
colNumber,
@@ -300,10 +299,10 @@ TEST_F (ShuffleTest, Split_Long_100WRows) {
tmpTestingDir);
for (uint64_t j = 0; j < 100; j++) {
VectorBatch* vb = CreateVectorBatch_1FixCol_withPid(partitionNum, 10000, LongType());
- Test_splitter_split(splitterId, vb);
+ Test_splitter_split(splitterAddr, vb);
}
- Test_splitter_stop(splitterId);
- Test_splitter_close(splitterId);
+ Test_splitter_stop(splitterAddr);
+ Test_splitter_close(splitterAddr);
delete[] inputDataTypes.inputDataPrecisions;
delete[] inputDataTypes.inputDataScales;
}
@@ -317,7 +316,7 @@ TEST_F (ShuffleTest, Split_VarChar_LargeSize) {
inputDataTypes.inputDataPrecisions = new uint32_t[colNumber];
inputDataTypes.inputDataScales = new uint32_t[colNumber];
int partitionNum = 4;
- int splitterId = Test_splitter_nativeMake("hash",
+ long splitterAddr = Test_splitter_nativeMake("hash",
partitionNum,
inputDataTypes,
colNumber,
@@ -328,10 +327,10 @@ TEST_F (ShuffleTest, Split_VarChar_LargeSize) {
tmpTestingDir);
for (uint64_t j = 0; j < 99; j++) {
VectorBatch* vb = CreateVectorBatch_4varcharCols_withPid(partitionNum, 99);
- Test_splitter_split(splitterId, vb);
+ Test_splitter_split(splitterAddr, vb);
}
- Test_splitter_stop(splitterId);
- Test_splitter_close(splitterId);
+ Test_splitter_stop(splitterAddr);
+ Test_splitter_close(splitterAddr);
delete[] inputDataTypes.inputDataPrecisions;
delete[] inputDataTypes.inputDataScales;
}
@@ -345,7 +344,7 @@ TEST_F (ShuffleTest, Split_VarChar_First) {
inputDataTypes.inputDataPrecisions = new uint32_t[colNumber];
inputDataTypes.inputDataScales = new uint32_t[colNumber];
int partitionNum = 4;
- int splitterId = Test_splitter_nativeMake("hash",
+ long splitterAddr = Test_splitter_nativeMake("hash",
partitionNum,
inputDataTypes,
colNumber,
@@ -355,27 +354,27 @@ TEST_F (ShuffleTest, Split_VarChar_First) {
0,
tmpTestingDir);
VectorBatch* vb0 = CreateVectorBatch_2column_1row_withPid(0, "corpbrand #4", 1);
- Test_splitter_split(splitterId, vb0);
+ Test_splitter_split(splitterAddr, vb0);
VectorBatch* vb1 = CreateVectorBatch_2column_1row_withPid(3, "brandmaxi #4", 1);
- Test_splitter_split(splitterId, vb1);
+ Test_splitter_split(splitterAddr, vb1);
VectorBatch* vb2 = CreateVectorBatch_2column_1row_withPid(1, "edu packnameless #9", 1);
- Test_splitter_split(splitterId, vb2);
+ Test_splitter_split(splitterAddr, vb2);
VectorBatch* vb3 = CreateVectorBatch_2column_1row_withPid(1, "amalgunivamalg #11", 1);
- Test_splitter_split(splitterId, vb3);
+ Test_splitter_split(splitterAddr, vb3);
VectorBatch* vb4 = CreateVectorBatch_2column_1row_withPid(0, "brandcorp #2", 1);
- Test_splitter_split(splitterId, vb4);
+ Test_splitter_split(splitterAddr, vb4);
VectorBatch* vb5 = CreateVectorBatch_2column_1row_withPid(0, "scholarbrand #2", 1);
- Test_splitter_split(splitterId, vb5);
+ Test_splitter_split(splitterAddr, vb5);
VectorBatch* vb6 = CreateVectorBatch_2column_1row_withPid(2, "edu packcorp #6", 1);
- Test_splitter_split(splitterId, vb6);
+ Test_splitter_split(splitterAddr, vb6);
VectorBatch* vb7 = CreateVectorBatch_2column_1row_withPid(2, "edu packamalg #1", 1);
- Test_splitter_split(splitterId, vb7);
+ Test_splitter_split(splitterAddr, vb7);
VectorBatch* vb8 = CreateVectorBatch_2column_1row_withPid(0, "brandnameless #8", 1);
- Test_splitter_split(splitterId, vb8);
+ Test_splitter_split(splitterAddr, vb8);
VectorBatch* vb9 = CreateVectorBatch_2column_1row_withPid(2, "univmaxi #2", 1);
- Test_splitter_split(splitterId, vb9);
- Test_splitter_stop(splitterId);
- Test_splitter_close(splitterId);
+ Test_splitter_split(splitterAddr, vb9);
+ Test_splitter_stop(splitterAddr);
+ Test_splitter_close(splitterAddr);
delete[] inputDataTypes.inputDataPrecisions;
delete[] inputDataTypes.inputDataScales;
}
@@ -389,7 +388,7 @@ TEST_F (ShuffleTest, Split_Dictionary) {
inputDataTypes.inputDataPrecisions = new uint32_t[colNumber];
inputDataTypes.inputDataScales = new uint32_t[colNumber];
int partitionNum = 4;
- int splitterId = Test_splitter_nativeMake("hash",
+ long splitterAddr = Test_splitter_nativeMake("hash",
partitionNum,
inputDataTypes,
colNumber,
@@ -400,10 +399,10 @@ TEST_F (ShuffleTest, Split_Dictionary) {
tmpTestingDir);
for (uint64_t j = 0; j < 2; j++) {
VectorBatch* vb = CreateVectorBatch_2dictionaryCols_withPid(partitionNum);
- Test_splitter_split(splitterId, vb);
+ Test_splitter_split(splitterAddr, vb);
}
- Test_splitter_stop(splitterId);
- Test_splitter_close(splitterId);
+ Test_splitter_stop(splitterAddr);
+ Test_splitter_close(splitterAddr);
delete[] inputDataTypes.inputDataPrecisions;
delete[] inputDataTypes.inputDataScales;
}
@@ -417,7 +416,7 @@ TEST_F (ShuffleTest, Split_Char) {
inputDataTypes.inputDataPrecisions = new uint32_t[colNumber];
inputDataTypes.inputDataScales = new uint32_t[colNumber];
int partitionNum = 4;
- int splitterId = Test_splitter_nativeMake("hash",
+ long splitterAddr = Test_splitter_nativeMake("hash",
partitionNum,
inputDataTypes,
colNumber,
@@ -428,10 +427,10 @@ TEST_F (ShuffleTest, Split_Char) {
tmpTestingDir);
for (uint64_t j = 0; j < 99; j++) {
VectorBatch* vb = CreateVectorBatch_4charCols_withPid(partitionNum, 99);
- Test_splitter_split(splitterId, vb);
+ Test_splitter_split(splitterAddr, vb);
}
- Test_splitter_stop(splitterId);
- Test_splitter_close(splitterId);
+ Test_splitter_stop(splitterAddr);
+ Test_splitter_close(splitterAddr);
delete[] inputDataTypes.inputDataPrecisions;
delete[] inputDataTypes.inputDataScales;
}
@@ -445,7 +444,7 @@ TEST_F (ShuffleTest, Split_Decimal128) {
inputDataTypes.inputDataPrecisions = new uint32_t[colNumber];
inputDataTypes.inputDataScales = new uint32_t[colNumber];
int partitionNum = 4;
- int splitterId = Test_splitter_nativeMake("hash",
+ long splitterAddr = Test_splitter_nativeMake("hash",
partitionNum,
inputDataTypes,
colNumber,
@@ -456,10 +455,10 @@ TEST_F (ShuffleTest, Split_Decimal128) {
tmpTestingDir);
for (uint64_t j = 0; j < 999; j++) {
VectorBatch* vb = CreateVectorBatch_1decimal128Col_withPid(partitionNum, 999);
- Test_splitter_split(splitterId, vb);
+ Test_splitter_split(splitterAddr, vb);
}
- Test_splitter_stop(splitterId);
- Test_splitter_close(splitterId);
+ Test_splitter_stop(splitterAddr);
+ Test_splitter_close(splitterAddr);
delete[] inputDataTypes.inputDataPrecisions;
delete[] inputDataTypes.inputDataScales;
}
@@ -473,7 +472,7 @@ TEST_F (ShuffleTest, Split_Decimal64) {
inputDataTypes.inputDataPrecisions = new uint32_t[colNumber];
inputDataTypes.inputDataScales = new uint32_t[colNumber];
int partitionNum = 4;
- int splitterId = Test_splitter_nativeMake("hash",
+ long splitterAddr = Test_splitter_nativeMake("hash",
partitionNum,
inputDataTypes,
colNumber,
@@ -484,10 +483,10 @@ TEST_F (ShuffleTest, Split_Decimal64) {
tmpTestingDir);
for (uint64_t j = 0; j < 999; j++) {
VectorBatch* vb = CreateVectorBatch_1decimal64Col_withPid(partitionNum, 999);
- Test_splitter_split(splitterId, vb);
+ Test_splitter_split(splitterAddr, vb);
}
- Test_splitter_stop(splitterId);
- Test_splitter_close(splitterId);
+ Test_splitter_stop(splitterAddr);
+ Test_splitter_close(splitterAddr);
delete[] inputDataTypes.inputDataPrecisions;
delete[] inputDataTypes.inputDataScales;
}
@@ -501,7 +500,7 @@ TEST_F (ShuffleTest, Split_Decimal64_128) {
inputDataTypes.inputDataPrecisions = new uint32_t[colNumber];
inputDataTypes.inputDataScales = new uint32_t[colNumber];
int partitionNum = 4;
- int splitterId = Test_splitter_nativeMake("hash",
+ long splitterAddr = Test_splitter_nativeMake("hash",
partitionNum,
inputDataTypes,
colNumber,
@@ -512,10 +511,10 @@ TEST_F (ShuffleTest, Split_Decimal64_128) {
tmpTestingDir);
for (uint64_t j = 0; j < 999; j++) {
VectorBatch* vb = CreateVectorBatch_2decimalCol_withPid(partitionNum, 999);
- Test_splitter_split(splitterId, vb);
+ Test_splitter_split(splitterAddr, vb);
}
- Test_splitter_stop(splitterId);
- Test_splitter_close(splitterId);
+ Test_splitter_stop(splitterAddr);
+ Test_splitter_close(splitterAddr);
delete[] inputDataTypes.inputDataPrecisions;
delete[] inputDataTypes.inputDataScales;
}
\ No newline at end of file
diff --git a/omnioperator/omniop-spark-extension/cpp/test/utils/test_utils.cpp b/omnioperator/omniop-spark-extension/cpp/test/utils/test_utils.cpp
index 9c30ed17e05f13bb438856f459501f4978a1220f..abf9f8074c08d8c7cb749b7f5c3295b1cbd4e3a9 100644
--- a/omnioperator/omniop-spark-extension/cpp/test/utils/test_utils.cpp
+++ b/omnioperator/omniop-spark-extension/cpp/test/utils/test_utils.cpp
@@ -47,7 +47,9 @@ BaseVector *CreateDictionaryVector(DataType &dataType, int32_t rowCount, int32_t
va_start(args, idsCount);
BaseVector *dictionary = CreateVector(dataType, rowCount, args);
va_end(args);
- return DYNAMIC_TYPE_DISPATCH(CreateDictionary, dataType.GetId(), dictionary, ids, idsCount);
+ BaseVector *dictVector = DYNAMIC_TYPE_DISPATCH(CreateDictionary, dataType.GetId(), dictionary, ids, idsCount);
+ delete dictionary;
+ return dictVector;
}
/**
@@ -379,7 +381,7 @@ void Test_Shuffle_Compression(std::string compStr, int32_t numPartition, int32_t
inputDataTypes.inputDataPrecisions = new uint32_t[colNumber];
inputDataTypes.inputDataScales = new uint32_t[colNumber];
int partitionNum = numPartition;
- int splitterId = Test_splitter_nativeMake("hash",
+ long splitterId = Test_splitter_nativeMake("hash",
partitionNum,
inputDataTypes,
colNumber,
@@ -422,31 +424,31 @@ long Test_splitter_nativeMake(std::string partitioning_name,
splitOptions.compression_type = compression_type_result;
splitOptions.data_file = data_file_jstr;
auto splitter = Splitter::Make(partitioning_name, inputDataTypes, numCols, num_partitions, std::move(splitOptions));
- return testShuffleSplitterHolder.Insert(std::shared_ptr(splitter));
+ return reinterpret_cast(static_cast(splitter));
}
-void Test_splitter_split(long splitter_id, VectorBatch* vb) {
- auto splitter = testShuffleSplitterHolder.Lookup(splitter_id);
+void Test_splitter_split(long splitter_addr, VectorBatch* vb) {
+ auto splitter = reinterpret_cast(splitter_addr);
// Initialize split global variables
splitter->Split(*vb);
}
-void Test_splitter_stop(long splitter_id) {
- auto splitter = testShuffleSplitterHolder.Lookup(splitter_id);
+void Test_splitter_stop(long splitter_addr) {
+ auto splitter = reinterpret_cast(splitter_addr);
if (!splitter) {
- std::string error_message = "Invalid splitter id " + std::to_string(splitter_id);
+ std::string error_message = "Invalid splitter id " + std::to_string(splitter_addr);
throw std::runtime_error("Test no splitter.");
}
splitter->Stop();
}
-void Test_splitter_close(long splitter_id) {
- auto splitter = testShuffleSplitterHolder.Lookup(splitter_id);
+void Test_splitter_close(long splitter_addr) {
+ auto splitter = reinterpret_cast(splitter_addr);
if (!splitter) {
- std::string error_message = "Invalid splitter id " + std::to_string(splitter_id);
+ std::string error_message = "Invalid splitter id " + std::to_string(splitter_addr);
throw std::runtime_error("Test no splitter.");
}
- testShuffleSplitterHolder.Erase(splitter_id);
+ delete splitter;
}
void GetFilePath(const char *path, const char *filename, char *filepath) {
diff --git a/omnioperator/omniop-spark-extension/cpp/test/utils/test_utils.h b/omnioperator/omniop-spark-extension/cpp/test/utils/test_utils.h
index b7380254a687ed6f3eaf8234df944feac9087404..b588ea6f2427f185686ca6357570ec3a44b52691 100644
--- a/omnioperator/omniop-spark-extension/cpp/test/utils/test_utils.h
+++ b/omnioperator/omniop-spark-extension/cpp/test/utils/test_utils.h
@@ -25,10 +25,7 @@
#include
#include
#include
-#include "shuffle/splitter.h"
-#include "jni/concurrent_map.h"
-
-static ConcurrentMap> testShuffleSplitterHolder;
+#include "../../src/shuffle/splitter.h"
static std::string s_shuffle_tests_dir = "/tmp/shuffleTests";
@@ -131,4 +128,71 @@ void GetFilePath(const char *path, const char *filename, char *filepath);
void DeletePathAll(const char* path);
+class Timer {
+public:
+ Timer() : wallElapsed(0), cpuElapsed(0) {}
+
+ ~Timer() {}
+
+ void SetStart() {
+ clock_gettime(CLOCK_REALTIME, &wallStart);
+ clock_gettime(CLOCK_PROCESS_CPUTIME_ID, &cpuStart);
+ }
+
+ void CalculateElapse() {
+ clock_gettime(CLOCK_REALTIME, &wallEnd);
+ clock_gettime(CLOCK_PROCESS_CPUTIME_ID, &cpuEnd);
+ long secondsWall = wallEnd.tv_sec - wallStart.tv_sec;
+ long secondsCpu = cpuEnd.tv_sec - cpuEnd.tv_sec;
+ long nsWall = wallEnd.tv_nsec - wallStart.tv_nsec;
+ long nsCpu = cpuEnd.tv_nsec - cpuEnd.tv_nsec;
+ wallElapsed = secondsWall + nsWall * 1e-9;
+ cpuElapsed = secondsCpu + nsCpu * 1e-9;
+ }
+
+ void Start(const char *TestTitle) {
+ wallElapsed = 0;
+ cpuElapsed = 0;
+ clock_gettime(CLOCK_REALTIME, &wallStart);
+ clock_gettime(CLOCK_PROCESS_CPUTIME_ID, &cpuStart);
+ this->title = TestTitle;
+ }
+
+ void End() {
+ clock_gettime(CLOCK_REALTIME, &wallEnd);
+ clock_gettime(CLOCK_PROCESS_CPUTIME_ID, &cpuEnd);
+ long secondsWall = wallEnd.tv_sec - wallStart.tv_sec;
+ long secondsCpu = cpuEnd.tv_sec - cpuEnd.tv_sec;
+ long nsWall = wallEnd.tv_nsec - wallStart.tv_nsec;
+ long nsCpu = cpuEnd.tv_nsec - cpuEnd.tv_nsec;
+ wallElapsed = secondsWall + nsWall * 1e-9;
+ cpuElapsed = secondsCpu + nsCpu * 1e-9;
+ std::cout << title << "\t: wall " << wallElapsed << " \tcpu " << cpuElapsed << std::endl;
+ }
+
+ double GetWallElapse() {
+ return wallElapsed;
+ }
+
+ double GetCpuElapse() {
+ return cpuElapsed;
+ }
+
+ void Reset() {
+ wallElapsed = 0;
+ cpuElapsed = 0;
+ clock_gettime(CLOCK_REALTIME, &wallStart);
+ clock_gettime(CLOCK_PROCESS_CPUTIME_ID, &cpuStart);
+ }
+
+private:
+ double wallElapsed;
+ double cpuElapsed;
+ struct timespec cpuStart;
+ struct timespec wallStart;
+ struct timespec cpuEnd;
+ struct timespec wallEnd;
+ const char *title;
+};
+
#endif //SPARK_THESTRAL_PLUGIN_TEST_UTILS_H
\ No newline at end of file
diff --git a/omnioperator/omniop-spark-extension/java/pom.xml b/omnioperator/omniop-spark-extension/java/pom.xml
index 62c407dc3dedf2df0c4ca7ab891083467a8279f7..1fd6bb40d13a5670f440a468d5fb02022642116c 100644
--- a/omnioperator/omniop-spark-extension/java/pom.xml
+++ b/omnioperator/omniop-spark-extension/java/pom.xml
@@ -7,7 +7,7 @@
com.huawei.kunpeng
boostkit-omniop-spark-parent
- 3.3.1-1.4.0
+ 3.3.1-1.5.0
../pom.xml
@@ -46,13 +46,13 @@
com.huawei.boostkit
boostkit-omniop-bindings
- 1.4.0
+ 1.5.0
aarch64
com.huawei.boostkit
boostkit-omniop-native-reader
- 3.3.1-1.4.0
+ 3.3.1-1.5.0
junit
diff --git a/omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchScanReader.java b/omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchScanReader.java
index 1d858a5e3f22353ddfd18588ec69f50d96de5852..73438aa4355b090b78bed387d0c6fa81a308f339 100644
--- a/omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchScanReader.java
+++ b/omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchScanReader.java
@@ -20,11 +20,8 @@ package com.huawei.boostkit.spark.jni;
import com.huawei.boostkit.scan.jni.OrcColumnarBatchJniReader;
import nova.hetu.omniruntime.type.DataType;
-import nova.hetu.omniruntime.type.Decimal128DataType;
import nova.hetu.omniruntime.vector.*;
-import org.apache.hadoop.security.UserGroupInformation;
-import org.apache.hadoop.security.token.Token;
import org.apache.spark.sql.catalyst.util.RebaseDateTime;
import org.apache.hadoop.hive.ql.io.sarg.ExpressionTree;
import org.apache.hadoop.hive.ql.io.sarg.PredicateLeaf;
@@ -33,13 +30,10 @@ import org.apache.orc.Reader.Options;
import org.json.JSONObject;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
-import org.apache.orc.TypeDescription;
-import java.io.IOException;
import java.net.URI;
import java.sql.Date;
import java.util.ArrayList;
-import java.util.Arrays;
import java.util.List;
public class OrcColumnarBatchScanReader {
@@ -186,13 +180,6 @@ public class OrcColumnarBatchScanReader {
}
job.put("tailLocation", 9223372036854775807L);
- // handle delegate token for native orc reader
- OrcColumnarBatchScanReader.tokenDebug("initializeReader");
- JSONObject tokenJsonObj = constructTokensJSONObject();
- if (null != tokenJsonObj) {
- job.put("tokens", tokenJsonObj);
- }
-
job.put("scheme", uri.getScheme() == null ? "" : uri.getScheme());
job.put("host", uri.getHost() == null ? "" : uri.getHost());
job.put("port", uri.getPort());
@@ -228,12 +215,7 @@ public class OrcColumnarBatchScanReader {
}
job.put("includedColumns", colToInclu.toArray());
- // handle delegate token for native orc reader
- OrcColumnarBatchScanReader.tokenDebug("initializeRecordReader");
- JSONObject tokensJsonObj = constructTokensJSONObject();
- if (null != tokensJsonObj) {
- job.put("tokens", tokensJsonObj);
- }
+
recordReader = jniReader.initializeRecordReader(reader, job);
return recordReader;
}
@@ -271,8 +253,7 @@ public class OrcColumnarBatchScanReader {
}
}
- public int next(Vec[] vecList) {
- int[] typeIds = new int[realColsCnt];
+ public int next(Vec[] vecList, int[] typeIds) {
long[] vecNativeIds = new long[realColsCnt];
long rtn = jniReader.recordReaderNext(recordReader, batchReader, typeIds, vecNativeIds);
if (rtn == 0) {
@@ -342,39 +323,4 @@ public class OrcColumnarBatchScanReader {
return hexString.toString().toLowerCase();
}
-
- public static JSONObject constructTokensJSONObject() {
- JSONObject tokensJsonItem = new JSONObject();
- try {
- ArrayList child = new ArrayList();
- for (Token> token : UserGroupInformation.getCurrentUser().getTokens()) {
- JSONObject tokenJsonItem = new JSONObject();
- tokenJsonItem.put("identifier", bytesToHexString(token.getIdentifier()));
- tokenJsonItem.put("password", bytesToHexString(token.getPassword()));
- tokenJsonItem.put("kind", token.getKind().toString());
- tokenJsonItem.put("service", token.getService().toString());
- child.add(tokenJsonItem);
- }
- tokensJsonItem.put("token", child.toArray());
- } catch (IOException e) {
- tokensJsonItem = null;
- } finally {
- LOGGER.debug("\n\n================== tokens-json ==================\n" + tokensJsonItem.toString());
- return tokensJsonItem;
- }
- }
-
- public static void tokenDebug(String mesg) {
- try {
- LOGGER.debug("\n\n=============" + mesg + "=============\n" + UserGroupInformation.getCurrentUser().toString());
- for (Token> token : UserGroupInformation.getCurrentUser().getTokens()) {
- LOGGER.debug("\n\ntoken identifier:" + bytesToHexString(token.getIdentifier()));
- LOGGER.debug("\ntoken password:" + bytesToHexString(token.getPassword()));
- LOGGER.debug("\ntoken kind:" + token.getKind());
- LOGGER.debug("\ntoken service:" + token.getService());
- }
- } catch (IOException e) {
- LOGGER.debug("\n\n**********" + mesg + " exception **********\n");
- }
- }
}
diff --git a/omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/jni/SparkJniWrapper.java b/omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/jni/SparkJniWrapper.java
index 9aa7c414bc841fc10f2e0daae30fd168b1690055..9a49812e678286ba734c3b7d65ffd864ed20df4e 100644
--- a/omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/jni/SparkJniWrapper.java
+++ b/omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/jni/SparkJniWrapper.java
@@ -35,7 +35,8 @@ public class SparkJniWrapper {
String localDirs,
long shuffleCompressBlockSize,
int shuffleSpillBatchRowNum,
- long shuffleSpillMemoryThreshold) {
+ long shuffleTaskSpillMemoryThreshold,
+ long shuffleExecutorSpillMemoryThreshold) {
return nativeMake(
part.getPartitionName(),
part.getPartitionNum(),
@@ -48,7 +49,8 @@ public class SparkJniWrapper {
localDirs,
shuffleCompressBlockSize,
shuffleSpillBatchRowNum,
- shuffleSpillMemoryThreshold);
+ shuffleTaskSpillMemoryThreshold,
+ shuffleExecutorSpillMemoryThreshold);
}
public native long nativeMake(
@@ -63,7 +65,8 @@ public class SparkJniWrapper {
String localDirs,
long shuffleCompressBlockSize,
int shuffleSpillBatchRowNum,
- long shuffleSpillMemoryThreshold
+ long shuffleTaskSpillMemoryThreshold,
+ long shuffleExecutorSpillMemoryThreshold
);
/**
@@ -75,6 +78,16 @@ public class SparkJniWrapper {
*/
public native void split(long splitterId, long nativeVectorBatch);
+ /**
+ * Split one record batch represented by bufAddrs and bufSizes into several batches. The batch is converted to row
+ * formats for split according to the first column as partition id. During splitting, the data in native
+ * buffers will be written to disk when the buffers are full.
+ *
+ * @param splitterId Addresses of splitter
+ * @param nativeVectorBatch Addresses of nativeVectorBatch
+ */
+ public native void rowSplit(long splitterId, long nativeVectorBatch);
+
/**
* Write the data remained in the buffers hold by native splitter to each partition's temporary
* file. And stop processing splitting
@@ -84,6 +97,15 @@ public class SparkJniWrapper {
*/
public native SplitResult stop(long splitterId);
+ /**
+ * Write the data remained in the row buffers hold by native splitter to each partition's temporary
+ * file. And stop processing splitting
+ *
+ * @param splitterId splitter instance id
+ * @return SplitResult
+ */
+ public native SplitResult rowStop(long splitterId);
+
/**
* Release resources associated with designated splitter instance.
*
diff --git a/omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/serialize/ShuffleDataSerializer.java b/omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/serialize/ShuffleDataSerializer.java
index 6a0c1b27c4282016aecada2ba4ef0c48c320f20f..99759e4a30cc8d88189e703ac5fc36541446f1e7 100644
--- a/omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/serialize/ShuffleDataSerializer.java
+++ b/omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/serialize/ShuffleDataSerializer.java
@@ -20,6 +20,8 @@ package com.huawei.boostkit.spark.serialize;
import com.google.protobuf.InvalidProtocolBufferException;
+import com.huawei.boostkit.spark.jni.NativeLoader;
+import nova.hetu.omniruntime.type.*;
import nova.hetu.omniruntime.utils.OmniRuntimeException;
import nova.hetu.omniruntime.vector.BooleanVec;
import nova.hetu.omniruntime.vector.Decimal128Vec;
@@ -29,16 +31,28 @@ import nova.hetu.omniruntime.vector.LongVec;
import nova.hetu.omniruntime.vector.ShortVec;
import nova.hetu.omniruntime.vector.VarcharVec;
import nova.hetu.omniruntime.vector.Vec;
+import nova.hetu.omniruntime.vector.serialize.OmniRowDeserializer;
import org.apache.spark.sql.execution.vectorized.OmniColumnVector;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.vectorized.ColumnVector;
import org.apache.spark.sql.vectorized.ColumnarBatch;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
public class ShuffleDataSerializer {
+ private static final Logger LOG = LoggerFactory.getLogger(NativeLoader.class);
- public static ColumnarBatch deserialize(byte[] bytes) {
+ public static ColumnarBatch deserialize(boolean isRowShuffle, byte[] bytes) {
+ if (!isRowShuffle) {
+ return deserializeByColumn(bytes);
+ } else {
+ return deserializeByRow(bytes);
+ }
+ }
+
+ public static ColumnarBatch deserializeByColumn(byte[] bytes) {
ColumnVector[] vecs = null;
try {
VecData.VecBatch vecBatch = VecData.VecBatch.parseFrom(bytes);
@@ -64,6 +78,30 @@ public class ShuffleDataSerializer {
}
}
+ public static ColumnarBatch deserializeByRow(byte[] bytes) {
+ try {
+ VecData.ProtoRowBatch rowBatch = VecData.ProtoRowBatch.parseFrom(bytes);
+ int vecCount = rowBatch.getVecCnt();
+ int rowCount = rowBatch.getRowCnt();
+ OmniColumnVector[] columnarVecs = new OmniColumnVector[vecCount];
+ long[] omniVecs = new long[vecCount];
+ int[] omniTypes = new int[vecCount];
+ createEmptyVec(rowBatch, omniTypes, omniVecs, columnarVecs, vecCount, rowCount);
+ OmniRowDeserializer deserializer = new OmniRowDeserializer(omniTypes);
+
+ for (int rowIdx = 0; rowIdx < rowCount; rowIdx++) {
+ VecData.ProtoRow protoRow = rowBatch.getRows(rowIdx);
+ byte[] array = protoRow.getData().toByteArray();
+ deserializer.parse(array, omniVecs, rowIdx);
+ }
+
+ deserializer.close();
+ return new ColumnarBatch(columnarVecs, rowCount);
+ } catch (InvalidProtocolBufferException e) {
+ throw new RuntimeException("deserialize failed. errmsg:" + e.getMessage());
+ }
+ }
+
private static ColumnVector buildVec(VecData.Vec protoVec, int vecSize) {
VecData.VecType protoTypeId = protoVec.getVecType();
Vec vec;
@@ -128,4 +166,75 @@ public class ShuffleDataSerializer {
vecTmp.setVec(vec);
return vecTmp;
}
+
+ public static void createEmptyVec(VecData.ProtoRowBatch rowBatch, int[] omniTypes, long[] omniVecs, OmniColumnVector[] columnarVectors, int vecCount, int rowCount) {
+ for (int i = 0; i < vecCount; i++) {
+ VecData.VecType protoTypeId = rowBatch.getVecTypes(i);
+ DataType sparkType;
+ Vec omniVec;
+ switch (protoTypeId.getTypeId()) {
+ case VEC_TYPE_INT:
+ sparkType = DataTypes.IntegerType;
+ omniTypes[i] = IntDataType.INTEGER.getId().toValue();
+ omniVec = new IntVec(rowCount);
+ break;
+ case VEC_TYPE_DATE32:
+ sparkType = DataTypes.DateType;
+ omniTypes[i] = Date32DataType.DATE32.getId().toValue();
+ omniVec = new IntVec(rowCount);
+ break;
+ case VEC_TYPE_LONG:
+ sparkType = DataTypes.LongType;
+ omniTypes[i] = LongDataType.LONG.getId().toValue();
+ omniVec = new LongVec(rowCount);
+ break;
+ case VEC_TYPE_DATE64:
+ sparkType = DataTypes.DateType;
+ omniTypes[i] = Date64DataType.DATE64.getId().toValue();
+ omniVec = new LongVec(rowCount);
+ break;
+ case VEC_TYPE_DECIMAL64:
+ sparkType = DataTypes.createDecimalType(protoTypeId.getPrecision(), protoTypeId.getScale());
+ omniTypes[i] = new Decimal64DataType(protoTypeId.getPrecision(), protoTypeId.getScale()).getId().toValue();
+ omniVec = new LongVec(rowCount);
+ break;
+ case VEC_TYPE_SHORT:
+ sparkType = DataTypes.ShortType;
+ omniTypes[i] = ShortDataType.SHORT.getId().toValue();
+ omniVec = new ShortVec(rowCount);
+ break;
+ case VEC_TYPE_BOOLEAN:
+ sparkType = DataTypes.BooleanType;
+ omniTypes[i] = BooleanDataType.BOOLEAN.getId().toValue();
+ omniVec = new BooleanVec(rowCount);
+ break;
+ case VEC_TYPE_DOUBLE:
+ sparkType = DataTypes.DoubleType;
+ omniTypes[i] = DoubleDataType.DOUBLE.getId().toValue();
+ omniVec = new DoubleVec(rowCount);
+ break;
+ case VEC_TYPE_VARCHAR:
+ case VEC_TYPE_CHAR:
+ sparkType = DataTypes.StringType;
+ omniTypes[i] = VarcharDataType.VARCHAR.getId().toValue();
+ omniVec = new VarcharVec(rowCount);
+ break;
+ case VEC_TYPE_DECIMAL128:
+ sparkType = DataTypes.createDecimalType(protoTypeId.getPrecision(), protoTypeId.getScale());
+ omniTypes[i] = new Decimal128DataType(protoTypeId.getPrecision(), protoTypeId.getScale()).getId().toValue();
+ omniVec = new Decimal128Vec(rowCount);
+ break;
+ case VEC_TYPE_TIME32:
+ case VEC_TYPE_TIME64:
+ case VEC_TYPE_INTERVAL_DAY_TIME:
+ case VEC_TYPE_INTERVAL_MONTHS:
+ default:
+ throw new IllegalStateException("Unexpected value: " + protoTypeId.getTypeId());
+ }
+
+ omniVecs[i] = omniVec.getNativeVector();
+ columnarVectors[i] = new OmniColumnVector(rowCount, sparkType, false);
+ columnarVectors[i].setVec(omniVec);
+ }
+ }
}
diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarPlugin.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarPlugin.scala
index 108562dc66926aac1d5adadccc3d7568d79071da..d4a26303181bfa476f01f9293c0bca172e61dacc 100644
--- a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarPlugin.scala
+++ b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarPlugin.scala
@@ -20,23 +20,26 @@ package com.huawei.boostkit.spark
import com.huawei.boostkit.spark.expression.OmniExpressionAdaptor
import com.huawei.boostkit.spark.util.PhysicalPlanSelector
+import nova.hetu.omniruntime.memory.MemoryManager
+import org.apache.spark.api.plugin.{DriverPlugin, ExecutorPlugin, SparkPlugin}
import org.apache.spark.internal.Logging
import org.apache.spark.sql.{SparkSession, SparkSessionExtensions}
import org.apache.spark.sql.catalyst.expressions.{Ascending, DynamicPruningSubquery, Expression, Literal, SortOrder}
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Partial, PartialMerge}
-import org.apache.spark.sql.catalyst.optimizer.{DelayCartesianProduct, HeuristicJoinReorder, MergeSubqueryFilters, RewriteSelfJoinInInPredicate}
+import org.apache.spark.sql.catalyst.optimizer.{DelayCartesianProduct, HeuristicJoinReorder, MergeSubqueryFilters, CombineJoinedAggregates, RewriteSelfJoinInInPredicate}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.{RowToOmniColumnarExec, _}
import org.apache.spark.sql.execution.adaptive.{AQEShuffleReadExec, BroadcastQueryStageExec, OmniAQEShuffleReadExec, QueryStageExec, ShuffleQueryStageExec}
import org.apache.spark.sql.execution.aggregate.{DummyLogicalPlan, ExtendedAggUtils, HashAggregateExec}
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, Exchange, ReusedExchangeExec, ShuffleExchangeExec}
import org.apache.spark.sql.execution.joins._
-import org.apache.spark.sql.execution.window.WindowExec
+import org.apache.spark.sql.execution.window.{WindowExec, TopNPushDownForWindow}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.ColumnarBatchSupportUtil.checkColumnarBatchSupport
import org.apache.spark.sql.catalyst.planning.PhysicalAggregation
import org.apache.spark.sql.catalyst.plans.LeftSemi
import org.apache.spark.sql.catalyst.plans.logical.Aggregate
+import org.apache.spark.sql.execution.util.SparkMemoryUtils.addLeakSafeTaskCompletionListener
case class ColumnarPreOverrides() extends Rule[SparkPlan] {
val columnarConf: ColumnarPluginConfig = ColumnarPluginConfig.getSessionConf
@@ -67,6 +70,8 @@ case class ColumnarPreOverrides() extends Rule[SparkPlan] {
val dedupLeftSemiJoinThreshold: Int = columnarConf.dedupLeftSemiJoinThreshold
val enableColumnarCoalesce: Boolean = columnarConf.enableColumnarCoalesce
val enableRollupOptimization: Boolean = columnarConf.enableRollupOptimization
+ val enableRowShuffle: Boolean = columnarConf.enableRowShuffle
+ val columnsThreshold: Int = columnarConf.columnsThreshold
def apply(plan: SparkPlan): SparkPlan = {
replaceWithColumnarPlan(plan)
@@ -544,11 +549,18 @@ case class ColumnarPreOverrides() extends Rule[SparkPlan] {
val children = plan.children.map(replaceWithColumnarPlan)
logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.")
ColumnarUnionExec(children)
- case plan: ShuffleExchangeExec if enableColumnarShuffle =>
+ case plan: ShuffleExchangeExec if enableColumnarShuffle || enableRowShuffle =>
val child = replaceWithColumnarPlan(plan.child)
if (child.output.nonEmpty) {
logInfo(s"Columnar Processing for ${plan.getClass} is currently supported.")
- new ColumnarShuffleExchangeExec(plan.outputPartitioning, child, plan.shuffleOrigin)
+ if (child.isInstanceOf[ColumnarHashAggregateExec] && child.output.size > columnsThreshold
+ && enableRowShuffle) {
+ new ColumnarShuffleExchangeExec(plan.outputPartitioning, child, plan.shuffleOrigin, true)
+ } else if (enableColumnarShuffle) {
+ new ColumnarShuffleExchangeExec(plan.outputPartitioning, child, plan.shuffleOrigin, false)
+ } else {
+ plan
+ }
} else {
plan
}
@@ -751,5 +763,25 @@ class ColumnarPlugin extends (SparkSessionExtensions => Unit) with Logging {
extensions.injectOptimizerRule(_ => DelayCartesianProduct)
extensions.injectOptimizerRule(_ => HeuristicJoinReorder)
extensions.injectOptimizerRule(_ => MergeSubqueryFilters)
+ extensions.injectOptimizerRule(_ => CombineJoinedAggregates)
+ extensions.injectQueryStagePrepRule(_ => TopNPushDownForWindow)
+ }
+}
+
+private class OmniTaskStartExecutorPlugin extends ExecutorPlugin {
+ override def onTaskStart(): Unit = {
+ addLeakSafeTaskCompletionListener[Unit](_ => {
+ MemoryManager.clearMemory()
+ })
+ }
+}
+
+class OmniSparkPlugin extends SparkPlugin {
+ override def executorPlugin(): ExecutorPlugin = {
+ new OmniTaskStartExecutorPlugin()
+ }
+
+ override def driverPlugin(): DriverPlugin = {
+ null
}
}
\ No newline at end of file
diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarPluginConfig.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarPluginConfig.scala
index e87122e87c6b675dfaeab352af390143111510d9..5fc711a3982ba4633e7cff785414198e751fcd94 100644
--- a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarPluginConfig.scala
+++ b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarPluginConfig.scala
@@ -135,9 +135,9 @@ class ColumnarPluginConfig(conf: SQLConf) extends Logging {
val columnarShuffleSpillBatchRowNum =
conf.getConfString("spark.shuffle.columnar.shuffleSpillBatchRowNum", "10000").toInt
- // columnar shuffle spill memory threshold
- val columnarShuffleSpillMemoryThreshold =
- conf.getConfString("spark.shuffle.columnar.shuffleSpillMemoryThreshold",
+ // columnar shuffle spill memory threshold in task level
+ val columnarShuffleTaskSpillMemoryThreshold =
+ conf.getConfString("spark.shuffle.columnar.shuffleTaskSpillMemoryThreshold",
"2147483648").toLong
// columnar shuffle compress block size
@@ -156,12 +156,15 @@ class ColumnarPluginConfig(conf: SQLConf) extends Logging {
val columnarShuffleNativeBufferSize =
conf.getConfString("spark.sql.execution.columnar.maxRecordsPerBatch", "4096").toInt
+ val columnarSpillWriteBufferSize: Long =
+ conf.getConfString("spark.omni.sql.columnar.spill.writeBufferSize", "4121440").toLong
+
// columnar spill threshold - Percentage of memory usage, associate with the "spark.memory.offHeap" together
val columnarSpillMemPctThreshold: Integer =
conf.getConfString("spark.omni.sql.columnar.spill.memFraction", "90").toInt
// columnar spill dir disk reserve Size, default 10GB
- val columnarSpillDirDiskReserveSize:Long =
+ val columnarSpillDirDiskReserveSize: Long =
conf.getConfString("spark.omni.sql.columnar.spill.dirDiskReserveSize", "10737418240").toLong
// enable or disable columnar sort spill
@@ -244,6 +247,7 @@ class ColumnarPluginConfig(conf: SQLConf) extends Logging {
conf.getConfString("spark.omni.sql.columnar.dedupLeftSemiJoinThreshold", "3").toInt
val filterMergeEnable: Boolean = conf.getConfString("spark.sql.execution.filterMerge.enabled", "false").toBoolean
+ val combineJoinedAggregatesEnabled: Boolean = conf.getConfString("spark.sql.execution.combineJoinedAggregates.enabled", "false").toBoolean
val filterMergeThreshold: Double = conf.getConfString("spark.sql.execution.filterMerge.maxCost", "100.0").toDouble
@@ -259,6 +263,13 @@ class ColumnarPluginConfig(conf: SQLConf) extends Logging {
val radixSortThreshold: Int =
conf.getConfString("spark.omni.sql.columnar.radixSortThreshold", "1000000").toInt
+
+ // enable or disable row shuffle
+ val enableRowShuffle: Boolean =
+ conf.getConfString("spark.omni.sql.columnar.rowShuffle.enabled", "true").toBoolean
+
+ val columnsThreshold: Int =
+ conf.getConfString("spark.omni.sql.columnar.columnsThreshold", "10").toInt
}
diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/expression/OmniExpressionAdaptor.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/expression/OmniExpressionAdaptor.scala
index 11ff8e12b54519484e6aa04885e3a50933d859a4..cfc95ae37506fc7ec2a39d45b8c8747783fcf298 100644
--- a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/expression/OmniExpressionAdaptor.scala
+++ b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/expression/OmniExpressionAdaptor.scala
@@ -35,7 +35,7 @@ import org.apache.spark.sql.catalyst.util.CharVarcharUtils.getRawTypeString
import org.apache.spark.sql.execution.ColumnarBloomFilterSubquery
import org.apache.spark.sql.expression.ColumnarExpressionConverter
import org.apache.spark.sql.hive.HiveUdfAdaptorUtil
-import org.apache.spark.sql.types.{BooleanType, DataType, DateType, Decimal, DecimalType, DoubleType, IntegerType, LongType, Metadata, NullType, ShortType, StringType, TimestampType}
+import org.apache.spark.sql.types.{BinaryType, BooleanType, DataType, DateType, Decimal, DecimalType, DoubleType, IntegerType, LongType, Metadata, NullType, ShortType, StringType, TimestampType}
import org.json.{JSONArray, JSONObject}
import java.util.Locale
@@ -74,10 +74,10 @@ object OmniExpressionAdaptor extends Logging {
}
}
- private def unsupportedCastCheck(expr: Expression, cast: Cast): Unit = {
+ private def unsupportedCastCheck(expr: Expression, cast: CastBase): Unit = {
def doSupportCastToString(dataType: DataType): Boolean = {
if (dataType.isInstanceOf[DecimalType] || dataType.isInstanceOf[StringType] || dataType.isInstanceOf[IntegerType]
- || dataType.isInstanceOf[LongType]) {
+ || dataType.isInstanceOf[LongType] || dataType.isInstanceOf[DateType] || dataType.isInstanceOf[DoubleType]) {
true
} else {
false
@@ -86,7 +86,7 @@ object OmniExpressionAdaptor extends Logging {
def doSupportCastFromString(dataType: DataType): Boolean = {
if (dataType.isInstanceOf[DecimalType] || dataType.isInstanceOf[StringType] || dataType.isInstanceOf[DateType]
- || dataType.isInstanceOf[IntegerType] || dataType.isInstanceOf[LongType]) {
+ || dataType.isInstanceOf[IntegerType] || dataType.isInstanceOf[LongType] || dataType.isInstanceOf[DoubleType]) {
true
} else {
false
@@ -103,10 +103,6 @@ object OmniExpressionAdaptor extends Logging {
throw new UnsupportedOperationException(s"Unsupported expression: $expr")
}
- // not support Cast(double as decimal)
- if (cast.dataType.isInstanceOf[DecimalType] && cast.child.dataType.isInstanceOf[DoubleType]) {
- throw new UnsupportedOperationException(s"Unsupported expression: $expr")
- }
}
def rewriteToOmniJsonExpressionLiteral(expr: Expression,
@@ -242,6 +238,7 @@ object OmniExpressionAdaptor extends Logging {
case alias: Alias => rewriteToOmniJsonExpressionLiteralJsonObject(alias.child, exprsIndexMap)
case literal: Literal => toOmniJsonLiteral(literal)
+
case not: Not =>
not.child match {
case isnull: IsNull =>
@@ -263,6 +260,7 @@ object OmniExpressionAdaptor extends Logging {
.put("operator", "not")
.put("expr", rewriteToOmniJsonExpressionLiteralJsonObject(not.child, exprsIndexMap))
}
+
case isnotnull: IsNotNull =>
new JSONObject().put("exprType", "UNARY")
.addOmniExpJsonType("returnType", BooleanType)
@@ -287,7 +285,7 @@ object OmniExpressionAdaptor extends Logging {
.put(rewriteToOmniJsonExpressionLiteralJsonObject(subString.len, exprsIndexMap)))
// Cast
- case cast: Cast =>
+ case cast: CastBase =>
unsupportedCastCheck(expr, cast)
cast.dataType match {
case StringType =>
@@ -302,8 +300,8 @@ object OmniExpressionAdaptor extends Logging {
.addOmniExpJsonType("returnType", cast.dataType)
.put("function_name", "CAST")
.put("arguments", new JSONArray().put(rewriteToOmniJsonExpressionLiteralJsonObject(cast.child, exprsIndexMap)))
-
}
+
// Abs
case abs: Abs =>
new JSONObject().put("exprType", "FUNCTION")
@@ -414,6 +412,13 @@ object OmniExpressionAdaptor extends Logging {
.put("arguments", new JSONArray().put(rewriteToOmniJsonExpressionLiteralJsonObject(inStr.str, exprsIndexMap))
.put(rewriteToOmniJsonExpressionLiteralJsonObject(inStr.substr, exprsIndexMap)))
+ case rlike: RLike =>
+ new JSONObject().put("exprType", "FUNCTION")
+ .addOmniExpJsonType("returnType", rlike.dataType)
+ .put("function_name", "RLike")
+ .put("arguments", new JSONArray().put(rewriteToOmniJsonExpressionLiteralJsonObject(rlike.left, exprsIndexMap))
+ .put(rewriteToOmniJsonExpressionLiteralJsonObject(rlike.right, exprsIndexMap)))
+
// for floating numbers normalize
case normalizeNaNAndZero: NormalizeNaNAndZero =>
new JSONObject().put("exprType", "FUNCTION")
@@ -450,6 +455,25 @@ object OmniExpressionAdaptor extends Logging {
throw new UnsupportedOperationException(s"Unsupported right expression in like expression: $endsWith")
}
+ case truncDate: TruncDate =>
+ new JSONObject().put("exprType", "FUNCTION")
+ .addOmniExpJsonType("returnType", truncDate.dataType)
+ .put("function_name", "trunc_date")
+ .put("arguments", new JSONArray().put(rewriteToOmniJsonExpressionLiteralJsonObject(truncDate.left, exprsIndexMap))
+ .put(rewriteToOmniJsonExpressionLiteralJsonObject(truncDate.right, exprsIndexMap)))
+
+ case md5: Md5 =>
+ md5.child match {
+ case Cast(inputExpression, outputType, _, _) if outputType == BinaryType =>
+ inputExpression match {
+ case AttributeReference(_, dataType, _, _) if dataType == StringType =>
+ new JSONObject().put("exprType", "FUNCTION")
+ .addOmniExpJsonType("returnType", md5.dataType)
+ .put("function_name", "Md5")
+ .put("arguments", new JSONArray().put(rewriteToOmniJsonExpressionLiteralJsonObject(inputExpression, exprsIndexMap)))
+ }
+ }
+
case _ =>
if (HiveUdfAdaptorUtil.isHiveUdf(expr) && ColumnarPluginConfig.getSessionConf.enableColumnarUdf) {
val hiveUdf = HiveUdfAdaptorUtil.asHiveSimpleUDF(expr)
@@ -723,6 +747,10 @@ object OmniExpressionAdaptor extends Logging {
}
}
+ def sparkTypeToOmniType(dataType: DataType): Int = {
+ sparkTypeToOmniType(dataType, Metadata.empty).getId.ordinal()
+ }
+
def sparkTypeToOmniType(dataType: DataType, metadata: Metadata = Metadata.empty):
nova.hetu.omniruntime.`type`.DataType = {
dataType match {
diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/serialize/ColumnarBatchSerializer.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/serialize/ColumnarBatchSerializer.scala
index 07ac07e8f81a47811ce8b979c9efcc3b52f882e3..26e2b7a3e057662fd9abee0db0345beeeb63f523 100644
--- a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/serialize/ColumnarBatchSerializer.scala
+++ b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/serialize/ColumnarBatchSerializer.scala
@@ -28,15 +28,18 @@ import org.apache.spark.serializer.{DeserializationStream, SerializationStream,
import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.vectorized.ColumnarBatch
-class ColumnarBatchSerializer(readBatchNumRows: SQLMetric, numOutputRows: SQLMetric)
+class ColumnarBatchSerializer(readBatchNumRows: SQLMetric,
+ numOutputRows: SQLMetric,
+ isRowShuffle: Boolean = false)
extends Serializer
with Serializable {
/** Creates a new [[SerializerInstance]]. */
override def newInstance(): SerializerInstance =
- new ColumnarBatchSerializerInstance(readBatchNumRows, numOutputRows)
+ new ColumnarBatchSerializerInstance(isRowShuffle, readBatchNumRows, numOutputRows)
}
private class ColumnarBatchSerializerInstance(
+ isRowShuffle: Boolean,
readBatchNumRows: SQLMetric,
numOutputRows: SQLMetric)
extends SerializerInstance with Logging {
@@ -62,7 +65,6 @@ private class ColumnarBatchSerializerInstance(
new DataInputStream(new BufferedInputStream(in))
}
private[this] var columnarBuffer: Array[Byte] = new Array[Byte](1024)
- val ibuffer: ByteBuffer = ByteBuffer.allocateDirect(4)
private[this] val EOF: Int = -1
@@ -85,7 +87,7 @@ private class ColumnarBatchSerializerInstance(
}
ByteStreams.readFully(dIn, columnarBuffer, 0, dataSize)
// protobuf serialize
- val columnarBatch: ColumnarBatch = ShuffleDataSerializer.deserialize(columnarBuffer.slice(0, dataSize))
+ val columnarBatch: ColumnarBatch = ShuffleDataSerializer.deserialize(isRowShuffle, columnarBuffer.slice(0, dataSize))
dataSize = readSize()
if (dataSize == EOF) {
dIn.close()
@@ -114,7 +116,7 @@ private class ColumnarBatchSerializerInstance(
}
ByteStreams.readFully(dIn, columnarBuffer, 0, dataSize)
// protobuf serialize
- val columnarBatch: ColumnarBatch = ShuffleDataSerializer.deserialize(columnarBuffer.slice(0, dataSize))
+ val columnarBatch: ColumnarBatch = ShuffleDataSerializer.deserialize(isRowShuffle, columnarBuffer.slice(0, dataSize))
numBatchesTotal += 1
numRowsTotal += columnarBatch.numRows()
columnarBatch.asInstanceOf[T]
diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/util/OmniAdaptorUtil.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/util/OmniAdaptorUtil.scala
index 875fe939dcbd932877165079a6f16400fc968865..113e88399ba897dcc38c3c40841a18dafd2a9315 100644
--- a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/util/OmniAdaptorUtil.scala
+++ b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/util/OmniAdaptorUtil.scala
@@ -330,12 +330,12 @@ object OmniAdaptorUtil {
operator
}
- def pruneOutput(output: Seq[Attribute], projectList: Seq[NamedExpression]): Seq[Attribute] = {
- if (projectList.nonEmpty) {
+ def pruneOutput(output: Seq[Attribute], projectExprIdList: Seq[ExprId]): Seq[Attribute] = {
+ if (projectExprIdList.nonEmpty) {
val projectOutput = ListBuffer[Attribute]()
- for (project <- projectList) {
+ for (index <- projectExprIdList.indices) {
for (col <- output) {
- if (col.exprId.equals(getProjectAliasExprId(project))) {
+ if (col.exprId.equals(projectExprIdList(index))) {
projectOutput += col
}
}
@@ -346,13 +346,13 @@ object OmniAdaptorUtil {
}
}
- def getIndexArray(output: Seq[Attribute], projectList: Seq[NamedExpression]): Array[Int] = {
- if (projectList.nonEmpty) {
+ def getIndexArray(output: Seq[Attribute], projectExprIdList: Seq[ExprId]): Array[Int] = {
+ if (projectExprIdList.nonEmpty) {
val indexList = ListBuffer[Int]()
- for (project <- projectList) {
+ for (index <- projectExprIdList.indices) {
for (i <- output.indices) {
val col = output(i)
- if (col.exprId.equals(getProjectAliasExprId(project))) {
+ if (col.exprId.equals(projectExprIdList(index))) {
indexList += i
}
}
@@ -363,23 +363,50 @@ object OmniAdaptorUtil {
}
}
- def reorderVecs(prunedOutput: Seq[Attribute], projectList: Seq[NamedExpression], resultVecs: Array[nova.hetu.omniruntime.vector.Vec], vecs: Array[OmniColumnVector]) = {
- val used = new Array[Boolean](resultVecs.length)
- for (index <- projectList.indices) {
- val project = projectList(index)
+ def reorderOutputVecs(projectListIndex: Array[Int], omniVecs: Array[nova.hetu.omniruntime.vector.Vec],
+ outputVecs: Array[OmniColumnVector]) = {
+ for (index <- projectListIndex.indices) {
+ val outputVec = outputVecs(index)
+ outputVec.reset()
+ val projectIndex = projectListIndex(index)
+ outputVec.setVec(omniVecs(projectIndex))
+ }
+ }
+
+ def getProjectListIndex(projectExprIdList: Seq[ExprId], probeOutput: Seq[Attribute],
+ buildOutput: Seq[Attribute]): Array[Int] = {
+ val projectListIndex = ListBuffer[Int]()
+ var probeIndex = 0
+ var buildIndex = probeOutput.size
+ for (index <- projectExprIdList.indices) {
breakable {
- for (i <- prunedOutput.indices) {
- val col = prunedOutput(i)
- if (!used(i) && col.exprId.equals(getProjectAliasExprId(project))) {
- val v = vecs(index)
- v.reset()
- v.setVec(resultVecs(i))
- used(i) = true;
+ for (probeAttr <- probeOutput) {
+ if (probeAttr.exprId.equals(projectExprIdList(index))) {
+ projectListIndex += probeIndex
+ probeIndex += 1
break
}
}
}
+ breakable {
+ for (buildAttr <- buildOutput) {
+ if (buildAttr.exprId.equals(projectExprIdList(index))) {
+ projectListIndex += buildIndex
+ buildIndex += 1
+ break
+ }
+ }
+ }
+ }
+ projectListIndex.toArray
+ }
+
+ def getExprIdForProjectList(projectList: Seq[NamedExpression]): Seq[ExprId] = {
+ val exprIdList = ListBuffer[ExprId]()
+ for (project <- projectList) {
+ exprIdList += getProjectAliasExprId(project)
}
+ exprIdList
}
def getProjectAliasExprId(project: NamedExpression): ExprId = {
diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/shuffle/ColumnarShuffleDependency.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/shuffle/ColumnarShuffleDependency.scala
index 4c27688cb741eec4913ea51e23e14dfa16aa6b64..215be3846b17c6f902987364acb33693a6e8750b 100644
--- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/shuffle/ColumnarShuffleDependency.scala
+++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/shuffle/ColumnarShuffleDependency.scala
@@ -48,6 +48,7 @@ class ColumnarShuffleDependency[K: ClassTag, V: ClassTag, C: ClassTag](
override val aggregator: Option[Aggregator[K, V, C]] = None,
override val mapSideCombine: Boolean = false,
override val shuffleWriterProcessor: ShuffleWriteProcessor = new ShuffleWriteProcessor,
+ val handleRow: Boolean,
val partitionInfo: PartitionInfo,
val dataSize: SQLMetric,
val bytesSpilled: SQLMetric,
diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/shuffle/ColumnarShuffleWriter.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/shuffle/ColumnarShuffleWriter.scala
index 615ddb6b7449d3ba3f2b6839df8eee67b6d5b05e..a8b7d9eab39374404a5bff4f176569e97e2bd0de 100644
--- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/shuffle/ColumnarShuffleWriter.scala
+++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/shuffle/ColumnarShuffleWriter.scala
@@ -25,6 +25,7 @@ import nova.hetu.omniruntime.vector.VecBatch
import org.apache.spark.SparkEnv
import org.apache.spark.internal.Logging
import org.apache.spark.scheduler.MapStatus
+import org.apache.spark.sql.execution.util.SparkMemoryUtils
import org.apache.spark.sql.vectorized.ColumnarBatch
import org.apache.spark.util.Utils
@@ -49,7 +50,8 @@ class ColumnarShuffleWriter[K, V](
val columnarConf = ColumnarPluginConfig.getSessionConf
val shuffleSpillBatchRowNum = columnarConf.columnarShuffleSpillBatchRowNum
- val shuffleSpillMemoryThreshold = columnarConf.columnarShuffleSpillMemoryThreshold
+ val shuffleTaskSpillMemoryThreshold = columnarConf.columnarShuffleTaskSpillMemoryThreshold
+ val shuffleExecutorSpillMemoryThreshold = columnarConf.columnarSpillMemPctThreshold * SparkMemoryUtils.offHeapSize
val shuffleCompressBlockSize = columnarConf.columnarShuffleCompressBlockSize
val shuffleNativeBufferSize = columnarConf.columnarShuffleNativeBufferSize
val enableShuffleCompress = columnarConf.enableShuffleCompress
@@ -87,7 +89,8 @@ class ColumnarShuffleWriter[K, V](
localDirs,
shuffleCompressBlockSize,
shuffleSpillBatchRowNum,
- shuffleSpillMemoryThreshold)
+ shuffleTaskSpillMemoryThreshold,
+ shuffleExecutorSpillMemoryThreshold)
}
@@ -104,14 +107,22 @@ class ColumnarShuffleWriter[K, V](
dep.dataSize += input(col).getRealOffsetBufCapacityInBytes
}
val vb = new VecBatch(input, cb.numRows())
- jniWrapper.split(nativeSplitter, vb.getNativeVectorBatch)
+ if (!dep.handleRow) {
+ jniWrapper.split(nativeSplitter, vb.getNativeVectorBatch)
+ } else {
+ jniWrapper.rowSplit(nativeSplitter, vb.getNativeVectorBatch)
+ }
dep.splitTime.add(System.nanoTime() - startTime)
dep.numInputRows.add(cb.numRows)
writeMetrics.incRecordsWritten(cb.numRows)
}
}
val startTime = System.nanoTime()
- splitResult = jniWrapper.stop(nativeSplitter)
+ if (!dep.handleRow) {
+ splitResult = jniWrapper.stop(nativeSplitter)
+ } else {
+ splitResult = jniWrapper.rowStop(nativeSplitter)
+ }
dep.splitTime.add(System.nanoTime() - startTime - splitResult.getTotalSpillTime -
splitResult.getTotalWriteTime - splitResult.getTotalComputePidTime)
diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CombineJoinedAggregates.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CombineJoinedAggregates.scala
new file mode 100644
index 0000000000000000000000000000000000000000..94e3e35e4d3bf29d24c2397c046b264e7eedd092
--- /dev/null
+++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CombineJoinedAggregates.scala
@@ -0,0 +1,350 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.optimizer
+import com.huawei.boostkit.spark.ColumnarPluginConfig
+import scala.collection.mutable.ArrayBuffer
+import org.apache.spark.internal.Logging
+
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
+import org.apache.spark.sql.catalyst.plans.{Cross, FullOuter, Inner, JoinType, LeftOuter, RightOuter}
+import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.catalyst.rules.Rule
+import scala.collection.mutable
+import org.apache.spark.sql.catalyst.analysis._
+import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
+import org.apache.spark.sql.catalyst.optimizer._
+import org.apache.spark.sql.catalyst.rules.RuleExecutor
+import org.apache.spark.sql.catalyst.trees.TreePattern.{AGGREGATE, DYNAMIC_PRUNING_EXPRESSION, DYNAMIC_PRUNING_SUBQUERY, EXISTS_SUBQUERY, HIGH_ORDER_FUNCTION, IN, IN_SUBQUERY, INSET, INVOKE, JOIN, JSON_TO_STRUCT, LIKE_FAMLIY, PYTHON_UDF, REGEXP_EXTRACT_FAMILY, REGEXP_REPLACE, SCALA_UDF}
+import org.apache.spark.sql.types.{DataType, Decimal, DecimalType, DoubleType, FloatType, IntegerType, LongType, ShortType}
+
+/**
+ * This rule eliminates the [[Join]] if all the join side are [[Aggregate]]s by combine these
+ * [[Aggregate]]s. This rule also support the nested [[Join]], as long as all the join sides for
+ * every [[Join]] are [[Aggregate]]s.
+ *
+ * Note: this rule doesn't support following cases:
+ * 1. The [[Aggregate]]s to be merged if at least one of them does not have a predicate or
+ * has low predicate selectivity.
+ * 2. The upstream node of these [[Aggregate]]s to be merged exists [[Join]].
+ */
+object CombineJoinedAggregates extends Rule[LogicalPlan] with MergeScalarSubqueriesHelper {
+
+ private def isSupportedJoinType(joinType: JoinType): Boolean =
+ Seq(Inner, Cross, LeftOuter, RightOuter, FullOuter).contains(joinType)
+
+ private def maxTreeNodeNumOfPredicate: Int = 10
+
+ private def isCheapPredicate(e: Expression): Boolean = {
+ !e.containsAnyPattern(PYTHON_UDF, SCALA_UDF, INVOKE, JSON_TO_STRUCT, LIKE_FAMLIY,
+ REGEXP_EXTRACT_FAMILY, REGEXP_REPLACE, DYNAMIC_PRUNING_SUBQUERY, DYNAMIC_PRUNING_EXPRESSION,
+ HIGH_ORDER_FUNCTION, IN_SUBQUERY, IN, INSET, EXISTS_SUBQUERY)
+ //countPredicatesInExpressions(e)
+ }
+
+private def checkCondition(leftCondition: Expression, rightCondition: Expression): Boolean = {
+ val normalizedLeft = normalizeExpression(leftCondition)
+ val normalizedRight = normalizeExpression(rightCondition)
+ if (normalizedLeft.isDefined && normalizedRight.isDefined) {
+ (normalizedLeft.get, normalizedRight.get) match {
+ case (a GreaterThan b, c LessThan d) if a.semanticEquals(c) =>
+ isGreaterOrEqualTo(b, d, a.dataType)
+ case (a LessThan b, c GreaterThan d) if a.semanticEquals(c) =>
+ isGreaterOrEqualTo(d, b, a.dataType)
+ case (a GreaterThanOrEqual b, c LessThan d) if a.semanticEquals(c) =>
+ isGreaterOrEqualTo(b, d, a.dataType)
+ case (a LessThan b, c GreaterThanOrEqual d) if a.semanticEquals(c) =>
+ isGreaterOrEqualTo(d, b, a.dataType)
+ case (a GreaterThan b, c LessThanOrEqual d) if a.semanticEquals(c) =>
+ isGreaterOrEqualTo(b, d, a.dataType)
+ case (a LessThanOrEqual b, c GreaterThan d) if a.semanticEquals(c) =>
+ isGreaterOrEqualTo(d, b, a.dataType)
+ case (a EqualTo b, Not(c EqualTo d)) if a.semanticEquals(c) =>
+ isEqualTo(b, d, a.dataType)
+ case _ => false
+ }
+ } else {
+ false
+ }
+ }
+
+ private def normalizeExpression(expr: Expression): Option[Expression] = {
+ expr match {
+ case gt @ GreaterThan(_, r) if r.foldable =>
+ Some(gt)
+ case l GreaterThan r if l.foldable =>
+ Some(LessThanOrEqual(r, l))
+ case lt @ LessThan(_, r) if r.foldable =>
+ Some(lt)
+ case l LessThan r if l.foldable =>
+ Some(GreaterThanOrEqual(r, l))
+ case gte @ GreaterThanOrEqual(_, r) if r.foldable =>
+ Some(gte)
+ case l GreaterThanOrEqual r if l.foldable =>
+ Some(LessThan(r, l))
+ case lte @ LessThanOrEqual(_, r) if r.foldable =>
+ Some(lte)
+ case l LessThanOrEqual r if l.foldable =>
+ Some(GreaterThan(r, l))
+ case eq @ EqualTo(_, r) if r.foldable =>
+ Some(eq)
+ case l EqualTo r if l.foldable =>
+ Some(EqualTo(r, l))
+ case not @ Not(EqualTo(l, r)) if r.foldable =>
+ Some(not)
+ case Not(l EqualTo r) if l.foldable =>
+ Some(Not(EqualTo(r, l)))
+ case _ => None
+ }
+ }
+
+ private def isGreaterOrEqualTo(
+ left: Expression, right: Expression, dataType: DataType): Boolean = dataType match {
+ case ShortType => left.eval().asInstanceOf[Short] >= right.eval().asInstanceOf[Short]
+ case IntegerType => left.eval().asInstanceOf[Int] >= right.eval().asInstanceOf[Int]
+ case LongType => left.eval().asInstanceOf[Long] >= right.eval().asInstanceOf[Long]
+ case FloatType => left.eval().asInstanceOf[Float] >= right.eval().asInstanceOf[Float]
+ case DoubleType => left.eval().asInstanceOf[Double] >= right.eval().asInstanceOf[Double]
+ case DecimalType.Fixed(_, _) =>
+ left.eval().asInstanceOf[Decimal] >= right.eval().asInstanceOf[Decimal]
+ case _ => false
+ }
+
+ private def isEqualTo(
+ left: Expression, right: Expression, dataType: DataType): Boolean = dataType match {
+ case ShortType => left.eval().asInstanceOf[Short] == right.eval().asInstanceOf[Short]
+ case IntegerType => left.eval().asInstanceOf[Int] == right.eval().asInstanceOf[Int]
+ case LongType => left.eval().asInstanceOf[Long] == right.eval().asInstanceOf[Long]
+ case FloatType => left.eval().asInstanceOf[Float] == right.eval().asInstanceOf[Float]
+ case DoubleType => left.eval().asInstanceOf[Double] == right.eval().asInstanceOf[Double]
+ case DecimalType.Fixed(_, _) =>
+ left.eval().asInstanceOf[Decimal] == right.eval().asInstanceOf[Decimal]
+ case _ => false
+ }
+
+ def countPredicatesInExpressions(expression: Expression): Int = {
+ expression match {
+ // If the expression is a predicate, count it
+ case predicate: Predicate => 1
+ // If the expression is a complex expression, recursively count predicates in its children
+ case complexExpression =>
+ complexExpression.children.map(countPredicatesInExpressions).sum
+ }
+ }
+
+ def normalizeJoinExpression(expr: Expression): Expression = expr match {
+ case BinaryComparison(left, right) =>
+ val sortedChildren = Seq(left, right).sortBy(_.toString)
+ expr.withNewChildren(sortedChildren)
+ case _ => expr.transform {
+ case a: AttributeReference => UnresolvedAttribute(a.name)
+ }
+ }
+
+ def extendedNormalizeExpression(expr: Expression): Expression = {
+ expr.transformUp {
+ // Normalize attributes by name, ignoring exprId
+ case attr: AttributeReference =>
+ // You can adjust the normalization based on what aspects of the attributes are significant for your comparison
+ // Here, we're focusing on the name and data type, but excluding metadata and other identifiers
+ AttributeReference(attr.name, attr.dataType, attr.nullable)(exprId = NamedExpression.newExprId, qualifier = attr.qualifier)
+
+ // Unwrap aliases to compare the underlying expressions directly
+ case Alias(child, _) => child
+
+ case Cast(child, dataType,_,_) =>
+ // Normalize child and retain the cast's target data type
+ Cast(extendedNormalizeExpression(child), dataType)
+
+ // Handle commutative operations by sorting their children
+ case b: BinaryOperator if b.isInstanceOf[Add] || b.isInstanceOf[Multiply] =>
+ val sortedChildren = b.children.sortBy(_.toString())
+ b.withNewChildren(sortedChildren)
+
+ // Further transformations can be added here to handle other specific cases as needed
+ }
+ }
+
+ // Function to compare two join conditions after normalization
+ def isJoinConditionEqual(condition1: Option[Expression], condition2: Option[Expression]): Boolean = {
+ (condition1, condition2) match {
+ case (Some(expr1), Some(expr2)) =>
+
+ // Check join condition
+ val pattern = "#\\d+"
+ val result1 = expr1.toString().replaceAll(pattern, "")
+ val result2 = expr2.toString().replaceAll(pattern, "")
+ return result1 == result2
+
+ /*val normalizedExpr1 = normalizeJoinExpression(expr1)
+ val normalizedExpr2 = normalizeJoinExpression(expr2)
+ if(normalizedExpr1.semanticEquals(normalizedExpr2))
+ return true
+
+ val extendedNormalizedExpr1 = extendedNormalizeExpression(expr1)
+ val extendedNormalizedExpr2 = extendedNormalizeExpression(expr2)
+ if(extendedNormalizedExpr1.semanticEquals(extendedNormalizedExpr2))
+ return true*/
+
+ case (None, None) => true // Both conditions are None
+ case _ => false // One condition is None and the other is not
+ }
+ }
+
+ // Function to check if two joins are the same
+ def areJoinsEqual(join1: Join, join2: Join): Boolean = {
+ // Check join type
+ if (join1.joinType != join2.joinType) return false
+
+ if (!isJoinConditionEqual(join1.condition, join2.condition)) return false
+
+ // Joins are equal
+ true
+ }
+
+ // class to hold Expression with boolean flag
+ case class ExpressionHolder(val expression: Expression, val propagate: Boolean)
+ /**
+ * Try to merge two `Aggregate`s by traverse down recursively.
+ *
+ * @return The optional tuple as follows:
+ * 1. the merged plan
+ * 2. the attribute mapping from the old to the merged version
+ * 3. optional filters of both plans that need to be propagated and merged in an
+ * ancestor `Aggregate` node if possible.
+ */
+ private def mergePlan(
+ left: LogicalPlan,
+ right: LogicalPlan): Option[(LogicalPlan, AttributeMap[Attribute], Seq[ExpressionHolder])] = {
+ (left, right) match {
+ case (la: Aggregate, ra: Aggregate) =>
+ mergePlan(la.child, ra.child).map { case (newChild, outputMap, filters) =>
+ val rightAggregateExprs = ra.aggregateExpressions.map(mapAttributes(_, outputMap))
+
+ // Filter the sequence to include only those entries where the propagate is true
+ val filtersToBePropagated: Seq[ExpressionHolder] = filters.filter(_.propagate)
+ val mergedAggregateExprs = if (filtersToBePropagated.length == 2) {
+ Seq(
+ (la.aggregateExpressions, filtersToBePropagated.head.expression),
+ (rightAggregateExprs, filtersToBePropagated.last.expression)
+ ).flatMap { case (aggregateExpressions, propagatedFilter) =>
+ aggregateExpressions.map { ne =>
+ ne.transform {
+ case ae @ AggregateExpression(_, _, _, filterOpt, _) =>
+ val newFilter = filterOpt.map { filter =>
+ //And(propagatedFilter, filter)
+ filter
+ }.orElse(Some(propagatedFilter))
+ ae.copy(filter = newFilter)
+ }.asInstanceOf[NamedExpression]
+ }
+ }
+ } else {
+ la.aggregateExpressions ++ rightAggregateExprs
+ }
+
+ (Aggregate(Seq.empty, mergedAggregateExprs, newChild), AttributeMap.empty, Seq.empty)
+ }
+ case (lp: Project, rp: Project) =>
+ val mergedProjectList = ArrayBuffer[NamedExpression](lp.projectList: _*)
+
+ mergePlan(lp.child, rp.child).map { case (newChild, outputMap, filters) =>
+ val allFilterReferences = filters.flatMap(_.expression.references)
+ val newOutputMap = AttributeMap((rp.projectList ++ allFilterReferences).map { ne =>
+ val mapped = mapAttributes(ne, outputMap)
+
+ val withoutAlias = mapped match {
+ case Alias(child, _) => child
+ case e => e
+ }
+
+ val outputAttr = mergedProjectList.find {
+ case Alias(child, _) => child semanticEquals withoutAlias
+ case e => e semanticEquals withoutAlias
+ }.getOrElse {
+ mergedProjectList += mapped
+ mapped
+ }.toAttribute
+ ne.toAttribute -> outputAttr
+ })
+
+ (Project(mergedProjectList.toSeq, newChild), newOutputMap, filters)
+ }
+ case (lf: Filter, rf: Filter)
+ if isCheapPredicate(lf.condition) && isCheapPredicate(rf.condition) =>
+
+ val pattern = "#\\d+"
+ // Replace the matched pattern with an empty string
+ val result1 = lf.condition.toString().replaceAll(pattern, "")
+ val result2 = rf.condition.toString().replaceAll(pattern, "")
+
+ if (result1 == result2 || lf.condition == rf.condition || checkCondition(lf.condition, rf.condition)) {
+ // If both conditions are the same, proceed with one of them.
+ mergePlan(lf.child, rf.child).map { case (newChild, outputMap, filters) =>
+ (Filter(lf.condition, newChild), outputMap, Seq(ExpressionHolder(lf.condition, false)))
+ }
+ } else {
+ mergePlan(lf.child, rf.child).map {
+ case (newChild, outputMap, filters) =>
+ val mappedRightCondition = mapAttributes(rf.condition, outputMap)
+ val (newLeftCondition, newRightCondition) = if (filters.length == 2) {
+ (And(lf.condition, filters.head.expression), And(mappedRightCondition, filters.last.expression))
+ } else {
+ (lf.condition, mappedRightCondition)
+ }
+ val newCondition = Or(newLeftCondition, newRightCondition)
+ (Filter(newCondition, newChild), outputMap, Seq(ExpressionHolder(newLeftCondition,true), ExpressionHolder(newRightCondition,true)))
+ }
+ }
+ case (lj: Join, rj: Join) =>
+ if (areJoinsEqual(lj, rj)) {
+ mergePlan(lj.left, rj.left).flatMap { case (newLeft, leftOutputMap, leftFilters) =>
+ mergePlan(lj.right, rj.right).map { case (newRight, rightOutputMap, rightFilters) =>
+ val newJoin = Join(newLeft, newRight, lj.joinType, lj.condition, lj.hint)
+ val mergedOutputMap = leftOutputMap ++ rightOutputMap
+ val mergedFilters = leftFilters ++ rightFilters
+ (newJoin, mergedOutputMap, mergedFilters)
+ }
+ }
+ } else {
+ None
+ }
+ case (ll: LeafNode, rl: LeafNode) =>
+ checkIdenticalPlans(rl, ll).map { outputMap =>
+ (ll, outputMap, Seq.empty)
+ }
+ case (ls: SerializeFromObject, rs: SerializeFromObject) =>
+ checkIdenticalPlans(rs, ls).map { outputMap =>
+ (ls, outputMap, Seq.empty)
+ }
+ case _ => None
+ }
+ }
+
+ def apply(plan: LogicalPlan): LogicalPlan = {
+ if (!ColumnarPluginConfig.getConf.combineJoinedAggregatesEnabled) return plan
+ //apply rule on children first then itself
+ plan.transformUpWithPruning(_.containsAnyPattern(JOIN, AGGREGATE)) {
+ case j @ Join(left: Aggregate, right: Aggregate, joinType, None, _)
+ if isSupportedJoinType(joinType) &&
+ left.groupingExpressions.isEmpty && right.groupingExpressions.isEmpty =>
+ val mergedAggregate = mergePlan(left, right)
+ mergedAggregate.map(_._1).getOrElse(j)
+ }
+ }
+}
\ No newline at end of file
diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/optimizer/MergeScalarSubqueriesHelper.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/optimizer/MergeScalarSubqueriesHelper.scala
new file mode 100644
index 0000000000000000000000000000000000000000..6d70bc416d8c4817b8e72ef883f698b6a16bd450
--- /dev/null
+++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/optimizer/MergeScalarSubqueriesHelper.scala
@@ -0,0 +1,43 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.optimizer
+
+import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, Expression}
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+
+/**
+ * The helper class used to merge scalar subqueries.
+ */
+trait MergeScalarSubqueriesHelper {
+
+ // If 2 plans are identical return the attribute mapping from the left to the right.
+ protected def checkIdenticalPlans(
+ left: LogicalPlan, right: LogicalPlan): Option[AttributeMap[Attribute]] = {
+ if (left.canonicalized == right.canonicalized) {
+ Some(AttributeMap(left.output.zip(right.output)))
+ } else {
+ None
+ }
+ }
+
+ protected def mapAttributes[T <: Expression](expr: T, outputMap: AttributeMap[Attribute]): T = {
+ expr.transform {
+ case a: Attribute => outputMap.getOrElse(a, a)
+ }.asInstanceOf[T]
+ }
+}
\ No newline at end of file
diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarHashAggregateExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarHashAggregateExec.scala
index 8eff1774a10663f0e7249b9ad1f0abe991ced544..55fba9f2b750d215326d6835ca849ba66d994caf 100644
--- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarHashAggregateExec.scala
+++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarHashAggregateExec.scala
@@ -198,8 +198,9 @@ case class ColumnarHashAggregateExec(
val finalOut = groupingExpressions.map(_.toAttribute) ++ aggregateAttributes
finalOut.map(
exp => sparkTypeToOmniType(exp.dataType, exp.metadata)).toArray
+ val finalAttrExprsIdMap = getExprIdMap(finalOut)
val projectExpressions: Array[AnyRef] = resultExpressions.map(
- exp => rewriteToOmniJsonExpressionLiteral(exp, getExprIdMap(finalOut))).toArray
+ exp => rewriteToOmniJsonExpressionLiteral(exp, finalAttrExprsIdMap)).toArray
if (!isSimpleColumnForAll(projectExpressions.map(expr => expr.toString))) {
checkOmniJsonWhiteList("", projectExpressions)
}
@@ -297,13 +298,14 @@ case class ColumnarHashAggregateExec(
child.executeColumnar().mapPartitionsWithIndex { (index, iter) =>
val columnarConf = ColumnarPluginConfig.getSessionConf
- val hashAggSpillRowThreshold = columnarConf.columnarHashAggSpillRowThreshold
+ val spillWriteBufferSize = columnarConf.columnarSpillWriteBufferSize
val spillMemPctThreshold = columnarConf.columnarSpillMemPctThreshold
val spillDirDiskReserveSize = columnarConf.columnarSpillDirDiskReserveSize
val hashAggSpillEnable = columnarConf.enableHashAggSpill
+ val hashAggSpillRowThreshold = columnarConf.columnarHashAggSpillRowThreshold
val spillDirectory = generateSpillDir(tmpSparkConf, "columnarHashAggSpill")
val sparkSpillConf = new SparkSpillConfig(hashAggSpillEnable, spillDirectory,
- spillDirDiskReserveSize, hashAggSpillRowThreshold, spillMemPctThreshold)
+ spillDirDiskReserveSize, hashAggSpillRowThreshold, spillMemPctThreshold, spillWriteBufferSize)
val startCodegen = System.nanoTime()
val operator = OmniAdaptorUtil.getAggOperator(groupingExpressions,
@@ -373,10 +375,11 @@ case class ColumnarHashAggregateExec(
}
if (finalStep) {
val finalOut = groupingExpressions.map(_.toAttribute) ++ aggregateAttributes
+ val finalAttrExprsIdMap = getExprIdMap(finalOut)
val projectInputTypes = finalOut.map(
exp => sparkTypeToOmniType(exp.dataType, exp.metadata)).toArray
val projectExpressions = resultExpressions.map(
- exp => rewriteToOmniJsonExpressionLiteral(exp, getExprIdMap(finalOut))).toArray
+ exp => rewriteToOmniJsonExpressionLiteral(exp, finalAttrExprsIdMap)).toArray
dealPartitionData(null, null, addInputTime, omniCodegenTime,
getOutputTime, projectInputTypes, projectExpressions, hashAggIter, this.schema)
diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarLimit.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarLimit.scala
index 4ce57e12ab8cb9f939664c4ce8a59b441d65ffd2..3603ecccc9df432c2b2e19650204c2341ebf15d5 100644
--- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarLimit.scala
+++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarLimit.scala
@@ -267,6 +267,7 @@ case class ColumnarTakeOrderedAndProjectExec(
child.output,
SinglePartition,
serializer,
+ handleRow = false,
writeMetrics,
longMetric("dataSize"),
longMetric("bytesSpilled"),
diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarShuffleExchangeExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarShuffleExchangeExec.scala
index 6e6588304b3ba809fc9a9a561e66703a4f988122..e1e07dd48e69443e3f95c01d1f3260d5ae25d083 100644
--- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarShuffleExchangeExec.scala
+++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarShuffleExchangeExec.scala
@@ -54,10 +54,13 @@ import org.apache.spark.sql.vectorized.ColumnarBatch
import org.apache.spark.util.MutablePair
import org.apache.spark.util.random.XORShiftRandom
+import nova.hetu.omniruntime.vector.IntVec
+
case class ColumnarShuffleExchangeExec(
override val outputPartitioning: Partitioning,
child: SparkPlan,
- shuffleOrigin: ShuffleOrigin = ENSURE_REQUIREMENTS)
+ shuffleOrigin: ShuffleOrigin = ENSURE_REQUIREMENTS,
+ handleRow: Boolean = false)
extends ShuffleExchangeLike {
private lazy val writeMetrics =
@@ -78,13 +81,14 @@ case class ColumnarShuffleExchangeExec(
"numPartitions" -> SQLMetrics.createMetric(sparkContext, "number of partitions")
) ++ readMetrics ++ writeMetrics
- override def nodeName: String = "OmniColumnarShuffleExchange"
+ override def nodeName: String = if (!handleRow) "OmniColumnarShuffleExchange" else "OmniRowShuffleExchange"
override def supportsColumnar: Boolean = true
val serializer: Serializer = new ColumnarBatchSerializer(
longMetric("avgReadBatchNumRows"),
- longMetric("numOutputRows"))
+ longMetric("numOutputRows"),
+ handleRow)
@transient lazy val inputColumnarRDD: RDD[ColumnarBatch] = child.executeColumnar()
@@ -118,6 +122,7 @@ case class ColumnarShuffleExchangeExec(
child.output,
outputPartitioning,
serializer,
+ handleRow,
writeMetrics,
longMetric("dataSize"),
longMetric("bytesSpilled"),
@@ -189,6 +194,7 @@ object ColumnarShuffleExchangeExec extends Logging {
outputAttributes: Seq[Attribute],
newPartitioning: Partitioning,
serializer: Serializer,
+ handleRow: Boolean,
writeMetrics: Map[String, SQLMetric],
dataSize: SQLMetric,
bytesSpilled: SQLMetric,
@@ -277,16 +283,25 @@ object ColumnarShuffleExchangeExec extends Logging {
val addPid2ColumnBatch = addPidToColumnBatch()
cbIter.filter(cb => cb.numRows != 0 && cb.numCols != 0).map {
cb =>
- val pidArr = new Array[Int](cb.numRows)
- (0 until cb.numRows).foreach { i =>
- val row = cb.getRow(i)
- val pid = part.get.getPartition(partitionKeyExtractor(row))
- pidArr(i) = pid
- }
- val pidVec = new IntVec(cb.numRows)
- pidVec.put(pidArr, 0, 0, cb.numRows)
+ var pidVec: IntVec = null
+ try {
+ val pidArr = new Array[Int](cb.numRows)
+ (0 until cb.numRows).foreach { i =>
+ val row = cb.getRow(i)
+ val pid = part.get.getPartition(partitionKeyExtractor(row))
+ pidArr(i) = pid
+ }
+ pidVec = new IntVec(cb.numRows)
+ pidVec.put(pidArr, 0, 0, cb.numRows)
- addPid2ColumnBatch(pidVec, cb)
+ addPid2ColumnBatch(pidVec, cb)
+ } catch {
+ case e: Exception =>
+ if (pidVec != null) {
+ pidVec.close()
+ }
+ throw e
+ }
}
}
@@ -308,8 +323,17 @@ object ColumnarShuffleExchangeExec extends Logging {
val getRoundRobinPid = getRoundRobinPartitionKey
val addPid2ColumnBatch = addPidToColumnBatch()
cbIter.map { cb =>
- val pidVec = getRoundRobinPid(cb, numPartitions)
- addPid2ColumnBatch(pidVec, cb)
+ var pidVec: IntVec = null
+ try {
+ pidVec = getRoundRobinPid(cb, numPartitions)
+ addPid2ColumnBatch(pidVec, cb)
+ } catch {
+ case e: Exception =>
+ if (pidVec != null) {
+ pidVec.close()
+ }
+ throw e
+ }
}
}, isOrderSensitive = isOrderSensitive)
case RangePartitioning(sortingExpressions, _) =>
@@ -349,17 +373,26 @@ object ColumnarShuffleExchangeExec extends Logging {
})
cbIter.map { cb =>
- val vecs = transColBatchToOmniVecs(cb, true)
- op.addInput(new VecBatch(vecs, cb.numRows()))
- val res = op.getOutput
- if (res.hasNext) {
- val retBatch = res.next()
- val pidVec = retBatch.getVectors()(0)
- // close return VecBatch
- retBatch.close()
- addPid2ColumnBatch(pidVec.asInstanceOf[IntVec], cb)
- } else {
- throw new Exception("Empty Project Operator Result...")
+ var pidVec: IntVec = null
+ try {
+ val vecs = transColBatchToOmniVecs(cb, true)
+ op.addInput(new VecBatch(vecs, cb.numRows()))
+ val res = op.getOutput
+ if (res.hasNext) {
+ val retBatch = res.next()
+ pidVec = retBatch.getVectors()(0).asInstanceOf[IntVec]
+ // close return VecBatch
+ retBatch.close()
+ addPid2ColumnBatch(pidVec, cb)
+ } else {
+ throw new Exception("Empty Project Operator Result...")
+ }
+ } catch {
+ case e: Exception =>
+ if (pidVec != null) {
+ pidVec.close()
+ }
+ throw e
}
}
}, isOrderSensitive = isOrderSensitive)
@@ -393,6 +426,7 @@ object ColumnarShuffleExchangeExec extends Logging {
rddWithPartitionId,
new PartitionIdPassthrough(newPartitioning.numPartitions),
serializer,
+ handleRow = handleRow,
shuffleWriterProcessor = createShuffleWriteProcessor(writeMetrics),
partitionInfo = partitionInfo,
dataSize = dataSize,
diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarSortExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarSortExec.scala
index 55e4c6d5d38e0c7e7d4b33943089ba5dabd8e4cc..d94d256568e70a80e2b593f7281765826695fd3d 100644
--- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarSortExec.scala
+++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarSortExec.scala
@@ -93,13 +93,14 @@ case class ColumnarSortExec(
child.executeColumnar().mapPartitionsWithIndexInternal { (_, iter) =>
val columnarConf = ColumnarPluginConfig.getSessionConf
- val sortSpillRowThreshold = columnarConf.columnarSortSpillRowThreshold
+ val spillWriteBufferSize = columnarConf.columnarSpillWriteBufferSize
val spillMemPctThreshold = columnarConf.columnarSpillMemPctThreshold
val spillDirDiskReserveSize = columnarConf.columnarSpillDirDiskReserveSize
val sortSpillEnable = columnarConf.enableSortSpill
+ val sortSpillRowThreshold = columnarConf.columnarSortSpillRowThreshold
val spillDirectory = generateSpillDir(tmpSparkConf, "columnarSortSpill")
val sparkSpillConf = new SparkSpillConfig(sortSpillEnable, spillDirectory, spillDirDiskReserveSize,
- sortSpillRowThreshold, spillMemPctThreshold)
+ sortSpillRowThreshold, spillMemPctThreshold, spillWriteBufferSize)
val startCodegen = System.nanoTime()
val radixSortEnable = columnarConf.enableRadixSort
diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarWindowExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarWindowExec.scala
index 7d1828c27d9fa009afb343acbce821fbf0ca1c5d..837760ac89c7f714a2cdb8c75a41f6969840f2ca 100644
--- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarWindowExec.scala
+++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarWindowExec.scala
@@ -133,7 +133,11 @@ case class ColumnarWindowExec(windowExpression: Seq[NamedExpression],
val windowFrameEndTypes = new Array[OmniWindowFrameBoundType](winExpressions.size)
val windowFrameEndChannels = new Array[Int](winExpressions.size)
var attrMap: Map[String, Int] = Map()
-
+
+ if (winExpressions.isEmpty) {
+ throw new UnsupportedOperationException(s"Unsupported empty winExpressions")
+ }
+
for (sortAttr <- orderSpec) {
if (!sortAttr.child.isInstanceOf[AttributeReference]) {
throw new UnsupportedOperationException(s"Unsupported sort col : ${sortAttr.child.nodeName}")
@@ -356,13 +360,14 @@ case class ColumnarWindowExec(windowExpression: Seq[NamedExpression],
val windowExpressionWithProjectConstant = windowExpressionWithProject
child.executeColumnar().mapPartitionsWithIndexInternal { (index, iter) =>
val columnarConf = ColumnarPluginConfig.getSessionConf
- val windowSpillEnable = columnarConf.enableWindowSpill
+ val spillWriteBufferSize = columnarConf.columnarSpillWriteBufferSize
+ val spillMemPctThreshold = columnarConf.columnarSpillMemPctThreshold
val spillDirDiskReserveSize = columnarConf.columnarSpillDirDiskReserveSize
+ val windowSpillEnable = columnarConf.enableWindowSpill
val windowSpillRowThreshold = columnarConf.columnarWindowSpillRowThreshold
- val spillMemPctThreshold = columnarConf.columnarSpillMemPctThreshold
val spillDirectory = generateSpillDir(tmpSparkConf, "columnarWindowSpill")
val sparkSpillConfig = new SparkSpillConfig(windowSpillEnable, spillDirectory,
- spillDirDiskReserveSize, windowSpillRowThreshold, spillMemPctThreshold)
+ spillDirDiskReserveSize, windowSpillRowThreshold, spillMemPctThreshold, spillWriteBufferSize)
val startCodegen = System.nanoTime()
val windowOperatorFactory = new OmniWindowWithExprOperatorFactory(sourceTypes, outputCols,
diff --git a/omnioperator/omniop-spark-extension/java/src/main/java/org/apache/spark/sql/execution/datasources/orc/OmniOrcColumnarBatchReader.java b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OmniOrcColumnarBatchReader.java
similarity index 96%
rename from omnioperator/omniop-spark-extension/java/src/main/java/org/apache/spark/sql/execution/datasources/orc/OmniOrcColumnarBatchReader.java
rename to omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OmniOrcColumnarBatchReader.java
index aeaa10faab50bb0490529fadebb331db2c60efa5..93950e9f0f1b9339587893b95f99c70967024acc 100644
--- a/omnioperator/omniop-spark-extension/java/src/main/java/org/apache/spark/sql/execution/datasources/orc/OmniOrcColumnarBatchReader.java
+++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OmniOrcColumnarBatchReader.java
@@ -19,6 +19,7 @@
package org.apache.spark.sql.execution.datasources.orc;
import com.google.common.annotations.VisibleForTesting;
+import com.huawei.boostkit.spark.expression.OmniExpressionAdaptor;
import com.huawei.boostkit.spark.jni.OrcColumnarBatchScanReader;
import nova.hetu.omniruntime.vector.Vec;
import org.apache.hadoop.conf.Configuration;
@@ -79,6 +80,8 @@ public class OmniOrcColumnarBatchReader extends RecordReader();
+ // collect read cols types
+ ArrayList typeBuilder = new ArrayList<>();
for (int i = 0; i < requiredfieldNames.length; i++) {
String target = requiredfieldNames[i];
boolean is_find = false;
@@ -163,6 +168,7 @@ public class OmniOrcColumnarBatchReader extends RecordReader
- getIndexArray(buildOutput, projectList)
+ getIndexArray(buildOutput, projectExprIdList)
case LeftExistence(_) =>
Array[Int]()
case x =>
throw new UnsupportedOperationException(s"ColumnBroadcastHashJoin Join-type[$x] is not supported!")
}
+
+ val buildOutputExprIdMap = OmniExpressionAdaptor.getExprIdMap(buildOutput.map(_.toAttribute))
val buildJoinColsExp = buildKeys.map { x =>
- OmniExpressionAdaptor.rewriteToOmniJsonExpressionLiteral(x,
- OmniExpressionAdaptor.getExprIdMap(buildOutput.map(_.toAttribute)))
+ OmniExpressionAdaptor.rewriteToOmniJsonExpressionLiteral(x, buildOutputExprIdMap)
}.toArray
val relation = buildPlan.executeBroadcast[ColumnarHashedRelation]()
- val prunedBuildOutput = pruneOutput(buildOutput, projectList)
+ val prunedBuildOutput = pruneOutput(buildOutput, projectExprIdList)
val buildOutputTypes = new Array[DataType](prunedBuildOutput.size) // {2,2}, buildOutput:col1#12,col2#13
prunedBuildOutput.zipWithIndex.foreach { case (att, i) =>
buildOutputTypes(i) = OmniExpressionAdaptor.sparkTypeToOmniType(att.dataType, att.metadata)
@@ -324,14 +326,17 @@ case class ColumnarBroadcastHashJoinExec(
streamedOutput.zipWithIndex.foreach { case (attr, i) =>
probeTypes(i) = OmniExpressionAdaptor.sparkTypeToOmniType(attr.dataType, attr.metadata)
}
- val probeOutputCols = getIndexArray(streamedOutput, projectList) // {0,1}
+ val probeOutputCols = getIndexArray(streamedOutput, projectExprIdList) // {0,1}
+ val probeOutputExprIdMap = OmniExpressionAdaptor.getExprIdMap(streamedOutput.map(_.toAttribute))
val probeHashColsExp = streamedKeys.map { x =>
- OmniExpressionAdaptor.rewriteToOmniJsonExpressionLiteral(x,
- OmniExpressionAdaptor.getExprIdMap(streamedOutput.map(_.toAttribute)))
+ OmniExpressionAdaptor.rewriteToOmniJsonExpressionLiteral(x, probeOutputExprIdMap)
}.toArray
+ val prunedStreamedOutput = pruneOutput(streamedOutput, projectExprIdList)
+ val projectListIndex = getProjectListIndex(projectExprIdList, prunedStreamedOutput, prunedBuildOutput)
val lookupJoinType = OmniExpressionAdaptor.toOmniJoinType(joinType)
val canShareBuildOp = (lookupJoinType != OMNI_JOIN_TYPE_RIGHT && lookupJoinType != OMNI_JOIN_TYPE_FULL)
+
streamedPlan.executeColumnar().mapPartitionsWithIndexInternal { (index, iter) =>
val filter: Optional[String] = condition match {
case Some(expr) =>
@@ -341,7 +346,7 @@ case class ColumnarBroadcastHashJoinExec(
Optional.empty()
}
- def createBuildOpFactoryAndOp(): (OmniHashBuilderWithExprOperatorFactory, OmniOperator) = {
+ def createBuildOpFactoryAndOp(isShared: Boolean): (OmniHashBuilderWithExprOperatorFactory, OmniOperator) = {
val startBuildCodegen = System.nanoTime()
val opFactory =
new OmniHashBuilderWithExprOperatorFactory(lookupJoinType, buildTypes, buildJoinColsExp, 1,
@@ -350,6 +355,10 @@ case class ColumnarBroadcastHashJoinExec(
val op = opFactory.createOperator()
buildCodegenTime += NANOSECONDS.toMillis(System.nanoTime() - startBuildCodegen)
+ if (isShared) {
+ OmniHashBuilderWithExprOperatorFactory.saveHashBuilderOperatorAndFactory(buildPlan.id,
+ opFactory, op)
+ }
val deserializer = VecBatchSerializerFactory.create()
relation.value.buildData.foreach { input =>
val startBuildInput = System.nanoTime()
@@ -357,7 +366,19 @@ case class ColumnarBroadcastHashJoinExec(
buildAddInputTime += NANOSECONDS.toMillis(System.nanoTime() - startBuildInput)
}
val startBuildGetOp = System.nanoTime()
- op.getOutput
+ try {
+ op.getOutput
+ } catch {
+ case e: Exception => {
+ if (isShared) {
+ OmniHashBuilderWithExprOperatorFactory.dereferenceHashBuilderOperatorAndFactory(buildPlan.id)
+ } else {
+ op.close()
+ opFactory.close()
+ }
+ throw new RuntimeException("HashBuilder getOutput failed")
+ }
+ }
buildGetOutputTime += NANOSECONDS.toMillis(System.nanoTime() - startBuildGetOp)
(opFactory, op)
}
@@ -369,11 +390,9 @@ case class ColumnarBroadcastHashJoinExec(
try {
buildOpFactory = OmniHashBuilderWithExprOperatorFactory.getHashBuilderOperatorFactory(buildPlan.id)
if (buildOpFactory == null) {
- val (opFactory, op) = createBuildOpFactoryAndOp()
+ val (opFactory, op) = createBuildOpFactoryAndOp(true)
buildOpFactory = opFactory
buildOp = op
- OmniHashBuilderWithExprOperatorFactory.saveHashBuilderOperatorAndFactory(buildPlan.id,
- buildOpFactory, buildOp)
}
} catch {
case e: Exception => {
@@ -383,7 +402,7 @@ case class ColumnarBroadcastHashJoinExec(
OmniHashBuilderWithExprOperatorFactory.gLock.unlock()
}
} else {
- val (opFactory, op) = createBuildOpFactoryAndOp()
+ val (opFactory, op) = createBuildOpFactoryAndOp(false)
buildOpFactory = opFactory
buildOp = op
}
@@ -410,19 +429,17 @@ case class ColumnarBroadcastHashJoinExec(
}
})
- val streamedPlanOutput = pruneOutput(streamedPlan.output, projectList)
- val prunedOutput = streamedPlanOutput ++ prunedBuildOutput
val resultSchema = this.schema
val reverse = buildSide == BuildLeft
var left = 0
- var leftLen = streamedPlanOutput.size
- var right = streamedPlanOutput.size
+ var leftLen = prunedStreamedOutput.size
+ var right = prunedStreamedOutput.size
var rightLen = output.size
if (reverse) {
- left = streamedPlanOutput.size
+ left = prunedStreamedOutput.size
leftLen = output.size
right = 0
- rightLen = streamedPlanOutput.size
+ rightLen = prunedStreamedOutput.size
}
val iterBatch = new Iterator[ColumnarBatch] {
@@ -468,7 +485,7 @@ case class ColumnarBroadcastHashJoinExec(
val vecs = OmniColumnVector
.allocateColumns(result.getRowCount, resultSchema, false)
if (projectList.nonEmpty) {
- reorderVecs(prunedOutput, projectList, resultVecs, vecs)
+ reorderOutputVecs(projectListIndex, resultVecs, vecs)
} else {
var index = 0
for (i <- left until leftLen) {
@@ -566,4 +583,4 @@ case class ColumnarBroadcastHashJoinExec(
}
-}
\ No newline at end of file
+}
diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/joins/ColumnarShuffledHashJoinExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/joins/ColumnarShuffledHashJoinExec.scala
index e041a1fb119bc63bf689361534c207304383e83e..6e1f76d75a540345b2e5dacad61361fd39f185e8 100644
--- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/joins/ColumnarShuffledHashJoinExec.scala
+++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/joins/ColumnarShuffledHashJoinExec.scala
@@ -24,7 +24,7 @@ import com.huawei.boostkit.spark.Constant.IS_SKIP_VERIFY_EXP
import com.huawei.boostkit.spark.expression.OmniExpressionAdaptor
import com.huawei.boostkit.spark.expression.OmniExpressionAdaptor.{checkOmniJsonWhiteList, isSimpleColumn, isSimpleColumnForAll}
import com.huawei.boostkit.spark.util.OmniAdaptorUtil
-import com.huawei.boostkit.spark.util.OmniAdaptorUtil.{getIndexArray, pruneOutput, reorderVecs, transColBatchToOmniVecs}
+import com.huawei.boostkit.spark.util.OmniAdaptorUtil.{getExprIdForProjectList, getIndexArray, getProjectListIndex,pruneOutput, reorderOutputVecs, transColBatchToOmniVecs}
import nova.hetu.omniruntime.`type`.DataType
import nova.hetu.omniruntime.operator.config.{OperatorConfig, OverflowConfig, SpillConfig}
import nova.hetu.omniruntime.operator.join.{OmniHashBuilderWithExprOperatorFactory, OmniLookupJoinWithExprOperatorFactory, OmniLookupOuterJoinWithExprOperatorFactory}
@@ -62,8 +62,8 @@ case class ColumnarShuffledHashJoinExec(
s"""
|$formattedNodeName
|$simpleStringWithNodeId
- |${ExplainUtils.generateFieldString("buildOutput", buildOutput ++ buildOutput.map(_.dataType))}
- |${ExplainUtils.generateFieldString("streamedOutput", streamedOutput ++ streamedOutput.map(_.dataType))}
+ |${ExplainUtils.generateFieldString("buildInput", buildOutput ++ buildOutput.map(_.dataType))}
+ |${ExplainUtils.generateFieldString("streamedInput", streamedOutput ++ streamedOutput.map(_.dataType))}
|${ExplainUtils.generateFieldString("leftKeys", leftKeys ++ leftKeys.map(_.dataType))}
|${ExplainUtils.generateFieldString("rightKeys", rightKeys ++ rightKeys.map(_.dataType))}
|${ExplainUtils.generateFieldString("condition", joinCondStr)}
@@ -131,9 +131,9 @@ case class ColumnarShuffledHashJoinExec(
buildTypes(i) = OmniExpressionAdaptor.sparkTypeToOmniType(att.dataType, att.metadata)
}
+ val buildOutputExprIdMap = OmniExpressionAdaptor.getExprIdMap(buildOutput.map(_.toAttribute))
val buildJoinColsExp = buildKeys.map { x =>
- OmniExpressionAdaptor.rewriteToOmniJsonExpressionLiteral(x,
- OmniExpressionAdaptor.getExprIdMap(buildOutput.map(_.toAttribute)))
+ OmniExpressionAdaptor.rewriteToOmniJsonExpressionLiteral(x, buildOutputExprIdMap)
}.toArray
if (!isSimpleColumnForAll(buildJoinColsExp)) {
@@ -145,9 +145,9 @@ case class ColumnarShuffledHashJoinExec(
probeTypes(i) = OmniExpressionAdaptor.sparkTypeToOmniType(attr.dataType, attr.metadata)
}
+ val streamOutputExprIdMap = OmniExpressionAdaptor.getExprIdMap(streamedOutput.map(_.toAttribute))
val probeHashColsExp = streamedKeys.map { x =>
- OmniExpressionAdaptor.rewriteToOmniJsonExpressionLiteral(x,
- OmniExpressionAdaptor.getExprIdMap(streamedOutput.map(_.toAttribute)))
+ OmniExpressionAdaptor.rewriteToOmniJsonExpressionLiteral(x, streamOutputExprIdMap)
}.toArray
if (!isSimpleColumnForAll(probeHashColsExp)) {
@@ -186,21 +186,22 @@ case class ColumnarShuffledHashJoinExec(
buildTypes(i) = OmniExpressionAdaptor.sparkTypeToOmniType(att.dataType, att.metadata)
}
+ val projectExprIdList = getExprIdForProjectList(projectList)
val buildOutputCols: Array[Int] = joinType match {
case Inner | FullOuter | LeftOuter | RightOuter =>
- getIndexArray(buildOutput, projectList)
+ getIndexArray(buildOutput, projectExprIdList)
case LeftExistence(_) =>
Array[Int]()
case x =>
throw new UnsupportedOperationException(s"ColumnShuffledHashJoin Join-type[$x] is not supported!")
}
+ val buildOutputExprIdMap = OmniExpressionAdaptor.getExprIdMap(buildOutput.map(_.toAttribute))
val buildJoinColsExp = buildKeys.map { x =>
- OmniExpressionAdaptor.rewriteToOmniJsonExpressionLiteral(x,
- OmniExpressionAdaptor.getExprIdMap(buildOutput.map(_.toAttribute)))
+ OmniExpressionAdaptor.rewriteToOmniJsonExpressionLiteral(x, buildOutputExprIdMap)
}.toArray
- val prunedBuildOutput = pruneOutput(buildOutput, projectList)
+ val prunedBuildOutput = pruneOutput(buildOutput, projectExprIdList)
val buildOutputTypes = new Array[DataType](prunedBuildOutput.size) // {2,2}, buildOutput:col1#12,col2#13
prunedBuildOutput.zipWithIndex.foreach { case (att, i) =>
buildOutputTypes(i) = OmniExpressionAdaptor.sparkTypeToOmniType(att.dataType, att.metadata)
@@ -210,12 +211,14 @@ case class ColumnarShuffledHashJoinExec(
streamedOutput.zipWithIndex.foreach { case (attr, i) =>
probeTypes(i) = OmniExpressionAdaptor.sparkTypeToOmniType(attr.dataType, attr.metadata)
}
- val probeOutputCols = getIndexArray(streamedOutput, projectList)
+ val probeOutputCols = getIndexArray(streamedOutput, projectExprIdList)
+ val streamOutputExprIdMap = OmniExpressionAdaptor.getExprIdMap(streamedOutput.map(_.toAttribute))
val probeHashColsExp = streamedKeys.map { x =>
- OmniExpressionAdaptor.rewriteToOmniJsonExpressionLiteral(x,
- OmniExpressionAdaptor.getExprIdMap(streamedOutput.map(_.toAttribute)))
+ OmniExpressionAdaptor.rewriteToOmniJsonExpressionLiteral(x, streamOutputExprIdMap)
}.toArray
+ val prunedStreamedOutput = pruneOutput(streamedOutput, projectExprIdList)
+ val projectListIndex = getProjectListIndex(projectExprIdList, prunedStreamedOutput, prunedBuildOutput)
streamedPlan.executeColumnar.zipPartitions(buildPlan.executeColumnar()) {
(streamIter, buildIter) =>
val filter: Optional[String] = condition match {
@@ -264,19 +267,17 @@ case class ColumnarShuffledHashJoinExec(
buildOp.getOutput
buildGetOutputTime += NANOSECONDS.toMillis(System.nanoTime() - startBuildGetOp)
- val streamedPlanOutput = pruneOutput(streamedPlan.output, projectList)
- val prunedOutput = streamedPlanOutput ++ prunedBuildOutput
val resultSchema = this.schema
val reverse = buildSide == BuildLeft
var left = 0
- var leftLen = streamedPlanOutput.size
- var right = streamedPlanOutput.size
+ var leftLen = prunedStreamedOutput.size
+ var right = prunedStreamedOutput.size
var rightLen = output.size
if (reverse) {
- left = streamedPlanOutput.size
+ left = prunedStreamedOutput.size
leftLen = output.size
right = 0
- rightLen = streamedPlanOutput.size
+ rightLen = prunedStreamedOutput.size
}
val joinIter: Iterator[ColumnarBatch] = new Iterator[ColumnarBatch] {
@@ -320,7 +321,7 @@ case class ColumnarShuffledHashJoinExec(
val vecs = OmniColumnVector
.allocateColumns(result.getRowCount, resultSchema, false)
if (projectList.nonEmpty) {
- reorderVecs(prunedOutput, projectList, resultVecs, vecs)
+ reorderOutputVecs(projectListIndex, resultVecs, vecs)
} else {
var index = 0
for (i <- left until leftLen) {
@@ -375,7 +376,7 @@ case class ColumnarShuffledHashJoinExec(
val vecs = OmniColumnVector
.allocateColumns(result.getRowCount, resultSchema, false)
if (projectList.nonEmpty) {
- reorderVecs(prunedOutput, projectList, resultVecs, vecs)
+ reorderOutputVecs(projectListIndex, resultVecs, vecs)
} else {
var index = 0
for (i <- left until leftLen) {
diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/joins/ColumnarSortMergeJoinExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/joins/ColumnarSortMergeJoinExec.scala
index c3a22b1ea3a8c8a250e75142af71d47df9f2bcb5..a5baa6bde611724f1333fc5f644da9f2bc34a7a2 100644
--- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/joins/ColumnarSortMergeJoinExec.scala
+++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/joins/ColumnarSortMergeJoinExec.scala
@@ -25,7 +25,7 @@ import com.huawei.boostkit.spark.Constant.IS_SKIP_VERIFY_EXP
import com.huawei.boostkit.spark.expression.OmniExpressionAdaptor
import com.huawei.boostkit.spark.expression.OmniExpressionAdaptor.{checkOmniJsonWhiteList, isSimpleColumn, isSimpleColumnForAll}
import com.huawei.boostkit.spark.util.OmniAdaptorUtil
-import com.huawei.boostkit.spark.util.OmniAdaptorUtil.{getIndexArray, pruneOutput, reorderVecs, transColBatchToOmniVecs}
+import com.huawei.boostkit.spark.util.OmniAdaptorUtil.{getExprIdForProjectList, getIndexArray, getProjectListIndex,pruneOutput, reorderOutputVecs, transColBatchToOmniVecs}
import nova.hetu.omniruntime.`type`.DataType
import nova.hetu.omniruntime.constants.JoinType._
import nova.hetu.omniruntime.operator.config.{OperatorConfig, OverflowConfig, SpillConfig}
@@ -197,18 +197,18 @@ case class ColumnarSortMergeJoinExec(
left.output.zipWithIndex.foreach { case (attr, i) =>
streamedTypes(i) = OmniExpressionAdaptor.sparkTypeToOmniType(attr.dataType, attr.metadata)
}
+ val streamOutputExprIdMap = OmniExpressionAdaptor.getExprIdMap(left.output.map(_.toAttribute))
val streamedKeyColsExp: Array[AnyRef] = leftKeys.map { x =>
- OmniExpressionAdaptor.rewriteToOmniJsonExpressionLiteral(x,
- OmniExpressionAdaptor.getExprIdMap(left.output.map(_.toAttribute)))
+ OmniExpressionAdaptor.rewriteToOmniJsonExpressionLiteral(x, streamOutputExprIdMap)
}.toArray
val bufferedTypes = new Array[DataType](right.output.size)
right.output.zipWithIndex.foreach { case (attr, i) =>
bufferedTypes(i) = OmniExpressionAdaptor.sparkTypeToOmniType(attr.dataType, attr.metadata)
}
+ val bufferOutputExprIdMap = OmniExpressionAdaptor.getExprIdMap(right.output.map(_.toAttribute))
val bufferedKeyColsExp: Array[AnyRef] = rightKeys.map { x =>
- OmniExpressionAdaptor.rewriteToOmniJsonExpressionLiteral(x,
- OmniExpressionAdaptor.getExprIdMap(right.output.map(_.toAttribute)))
+ OmniExpressionAdaptor.rewriteToOmniJsonExpressionLiteral(x, bufferOutputExprIdMap)
}.toArray
if (!isSimpleColumnForAll(streamedKeyColsExp.map(expr => expr.toString))) {
@@ -246,23 +246,24 @@ case class ColumnarSortMergeJoinExec(
left.output.zipWithIndex.foreach { case (attr, i) =>
streamedTypes(i) = OmniExpressionAdaptor.sparkTypeToOmniType(attr.dataType, attr.metadata)
}
+ val streamOutputExprIdMap = OmniExpressionAdaptor.getExprIdMap(left.output.map(_.toAttribute))
val streamedKeyColsExp = leftKeys.map { x =>
- OmniExpressionAdaptor.rewriteToOmniJsonExpressionLiteral(x,
- OmniExpressionAdaptor.getExprIdMap(left.output.map(_.toAttribute)))
+ OmniExpressionAdaptor.rewriteToOmniJsonExpressionLiteral(x, streamOutputExprIdMap)
}.toArray
- val streamedOutputChannel = getIndexArray(left.output, projectList)
+ val projectExprIdList = getExprIdForProjectList(projectList)
+ val streamedOutputChannel = getIndexArray(left.output, projectExprIdList)
val bufferedTypes = new Array[DataType](right.output.size)
right.output.zipWithIndex.foreach { case (attr, i) =>
bufferedTypes(i) = OmniExpressionAdaptor.sparkTypeToOmniType(attr.dataType, attr.metadata)
}
+ val bufferOutputExprIdMap = OmniExpressionAdaptor.getExprIdMap(right.output.map(_.toAttribute))
val bufferedKeyColsExp = rightKeys.map { x =>
- OmniExpressionAdaptor.rewriteToOmniJsonExpressionLiteral(x,
- OmniExpressionAdaptor.getExprIdMap(right.output.map(_.toAttribute)))
+ OmniExpressionAdaptor.rewriteToOmniJsonExpressionLiteral(x, bufferOutputExprIdMap)
}.toArray
val bufferedOutputChannel: Array[Int] = joinType match {
case Inner | LeftOuter | FullOuter =>
- getIndexArray(right.output, projectList)
+ getIndexArray(right.output, projectExprIdList)
case LeftExistence(_) =>
Array[Int]()
case x =>
@@ -275,6 +276,9 @@ case class ColumnarSortMergeJoinExec(
OmniExpressionAdaptor.getExprIdMap((left.output ++ right.output).map(_.toAttribute)))
case _ => null
}
+ val prunedStreamOutput = pruneOutput(left.output, projectExprIdList)
+ val prunedBufferOutput = pruneOutput(right.output, projectExprIdList)
+ val projectListIndex = getProjectListIndex(projectExprIdList, prunedStreamOutput, prunedBufferOutput)
left.executeColumnar().zipPartitions(right.executeColumnar()) { (streamedIter, bufferedIter) =>
val filter: Optional[String] = Optional.ofNullable(filterString)
@@ -304,9 +308,6 @@ case class ColumnarSortMergeJoinExec(
streamedOpFactory.close()
})
- val prunedStreamOutput = pruneOutput(left.output, projectList)
- val prunedBufferOutput = pruneOutput(right.output, projectList)
- val prunedOutput = prunedStreamOutput ++ prunedBufferOutput
val resultSchema = this.schema
val columnarConf: ColumnarPluginConfig = ColumnarPluginConfig.getSessionConf
val enableSortMergeJoinBatchMerge: Boolean = columnarConf.enableSortMergeJoinBatchMerge
@@ -415,7 +416,7 @@ case class ColumnarSortMergeJoinExec(
val resultVecs = result.getVectors
val vecs = OmniColumnVector.allocateColumns(result.getRowCount, resultSchema, false)
if (projectList.nonEmpty) {
- reorderVecs(prunedOutput, projectList, resultVecs, vecs)
+ reorderOutputVecs(projectListIndex, resultVecs, vecs)
} else {
for (index <- output.indices) {
val v = vecs(index)
diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/util/SparkMemoryUtils.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/util/SparkMemoryUtils.scala
index 946c90a9baf346dc4e47253ced50a53def22374b..2eb8fec002fb39b65284f33654c6df88a044474f 100644
--- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/util/SparkMemoryUtils.scala
+++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/util/SparkMemoryUtils.scala
@@ -23,8 +23,8 @@ import org.apache.spark.{SparkEnv, TaskContext}
object SparkMemoryUtils {
- private val max: Long = SparkEnv.get.conf.getSizeAsBytes("spark.memory.offHeap.size", "1g")
- MemoryManager.setGlobalMemoryLimit(max)
+ val offHeapSize: Long = SparkEnv.get.conf.getSizeAsBytes("spark.memory.offHeap.size", "1g")
+ MemoryManager.setGlobalMemoryLimit(offHeapSize)
def init(): Unit = {}
diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/window/TopNPushDownForWindow.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/window/TopNPushDownForWindow.scala
index 94e566f9b571c57c2a0e1cc17143c055d7be9229..d53c6e0286c21e026c5073335e96a5a00010a71a 100644
--- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/window/TopNPushDownForWindow.scala
+++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/window/TopNPushDownForWindow.scala
@@ -81,7 +81,7 @@ object TopNPushDownForWindow extends Rule[SparkPlan] with PredicateHelper {
private def isTopNExpression(e: Expression): Boolean = e match {
case Alias(child, _) => isTopNExpression(child)
case WindowExpression(windowFunction, _)
- if windowFunction.isInstanceOf[Rank] || windowFunction.isInstanceOf[RowNumber] => true
+ if windowFunction.isInstanceOf[Rank] => true
case _ => false
}
diff --git a/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/ColumnShuffleCompressionTest.java b/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/ColumnShuffleCompressionTest.java
index d95be18832b926500b599821b6b6fd0baa8861c5..d1cd5b7f29c8249d81cdb458d89d82ecf38b8dbe 100644
--- a/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/ColumnShuffleCompressionTest.java
+++ b/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/ColumnShuffleCompressionTest.java
@@ -117,7 +117,8 @@ public class ColumnShuffleCompressionTest extends ColumnShuffleTest {
shuffleTestDir,
64 * 1024,
4096,
- 1024 * 1024 * 1024);
+ 1024 * 1024 * 1024,
+ Long.MAX_VALUE);
for (int i = 0; i < 999; i++) {
VecBatch vecBatchTmp = buildVecBatch(idTypes, 1000, partitionNum, true, true);
jniWrapper.split(splitterId, vecBatchTmp.getNativeVectorBatch());
diff --git a/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/ColumnShuffleDiffPartitionTest.java b/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/ColumnShuffleDiffPartitionTest.java
index c8fd474137a93ea8831d3dc3ab432e409018cc55..e0d271ab107073c85314c1b0b796506dd5892232 100644
--- a/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/ColumnShuffleDiffPartitionTest.java
+++ b/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/ColumnShuffleDiffPartitionTest.java
@@ -115,7 +115,8 @@ public class ColumnShuffleDiffPartitionTest extends ColumnShuffleTest {
shuffleTestDir,
64 * 1024,
4096,
- 1024 * 1024 * 1024);
+ 1024 * 1024 * 1024,
+ Long.MAX_VALUE);
for (int i = 0; i < 99; i++) {
VecBatch vecBatchTmp = buildVecBatch(idTypes, 999, partitionNum, true, pidVec);
jniWrapper.split(splitterId, vecBatchTmp.getNativeVectorBatch());
diff --git a/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/ColumnShuffleDiffRowVBTest.java b/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/ColumnShuffleDiffRowVBTest.java
index dc53fda8a1a04a15bf7ffb9919926d4812208fc0..0f935e68a156034d7973409b41098a3bff331330 100644
--- a/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/ColumnShuffleDiffRowVBTest.java
+++ b/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/ColumnShuffleDiffRowVBTest.java
@@ -95,7 +95,8 @@ public class ColumnShuffleDiffRowVBTest extends ColumnShuffleTest {
shuffleTestDir,
64 * 1024,
4096,
- 1024 * 1024 * 1024);
+ 1024 * 1024 * 1024,
+ Long.MAX_VALUE);
for (int i = 0; i < 999; i++) {
VecBatch vecBatchTmp = buildVecBatch(idTypes, 999, partitionNum, true, true);
jniWrapper.split(splitterId, vecBatchTmp.getNativeVectorBatch());
@@ -125,7 +126,8 @@ public class ColumnShuffleDiffRowVBTest extends ColumnShuffleTest {
shuffleTestDir,
0,
4096,
- 1024*1024*1024);
+ 1024*1024*1024,
+ Long.MAX_VALUE);
for (int i = 0; i < 999; i++) {
VecBatch vecBatchTmp = buildVecBatch(idTypes, 999, partitionNum, true, true);
jniWrapper.split(splitterId, vecBatchTmp.getNativeVectorBatch());
@@ -155,7 +157,8 @@ public class ColumnShuffleDiffRowVBTest extends ColumnShuffleTest {
shuffleTestDir,
64 * 1024,
4096,
- 1024 * 1024 * 1024);
+ 1024 * 1024 * 1024,
+ Long.MAX_VALUE);
for (int i = 0; i < 1024; i++) {
VecBatch vecBatchTmp = buildVecBatch(idTypes, 1, partitionNum, false, true);
jniWrapper.split(splitterId, vecBatchTmp.getNativeVectorBatch());
@@ -185,7 +188,8 @@ public class ColumnShuffleDiffRowVBTest extends ColumnShuffleTest {
shuffleTestDir,
64 * 1024,
4096,
- 1024 * 1024 * 1024);
+ 1024 * 1024 * 1024,
+ Long.MAX_VALUE);
for (int i = 0; i < 1; i++) {
VecBatch vecBatchTmp = buildVecBatch(idTypes, 1024, partitionNum, false, true);
jniWrapper.split(splitterId, vecBatchTmp.getNativeVectorBatch());
@@ -214,7 +218,8 @@ public class ColumnShuffleDiffRowVBTest extends ColumnShuffleTest {
shuffleTestDir,
64 * 1024,
4096,
- 1024 * 1024 * 1024);
+ 1024 * 1024 * 1024,
+ Long.MAX_VALUE);
for (int i = 1; i < 1000; i++) {
VecBatch vecBatchTmp = buildVecBatch(idTypes, i, numPartition, false, true);
jniWrapper.split(splitterId, vecBatchTmp.getNativeVectorBatch());
@@ -243,7 +248,8 @@ public class ColumnShuffleDiffRowVBTest extends ColumnShuffleTest {
shuffleTestDir,
64 * 1024,
4096,
- 1024 * 1024 * 1024);
+ 1024 * 1024 * 1024,
+ Long.MAX_VALUE);
VecBatch vecBatchTmp1 = new VecBatch(buildValChar(3, "N"));
jniWrapper.split(splitterId, vecBatchTmp1.getNativeVectorBatch());
VecBatch vecBatchTmp2 = new VecBatch(buildValChar(2, "F"));
@@ -282,7 +288,8 @@ public class ColumnShuffleDiffRowVBTest extends ColumnShuffleTest {
shuffleTestDir,
64 * 1024,
4096,
- 1024 * 1024 * 1024);
+ 1024 * 1024 * 1024,
+ Long.MAX_VALUE);
VecBatch vecBatchTmp1 = new VecBatch(buildValInt(3, 1));
jniWrapper.split(splitterId, vecBatchTmp1.getNativeVectorBatch());
VecBatch vecBatchTmp2 = new VecBatch(buildValInt(2, 2));
diff --git a/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/ColumnShuffleGBSizeTest.java b/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/ColumnShuffleGBSizeTest.java
index 2ef81ac49e545aa617136b9d4f3e7e769ea34652..dcd1e8b8571f5942cdb8e7ff29f842eb2c325528 100644
--- a/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/ColumnShuffleGBSizeTest.java
+++ b/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/ColumnShuffleGBSizeTest.java
@@ -95,7 +95,8 @@ public class ColumnShuffleGBSizeTest extends ColumnShuffleTest {
shuffleTestDir,
64 * 1024,
4096,
- 1024 * 1024 * 1024);
+ 1024 * 1024 * 1024,
+ Long.MAX_VALUE);
for (int i = 0; i < 6 * 1024; i++) {
VecBatch vecBatchTmp = buildVecBatch(idTypes, 9999, partitionNum, false, true);
jniWrapper.split(splitterId, vecBatchTmp.getNativeVectorBatch());
@@ -124,7 +125,8 @@ public class ColumnShuffleGBSizeTest extends ColumnShuffleTest {
shuffleTestDir,
64 * 1024,
4096,
- 1024 * 1024 * 1024);
+ 1024 * 1024 * 1024,
+ Long.MAX_VALUE);
for (int i = 0; i < 10 * 8 * 1024; i++) {
VecBatch vecBatchTmp = buildVecBatch(idTypes, 9999, partitionNum, false, true);
jniWrapper.split(splitterId, vecBatchTmp.getNativeVectorBatch());
@@ -153,7 +155,8 @@ public class ColumnShuffleGBSizeTest extends ColumnShuffleTest {
shuffleTestDir,
64 * 1024,
4096,
- 1024 * 1024 * 1024);
+ 1024 * 1024 * 1024,
+ Long.MAX_VALUE);
// 不能重复split同一个vb,接口有释放vb内存,重复split会导致重复释放内存而Core
for (int i = 0; i < 99; i++) {
VecBatch vecBatchTmp = buildVecBatch(idTypes, 999, partitionNum, false, true);
@@ -183,7 +186,8 @@ public class ColumnShuffleGBSizeTest extends ColumnShuffleTest {
shuffleTestDir,
64 * 1024,
4096,
- 1024 * 1024 * 1024);
+ 1024 * 1024 * 1024,
+ Long.MAX_VALUE);
for (int i = 0; i < 10 * 3 * 999; i++) {
VecBatch vecBatchTmp = buildVecBatch(idTypes, 9999, partitionNum, false, true);
jniWrapper.split(splitterId, vecBatchTmp.getNativeVectorBatch());
@@ -213,7 +217,8 @@ public class ColumnShuffleGBSizeTest extends ColumnShuffleTest {
shuffleTestDir,
64 * 1024,
4096,
- 1024 * 1024 * 1024);
+ 1024 * 1024 * 1024,
+ Long.MAX_VALUE);
// 不能重复split同一个vb,接口有释放vb内存,重复split会导致重复释放内存而Core
for (int i = 0; i < 6 * 999; i++) {
VecBatch vecBatchTmp = buildVecBatch(idTypes, 9999, partitionNum, false, true);
@@ -244,7 +249,8 @@ public class ColumnShuffleGBSizeTest extends ColumnShuffleTest {
shuffleTestDir,
64 * 1024,
4096,
- 1024 * 1024 * 1024);
+ 1024 * 1024 * 1024,
+ Long.MAX_VALUE);
for (int i = 0; i < 3 * 9 * 999; i++) {
VecBatch vecBatchTmp = buildVecBatch(idTypes, 9999, partitionNum, false, true);
jniWrapper.split(splitterId, vecBatchTmp.getNativeVectorBatch());
diff --git a/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/ColumnShuffleNullTest.java b/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/ColumnShuffleNullTest.java
index 98fc18dd8f3237928cc066887e6fcb2205686692..886c2f80643bc314ecc8e1f1fdd4272cc4fbebd2 100644
--- a/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/ColumnShuffleNullTest.java
+++ b/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/ColumnShuffleNullTest.java
@@ -94,7 +94,8 @@ public class ColumnShuffleNullTest extends ColumnShuffleTest {
shuffleTestDir,
64 * 1024,
4096,
- 1024 * 1024 * 1024);
+ 1024 * 1024 * 1024,
+ Long.MAX_VALUE);
// 不能重复split同一个vb,接口有释放vb内存,重复split会导致重复释放内存而Core
for (int i = 0; i < 1; i++) {
VecBatch vecBatchTmp = buildVecBatch(idTypes, 9, numPartition, true, true);
@@ -124,7 +125,8 @@ public class ColumnShuffleNullTest extends ColumnShuffleTest {
shuffleTestDir,
64 * 1024,
4096,
- 1024 * 1024 * 1024);
+ 1024 * 1024 * 1024,
+ Long.MAX_VALUE);
// 不能重复split同一个vb,接口有释放vb内存,重复split会导致重复释放内存而Core
for (int i = 0; i < 1; i++) {
VecBatch vecBatchTmp = buildVecBatch(idTypes, 9, numPartition, true, true);
@@ -155,7 +157,8 @@ public class ColumnShuffleNullTest extends ColumnShuffleTest {
shuffleTestDir,
64 * 1024,
4096,
- 1024 * 1024 * 1024);
+ 1024 * 1024 * 1024,
+ Long.MAX_VALUE);
// 不能重复split同一个vb,接口有释放vb内存,重复split会导致重复释放内存而Core
for (int i = 0; i < 1; i++) {
VecBatch vecBatchTmp = buildVecBatch(idTypes, 9, numPartition, true, true);
@@ -186,7 +189,8 @@ public class ColumnShuffleNullTest extends ColumnShuffleTest {
shuffleTestDir,
64 * 1024,
4096,
- 1024 * 1024 * 1024);
+ 1024 * 1024 * 1024,
+ Long.MAX_VALUE);
for (int i = 0; i < 1; i++) {
VecBatch vecBatchTmp = buildVecBatch(idTypes, 9999, numPartition, true, true);
jniWrapper.split(splitterId, vecBatchTmp.getNativeVectorBatch());
diff --git a/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderDataTypeTest.java b/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderDataTypeTest.java
index 77283e4d07eb7e1d77317a026f5d04ea4032c393..fe1c55ffb6eeecc7a16521430759e3239ca77dd4 100644
--- a/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderDataTypeTest.java
+++ b/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderDataTypeTest.java
@@ -36,6 +36,10 @@ import java.net.URI;
import java.net.URISyntaxException;
import java.util.ArrayList;
+import static nova.hetu.omniruntime.type.DataType.DataTypeId.OMNI_LONG;
+import static nova.hetu.omniruntime.type.DataType.DataTypeId.OMNI_VARCHAR;
+import static nova.hetu.omniruntime.type.DataType.DataTypeId.OMNI_INT;
+
@FixMethodOrder(value = MethodSorters.NAME_ASCENDING )
public class OrcColumnarBatchJniReaderDataTypeTest extends TestCase {
public OrcColumnarBatchScanReader orcColumnarBatchScanReader;
@@ -96,7 +100,7 @@ public class OrcColumnarBatchJniReaderDataTypeTest extends TestCase {
@Test
public void testNext() {
- int[] typeId = new int[4];
+ int[] typeId = new int[] {OMNI_LONG.ordinal(), OMNI_VARCHAR.ordinal(), OMNI_VARCHAR.ordinal(), OMNI_INT.ordinal()};
long[] vecNativeId = new long[4];
long rtn = orcColumnarBatchScanReader.jniReader.recordReaderNext(orcColumnarBatchScanReader.recordReader, orcColumnarBatchScanReader.batchReader, typeId, vecNativeId);
assertTrue(rtn == 4096);
@@ -115,4 +119,27 @@ public class OrcColumnarBatchJniReaderDataTypeTest extends TestCase {
vec3.close();
vec4.close();
}
+
+ // Here we test OMNI_LONG type instead of OMNI_INT in 4th field.
+ @Test
+ public void testNextIfSchemaChange() {
+ int[] typeId = new int[] {OMNI_LONG.ordinal(), OMNI_VARCHAR.ordinal(), OMNI_VARCHAR.ordinal(), OMNI_LONG.ordinal()};
+ long[] vecNativeId = new long[4];
+ long rtn = orcColumnarBatchScanReader.jniReader.recordReaderNext(orcColumnarBatchScanReader.recordReader, orcColumnarBatchScanReader.batchReader, typeId, vecNativeId);
+ assertTrue(rtn == 4096);
+ LongVec vec1 = new LongVec(vecNativeId[0]);
+ VarcharVec vec2 = new VarcharVec(vecNativeId[1]);
+ VarcharVec vec3 = new VarcharVec(vecNativeId[2]);
+ LongVec vec4 = new LongVec(vecNativeId[3]);
+ assertTrue(vec1.get(10) == 11);
+ String tmp1 = new String(vec2.get(4080));
+ assertTrue(tmp1.equals("AAAAAAAABPPAAAAA"));
+ String tmp2 = new String(vec3.get(4070));
+ assertTrue(tmp2.equals("Particular, arab cases shall like less current, different names. Computers start for the changes. Scottish, trying exercises operate marks; long, supreme miners may ro"));
+ assertTrue(0 == vec4.get(1000));
+ vec1.close();
+ vec2.close();
+ vec3.close();
+ vec4.close();
+ }
}
\ No newline at end of file
diff --git a/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderNotPushDownTest.java b/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderNotPushDownTest.java
index 72587b3f36a469fe130abc76c51197fb2a16bd29..995c434f66a8b03a6a76830b8fb0b08be9a3223e 100644
--- a/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderNotPushDownTest.java
+++ b/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderNotPushDownTest.java
@@ -35,6 +35,9 @@ import java.net.URI;
import java.net.URISyntaxException;
import java.util.ArrayList;
+import static nova.hetu.omniruntime.type.DataType.DataTypeId.OMNI_LONG;
+import static nova.hetu.omniruntime.type.DataType.DataTypeId.OMNI_VARCHAR;
+
@FixMethodOrder(value = MethodSorters.NAME_ASCENDING )
public class OrcColumnarBatchJniReaderNotPushDownTest extends TestCase {
public OrcColumnarBatchScanReader orcColumnarBatchScanReader;
@@ -89,7 +92,7 @@ public class OrcColumnarBatchJniReaderNotPushDownTest extends TestCase {
@Test
public void testNext() {
- int[] typeId = new int[2];
+ int[] typeId = new int[] {OMNI_LONG.ordinal(), OMNI_VARCHAR.ordinal()};
long[] vecNativeId = new long[2];
long rtn = orcColumnarBatchScanReader.jniReader.recordReaderNext(orcColumnarBatchScanReader.recordReader, orcColumnarBatchScanReader.batchReader, typeId, vecNativeId);
assertTrue(rtn == 4096);
diff --git a/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderPushDownTest.java b/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderPushDownTest.java
index 6c75eda79e38332e48e67df8a27c0e1394e4e477..c9ad9fadaf4ae13b1286467f965a23f66788c37f 100644
--- a/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderPushDownTest.java
+++ b/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderPushDownTest.java
@@ -35,6 +35,9 @@ import java.net.URI;
import java.net.URISyntaxException;
import java.util.ArrayList;
+import static nova.hetu.omniruntime.type.DataType.DataTypeId.OMNI_LONG;
+import static nova.hetu.omniruntime.type.DataType.DataTypeId.OMNI_VARCHAR;
+
@FixMethodOrder(value = MethodSorters.NAME_ASCENDING )
public class OrcColumnarBatchJniReaderPushDownTest extends TestCase {
public OrcColumnarBatchScanReader orcColumnarBatchScanReader;
@@ -135,7 +138,7 @@ public class OrcColumnarBatchJniReaderPushDownTest extends TestCase {
@Test
public void testNext() {
- int[] typeId = new int[2];
+ int[] typeId = new int[] {OMNI_LONG.ordinal(), OMNI_VARCHAR.ordinal()};
long[] vecNativeId = new long[2];
long rtn = orcColumnarBatchScanReader.jniReader.recordReaderNext(orcColumnarBatchScanReader.recordReader, orcColumnarBatchScanReader.batchReader, typeId, vecNativeId);
assertTrue(rtn == 4096);
diff --git a/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderSparkORCNotPushDownTest.java b/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderSparkORCNotPushDownTest.java
index 7fb87efa3a2899bc5375168f56d516243a10881d..8f4535338cc3715e33dbf336e5f15fd4eb91569f 100644
--- a/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderSparkORCNotPushDownTest.java
+++ b/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderSparkORCNotPushDownTest.java
@@ -36,6 +36,10 @@ import java.net.URI;
import java.net.URISyntaxException;
import java.util.ArrayList;
+import static nova.hetu.omniruntime.type.DataType.DataTypeId.OMNI_LONG;
+import static nova.hetu.omniruntime.type.DataType.DataTypeId.OMNI_VARCHAR;
+import static nova.hetu.omniruntime.type.DataType.DataTypeId.OMNI_INT;
+
@FixMethodOrder(value = MethodSorters.NAME_ASCENDING )
public class OrcColumnarBatchJniReaderSparkORCNotPushDownTest extends TestCase {
public OrcColumnarBatchScanReader orcColumnarBatchScanReader;
@@ -96,7 +100,7 @@ public class OrcColumnarBatchJniReaderSparkORCNotPushDownTest extends TestCase {
@Test
public void testNext() {
- int[] typeId = new int[4];
+ int[] typeId = new int[] {OMNI_LONG.ordinal(), OMNI_VARCHAR.ordinal(), OMNI_VARCHAR.ordinal(), OMNI_INT.ordinal()};
long[] vecNativeId = new long[4];
long rtn = orcColumnarBatchScanReader.jniReader.recordReaderNext(orcColumnarBatchScanReader.recordReader, orcColumnarBatchScanReader.batchReader, typeId, vecNativeId);
assertTrue(rtn == 4096);
diff --git a/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderSparkORCPushDownTest.java b/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderSparkORCPushDownTest.java
index 4ba4579cc9340ea410d1d1cdbbf5e0a88ebe2888..27bcf5d7bdbd4787f42e90e620d1627678f70910 100644
--- a/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderSparkORCPushDownTest.java
+++ b/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderSparkORCPushDownTest.java
@@ -36,6 +36,10 @@ import java.net.URI;
import java.net.URISyntaxException;
import java.util.ArrayList;
+import static nova.hetu.omniruntime.type.DataType.DataTypeId.OMNI_LONG;
+import static nova.hetu.omniruntime.type.DataType.DataTypeId.OMNI_VARCHAR;
+import static nova.hetu.omniruntime.type.DataType.DataTypeId.OMNI_INT;
+
@FixMethodOrder(value = MethodSorters.NAME_ASCENDING )
public class OrcColumnarBatchJniReaderSparkORCPushDownTest extends TestCase {
public OrcColumnarBatchScanReader orcColumnarBatchScanReader;
@@ -142,7 +146,7 @@ public class OrcColumnarBatchJniReaderSparkORCPushDownTest extends TestCase {
@Test
public void testNext() {
- int[] typeId = new int[4];
+ int[] typeId = new int[] {OMNI_LONG.ordinal(), OMNI_VARCHAR.ordinal(), OMNI_VARCHAR.ordinal(), OMNI_INT.ordinal()};
long[] vecNativeId = new long[4];
long rtn = orcColumnarBatchScanReader.jniReader.recordReaderNext(orcColumnarBatchScanReader.recordReader, orcColumnarBatchScanReader.batchReader, typeId, vecNativeId);
assertTrue(rtn == 4096);
diff --git a/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderTest.java b/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderTest.java
index c8581f35ebc605845896b5f35731b0908779f326..eab15fef660250780e0beb311b489e3ceeb8ff5b 100644
--- a/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderTest.java
+++ b/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderTest.java
@@ -21,20 +21,15 @@ package com.huawei.boostkit.spark.jni;
import com.esotericsoftware.kryo.Kryo;
import com.esotericsoftware.kryo.io.Input;
import junit.framework.TestCase;
-import nova.hetu.omniruntime.type.DataType;
-import nova.hetu.omniruntime.vector.IntVec;
import nova.hetu.omniruntime.vector.LongVec;
import nova.hetu.omniruntime.vector.VarcharVec;
import nova.hetu.omniruntime.vector.Vec;
import org.apache.commons.codec.binary.Base64;
import org.apache.hadoop.hive.ql.io.sarg.SearchArgument;
import org.apache.hadoop.hive.ql.io.sarg.SearchArgumentImpl;
-import org.apache.orc.OrcConf;
import org.apache.orc.OrcFile;
-import org.apache.orc.Reader;
import org.apache.orc.TypeDescription;
import org.apache.orc.mapred.OrcInputFormat;
-import org.json.JSONObject;
import org.junit.After;
import org.junit.Before;
import org.junit.FixMethodOrder;
@@ -51,6 +46,9 @@ import org.apache.orc.Reader.Options;
import static org.junit.Assert.*;
+import static nova.hetu.omniruntime.type.DataType.DataTypeId.OMNI_LONG;
+import static nova.hetu.omniruntime.type.DataType.DataTypeId.OMNI_VARCHAR;
+
@FixMethodOrder(value = MethodSorters.NAME_ASCENDING )
public class OrcColumnarBatchJniReaderTest extends TestCase {
public Configuration conf = new Configuration();
@@ -152,7 +150,8 @@ public class OrcColumnarBatchJniReaderTest extends TestCase {
@Test
public void testNext() {
Vec[] vecs = new Vec[2];
- long rtn = orcColumnarBatchScanReader.next(vecs);
+ int[] typeId = new int[] {OMNI_LONG.ordinal(), OMNI_VARCHAR.ordinal()};
+ long rtn = orcColumnarBatchScanReader.next(vecs, typeId);
assertTrue(rtn == 4096);
assertTrue(((LongVec) vecs[0]).get(0) == 1);
String str = new String(((VarcharVec) vecs[1]).get(0));
diff --git a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/shuffle/RowShuffleSerializerSuite.scala b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/shuffle/RowShuffleSerializerSuite.scala
new file mode 100644
index 0000000000000000000000000000000000000000..0f0eda4f765d4eddb57fcfd0615b3918fed842d1
--- /dev/null
+++ b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/shuffle/RowShuffleSerializerSuite.scala
@@ -0,0 +1,249 @@
+/*
+ * Copyright (C) 2022-2022. Huawei Technologies Co., Ltd. All rights reserved.
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle
+
+import java.io.{File, FileInputStream}
+
+import com.huawei.boostkit.spark.serialize.ColumnarBatchSerializer
+import com.huawei.boostkit.spark.vectorized.PartitionInfo
+import nova.hetu.omniruntime.`type`.{DataType, _}
+import nova.hetu.omniruntime.vector._
+import org.apache.spark.{HashPartitioner, SparkConf, TaskContext}
+import org.apache.spark.executor.TaskMetrics
+import org.apache.spark.serializer.JavaSerializer
+import org.apache.spark.shuffle.sort.ColumnarShuffleHandle
+import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
+import org.apache.spark.sql.execution.vectorized.OmniColumnVector
+import org.apache.spark.sql.test.SharedSparkSession
+import org.apache.spark.sql.vectorized.ColumnarBatch
+import org.apache.spark.util.Utils
+import org.mockito.Answers.RETURNS_SMART_NULLS
+import org.mockito.ArgumentMatchers.{any, anyInt, anyLong}
+import org.mockito.{Mock, MockitoAnnotations}
+import org.mockito.Mockito.{doAnswer, when}
+import org.mockito.invocation.InvocationOnMock
+
+class RowShuffleSerializerSuite extends SharedSparkSession {
+ @Mock(answer = RETURNS_SMART_NULLS) private var taskContext: TaskContext = _
+ @Mock(answer = RETURNS_SMART_NULLS) private var blockResolver: IndexShuffleBlockResolver = _
+ @Mock(answer = RETURNS_SMART_NULLS) private var dependency
+ : ColumnarShuffleDependency[Int, ColumnarBatch, ColumnarBatch] = _
+
+ override def sparkConf: SparkConf =
+ super.sparkConf
+ .setAppName("test row shuffle serializer")
+ .set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.OmniColumnarShuffleManager")
+ .set("spark.shuffle.compress", "true")
+ .set("spark.io.compression.codec", "lz4")
+
+ private var taskMetrics: TaskMetrics = _
+ private var tempDir: File = _
+ private var outputFile: File = _
+
+ private var shuffleHandle: ColumnarShuffleHandle[Int, ColumnarBatch] = _
+ private val numPartitions = 1
+
+ protected var avgBatchNumRows: SQLMetric = _
+ protected var outputNumRows: SQLMetric = _
+
+ override def beforeEach(): Unit = {
+ super.beforeEach()
+
+ avgBatchNumRows = SQLMetrics.createAverageMetric(spark.sparkContext,
+ "test serializer avg read batch num rows")
+ outputNumRows = SQLMetrics.createAverageMetric(spark.sparkContext,
+ "test serializer number of output rows")
+
+ tempDir = Utils.createTempDir()
+ outputFile = File.createTempFile("shuffle", null, tempDir)
+ taskMetrics = new TaskMetrics
+
+ MockitoAnnotations.initMocks(this)
+
+ shuffleHandle =
+ new ColumnarShuffleHandle[Int, ColumnarBatch](shuffleId = 0, dependency = dependency)
+
+ val types : Array[DataType] = Array[DataType](
+ IntDataType.INTEGER,
+ ShortDataType.SHORT,
+ LongDataType.LONG,
+ DoubleDataType.DOUBLE,
+ new Decimal64DataType(18, 3),
+ new Decimal128DataType(28, 11),
+ VarcharDataType.VARCHAR,
+ BooleanDataType.BOOLEAN)
+ val inputTypes = DataTypeSerializer.serialize(types)
+
+ when(dependency.partitioner).thenReturn(new HashPartitioner(numPartitions))
+ when(dependency.serializer).thenReturn(new JavaSerializer(sparkConf))
+ when(dependency.handleRow).thenReturn(true) // adapt row shuffle
+ when(dependency.partitionInfo).thenReturn(
+ new PartitionInfo("hash", numPartitions, types.length, inputTypes))
+ when(dependency.dataSize)
+ .thenReturn(SQLMetrics.createSizeMetric(spark.sparkContext, "data size"))
+ when(dependency.bytesSpilled)
+ .thenReturn(SQLMetrics.createSizeMetric(spark.sparkContext, "shuffle bytes spilled"))
+ when(dependency.numInputRows)
+ .thenReturn(SQLMetrics.createMetric(spark.sparkContext, "number of input rows"))
+ when(dependency.splitTime)
+ .thenReturn(SQLMetrics.createNanoTimingMetric(spark.sparkContext, "totaltime_split"))
+ when(dependency.spillTime)
+ .thenReturn(SQLMetrics.createNanoTimingMetric(spark.sparkContext, "totaltime_spill"))
+ when(taskContext.taskMetrics()).thenReturn(taskMetrics)
+ when(blockResolver.getDataFile(0, 0)).thenReturn(outputFile)
+
+ doAnswer { (invocationOnMock: InvocationOnMock) =>
+ val tmp = invocationOnMock.getArguments()(4).asInstanceOf[File]
+ if (tmp != null) {
+ outputFile.delete
+ tmp.renameTo(outputFile)
+ }
+ null
+ }.when(blockResolver)
+ .writeMetadataFileAndCommit(anyInt, anyLong, any(classOf[Array[Long]]), any(classOf[Array[Long]]), any(classOf[File]))
+ }
+
+ override def afterEach(): Unit = {
+ try {
+ Utils.deleteRecursively(tempDir)
+ } finally {
+ super.afterEach()
+ }
+ }
+
+ override def afterAll(): Unit = {
+ super.afterAll()
+ }
+
+ test("row shuffle serialize and deserialize") {
+ val pidArray: Array[java.lang.Integer] = Array(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0)
+ val intArray: Array[java.lang.Integer] = Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+ val shortArray: Array[java.lang.Integer] = Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+ val longArray: Array[java.lang.Long] = Array(0L, 1L, 2L, 3L, 4L, 5L, 6L, 7L, 8L, 9L, 10L, 11L, 12L, 13L, 14L, 15L, 16L,
+ 17L, 18L, 19L, 20L)
+ val doubleArray: Array[java.lang.Double] = Array(0.0, 1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8, 9.9, 10.10, 11.11, 12.12,
+ 13.13, 14.14, 15.15, 16.16, 17.17, 18.18, 19.19, 20.20)
+ val decimal64Array: Array[java.lang.Long] = Array(0L, 1L, 2L, 3L, 4L, 5L, 6L, 7L, 8L, 9L, 10L, 11L, 12L, 13L, 14L, 15L, 16L,
+ 17L, 18L, 19L, 20L)
+ val decimal128Array: Array[Array[Long]] = Array(
+ Array(0L, 0L), Array(1L, 1L), Array(2L, 2L), Array(3L, 3L), Array(4L, 4L), Array(5L, 5L), Array(6L, 6L),
+ Array(7L, 7L), Array(8L, 8L), Array(9L, 9L), Array(10L, 10L), Array(11L, 11L), Array(12L, 12L), Array(13L, 13L),
+ Array(14L, 14L), Array(15L, 15L), Array(16L, 16L), Array(17L, 17L), Array(18L, 18L), Array(19L, 19L), Array(20L, 20L))
+ val stringArray: Array[java.lang.String] = Array("", "a", "bb", "ccc", "dddd", "eeeee", "ffffff", "ggggggg",
+ "hhhhhhhh", "iiiiiiiii", "jjjjjjjjjj", "kkkkkkkkkkk", "llllllllllll", "mmmmmmmmmmmmm", "nnnnnnnnnnnnnn",
+ "ooooooooooooooo", "pppppppppppppppp", "qqqqqqqqqqqqqqqqq", "rrrrrrrrrrrrrrrrrr", "sssssssssssssssssss",
+ "tttttttttttttttttttt")
+ val booleanArray: Array[java.lang.Boolean] = Array(true, true, true, true, true, true, true, true, true, true,
+ false, false, false, false, false, false, false, false, false, false, false)
+
+ val pidVector0 = ColumnarShuffleWriterSuite.initOmniColumnIntVector(pidArray)
+ val intVector0 = ColumnarShuffleWriterSuite.initOmniColumnIntVector(intArray)
+ val shortVector0 = ColumnarShuffleWriterSuite.initOmniColumnShortVector(shortArray)
+ val longVector0 = ColumnarShuffleWriterSuite.initOmniColumnLongVector(longArray)
+ val doubleVector0 = ColumnarShuffleWriterSuite.initOmniColumnDoubleVector(doubleArray)
+ val decimal64Vector0 = ColumnarShuffleWriterSuite.initOmniColumnDecimal64Vector(decimal64Array)
+ val decimal128Vector0 = ColumnarShuffleWriterSuite.initOmniColumnDecimal128Vector(decimal128Array)
+ val varcharVector0 = ColumnarShuffleWriterSuite.initOmniColumnVarcharVector(stringArray)
+ val booleanVector0 = ColumnarShuffleWriterSuite.initOmniColumnBooleanVector(booleanArray)
+
+ val cb0 = ColumnarShuffleWriterSuite.makeColumnarBatch(
+ pidVector0.getVec.getSize,
+ List(pidVector0, intVector0, shortVector0, longVector0, doubleVector0,
+ decimal64Vector0, decimal128Vector0, varcharVector0, booleanVector0)
+ )
+
+ val pidVector1 = ColumnarShuffleWriterSuite.initOmniColumnIntVector(pidArray)
+ val intVector1 = ColumnarShuffleWriterSuite.initOmniColumnIntVector(intArray)
+ val shortVector1 = ColumnarShuffleWriterSuite.initOmniColumnShortVector(shortArray)
+ val longVector1 = ColumnarShuffleWriterSuite.initOmniColumnLongVector(longArray)
+ val doubleVector1 = ColumnarShuffleWriterSuite.initOmniColumnDoubleVector(doubleArray)
+ val decimal64Vector1 = ColumnarShuffleWriterSuite.initOmniColumnDecimal64Vector(decimal64Array)
+ val decimal128Vector1 = ColumnarShuffleWriterSuite.initOmniColumnDecimal128Vector(decimal128Array)
+ val varcharVector1 = ColumnarShuffleWriterSuite.initOmniColumnVarcharVector(stringArray)
+ val booleanVector1 = ColumnarShuffleWriterSuite.initOmniColumnBooleanVector(booleanArray)
+
+ val cb1 = ColumnarShuffleWriterSuite.makeColumnarBatch(
+ pidVector1.getVec.getSize,
+ List(pidVector1, intVector1, shortVector1, longVector1, doubleVector1,
+ decimal64Vector1, decimal128Vector1, varcharVector1, booleanVector1)
+ )
+
+ def records: Iterator[(Int, ColumnarBatch)] = Iterator((0, cb0), (0, cb1))
+
+ val writer = new ColumnarShuffleWriter[Int, ColumnarBatch](
+ blockResolver,
+ shuffleHandle,
+ 0L, // MapId
+ taskContext.taskMetrics().shuffleWriteMetrics)
+
+ // row shuffle realized
+ writer.write(records)
+ writer.stop(success = true)
+
+ assert(writer.getPartitionLengths.sum === outputFile.length())
+ assert(writer.getPartitionLengths.count(_ == 0L) === 0)
+ // should be (numPartitions - 2) zero length files
+
+ val shuffleWriteMetrics = taskContext.taskMetrics().shuffleWriteMetrics
+ assert(shuffleWriteMetrics.bytesWritten === outputFile.length())
+ assert(shuffleWriteMetrics.recordsWritten === pidArray.length * 2)
+
+ assert(taskMetrics.diskBytesSpilled === 0)
+ assert(taskMetrics.memoryBytesSpilled === 0)
+
+ // shuffle writer adapt row structure, so need to deserialized by row.
+ val serializer = new ColumnarBatchSerializer(avgBatchNumRows, outputNumRows, true).newInstance()
+ val deserializedStream = serializer.deserializeStream(new FileInputStream(outputFile))
+
+ try {
+ val kv = deserializedStream.asKeyValueIterator
+ var length = 0
+ kv.foreach {
+ case (_, batch: ColumnarBatch) =>
+ length += 1
+ assert(batch.numRows == 42)
+ assert(batch.numCols == 8)
+ assert(batch.column(0).asInstanceOf[OmniColumnVector].getVec.asInstanceOf[IntVec].get(0) == 0)
+ assert(batch.column(0).asInstanceOf[OmniColumnVector].getVec.asInstanceOf[IntVec].get(19) == 19)
+ assert(batch.column(1).asInstanceOf[OmniColumnVector].getVec.asInstanceOf[ShortVec].get(0) == 0)
+ assert(batch.column(1).asInstanceOf[OmniColumnVector].getVec.asInstanceOf[ShortVec].get(19) == 19)
+ assert(batch.column(2).asInstanceOf[OmniColumnVector].getVec.asInstanceOf[LongVec].get(0) == 0)
+ assert(batch.column(2).asInstanceOf[OmniColumnVector].getVec.asInstanceOf[LongVec].get(19) == 19)
+ assert(batch.column(3).asInstanceOf[OmniColumnVector].getVec.asInstanceOf[DoubleVec].get(0) == 0.0)
+ assert(batch.column(3).asInstanceOf[OmniColumnVector].getVec.asInstanceOf[DoubleVec].get(19) == 19.19)
+ assert(batch.column(4).asInstanceOf[OmniColumnVector].getVec.asInstanceOf[LongVec].get(0) == 0L)
+ assert(batch.column(4).asInstanceOf[OmniColumnVector].getVec.asInstanceOf[LongVec].get(19) == 19L)
+ assert(batch.column(5).asInstanceOf[OmniColumnVector].getVec.asInstanceOf[Decimal128Vec].get(0) sameElements Array(0L, 0L))
+ assert(batch.column(5).asInstanceOf[OmniColumnVector].getVec.asInstanceOf[Decimal128Vec].get(19) sameElements Array(19L, 19L))
+ assert(batch.column(6).asInstanceOf[OmniColumnVector].getVec.asInstanceOf[VarcharVec].get(0) sameElements "")
+ assert(batch.column(6).asInstanceOf[OmniColumnVector].getVec.asInstanceOf[VarcharVec].get(19) sameElements "sssssssssssssssssss")
+ assert(batch.column(7).asInstanceOf[OmniColumnVector].getVec.asInstanceOf[BooleanVec].get(0) == true)
+ assert(batch.column(7).asInstanceOf[OmniColumnVector].getVec.asInstanceOf[BooleanVec].get(19) == false)
+ (0 until batch.numCols).foreach { i =>
+ val valueVector = batch.column(i).asInstanceOf[OmniColumnVector].getVec
+ assert(valueVector.getSize == batch.numRows)
+ }
+ batch.close()
+ }
+ assert(length == 1)
+ } finally {
+ deserializedStream.close()
+ }
+ }
+}
diff --git a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/ColumnarTopNSortExecSuite.scala b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/ColumnarTopNSortExecSuite.scala
index a788501ed8ed7d06e8939cd45560f59648a56acf..fa8c1390ebe2af2a3441234868c65321d2055847 100644
--- a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/ColumnarTopNSortExecSuite.scala
+++ b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/ColumnarTopNSortExecSuite.scala
@@ -49,16 +49,17 @@ class ColumnarTopNSortExecSuite extends ColumnarSparkPlanTest {
test("Test topNSort") {
val sql1 = "select * from (SELECT city, rank() OVER (ORDER BY sales) AS rk FROM dealer) where rk < 4 order by rk;"
- assertColumnarTopNSortExecAndSparkResultEqual(sql1, true)
+ assertColumnarTopNSortExecAndSparkResultEqual(sql1, true, true)
val sql2 = "select * from (SELECT city, row_number() OVER (ORDER BY sales) AS rn FROM dealer) where rn < 4 order by rn;"
- assertColumnarTopNSortExecAndSparkResultEqual(sql2, false)
+ assertColumnarTopNSortExecAndSparkResultEqual(sql2, false, false)
val sql3 = "select * from (SELECT city, rank() OVER (PARTITION BY city ORDER BY sales) AS rk FROM dealer) where rk < 4 order by rk;"
- assertColumnarTopNSortExecAndSparkResultEqual(sql3, true)
+ assertColumnarTopNSortExecAndSparkResultEqual(sql3, true, true)
}
- private def assertColumnarTopNSortExecAndSparkResultEqual(sql: String, hasColumnarTopNSortExec: Boolean = true): Unit = {
+ private def assertColumnarTopNSortExecAndSparkResultEqual(sql: String, hasColumnarTopNSortExec: Boolean = true,
+ hasTopNSortExec: Boolean = false): Unit = {
// run ColumnarTopNSortExec config
spark.conf.set("spark.omni.sql.columnar.topNSort", true)
spark.conf.set("spark.sql.execution.topNPushDownForWindow.enabled", true)
@@ -79,8 +80,10 @@ class ColumnarTopNSortExecSuite extends ColumnarSparkPlanTest {
val sparkPlan = sparkResult.queryExecution.executedPlan.toString()
assert(!sparkPlan.contains("ColumnarTopNSort"),
s"SQL:${sql}\n@SparkEnv have ColumnarTopNSortExec, sparkPlan:${sparkPlan}")
- assert(sparkPlan.contains("TopNSort"),
- s"SQL:${sql}\n@SparkEnv no TopNSortExec, sparkPlan:${sparkPlan}")
+ if (hasTopNSortExec) {
+ assert(sparkPlan.contains("TopNSort"),
+ s"SQL:${sql}\n@SparkEnv no TopNSortExec, sparkPlan:${sparkPlan}")
+ }
// DataFrame do not support comparing with equals method, use DataFrame.except instead
// DataFrame.except can do equal for rows misorder(with and without order by are same)
assert(omniResult.except(sparkResult).isEmpty,
diff --git a/omnioperator/omniop-spark-extension/pom.xml b/omnioperator/omniop-spark-extension/pom.xml
index b7315c5b49145c0805d940cd088eb70cbc9c6265..24a42654b4f07b86944ff098639a3d3e237e9402 100644
--- a/omnioperator/omniop-spark-extension/pom.xml
+++ b/omnioperator/omniop-spark-extension/pom.xml
@@ -8,7 +8,7 @@
com.huawei.kunpeng
boostkit-omniop-spark-parent
pom
- 3.3.1-1.4.0
+ 3.3.1-1.5.0
BoostKit Spark Native Sql Engine Extension Parent Pom
@@ -20,7 +20,7 @@
UTF-8
3.13.0-h19
FALSE
- 1.4.0
+ 1.5.0
java