diff --git a/omnidata/omnidata-spark-connector/connector/pom.xml b/omnidata/omnidata-spark-connector/connector/pom.xml index 4fd2668b10b13f63ccef418c56af7ee9d1bfe8ce..dda5fe43f387e6d1fa7346321109b821f8cd0322 100644 --- a/omnidata/omnidata-spark-connector/connector/pom.xml +++ b/omnidata/omnidata-spark-connector/connector/pom.xml @@ -5,12 +5,12 @@ org.apache.spark omnidata-spark-connector-root - 1.4.0 + 1.5.0 4.0.0 boostkit-omnidata-spark-sql_2.12-3.1.1 - 1.4.0 + 1.5.0 boostkit omnidata spark sql 2021 jar @@ -55,7 +55,7 @@ com.huawei.boostkit boostkit-omnidata-stub - 1.4.0 + 1.5.0 compile diff --git a/omnioperator/omniop-native-reader/cpp/src/CMakeLists.txt b/omnioperator/omniop-native-reader/cpp/src/CMakeLists.txt index 7ba2967f87f8ebcaeb6959b51c23cef857462a07..8aa1e62449a5c8729506f5854772fb8a14b687b8 100644 --- a/omnioperator/omniop-native-reader/cpp/src/CMakeLists.txt +++ b/omnioperator/omniop-native-reader/cpp/src/CMakeLists.txt @@ -48,7 +48,7 @@ target_link_libraries (${PROJ_TARGET} PUBLIC Arrow::arrow_shared Parquet::parquet_shared orc - boostkit-omniop-vector-1.4.0-aarch64 + boostkit-omniop-vector-1.5.0-aarch64 hdfs ) diff --git a/omnioperator/omniop-native-reader/cpp/src/jni/OrcColumnarBatchJniReader.cpp b/omnioperator/omniop-native-reader/cpp/src/jni/OrcColumnarBatchJniReader.cpp index e1300c4e062857780c1c30d56fe9d366cce6f868..f8ee293e2dc7576f0784fe5f2e862141616dc3f4 100644 --- a/omnioperator/omniop-native-reader/cpp/src/jni/OrcColumnarBatchJniReader.cpp +++ b/omnioperator/omniop-native-reader/cpp/src/jni/OrcColumnarBatchJniReader.cpp @@ -28,7 +28,6 @@ using namespace std; using namespace orc; static constexpr int32_t MAX_DECIMAL64_DIGITS = 18; -bool isDecimal64Transfor128 = false; // vecFildsNames存储文件每列的列名,从orc reader c++侧获取,回传到java侧使用 JNIEXPORT jlong JNICALL Java_com_huawei_boostkit_scan_jni_OrcColumnarBatchJniReader_initializeReader(JNIEnv *env, @@ -74,7 +73,7 @@ JNIEXPORT jlong JNICALL Java_com_huawei_boostkit_scan_jni_OrcColumnarBatchJniRea UriInfo uri{schemaStr, fileStr, hostStr, std::to_string(port)}; reader = createReader(orc::readFileOverride(uri), readerOptions); std::vector orcColumnNames = reader->getAllFiedsName(); - for (int i = 0; i < orcColumnNames.size(); i++) { + for (uint32_t i = 0; i < orcColumnNames.size(); i++) { jstring fildname = env->NewStringUTF(orcColumnNames[i].c_str()); // use ArrayList and function env->CallBooleanMethod(vecFildsNames, arrayListAdd, fildname); @@ -268,12 +267,7 @@ JNIEXPORT jlong JNICALL Java_com_huawei_boostkit_scan_jni_OrcColumnarBatchJniRea { JNI_FUNC_START orc::Reader *readerPtr = (orc::Reader *)reader; - // Get if the decimal for spark or hive - jboolean jni_isDecimal64Transfor128 = env->CallBooleanMethod(jsonObj, jsonMethodHas, - env->NewStringUTF("isDecimal64Transfor128")); - if (jni_isDecimal64Transfor128) { - isDecimal64Transfor128 = true; - } + // get offset from json obj jlong offset = env->CallLongMethod(jsonObj, jsonMethodLong, env->NewStringUTF("offset")); jlong length = env->CallLongMethod(jsonObj, jsonMethodLong, env->NewStringUTF("length")); @@ -520,45 +514,86 @@ std::unique_ptr CopyToOmniDecimal128VecFrom64(orc::ColumnVectorBatch return newVector; } -std::unique_ptr CopyToOmniVec(const orc::Type *type, int &omniTypeId, orc::ColumnVectorBatch *field, - bool isDecimal64Transfor128) +std::unique_ptr DealLongVectorBatch(DataTypeId id, orc::ColumnVectorBatch *field) { + switch (id) { + case omniruntime::type::OMNI_BOOLEAN: + return CopyFixedWidth(field); + case omniruntime::type::OMNI_SHORT: + return CopyFixedWidth(field); + case omniruntime::type::OMNI_INT: + return CopyFixedWidth(field); + case omniruntime::type::OMNI_LONG: + return CopyOptimizedForInt64(field); + case omniruntime::type::OMNI_DATE32: + return CopyFixedWidth(field); + case omniruntime::type::OMNI_DATE64: + return CopyOptimizedForInt64(field); + default: { + throw std::runtime_error("DealLongVectorBatch not support for type: " + id); + } + } +} + +std::unique_ptr DealDoubleVectorBatch(DataTypeId id, orc::ColumnVectorBatch *field) { + switch (id) { + case omniruntime::type::OMNI_DOUBLE: + return CopyOptimizedForInt64(field); + default: { + throw std::runtime_error("DealDoubleVectorBatch not support for type: " + id); + } + } +} + +std::unique_ptr DealDecimal64VectorBatch(DataTypeId id, orc::ColumnVectorBatch *field) { + switch (id) { + case omniruntime::type::OMNI_DECIMAL64: + return CopyToOmniDecimal64Vec(field); + case omniruntime::type::OMNI_DECIMAL128: + return CopyToOmniDecimal128VecFrom64(field); + default: { + throw std::runtime_error("DealDecimal64VectorBatch not support for type: " + id); + } + } +} + +std::unique_ptr DealDecimal128VectorBatch(DataTypeId id, orc::ColumnVectorBatch *field) { + switch (id) { + case omniruntime::type::OMNI_DECIMAL128: + return CopyToOmniDecimal128Vec(field); + default: { + throw std::runtime_error("DealDecimal128VectorBatch not support for type: " + id); + } + } +} + +std::unique_ptr CopyToOmniVec(const orc::Type *type, int omniTypeId, orc::ColumnVectorBatch *field) { + DataTypeId dataTypeId = static_cast(omniTypeId); switch (type->getKind()) { case orc::TypeKind::BOOLEAN: - omniTypeId = static_cast(OMNI_BOOLEAN); - return CopyFixedWidth(field); case orc::TypeKind::SHORT: - omniTypeId = static_cast(OMNI_SHORT); - return CopyFixedWidth(field); case orc::TypeKind::DATE: - omniTypeId = static_cast(OMNI_DATE32); - return CopyFixedWidth(field); case orc::TypeKind::INT: - omniTypeId = static_cast(OMNI_INT); - return CopyFixedWidth(field); case orc::TypeKind::LONG: - omniTypeId = static_cast(OMNI_LONG); - return CopyOptimizedForInt64(field); + return DealLongVectorBatch(dataTypeId, field); case orc::TypeKind::DOUBLE: - omniTypeId = static_cast(OMNI_DOUBLE); - return CopyOptimizedForInt64(field); + return DealDoubleVectorBatch(dataTypeId, field); case orc::TypeKind::CHAR: - omniTypeId = static_cast(OMNI_VARCHAR); + if (dataTypeId != OMNI_VARCHAR) { + throw std::runtime_error("Cannot transfer to other OMNI_TYPE but VARCHAR for orc char"); + } return CopyCharType(field); case orc::TypeKind::STRING: case orc::TypeKind::VARCHAR: - omniTypeId = static_cast(OMNI_VARCHAR); + if (dataTypeId != OMNI_VARCHAR) { + throw std::runtime_error("Cannot transfer to other OMNI_TYPE but VARCHAR for orc string/varchar"); + } return CopyVarWidth(field); case orc::TypeKind::DECIMAL: if (type->getPrecision() > MAX_DECIMAL64_DIGITS) { - omniTypeId = static_cast(OMNI_DECIMAL128); - return CopyToOmniDecimal128Vec(field); - } else if (isDecimal64Transfor128) { - omniTypeId = static_cast(OMNI_DECIMAL128); - return CopyToOmniDecimal128VecFrom64(field); + return DealDecimal128VectorBatch(dataTypeId, field); } else { - omniTypeId = static_cast(OMNI_DECIMAL64); - return CopyToOmniDecimal64Vec(field); + return DealDecimal64VectorBatch(dataTypeId, field); } default: { throw std::runtime_error("Native ColumnarFileScan Not support For This Type: " + type->getKind()); @@ -569,28 +604,37 @@ std::unique_ptr CopyToOmniVec(const orc::Type *type, int &omniTypeId JNIEXPORT jlong JNICALL Java_com_huawei_boostkit_scan_jni_OrcColumnarBatchJniReader_recordReaderNext(JNIEnv *env, jobject jObj, jlong rowReader, jlong batch, jintArray typeId, jlongArray vecNativeId) { + JNI_FUNC_START orc::RowReader *rowReaderPtr = (orc::RowReader *)rowReader; orc::ColumnVectorBatch *columnVectorBatch = (orc::ColumnVectorBatch *)batch; std::vector> omniVecs; const orc::Type &baseTp = rowReaderPtr->getSelectedType(); uint64_t batchRowSize = 0; + auto ptr = env->GetIntArrayElements(typeId, JNI_FALSE); + if (ptr == NULL) { + throw std::runtime_error("Types should not be null"); + } + int32_t arrLen = (int32_t) env->GetArrayLength(typeId); if (rowReaderPtr->next(*columnVectorBatch)) { orc::StructVectorBatch *root = dynamic_cast(columnVectorBatch); batchRowSize = root->fields[0]->numElements; int32_t vecCnt = root->fields.size(); - std::vector omniTypeIds(vecCnt, 0); + if (vecCnt != arrLen) { + throw std::runtime_error("Types should align to root fields"); + } for (int32_t id = 0; id < vecCnt; id++) { auto type = baseTp.getSubtype(id); - omniVecs.emplace_back(CopyToOmniVec(type, omniTypeIds[id], root->fields[id], isDecimal64Transfor128)); + int omniTypeId = ptr[id]; + omniVecs.emplace_back(CopyToOmniVec(type, omniTypeId, root->fields[id])); } for (int32_t id = 0; id < vecCnt; id++) { - env->SetIntArrayRegion(typeId, id, 1, omniTypeIds.data() + id); jlong omniVec = reinterpret_cast(omniVecs[id].release()); env->SetLongArrayRegion(vecNativeId, id, 1, &omniVec); } } return (jlong) batchRowSize; + JNI_FUNC_END(runtimeExceptionClass) } /* diff --git a/omnioperator/omniop-native-reader/cpp/src/jni/OrcColumnarBatchJniReader.h b/omnioperator/omniop-native-reader/cpp/src/jni/OrcColumnarBatchJniReader.h index 829f5c0744d3d563601ec6506ebfc82b5a020e93..8b942fe8b3e0975cccfece4a68e9112b6c550802 100644 --- a/omnioperator/omniop-native-reader/cpp/src/jni/OrcColumnarBatchJniReader.h +++ b/omnioperator/omniop-native-reader/cpp/src/jni/OrcColumnarBatchJniReader.h @@ -141,8 +141,7 @@ int BuildLeaves(PredicateOperatorType leafOp, std::vector &litList bool StringToBool(const std::string &boolStr); -int CopyToOmniVec(const orc::Type *type, int &omniTypeId, uint64_t &omniVecId, orc::ColumnVectorBatch *field, - bool isDecimal64Transfor128); +std::unique_ptr CopyToOmniVec(const orc::Type *type, int omniTypeId, orc::ColumnVectorBatch *field); #ifdef __cplusplus } diff --git a/omnioperator/omniop-native-reader/cpp/src/parquet/ParquetDecoder.h b/omnioperator/omniop-native-reader/cpp/src/parquet/ParquetDecoder.h index a36c2e2acb430d15e32f1a1da1be6c83700ecd7d..0bc32f33277da54253f4ca81bab490a3c2a1875d 100644 --- a/omnioperator/omniop-native-reader/cpp/src/parquet/ParquetDecoder.h +++ b/omnioperator/omniop-native-reader/cpp/src/parquet/ParquetDecoder.h @@ -346,6 +346,7 @@ namespace omniruntime::reader { vec->SetValue(i + offset, value); } values_decoded += num_indices; + offset += num_indices; } *out_num_values = values_decoded; return Status::OK(); diff --git a/omnioperator/omniop-native-reader/cpp/test/CMakeLists.txt b/omnioperator/omniop-native-reader/cpp/test/CMakeLists.txt index 3d1d559df94b1137db424b3318b86b956830e09c..cff2824fae3053fd2ec04b16471ab775a23b2de9 100644 --- a/omnioperator/omniop-native-reader/cpp/test/CMakeLists.txt +++ b/omnioperator/omniop-native-reader/cpp/test/CMakeLists.txt @@ -31,7 +31,7 @@ target_link_libraries(${TP_TEST_TARGET} pthread stdc++ dl - boostkit-omniop-vector-1.4.0-aarch64 + boostkit-omniop-vector-1.5.0-aarch64 securec spark_columnar_plugin) diff --git a/omnioperator/omniop-native-reader/java/pom.xml b/omnioperator/omniop-native-reader/java/pom.xml index 99c66a43076a3c9b7f8b529749436e9c4ea10ed9..8f6a401efe919fc58e426b25dd185f383fca2266 100644 --- a/omnioperator/omniop-native-reader/java/pom.xml +++ b/omnioperator/omniop-native-reader/java/pom.xml @@ -8,7 +8,7 @@ com.huawei.boostkit boostkit-omniop-native-reader jar - 3.3.1-1.4.0 + 3.3.1-1.5.0 BoostKit Spark Native Sql Engine Extension With OmniOperator @@ -31,7 +31,7 @@ com.huawei.boostkit boostkit-omniop-bindings aarch64 - 1.4.0 + 1.5.0 org.slf4j diff --git a/omnioperator/omniop-spark-extension-ock/ock-omniop-tuning/pom.xml b/omnioperator/omniop-spark-extension-ock/ock-omniop-tuning/pom.xml index 608a3ca714fe707b61758f870d0b32877e22f9b2..e8815aea6ee6936e5d7a4ab2083ec115a3c347f7 100644 --- a/omnioperator/omniop-spark-extension-ock/ock-omniop-tuning/pom.xml +++ b/omnioperator/omniop-spark-extension-ock/ock-omniop-tuning/pom.xml @@ -65,12 +65,12 @@ com.huawei.boostkit boostkit-omniop-bindings - 1.3.0 + 1.4.0 com.huawei.kunpeng boostkit-omniop-spark - 3.3.1-1.3.0 + 3.3.1-1.4.0 org.scalatest diff --git a/omnioperator/omniop-spark-extension-ock/pom.xml b/omnioperator/omniop-spark-extension-ock/pom.xml index 84c9208cc3caebf6a4a3dff7a9b4cd9b0b4d63ee..8201865444e199606338b6812f6d11b5fe8d250b 100644 --- a/omnioperator/omniop-spark-extension-ock/pom.xml +++ b/omnioperator/omniop-spark-extension-ock/pom.xml @@ -62,12 +62,12 @@ com.huawei.boostkit boostkit-omniop-bindings - 1.3.0 + 1.4.0 com.huawei.kunpeng boostkit-omniop-spark - 3.3.1-1.3.0 + 3.3.1-1.4.0 com.huawei.ock diff --git a/omnioperator/omniop-spark-extension/cpp/src/CMakeLists.txt b/omnioperator/omniop-spark-extension/cpp/src/CMakeLists.txt index 26df3cb85b9255ee0969d8631deb6ab76488101d..fe4dc5fc593d0435571859b9b687bff24170f265 100644 --- a/omnioperator/omniop-spark-extension/cpp/src/CMakeLists.txt +++ b/omnioperator/omniop-spark-extension/cpp/src/CMakeLists.txt @@ -42,7 +42,7 @@ target_link_libraries (${PROJ_TARGET} PUBLIC snappy lz4 zstd - boostkit-omniop-vector-1.4.0-aarch64 + boostkit-omniop-vector-1.5.0-aarch64 ) set_target_properties(${PROJ_TARGET} PROPERTIES diff --git a/omnioperator/omniop-spark-extension/cpp/src/io/SparkFile.cc b/omnioperator/omniop-spark-extension/cpp/src/io/SparkFile.cc index 3c6e3b3bc31746c7c28a2a63f5bd1b5b1b2a3e44..7e46b9f560f80d1c689e9a1e9781c08de8b3ee54 100644 --- a/omnioperator/omniop-spark-extension/cpp/src/io/SparkFile.cc +++ b/omnioperator/omniop-spark-extension/cpp/src/io/SparkFile.cc @@ -141,14 +141,18 @@ namespace spark { if (closed) { throw std::logic_error("Cannot write to closed stream."); } - ssize_t bytesWrite = ::write(file, buf, length); - if (bytesWrite == -1) { - throw std::runtime_error("Bad write of " + filename); - } - if (static_cast(bytesWrite) != length) { - throw std::runtime_error("Short write of " + filename); + + size_t bytesWritten = 0; + while (bytesWritten < length) { + ssize_t actualBytes = ::write(file, static_cast(buf) + bytesWritten, length - bytesWritten); + if (actualBytes <= 0) { + close(); + std::string errMsg = "Bad write of " + filename + " since " + strerror(errno) + ",actual write bytes " + + std::to_string(actualBytes) + "."; + throw std::runtime_error(errMsg); + } + bytesWritten += actualBytes; } - bytesWritten += static_cast(bytesWrite); } const std::string& getName() const override { @@ -177,4 +181,4 @@ namespace spark { InputStream::~InputStream() { // PASS }; -} \ No newline at end of file +} diff --git a/omnioperator/omniop-spark-extension/cpp/src/jni/SparkJniWrapper.cpp b/omnioperator/omniop-spark-extension/cpp/src/jni/SparkJniWrapper.cpp index 14785a9cf453f5925974f8b85dc80538a5b85a17..d67ba33c7c0da99334e5f573ef397db0494ddfad 100644 --- a/omnioperator/omniop-spark-extension/cpp/src/jni/SparkJniWrapper.cpp +++ b/omnioperator/omniop-spark-extension/cpp/src/jni/SparkJniWrapper.cpp @@ -34,7 +34,7 @@ JNIEXPORT jlong JNICALL Java_com_huawei_boostkit_spark_jni_SparkJniWrapper_nativ jstring jInputType, jint jNumCols, jint buffer_size, jstring compression_type_jstr, jstring data_file_jstr, jint num_sub_dirs, jstring local_dirs_jstr, jlong compress_block_size, - jint spill_batch_row, jlong spill_memory_threshold) + jint spill_batch_row, jlong task_spill_memory_threshold, jlong executor_spill_memory_threshold) { JNI_FUNC_START if (partitioning_name_jstr == nullptr) { @@ -107,8 +107,11 @@ JNIEXPORT jlong JNICALL Java_com_huawei_boostkit_spark_jni_SparkJniWrapper_nativ if (spill_batch_row > 0) { splitOptions.spill_batch_row_num = spill_batch_row; } - if (spill_memory_threshold > 0) { - splitOptions.spill_mem_threshold = spill_memory_threshold; + if (task_spill_memory_threshold > 0) { + splitOptions.task_spill_mem_threshold = task_spill_memory_threshold; + } + if (executor_spill_memory_threshold > 0) { + splitOptions.executor_spill_mem_threshold = executor_spill_memory_threshold; } if (compress_block_size > 0) { splitOptions.compress_block_size = compress_block_size; @@ -124,16 +127,16 @@ JNIEXPORT jlong JNICALL Java_com_huawei_boostkit_spark_jni_SparkJniWrapper_nativ auto splitter = Splitter::Make(partitioning_name, inputDataTypesTmp, jNumCols, num_partitions, std::move(splitOptions)); - return g_shuffleSplitterHolder.Insert(std::shared_ptr(splitter)); + return reinterpret_cast(static_cast(splitter)); JNI_FUNC_END(runtimeExceptionClass) } JNIEXPORT jlong JNICALL Java_com_huawei_boostkit_spark_jni_SparkJniWrapper_split( - JNIEnv *env, jobject jObj, jlong splitter_id, jlong jVecBatchAddress) + JNIEnv *env, jobject jObj, jlong splitter_addr, jlong jVecBatchAddress) { - auto splitter = g_shuffleSplitterHolder.Lookup(splitter_id); + auto splitter = reinterpret_cast(splitter_addr); if (!splitter) { - std::string error_message = "Invalid splitter id " + std::to_string(splitter_id); + std::string error_message = "Invalid splitter id " + std::to_string(splitter_addr); env->ThrowNew(runtimeExceptionClass, error_message.c_str()); return -1; } @@ -146,14 +149,33 @@ JNIEXPORT jlong JNICALL Java_com_huawei_boostkit_spark_jni_SparkJniWrapper_split JNI_FUNC_END_WITH_VECBATCH(runtimeExceptionClass, splitter->GetInputVecBatch()) } +JNIEXPORT jlong JNICALL Java_com_huawei_boostkit_spark_jni_SparkJniWrapper_rowSplit( + JNIEnv *env, jobject jObj, jlong splitter_addr, jlong jVecBatchAddress) +{ + auto splitter = reinterpret_cast(splitter_addr); + if (!splitter) { + std::string error_message = "Invalid splitter id " + std::to_string(splitter_addr); + env->ThrowNew(runtimeExceptionClass, error_message.c_str()); + return -1; + } + + auto vecBatch = (VectorBatch *) jVecBatchAddress; + splitter->SetInputVecBatch(vecBatch); + JNI_FUNC_START + splitter->SplitByRow(vecBatch); + return 0L; + JNI_FUNC_END_WITH_VECBATCH(runtimeExceptionClass, splitter->GetInputVecBatch()) +} + JNIEXPORT jobject JNICALL Java_com_huawei_boostkit_spark_jni_SparkJniWrapper_stop( - JNIEnv* env, jobject, jlong splitter_id) + JNIEnv* env, jobject, jlong splitter_addr) { JNI_FUNC_START - auto splitter = g_shuffleSplitterHolder.Lookup(splitter_id); + auto splitter = reinterpret_cast(splitter_addr); if (!splitter) { - std::string error_message = "Invalid splitter id " + std::to_string(splitter_id); + std::string error_message = "Invalid splitter id " + std::to_string(splitter_addr); env->ThrowNew(runtimeExceptionClass, error_message.c_str()); + return nullptr; } splitter->Stop(); @@ -170,15 +192,40 @@ JNIEXPORT jobject JNICALL Java_com_huawei_boostkit_spark_jni_SparkJniWrapper_sto JNI_FUNC_END(runtimeExceptionClass) } +JNIEXPORT jobject JNICALL Java_com_huawei_boostkit_spark_jni_SparkJniWrapper_rowStop( + JNIEnv* env, jobject, jlong splitter_addr) +{ + JNI_FUNC_START + auto splitter = reinterpret_cast(splitter_addr); + if (!splitter) { + std::string error_message = "Invalid splitter id " + std::to_string(splitter_addr); + env->ThrowNew(runtimeExceptionClass, error_message.c_str()); + return nullptr; + } + splitter->StopByRow(); + + const auto& partition_length = splitter->PartitionLengths(); + auto partition_length_arr = env->NewLongArray(partition_length.size()); + auto src = reinterpret_cast(partition_length.data()); + env->SetLongArrayRegion(partition_length_arr, 0, partition_length.size(), src); + jobject split_result = env->NewObject( + splitResultClass, splitResultConstructor, splitter->TotalComputePidTime(), + splitter->TotalWriteTime(), splitter->TotalSpillTime(), + splitter->TotalBytesWritten(), splitter->TotalBytesSpilled(), partition_length_arr); + + return split_result; + JNI_FUNC_END(runtimeExceptionClass) +} + JNIEXPORT void JNICALL Java_com_huawei_boostkit_spark_jni_SparkJniWrapper_close( - JNIEnv* env, jobject, jlong splitter_id) + JNIEnv* env, jobject, jlong splitter_addr) { JNI_FUNC_START - auto splitter = g_shuffleSplitterHolder.Lookup(splitter_id); + auto splitter = reinterpret_cast(splitter_addr); if (!splitter) { - std::string error_message = "Invalid splitter id " + std::to_string(splitter_id); + std::string error_message = "Invalid splitter id " + std::to_string(splitter_addr); env->ThrowNew(runtimeExceptionClass, error_message.c_str()); } - g_shuffleSplitterHolder.Erase(splitter_id); + delete splitter; JNI_FUNC_END_VOID(runtimeExceptionClass) } diff --git a/omnioperator/omniop-spark-extension/cpp/src/jni/SparkJniWrapper.hh b/omnioperator/omniop-spark-extension/cpp/src/jni/SparkJniWrapper.hh index c98c10383c4cab04c4770adb8ebdab0ebdb4424b..15076b2ab6c3d2900e0393200766b63db56f07d8 100644 --- a/omnioperator/omniop-spark-extension/cpp/src/jni/SparkJniWrapper.hh +++ b/omnioperator/omniop-spark-extension/cpp/src/jni/SparkJniWrapper.hh @@ -20,7 +20,6 @@ #include #include #include -#include "concurrent_map.h" #include "shuffle/splitter.h" #ifndef SPARK_JNI_WRAPPER @@ -39,21 +38,27 @@ Java_com_huawei_boostkit_spark_jni_SparkJniWrapper_nativeMake( jstring jInputType, jint jNumCols, jint buffer_size, jstring compression_type_jstr, jstring data_file_jstr, jint num_sub_dirs, jstring local_dirs_jstr, jlong compress_block_size, - jint spill_batch_row, jlong spill_memory_threshold); + jint spill_batch_row, jlong task_spill_memory_threshold, jlong executor_spill_memory_threshold); JNIEXPORT jlong JNICALL Java_com_huawei_boostkit_spark_jni_SparkJniWrapper_split( JNIEnv* env, jobject jObj, jlong splitter_id, jlong jVecBatchAddress); +JNIEXPORT jlong JNICALL +Java_com_huawei_boostkit_spark_jni_SparkJniWrapper_rowSplit( + JNIEnv* env, jobject jObj, jlong splitter_id, jlong jVecBatchAddress); + JNIEXPORT jobject JNICALL Java_com_huawei_boostkit_spark_jni_SparkJniWrapper_stop( JNIEnv* env, jobject, jlong splitter_id); - + +JNIEXPORT jobject JNICALL +Java_com_huawei_boostkit_spark_jni_SparkJniWrapper_rowStop( + JNIEnv* env, jobject, jlong splitter_id); + JNIEXPORT void JNICALL Java_com_huawei_boostkit_spark_jni_SparkJniWrapper_close( - JNIEnv* env, jobject, jlong splitter_id); - -static ConcurrentMap> g_shuffleSplitterHolder; + JNIEnv* env, jobject, jlong splitter_id); #ifdef __cplusplus } diff --git a/omnioperator/omniop-spark-extension/cpp/src/jni/jni_common.cpp b/omnioperator/omniop-spark-extension/cpp/src/jni/jni_common.cpp index f0e3a225363913268b885cbcde2903d00eea7476..605107e52b2adad036f7682b161c46fecd84f190 100644 --- a/omnioperator/omniop-spark-extension/cpp/src/jni/jni_common.cpp +++ b/omnioperator/omniop-spark-extension/cpp/src/jni/jni_common.cpp @@ -109,8 +109,6 @@ void JNI_OnUnload(JavaVM* vm, void* reserved) env->DeleteGlobalRef(jsonClass); env->DeleteGlobalRef(arrayListClass); env->DeleteGlobalRef(threadClass); - - g_shuffleSplitterHolder.Clear(); } #endif //THESTRAL_PLUGIN_MASTER_JNI_COMMON_CPP diff --git a/omnioperator/omniop-spark-extension/cpp/src/proto/vec_data.proto b/omnioperator/omniop-spark-extension/cpp/src/proto/vec_data.proto index c40472020171692ea7b0acde2dd873efeda691f4..33ee64ec84d12b4bbf76c579645a8cb8a2e9db7d 100644 --- a/omnioperator/omniop-spark-extension/cpp/src/proto/vec_data.proto +++ b/omnioperator/omniop-spark-extension/cpp/src/proto/vec_data.proto @@ -57,4 +57,16 @@ message VecType { NANOSEC = 3; } TimeUnit timeUnit = 6; +} + +message ProtoRow { + bytes data = 1; + uint32 length = 2; +} + +message ProtoRowBatch { + int32 rowCnt = 1; + int32 vecCnt = 2; + repeated VecType vecTypes = 3; + repeated ProtoRow rows = 4; } \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension/cpp/src/shuffle/splitter.cpp b/omnioperator/omniop-spark-extension/cpp/src/shuffle/splitter.cpp index c503c38f085cd51ed29061fbfb4d91d2e573cd8c..2e85b61a2fe90ecdc16b23a73e9d3b081175cada 100644 --- a/omnioperator/omniop-spark-extension/cpp/src/shuffle/splitter.cpp +++ b/omnioperator/omniop-spark-extension/cpp/src/shuffle/splitter.cpp @@ -77,7 +77,7 @@ int Splitter::AllocatePartitionBuffers(int32_t partition_id, int32_t new_size) { case SHUFFLE_DECIMAL128: default: { void *ptr_tmp = static_cast(options_.allocator->Alloc(new_size * (1 << column_type_id_[i]))); - fixed_valueBuffer_size_[partition_id] = new_size * (1 << column_type_id_[i]); + fixed_valueBuffer_size_[partition_id] += new_size * (1 << column_type_id_[i]); if (nullptr == ptr_tmp) { throw std::runtime_error("Allocator for AllocatePartitionBuffers Failed! "); } @@ -238,6 +238,7 @@ void Splitter::SplitBinaryVector(BaseVector *varcharVector, int col_schema) { if (varcharVector->GetEncoding() == OMNI_DICTIONARY) { auto vc = reinterpret_cast> *>( varcharVector); + cached_vectorbatch_size_ += num_rows * (sizeof(bool) + sizeof(int32_t)); for (auto row = 0; row < num_rows; ++row) { auto pid = partition_id_[row]; uint8_t *dst = nullptr; @@ -272,7 +273,8 @@ void Splitter::SplitBinaryVector(BaseVector *varcharVector, int col_schema) { } } else { auto vc = reinterpret_cast> *>(varcharVector); - for (auto row = 0; row < num_rows; ++row) { + cached_vectorbatch_size_ += num_rows * (sizeof(bool) + sizeof(int32_t)) + sizeof(int32_t); + for (auto row = 0; row < num_rows; ++row) { auto pid = partition_id_[row]; uint8_t *dst = nullptr; uint32_t str_len = 0; @@ -310,7 +312,6 @@ void Splitter::SplitBinaryVector(BaseVector *varcharVector, int col_schema) { int Splitter::SplitBinaryArray(VectorBatch& vb) { - const auto num_rows = vb.GetRowCount(); auto vec_cnt_vb = vb.GetVectorCount(); auto vec_cnt_schema = singlePartitionFlag ? vec_cnt_vb : vec_cnt_vb - 1; for (auto col_schema = 0; col_schema < vec_cnt_schema; ++col_schema) { @@ -355,7 +356,7 @@ int Splitter::SplitFixedWidthValidityBuffer(VectorBatch& vb){ dst_addrs[pid] = const_cast(validity_buffer->data_); std::memset(validity_buffer->data_, 0, new_size); partition_fixed_width_buffers_[col][pid][0] = std::move(validity_buffer); - fixed_nullBuffer_size_[pid] = new_size; + fixed_nullBuffer_size_[pid] += new_size; } } @@ -412,9 +413,11 @@ int Splitter::CacheVectorBatch(int32_t partition_id, bool reset_buffers) { } } } - cached_vectorbatch_size_ += batch_partition_size; - partition_cached_vectorbatch_[partition_id].push_back(std::move(bufferArrayTotal)); - partition_buffer_idx_base_[partition_id] = 0; + cached_vectorbatch_size_ += batch_partition_size; + partition_cached_vectorbatch_[partition_id].push_back(std::move(bufferArrayTotal)); + fixed_valueBuffer_size_[partition_id] = 0; + fixed_nullBuffer_size_[partition_id] = 0; + partition_buffer_idx_base_[partition_id] = 0; } return 0; } @@ -458,7 +461,7 @@ int Splitter::DoSplit(VectorBatch& vb) { TIME_NANO_OR_RAISE(total_spill_time_, SpillToTmpFile()); isSpill = true; } - if (cached_vectorbatch_size_ + current_fixed_alloc_buffer_size_ >= options_.spill_mem_threshold) { + if (cached_vectorbatch_size_ + current_fixed_alloc_buffer_size_ >= options_.task_spill_mem_threshold) { LogsDebug(" Spill For Memory Size Threshold."); TIME_NANO_OR_RAISE(total_spill_time_, SpillToTmpFile()); isSpill = true; @@ -521,7 +524,7 @@ void Splitter::ToSplitterTypeId(int num_cols) void Splitter::CastOmniToShuffleType(DataTypeId omniType, ShuffleTypeId shuffleType) { - vector_batch_col_types_.push_back(omniType); + proto_col_types_.push_back(CastOmniTypeIdToProtoVecType(omniType)); column_type_id_.push_back(shuffleType); } @@ -529,35 +532,21 @@ int Splitter::Split_Init(){ num_row_splited_ = 0; cached_vectorbatch_size_ = 0; - partition_id_cnt_cur_ = static_cast(malloc(num_partitions_ * sizeof(int32_t))); - std::memset(partition_id_cnt_cur_, 0, num_partitions_ * sizeof(int32_t)); - - partition_id_cnt_cache_ = static_cast(malloc(num_partitions_ * sizeof(uint64_t))); - std::memset(partition_id_cnt_cache_, 0, num_partitions_ * sizeof(uint64_t)); - - partition_buffer_size_ = static_cast(malloc(num_partitions_ * sizeof(int32_t))); - std::memset(partition_buffer_size_, 0, num_partitions_ * sizeof(int32_t)); - - partition_buffer_idx_base_ = static_cast(malloc(num_partitions_ * sizeof(int32_t))); - std::memset(partition_buffer_idx_base_, 0, num_partitions_ * sizeof(int32_t)); - - partition_buffer_idx_offset_ = static_cast(malloc(num_partitions_ * sizeof(int32_t))); - std::memset(partition_buffer_idx_offset_, 0, num_partitions_ * sizeof(int32_t)); - - partition_serialization_size_ = static_cast(malloc(num_partitions_ * sizeof(uint32_t))); - std::memset(partition_serialization_size_, 0, num_partitions_ * sizeof(uint32_t)); + partition_id_cnt_cur_ = new int32_t[num_partitions_](); + partition_id_cnt_cache_ = new uint64_t[num_partitions_](); + partition_buffer_size_ = new int32_t[num_partitions_](); + partition_buffer_idx_base_ = new int32_t[num_partitions_](); + partition_buffer_idx_offset_ = new int32_t[num_partitions_](); + partition_serialization_size_ = new uint32_t[num_partitions_](); partition_cached_vectorbatch_.resize(num_partitions_); fixed_width_array_idx_.clear(); partition_lengths_.resize(num_partitions_); - fixed_valueBuffer_size_ = static_cast(malloc(num_partitions_ * sizeof(uint32_t))); - std::memset(fixed_valueBuffer_size_, 0, num_partitions_ * sizeof(uint32_t)); + fixed_valueBuffer_size_ = new uint32_t[num_partitions_](); + fixed_nullBuffer_size_ = new uint32_t[num_partitions_](); - fixed_nullBuffer_size_ = static_cast(malloc(num_partitions_ * sizeof(uint32_t))); - std::memset(fixed_nullBuffer_size_, 0, num_partitions_ * sizeof(uint32_t)); - - //obtain configed dir from Environment Variables + // obtain configed dir from Environment Variables configured_dirs_ = GetConfiguredLocalDirs(); sub_dir_selection_.assign(configured_dirs_.size(), 0); @@ -601,19 +590,76 @@ int Splitter::Split_Init(){ for (auto i = 0; i < num_partitions_; ++i) { vc_partition_array_buffers_[i].resize(column_type_id_.size()); } + + partition_rows.resize(num_partitions_); return 0; } int Splitter::Split(VectorBatch& vb ) { - //计算vectorBatch分区信息 + // 计算vectorBatch分区信息 LogsTrace(" split vb row number: %d ", vb.GetRowCount()); TIME_NANO_OR_RAISE(total_compute_pid_time_, ComputeAndCountPartitionId(vb)); - //执行分区动作 + // 执行分区动作 DoSplit(vb); return 0; } +int Splitter::SplitByRow(VectorBatch *vecBatch) { + int32_t rowCount = vecBatch->GetRowCount(); + for (int pid = 0; pid < num_partitions_; ++pid) { + auto needCapacity = partition_rows[pid].size() + rowCount; + if (partition_rows[pid].capacity() < needCapacity) { + auto prepareCapacity = partition_rows[pid].capacity() * expansion; + auto newCapacity = prepareCapacity > needCapacity ? prepareCapacity : needCapacity; + partition_rows[pid].reserve(newCapacity); + } + } + + if (singlePartitionFlag) { + RowBatch *rowBatch = VectorHelper::TransRowBatchFromVectorBatch(vecBatch); + for (int i = 0; i < rowCount; ++i) { + RowInfo *rowInfo = rowBatch->Get(i); + partition_rows[0].emplace_back(rowInfo); + total_input_size += rowInfo->length; + } + delete vecBatch; + } else { + auto pidVec = reinterpret_cast *>(vecBatch->Get(0)); + auto tmpVectorBatch = new VectorBatch(rowCount); + for (int i = 1; i < vecBatch->GetVectorCount(); ++i) { + tmpVectorBatch->Append(vecBatch->Get(i)); + } + vecBatch->ResizeVectorCount(1); + RowBatch *rowBatch = VectorHelper::TransRowBatchFromVectorBatch(tmpVectorBatch); + for (int i = 0; i < rowCount; ++i) { + auto pid = pidVec->GetValue(i); + RowInfo *rowInfo = rowBatch->Get(i); + partition_rows[pid].emplace_back(rowInfo); + total_input_size += rowInfo->length; + } + delete vecBatch; + delete tmpVectorBatch; + } + + // spill + // process level: If the memory usage of the current executor exceeds the threshold, spill is triggered. + auto usedMemorySize = omniruntime::mem::MemoryManager::GetGlobalAccountedMemory(); + if (usedMemorySize > options_.executor_spill_mem_threshold) { + TIME_NANO_OR_RAISE(total_spill_time_, SpillToTmpFileByRow()); + total_input_size = 0; + isSpill = true; + } + + // task level: If the memory usage of the current task exceeds the threshold, spill is triggered. + if (total_input_size > options_.task_spill_mem_threshold) { + TIME_NANO_OR_RAISE(total_spill_time_, SpillToTmpFileByRow()); + total_input_size = 0; + isSpill = true; + } + return 0; +} + std::shared_ptr Splitter::CaculateSpilledTmpFilePartitionOffsets() { void *ptr_tmp = static_cast(options_.allocator->Alloc((num_partitions_ + 1) * sizeof(uint64_t))); if (nullptr == ptr_tmp) { @@ -633,8 +679,8 @@ std::shared_ptr Splitter::CaculateSpilledTmpFilePartitionOffsets() { return ptrPartitionOffsets; } -spark::VecType::VecTypeId CastShuffleTypeIdToVecType(int32_t tmpType) { - switch (tmpType) { +spark::VecType::VecTypeId Splitter::CastOmniTypeIdToProtoVecType(int32_t omniType) { + switch (omniType) { case OMNI_NONE: return spark::VecType::VEC_TYPE_NONE; case OMNI_INT: @@ -674,7 +720,7 @@ spark::VecType::VecTypeId CastShuffleTypeIdToVecType(int32_t tmpType) { case DataTypeId::OMNI_INVALID: return spark::VecType::VEC_TYPE_INVALID; default: { - throw std::runtime_error("castShuffleTypeIdToVecType() unexpected ShuffleTypeId"); + throw std::runtime_error("CastOmniTypeIdToProtoVecType() unexpected OmniTypeId"); } } }; @@ -822,13 +868,13 @@ int32_t Splitter::ProtoWritePartition(int32_t partition_id, std::unique_ptrset_veccnt(column_type_id_.size()); int fixColIndexTmp = 0; for (size_t indexSchema = 0; indexSchema < column_type_id_.size(); indexSchema++) { - spark::Vec * vec = vecBatchProto->add_vecs(); + spark::Vec *vec = vecBatchProto->add_vecs(); switch (column_type_id_[indexSchema]) { case ShuffleTypeId::SHUFFLE_1BYTE: case ShuffleTypeId::SHUFFLE_2BYTE: case ShuffleTypeId::SHUFFLE_4BYTE: case ShuffleTypeId::SHUFFLE_8BYTE: - case ShuffleTypeId::SHUFFLE_DECIMAL128:{ + case ShuffleTypeId::SHUFFLE_DECIMAL128: { SerializingFixedColumns(partition_id, *vec, fixColIndexTmp, &splitRowInfoTmp); fixColIndexTmp++; // 定长序列化数量++ break; @@ -842,13 +888,13 @@ int32_t Splitter::ProtoWritePartition(int32_t partition_id, std::unique_ptrmutable_vectype(); - vt->set_typeid_(CastShuffleTypeIdToVecType(vector_batch_col_types_[indexSchema])); - LogsDebug("precision[indexSchema %d]: %d , scale[indexSchema %d]: %d ", - indexSchema, input_col_types.inputDataPrecisions[indexSchema], - indexSchema, input_col_types.inputDataScales[indexSchema]); + vt->set_typeid_(proto_col_types_[indexSchema]); if(vt->typeid_() == spark::VecType::VEC_TYPE_DECIMAL128 || vt->typeid_() == spark::VecType::VEC_TYPE_DECIMAL64){ vt->set_precision(input_col_types.inputDataPrecisions[indexSchema]); vt->set_scale(input_col_types.inputDataScales[indexSchema]); + LogsDebug("precision[indexSchema %d]: %d , scale[indexSchema %d]: %d ", + indexSchema, input_col_types.inputDataPrecisions[indexSchema], + indexSchema, input_col_types.inputDataScales[indexSchema]); } } curBatch++; @@ -882,7 +928,67 @@ int32_t Splitter::ProtoWritePartition(int32_t partition_id, std::unique_ptr &bufferStream, void *bufferOut, int32_t &sizeOut) { + uint64_t rowCount = partition_rows[partition_id].size(); + uint64_t onceCopyRow = 0; + uint32_t batchCount = 0; + while (0 < rowCount) { + if (options_.spill_batch_row_num < rowCount) { + onceCopyRow = options_.spill_batch_row_num; + } else { + onceCopyRow = rowCount; + } + + protoRowBatch->set_rowcnt(onceCopyRow); + protoRowBatch->set_veccnt(proto_col_types_.size()); + for (int i = 0; i < proto_col_types_.size(); ++i) { + spark::VecType *vt = protoRowBatch->add_vectypes(); + vt->set_typeid_(proto_col_types_[i]); + if(vt->typeid_() == spark::VecType::VEC_TYPE_DECIMAL128 || vt->typeid_() == spark::VecType::VEC_TYPE_DECIMAL64){ + vt->set_precision(input_col_types.inputDataPrecisions[i]); + vt->set_scale(input_col_types.inputDataScales[i]); + LogsDebug("precision[indexSchema %d]: %d , scale[indexSchema %d]: %d ", + i, input_col_types.inputDataPrecisions[i], + i, input_col_types.inputDataScales[i]); + } + } + + int64_t offset = batchCount * options_.spill_batch_row_num; + auto rowInfoPtr = partition_rows[partition_id].data() + offset; + for (int i = 0; i < onceCopyRow; ++i) { + RowInfo *rowInfo = rowInfoPtr[i]; + spark::ProtoRow *protoRow = protoRowBatch->add_rows(); + protoRow->set_data(rowInfo->row, rowInfo->length); + protoRow->set_length(rowInfo->length); + // free row memory + delete rowInfo; + } + + if (protoRowBatch->ByteSizeLong() > UINT32_MAX) { + throw std::runtime_error("Unsafe static_cast long to uint_32t."); + } + uint32_t protoRowBatchSize = reversebytes_uint32t(static_cast(protoRowBatch->ByteSizeLong())); + if (bufferStream->Next(&bufferOut, &sizeOut)) { + std::memcpy(bufferOut, &protoRowBatchSize, sizeof(protoRowBatchSize)); + if (sizeof(protoRowBatchSize) < sizeOut) { + bufferStream->BackUp(sizeOut - sizeof(protoRowBatchSize)); + } + } + + protoRowBatch->SerializeToZeroCopyStream(bufferStream.get()); + rowCount -= onceCopyRow; + batchCount++; + protoRowBatch->Clear(); + } + uint64_t partitionBatchSize = bufferStream->flush(); + total_bytes_written_ += partitionBatchSize; + partition_lengths_[partition_id] += partitionBatchSize; + partition_rows[partition_id].clear(); + LogsDebug(" partitionBatch write length: %lu", partitionBatchSize); + return 0; } int Splitter::protoSpillPartition(int32_t partition_id, std::unique_ptr &bufferStream) { @@ -908,13 +1014,13 @@ int Splitter::protoSpillPartition(int32_t partition_id, std::unique_ptrset_veccnt(column_type_id_.size()); int fixColIndexTmp = 0; for (size_t indexSchema = 0; indexSchema < column_type_id_.size(); indexSchema++) { - spark::Vec * vec = vecBatchProto->add_vecs(); + spark::Vec *vec = vecBatchProto->add_vecs(); switch (column_type_id_[indexSchema]) { case ShuffleTypeId::SHUFFLE_1BYTE: case ShuffleTypeId::SHUFFLE_2BYTE: case ShuffleTypeId::SHUFFLE_4BYTE: case ShuffleTypeId::SHUFFLE_8BYTE: - case ShuffleTypeId::SHUFFLE_DECIMAL128:{ + case ShuffleTypeId::SHUFFLE_DECIMAL128: { SerializingFixedColumns(partition_id, *vec, fixColIndexTmp, &splitRowInfoTmp); fixColIndexTmp++; // 定长序列化数量++ break; @@ -928,13 +1034,13 @@ int Splitter::protoSpillPartition(int32_t partition_id, std::unique_ptrmutable_vectype(); - vt->set_typeid_(CastShuffleTypeIdToVecType(vector_batch_col_types_[indexSchema])); - LogsDebug("precision[indexSchema %d]: %d , scale[indexSchema %d]: %d ", - indexSchema, input_col_types.inputDataPrecisions[indexSchema], - indexSchema, input_col_types.inputDataScales[indexSchema]); + vt->set_typeid_(proto_col_types_[indexSchema]); if(vt->typeid_() == spark::VecType::VEC_TYPE_DECIMAL128 || vt->typeid_() == spark::VecType::VEC_TYPE_DECIMAL64){ vt->set_precision(input_col_types.inputDataPrecisions[indexSchema]); vt->set_scale(input_col_types.inputDataScales[indexSchema]); + LogsDebug("precision[indexSchema %d]: %d , scale[indexSchema %d]: %d ", + indexSchema, input_col_types.inputDataPrecisions[indexSchema], + indexSchema, input_col_types.inputDataScales[indexSchema]); } } curBatch++; @@ -974,6 +1080,70 @@ int Splitter::protoSpillPartition(int32_t partition_id, std::unique_ptr &bufferStream) { + uint64_t rowCount = partition_rows[partition_id].size(); + total_spill_row_num_ += rowCount; + + uint64_t onceCopyRow = 0; + uint32_t batchCount = 0; + while (0 < rowCount) { + if (options_.spill_batch_row_num < rowCount) { + onceCopyRow = options_.spill_batch_row_num; + } else { + onceCopyRow = rowCount; + } + + protoRowBatch->set_rowcnt(onceCopyRow); + protoRowBatch->set_veccnt(proto_col_types_.size()); + for (int i = 0; i < proto_col_types_.size(); ++i) { + spark::VecType *vt = protoRowBatch->add_vectypes(); + vt->set_typeid_(proto_col_types_[i]); + if(vt->typeid_() == spark::VecType::VEC_TYPE_DECIMAL128 || vt->typeid_() == spark::VecType::VEC_TYPE_DECIMAL64){ + vt->set_precision(input_col_types.inputDataPrecisions[i]); + vt->set_scale(input_col_types.inputDataScales[i]); + LogsDebug("precision[indexSchema %d]: %d , scale[indexSchema %d]: %d ", + i, input_col_types.inputDataPrecisions[i], + i, input_col_types.inputDataScales[i]); + } + } + + int64_t offset = batchCount * options_.spill_batch_row_num; + auto rowInfoPtr = partition_rows[partition_id].data() + offset; + for (int i = 0; i < onceCopyRow; ++i) { + RowInfo *rowInfo = rowInfoPtr[i]; + spark::ProtoRow *protoRow = protoRowBatch->add_rows(); + protoRow->set_data(rowInfo->row, rowInfo->length); + protoRow->set_length(rowInfo->length); + // free row memory + delete rowInfo; + } + + if (protoRowBatch->ByteSizeLong() > UINT32_MAX) { + throw std::runtime_error("Unsafe static_cast long to uint_32t."); + } + uint32_t protoRowBatchSize = reversebytes_uint32t(static_cast(protoRowBatch->ByteSizeLong())); + void *buffer = nullptr; + if (!bufferStream->NextNBytes(&buffer, sizeof(protoRowBatchSize))) { + throw std::runtime_error("Allocate Memory Failed: Flush Spilled Data, Next failed."); + } + // set serizalized bytes to stream + memcpy(buffer, &protoRowBatchSize, sizeof(protoRowBatchSize)); + LogsDebug(" A Slice Of vecBatchProtoSize: %d ", reversebytes_uint32t(protoRowBatchSize)); + + protoRowBatch->SerializeToZeroCopyStream(bufferStream.get()); + rowCount -= onceCopyRow; + batchCount++; + protoRowBatch->Clear(); + } + + uint64_t partitionBatchSize = bufferStream->flush(); + total_bytes_spilled_ += partitionBatchSize; + partition_serialization_size_[partition_id] = partitionBatchSize; + partition_rows[partition_id].clear(); + LogsDebug(" partitionBatch write length: %lu", partitionBatchSize); + return 0; +} + int Splitter::WriteDataFileProto() { LogsDebug(" spill DataFile: %s ", (options_.next_spilled_file_dir + ".data").c_str()); std::unique_ptr outStream = writeLocalFile(options_.next_spilled_file_dir + ".data"); @@ -991,10 +1161,26 @@ int Splitter::WriteDataFileProto() { return 0; } +int Splitter::WriteDataFileProtoByRow() { + LogsDebug(" spill DataFile: %s ", (options_.next_spilled_file_dir + ".data").c_str()); + std::unique_ptr outStream = writeLocalFile(options_.next_spilled_file_dir + ".data"); + WriterOptions options; + // tmp spilled file no need compression + options.setCompression(CompressionKind_NONE); + std::unique_ptr streamsFactory = createStreamsFactory(options, outStream.get()); + std::unique_ptr bufferStream = streamsFactory->createStream(); + // 顺序写入每个partition的offset + for (auto pid = 0; pid < num_partitions_; ++pid) { + protoSpillPartitionByRow(pid, bufferStream); + } + outStream->close(); + return 0; +} + void Splitter::MergeSpilled() { for (auto pid = 0; pid < num_partitions_; ++pid) { CacheVectorBatch(pid, true); - partition_buffer_size_[pid] = 0; //溢写之后将其清零,条件溢写需要重新分配内存 + partition_buffer_size_[pid] = 0; // 溢写之后将其清零,条件溢写需要重新分配内存 } std::unique_ptr outStream = writeLocalFile(options_.data_file); @@ -1004,13 +1190,13 @@ void Splitter::MergeSpilled() { options.setCompressionBlockSize(options_.compress_block_size); options.setCompressionStrategy(CompressionStrategy_COMPRESSION); std::unique_ptr streamsFactory = createStreamsFactory(options, outStream.get()); - std::unique_ptr bufferOutPutStream = streamsFactory->createStream(); + std::unique_ptr bufferOutPutStream = streamsFactory->createStream(); void* bufferOut = nullptr; int sizeOut = 0; for (int pid = 0; pid < num_partitions_; pid++) { ProtoWritePartition(pid, bufferOutPutStream, bufferOut, sizeOut); - LogsDebug(" MergeSplled traversal partition( %d ) ",pid); + LogsDebug(" MergeSpilled traversal partition( %d ) ", pid); for (auto &pair : spilled_tmp_files_info_) { auto tmpDataFilePath = pair.first + ".data"; auto tmpPartitionOffset = reinterpret_cast(pair.second->data_)[pid]; @@ -1047,10 +1233,56 @@ void Splitter::MergeSpilled() { outStream->close(); } +void Splitter::MergeSpilledByRow() { + std::unique_ptr outStream = writeLocalFile(options_.data_file); + LogsDebug(" Merge Spilled Tmp File: %s ", options_.data_file.c_str()); + WriterOptions options; + options.setCompression(options_.compression_type); + options.setCompressionBlockSize(options_.compress_block_size); + options.setCompressionStrategy(CompressionStrategy_COMPRESSION); + std::unique_ptr streamsFactory = createStreamsFactory(options, outStream.get()); + std::unique_ptr bufferOutPutStream = streamsFactory->createStream(); + + void* bufferOut = nullptr; + int sizeOut = 0; + for (int pid = 0; pid < num_partitions_; pid++) { + ProtoWritePartitionByRow(pid, bufferOutPutStream, bufferOut, sizeOut); + LogsDebug(" MergeSpilled traversal partition( %d ) ", pid); + for (auto &pair : spilled_tmp_files_info_) { + auto tmpDataFilePath = pair.first + ".data"; + auto tmpPartitionOffset = reinterpret_cast(pair.second->data_)[pid]; + auto tmpPartitionSize = reinterpret_cast(pair.second->data_)[pid + 1] - reinterpret_cast(pair.second->data_)[pid]; + LogsDebug(" get Partition Stream...tmpPartitionOffset %d tmpPartitionSize %d path %s", + tmpPartitionOffset, tmpPartitionSize, tmpDataFilePath.c_str()); + std::unique_ptr inputStream = readLocalFile(tmpDataFilePath); + uint64_t targetLen = tmpPartitionSize; + uint64_t seekPosit = tmpPartitionOffset; + uint64_t onceReadLen = 0; + while ((targetLen > 0) && bufferOutPutStream->Next(&bufferOut, &sizeOut)) { + onceReadLen = targetLen > sizeOut ? sizeOut : targetLen; + inputStream->read(bufferOut, onceReadLen, seekPosit); + targetLen -= onceReadLen; + seekPosit += onceReadLen; + if (onceReadLen < sizeOut) { + // Reached END. + bufferOutPutStream->BackUp(sizeOut - onceReadLen); + break; + } + } + + uint64_t flushSize = bufferOutPutStream->flush(); + total_bytes_written_ += flushSize; + LogsDebug(" Merge Flush Partition[%d] flushSize: %ld ", pid, flushSize); + partition_lengths_[pid] += flushSize; + } + } + outStream->close(); +} + void Splitter::WriteSplit() { for (auto pid = 0; pid < num_partitions_; ++pid) { CacheVectorBatch(pid, true); - partition_buffer_size_[pid] = 0; //溢写之后将其清零,条件溢写需要重新分配内存 + partition_buffer_size_[pid] = 0; // 溢写之后将其清零,条件溢写需要重新分配内存 } std::unique_ptr outStream = writeLocalFile(options_.data_file); @@ -1059,11 +1291,11 @@ void Splitter::WriteSplit() { options.setCompressionBlockSize(options_.compress_block_size); options.setCompressionStrategy(CompressionStrategy_COMPRESSION); std::unique_ptr streamsFactory = createStreamsFactory(options, outStream.get()); - std::unique_ptr bufferOutPutStream = streamsFactory->createStream(); + std::unique_ptr bufferOutPutStream = streamsFactory->createStream(); void* bufferOut = nullptr; int32_t sizeOut = 0; - for (auto pid = 0; pid < num_partitions_; ++ pid) { + for (auto pid = 0; pid < num_partitions_; ++pid) { ProtoWritePartition(pid, bufferOutPutStream, bufferOut, sizeOut); } @@ -1074,6 +1306,23 @@ void Splitter::WriteSplit() { outStream->close(); } +void Splitter::WriteSplitByRow() { + std::unique_ptr outStream = writeLocalFile(options_.data_file); + WriterOptions options; + options.setCompression(options_.compression_type); + options.setCompressionBlockSize(options_.compress_block_size); + options.setCompressionStrategy(CompressionStrategy_COMPRESSION); + std::unique_ptr streamsFactory = createStreamsFactory(options, outStream.get()); + std::unique_ptr bufferOutPutStream = streamsFactory->createStream(); + + void* bufferOut = nullptr; + int32_t sizeOut = 0; + for (auto pid = 0; pid < num_partitions_; ++pid) { + ProtoWritePartitionByRow(pid, bufferOutPutStream, bufferOut, sizeOut); + } + outStream->close(); +} + int Splitter::DeleteSpilledTmpFile() { for (auto &pair : spilled_tmp_files_info_) { auto tmpDataFilePath = pair.first + ".data"; @@ -1092,7 +1341,7 @@ int Splitter::DeleteSpilledTmpFile() { int Splitter::SpillToTmpFile() { for (auto pid = 0; pid < num_partitions_; ++pid) { CacheVectorBatch(pid, true); - partition_buffer_size_[pid] = 0; //溢写之后将其清零,条件溢写需要重新分配内存 + partition_buffer_size_[pid] = 0; // 溢写之后将其清零,条件溢写需要重新分配内存 } options_.next_spilled_file_dir = CreateTempShuffleFile(NextSpilledFileDir()); @@ -1105,6 +1354,14 @@ int Splitter::SpillToTmpFile() { return 0; } +int Splitter::SpillToTmpFileByRow() { + options_.next_spilled_file_dir = CreateTempShuffleFile(NextSpilledFileDir()); + WriteDataFileProtoByRow(); + std::shared_ptr ptrTmp = CaculateSpilledTmpFilePartitionOffsets(); + spilled_tmp_files_info_[options_.next_spilled_file_dir] = ptrTmp; + return 0; +} + Splitter::Splitter(InputDataTypes inputDataTypes, int32_t num_cols, int32_t num_partitions, SplitOptions options, bool flag) : input_col_types(inputDataTypes), singlePartitionFlag(flag), @@ -1116,23 +1373,18 @@ Splitter::Splitter(InputDataTypes inputDataTypes, int32_t num_cols, int32_t num_ ToSplitterTypeId(num_cols); } -std::shared_ptr Create(InputDataTypes inputDataTypes, +Splitter *Create(InputDataTypes inputDataTypes, int32_t num_cols, int32_t num_partitions, SplitOptions options, bool flag) { - std::shared_ptr res( - new Splitter(inputDataTypes, - num_cols, - num_partitions, - std::move(options), - flag)); + auto res = new Splitter(inputDataTypes, num_cols, num_partitions, std::move(options), flag); res->Split_Init(); return res; } -std::shared_ptr Splitter::Make( +Splitter *Splitter::Make( const std::string& short_name, InputDataTypes inputDataTypes, int32_t num_cols, @@ -1168,14 +1420,19 @@ int Splitter::Stop() { if (nullptr == vecBatchProto) { throw std::runtime_error("delete nullptr error for free protobuf vecBatch memory"); } - delete vecBatchProto; //free protobuf vecBatch memory - delete partition_id_cnt_cur_; - delete partition_id_cnt_cache_; - delete fixed_valueBuffer_size_; - delete fixed_nullBuffer_size_; - delete partition_buffer_size_; - delete partition_buffer_idx_base_; - delete partition_buffer_idx_offset_; - delete partition_serialization_size_; + return 0; +} + +int Splitter::StopByRow() { + if (isSpill) { + TIME_NANO_OR_RAISE(total_write_time_, MergeSpilledByRow()); + TIME_NANO_OR_RAISE(total_write_time_, DeleteSpilledTmpFile()); + LogsDebug(" Spill For Splitter Stopped. total_spill_row_num_: %ld ", total_spill_row_num_); + } else { + TIME_NANO_OR_RAISE(total_write_time_, WriteSplitByRow()); + } + if (nullptr == protoRowBatch) { + throw std::runtime_error("delete nullptr error for free protobuf rowBatch memory"); + } return 0; } diff --git a/omnioperator/omniop-spark-extension/cpp/src/shuffle/splitter.h b/omnioperator/omniop-spark-extension/cpp/src/shuffle/splitter.h index ec0cc661f0a49d531b47dab87299cb6a8dfbde2a..9f0e8fa582aa30e7ef6933546367854e50fa986c 100644 --- a/omnioperator/omniop-spark-extension/cpp/src/shuffle/splitter.h +++ b/omnioperator/omniop-spark-extension/cpp/src/shuffle/splitter.h @@ -35,6 +35,7 @@ #include "../common/common.h" #include "vec_data.pb.h" #include "google/protobuf/io/zero_copy_stream_impl.h" +#include "vector/omni_row.h" using namespace std; using namespace spark; @@ -51,13 +52,16 @@ struct SplitRowInfo { }; class Splitter { - virtual int DoSplit(VectorBatch& vb); int WriteDataFileProto(); + int WriteDataFileProtoByRow(); + std::shared_ptr CaculateSpilledTmpFilePartitionOffsets(); + spark::VecType::VecTypeId CastOmniTypeIdToProtoVecType(int32_t omniType); + void SerializingFixedColumns(int32_t partitionId, spark::Vec& vec, int fixColIndexTmp, @@ -70,8 +74,12 @@ class Splitter { int protoSpillPartition(int32_t partition_id, std::unique_ptr &bufferStream); + int protoSpillPartitionByRow(int32_t partition_id, std::unique_ptr &bufferStream); + int32_t ProtoWritePartition(int32_t partition_id, std::unique_ptr &bufferStream, void *bufferOut, int32_t &sizeOut); + int32_t ProtoWritePartitionByRow(int32_t partition_id, std::unique_ptr &bufferStream, void *bufferOut, int32_t &sizeOut); + int ComputeAndCountPartitionId(VectorBatch& vb); int AllocatePartitionBuffers(int32_t partition_id, int32_t new_size); @@ -93,9 +101,28 @@ class Splitter { void MergeSpilled(); + void MergeSpilledByRow(); + void WriteSplit(); + void WriteSplitByRow(); + + // Common structures for row formats and col formats bool isSpill = false; + int64_t total_bytes_written_ = 0; + int64_t total_bytes_spilled_ = 0; + int64_t total_write_time_ = 0; + int64_t total_spill_time_ = 0; + int64_t total_spill_row_num_ = 0; + + // configured local dirs for spilled file + int32_t dir_selection_ = 0; + std::vector sub_dir_selection_; + std::vector configured_dirs_; + + // Data structures required to handle col formats + int64_t total_compute_pid_time_ = 0; + std::vector partition_lengths_; std::vector partition_id_; // 记录当前vb每一行的pid int32_t *partition_id_cnt_cur_; // 统计不同partition记录的行数(当前处理中的vb) uint64_t *partition_id_cnt_cache_; // 统计不同partition记录的行数,cache住的 @@ -117,12 +144,6 @@ class Splitter { int32_t *partition_buffer_idx_base_; //当前已缓存的各partition行数据记录,用于定位缓冲buffer当前可用位置 int32_t *partition_buffer_idx_offset_; //split定长列时用于统计offset的临时变量 uint32_t *partition_serialization_size_; // 记录序列化后的各partition大小,用于stop返回partition偏移 in bytes - - // configured local dirs for spilled file - int32_t dir_selection_ = 0; - std::vector sub_dir_selection_; - std::vector configured_dirs_; - std::vector>>>> partition_cached_vectorbatch_; /* * varchar buffers: @@ -130,14 +151,13 @@ class Splitter { * */ std::vector>> vc_partition_array_buffers_; + spark::VecBatch *vecBatchProto = new VecBatch(); // protobuf 序列化对象结构 - int64_t total_bytes_written_ = 0; - int64_t total_bytes_spilled_ = 0; - int64_t total_write_time_ = 0; - int64_t total_spill_time_ = 0; - int64_t total_compute_pid_time_ = 0; - int64_t total_spill_row_num_ = 0; - std::vector partition_lengths_; + // Data structures required to handle row formats + std::vector> partition_rows; // pid : std::vector + uint64_t total_input_size = 0; // total row size in bytes + uint32_t expansion = 2; // expansion coefficient + spark::ProtoRowBatch *protoRowBatch = new ProtoRowBatch(); private: void ReleaseVarcharVector() @@ -167,37 +187,41 @@ private: delete vb; } + // Data structures required to handle col formats std::set varcharVectorCache; - std::vector vector_batch_col_types_; - InputDataTypes input_col_types; - std::vector binary_array_empirical_size_; - omniruntime::vec::VectorBatch *inputVecBatch = nullptr; public: + // Common structures for row formats and col formats bool singlePartitionFlag = false; int32_t num_partitions_; SplitOptions options_; // 分区数 int32_t num_fields_; - + InputDataTypes input_col_types; + std::vector proto_col_types_; // Avoid repeated type conversion during the split process. + omniruntime::vec::VectorBatch *inputVecBatch = nullptr; std::map> spilled_tmp_files_info_; - spark::VecBatch *vecBatchProto = new VecBatch(); // protobuf 序列化对象结构 - virtual int Split_Init(); virtual int Split(VectorBatch& vb); + virtual int SplitByRow(VectorBatch* vb); + int Stop(); + int StopByRow(); + int SpillToTmpFile(); + int SpillToTmpFileByRow(); + Splitter(InputDataTypes inputDataTypes, int32_t num_cols, int32_t num_partitions, SplitOptions options, bool flag); - static std::shared_ptr Make( + static Splitter *Make( const std::string &short_name, InputDataTypes inputDataTypes, int32_t num_cols, @@ -220,6 +244,24 @@ public: const std::vector& PartitionLengths() const { return partition_lengths_; } + virtual ~Splitter() + { + delete vecBatchProto; //free protobuf vecBatch memory + delete protoRowBatch; //free protobuf rowBatch memory + delete[] partition_id_cnt_cur_; + delete[] partition_id_cnt_cache_; + delete[] partition_buffer_size_; + delete[] partition_buffer_idx_base_; + delete[] partition_buffer_idx_offset_; + delete[] partition_serialization_size_; + delete[] fixed_valueBuffer_size_; + delete[] fixed_nullBuffer_size_; + partition_fixed_width_buffers_.clear(); + partition_binary_builders_.clear(); + partition_cached_vectorbatch_.clear(); + spilled_tmp_files_info_.clear(); + } + omniruntime::vec::VectorBatch *GetInputVecBatch() { return inputVecBatch; diff --git a/omnioperator/omniop-spark-extension/cpp/src/shuffle/type.h b/omnioperator/omniop-spark-extension/cpp/src/shuffle/type.h index 04d90130dea30a83651fff3526c08dc0992f9928..61b4bc1498b2dbc3b2435f428401c36bb17c5725 100644 --- a/omnioperator/omniop-spark-extension/cpp/src/shuffle/type.h +++ b/omnioperator/omniop-spark-extension/cpp/src/shuffle/type.h @@ -43,7 +43,8 @@ struct SplitOptions { Allocator *allocator = Allocator::GetAllocator(); uint64_t spill_batch_row_num = 4096; // default value - uint64_t spill_mem_threshold = 1024 * 1024 * 1024; // default value + uint64_t task_spill_mem_threshold = 1024 * 1024 * 1024; // default value + uint64_t executor_spill_mem_threshold = UINT64_MAX; // default value uint64_t compress_block_size = 64 * 1024; // default value static SplitOptions Defaults(); diff --git a/omnioperator/omniop-spark-extension/cpp/test/CMakeLists.txt b/omnioperator/omniop-spark-extension/cpp/test/CMakeLists.txt index f53ac2ad45e95bddfe3a15a4da78b245e373981e..287223e1dfbd63405f6d0ae7d1605698b9566c3d 100644 --- a/omnioperator/omniop-spark-extension/cpp/test/CMakeLists.txt +++ b/omnioperator/omniop-spark-extension/cpp/test/CMakeLists.txt @@ -2,12 +2,14 @@ aux_source_directory(${CMAKE_CURRENT_LIST_DIR} TEST_ROOT_SRCS) add_subdirectory(shuffle) add_subdirectory(utils) +add_subdirectory(benchmark) # configure set(TP_TEST_TARGET tptest) set(MY_LINK shuffletest utilstest + benchmark_test ) # find gtest package @@ -27,7 +29,7 @@ target_link_libraries(${TP_TEST_TARGET} pthread stdc++ dl - boostkit-omniop-vector-1.4.0-aarch64 + boostkit-omniop-vector-1.5.0-aarch64 securec spark_columnar_plugin) diff --git a/omnioperator/omniop-spark-extension/cpp/test/benchmark/CMakeLists.txt b/omnioperator/omniop-spark-extension/cpp/test/benchmark/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..e4b47d29549ad375eea27fbb9a7105569f5e8a73 --- /dev/null +++ b/omnioperator/omniop-spark-extension/cpp/test/benchmark/CMakeLists.txt @@ -0,0 +1,8 @@ +aux_source_directory(${CMAKE_CURRENT_LIST_DIR} BENCHMARK_LIST) +set(BENCHMARK_TEST_TARGET benchmark_test) +add_library(${BENCHMARK_TEST_TARGET} STATIC ${BENCHMARK_LIST}) +target_compile_options(${BENCHMARK_TEST_TARGET} PUBLIC ) +target_link_libraries(${BENCHMARK_TEST_TARGET} utilstest) +target_include_directories(${BENCHMARK_TEST_TARGET} PUBLIC ${CMAKE_BINARY_DIR}/src) +target_include_directories(${BENCHMARK_TEST_TARGET} PUBLIC $ENV{JAVA_HOME}/include) +target_include_directories(${BENCHMARK_TEST_TARGET} PUBLIC $ENV{JAVA_HOME}/include/linux) \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension/cpp/test/benchmark/shuffle_benchmark.cpp b/omnioperator/omniop-spark-extension/cpp/test/benchmark/shuffle_benchmark.cpp new file mode 100644 index 0000000000000000000000000000000000000000..db273192fc88cde342ca49951e1d740cdd8874a2 --- /dev/null +++ b/omnioperator/omniop-spark-extension/cpp/test/benchmark/shuffle_benchmark.cpp @@ -0,0 +1,386 @@ +/** + * Copyright (C) 2020-2024. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "gtest/gtest.h" +#include "../utils/test_utils.h" + +using namespace omniruntime::type; +using namespace omniruntime; + +static constexpr int ROWS = 300; +static constexpr int COLS = 300; +static constexpr int PARTITION_SIZE = 600; +static constexpr int BATCH_COUNT = 20; + +static int generateRandomNumber() { + return std::rand() % PARTITION_SIZE; +} + +// construct data +static std::vector constructVecs(int rows, int cols, int* inputTypeIds, double nullProbability) { + std::srand(time(nullptr)); + std::vector vecs; + vecs.resize(cols); + + for (int i = 0; i < cols; ++i) { + BaseVector *vector = VectorHelper::CreateFlatVector(inputTypeIds[i], rows); + if (inputTypeIds[i] == OMNI_VARCHAR) { + auto strVec = reinterpret_cast> *>(vector); + for(int j = 0; j < rows; ++j) { + auto randNum = static_cast(std::rand()) / RAND_MAX; + if (randNum < nullProbability) { + strVec->SetNull(j); + } else { + std::string_view str("hello world"); + strVec->SetValue(j, str); + } + } + } else if (inputTypeIds[i] == OMNI_LONG) { + auto longVec = reinterpret_cast *>(vector); + for (int j = 0; j < rows; ++j) { + auto randNum = static_cast(std::rand()) / RAND_MAX; + if (randNum < nullProbability) { + longVec->SetNull(j); + } else { + long value = generateRandomNumber(); + longVec->SetValue(j, value); + } + } + } + vecs[i] = vector; + } + return vecs; +} + +// generate partitionId +static Vector* constructPidVec(int rows) { + srand(time(nullptr)); + auto pidVec = new Vector(rows); + for (int j = 0; j < rows; ++j) { + int pid = generateRandomNumber(); + pidVec->SetValue(j, pid); + } + return pidVec; +} + +static std::vector generateData(int rows, int cols, int* inputTypeIds, double nullProbability) { + std::vector vecBatches; + vecBatches.resize(BATCH_COUNT); + for (int i = 0; i < BATCH_COUNT; ++i) { + auto vecBatch = new VectorBatch(rows); + auto pidVec = constructPidVec(rows); + vecBatch->Append(pidVec); + auto vecs = constructVecs(rows, cols, inputTypeIds, nullProbability); + for (int j = 0; j < vecs.size(); ++j) { + vecBatch->Append(vecs[j]); + } + vecBatches[i] = vecBatch; + } + return vecBatches; +} + +static std::vector copyData(const std::vector& origin) { + std::vector vecBatches; + vecBatches.resize(origin.size()); + for (int i = 0; i < origin.size(); ++i) { + auto originBatch = origin[i]; + auto vecBatch = new VectorBatch(originBatch->GetRowCount()); + + for (int j = 0; j < originBatch->GetVectorCount(); ++j) { + BaseVector *vec = originBatch->Get(j); + BaseVector *sliceVec = VectorHelper::SliceVector(vec, 0, originBatch->GetRowCount()); + vecBatch->Append(sliceVec); + } + vecBatches[i] = vecBatch; + } + return vecBatches; +} + +static void bm_row_handle(const std::vector& vecBatches, int *inputTypeIds, int cols) { + Timer timer; + timer.SetStart(); + + InputDataTypes inputDataTypes; + inputDataTypes.inputVecTypeIds = inputTypeIds; + + auto splitOptions = SplitOptions::Defaults(); + splitOptions.buffer_size = 4096; + + auto compression_type_result = GetCompressionType("lz4"); + splitOptions.compression_type = compression_type_result; + auto splitter = Splitter::Make("hash", inputDataTypes, cols, PARTITION_SIZE, std::move(splitOptions)); + + for ( int i = 0; i < vecBatches.size(); ++i) { + VectorBatch *vb = vecBatches[i]; + splitter->SplitByRow(vb); + } + splitter->StopByRow(); + + timer.CalculateElapse(); + double wallElapsed = timer.GetWallElapse(); + double cpuElapsed = timer.GetCpuElapse(); + std::cout << "row time, wall " << wallElapsed << " cpu " << cpuElapsed << std::endl; + + delete splitter; +} + +static void bm_col_handle(const std::vector& vecBatches, int *inputTypeIds, int cols) { + Timer timer; + timer.SetStart(); + + InputDataTypes inputDataTypes; + inputDataTypes.inputVecTypeIds = inputTypeIds; + + auto splitOptions = SplitOptions::Defaults(); + splitOptions.buffer_size = 4096; + + auto compression_type_result = GetCompressionType("lz4"); + splitOptions.compression_type = compression_type_result; + auto splitter = Splitter::Make("hash", inputDataTypes, cols, PARTITION_SIZE, std::move(splitOptions)); + + for ( int i = 0; i < vecBatches.size(); ++i) { + VectorBatch *vb = vecBatches[i]; + splitter->Split(*vb); + } + splitter->Stop(); + + timer.CalculateElapse(); + double wallElapsed = timer.GetWallElapse(); + double cpuElapsed = timer.GetCpuElapse(); + std::cout << "col time, wall " << wallElapsed << " cpu " << cpuElapsed << std::endl; + + delete splitter; +} + +TEST(shuffle_benchmark, null_0) { + double strProbability = 0.25; + double nullProbability = 0; + + int *inputTypeIds = new int32_t[COLS]; + for (int i = 0; i < COLS; ++i) { + double randNum = static_cast(std::rand()) / RAND_MAX; + if (randNum < strProbability) { + inputTypeIds[i] = OMNI_VARCHAR; + } else { + inputTypeIds[i] = OMNI_LONG; + } + } + + auto vecBatches1 = generateData(ROWS, COLS, inputTypeIds, nullProbability); + auto vecBatches2 = copyData(vecBatches1); + + std::cout << "rows: " << ROWS << ", cols: " << COLS << ", null probability: " << nullProbability << std::endl; + bm_row_handle(vecBatches1, inputTypeIds, COLS); + bm_col_handle(vecBatches2, inputTypeIds, COLS); + delete[] inputTypeIds; +} + +TEST(shuffle_benchmark, null_25) { + double strProbability = 0.25; + double nullProbability = 0.25; + + int *inputTypeIds = new int32_t[COLS]; + for (int i = 0; i < COLS; ++i) { + double randNum = static_cast(std::rand()) / RAND_MAX; + if (randNum < strProbability) { + inputTypeIds[i] = OMNI_VARCHAR; + } else { + inputTypeIds[i] = OMNI_LONG; + } + } + + auto vecBatches1 = generateData(ROWS, COLS, inputTypeIds, nullProbability); + auto vecBatches2 = copyData(vecBatches1); + + std::cout << "rows: " << ROWS << ", cols: " << COLS << ", null probability: " << nullProbability << std::endl; + bm_row_handle(vecBatches1, inputTypeIds, COLS); + bm_col_handle(vecBatches2, inputTypeIds, COLS); + delete[] inputTypeIds; +} + +TEST(shuffle_benchmark, null_50) { + double strProbability = 0.25; + double nullProbability = 0.5; + + int *inputTypeIds = new int32_t[COLS]; + for (int i = 0; i < COLS; ++i) { + double randNum = static_cast(std::rand()) / RAND_MAX; + if (randNum < strProbability) { + inputTypeIds[i] = OMNI_VARCHAR; + } else { + inputTypeIds[i] = OMNI_LONG; + } + } + + auto vecBatches1 = generateData(ROWS, COLS, inputTypeIds, nullProbability); + auto vecBatches2 = copyData(vecBatches1); + + std::cout << "rows: " << ROWS << ", cols: " << COLS << ", null probability: " << nullProbability << std::endl; + bm_row_handle(vecBatches1, inputTypeIds, COLS); + bm_col_handle(vecBatches2, inputTypeIds, COLS); + delete[] inputTypeIds; +} + +TEST(shuffle_benchmark, null_75) { + double strProbability = 0.25; + double nullProbability = 0.75; + + int *inputTypeIds = new int32_t[COLS]; + for (int i = 0; i < COLS; ++i) { + double randNum = static_cast(std::rand()) / RAND_MAX; + if (randNum < strProbability) { + inputTypeIds[i] = OMNI_VARCHAR; + } else { + inputTypeIds[i] = OMNI_LONG; + } + } + + auto vecBatches1 = generateData(ROWS, COLS, inputTypeIds, nullProbability); + auto vecBatches2 = copyData(vecBatches1); + + std::cout << "rows: " << ROWS << ", cols: " << COLS << ", null probability: " << nullProbability << std::endl; + bm_row_handle(vecBatches1, inputTypeIds, COLS); + bm_col_handle(vecBatches2, inputTypeIds, COLS); + delete[] inputTypeIds; +} + +TEST(shuffle_benchmark, null_100) { + double strProbability = 0.25; + double nullProbability = 1; + + int *inputTypeIds = new int32_t[COLS]; + for (int i = 0; i < COLS; ++i) { + double randNum = static_cast(std::rand()) / RAND_MAX; + if (randNum < strProbability) { + inputTypeIds[i] = OMNI_VARCHAR; + } else { + inputTypeIds[i] = OMNI_LONG; + } + } + + auto vecBatches1 = generateData(ROWS, COLS, inputTypeIds, nullProbability); + auto vecBatches2 = copyData(vecBatches1); + + std::cout << "rows: " << ROWS << ", cols: " << COLS << ", null probability: " << nullProbability << std::endl; + bm_row_handle(vecBatches1, inputTypeIds, COLS); + bm_col_handle(vecBatches2, inputTypeIds, COLS); + delete[] inputTypeIds; +} + +TEST(shuffle_benchmark, null_25_row_900_col_100) { + double strProbability = 0.25; + double nullProbability = 0.25; + int rows = 900; + int cols = 100; + + int *inputTypeIds = new int32_t[cols]; + for (int i = 0; i < cols; ++i) { + double randNum = static_cast(std::rand()) / RAND_MAX; + if (randNum < strProbability) { + inputTypeIds[i] = OMNI_VARCHAR; + } else { + inputTypeIds[i] = OMNI_LONG; + } + } + + auto vecBatches1 = generateData(rows, cols, inputTypeIds, nullProbability); + auto vecBatches2 = copyData(vecBatches1); + + std::cout << "rows: " << rows << ", cols: " << cols << ", null probability: " << nullProbability << std::endl; + bm_row_handle(vecBatches1, inputTypeIds, cols); + bm_col_handle(vecBatches2, inputTypeIds, cols); + delete[] inputTypeIds; +} + +TEST(shuffle_benchmark, null_25_row_1800_col_50) { + double strProbability = 0.25; + double nullProbability = 0.25; + int rows = 1800; + int cols = 50; + + int *inputTypeIds = new int32_t[cols]; + for (int i = 0; i < cols; ++i) { + double randNum = static_cast(std::rand()) / RAND_MAX; + if (randNum < strProbability) { + inputTypeIds[i] = OMNI_VARCHAR; + } else { + inputTypeIds[i] = OMNI_LONG; + } + } + + auto vecBatches1 = generateData(rows, cols, inputTypeIds, nullProbability); + auto vecBatches2 = copyData(vecBatches1); + + std::cout << "rows: " << rows << ", cols: " << cols << ", null probability: " << nullProbability << std::endl; + bm_row_handle(vecBatches1, inputTypeIds, cols); + bm_col_handle(vecBatches2, inputTypeIds, cols); + delete[] inputTypeIds; +} + +TEST(shuffle_benchmark, null_25_row_9000_col_10) { + double strProbability = 0.25; + double nullProbability = 0.25; + int rows = 9000; + int cols = 10; + + int *inputTypeIds = new int32_t[cols]; + for (int i = 0; i < cols; ++i) { + double randNum = static_cast(std::rand()) / RAND_MAX; + if (randNum < strProbability) { + inputTypeIds[i] = OMNI_VARCHAR; + } else { + inputTypeIds[i] = OMNI_LONG; + } + } + + auto vecBatches1 = generateData(rows, cols, inputTypeIds, nullProbability); + auto vecBatches2 = copyData(vecBatches1); + + std::cout << "rows: " << rows << ", cols: " << cols << ", null probability: " << nullProbability << std::endl; + bm_row_handle(vecBatches1, inputTypeIds, cols); + bm_col_handle(vecBatches2, inputTypeIds, cols); + delete[] inputTypeIds; +} + +TEST(shuffle_benchmark, null_25_row_18000_col_5) { + double strProbability = 0.25; + double nullProbability = 0.25; + int rows = 18000; + int cols = 5; + + int *inputTypeIds = new int32_t[cols]; + for (int i = 0; i < cols; ++i) { + double randNum = static_cast(std::rand()) / RAND_MAX; + if (randNum < strProbability) { + inputTypeIds[i] = OMNI_VARCHAR; + } else { + inputTypeIds[i] = OMNI_LONG; + } + } + + auto vecBatches1 = generateData(rows, cols, inputTypeIds, nullProbability); + auto vecBatches2 = copyData(vecBatches1); + + std::cout << "rows: " << rows << ", cols: " << cols << ", null probability: " << nullProbability << std::endl; + bm_row_handle(vecBatches1, inputTypeIds, cols); + bm_col_handle(vecBatches2, inputTypeIds, cols); + delete[] inputTypeIds; +} \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension/cpp/test/shuffle/shuffle_test.cpp b/omnioperator/omniop-spark-extension/cpp/test/shuffle/shuffle_test.cpp index 3031943eeae22b591de4c4b3693eb1e1744b3ac3..27e1297e75f155862bbbf4cd5d82db89827cca80 100644 --- a/omnioperator/omniop-spark-extension/cpp/test/shuffle/shuffle_test.cpp +++ b/omnioperator/omniop-spark-extension/cpp/test/shuffle/shuffle_test.cpp @@ -39,7 +39,6 @@ protected: if (IsFileExist(tmpTestingDir)) { DeletePathAll(tmpTestingDir.c_str()); } - testShuffleSplitterHolder.Clear(); } // run before each case... @@ -63,7 +62,7 @@ TEST_F (ShuffleTest, Split_SingleVarChar) { inputDataTypes.inputVecTypeIds = inputVecTypeIds; inputDataTypes.inputDataPrecisions = new uint32_t[colNumber]; inputDataTypes.inputDataScales = new uint32_t[colNumber]; - int splitterId = Test_splitter_nativeMake("hash", + long splitterAddr = Test_splitter_nativeMake("hash", 4, inputDataTypes, colNumber, @@ -73,21 +72,21 @@ TEST_F (ShuffleTest, Split_SingleVarChar) { 0, tmpTestingDir); VectorBatch* vb1 = CreateVectorBatch_1row_varchar_withPid(3, "N"); - Test_splitter_split(splitterId, vb1); + Test_splitter_split(splitterAddr, vb1); VectorBatch* vb2 = CreateVectorBatch_1row_varchar_withPid(2, "F"); - Test_splitter_split(splitterId, vb2); + Test_splitter_split(splitterAddr, vb2); VectorBatch* vb3 = CreateVectorBatch_1row_varchar_withPid(3, "N"); - Test_splitter_split(splitterId, vb3); + Test_splitter_split(splitterAddr, vb3); VectorBatch* vb4 = CreateVectorBatch_1row_varchar_withPid(2, "F"); - Test_splitter_split(splitterId, vb4); + Test_splitter_split(splitterAddr, vb4); VectorBatch* vb5 = CreateVectorBatch_1row_varchar_withPid(2, "F"); - Test_splitter_split(splitterId, vb5); + Test_splitter_split(splitterAddr, vb5); VectorBatch* vb6 = CreateVectorBatch_1row_varchar_withPid(1, "R"); - Test_splitter_split(splitterId, vb6); + Test_splitter_split(splitterAddr, vb6); VectorBatch* vb7 = CreateVectorBatch_1row_varchar_withPid(3, "N"); - Test_splitter_split(splitterId, vb7); - Test_splitter_stop(splitterId); - Test_splitter_close(splitterId); + Test_splitter_split(splitterAddr, vb7); + Test_splitter_stop(splitterAddr); + Test_splitter_close(splitterAddr); delete[] inputDataTypes.inputDataPrecisions; delete[] inputDataTypes.inputDataScales; } @@ -101,7 +100,7 @@ TEST_F (ShuffleTest, Split_Fixed_Cols) { inputDataTypes.inputDataPrecisions = new uint32_t[colNumber]; inputDataTypes.inputDataScales = new uint32_t[colNumber]; int partitionNum = 4; - int splitterId = Test_splitter_nativeMake("hash", + long splitterAddr = Test_splitter_nativeMake("hash", partitionNum, inputDataTypes, colNumber, @@ -112,10 +111,10 @@ TEST_F (ShuffleTest, Split_Fixed_Cols) { tmpTestingDir); for (uint64_t j = 0; j < 1; j++) { VectorBatch* vb = CreateVectorBatch_5fixedCols_withPid(partitionNum, 999); - Test_splitter_split(splitterId, vb); + Test_splitter_split(splitterAddr, vb); } - Test_splitter_stop(splitterId); - Test_splitter_close(splitterId); + Test_splitter_stop(splitterAddr); + Test_splitter_close(splitterAddr); delete[] inputDataTypes.inputDataPrecisions; delete[] inputDataTypes.inputDataScales; } @@ -129,7 +128,7 @@ TEST_F (ShuffleTest, Split_Fixed_SinglePartition_SomeNullRow) { inputDataTypes.inputDataPrecisions = new uint32_t[colNumber]; inputDataTypes.inputDataScales = new uint32_t[colNumber]; int partitionNum = 1; - int splitterId = Test_splitter_nativeMake("single", + long splitterAddr = Test_splitter_nativeMake("single", partitionNum, inputDataTypes, colNumber, @@ -140,10 +139,10 @@ TEST_F (ShuffleTest, Split_Fixed_SinglePartition_SomeNullRow) { tmpTestingDir); for (uint64_t j = 0; j < 100; j++) { VectorBatch* vb = CreateVectorBatch_someNullRow_vectorBatch(); - Test_splitter_split(splitterId, vb); + Test_splitter_split(splitterAddr, vb); } - Test_splitter_stop(splitterId); - Test_splitter_close(splitterId); + Test_splitter_stop(splitterAddr); + Test_splitter_close(splitterAddr); delete[] inputDataTypes.inputDataPrecisions; delete[] inputDataTypes.inputDataScales; } @@ -157,7 +156,7 @@ TEST_F (ShuffleTest, Split_Fixed_SinglePartition_SomeNullCol) { inputDataTypes.inputDataPrecisions = new uint32_t[colNumber]; inputDataTypes.inputDataScales = new uint32_t[colNumber]; int partitionNum = 1; - int splitterId = Test_splitter_nativeMake("single", + long splitterAddr = Test_splitter_nativeMake("single", partitionNum, inputDataTypes, colNumber, @@ -168,10 +167,10 @@ TEST_F (ShuffleTest, Split_Fixed_SinglePartition_SomeNullCol) { tmpTestingDir); for (uint64_t j = 0; j < 100; j++) { VectorBatch* vb = CreateVectorBatch_someNullCol_vectorBatch(); - Test_splitter_split(splitterId, vb); + Test_splitter_split(splitterAddr, vb); } - Test_splitter_stop(splitterId); - Test_splitter_close(splitterId); + Test_splitter_stop(splitterAddr); + Test_splitter_close(splitterAddr); delete[] inputDataTypes.inputDataPrecisions; delete[] inputDataTypes.inputDataScales; } @@ -205,7 +204,7 @@ TEST_F (ShuffleTest, Split_Mix_LargeSize) { inputDataTypes.inputDataPrecisions = new uint32_t[colNumber]; inputDataTypes.inputDataScales = new uint32_t[colNumber]; int partitionNum = 4; - int splitterId = Test_splitter_nativeMake("hash", + long splitterAddr = Test_splitter_nativeMake("hash", partitionNum, inputDataTypes, colNumber, @@ -216,10 +215,10 @@ TEST_F (ShuffleTest, Split_Mix_LargeSize) { tmpTestingDir); for (uint64_t j = 0; j < 999; j++) { VectorBatch* vb = CreateVectorBatch_4col_withPid(partitionNum, 999); - Test_splitter_split(splitterId, vb); + Test_splitter_split(splitterAddr, vb); } - Test_splitter_stop(splitterId); - Test_splitter_close(splitterId); + Test_splitter_stop(splitterAddr); + Test_splitter_close(splitterAddr); delete[] inputDataTypes.inputDataPrecisions; delete[] inputDataTypes.inputDataScales; } @@ -233,7 +232,7 @@ TEST_F (ShuffleTest, Split_Short_10WRows) { inputDataTypes.inputDataPrecisions = new uint32_t[colNumber]; inputDataTypes.inputDataScales = new uint32_t[colNumber]; int partitionNum = 10; - int splitterId = Test_splitter_nativeMake("hash", + long splitterAddr = Test_splitter_nativeMake("hash", partitionNum, inputDataTypes, colNumber, @@ -244,10 +243,10 @@ TEST_F (ShuffleTest, Split_Short_10WRows) { tmpTestingDir); for (uint64_t j = 0; j < 100; j++) { VectorBatch* vb = CreateVectorBatch_1FixCol_withPid(partitionNum, 1000, ShortType()); - Test_splitter_split(splitterId, vb); + Test_splitter_split(splitterAddr, vb); } - Test_splitter_stop(splitterId); - Test_splitter_close(splitterId); + Test_splitter_stop(splitterAddr); + Test_splitter_close(splitterAddr); delete[] inputDataTypes.inputDataPrecisions; delete[] inputDataTypes.inputDataScales; } @@ -261,7 +260,7 @@ TEST_F (ShuffleTest, Split_Boolean_10WRows) { inputDataTypes.inputDataPrecisions = new uint32_t[colNumber]; inputDataTypes.inputDataScales = new uint32_t[colNumber]; int partitionNum = 10; - int splitterId = Test_splitter_nativeMake("hash", + long splitterAddr = Test_splitter_nativeMake("hash", partitionNum, inputDataTypes, colNumber, @@ -272,10 +271,10 @@ TEST_F (ShuffleTest, Split_Boolean_10WRows) { tmpTestingDir); for (uint64_t j = 0; j < 100; j++) { VectorBatch* vb = CreateVectorBatch_1FixCol_withPid(partitionNum, 1000, BooleanType()); - Test_splitter_split(splitterId, vb); + Test_splitter_split(splitterAddr, vb); } - Test_splitter_stop(splitterId); - Test_splitter_close(splitterId); + Test_splitter_stop(splitterAddr); + Test_splitter_close(splitterAddr); delete[] inputDataTypes.inputDataPrecisions; delete[] inputDataTypes.inputDataScales; } @@ -289,7 +288,7 @@ TEST_F (ShuffleTest, Split_Long_100WRows) { inputDataTypes.inputDataPrecisions = new uint32_t[colNumber]; inputDataTypes.inputDataScales = new uint32_t[colNumber]; int partitionNum = 10; - int splitterId = Test_splitter_nativeMake("hash", + long splitterAddr = Test_splitter_nativeMake("hash", partitionNum, inputDataTypes, colNumber, @@ -300,10 +299,10 @@ TEST_F (ShuffleTest, Split_Long_100WRows) { tmpTestingDir); for (uint64_t j = 0; j < 100; j++) { VectorBatch* vb = CreateVectorBatch_1FixCol_withPid(partitionNum, 10000, LongType()); - Test_splitter_split(splitterId, vb); + Test_splitter_split(splitterAddr, vb); } - Test_splitter_stop(splitterId); - Test_splitter_close(splitterId); + Test_splitter_stop(splitterAddr); + Test_splitter_close(splitterAddr); delete[] inputDataTypes.inputDataPrecisions; delete[] inputDataTypes.inputDataScales; } @@ -317,7 +316,7 @@ TEST_F (ShuffleTest, Split_VarChar_LargeSize) { inputDataTypes.inputDataPrecisions = new uint32_t[colNumber]; inputDataTypes.inputDataScales = new uint32_t[colNumber]; int partitionNum = 4; - int splitterId = Test_splitter_nativeMake("hash", + long splitterAddr = Test_splitter_nativeMake("hash", partitionNum, inputDataTypes, colNumber, @@ -328,10 +327,10 @@ TEST_F (ShuffleTest, Split_VarChar_LargeSize) { tmpTestingDir); for (uint64_t j = 0; j < 99; j++) { VectorBatch* vb = CreateVectorBatch_4varcharCols_withPid(partitionNum, 99); - Test_splitter_split(splitterId, vb); + Test_splitter_split(splitterAddr, vb); } - Test_splitter_stop(splitterId); - Test_splitter_close(splitterId); + Test_splitter_stop(splitterAddr); + Test_splitter_close(splitterAddr); delete[] inputDataTypes.inputDataPrecisions; delete[] inputDataTypes.inputDataScales; } @@ -345,7 +344,7 @@ TEST_F (ShuffleTest, Split_VarChar_First) { inputDataTypes.inputDataPrecisions = new uint32_t[colNumber]; inputDataTypes.inputDataScales = new uint32_t[colNumber]; int partitionNum = 4; - int splitterId = Test_splitter_nativeMake("hash", + long splitterAddr = Test_splitter_nativeMake("hash", partitionNum, inputDataTypes, colNumber, @@ -355,27 +354,27 @@ TEST_F (ShuffleTest, Split_VarChar_First) { 0, tmpTestingDir); VectorBatch* vb0 = CreateVectorBatch_2column_1row_withPid(0, "corpbrand #4", 1); - Test_splitter_split(splitterId, vb0); + Test_splitter_split(splitterAddr, vb0); VectorBatch* vb1 = CreateVectorBatch_2column_1row_withPid(3, "brandmaxi #4", 1); - Test_splitter_split(splitterId, vb1); + Test_splitter_split(splitterAddr, vb1); VectorBatch* vb2 = CreateVectorBatch_2column_1row_withPid(1, "edu packnameless #9", 1); - Test_splitter_split(splitterId, vb2); + Test_splitter_split(splitterAddr, vb2); VectorBatch* vb3 = CreateVectorBatch_2column_1row_withPid(1, "amalgunivamalg #11", 1); - Test_splitter_split(splitterId, vb3); + Test_splitter_split(splitterAddr, vb3); VectorBatch* vb4 = CreateVectorBatch_2column_1row_withPid(0, "brandcorp #2", 1); - Test_splitter_split(splitterId, vb4); + Test_splitter_split(splitterAddr, vb4); VectorBatch* vb5 = CreateVectorBatch_2column_1row_withPid(0, "scholarbrand #2", 1); - Test_splitter_split(splitterId, vb5); + Test_splitter_split(splitterAddr, vb5); VectorBatch* vb6 = CreateVectorBatch_2column_1row_withPid(2, "edu packcorp #6", 1); - Test_splitter_split(splitterId, vb6); + Test_splitter_split(splitterAddr, vb6); VectorBatch* vb7 = CreateVectorBatch_2column_1row_withPid(2, "edu packamalg #1", 1); - Test_splitter_split(splitterId, vb7); + Test_splitter_split(splitterAddr, vb7); VectorBatch* vb8 = CreateVectorBatch_2column_1row_withPid(0, "brandnameless #8", 1); - Test_splitter_split(splitterId, vb8); + Test_splitter_split(splitterAddr, vb8); VectorBatch* vb9 = CreateVectorBatch_2column_1row_withPid(2, "univmaxi #2", 1); - Test_splitter_split(splitterId, vb9); - Test_splitter_stop(splitterId); - Test_splitter_close(splitterId); + Test_splitter_split(splitterAddr, vb9); + Test_splitter_stop(splitterAddr); + Test_splitter_close(splitterAddr); delete[] inputDataTypes.inputDataPrecisions; delete[] inputDataTypes.inputDataScales; } @@ -389,7 +388,7 @@ TEST_F (ShuffleTest, Split_Dictionary) { inputDataTypes.inputDataPrecisions = new uint32_t[colNumber]; inputDataTypes.inputDataScales = new uint32_t[colNumber]; int partitionNum = 4; - int splitterId = Test_splitter_nativeMake("hash", + long splitterAddr = Test_splitter_nativeMake("hash", partitionNum, inputDataTypes, colNumber, @@ -400,10 +399,10 @@ TEST_F (ShuffleTest, Split_Dictionary) { tmpTestingDir); for (uint64_t j = 0; j < 2; j++) { VectorBatch* vb = CreateVectorBatch_2dictionaryCols_withPid(partitionNum); - Test_splitter_split(splitterId, vb); + Test_splitter_split(splitterAddr, vb); } - Test_splitter_stop(splitterId); - Test_splitter_close(splitterId); + Test_splitter_stop(splitterAddr); + Test_splitter_close(splitterAddr); delete[] inputDataTypes.inputDataPrecisions; delete[] inputDataTypes.inputDataScales; } @@ -417,7 +416,7 @@ TEST_F (ShuffleTest, Split_Char) { inputDataTypes.inputDataPrecisions = new uint32_t[colNumber]; inputDataTypes.inputDataScales = new uint32_t[colNumber]; int partitionNum = 4; - int splitterId = Test_splitter_nativeMake("hash", + long splitterAddr = Test_splitter_nativeMake("hash", partitionNum, inputDataTypes, colNumber, @@ -428,10 +427,10 @@ TEST_F (ShuffleTest, Split_Char) { tmpTestingDir); for (uint64_t j = 0; j < 99; j++) { VectorBatch* vb = CreateVectorBatch_4charCols_withPid(partitionNum, 99); - Test_splitter_split(splitterId, vb); + Test_splitter_split(splitterAddr, vb); } - Test_splitter_stop(splitterId); - Test_splitter_close(splitterId); + Test_splitter_stop(splitterAddr); + Test_splitter_close(splitterAddr); delete[] inputDataTypes.inputDataPrecisions; delete[] inputDataTypes.inputDataScales; } @@ -445,7 +444,7 @@ TEST_F (ShuffleTest, Split_Decimal128) { inputDataTypes.inputDataPrecisions = new uint32_t[colNumber]; inputDataTypes.inputDataScales = new uint32_t[colNumber]; int partitionNum = 4; - int splitterId = Test_splitter_nativeMake("hash", + long splitterAddr = Test_splitter_nativeMake("hash", partitionNum, inputDataTypes, colNumber, @@ -456,10 +455,10 @@ TEST_F (ShuffleTest, Split_Decimal128) { tmpTestingDir); for (uint64_t j = 0; j < 999; j++) { VectorBatch* vb = CreateVectorBatch_1decimal128Col_withPid(partitionNum, 999); - Test_splitter_split(splitterId, vb); + Test_splitter_split(splitterAddr, vb); } - Test_splitter_stop(splitterId); - Test_splitter_close(splitterId); + Test_splitter_stop(splitterAddr); + Test_splitter_close(splitterAddr); delete[] inputDataTypes.inputDataPrecisions; delete[] inputDataTypes.inputDataScales; } @@ -473,7 +472,7 @@ TEST_F (ShuffleTest, Split_Decimal64) { inputDataTypes.inputDataPrecisions = new uint32_t[colNumber]; inputDataTypes.inputDataScales = new uint32_t[colNumber]; int partitionNum = 4; - int splitterId = Test_splitter_nativeMake("hash", + long splitterAddr = Test_splitter_nativeMake("hash", partitionNum, inputDataTypes, colNumber, @@ -484,10 +483,10 @@ TEST_F (ShuffleTest, Split_Decimal64) { tmpTestingDir); for (uint64_t j = 0; j < 999; j++) { VectorBatch* vb = CreateVectorBatch_1decimal64Col_withPid(partitionNum, 999); - Test_splitter_split(splitterId, vb); + Test_splitter_split(splitterAddr, vb); } - Test_splitter_stop(splitterId); - Test_splitter_close(splitterId); + Test_splitter_stop(splitterAddr); + Test_splitter_close(splitterAddr); delete[] inputDataTypes.inputDataPrecisions; delete[] inputDataTypes.inputDataScales; } @@ -501,7 +500,7 @@ TEST_F (ShuffleTest, Split_Decimal64_128) { inputDataTypes.inputDataPrecisions = new uint32_t[colNumber]; inputDataTypes.inputDataScales = new uint32_t[colNumber]; int partitionNum = 4; - int splitterId = Test_splitter_nativeMake("hash", + long splitterAddr = Test_splitter_nativeMake("hash", partitionNum, inputDataTypes, colNumber, @@ -512,10 +511,10 @@ TEST_F (ShuffleTest, Split_Decimal64_128) { tmpTestingDir); for (uint64_t j = 0; j < 999; j++) { VectorBatch* vb = CreateVectorBatch_2decimalCol_withPid(partitionNum, 999); - Test_splitter_split(splitterId, vb); + Test_splitter_split(splitterAddr, vb); } - Test_splitter_stop(splitterId); - Test_splitter_close(splitterId); + Test_splitter_stop(splitterAddr); + Test_splitter_close(splitterAddr); delete[] inputDataTypes.inputDataPrecisions; delete[] inputDataTypes.inputDataScales; } \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension/cpp/test/utils/test_utils.cpp b/omnioperator/omniop-spark-extension/cpp/test/utils/test_utils.cpp index 9c30ed17e05f13bb438856f459501f4978a1220f..abf9f8074c08d8c7cb749b7f5c3295b1cbd4e3a9 100644 --- a/omnioperator/omniop-spark-extension/cpp/test/utils/test_utils.cpp +++ b/omnioperator/omniop-spark-extension/cpp/test/utils/test_utils.cpp @@ -47,7 +47,9 @@ BaseVector *CreateDictionaryVector(DataType &dataType, int32_t rowCount, int32_t va_start(args, idsCount); BaseVector *dictionary = CreateVector(dataType, rowCount, args); va_end(args); - return DYNAMIC_TYPE_DISPATCH(CreateDictionary, dataType.GetId(), dictionary, ids, idsCount); + BaseVector *dictVector = DYNAMIC_TYPE_DISPATCH(CreateDictionary, dataType.GetId(), dictionary, ids, idsCount); + delete dictionary; + return dictVector; } /** @@ -379,7 +381,7 @@ void Test_Shuffle_Compression(std::string compStr, int32_t numPartition, int32_t inputDataTypes.inputDataPrecisions = new uint32_t[colNumber]; inputDataTypes.inputDataScales = new uint32_t[colNumber]; int partitionNum = numPartition; - int splitterId = Test_splitter_nativeMake("hash", + long splitterId = Test_splitter_nativeMake("hash", partitionNum, inputDataTypes, colNumber, @@ -422,31 +424,31 @@ long Test_splitter_nativeMake(std::string partitioning_name, splitOptions.compression_type = compression_type_result; splitOptions.data_file = data_file_jstr; auto splitter = Splitter::Make(partitioning_name, inputDataTypes, numCols, num_partitions, std::move(splitOptions)); - return testShuffleSplitterHolder.Insert(std::shared_ptr(splitter)); + return reinterpret_cast(static_cast(splitter)); } -void Test_splitter_split(long splitter_id, VectorBatch* vb) { - auto splitter = testShuffleSplitterHolder.Lookup(splitter_id); +void Test_splitter_split(long splitter_addr, VectorBatch* vb) { + auto splitter = reinterpret_cast(splitter_addr); // Initialize split global variables splitter->Split(*vb); } -void Test_splitter_stop(long splitter_id) { - auto splitter = testShuffleSplitterHolder.Lookup(splitter_id); +void Test_splitter_stop(long splitter_addr) { + auto splitter = reinterpret_cast(splitter_addr); if (!splitter) { - std::string error_message = "Invalid splitter id " + std::to_string(splitter_id); + std::string error_message = "Invalid splitter id " + std::to_string(splitter_addr); throw std::runtime_error("Test no splitter."); } splitter->Stop(); } -void Test_splitter_close(long splitter_id) { - auto splitter = testShuffleSplitterHolder.Lookup(splitter_id); +void Test_splitter_close(long splitter_addr) { + auto splitter = reinterpret_cast(splitter_addr); if (!splitter) { - std::string error_message = "Invalid splitter id " + std::to_string(splitter_id); + std::string error_message = "Invalid splitter id " + std::to_string(splitter_addr); throw std::runtime_error("Test no splitter."); } - testShuffleSplitterHolder.Erase(splitter_id); + delete splitter; } void GetFilePath(const char *path, const char *filename, char *filepath) { diff --git a/omnioperator/omniop-spark-extension/cpp/test/utils/test_utils.h b/omnioperator/omniop-spark-extension/cpp/test/utils/test_utils.h index b7380254a687ed6f3eaf8234df944feac9087404..b588ea6f2427f185686ca6357570ec3a44b52691 100644 --- a/omnioperator/omniop-spark-extension/cpp/test/utils/test_utils.h +++ b/omnioperator/omniop-spark-extension/cpp/test/utils/test_utils.h @@ -25,10 +25,7 @@ #include #include #include -#include "shuffle/splitter.h" -#include "jni/concurrent_map.h" - -static ConcurrentMap> testShuffleSplitterHolder; +#include "../../src/shuffle/splitter.h" static std::string s_shuffle_tests_dir = "/tmp/shuffleTests"; @@ -131,4 +128,71 @@ void GetFilePath(const char *path, const char *filename, char *filepath); void DeletePathAll(const char* path); +class Timer { +public: + Timer() : wallElapsed(0), cpuElapsed(0) {} + + ~Timer() {} + + void SetStart() { + clock_gettime(CLOCK_REALTIME, &wallStart); + clock_gettime(CLOCK_PROCESS_CPUTIME_ID, &cpuStart); + } + + void CalculateElapse() { + clock_gettime(CLOCK_REALTIME, &wallEnd); + clock_gettime(CLOCK_PROCESS_CPUTIME_ID, &cpuEnd); + long secondsWall = wallEnd.tv_sec - wallStart.tv_sec; + long secondsCpu = cpuEnd.tv_sec - cpuEnd.tv_sec; + long nsWall = wallEnd.tv_nsec - wallStart.tv_nsec; + long nsCpu = cpuEnd.tv_nsec - cpuEnd.tv_nsec; + wallElapsed = secondsWall + nsWall * 1e-9; + cpuElapsed = secondsCpu + nsCpu * 1e-9; + } + + void Start(const char *TestTitle) { + wallElapsed = 0; + cpuElapsed = 0; + clock_gettime(CLOCK_REALTIME, &wallStart); + clock_gettime(CLOCK_PROCESS_CPUTIME_ID, &cpuStart); + this->title = TestTitle; + } + + void End() { + clock_gettime(CLOCK_REALTIME, &wallEnd); + clock_gettime(CLOCK_PROCESS_CPUTIME_ID, &cpuEnd); + long secondsWall = wallEnd.tv_sec - wallStart.tv_sec; + long secondsCpu = cpuEnd.tv_sec - cpuEnd.tv_sec; + long nsWall = wallEnd.tv_nsec - wallStart.tv_nsec; + long nsCpu = cpuEnd.tv_nsec - cpuEnd.tv_nsec; + wallElapsed = secondsWall + nsWall * 1e-9; + cpuElapsed = secondsCpu + nsCpu * 1e-9; + std::cout << title << "\t: wall " << wallElapsed << " \tcpu " << cpuElapsed << std::endl; + } + + double GetWallElapse() { + return wallElapsed; + } + + double GetCpuElapse() { + return cpuElapsed; + } + + void Reset() { + wallElapsed = 0; + cpuElapsed = 0; + clock_gettime(CLOCK_REALTIME, &wallStart); + clock_gettime(CLOCK_PROCESS_CPUTIME_ID, &cpuStart); + } + +private: + double wallElapsed; + double cpuElapsed; + struct timespec cpuStart; + struct timespec wallStart; + struct timespec cpuEnd; + struct timespec wallEnd; + const char *title; +}; + #endif //SPARK_THESTRAL_PLUGIN_TEST_UTILS_H \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension/java/pom.xml b/omnioperator/omniop-spark-extension/java/pom.xml index 62c407dc3dedf2df0c4ca7ab891083467a8279f7..1fd6bb40d13a5670f440a468d5fb02022642116c 100644 --- a/omnioperator/omniop-spark-extension/java/pom.xml +++ b/omnioperator/omniop-spark-extension/java/pom.xml @@ -7,7 +7,7 @@ com.huawei.kunpeng boostkit-omniop-spark-parent - 3.3.1-1.4.0 + 3.3.1-1.5.0 ../pom.xml @@ -46,13 +46,13 @@ com.huawei.boostkit boostkit-omniop-bindings - 1.4.0 + 1.5.0 aarch64 com.huawei.boostkit boostkit-omniop-native-reader - 3.3.1-1.4.0 + 3.3.1-1.5.0 junit diff --git a/omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchScanReader.java b/omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchScanReader.java index 1d858a5e3f22353ddfd18588ec69f50d96de5852..73438aa4355b090b78bed387d0c6fa81a308f339 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchScanReader.java +++ b/omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchScanReader.java @@ -20,11 +20,8 @@ package com.huawei.boostkit.spark.jni; import com.huawei.boostkit.scan.jni.OrcColumnarBatchJniReader; import nova.hetu.omniruntime.type.DataType; -import nova.hetu.omniruntime.type.Decimal128DataType; import nova.hetu.omniruntime.vector.*; -import org.apache.hadoop.security.UserGroupInformation; -import org.apache.hadoop.security.token.Token; import org.apache.spark.sql.catalyst.util.RebaseDateTime; import org.apache.hadoop.hive.ql.io.sarg.ExpressionTree; import org.apache.hadoop.hive.ql.io.sarg.PredicateLeaf; @@ -33,13 +30,10 @@ import org.apache.orc.Reader.Options; import org.json.JSONObject; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import org.apache.orc.TypeDescription; -import java.io.IOException; import java.net.URI; import java.sql.Date; import java.util.ArrayList; -import java.util.Arrays; import java.util.List; public class OrcColumnarBatchScanReader { @@ -186,13 +180,6 @@ public class OrcColumnarBatchScanReader { } job.put("tailLocation", 9223372036854775807L); - // handle delegate token for native orc reader - OrcColumnarBatchScanReader.tokenDebug("initializeReader"); - JSONObject tokenJsonObj = constructTokensJSONObject(); - if (null != tokenJsonObj) { - job.put("tokens", tokenJsonObj); - } - job.put("scheme", uri.getScheme() == null ? "" : uri.getScheme()); job.put("host", uri.getHost() == null ? "" : uri.getHost()); job.put("port", uri.getPort()); @@ -228,12 +215,7 @@ public class OrcColumnarBatchScanReader { } job.put("includedColumns", colToInclu.toArray()); - // handle delegate token for native orc reader - OrcColumnarBatchScanReader.tokenDebug("initializeRecordReader"); - JSONObject tokensJsonObj = constructTokensJSONObject(); - if (null != tokensJsonObj) { - job.put("tokens", tokensJsonObj); - } + recordReader = jniReader.initializeRecordReader(reader, job); return recordReader; } @@ -271,8 +253,7 @@ public class OrcColumnarBatchScanReader { } } - public int next(Vec[] vecList) { - int[] typeIds = new int[realColsCnt]; + public int next(Vec[] vecList, int[] typeIds) { long[] vecNativeIds = new long[realColsCnt]; long rtn = jniReader.recordReaderNext(recordReader, batchReader, typeIds, vecNativeIds); if (rtn == 0) { @@ -342,39 +323,4 @@ public class OrcColumnarBatchScanReader { return hexString.toString().toLowerCase(); } - - public static JSONObject constructTokensJSONObject() { - JSONObject tokensJsonItem = new JSONObject(); - try { - ArrayList child = new ArrayList(); - for (Token token : UserGroupInformation.getCurrentUser().getTokens()) { - JSONObject tokenJsonItem = new JSONObject(); - tokenJsonItem.put("identifier", bytesToHexString(token.getIdentifier())); - tokenJsonItem.put("password", bytesToHexString(token.getPassword())); - tokenJsonItem.put("kind", token.getKind().toString()); - tokenJsonItem.put("service", token.getService().toString()); - child.add(tokenJsonItem); - } - tokensJsonItem.put("token", child.toArray()); - } catch (IOException e) { - tokensJsonItem = null; - } finally { - LOGGER.debug("\n\n================== tokens-json ==================\n" + tokensJsonItem.toString()); - return tokensJsonItem; - } - } - - public static void tokenDebug(String mesg) { - try { - LOGGER.debug("\n\n=============" + mesg + "=============\n" + UserGroupInformation.getCurrentUser().toString()); - for (Token token : UserGroupInformation.getCurrentUser().getTokens()) { - LOGGER.debug("\n\ntoken identifier:" + bytesToHexString(token.getIdentifier())); - LOGGER.debug("\ntoken password:" + bytesToHexString(token.getPassword())); - LOGGER.debug("\ntoken kind:" + token.getKind()); - LOGGER.debug("\ntoken service:" + token.getService()); - } - } catch (IOException e) { - LOGGER.debug("\n\n**********" + mesg + " exception **********\n"); - } - } } diff --git a/omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/jni/SparkJniWrapper.java b/omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/jni/SparkJniWrapper.java index 9aa7c414bc841fc10f2e0daae30fd168b1690055..9a49812e678286ba734c3b7d65ffd864ed20df4e 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/jni/SparkJniWrapper.java +++ b/omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/jni/SparkJniWrapper.java @@ -35,7 +35,8 @@ public class SparkJniWrapper { String localDirs, long shuffleCompressBlockSize, int shuffleSpillBatchRowNum, - long shuffleSpillMemoryThreshold) { + long shuffleTaskSpillMemoryThreshold, + long shuffleExecutorSpillMemoryThreshold) { return nativeMake( part.getPartitionName(), part.getPartitionNum(), @@ -48,7 +49,8 @@ public class SparkJniWrapper { localDirs, shuffleCompressBlockSize, shuffleSpillBatchRowNum, - shuffleSpillMemoryThreshold); + shuffleTaskSpillMemoryThreshold, + shuffleExecutorSpillMemoryThreshold); } public native long nativeMake( @@ -63,7 +65,8 @@ public class SparkJniWrapper { String localDirs, long shuffleCompressBlockSize, int shuffleSpillBatchRowNum, - long shuffleSpillMemoryThreshold + long shuffleTaskSpillMemoryThreshold, + long shuffleExecutorSpillMemoryThreshold ); /** @@ -75,6 +78,16 @@ public class SparkJniWrapper { */ public native void split(long splitterId, long nativeVectorBatch); + /** + * Split one record batch represented by bufAddrs and bufSizes into several batches. The batch is converted to row + * formats for split according to the first column as partition id. During splitting, the data in native + * buffers will be written to disk when the buffers are full. + * + * @param splitterId Addresses of splitter + * @param nativeVectorBatch Addresses of nativeVectorBatch + */ + public native void rowSplit(long splitterId, long nativeVectorBatch); + /** * Write the data remained in the buffers hold by native splitter to each partition's temporary * file. And stop processing splitting @@ -84,6 +97,15 @@ public class SparkJniWrapper { */ public native SplitResult stop(long splitterId); + /** + * Write the data remained in the row buffers hold by native splitter to each partition's temporary + * file. And stop processing splitting + * + * @param splitterId splitter instance id + * @return SplitResult + */ + public native SplitResult rowStop(long splitterId); + /** * Release resources associated with designated splitter instance. * diff --git a/omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/serialize/ShuffleDataSerializer.java b/omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/serialize/ShuffleDataSerializer.java index 6a0c1b27c4282016aecada2ba4ef0c48c320f20f..99759e4a30cc8d88189e703ac5fc36541446f1e7 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/serialize/ShuffleDataSerializer.java +++ b/omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/serialize/ShuffleDataSerializer.java @@ -20,6 +20,8 @@ package com.huawei.boostkit.spark.serialize; import com.google.protobuf.InvalidProtocolBufferException; +import com.huawei.boostkit.spark.jni.NativeLoader; +import nova.hetu.omniruntime.type.*; import nova.hetu.omniruntime.utils.OmniRuntimeException; import nova.hetu.omniruntime.vector.BooleanVec; import nova.hetu.omniruntime.vector.Decimal128Vec; @@ -29,16 +31,28 @@ import nova.hetu.omniruntime.vector.LongVec; import nova.hetu.omniruntime.vector.ShortVec; import nova.hetu.omniruntime.vector.VarcharVec; import nova.hetu.omniruntime.vector.Vec; +import nova.hetu.omniruntime.vector.serialize.OmniRowDeserializer; import org.apache.spark.sql.execution.vectorized.OmniColumnVector; import org.apache.spark.sql.types.DataType; import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.vectorized.ColumnVector; import org.apache.spark.sql.vectorized.ColumnarBatch; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; public class ShuffleDataSerializer { + private static final Logger LOG = LoggerFactory.getLogger(NativeLoader.class); - public static ColumnarBatch deserialize(byte[] bytes) { + public static ColumnarBatch deserialize(boolean isRowShuffle, byte[] bytes) { + if (!isRowShuffle) { + return deserializeByColumn(bytes); + } else { + return deserializeByRow(bytes); + } + } + + public static ColumnarBatch deserializeByColumn(byte[] bytes) { ColumnVector[] vecs = null; try { VecData.VecBatch vecBatch = VecData.VecBatch.parseFrom(bytes); @@ -64,6 +78,30 @@ public class ShuffleDataSerializer { } } + public static ColumnarBatch deserializeByRow(byte[] bytes) { + try { + VecData.ProtoRowBatch rowBatch = VecData.ProtoRowBatch.parseFrom(bytes); + int vecCount = rowBatch.getVecCnt(); + int rowCount = rowBatch.getRowCnt(); + OmniColumnVector[] columnarVecs = new OmniColumnVector[vecCount]; + long[] omniVecs = new long[vecCount]; + int[] omniTypes = new int[vecCount]; + createEmptyVec(rowBatch, omniTypes, omniVecs, columnarVecs, vecCount, rowCount); + OmniRowDeserializer deserializer = new OmniRowDeserializer(omniTypes); + + for (int rowIdx = 0; rowIdx < rowCount; rowIdx++) { + VecData.ProtoRow protoRow = rowBatch.getRows(rowIdx); + byte[] array = protoRow.getData().toByteArray(); + deserializer.parse(array, omniVecs, rowIdx); + } + + deserializer.close(); + return new ColumnarBatch(columnarVecs, rowCount); + } catch (InvalidProtocolBufferException e) { + throw new RuntimeException("deserialize failed. errmsg:" + e.getMessage()); + } + } + private static ColumnVector buildVec(VecData.Vec protoVec, int vecSize) { VecData.VecType protoTypeId = protoVec.getVecType(); Vec vec; @@ -128,4 +166,75 @@ public class ShuffleDataSerializer { vecTmp.setVec(vec); return vecTmp; } + + public static void createEmptyVec(VecData.ProtoRowBatch rowBatch, int[] omniTypes, long[] omniVecs, OmniColumnVector[] columnarVectors, int vecCount, int rowCount) { + for (int i = 0; i < vecCount; i++) { + VecData.VecType protoTypeId = rowBatch.getVecTypes(i); + DataType sparkType; + Vec omniVec; + switch (protoTypeId.getTypeId()) { + case VEC_TYPE_INT: + sparkType = DataTypes.IntegerType; + omniTypes[i] = IntDataType.INTEGER.getId().toValue(); + omniVec = new IntVec(rowCount); + break; + case VEC_TYPE_DATE32: + sparkType = DataTypes.DateType; + omniTypes[i] = Date32DataType.DATE32.getId().toValue(); + omniVec = new IntVec(rowCount); + break; + case VEC_TYPE_LONG: + sparkType = DataTypes.LongType; + omniTypes[i] = LongDataType.LONG.getId().toValue(); + omniVec = new LongVec(rowCount); + break; + case VEC_TYPE_DATE64: + sparkType = DataTypes.DateType; + omniTypes[i] = Date64DataType.DATE64.getId().toValue(); + omniVec = new LongVec(rowCount); + break; + case VEC_TYPE_DECIMAL64: + sparkType = DataTypes.createDecimalType(protoTypeId.getPrecision(), protoTypeId.getScale()); + omniTypes[i] = new Decimal64DataType(protoTypeId.getPrecision(), protoTypeId.getScale()).getId().toValue(); + omniVec = new LongVec(rowCount); + break; + case VEC_TYPE_SHORT: + sparkType = DataTypes.ShortType; + omniTypes[i] = ShortDataType.SHORT.getId().toValue(); + omniVec = new ShortVec(rowCount); + break; + case VEC_TYPE_BOOLEAN: + sparkType = DataTypes.BooleanType; + omniTypes[i] = BooleanDataType.BOOLEAN.getId().toValue(); + omniVec = new BooleanVec(rowCount); + break; + case VEC_TYPE_DOUBLE: + sparkType = DataTypes.DoubleType; + omniTypes[i] = DoubleDataType.DOUBLE.getId().toValue(); + omniVec = new DoubleVec(rowCount); + break; + case VEC_TYPE_VARCHAR: + case VEC_TYPE_CHAR: + sparkType = DataTypes.StringType; + omniTypes[i] = VarcharDataType.VARCHAR.getId().toValue(); + omniVec = new VarcharVec(rowCount); + break; + case VEC_TYPE_DECIMAL128: + sparkType = DataTypes.createDecimalType(protoTypeId.getPrecision(), protoTypeId.getScale()); + omniTypes[i] = new Decimal128DataType(protoTypeId.getPrecision(), protoTypeId.getScale()).getId().toValue(); + omniVec = new Decimal128Vec(rowCount); + break; + case VEC_TYPE_TIME32: + case VEC_TYPE_TIME64: + case VEC_TYPE_INTERVAL_DAY_TIME: + case VEC_TYPE_INTERVAL_MONTHS: + default: + throw new IllegalStateException("Unexpected value: " + protoTypeId.getTypeId()); + } + + omniVecs[i] = omniVec.getNativeVector(); + columnarVectors[i] = new OmniColumnVector(rowCount, sparkType, false); + columnarVectors[i].setVec(omniVec); + } + } } diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarPlugin.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarPlugin.scala index 108562dc66926aac1d5adadccc3d7568d79071da..d4a26303181bfa476f01f9293c0bca172e61dacc 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarPlugin.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarPlugin.scala @@ -20,23 +20,26 @@ package com.huawei.boostkit.spark import com.huawei.boostkit.spark.expression.OmniExpressionAdaptor import com.huawei.boostkit.spark.util.PhysicalPlanSelector +import nova.hetu.omniruntime.memory.MemoryManager +import org.apache.spark.api.plugin.{DriverPlugin, ExecutorPlugin, SparkPlugin} import org.apache.spark.internal.Logging import org.apache.spark.sql.{SparkSession, SparkSessionExtensions} import org.apache.spark.sql.catalyst.expressions.{Ascending, DynamicPruningSubquery, Expression, Literal, SortOrder} import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Partial, PartialMerge} -import org.apache.spark.sql.catalyst.optimizer.{DelayCartesianProduct, HeuristicJoinReorder, MergeSubqueryFilters, RewriteSelfJoinInInPredicate} +import org.apache.spark.sql.catalyst.optimizer.{DelayCartesianProduct, HeuristicJoinReorder, MergeSubqueryFilters, CombineJoinedAggregates, RewriteSelfJoinInInPredicate} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.{RowToOmniColumnarExec, _} import org.apache.spark.sql.execution.adaptive.{AQEShuffleReadExec, BroadcastQueryStageExec, OmniAQEShuffleReadExec, QueryStageExec, ShuffleQueryStageExec} import org.apache.spark.sql.execution.aggregate.{DummyLogicalPlan, ExtendedAggUtils, HashAggregateExec} import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, Exchange, ReusedExchangeExec, ShuffleExchangeExec} import org.apache.spark.sql.execution.joins._ -import org.apache.spark.sql.execution.window.WindowExec +import org.apache.spark.sql.execution.window.{WindowExec, TopNPushDownForWindow} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.ColumnarBatchSupportUtil.checkColumnarBatchSupport import org.apache.spark.sql.catalyst.planning.PhysicalAggregation import org.apache.spark.sql.catalyst.plans.LeftSemi import org.apache.spark.sql.catalyst.plans.logical.Aggregate +import org.apache.spark.sql.execution.util.SparkMemoryUtils.addLeakSafeTaskCompletionListener case class ColumnarPreOverrides() extends Rule[SparkPlan] { val columnarConf: ColumnarPluginConfig = ColumnarPluginConfig.getSessionConf @@ -67,6 +70,8 @@ case class ColumnarPreOverrides() extends Rule[SparkPlan] { val dedupLeftSemiJoinThreshold: Int = columnarConf.dedupLeftSemiJoinThreshold val enableColumnarCoalesce: Boolean = columnarConf.enableColumnarCoalesce val enableRollupOptimization: Boolean = columnarConf.enableRollupOptimization + val enableRowShuffle: Boolean = columnarConf.enableRowShuffle + val columnsThreshold: Int = columnarConf.columnsThreshold def apply(plan: SparkPlan): SparkPlan = { replaceWithColumnarPlan(plan) @@ -544,11 +549,18 @@ case class ColumnarPreOverrides() extends Rule[SparkPlan] { val children = plan.children.map(replaceWithColumnarPlan) logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.") ColumnarUnionExec(children) - case plan: ShuffleExchangeExec if enableColumnarShuffle => + case plan: ShuffleExchangeExec if enableColumnarShuffle || enableRowShuffle => val child = replaceWithColumnarPlan(plan.child) if (child.output.nonEmpty) { logInfo(s"Columnar Processing for ${plan.getClass} is currently supported.") - new ColumnarShuffleExchangeExec(plan.outputPartitioning, child, plan.shuffleOrigin) + if (child.isInstanceOf[ColumnarHashAggregateExec] && child.output.size > columnsThreshold + && enableRowShuffle) { + new ColumnarShuffleExchangeExec(plan.outputPartitioning, child, plan.shuffleOrigin, true) + } else if (enableColumnarShuffle) { + new ColumnarShuffleExchangeExec(plan.outputPartitioning, child, plan.shuffleOrigin, false) + } else { + plan + } } else { plan } @@ -751,5 +763,25 @@ class ColumnarPlugin extends (SparkSessionExtensions => Unit) with Logging { extensions.injectOptimizerRule(_ => DelayCartesianProduct) extensions.injectOptimizerRule(_ => HeuristicJoinReorder) extensions.injectOptimizerRule(_ => MergeSubqueryFilters) + extensions.injectOptimizerRule(_ => CombineJoinedAggregates) + extensions.injectQueryStagePrepRule(_ => TopNPushDownForWindow) + } +} + +private class OmniTaskStartExecutorPlugin extends ExecutorPlugin { + override def onTaskStart(): Unit = { + addLeakSafeTaskCompletionListener[Unit](_ => { + MemoryManager.clearMemory() + }) + } +} + +class OmniSparkPlugin extends SparkPlugin { + override def executorPlugin(): ExecutorPlugin = { + new OmniTaskStartExecutorPlugin() + } + + override def driverPlugin(): DriverPlugin = { + null } } \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarPluginConfig.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarPluginConfig.scala index e87122e87c6b675dfaeab352af390143111510d9..5fc711a3982ba4633e7cff785414198e751fcd94 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarPluginConfig.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarPluginConfig.scala @@ -135,9 +135,9 @@ class ColumnarPluginConfig(conf: SQLConf) extends Logging { val columnarShuffleSpillBatchRowNum = conf.getConfString("spark.shuffle.columnar.shuffleSpillBatchRowNum", "10000").toInt - // columnar shuffle spill memory threshold - val columnarShuffleSpillMemoryThreshold = - conf.getConfString("spark.shuffle.columnar.shuffleSpillMemoryThreshold", + // columnar shuffle spill memory threshold in task level + val columnarShuffleTaskSpillMemoryThreshold = + conf.getConfString("spark.shuffle.columnar.shuffleTaskSpillMemoryThreshold", "2147483648").toLong // columnar shuffle compress block size @@ -156,12 +156,15 @@ class ColumnarPluginConfig(conf: SQLConf) extends Logging { val columnarShuffleNativeBufferSize = conf.getConfString("spark.sql.execution.columnar.maxRecordsPerBatch", "4096").toInt + val columnarSpillWriteBufferSize: Long = + conf.getConfString("spark.omni.sql.columnar.spill.writeBufferSize", "4121440").toLong + // columnar spill threshold - Percentage of memory usage, associate with the "spark.memory.offHeap" together val columnarSpillMemPctThreshold: Integer = conf.getConfString("spark.omni.sql.columnar.spill.memFraction", "90").toInt // columnar spill dir disk reserve Size, default 10GB - val columnarSpillDirDiskReserveSize:Long = + val columnarSpillDirDiskReserveSize: Long = conf.getConfString("spark.omni.sql.columnar.spill.dirDiskReserveSize", "10737418240").toLong // enable or disable columnar sort spill @@ -244,6 +247,7 @@ class ColumnarPluginConfig(conf: SQLConf) extends Logging { conf.getConfString("spark.omni.sql.columnar.dedupLeftSemiJoinThreshold", "3").toInt val filterMergeEnable: Boolean = conf.getConfString("spark.sql.execution.filterMerge.enabled", "false").toBoolean + val combineJoinedAggregatesEnabled: Boolean = conf.getConfString("spark.sql.execution.combineJoinedAggregates.enabled", "false").toBoolean val filterMergeThreshold: Double = conf.getConfString("spark.sql.execution.filterMerge.maxCost", "100.0").toDouble @@ -259,6 +263,13 @@ class ColumnarPluginConfig(conf: SQLConf) extends Logging { val radixSortThreshold: Int = conf.getConfString("spark.omni.sql.columnar.radixSortThreshold", "1000000").toInt + + // enable or disable row shuffle + val enableRowShuffle: Boolean = + conf.getConfString("spark.omni.sql.columnar.rowShuffle.enabled", "true").toBoolean + + val columnsThreshold: Int = + conf.getConfString("spark.omni.sql.columnar.columnsThreshold", "10").toInt } diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/expression/OmniExpressionAdaptor.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/expression/OmniExpressionAdaptor.scala index 11ff8e12b54519484e6aa04885e3a50933d859a4..cfc95ae37506fc7ec2a39d45b8c8747783fcf298 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/expression/OmniExpressionAdaptor.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/expression/OmniExpressionAdaptor.scala @@ -35,7 +35,7 @@ import org.apache.spark.sql.catalyst.util.CharVarcharUtils.getRawTypeString import org.apache.spark.sql.execution.ColumnarBloomFilterSubquery import org.apache.spark.sql.expression.ColumnarExpressionConverter import org.apache.spark.sql.hive.HiveUdfAdaptorUtil -import org.apache.spark.sql.types.{BooleanType, DataType, DateType, Decimal, DecimalType, DoubleType, IntegerType, LongType, Metadata, NullType, ShortType, StringType, TimestampType} +import org.apache.spark.sql.types.{BinaryType, BooleanType, DataType, DateType, Decimal, DecimalType, DoubleType, IntegerType, LongType, Metadata, NullType, ShortType, StringType, TimestampType} import org.json.{JSONArray, JSONObject} import java.util.Locale @@ -74,10 +74,10 @@ object OmniExpressionAdaptor extends Logging { } } - private def unsupportedCastCheck(expr: Expression, cast: Cast): Unit = { + private def unsupportedCastCheck(expr: Expression, cast: CastBase): Unit = { def doSupportCastToString(dataType: DataType): Boolean = { if (dataType.isInstanceOf[DecimalType] || dataType.isInstanceOf[StringType] || dataType.isInstanceOf[IntegerType] - || dataType.isInstanceOf[LongType]) { + || dataType.isInstanceOf[LongType] || dataType.isInstanceOf[DateType] || dataType.isInstanceOf[DoubleType]) { true } else { false @@ -86,7 +86,7 @@ object OmniExpressionAdaptor extends Logging { def doSupportCastFromString(dataType: DataType): Boolean = { if (dataType.isInstanceOf[DecimalType] || dataType.isInstanceOf[StringType] || dataType.isInstanceOf[DateType] - || dataType.isInstanceOf[IntegerType] || dataType.isInstanceOf[LongType]) { + || dataType.isInstanceOf[IntegerType] || dataType.isInstanceOf[LongType] || dataType.isInstanceOf[DoubleType]) { true } else { false @@ -103,10 +103,6 @@ object OmniExpressionAdaptor extends Logging { throw new UnsupportedOperationException(s"Unsupported expression: $expr") } - // not support Cast(double as decimal) - if (cast.dataType.isInstanceOf[DecimalType] && cast.child.dataType.isInstanceOf[DoubleType]) { - throw new UnsupportedOperationException(s"Unsupported expression: $expr") - } } def rewriteToOmniJsonExpressionLiteral(expr: Expression, @@ -242,6 +238,7 @@ object OmniExpressionAdaptor extends Logging { case alias: Alias => rewriteToOmniJsonExpressionLiteralJsonObject(alias.child, exprsIndexMap) case literal: Literal => toOmniJsonLiteral(literal) + case not: Not => not.child match { case isnull: IsNull => @@ -263,6 +260,7 @@ object OmniExpressionAdaptor extends Logging { .put("operator", "not") .put("expr", rewriteToOmniJsonExpressionLiteralJsonObject(not.child, exprsIndexMap)) } + case isnotnull: IsNotNull => new JSONObject().put("exprType", "UNARY") .addOmniExpJsonType("returnType", BooleanType) @@ -287,7 +285,7 @@ object OmniExpressionAdaptor extends Logging { .put(rewriteToOmniJsonExpressionLiteralJsonObject(subString.len, exprsIndexMap))) // Cast - case cast: Cast => + case cast: CastBase => unsupportedCastCheck(expr, cast) cast.dataType match { case StringType => @@ -302,8 +300,8 @@ object OmniExpressionAdaptor extends Logging { .addOmniExpJsonType("returnType", cast.dataType) .put("function_name", "CAST") .put("arguments", new JSONArray().put(rewriteToOmniJsonExpressionLiteralJsonObject(cast.child, exprsIndexMap))) - } + // Abs case abs: Abs => new JSONObject().put("exprType", "FUNCTION") @@ -414,6 +412,13 @@ object OmniExpressionAdaptor extends Logging { .put("arguments", new JSONArray().put(rewriteToOmniJsonExpressionLiteralJsonObject(inStr.str, exprsIndexMap)) .put(rewriteToOmniJsonExpressionLiteralJsonObject(inStr.substr, exprsIndexMap))) + case rlike: RLike => + new JSONObject().put("exprType", "FUNCTION") + .addOmniExpJsonType("returnType", rlike.dataType) + .put("function_name", "RLike") + .put("arguments", new JSONArray().put(rewriteToOmniJsonExpressionLiteralJsonObject(rlike.left, exprsIndexMap)) + .put(rewriteToOmniJsonExpressionLiteralJsonObject(rlike.right, exprsIndexMap))) + // for floating numbers normalize case normalizeNaNAndZero: NormalizeNaNAndZero => new JSONObject().put("exprType", "FUNCTION") @@ -450,6 +455,25 @@ object OmniExpressionAdaptor extends Logging { throw new UnsupportedOperationException(s"Unsupported right expression in like expression: $endsWith") } + case truncDate: TruncDate => + new JSONObject().put("exprType", "FUNCTION") + .addOmniExpJsonType("returnType", truncDate.dataType) + .put("function_name", "trunc_date") + .put("arguments", new JSONArray().put(rewriteToOmniJsonExpressionLiteralJsonObject(truncDate.left, exprsIndexMap)) + .put(rewriteToOmniJsonExpressionLiteralJsonObject(truncDate.right, exprsIndexMap))) + + case md5: Md5 => + md5.child match { + case Cast(inputExpression, outputType, _, _) if outputType == BinaryType => + inputExpression match { + case AttributeReference(_, dataType, _, _) if dataType == StringType => + new JSONObject().put("exprType", "FUNCTION") + .addOmniExpJsonType("returnType", md5.dataType) + .put("function_name", "Md5") + .put("arguments", new JSONArray().put(rewriteToOmniJsonExpressionLiteralJsonObject(inputExpression, exprsIndexMap))) + } + } + case _ => if (HiveUdfAdaptorUtil.isHiveUdf(expr) && ColumnarPluginConfig.getSessionConf.enableColumnarUdf) { val hiveUdf = HiveUdfAdaptorUtil.asHiveSimpleUDF(expr) @@ -723,6 +747,10 @@ object OmniExpressionAdaptor extends Logging { } } + def sparkTypeToOmniType(dataType: DataType): Int = { + sparkTypeToOmniType(dataType, Metadata.empty).getId.ordinal() + } + def sparkTypeToOmniType(dataType: DataType, metadata: Metadata = Metadata.empty): nova.hetu.omniruntime.`type`.DataType = { dataType match { diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/serialize/ColumnarBatchSerializer.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/serialize/ColumnarBatchSerializer.scala index 07ac07e8f81a47811ce8b979c9efcc3b52f882e3..26e2b7a3e057662fd9abee0db0345beeeb63f523 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/serialize/ColumnarBatchSerializer.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/serialize/ColumnarBatchSerializer.scala @@ -28,15 +28,18 @@ import org.apache.spark.serializer.{DeserializationStream, SerializationStream, import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.vectorized.ColumnarBatch -class ColumnarBatchSerializer(readBatchNumRows: SQLMetric, numOutputRows: SQLMetric) +class ColumnarBatchSerializer(readBatchNumRows: SQLMetric, + numOutputRows: SQLMetric, + isRowShuffle: Boolean = false) extends Serializer with Serializable { /** Creates a new [[SerializerInstance]]. */ override def newInstance(): SerializerInstance = - new ColumnarBatchSerializerInstance(readBatchNumRows, numOutputRows) + new ColumnarBatchSerializerInstance(isRowShuffle, readBatchNumRows, numOutputRows) } private class ColumnarBatchSerializerInstance( + isRowShuffle: Boolean, readBatchNumRows: SQLMetric, numOutputRows: SQLMetric) extends SerializerInstance with Logging { @@ -62,7 +65,6 @@ private class ColumnarBatchSerializerInstance( new DataInputStream(new BufferedInputStream(in)) } private[this] var columnarBuffer: Array[Byte] = new Array[Byte](1024) - val ibuffer: ByteBuffer = ByteBuffer.allocateDirect(4) private[this] val EOF: Int = -1 @@ -85,7 +87,7 @@ private class ColumnarBatchSerializerInstance( } ByteStreams.readFully(dIn, columnarBuffer, 0, dataSize) // protobuf serialize - val columnarBatch: ColumnarBatch = ShuffleDataSerializer.deserialize(columnarBuffer.slice(0, dataSize)) + val columnarBatch: ColumnarBatch = ShuffleDataSerializer.deserialize(isRowShuffle, columnarBuffer.slice(0, dataSize)) dataSize = readSize() if (dataSize == EOF) { dIn.close() @@ -114,7 +116,7 @@ private class ColumnarBatchSerializerInstance( } ByteStreams.readFully(dIn, columnarBuffer, 0, dataSize) // protobuf serialize - val columnarBatch: ColumnarBatch = ShuffleDataSerializer.deserialize(columnarBuffer.slice(0, dataSize)) + val columnarBatch: ColumnarBatch = ShuffleDataSerializer.deserialize(isRowShuffle, columnarBuffer.slice(0, dataSize)) numBatchesTotal += 1 numRowsTotal += columnarBatch.numRows() columnarBatch.asInstanceOf[T] diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/util/OmniAdaptorUtil.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/util/OmniAdaptorUtil.scala index 875fe939dcbd932877165079a6f16400fc968865..113e88399ba897dcc38c3c40841a18dafd2a9315 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/util/OmniAdaptorUtil.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/util/OmniAdaptorUtil.scala @@ -330,12 +330,12 @@ object OmniAdaptorUtil { operator } - def pruneOutput(output: Seq[Attribute], projectList: Seq[NamedExpression]): Seq[Attribute] = { - if (projectList.nonEmpty) { + def pruneOutput(output: Seq[Attribute], projectExprIdList: Seq[ExprId]): Seq[Attribute] = { + if (projectExprIdList.nonEmpty) { val projectOutput = ListBuffer[Attribute]() - for (project <- projectList) { + for (index <- projectExprIdList.indices) { for (col <- output) { - if (col.exprId.equals(getProjectAliasExprId(project))) { + if (col.exprId.equals(projectExprIdList(index))) { projectOutput += col } } @@ -346,13 +346,13 @@ object OmniAdaptorUtil { } } - def getIndexArray(output: Seq[Attribute], projectList: Seq[NamedExpression]): Array[Int] = { - if (projectList.nonEmpty) { + def getIndexArray(output: Seq[Attribute], projectExprIdList: Seq[ExprId]): Array[Int] = { + if (projectExprIdList.nonEmpty) { val indexList = ListBuffer[Int]() - for (project <- projectList) { + for (index <- projectExprIdList.indices) { for (i <- output.indices) { val col = output(i) - if (col.exprId.equals(getProjectAliasExprId(project))) { + if (col.exprId.equals(projectExprIdList(index))) { indexList += i } } @@ -363,23 +363,50 @@ object OmniAdaptorUtil { } } - def reorderVecs(prunedOutput: Seq[Attribute], projectList: Seq[NamedExpression], resultVecs: Array[nova.hetu.omniruntime.vector.Vec], vecs: Array[OmniColumnVector]) = { - val used = new Array[Boolean](resultVecs.length) - for (index <- projectList.indices) { - val project = projectList(index) + def reorderOutputVecs(projectListIndex: Array[Int], omniVecs: Array[nova.hetu.omniruntime.vector.Vec], + outputVecs: Array[OmniColumnVector]) = { + for (index <- projectListIndex.indices) { + val outputVec = outputVecs(index) + outputVec.reset() + val projectIndex = projectListIndex(index) + outputVec.setVec(omniVecs(projectIndex)) + } + } + + def getProjectListIndex(projectExprIdList: Seq[ExprId], probeOutput: Seq[Attribute], + buildOutput: Seq[Attribute]): Array[Int] = { + val projectListIndex = ListBuffer[Int]() + var probeIndex = 0 + var buildIndex = probeOutput.size + for (index <- projectExprIdList.indices) { breakable { - for (i <- prunedOutput.indices) { - val col = prunedOutput(i) - if (!used(i) && col.exprId.equals(getProjectAliasExprId(project))) { - val v = vecs(index) - v.reset() - v.setVec(resultVecs(i)) - used(i) = true; + for (probeAttr <- probeOutput) { + if (probeAttr.exprId.equals(projectExprIdList(index))) { + projectListIndex += probeIndex + probeIndex += 1 break } } } + breakable { + for (buildAttr <- buildOutput) { + if (buildAttr.exprId.equals(projectExprIdList(index))) { + projectListIndex += buildIndex + buildIndex += 1 + break + } + } + } + } + projectListIndex.toArray + } + + def getExprIdForProjectList(projectList: Seq[NamedExpression]): Seq[ExprId] = { + val exprIdList = ListBuffer[ExprId]() + for (project <- projectList) { + exprIdList += getProjectAliasExprId(project) } + exprIdList } def getProjectAliasExprId(project: NamedExpression): ExprId = { diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/shuffle/ColumnarShuffleDependency.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/shuffle/ColumnarShuffleDependency.scala index 4c27688cb741eec4913ea51e23e14dfa16aa6b64..215be3846b17c6f902987364acb33693a6e8750b 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/shuffle/ColumnarShuffleDependency.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/shuffle/ColumnarShuffleDependency.scala @@ -48,6 +48,7 @@ class ColumnarShuffleDependency[K: ClassTag, V: ClassTag, C: ClassTag]( override val aggregator: Option[Aggregator[K, V, C]] = None, override val mapSideCombine: Boolean = false, override val shuffleWriterProcessor: ShuffleWriteProcessor = new ShuffleWriteProcessor, + val handleRow: Boolean, val partitionInfo: PartitionInfo, val dataSize: SQLMetric, val bytesSpilled: SQLMetric, diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/shuffle/ColumnarShuffleWriter.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/shuffle/ColumnarShuffleWriter.scala index 615ddb6b7449d3ba3f2b6839df8eee67b6d5b05e..a8b7d9eab39374404a5bff4f176569e97e2bd0de 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/shuffle/ColumnarShuffleWriter.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/shuffle/ColumnarShuffleWriter.scala @@ -25,6 +25,7 @@ import nova.hetu.omniruntime.vector.VecBatch import org.apache.spark.SparkEnv import org.apache.spark.internal.Logging import org.apache.spark.scheduler.MapStatus +import org.apache.spark.sql.execution.util.SparkMemoryUtils import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.spark.util.Utils @@ -49,7 +50,8 @@ class ColumnarShuffleWriter[K, V]( val columnarConf = ColumnarPluginConfig.getSessionConf val shuffleSpillBatchRowNum = columnarConf.columnarShuffleSpillBatchRowNum - val shuffleSpillMemoryThreshold = columnarConf.columnarShuffleSpillMemoryThreshold + val shuffleTaskSpillMemoryThreshold = columnarConf.columnarShuffleTaskSpillMemoryThreshold + val shuffleExecutorSpillMemoryThreshold = columnarConf.columnarSpillMemPctThreshold * SparkMemoryUtils.offHeapSize val shuffleCompressBlockSize = columnarConf.columnarShuffleCompressBlockSize val shuffleNativeBufferSize = columnarConf.columnarShuffleNativeBufferSize val enableShuffleCompress = columnarConf.enableShuffleCompress @@ -87,7 +89,8 @@ class ColumnarShuffleWriter[K, V]( localDirs, shuffleCompressBlockSize, shuffleSpillBatchRowNum, - shuffleSpillMemoryThreshold) + shuffleTaskSpillMemoryThreshold, + shuffleExecutorSpillMemoryThreshold) } @@ -104,14 +107,22 @@ class ColumnarShuffleWriter[K, V]( dep.dataSize += input(col).getRealOffsetBufCapacityInBytes } val vb = new VecBatch(input, cb.numRows()) - jniWrapper.split(nativeSplitter, vb.getNativeVectorBatch) + if (!dep.handleRow) { + jniWrapper.split(nativeSplitter, vb.getNativeVectorBatch) + } else { + jniWrapper.rowSplit(nativeSplitter, vb.getNativeVectorBatch) + } dep.splitTime.add(System.nanoTime() - startTime) dep.numInputRows.add(cb.numRows) writeMetrics.incRecordsWritten(cb.numRows) } } val startTime = System.nanoTime() - splitResult = jniWrapper.stop(nativeSplitter) + if (!dep.handleRow) { + splitResult = jniWrapper.stop(nativeSplitter) + } else { + splitResult = jniWrapper.rowStop(nativeSplitter) + } dep.splitTime.add(System.nanoTime() - startTime - splitResult.getTotalSpillTime - splitResult.getTotalWriteTime - splitResult.getTotalComputePidTime) diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CombineJoinedAggregates.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CombineJoinedAggregates.scala new file mode 100644 index 0000000000000000000000000000000000000000..94e3e35e4d3bf29d24c2397c046b264e7eedd092 --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CombineJoinedAggregates.scala @@ -0,0 +1,350 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer +import com.huawei.boostkit.spark.ColumnarPluginConfig +import scala.collection.mutable.ArrayBuffer +import org.apache.spark.internal.Logging + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression +import org.apache.spark.sql.catalyst.plans.{Cross, FullOuter, Inner, JoinType, LeftOuter, RightOuter} +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.rules.Rule +import scala.collection.mutable +import org.apache.spark.sql.catalyst.analysis._ +import org.apache.spark.sql.catalyst.parser.CatalystSqlParser +import org.apache.spark.sql.catalyst.optimizer._ +import org.apache.spark.sql.catalyst.rules.RuleExecutor +import org.apache.spark.sql.catalyst.trees.TreePattern.{AGGREGATE, DYNAMIC_PRUNING_EXPRESSION, DYNAMIC_PRUNING_SUBQUERY, EXISTS_SUBQUERY, HIGH_ORDER_FUNCTION, IN, IN_SUBQUERY, INSET, INVOKE, JOIN, JSON_TO_STRUCT, LIKE_FAMLIY, PYTHON_UDF, REGEXP_EXTRACT_FAMILY, REGEXP_REPLACE, SCALA_UDF} +import org.apache.spark.sql.types.{DataType, Decimal, DecimalType, DoubleType, FloatType, IntegerType, LongType, ShortType} + +/** + * This rule eliminates the [[Join]] if all the join side are [[Aggregate]]s by combine these + * [[Aggregate]]s. This rule also support the nested [[Join]], as long as all the join sides for + * every [[Join]] are [[Aggregate]]s. + * + * Note: this rule doesn't support following cases: + * 1. The [[Aggregate]]s to be merged if at least one of them does not have a predicate or + * has low predicate selectivity. + * 2. The upstream node of these [[Aggregate]]s to be merged exists [[Join]]. + */ +object CombineJoinedAggregates extends Rule[LogicalPlan] with MergeScalarSubqueriesHelper { + + private def isSupportedJoinType(joinType: JoinType): Boolean = + Seq(Inner, Cross, LeftOuter, RightOuter, FullOuter).contains(joinType) + + private def maxTreeNodeNumOfPredicate: Int = 10 + + private def isCheapPredicate(e: Expression): Boolean = { + !e.containsAnyPattern(PYTHON_UDF, SCALA_UDF, INVOKE, JSON_TO_STRUCT, LIKE_FAMLIY, + REGEXP_EXTRACT_FAMILY, REGEXP_REPLACE, DYNAMIC_PRUNING_SUBQUERY, DYNAMIC_PRUNING_EXPRESSION, + HIGH_ORDER_FUNCTION, IN_SUBQUERY, IN, INSET, EXISTS_SUBQUERY) + //countPredicatesInExpressions(e) + } + +private def checkCondition(leftCondition: Expression, rightCondition: Expression): Boolean = { + val normalizedLeft = normalizeExpression(leftCondition) + val normalizedRight = normalizeExpression(rightCondition) + if (normalizedLeft.isDefined && normalizedRight.isDefined) { + (normalizedLeft.get, normalizedRight.get) match { + case (a GreaterThan b, c LessThan d) if a.semanticEquals(c) => + isGreaterOrEqualTo(b, d, a.dataType) + case (a LessThan b, c GreaterThan d) if a.semanticEquals(c) => + isGreaterOrEqualTo(d, b, a.dataType) + case (a GreaterThanOrEqual b, c LessThan d) if a.semanticEquals(c) => + isGreaterOrEqualTo(b, d, a.dataType) + case (a LessThan b, c GreaterThanOrEqual d) if a.semanticEquals(c) => + isGreaterOrEqualTo(d, b, a.dataType) + case (a GreaterThan b, c LessThanOrEqual d) if a.semanticEquals(c) => + isGreaterOrEqualTo(b, d, a.dataType) + case (a LessThanOrEqual b, c GreaterThan d) if a.semanticEquals(c) => + isGreaterOrEqualTo(d, b, a.dataType) + case (a EqualTo b, Not(c EqualTo d)) if a.semanticEquals(c) => + isEqualTo(b, d, a.dataType) + case _ => false + } + } else { + false + } + } + + private def normalizeExpression(expr: Expression): Option[Expression] = { + expr match { + case gt @ GreaterThan(_, r) if r.foldable => + Some(gt) + case l GreaterThan r if l.foldable => + Some(LessThanOrEqual(r, l)) + case lt @ LessThan(_, r) if r.foldable => + Some(lt) + case l LessThan r if l.foldable => + Some(GreaterThanOrEqual(r, l)) + case gte @ GreaterThanOrEqual(_, r) if r.foldable => + Some(gte) + case l GreaterThanOrEqual r if l.foldable => + Some(LessThan(r, l)) + case lte @ LessThanOrEqual(_, r) if r.foldable => + Some(lte) + case l LessThanOrEqual r if l.foldable => + Some(GreaterThan(r, l)) + case eq @ EqualTo(_, r) if r.foldable => + Some(eq) + case l EqualTo r if l.foldable => + Some(EqualTo(r, l)) + case not @ Not(EqualTo(l, r)) if r.foldable => + Some(not) + case Not(l EqualTo r) if l.foldable => + Some(Not(EqualTo(r, l))) + case _ => None + } + } + + private def isGreaterOrEqualTo( + left: Expression, right: Expression, dataType: DataType): Boolean = dataType match { + case ShortType => left.eval().asInstanceOf[Short] >= right.eval().asInstanceOf[Short] + case IntegerType => left.eval().asInstanceOf[Int] >= right.eval().asInstanceOf[Int] + case LongType => left.eval().asInstanceOf[Long] >= right.eval().asInstanceOf[Long] + case FloatType => left.eval().asInstanceOf[Float] >= right.eval().asInstanceOf[Float] + case DoubleType => left.eval().asInstanceOf[Double] >= right.eval().asInstanceOf[Double] + case DecimalType.Fixed(_, _) => + left.eval().asInstanceOf[Decimal] >= right.eval().asInstanceOf[Decimal] + case _ => false + } + + private def isEqualTo( + left: Expression, right: Expression, dataType: DataType): Boolean = dataType match { + case ShortType => left.eval().asInstanceOf[Short] == right.eval().asInstanceOf[Short] + case IntegerType => left.eval().asInstanceOf[Int] == right.eval().asInstanceOf[Int] + case LongType => left.eval().asInstanceOf[Long] == right.eval().asInstanceOf[Long] + case FloatType => left.eval().asInstanceOf[Float] == right.eval().asInstanceOf[Float] + case DoubleType => left.eval().asInstanceOf[Double] == right.eval().asInstanceOf[Double] + case DecimalType.Fixed(_, _) => + left.eval().asInstanceOf[Decimal] == right.eval().asInstanceOf[Decimal] + case _ => false + } + + def countPredicatesInExpressions(expression: Expression): Int = { + expression match { + // If the expression is a predicate, count it + case predicate: Predicate => 1 + // If the expression is a complex expression, recursively count predicates in its children + case complexExpression => + complexExpression.children.map(countPredicatesInExpressions).sum + } + } + + def normalizeJoinExpression(expr: Expression): Expression = expr match { + case BinaryComparison(left, right) => + val sortedChildren = Seq(left, right).sortBy(_.toString) + expr.withNewChildren(sortedChildren) + case _ => expr.transform { + case a: AttributeReference => UnresolvedAttribute(a.name) + } + } + + def extendedNormalizeExpression(expr: Expression): Expression = { + expr.transformUp { + // Normalize attributes by name, ignoring exprId + case attr: AttributeReference => + // You can adjust the normalization based on what aspects of the attributes are significant for your comparison + // Here, we're focusing on the name and data type, but excluding metadata and other identifiers + AttributeReference(attr.name, attr.dataType, attr.nullable)(exprId = NamedExpression.newExprId, qualifier = attr.qualifier) + + // Unwrap aliases to compare the underlying expressions directly + case Alias(child, _) => child + + case Cast(child, dataType,_,_) => + // Normalize child and retain the cast's target data type + Cast(extendedNormalizeExpression(child), dataType) + + // Handle commutative operations by sorting their children + case b: BinaryOperator if b.isInstanceOf[Add] || b.isInstanceOf[Multiply] => + val sortedChildren = b.children.sortBy(_.toString()) + b.withNewChildren(sortedChildren) + + // Further transformations can be added here to handle other specific cases as needed + } + } + + // Function to compare two join conditions after normalization + def isJoinConditionEqual(condition1: Option[Expression], condition2: Option[Expression]): Boolean = { + (condition1, condition2) match { + case (Some(expr1), Some(expr2)) => + + // Check join condition + val pattern = "#\\d+" + val result1 = expr1.toString().replaceAll(pattern, "") + val result2 = expr2.toString().replaceAll(pattern, "") + return result1 == result2 + + /*val normalizedExpr1 = normalizeJoinExpression(expr1) + val normalizedExpr2 = normalizeJoinExpression(expr2) + if(normalizedExpr1.semanticEquals(normalizedExpr2)) + return true + + val extendedNormalizedExpr1 = extendedNormalizeExpression(expr1) + val extendedNormalizedExpr2 = extendedNormalizeExpression(expr2) + if(extendedNormalizedExpr1.semanticEquals(extendedNormalizedExpr2)) + return true*/ + + case (None, None) => true // Both conditions are None + case _ => false // One condition is None and the other is not + } + } + + // Function to check if two joins are the same + def areJoinsEqual(join1: Join, join2: Join): Boolean = { + // Check join type + if (join1.joinType != join2.joinType) return false + + if (!isJoinConditionEqual(join1.condition, join2.condition)) return false + + // Joins are equal + true + } + + // class to hold Expression with boolean flag + case class ExpressionHolder(val expression: Expression, val propagate: Boolean) + /** + * Try to merge two `Aggregate`s by traverse down recursively. + * + * @return The optional tuple as follows: + * 1. the merged plan + * 2. the attribute mapping from the old to the merged version + * 3. optional filters of both plans that need to be propagated and merged in an + * ancestor `Aggregate` node if possible. + */ + private def mergePlan( + left: LogicalPlan, + right: LogicalPlan): Option[(LogicalPlan, AttributeMap[Attribute], Seq[ExpressionHolder])] = { + (left, right) match { + case (la: Aggregate, ra: Aggregate) => + mergePlan(la.child, ra.child).map { case (newChild, outputMap, filters) => + val rightAggregateExprs = ra.aggregateExpressions.map(mapAttributes(_, outputMap)) + + // Filter the sequence to include only those entries where the propagate is true + val filtersToBePropagated: Seq[ExpressionHolder] = filters.filter(_.propagate) + val mergedAggregateExprs = if (filtersToBePropagated.length == 2) { + Seq( + (la.aggregateExpressions, filtersToBePropagated.head.expression), + (rightAggregateExprs, filtersToBePropagated.last.expression) + ).flatMap { case (aggregateExpressions, propagatedFilter) => + aggregateExpressions.map { ne => + ne.transform { + case ae @ AggregateExpression(_, _, _, filterOpt, _) => + val newFilter = filterOpt.map { filter => + //And(propagatedFilter, filter) + filter + }.orElse(Some(propagatedFilter)) + ae.copy(filter = newFilter) + }.asInstanceOf[NamedExpression] + } + } + } else { + la.aggregateExpressions ++ rightAggregateExprs + } + + (Aggregate(Seq.empty, mergedAggregateExprs, newChild), AttributeMap.empty, Seq.empty) + } + case (lp: Project, rp: Project) => + val mergedProjectList = ArrayBuffer[NamedExpression](lp.projectList: _*) + + mergePlan(lp.child, rp.child).map { case (newChild, outputMap, filters) => + val allFilterReferences = filters.flatMap(_.expression.references) + val newOutputMap = AttributeMap((rp.projectList ++ allFilterReferences).map { ne => + val mapped = mapAttributes(ne, outputMap) + + val withoutAlias = mapped match { + case Alias(child, _) => child + case e => e + } + + val outputAttr = mergedProjectList.find { + case Alias(child, _) => child semanticEquals withoutAlias + case e => e semanticEquals withoutAlias + }.getOrElse { + mergedProjectList += mapped + mapped + }.toAttribute + ne.toAttribute -> outputAttr + }) + + (Project(mergedProjectList.toSeq, newChild), newOutputMap, filters) + } + case (lf: Filter, rf: Filter) + if isCheapPredicate(lf.condition) && isCheapPredicate(rf.condition) => + + val pattern = "#\\d+" + // Replace the matched pattern with an empty string + val result1 = lf.condition.toString().replaceAll(pattern, "") + val result2 = rf.condition.toString().replaceAll(pattern, "") + + if (result1 == result2 || lf.condition == rf.condition || checkCondition(lf.condition, rf.condition)) { + // If both conditions are the same, proceed with one of them. + mergePlan(lf.child, rf.child).map { case (newChild, outputMap, filters) => + (Filter(lf.condition, newChild), outputMap, Seq(ExpressionHolder(lf.condition, false))) + } + } else { + mergePlan(lf.child, rf.child).map { + case (newChild, outputMap, filters) => + val mappedRightCondition = mapAttributes(rf.condition, outputMap) + val (newLeftCondition, newRightCondition) = if (filters.length == 2) { + (And(lf.condition, filters.head.expression), And(mappedRightCondition, filters.last.expression)) + } else { + (lf.condition, mappedRightCondition) + } + val newCondition = Or(newLeftCondition, newRightCondition) + (Filter(newCondition, newChild), outputMap, Seq(ExpressionHolder(newLeftCondition,true), ExpressionHolder(newRightCondition,true))) + } + } + case (lj: Join, rj: Join) => + if (areJoinsEqual(lj, rj)) { + mergePlan(lj.left, rj.left).flatMap { case (newLeft, leftOutputMap, leftFilters) => + mergePlan(lj.right, rj.right).map { case (newRight, rightOutputMap, rightFilters) => + val newJoin = Join(newLeft, newRight, lj.joinType, lj.condition, lj.hint) + val mergedOutputMap = leftOutputMap ++ rightOutputMap + val mergedFilters = leftFilters ++ rightFilters + (newJoin, mergedOutputMap, mergedFilters) + } + } + } else { + None + } + case (ll: LeafNode, rl: LeafNode) => + checkIdenticalPlans(rl, ll).map { outputMap => + (ll, outputMap, Seq.empty) + } + case (ls: SerializeFromObject, rs: SerializeFromObject) => + checkIdenticalPlans(rs, ls).map { outputMap => + (ls, outputMap, Seq.empty) + } + case _ => None + } + } + + def apply(plan: LogicalPlan): LogicalPlan = { + if (!ColumnarPluginConfig.getConf.combineJoinedAggregatesEnabled) return plan + //apply rule on children first then itself + plan.transformUpWithPruning(_.containsAnyPattern(JOIN, AGGREGATE)) { + case j @ Join(left: Aggregate, right: Aggregate, joinType, None, _) + if isSupportedJoinType(joinType) && + left.groupingExpressions.isEmpty && right.groupingExpressions.isEmpty => + val mergedAggregate = mergePlan(left, right) + mergedAggregate.map(_._1).getOrElse(j) + } + } +} \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/optimizer/MergeScalarSubqueriesHelper.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/optimizer/MergeScalarSubqueriesHelper.scala new file mode 100644 index 0000000000000000000000000000000000000000..6d70bc416d8c4817b8e72ef883f698b6a16bd450 --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/catalyst/optimizer/MergeScalarSubqueriesHelper.scala @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, Expression} +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan + +/** + * The helper class used to merge scalar subqueries. + */ +trait MergeScalarSubqueriesHelper { + + // If 2 plans are identical return the attribute mapping from the left to the right. + protected def checkIdenticalPlans( + left: LogicalPlan, right: LogicalPlan): Option[AttributeMap[Attribute]] = { + if (left.canonicalized == right.canonicalized) { + Some(AttributeMap(left.output.zip(right.output))) + } else { + None + } + } + + protected def mapAttributes[T <: Expression](expr: T, outputMap: AttributeMap[Attribute]): T = { + expr.transform { + case a: Attribute => outputMap.getOrElse(a, a) + }.asInstanceOf[T] + } +} \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarHashAggregateExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarHashAggregateExec.scala index 8eff1774a10663f0e7249b9ad1f0abe991ced544..55fba9f2b750d215326d6835ca849ba66d994caf 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarHashAggregateExec.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarHashAggregateExec.scala @@ -198,8 +198,9 @@ case class ColumnarHashAggregateExec( val finalOut = groupingExpressions.map(_.toAttribute) ++ aggregateAttributes finalOut.map( exp => sparkTypeToOmniType(exp.dataType, exp.metadata)).toArray + val finalAttrExprsIdMap = getExprIdMap(finalOut) val projectExpressions: Array[AnyRef] = resultExpressions.map( - exp => rewriteToOmniJsonExpressionLiteral(exp, getExprIdMap(finalOut))).toArray + exp => rewriteToOmniJsonExpressionLiteral(exp, finalAttrExprsIdMap)).toArray if (!isSimpleColumnForAll(projectExpressions.map(expr => expr.toString))) { checkOmniJsonWhiteList("", projectExpressions) } @@ -297,13 +298,14 @@ case class ColumnarHashAggregateExec( child.executeColumnar().mapPartitionsWithIndex { (index, iter) => val columnarConf = ColumnarPluginConfig.getSessionConf - val hashAggSpillRowThreshold = columnarConf.columnarHashAggSpillRowThreshold + val spillWriteBufferSize = columnarConf.columnarSpillWriteBufferSize val spillMemPctThreshold = columnarConf.columnarSpillMemPctThreshold val spillDirDiskReserveSize = columnarConf.columnarSpillDirDiskReserveSize val hashAggSpillEnable = columnarConf.enableHashAggSpill + val hashAggSpillRowThreshold = columnarConf.columnarHashAggSpillRowThreshold val spillDirectory = generateSpillDir(tmpSparkConf, "columnarHashAggSpill") val sparkSpillConf = new SparkSpillConfig(hashAggSpillEnable, spillDirectory, - spillDirDiskReserveSize, hashAggSpillRowThreshold, spillMemPctThreshold) + spillDirDiskReserveSize, hashAggSpillRowThreshold, spillMemPctThreshold, spillWriteBufferSize) val startCodegen = System.nanoTime() val operator = OmniAdaptorUtil.getAggOperator(groupingExpressions, @@ -373,10 +375,11 @@ case class ColumnarHashAggregateExec( } if (finalStep) { val finalOut = groupingExpressions.map(_.toAttribute) ++ aggregateAttributes + val finalAttrExprsIdMap = getExprIdMap(finalOut) val projectInputTypes = finalOut.map( exp => sparkTypeToOmniType(exp.dataType, exp.metadata)).toArray val projectExpressions = resultExpressions.map( - exp => rewriteToOmniJsonExpressionLiteral(exp, getExprIdMap(finalOut))).toArray + exp => rewriteToOmniJsonExpressionLiteral(exp, finalAttrExprsIdMap)).toArray dealPartitionData(null, null, addInputTime, omniCodegenTime, getOutputTime, projectInputTypes, projectExpressions, hashAggIter, this.schema) diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarLimit.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarLimit.scala index 4ce57e12ab8cb9f939664c4ce8a59b441d65ffd2..3603ecccc9df432c2b2e19650204c2341ebf15d5 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarLimit.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarLimit.scala @@ -267,6 +267,7 @@ case class ColumnarTakeOrderedAndProjectExec( child.output, SinglePartition, serializer, + handleRow = false, writeMetrics, longMetric("dataSize"), longMetric("bytesSpilled"), diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarShuffleExchangeExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarShuffleExchangeExec.scala index 6e6588304b3ba809fc9a9a561e66703a4f988122..e1e07dd48e69443e3f95c01d1f3260d5ae25d083 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarShuffleExchangeExec.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarShuffleExchangeExec.scala @@ -54,10 +54,13 @@ import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.spark.util.MutablePair import org.apache.spark.util.random.XORShiftRandom +import nova.hetu.omniruntime.vector.IntVec + case class ColumnarShuffleExchangeExec( override val outputPartitioning: Partitioning, child: SparkPlan, - shuffleOrigin: ShuffleOrigin = ENSURE_REQUIREMENTS) + shuffleOrigin: ShuffleOrigin = ENSURE_REQUIREMENTS, + handleRow: Boolean = false) extends ShuffleExchangeLike { private lazy val writeMetrics = @@ -78,13 +81,14 @@ case class ColumnarShuffleExchangeExec( "numPartitions" -> SQLMetrics.createMetric(sparkContext, "number of partitions") ) ++ readMetrics ++ writeMetrics - override def nodeName: String = "OmniColumnarShuffleExchange" + override def nodeName: String = if (!handleRow) "OmniColumnarShuffleExchange" else "OmniRowShuffleExchange" override def supportsColumnar: Boolean = true val serializer: Serializer = new ColumnarBatchSerializer( longMetric("avgReadBatchNumRows"), - longMetric("numOutputRows")) + longMetric("numOutputRows"), + handleRow) @transient lazy val inputColumnarRDD: RDD[ColumnarBatch] = child.executeColumnar() @@ -118,6 +122,7 @@ case class ColumnarShuffleExchangeExec( child.output, outputPartitioning, serializer, + handleRow, writeMetrics, longMetric("dataSize"), longMetric("bytesSpilled"), @@ -189,6 +194,7 @@ object ColumnarShuffleExchangeExec extends Logging { outputAttributes: Seq[Attribute], newPartitioning: Partitioning, serializer: Serializer, + handleRow: Boolean, writeMetrics: Map[String, SQLMetric], dataSize: SQLMetric, bytesSpilled: SQLMetric, @@ -277,16 +283,25 @@ object ColumnarShuffleExchangeExec extends Logging { val addPid2ColumnBatch = addPidToColumnBatch() cbIter.filter(cb => cb.numRows != 0 && cb.numCols != 0).map { cb => - val pidArr = new Array[Int](cb.numRows) - (0 until cb.numRows).foreach { i => - val row = cb.getRow(i) - val pid = part.get.getPartition(partitionKeyExtractor(row)) - pidArr(i) = pid - } - val pidVec = new IntVec(cb.numRows) - pidVec.put(pidArr, 0, 0, cb.numRows) + var pidVec: IntVec = null + try { + val pidArr = new Array[Int](cb.numRows) + (0 until cb.numRows).foreach { i => + val row = cb.getRow(i) + val pid = part.get.getPartition(partitionKeyExtractor(row)) + pidArr(i) = pid + } + pidVec = new IntVec(cb.numRows) + pidVec.put(pidArr, 0, 0, cb.numRows) - addPid2ColumnBatch(pidVec, cb) + addPid2ColumnBatch(pidVec, cb) + } catch { + case e: Exception => + if (pidVec != null) { + pidVec.close() + } + throw e + } } } @@ -308,8 +323,17 @@ object ColumnarShuffleExchangeExec extends Logging { val getRoundRobinPid = getRoundRobinPartitionKey val addPid2ColumnBatch = addPidToColumnBatch() cbIter.map { cb => - val pidVec = getRoundRobinPid(cb, numPartitions) - addPid2ColumnBatch(pidVec, cb) + var pidVec: IntVec = null + try { + pidVec = getRoundRobinPid(cb, numPartitions) + addPid2ColumnBatch(pidVec, cb) + } catch { + case e: Exception => + if (pidVec != null) { + pidVec.close() + } + throw e + } } }, isOrderSensitive = isOrderSensitive) case RangePartitioning(sortingExpressions, _) => @@ -349,17 +373,26 @@ object ColumnarShuffleExchangeExec extends Logging { }) cbIter.map { cb => - val vecs = transColBatchToOmniVecs(cb, true) - op.addInput(new VecBatch(vecs, cb.numRows())) - val res = op.getOutput - if (res.hasNext) { - val retBatch = res.next() - val pidVec = retBatch.getVectors()(0) - // close return VecBatch - retBatch.close() - addPid2ColumnBatch(pidVec.asInstanceOf[IntVec], cb) - } else { - throw new Exception("Empty Project Operator Result...") + var pidVec: IntVec = null + try { + val vecs = transColBatchToOmniVecs(cb, true) + op.addInput(new VecBatch(vecs, cb.numRows())) + val res = op.getOutput + if (res.hasNext) { + val retBatch = res.next() + pidVec = retBatch.getVectors()(0).asInstanceOf[IntVec] + // close return VecBatch + retBatch.close() + addPid2ColumnBatch(pidVec, cb) + } else { + throw new Exception("Empty Project Operator Result...") + } + } catch { + case e: Exception => + if (pidVec != null) { + pidVec.close() + } + throw e } } }, isOrderSensitive = isOrderSensitive) @@ -393,6 +426,7 @@ object ColumnarShuffleExchangeExec extends Logging { rddWithPartitionId, new PartitionIdPassthrough(newPartitioning.numPartitions), serializer, + handleRow = handleRow, shuffleWriterProcessor = createShuffleWriteProcessor(writeMetrics), partitionInfo = partitionInfo, dataSize = dataSize, diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarSortExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarSortExec.scala index 55e4c6d5d38e0c7e7d4b33943089ba5dabd8e4cc..d94d256568e70a80e2b593f7281765826695fd3d 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarSortExec.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarSortExec.scala @@ -93,13 +93,14 @@ case class ColumnarSortExec( child.executeColumnar().mapPartitionsWithIndexInternal { (_, iter) => val columnarConf = ColumnarPluginConfig.getSessionConf - val sortSpillRowThreshold = columnarConf.columnarSortSpillRowThreshold + val spillWriteBufferSize = columnarConf.columnarSpillWriteBufferSize val spillMemPctThreshold = columnarConf.columnarSpillMemPctThreshold val spillDirDiskReserveSize = columnarConf.columnarSpillDirDiskReserveSize val sortSpillEnable = columnarConf.enableSortSpill + val sortSpillRowThreshold = columnarConf.columnarSortSpillRowThreshold val spillDirectory = generateSpillDir(tmpSparkConf, "columnarSortSpill") val sparkSpillConf = new SparkSpillConfig(sortSpillEnable, spillDirectory, spillDirDiskReserveSize, - sortSpillRowThreshold, spillMemPctThreshold) + sortSpillRowThreshold, spillMemPctThreshold, spillWriteBufferSize) val startCodegen = System.nanoTime() val radixSortEnable = columnarConf.enableRadixSort diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarWindowExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarWindowExec.scala index 7d1828c27d9fa009afb343acbce821fbf0ca1c5d..837760ac89c7f714a2cdb8c75a41f6969840f2ca 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarWindowExec.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarWindowExec.scala @@ -133,7 +133,11 @@ case class ColumnarWindowExec(windowExpression: Seq[NamedExpression], val windowFrameEndTypes = new Array[OmniWindowFrameBoundType](winExpressions.size) val windowFrameEndChannels = new Array[Int](winExpressions.size) var attrMap: Map[String, Int] = Map() - + + if (winExpressions.isEmpty) { + throw new UnsupportedOperationException(s"Unsupported empty winExpressions") + } + for (sortAttr <- orderSpec) { if (!sortAttr.child.isInstanceOf[AttributeReference]) { throw new UnsupportedOperationException(s"Unsupported sort col : ${sortAttr.child.nodeName}") @@ -356,13 +360,14 @@ case class ColumnarWindowExec(windowExpression: Seq[NamedExpression], val windowExpressionWithProjectConstant = windowExpressionWithProject child.executeColumnar().mapPartitionsWithIndexInternal { (index, iter) => val columnarConf = ColumnarPluginConfig.getSessionConf - val windowSpillEnable = columnarConf.enableWindowSpill + val spillWriteBufferSize = columnarConf.columnarSpillWriteBufferSize + val spillMemPctThreshold = columnarConf.columnarSpillMemPctThreshold val spillDirDiskReserveSize = columnarConf.columnarSpillDirDiskReserveSize + val windowSpillEnable = columnarConf.enableWindowSpill val windowSpillRowThreshold = columnarConf.columnarWindowSpillRowThreshold - val spillMemPctThreshold = columnarConf.columnarSpillMemPctThreshold val spillDirectory = generateSpillDir(tmpSparkConf, "columnarWindowSpill") val sparkSpillConfig = new SparkSpillConfig(windowSpillEnable, spillDirectory, - spillDirDiskReserveSize, windowSpillRowThreshold, spillMemPctThreshold) + spillDirDiskReserveSize, windowSpillRowThreshold, spillMemPctThreshold, spillWriteBufferSize) val startCodegen = System.nanoTime() val windowOperatorFactory = new OmniWindowWithExprOperatorFactory(sourceTypes, outputCols, diff --git a/omnioperator/omniop-spark-extension/java/src/main/java/org/apache/spark/sql/execution/datasources/orc/OmniOrcColumnarBatchReader.java b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OmniOrcColumnarBatchReader.java similarity index 96% rename from omnioperator/omniop-spark-extension/java/src/main/java/org/apache/spark/sql/execution/datasources/orc/OmniOrcColumnarBatchReader.java rename to omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OmniOrcColumnarBatchReader.java index aeaa10faab50bb0490529fadebb331db2c60efa5..93950e9f0f1b9339587893b95f99c70967024acc 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/java/org/apache/spark/sql/execution/datasources/orc/OmniOrcColumnarBatchReader.java +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OmniOrcColumnarBatchReader.java @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.datasources.orc; import com.google.common.annotations.VisibleForTesting; +import com.huawei.boostkit.spark.expression.OmniExpressionAdaptor; import com.huawei.boostkit.spark.jni.OrcColumnarBatchScanReader; import nova.hetu.omniruntime.vector.Vec; import org.apache.hadoop.conf.Configuration; @@ -79,6 +80,8 @@ public class OmniOrcColumnarBatchReader extends RecordReader(); + // collect read cols types + ArrayList typeBuilder = new ArrayList<>(); for (int i = 0; i < requiredfieldNames.length; i++) { String target = requiredfieldNames[i]; boolean is_find = false; @@ -163,6 +168,7 @@ public class OmniOrcColumnarBatchReader extends RecordReader - getIndexArray(buildOutput, projectList) + getIndexArray(buildOutput, projectExprIdList) case LeftExistence(_) => Array[Int]() case x => throw new UnsupportedOperationException(s"ColumnBroadcastHashJoin Join-type[$x] is not supported!") } + + val buildOutputExprIdMap = OmniExpressionAdaptor.getExprIdMap(buildOutput.map(_.toAttribute)) val buildJoinColsExp = buildKeys.map { x => - OmniExpressionAdaptor.rewriteToOmniJsonExpressionLiteral(x, - OmniExpressionAdaptor.getExprIdMap(buildOutput.map(_.toAttribute))) + OmniExpressionAdaptor.rewriteToOmniJsonExpressionLiteral(x, buildOutputExprIdMap) }.toArray val relation = buildPlan.executeBroadcast[ColumnarHashedRelation]() - val prunedBuildOutput = pruneOutput(buildOutput, projectList) + val prunedBuildOutput = pruneOutput(buildOutput, projectExprIdList) val buildOutputTypes = new Array[DataType](prunedBuildOutput.size) // {2,2}, buildOutput:col1#12,col2#13 prunedBuildOutput.zipWithIndex.foreach { case (att, i) => buildOutputTypes(i) = OmniExpressionAdaptor.sparkTypeToOmniType(att.dataType, att.metadata) @@ -324,14 +326,17 @@ case class ColumnarBroadcastHashJoinExec( streamedOutput.zipWithIndex.foreach { case (attr, i) => probeTypes(i) = OmniExpressionAdaptor.sparkTypeToOmniType(attr.dataType, attr.metadata) } - val probeOutputCols = getIndexArray(streamedOutput, projectList) // {0,1} + val probeOutputCols = getIndexArray(streamedOutput, projectExprIdList) // {0,1} + val probeOutputExprIdMap = OmniExpressionAdaptor.getExprIdMap(streamedOutput.map(_.toAttribute)) val probeHashColsExp = streamedKeys.map { x => - OmniExpressionAdaptor.rewriteToOmniJsonExpressionLiteral(x, - OmniExpressionAdaptor.getExprIdMap(streamedOutput.map(_.toAttribute))) + OmniExpressionAdaptor.rewriteToOmniJsonExpressionLiteral(x, probeOutputExprIdMap) }.toArray + val prunedStreamedOutput = pruneOutput(streamedOutput, projectExprIdList) + val projectListIndex = getProjectListIndex(projectExprIdList, prunedStreamedOutput, prunedBuildOutput) val lookupJoinType = OmniExpressionAdaptor.toOmniJoinType(joinType) val canShareBuildOp = (lookupJoinType != OMNI_JOIN_TYPE_RIGHT && lookupJoinType != OMNI_JOIN_TYPE_FULL) + streamedPlan.executeColumnar().mapPartitionsWithIndexInternal { (index, iter) => val filter: Optional[String] = condition match { case Some(expr) => @@ -341,7 +346,7 @@ case class ColumnarBroadcastHashJoinExec( Optional.empty() } - def createBuildOpFactoryAndOp(): (OmniHashBuilderWithExprOperatorFactory, OmniOperator) = { + def createBuildOpFactoryAndOp(isShared: Boolean): (OmniHashBuilderWithExprOperatorFactory, OmniOperator) = { val startBuildCodegen = System.nanoTime() val opFactory = new OmniHashBuilderWithExprOperatorFactory(lookupJoinType, buildTypes, buildJoinColsExp, 1, @@ -350,6 +355,10 @@ case class ColumnarBroadcastHashJoinExec( val op = opFactory.createOperator() buildCodegenTime += NANOSECONDS.toMillis(System.nanoTime() - startBuildCodegen) + if (isShared) { + OmniHashBuilderWithExprOperatorFactory.saveHashBuilderOperatorAndFactory(buildPlan.id, + opFactory, op) + } val deserializer = VecBatchSerializerFactory.create() relation.value.buildData.foreach { input => val startBuildInput = System.nanoTime() @@ -357,7 +366,19 @@ case class ColumnarBroadcastHashJoinExec( buildAddInputTime += NANOSECONDS.toMillis(System.nanoTime() - startBuildInput) } val startBuildGetOp = System.nanoTime() - op.getOutput + try { + op.getOutput + } catch { + case e: Exception => { + if (isShared) { + OmniHashBuilderWithExprOperatorFactory.dereferenceHashBuilderOperatorAndFactory(buildPlan.id) + } else { + op.close() + opFactory.close() + } + throw new RuntimeException("HashBuilder getOutput failed") + } + } buildGetOutputTime += NANOSECONDS.toMillis(System.nanoTime() - startBuildGetOp) (opFactory, op) } @@ -369,11 +390,9 @@ case class ColumnarBroadcastHashJoinExec( try { buildOpFactory = OmniHashBuilderWithExprOperatorFactory.getHashBuilderOperatorFactory(buildPlan.id) if (buildOpFactory == null) { - val (opFactory, op) = createBuildOpFactoryAndOp() + val (opFactory, op) = createBuildOpFactoryAndOp(true) buildOpFactory = opFactory buildOp = op - OmniHashBuilderWithExprOperatorFactory.saveHashBuilderOperatorAndFactory(buildPlan.id, - buildOpFactory, buildOp) } } catch { case e: Exception => { @@ -383,7 +402,7 @@ case class ColumnarBroadcastHashJoinExec( OmniHashBuilderWithExprOperatorFactory.gLock.unlock() } } else { - val (opFactory, op) = createBuildOpFactoryAndOp() + val (opFactory, op) = createBuildOpFactoryAndOp(false) buildOpFactory = opFactory buildOp = op } @@ -410,19 +429,17 @@ case class ColumnarBroadcastHashJoinExec( } }) - val streamedPlanOutput = pruneOutput(streamedPlan.output, projectList) - val prunedOutput = streamedPlanOutput ++ prunedBuildOutput val resultSchema = this.schema val reverse = buildSide == BuildLeft var left = 0 - var leftLen = streamedPlanOutput.size - var right = streamedPlanOutput.size + var leftLen = prunedStreamedOutput.size + var right = prunedStreamedOutput.size var rightLen = output.size if (reverse) { - left = streamedPlanOutput.size + left = prunedStreamedOutput.size leftLen = output.size right = 0 - rightLen = streamedPlanOutput.size + rightLen = prunedStreamedOutput.size } val iterBatch = new Iterator[ColumnarBatch] { @@ -468,7 +485,7 @@ case class ColumnarBroadcastHashJoinExec( val vecs = OmniColumnVector .allocateColumns(result.getRowCount, resultSchema, false) if (projectList.nonEmpty) { - reorderVecs(prunedOutput, projectList, resultVecs, vecs) + reorderOutputVecs(projectListIndex, resultVecs, vecs) } else { var index = 0 for (i <- left until leftLen) { @@ -566,4 +583,4 @@ case class ColumnarBroadcastHashJoinExec( } -} \ No newline at end of file +} diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/joins/ColumnarShuffledHashJoinExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/joins/ColumnarShuffledHashJoinExec.scala index e041a1fb119bc63bf689361534c207304383e83e..6e1f76d75a540345b2e5dacad61361fd39f185e8 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/joins/ColumnarShuffledHashJoinExec.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/joins/ColumnarShuffledHashJoinExec.scala @@ -24,7 +24,7 @@ import com.huawei.boostkit.spark.Constant.IS_SKIP_VERIFY_EXP import com.huawei.boostkit.spark.expression.OmniExpressionAdaptor import com.huawei.boostkit.spark.expression.OmniExpressionAdaptor.{checkOmniJsonWhiteList, isSimpleColumn, isSimpleColumnForAll} import com.huawei.boostkit.spark.util.OmniAdaptorUtil -import com.huawei.boostkit.spark.util.OmniAdaptorUtil.{getIndexArray, pruneOutput, reorderVecs, transColBatchToOmniVecs} +import com.huawei.boostkit.spark.util.OmniAdaptorUtil.{getExprIdForProjectList, getIndexArray, getProjectListIndex,pruneOutput, reorderOutputVecs, transColBatchToOmniVecs} import nova.hetu.omniruntime.`type`.DataType import nova.hetu.omniruntime.operator.config.{OperatorConfig, OverflowConfig, SpillConfig} import nova.hetu.omniruntime.operator.join.{OmniHashBuilderWithExprOperatorFactory, OmniLookupJoinWithExprOperatorFactory, OmniLookupOuterJoinWithExprOperatorFactory} @@ -62,8 +62,8 @@ case class ColumnarShuffledHashJoinExec( s""" |$formattedNodeName |$simpleStringWithNodeId - |${ExplainUtils.generateFieldString("buildOutput", buildOutput ++ buildOutput.map(_.dataType))} - |${ExplainUtils.generateFieldString("streamedOutput", streamedOutput ++ streamedOutput.map(_.dataType))} + |${ExplainUtils.generateFieldString("buildInput", buildOutput ++ buildOutput.map(_.dataType))} + |${ExplainUtils.generateFieldString("streamedInput", streamedOutput ++ streamedOutput.map(_.dataType))} |${ExplainUtils.generateFieldString("leftKeys", leftKeys ++ leftKeys.map(_.dataType))} |${ExplainUtils.generateFieldString("rightKeys", rightKeys ++ rightKeys.map(_.dataType))} |${ExplainUtils.generateFieldString("condition", joinCondStr)} @@ -131,9 +131,9 @@ case class ColumnarShuffledHashJoinExec( buildTypes(i) = OmniExpressionAdaptor.sparkTypeToOmniType(att.dataType, att.metadata) } + val buildOutputExprIdMap = OmniExpressionAdaptor.getExprIdMap(buildOutput.map(_.toAttribute)) val buildJoinColsExp = buildKeys.map { x => - OmniExpressionAdaptor.rewriteToOmniJsonExpressionLiteral(x, - OmniExpressionAdaptor.getExprIdMap(buildOutput.map(_.toAttribute))) + OmniExpressionAdaptor.rewriteToOmniJsonExpressionLiteral(x, buildOutputExprIdMap) }.toArray if (!isSimpleColumnForAll(buildJoinColsExp)) { @@ -145,9 +145,9 @@ case class ColumnarShuffledHashJoinExec( probeTypes(i) = OmniExpressionAdaptor.sparkTypeToOmniType(attr.dataType, attr.metadata) } + val streamOutputExprIdMap = OmniExpressionAdaptor.getExprIdMap(streamedOutput.map(_.toAttribute)) val probeHashColsExp = streamedKeys.map { x => - OmniExpressionAdaptor.rewriteToOmniJsonExpressionLiteral(x, - OmniExpressionAdaptor.getExprIdMap(streamedOutput.map(_.toAttribute))) + OmniExpressionAdaptor.rewriteToOmniJsonExpressionLiteral(x, streamOutputExprIdMap) }.toArray if (!isSimpleColumnForAll(probeHashColsExp)) { @@ -186,21 +186,22 @@ case class ColumnarShuffledHashJoinExec( buildTypes(i) = OmniExpressionAdaptor.sparkTypeToOmniType(att.dataType, att.metadata) } + val projectExprIdList = getExprIdForProjectList(projectList) val buildOutputCols: Array[Int] = joinType match { case Inner | FullOuter | LeftOuter | RightOuter => - getIndexArray(buildOutput, projectList) + getIndexArray(buildOutput, projectExprIdList) case LeftExistence(_) => Array[Int]() case x => throw new UnsupportedOperationException(s"ColumnShuffledHashJoin Join-type[$x] is not supported!") } + val buildOutputExprIdMap = OmniExpressionAdaptor.getExprIdMap(buildOutput.map(_.toAttribute)) val buildJoinColsExp = buildKeys.map { x => - OmniExpressionAdaptor.rewriteToOmniJsonExpressionLiteral(x, - OmniExpressionAdaptor.getExprIdMap(buildOutput.map(_.toAttribute))) + OmniExpressionAdaptor.rewriteToOmniJsonExpressionLiteral(x, buildOutputExprIdMap) }.toArray - val prunedBuildOutput = pruneOutput(buildOutput, projectList) + val prunedBuildOutput = pruneOutput(buildOutput, projectExprIdList) val buildOutputTypes = new Array[DataType](prunedBuildOutput.size) // {2,2}, buildOutput:col1#12,col2#13 prunedBuildOutput.zipWithIndex.foreach { case (att, i) => buildOutputTypes(i) = OmniExpressionAdaptor.sparkTypeToOmniType(att.dataType, att.metadata) @@ -210,12 +211,14 @@ case class ColumnarShuffledHashJoinExec( streamedOutput.zipWithIndex.foreach { case (attr, i) => probeTypes(i) = OmniExpressionAdaptor.sparkTypeToOmniType(attr.dataType, attr.metadata) } - val probeOutputCols = getIndexArray(streamedOutput, projectList) + val probeOutputCols = getIndexArray(streamedOutput, projectExprIdList) + val streamOutputExprIdMap = OmniExpressionAdaptor.getExprIdMap(streamedOutput.map(_.toAttribute)) val probeHashColsExp = streamedKeys.map { x => - OmniExpressionAdaptor.rewriteToOmniJsonExpressionLiteral(x, - OmniExpressionAdaptor.getExprIdMap(streamedOutput.map(_.toAttribute))) + OmniExpressionAdaptor.rewriteToOmniJsonExpressionLiteral(x, streamOutputExprIdMap) }.toArray + val prunedStreamedOutput = pruneOutput(streamedOutput, projectExprIdList) + val projectListIndex = getProjectListIndex(projectExprIdList, prunedStreamedOutput, prunedBuildOutput) streamedPlan.executeColumnar.zipPartitions(buildPlan.executeColumnar()) { (streamIter, buildIter) => val filter: Optional[String] = condition match { @@ -264,19 +267,17 @@ case class ColumnarShuffledHashJoinExec( buildOp.getOutput buildGetOutputTime += NANOSECONDS.toMillis(System.nanoTime() - startBuildGetOp) - val streamedPlanOutput = pruneOutput(streamedPlan.output, projectList) - val prunedOutput = streamedPlanOutput ++ prunedBuildOutput val resultSchema = this.schema val reverse = buildSide == BuildLeft var left = 0 - var leftLen = streamedPlanOutput.size - var right = streamedPlanOutput.size + var leftLen = prunedStreamedOutput.size + var right = prunedStreamedOutput.size var rightLen = output.size if (reverse) { - left = streamedPlanOutput.size + left = prunedStreamedOutput.size leftLen = output.size right = 0 - rightLen = streamedPlanOutput.size + rightLen = prunedStreamedOutput.size } val joinIter: Iterator[ColumnarBatch] = new Iterator[ColumnarBatch] { @@ -320,7 +321,7 @@ case class ColumnarShuffledHashJoinExec( val vecs = OmniColumnVector .allocateColumns(result.getRowCount, resultSchema, false) if (projectList.nonEmpty) { - reorderVecs(prunedOutput, projectList, resultVecs, vecs) + reorderOutputVecs(projectListIndex, resultVecs, vecs) } else { var index = 0 for (i <- left until leftLen) { @@ -375,7 +376,7 @@ case class ColumnarShuffledHashJoinExec( val vecs = OmniColumnVector .allocateColumns(result.getRowCount, resultSchema, false) if (projectList.nonEmpty) { - reorderVecs(prunedOutput, projectList, resultVecs, vecs) + reorderOutputVecs(projectListIndex, resultVecs, vecs) } else { var index = 0 for (i <- left until leftLen) { diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/joins/ColumnarSortMergeJoinExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/joins/ColumnarSortMergeJoinExec.scala index c3a22b1ea3a8c8a250e75142af71d47df9f2bcb5..a5baa6bde611724f1333fc5f644da9f2bc34a7a2 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/joins/ColumnarSortMergeJoinExec.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/joins/ColumnarSortMergeJoinExec.scala @@ -25,7 +25,7 @@ import com.huawei.boostkit.spark.Constant.IS_SKIP_VERIFY_EXP import com.huawei.boostkit.spark.expression.OmniExpressionAdaptor import com.huawei.boostkit.spark.expression.OmniExpressionAdaptor.{checkOmniJsonWhiteList, isSimpleColumn, isSimpleColumnForAll} import com.huawei.boostkit.spark.util.OmniAdaptorUtil -import com.huawei.boostkit.spark.util.OmniAdaptorUtil.{getIndexArray, pruneOutput, reorderVecs, transColBatchToOmniVecs} +import com.huawei.boostkit.spark.util.OmniAdaptorUtil.{getExprIdForProjectList, getIndexArray, getProjectListIndex,pruneOutput, reorderOutputVecs, transColBatchToOmniVecs} import nova.hetu.omniruntime.`type`.DataType import nova.hetu.omniruntime.constants.JoinType._ import nova.hetu.omniruntime.operator.config.{OperatorConfig, OverflowConfig, SpillConfig} @@ -197,18 +197,18 @@ case class ColumnarSortMergeJoinExec( left.output.zipWithIndex.foreach { case (attr, i) => streamedTypes(i) = OmniExpressionAdaptor.sparkTypeToOmniType(attr.dataType, attr.metadata) } + val streamOutputExprIdMap = OmniExpressionAdaptor.getExprIdMap(left.output.map(_.toAttribute)) val streamedKeyColsExp: Array[AnyRef] = leftKeys.map { x => - OmniExpressionAdaptor.rewriteToOmniJsonExpressionLiteral(x, - OmniExpressionAdaptor.getExprIdMap(left.output.map(_.toAttribute))) + OmniExpressionAdaptor.rewriteToOmniJsonExpressionLiteral(x, streamOutputExprIdMap) }.toArray val bufferedTypes = new Array[DataType](right.output.size) right.output.zipWithIndex.foreach { case (attr, i) => bufferedTypes(i) = OmniExpressionAdaptor.sparkTypeToOmniType(attr.dataType, attr.metadata) } + val bufferOutputExprIdMap = OmniExpressionAdaptor.getExprIdMap(right.output.map(_.toAttribute)) val bufferedKeyColsExp: Array[AnyRef] = rightKeys.map { x => - OmniExpressionAdaptor.rewriteToOmniJsonExpressionLiteral(x, - OmniExpressionAdaptor.getExprIdMap(right.output.map(_.toAttribute))) + OmniExpressionAdaptor.rewriteToOmniJsonExpressionLiteral(x, bufferOutputExprIdMap) }.toArray if (!isSimpleColumnForAll(streamedKeyColsExp.map(expr => expr.toString))) { @@ -246,23 +246,24 @@ case class ColumnarSortMergeJoinExec( left.output.zipWithIndex.foreach { case (attr, i) => streamedTypes(i) = OmniExpressionAdaptor.sparkTypeToOmniType(attr.dataType, attr.metadata) } + val streamOutputExprIdMap = OmniExpressionAdaptor.getExprIdMap(left.output.map(_.toAttribute)) val streamedKeyColsExp = leftKeys.map { x => - OmniExpressionAdaptor.rewriteToOmniJsonExpressionLiteral(x, - OmniExpressionAdaptor.getExprIdMap(left.output.map(_.toAttribute))) + OmniExpressionAdaptor.rewriteToOmniJsonExpressionLiteral(x, streamOutputExprIdMap) }.toArray - val streamedOutputChannel = getIndexArray(left.output, projectList) + val projectExprIdList = getExprIdForProjectList(projectList) + val streamedOutputChannel = getIndexArray(left.output, projectExprIdList) val bufferedTypes = new Array[DataType](right.output.size) right.output.zipWithIndex.foreach { case (attr, i) => bufferedTypes(i) = OmniExpressionAdaptor.sparkTypeToOmniType(attr.dataType, attr.metadata) } + val bufferOutputExprIdMap = OmniExpressionAdaptor.getExprIdMap(right.output.map(_.toAttribute)) val bufferedKeyColsExp = rightKeys.map { x => - OmniExpressionAdaptor.rewriteToOmniJsonExpressionLiteral(x, - OmniExpressionAdaptor.getExprIdMap(right.output.map(_.toAttribute))) + OmniExpressionAdaptor.rewriteToOmniJsonExpressionLiteral(x, bufferOutputExprIdMap) }.toArray val bufferedOutputChannel: Array[Int] = joinType match { case Inner | LeftOuter | FullOuter => - getIndexArray(right.output, projectList) + getIndexArray(right.output, projectExprIdList) case LeftExistence(_) => Array[Int]() case x => @@ -275,6 +276,9 @@ case class ColumnarSortMergeJoinExec( OmniExpressionAdaptor.getExprIdMap((left.output ++ right.output).map(_.toAttribute))) case _ => null } + val prunedStreamOutput = pruneOutput(left.output, projectExprIdList) + val prunedBufferOutput = pruneOutput(right.output, projectExprIdList) + val projectListIndex = getProjectListIndex(projectExprIdList, prunedStreamOutput, prunedBufferOutput) left.executeColumnar().zipPartitions(right.executeColumnar()) { (streamedIter, bufferedIter) => val filter: Optional[String] = Optional.ofNullable(filterString) @@ -304,9 +308,6 @@ case class ColumnarSortMergeJoinExec( streamedOpFactory.close() }) - val prunedStreamOutput = pruneOutput(left.output, projectList) - val prunedBufferOutput = pruneOutput(right.output, projectList) - val prunedOutput = prunedStreamOutput ++ prunedBufferOutput val resultSchema = this.schema val columnarConf: ColumnarPluginConfig = ColumnarPluginConfig.getSessionConf val enableSortMergeJoinBatchMerge: Boolean = columnarConf.enableSortMergeJoinBatchMerge @@ -415,7 +416,7 @@ case class ColumnarSortMergeJoinExec( val resultVecs = result.getVectors val vecs = OmniColumnVector.allocateColumns(result.getRowCount, resultSchema, false) if (projectList.nonEmpty) { - reorderVecs(prunedOutput, projectList, resultVecs, vecs) + reorderOutputVecs(projectListIndex, resultVecs, vecs) } else { for (index <- output.indices) { val v = vecs(index) diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/util/SparkMemoryUtils.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/util/SparkMemoryUtils.scala index 946c90a9baf346dc4e47253ced50a53def22374b..2eb8fec002fb39b65284f33654c6df88a044474f 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/util/SparkMemoryUtils.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/util/SparkMemoryUtils.scala @@ -23,8 +23,8 @@ import org.apache.spark.{SparkEnv, TaskContext} object SparkMemoryUtils { - private val max: Long = SparkEnv.get.conf.getSizeAsBytes("spark.memory.offHeap.size", "1g") - MemoryManager.setGlobalMemoryLimit(max) + val offHeapSize: Long = SparkEnv.get.conf.getSizeAsBytes("spark.memory.offHeap.size", "1g") + MemoryManager.setGlobalMemoryLimit(offHeapSize) def init(): Unit = {} diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/window/TopNPushDownForWindow.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/window/TopNPushDownForWindow.scala index 94e566f9b571c57c2a0e1cc17143c055d7be9229..d53c6e0286c21e026c5073335e96a5a00010a71a 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/window/TopNPushDownForWindow.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/window/TopNPushDownForWindow.scala @@ -81,7 +81,7 @@ object TopNPushDownForWindow extends Rule[SparkPlan] with PredicateHelper { private def isTopNExpression(e: Expression): Boolean = e match { case Alias(child, _) => isTopNExpression(child) case WindowExpression(windowFunction, _) - if windowFunction.isInstanceOf[Rank] || windowFunction.isInstanceOf[RowNumber] => true + if windowFunction.isInstanceOf[Rank] => true case _ => false } diff --git a/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/ColumnShuffleCompressionTest.java b/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/ColumnShuffleCompressionTest.java index d95be18832b926500b599821b6b6fd0baa8861c5..d1cd5b7f29c8249d81cdb458d89d82ecf38b8dbe 100644 --- a/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/ColumnShuffleCompressionTest.java +++ b/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/ColumnShuffleCompressionTest.java @@ -117,7 +117,8 @@ public class ColumnShuffleCompressionTest extends ColumnShuffleTest { shuffleTestDir, 64 * 1024, 4096, - 1024 * 1024 * 1024); + 1024 * 1024 * 1024, + Long.MAX_VALUE); for (int i = 0; i < 999; i++) { VecBatch vecBatchTmp = buildVecBatch(idTypes, 1000, partitionNum, true, true); jniWrapper.split(splitterId, vecBatchTmp.getNativeVectorBatch()); diff --git a/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/ColumnShuffleDiffPartitionTest.java b/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/ColumnShuffleDiffPartitionTest.java index c8fd474137a93ea8831d3dc3ab432e409018cc55..e0d271ab107073c85314c1b0b796506dd5892232 100644 --- a/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/ColumnShuffleDiffPartitionTest.java +++ b/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/ColumnShuffleDiffPartitionTest.java @@ -115,7 +115,8 @@ public class ColumnShuffleDiffPartitionTest extends ColumnShuffleTest { shuffleTestDir, 64 * 1024, 4096, - 1024 * 1024 * 1024); + 1024 * 1024 * 1024, + Long.MAX_VALUE); for (int i = 0; i < 99; i++) { VecBatch vecBatchTmp = buildVecBatch(idTypes, 999, partitionNum, true, pidVec); jniWrapper.split(splitterId, vecBatchTmp.getNativeVectorBatch()); diff --git a/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/ColumnShuffleDiffRowVBTest.java b/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/ColumnShuffleDiffRowVBTest.java index dc53fda8a1a04a15bf7ffb9919926d4812208fc0..0f935e68a156034d7973409b41098a3bff331330 100644 --- a/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/ColumnShuffleDiffRowVBTest.java +++ b/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/ColumnShuffleDiffRowVBTest.java @@ -95,7 +95,8 @@ public class ColumnShuffleDiffRowVBTest extends ColumnShuffleTest { shuffleTestDir, 64 * 1024, 4096, - 1024 * 1024 * 1024); + 1024 * 1024 * 1024, + Long.MAX_VALUE); for (int i = 0; i < 999; i++) { VecBatch vecBatchTmp = buildVecBatch(idTypes, 999, partitionNum, true, true); jniWrapper.split(splitterId, vecBatchTmp.getNativeVectorBatch()); @@ -125,7 +126,8 @@ public class ColumnShuffleDiffRowVBTest extends ColumnShuffleTest { shuffleTestDir, 0, 4096, - 1024*1024*1024); + 1024*1024*1024, + Long.MAX_VALUE); for (int i = 0; i < 999; i++) { VecBatch vecBatchTmp = buildVecBatch(idTypes, 999, partitionNum, true, true); jniWrapper.split(splitterId, vecBatchTmp.getNativeVectorBatch()); @@ -155,7 +157,8 @@ public class ColumnShuffleDiffRowVBTest extends ColumnShuffleTest { shuffleTestDir, 64 * 1024, 4096, - 1024 * 1024 * 1024); + 1024 * 1024 * 1024, + Long.MAX_VALUE); for (int i = 0; i < 1024; i++) { VecBatch vecBatchTmp = buildVecBatch(idTypes, 1, partitionNum, false, true); jniWrapper.split(splitterId, vecBatchTmp.getNativeVectorBatch()); @@ -185,7 +188,8 @@ public class ColumnShuffleDiffRowVBTest extends ColumnShuffleTest { shuffleTestDir, 64 * 1024, 4096, - 1024 * 1024 * 1024); + 1024 * 1024 * 1024, + Long.MAX_VALUE); for (int i = 0; i < 1; i++) { VecBatch vecBatchTmp = buildVecBatch(idTypes, 1024, partitionNum, false, true); jniWrapper.split(splitterId, vecBatchTmp.getNativeVectorBatch()); @@ -214,7 +218,8 @@ public class ColumnShuffleDiffRowVBTest extends ColumnShuffleTest { shuffleTestDir, 64 * 1024, 4096, - 1024 * 1024 * 1024); + 1024 * 1024 * 1024, + Long.MAX_VALUE); for (int i = 1; i < 1000; i++) { VecBatch vecBatchTmp = buildVecBatch(idTypes, i, numPartition, false, true); jniWrapper.split(splitterId, vecBatchTmp.getNativeVectorBatch()); @@ -243,7 +248,8 @@ public class ColumnShuffleDiffRowVBTest extends ColumnShuffleTest { shuffleTestDir, 64 * 1024, 4096, - 1024 * 1024 * 1024); + 1024 * 1024 * 1024, + Long.MAX_VALUE); VecBatch vecBatchTmp1 = new VecBatch(buildValChar(3, "N")); jniWrapper.split(splitterId, vecBatchTmp1.getNativeVectorBatch()); VecBatch vecBatchTmp2 = new VecBatch(buildValChar(2, "F")); @@ -282,7 +288,8 @@ public class ColumnShuffleDiffRowVBTest extends ColumnShuffleTest { shuffleTestDir, 64 * 1024, 4096, - 1024 * 1024 * 1024); + 1024 * 1024 * 1024, + Long.MAX_VALUE); VecBatch vecBatchTmp1 = new VecBatch(buildValInt(3, 1)); jniWrapper.split(splitterId, vecBatchTmp1.getNativeVectorBatch()); VecBatch vecBatchTmp2 = new VecBatch(buildValInt(2, 2)); diff --git a/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/ColumnShuffleGBSizeTest.java b/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/ColumnShuffleGBSizeTest.java index 2ef81ac49e545aa617136b9d4f3e7e769ea34652..dcd1e8b8571f5942cdb8e7ff29f842eb2c325528 100644 --- a/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/ColumnShuffleGBSizeTest.java +++ b/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/ColumnShuffleGBSizeTest.java @@ -95,7 +95,8 @@ public class ColumnShuffleGBSizeTest extends ColumnShuffleTest { shuffleTestDir, 64 * 1024, 4096, - 1024 * 1024 * 1024); + 1024 * 1024 * 1024, + Long.MAX_VALUE); for (int i = 0; i < 6 * 1024; i++) { VecBatch vecBatchTmp = buildVecBatch(idTypes, 9999, partitionNum, false, true); jniWrapper.split(splitterId, vecBatchTmp.getNativeVectorBatch()); @@ -124,7 +125,8 @@ public class ColumnShuffleGBSizeTest extends ColumnShuffleTest { shuffleTestDir, 64 * 1024, 4096, - 1024 * 1024 * 1024); + 1024 * 1024 * 1024, + Long.MAX_VALUE); for (int i = 0; i < 10 * 8 * 1024; i++) { VecBatch vecBatchTmp = buildVecBatch(idTypes, 9999, partitionNum, false, true); jniWrapper.split(splitterId, vecBatchTmp.getNativeVectorBatch()); @@ -153,7 +155,8 @@ public class ColumnShuffleGBSizeTest extends ColumnShuffleTest { shuffleTestDir, 64 * 1024, 4096, - 1024 * 1024 * 1024); + 1024 * 1024 * 1024, + Long.MAX_VALUE); // 不能重复split同一个vb,接口有释放vb内存,重复split会导致重复释放内存而Core for (int i = 0; i < 99; i++) { VecBatch vecBatchTmp = buildVecBatch(idTypes, 999, partitionNum, false, true); @@ -183,7 +186,8 @@ public class ColumnShuffleGBSizeTest extends ColumnShuffleTest { shuffleTestDir, 64 * 1024, 4096, - 1024 * 1024 * 1024); + 1024 * 1024 * 1024, + Long.MAX_VALUE); for (int i = 0; i < 10 * 3 * 999; i++) { VecBatch vecBatchTmp = buildVecBatch(idTypes, 9999, partitionNum, false, true); jniWrapper.split(splitterId, vecBatchTmp.getNativeVectorBatch()); @@ -213,7 +217,8 @@ public class ColumnShuffleGBSizeTest extends ColumnShuffleTest { shuffleTestDir, 64 * 1024, 4096, - 1024 * 1024 * 1024); + 1024 * 1024 * 1024, + Long.MAX_VALUE); // 不能重复split同一个vb,接口有释放vb内存,重复split会导致重复释放内存而Core for (int i = 0; i < 6 * 999; i++) { VecBatch vecBatchTmp = buildVecBatch(idTypes, 9999, partitionNum, false, true); @@ -244,7 +249,8 @@ public class ColumnShuffleGBSizeTest extends ColumnShuffleTest { shuffleTestDir, 64 * 1024, 4096, - 1024 * 1024 * 1024); + 1024 * 1024 * 1024, + Long.MAX_VALUE); for (int i = 0; i < 3 * 9 * 999; i++) { VecBatch vecBatchTmp = buildVecBatch(idTypes, 9999, partitionNum, false, true); jniWrapper.split(splitterId, vecBatchTmp.getNativeVectorBatch()); diff --git a/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/ColumnShuffleNullTest.java b/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/ColumnShuffleNullTest.java index 98fc18dd8f3237928cc066887e6fcb2205686692..886c2f80643bc314ecc8e1f1fdd4272cc4fbebd2 100644 --- a/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/ColumnShuffleNullTest.java +++ b/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/ColumnShuffleNullTest.java @@ -94,7 +94,8 @@ public class ColumnShuffleNullTest extends ColumnShuffleTest { shuffleTestDir, 64 * 1024, 4096, - 1024 * 1024 * 1024); + 1024 * 1024 * 1024, + Long.MAX_VALUE); // 不能重复split同一个vb,接口有释放vb内存,重复split会导致重复释放内存而Core for (int i = 0; i < 1; i++) { VecBatch vecBatchTmp = buildVecBatch(idTypes, 9, numPartition, true, true); @@ -124,7 +125,8 @@ public class ColumnShuffleNullTest extends ColumnShuffleTest { shuffleTestDir, 64 * 1024, 4096, - 1024 * 1024 * 1024); + 1024 * 1024 * 1024, + Long.MAX_VALUE); // 不能重复split同一个vb,接口有释放vb内存,重复split会导致重复释放内存而Core for (int i = 0; i < 1; i++) { VecBatch vecBatchTmp = buildVecBatch(idTypes, 9, numPartition, true, true); @@ -155,7 +157,8 @@ public class ColumnShuffleNullTest extends ColumnShuffleTest { shuffleTestDir, 64 * 1024, 4096, - 1024 * 1024 * 1024); + 1024 * 1024 * 1024, + Long.MAX_VALUE); // 不能重复split同一个vb,接口有释放vb内存,重复split会导致重复释放内存而Core for (int i = 0; i < 1; i++) { VecBatch vecBatchTmp = buildVecBatch(idTypes, 9, numPartition, true, true); @@ -186,7 +189,8 @@ public class ColumnShuffleNullTest extends ColumnShuffleTest { shuffleTestDir, 64 * 1024, 4096, - 1024 * 1024 * 1024); + 1024 * 1024 * 1024, + Long.MAX_VALUE); for (int i = 0; i < 1; i++) { VecBatch vecBatchTmp = buildVecBatch(idTypes, 9999, numPartition, true, true); jniWrapper.split(splitterId, vecBatchTmp.getNativeVectorBatch()); diff --git a/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderDataTypeTest.java b/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderDataTypeTest.java index 77283e4d07eb7e1d77317a026f5d04ea4032c393..fe1c55ffb6eeecc7a16521430759e3239ca77dd4 100644 --- a/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderDataTypeTest.java +++ b/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderDataTypeTest.java @@ -36,6 +36,10 @@ import java.net.URI; import java.net.URISyntaxException; import java.util.ArrayList; +import static nova.hetu.omniruntime.type.DataType.DataTypeId.OMNI_LONG; +import static nova.hetu.omniruntime.type.DataType.DataTypeId.OMNI_VARCHAR; +import static nova.hetu.omniruntime.type.DataType.DataTypeId.OMNI_INT; + @FixMethodOrder(value = MethodSorters.NAME_ASCENDING ) public class OrcColumnarBatchJniReaderDataTypeTest extends TestCase { public OrcColumnarBatchScanReader orcColumnarBatchScanReader; @@ -96,7 +100,7 @@ public class OrcColumnarBatchJniReaderDataTypeTest extends TestCase { @Test public void testNext() { - int[] typeId = new int[4]; + int[] typeId = new int[] {OMNI_LONG.ordinal(), OMNI_VARCHAR.ordinal(), OMNI_VARCHAR.ordinal(), OMNI_INT.ordinal()}; long[] vecNativeId = new long[4]; long rtn = orcColumnarBatchScanReader.jniReader.recordReaderNext(orcColumnarBatchScanReader.recordReader, orcColumnarBatchScanReader.batchReader, typeId, vecNativeId); assertTrue(rtn == 4096); @@ -115,4 +119,27 @@ public class OrcColumnarBatchJniReaderDataTypeTest extends TestCase { vec3.close(); vec4.close(); } + + // Here we test OMNI_LONG type instead of OMNI_INT in 4th field. + @Test + public void testNextIfSchemaChange() { + int[] typeId = new int[] {OMNI_LONG.ordinal(), OMNI_VARCHAR.ordinal(), OMNI_VARCHAR.ordinal(), OMNI_LONG.ordinal()}; + long[] vecNativeId = new long[4]; + long rtn = orcColumnarBatchScanReader.jniReader.recordReaderNext(orcColumnarBatchScanReader.recordReader, orcColumnarBatchScanReader.batchReader, typeId, vecNativeId); + assertTrue(rtn == 4096); + LongVec vec1 = new LongVec(vecNativeId[0]); + VarcharVec vec2 = new VarcharVec(vecNativeId[1]); + VarcharVec vec3 = new VarcharVec(vecNativeId[2]); + LongVec vec4 = new LongVec(vecNativeId[3]); + assertTrue(vec1.get(10) == 11); + String tmp1 = new String(vec2.get(4080)); + assertTrue(tmp1.equals("AAAAAAAABPPAAAAA")); + String tmp2 = new String(vec3.get(4070)); + assertTrue(tmp2.equals("Particular, arab cases shall like less current, different names. Computers start for the changes. Scottish, trying exercises operate marks; long, supreme miners may ro")); + assertTrue(0 == vec4.get(1000)); + vec1.close(); + vec2.close(); + vec3.close(); + vec4.close(); + } } \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderNotPushDownTest.java b/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderNotPushDownTest.java index 72587b3f36a469fe130abc76c51197fb2a16bd29..995c434f66a8b03a6a76830b8fb0b08be9a3223e 100644 --- a/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderNotPushDownTest.java +++ b/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderNotPushDownTest.java @@ -35,6 +35,9 @@ import java.net.URI; import java.net.URISyntaxException; import java.util.ArrayList; +import static nova.hetu.omniruntime.type.DataType.DataTypeId.OMNI_LONG; +import static nova.hetu.omniruntime.type.DataType.DataTypeId.OMNI_VARCHAR; + @FixMethodOrder(value = MethodSorters.NAME_ASCENDING ) public class OrcColumnarBatchJniReaderNotPushDownTest extends TestCase { public OrcColumnarBatchScanReader orcColumnarBatchScanReader; @@ -89,7 +92,7 @@ public class OrcColumnarBatchJniReaderNotPushDownTest extends TestCase { @Test public void testNext() { - int[] typeId = new int[2]; + int[] typeId = new int[] {OMNI_LONG.ordinal(), OMNI_VARCHAR.ordinal()}; long[] vecNativeId = new long[2]; long rtn = orcColumnarBatchScanReader.jniReader.recordReaderNext(orcColumnarBatchScanReader.recordReader, orcColumnarBatchScanReader.batchReader, typeId, vecNativeId); assertTrue(rtn == 4096); diff --git a/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderPushDownTest.java b/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderPushDownTest.java index 6c75eda79e38332e48e67df8a27c0e1394e4e477..c9ad9fadaf4ae13b1286467f965a23f66788c37f 100644 --- a/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderPushDownTest.java +++ b/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderPushDownTest.java @@ -35,6 +35,9 @@ import java.net.URI; import java.net.URISyntaxException; import java.util.ArrayList; +import static nova.hetu.omniruntime.type.DataType.DataTypeId.OMNI_LONG; +import static nova.hetu.omniruntime.type.DataType.DataTypeId.OMNI_VARCHAR; + @FixMethodOrder(value = MethodSorters.NAME_ASCENDING ) public class OrcColumnarBatchJniReaderPushDownTest extends TestCase { public OrcColumnarBatchScanReader orcColumnarBatchScanReader; @@ -135,7 +138,7 @@ public class OrcColumnarBatchJniReaderPushDownTest extends TestCase { @Test public void testNext() { - int[] typeId = new int[2]; + int[] typeId = new int[] {OMNI_LONG.ordinal(), OMNI_VARCHAR.ordinal()}; long[] vecNativeId = new long[2]; long rtn = orcColumnarBatchScanReader.jniReader.recordReaderNext(orcColumnarBatchScanReader.recordReader, orcColumnarBatchScanReader.batchReader, typeId, vecNativeId); assertTrue(rtn == 4096); diff --git a/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderSparkORCNotPushDownTest.java b/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderSparkORCNotPushDownTest.java index 7fb87efa3a2899bc5375168f56d516243a10881d..8f4535338cc3715e33dbf336e5f15fd4eb91569f 100644 --- a/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderSparkORCNotPushDownTest.java +++ b/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderSparkORCNotPushDownTest.java @@ -36,6 +36,10 @@ import java.net.URI; import java.net.URISyntaxException; import java.util.ArrayList; +import static nova.hetu.omniruntime.type.DataType.DataTypeId.OMNI_LONG; +import static nova.hetu.omniruntime.type.DataType.DataTypeId.OMNI_VARCHAR; +import static nova.hetu.omniruntime.type.DataType.DataTypeId.OMNI_INT; + @FixMethodOrder(value = MethodSorters.NAME_ASCENDING ) public class OrcColumnarBatchJniReaderSparkORCNotPushDownTest extends TestCase { public OrcColumnarBatchScanReader orcColumnarBatchScanReader; @@ -96,7 +100,7 @@ public class OrcColumnarBatchJniReaderSparkORCNotPushDownTest extends TestCase { @Test public void testNext() { - int[] typeId = new int[4]; + int[] typeId = new int[] {OMNI_LONG.ordinal(), OMNI_VARCHAR.ordinal(), OMNI_VARCHAR.ordinal(), OMNI_INT.ordinal()}; long[] vecNativeId = new long[4]; long rtn = orcColumnarBatchScanReader.jniReader.recordReaderNext(orcColumnarBatchScanReader.recordReader, orcColumnarBatchScanReader.batchReader, typeId, vecNativeId); assertTrue(rtn == 4096); diff --git a/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderSparkORCPushDownTest.java b/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderSparkORCPushDownTest.java index 4ba4579cc9340ea410d1d1cdbbf5e0a88ebe2888..27bcf5d7bdbd4787f42e90e620d1627678f70910 100644 --- a/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderSparkORCPushDownTest.java +++ b/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderSparkORCPushDownTest.java @@ -36,6 +36,10 @@ import java.net.URI; import java.net.URISyntaxException; import java.util.ArrayList; +import static nova.hetu.omniruntime.type.DataType.DataTypeId.OMNI_LONG; +import static nova.hetu.omniruntime.type.DataType.DataTypeId.OMNI_VARCHAR; +import static nova.hetu.omniruntime.type.DataType.DataTypeId.OMNI_INT; + @FixMethodOrder(value = MethodSorters.NAME_ASCENDING ) public class OrcColumnarBatchJniReaderSparkORCPushDownTest extends TestCase { public OrcColumnarBatchScanReader orcColumnarBatchScanReader; @@ -142,7 +146,7 @@ public class OrcColumnarBatchJniReaderSparkORCPushDownTest extends TestCase { @Test public void testNext() { - int[] typeId = new int[4]; + int[] typeId = new int[] {OMNI_LONG.ordinal(), OMNI_VARCHAR.ordinal(), OMNI_VARCHAR.ordinal(), OMNI_INT.ordinal()}; long[] vecNativeId = new long[4]; long rtn = orcColumnarBatchScanReader.jniReader.recordReaderNext(orcColumnarBatchScanReader.recordReader, orcColumnarBatchScanReader.batchReader, typeId, vecNativeId); assertTrue(rtn == 4096); diff --git a/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderTest.java b/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderTest.java index c8581f35ebc605845896b5f35731b0908779f326..eab15fef660250780e0beb311b489e3ceeb8ff5b 100644 --- a/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderTest.java +++ b/omnioperator/omniop-spark-extension/java/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderTest.java @@ -21,20 +21,15 @@ package com.huawei.boostkit.spark.jni; import com.esotericsoftware.kryo.Kryo; import com.esotericsoftware.kryo.io.Input; import junit.framework.TestCase; -import nova.hetu.omniruntime.type.DataType; -import nova.hetu.omniruntime.vector.IntVec; import nova.hetu.omniruntime.vector.LongVec; import nova.hetu.omniruntime.vector.VarcharVec; import nova.hetu.omniruntime.vector.Vec; import org.apache.commons.codec.binary.Base64; import org.apache.hadoop.hive.ql.io.sarg.SearchArgument; import org.apache.hadoop.hive.ql.io.sarg.SearchArgumentImpl; -import org.apache.orc.OrcConf; import org.apache.orc.OrcFile; -import org.apache.orc.Reader; import org.apache.orc.TypeDescription; import org.apache.orc.mapred.OrcInputFormat; -import org.json.JSONObject; import org.junit.After; import org.junit.Before; import org.junit.FixMethodOrder; @@ -51,6 +46,9 @@ import org.apache.orc.Reader.Options; import static org.junit.Assert.*; +import static nova.hetu.omniruntime.type.DataType.DataTypeId.OMNI_LONG; +import static nova.hetu.omniruntime.type.DataType.DataTypeId.OMNI_VARCHAR; + @FixMethodOrder(value = MethodSorters.NAME_ASCENDING ) public class OrcColumnarBatchJniReaderTest extends TestCase { public Configuration conf = new Configuration(); @@ -152,7 +150,8 @@ public class OrcColumnarBatchJniReaderTest extends TestCase { @Test public void testNext() { Vec[] vecs = new Vec[2]; - long rtn = orcColumnarBatchScanReader.next(vecs); + int[] typeId = new int[] {OMNI_LONG.ordinal(), OMNI_VARCHAR.ordinal()}; + long rtn = orcColumnarBatchScanReader.next(vecs, typeId); assertTrue(rtn == 4096); assertTrue(((LongVec) vecs[0]).get(0) == 1); String str = new String(((VarcharVec) vecs[1]).get(0)); diff --git a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/shuffle/RowShuffleSerializerSuite.scala b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/shuffle/RowShuffleSerializerSuite.scala new file mode 100644 index 0000000000000000000000000000000000000000..0f0eda4f765d4eddb57fcfd0615b3918fed842d1 --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/shuffle/RowShuffleSerializerSuite.scala @@ -0,0 +1,249 @@ +/* + * Copyright (C) 2022-2022. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle + +import java.io.{File, FileInputStream} + +import com.huawei.boostkit.spark.serialize.ColumnarBatchSerializer +import com.huawei.boostkit.spark.vectorized.PartitionInfo +import nova.hetu.omniruntime.`type`.{DataType, _} +import nova.hetu.omniruntime.vector._ +import org.apache.spark.{HashPartitioner, SparkConf, TaskContext} +import org.apache.spark.executor.TaskMetrics +import org.apache.spark.serializer.JavaSerializer +import org.apache.spark.shuffle.sort.ColumnarShuffleHandle +import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} +import org.apache.spark.sql.execution.vectorized.OmniColumnVector +import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.sql.vectorized.ColumnarBatch +import org.apache.spark.util.Utils +import org.mockito.Answers.RETURNS_SMART_NULLS +import org.mockito.ArgumentMatchers.{any, anyInt, anyLong} +import org.mockito.{Mock, MockitoAnnotations} +import org.mockito.Mockito.{doAnswer, when} +import org.mockito.invocation.InvocationOnMock + +class RowShuffleSerializerSuite extends SharedSparkSession { + @Mock(answer = RETURNS_SMART_NULLS) private var taskContext: TaskContext = _ + @Mock(answer = RETURNS_SMART_NULLS) private var blockResolver: IndexShuffleBlockResolver = _ + @Mock(answer = RETURNS_SMART_NULLS) private var dependency + : ColumnarShuffleDependency[Int, ColumnarBatch, ColumnarBatch] = _ + + override def sparkConf: SparkConf = + super.sparkConf + .setAppName("test row shuffle serializer") + .set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.OmniColumnarShuffleManager") + .set("spark.shuffle.compress", "true") + .set("spark.io.compression.codec", "lz4") + + private var taskMetrics: TaskMetrics = _ + private var tempDir: File = _ + private var outputFile: File = _ + + private var shuffleHandle: ColumnarShuffleHandle[Int, ColumnarBatch] = _ + private val numPartitions = 1 + + protected var avgBatchNumRows: SQLMetric = _ + protected var outputNumRows: SQLMetric = _ + + override def beforeEach(): Unit = { + super.beforeEach() + + avgBatchNumRows = SQLMetrics.createAverageMetric(spark.sparkContext, + "test serializer avg read batch num rows") + outputNumRows = SQLMetrics.createAverageMetric(spark.sparkContext, + "test serializer number of output rows") + + tempDir = Utils.createTempDir() + outputFile = File.createTempFile("shuffle", null, tempDir) + taskMetrics = new TaskMetrics + + MockitoAnnotations.initMocks(this) + + shuffleHandle = + new ColumnarShuffleHandle[Int, ColumnarBatch](shuffleId = 0, dependency = dependency) + + val types : Array[DataType] = Array[DataType]( + IntDataType.INTEGER, + ShortDataType.SHORT, + LongDataType.LONG, + DoubleDataType.DOUBLE, + new Decimal64DataType(18, 3), + new Decimal128DataType(28, 11), + VarcharDataType.VARCHAR, + BooleanDataType.BOOLEAN) + val inputTypes = DataTypeSerializer.serialize(types) + + when(dependency.partitioner).thenReturn(new HashPartitioner(numPartitions)) + when(dependency.serializer).thenReturn(new JavaSerializer(sparkConf)) + when(dependency.handleRow).thenReturn(true) // adapt row shuffle + when(dependency.partitionInfo).thenReturn( + new PartitionInfo("hash", numPartitions, types.length, inputTypes)) + when(dependency.dataSize) + .thenReturn(SQLMetrics.createSizeMetric(spark.sparkContext, "data size")) + when(dependency.bytesSpilled) + .thenReturn(SQLMetrics.createSizeMetric(spark.sparkContext, "shuffle bytes spilled")) + when(dependency.numInputRows) + .thenReturn(SQLMetrics.createMetric(spark.sparkContext, "number of input rows")) + when(dependency.splitTime) + .thenReturn(SQLMetrics.createNanoTimingMetric(spark.sparkContext, "totaltime_split")) + when(dependency.spillTime) + .thenReturn(SQLMetrics.createNanoTimingMetric(spark.sparkContext, "totaltime_spill")) + when(taskContext.taskMetrics()).thenReturn(taskMetrics) + when(blockResolver.getDataFile(0, 0)).thenReturn(outputFile) + + doAnswer { (invocationOnMock: InvocationOnMock) => + val tmp = invocationOnMock.getArguments()(4).asInstanceOf[File] + if (tmp != null) { + outputFile.delete + tmp.renameTo(outputFile) + } + null + }.when(blockResolver) + .writeMetadataFileAndCommit(anyInt, anyLong, any(classOf[Array[Long]]), any(classOf[Array[Long]]), any(classOf[File])) + } + + override def afterEach(): Unit = { + try { + Utils.deleteRecursively(tempDir) + } finally { + super.afterEach() + } + } + + override def afterAll(): Unit = { + super.afterAll() + } + + test("row shuffle serialize and deserialize") { + val pidArray: Array[java.lang.Integer] = Array(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0) + val intArray: Array[java.lang.Integer] = Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20) + val shortArray: Array[java.lang.Integer] = Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20) + val longArray: Array[java.lang.Long] = Array(0L, 1L, 2L, 3L, 4L, 5L, 6L, 7L, 8L, 9L, 10L, 11L, 12L, 13L, 14L, 15L, 16L, + 17L, 18L, 19L, 20L) + val doubleArray: Array[java.lang.Double] = Array(0.0, 1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8, 9.9, 10.10, 11.11, 12.12, + 13.13, 14.14, 15.15, 16.16, 17.17, 18.18, 19.19, 20.20) + val decimal64Array: Array[java.lang.Long] = Array(0L, 1L, 2L, 3L, 4L, 5L, 6L, 7L, 8L, 9L, 10L, 11L, 12L, 13L, 14L, 15L, 16L, + 17L, 18L, 19L, 20L) + val decimal128Array: Array[Array[Long]] = Array( + Array(0L, 0L), Array(1L, 1L), Array(2L, 2L), Array(3L, 3L), Array(4L, 4L), Array(5L, 5L), Array(6L, 6L), + Array(7L, 7L), Array(8L, 8L), Array(9L, 9L), Array(10L, 10L), Array(11L, 11L), Array(12L, 12L), Array(13L, 13L), + Array(14L, 14L), Array(15L, 15L), Array(16L, 16L), Array(17L, 17L), Array(18L, 18L), Array(19L, 19L), Array(20L, 20L)) + val stringArray: Array[java.lang.String] = Array("", "a", "bb", "ccc", "dddd", "eeeee", "ffffff", "ggggggg", + "hhhhhhhh", "iiiiiiiii", "jjjjjjjjjj", "kkkkkkkkkkk", "llllllllllll", "mmmmmmmmmmmmm", "nnnnnnnnnnnnnn", + "ooooooooooooooo", "pppppppppppppppp", "qqqqqqqqqqqqqqqqq", "rrrrrrrrrrrrrrrrrr", "sssssssssssssssssss", + "tttttttttttttttttttt") + val booleanArray: Array[java.lang.Boolean] = Array(true, true, true, true, true, true, true, true, true, true, + false, false, false, false, false, false, false, false, false, false, false) + + val pidVector0 = ColumnarShuffleWriterSuite.initOmniColumnIntVector(pidArray) + val intVector0 = ColumnarShuffleWriterSuite.initOmniColumnIntVector(intArray) + val shortVector0 = ColumnarShuffleWriterSuite.initOmniColumnShortVector(shortArray) + val longVector0 = ColumnarShuffleWriterSuite.initOmniColumnLongVector(longArray) + val doubleVector0 = ColumnarShuffleWriterSuite.initOmniColumnDoubleVector(doubleArray) + val decimal64Vector0 = ColumnarShuffleWriterSuite.initOmniColumnDecimal64Vector(decimal64Array) + val decimal128Vector0 = ColumnarShuffleWriterSuite.initOmniColumnDecimal128Vector(decimal128Array) + val varcharVector0 = ColumnarShuffleWriterSuite.initOmniColumnVarcharVector(stringArray) + val booleanVector0 = ColumnarShuffleWriterSuite.initOmniColumnBooleanVector(booleanArray) + + val cb0 = ColumnarShuffleWriterSuite.makeColumnarBatch( + pidVector0.getVec.getSize, + List(pidVector0, intVector0, shortVector0, longVector0, doubleVector0, + decimal64Vector0, decimal128Vector0, varcharVector0, booleanVector0) + ) + + val pidVector1 = ColumnarShuffleWriterSuite.initOmniColumnIntVector(pidArray) + val intVector1 = ColumnarShuffleWriterSuite.initOmniColumnIntVector(intArray) + val shortVector1 = ColumnarShuffleWriterSuite.initOmniColumnShortVector(shortArray) + val longVector1 = ColumnarShuffleWriterSuite.initOmniColumnLongVector(longArray) + val doubleVector1 = ColumnarShuffleWriterSuite.initOmniColumnDoubleVector(doubleArray) + val decimal64Vector1 = ColumnarShuffleWriterSuite.initOmniColumnDecimal64Vector(decimal64Array) + val decimal128Vector1 = ColumnarShuffleWriterSuite.initOmniColumnDecimal128Vector(decimal128Array) + val varcharVector1 = ColumnarShuffleWriterSuite.initOmniColumnVarcharVector(stringArray) + val booleanVector1 = ColumnarShuffleWriterSuite.initOmniColumnBooleanVector(booleanArray) + + val cb1 = ColumnarShuffleWriterSuite.makeColumnarBatch( + pidVector1.getVec.getSize, + List(pidVector1, intVector1, shortVector1, longVector1, doubleVector1, + decimal64Vector1, decimal128Vector1, varcharVector1, booleanVector1) + ) + + def records: Iterator[(Int, ColumnarBatch)] = Iterator((0, cb0), (0, cb1)) + + val writer = new ColumnarShuffleWriter[Int, ColumnarBatch]( + blockResolver, + shuffleHandle, + 0L, // MapId + taskContext.taskMetrics().shuffleWriteMetrics) + + // row shuffle realized + writer.write(records) + writer.stop(success = true) + + assert(writer.getPartitionLengths.sum === outputFile.length()) + assert(writer.getPartitionLengths.count(_ == 0L) === 0) + // should be (numPartitions - 2) zero length files + + val shuffleWriteMetrics = taskContext.taskMetrics().shuffleWriteMetrics + assert(shuffleWriteMetrics.bytesWritten === outputFile.length()) + assert(shuffleWriteMetrics.recordsWritten === pidArray.length * 2) + + assert(taskMetrics.diskBytesSpilled === 0) + assert(taskMetrics.memoryBytesSpilled === 0) + + // shuffle writer adapt row structure, so need to deserialized by row. + val serializer = new ColumnarBatchSerializer(avgBatchNumRows, outputNumRows, true).newInstance() + val deserializedStream = serializer.deserializeStream(new FileInputStream(outputFile)) + + try { + val kv = deserializedStream.asKeyValueIterator + var length = 0 + kv.foreach { + case (_, batch: ColumnarBatch) => + length += 1 + assert(batch.numRows == 42) + assert(batch.numCols == 8) + assert(batch.column(0).asInstanceOf[OmniColumnVector].getVec.asInstanceOf[IntVec].get(0) == 0) + assert(batch.column(0).asInstanceOf[OmniColumnVector].getVec.asInstanceOf[IntVec].get(19) == 19) + assert(batch.column(1).asInstanceOf[OmniColumnVector].getVec.asInstanceOf[ShortVec].get(0) == 0) + assert(batch.column(1).asInstanceOf[OmniColumnVector].getVec.asInstanceOf[ShortVec].get(19) == 19) + assert(batch.column(2).asInstanceOf[OmniColumnVector].getVec.asInstanceOf[LongVec].get(0) == 0) + assert(batch.column(2).asInstanceOf[OmniColumnVector].getVec.asInstanceOf[LongVec].get(19) == 19) + assert(batch.column(3).asInstanceOf[OmniColumnVector].getVec.asInstanceOf[DoubleVec].get(0) == 0.0) + assert(batch.column(3).asInstanceOf[OmniColumnVector].getVec.asInstanceOf[DoubleVec].get(19) == 19.19) + assert(batch.column(4).asInstanceOf[OmniColumnVector].getVec.asInstanceOf[LongVec].get(0) == 0L) + assert(batch.column(4).asInstanceOf[OmniColumnVector].getVec.asInstanceOf[LongVec].get(19) == 19L) + assert(batch.column(5).asInstanceOf[OmniColumnVector].getVec.asInstanceOf[Decimal128Vec].get(0) sameElements Array(0L, 0L)) + assert(batch.column(5).asInstanceOf[OmniColumnVector].getVec.asInstanceOf[Decimal128Vec].get(19) sameElements Array(19L, 19L)) + assert(batch.column(6).asInstanceOf[OmniColumnVector].getVec.asInstanceOf[VarcharVec].get(0) sameElements "") + assert(batch.column(6).asInstanceOf[OmniColumnVector].getVec.asInstanceOf[VarcharVec].get(19) sameElements "sssssssssssssssssss") + assert(batch.column(7).asInstanceOf[OmniColumnVector].getVec.asInstanceOf[BooleanVec].get(0) == true) + assert(batch.column(7).asInstanceOf[OmniColumnVector].getVec.asInstanceOf[BooleanVec].get(19) == false) + (0 until batch.numCols).foreach { i => + val valueVector = batch.column(i).asInstanceOf[OmniColumnVector].getVec + assert(valueVector.getSize == batch.numRows) + } + batch.close() + } + assert(length == 1) + } finally { + deserializedStream.close() + } + } +} diff --git a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/ColumnarTopNSortExecSuite.scala b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/ColumnarTopNSortExecSuite.scala index a788501ed8ed7d06e8939cd45560f59648a56acf..fa8c1390ebe2af2a3441234868c65321d2055847 100644 --- a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/ColumnarTopNSortExecSuite.scala +++ b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/ColumnarTopNSortExecSuite.scala @@ -49,16 +49,17 @@ class ColumnarTopNSortExecSuite extends ColumnarSparkPlanTest { test("Test topNSort") { val sql1 = "select * from (SELECT city, rank() OVER (ORDER BY sales) AS rk FROM dealer) where rk < 4 order by rk;" - assertColumnarTopNSortExecAndSparkResultEqual(sql1, true) + assertColumnarTopNSortExecAndSparkResultEqual(sql1, true, true) val sql2 = "select * from (SELECT city, row_number() OVER (ORDER BY sales) AS rn FROM dealer) where rn < 4 order by rn;" - assertColumnarTopNSortExecAndSparkResultEqual(sql2, false) + assertColumnarTopNSortExecAndSparkResultEqual(sql2, false, false) val sql3 = "select * from (SELECT city, rank() OVER (PARTITION BY city ORDER BY sales) AS rk FROM dealer) where rk < 4 order by rk;" - assertColumnarTopNSortExecAndSparkResultEqual(sql3, true) + assertColumnarTopNSortExecAndSparkResultEqual(sql3, true, true) } - private def assertColumnarTopNSortExecAndSparkResultEqual(sql: String, hasColumnarTopNSortExec: Boolean = true): Unit = { + private def assertColumnarTopNSortExecAndSparkResultEqual(sql: String, hasColumnarTopNSortExec: Boolean = true, + hasTopNSortExec: Boolean = false): Unit = { // run ColumnarTopNSortExec config spark.conf.set("spark.omni.sql.columnar.topNSort", true) spark.conf.set("spark.sql.execution.topNPushDownForWindow.enabled", true) @@ -79,8 +80,10 @@ class ColumnarTopNSortExecSuite extends ColumnarSparkPlanTest { val sparkPlan = sparkResult.queryExecution.executedPlan.toString() assert(!sparkPlan.contains("ColumnarTopNSort"), s"SQL:${sql}\n@SparkEnv have ColumnarTopNSortExec, sparkPlan:${sparkPlan}") - assert(sparkPlan.contains("TopNSort"), - s"SQL:${sql}\n@SparkEnv no TopNSortExec, sparkPlan:${sparkPlan}") + if (hasTopNSortExec) { + assert(sparkPlan.contains("TopNSort"), + s"SQL:${sql}\n@SparkEnv no TopNSortExec, sparkPlan:${sparkPlan}") + } // DataFrame do not support comparing with equals method, use DataFrame.except instead // DataFrame.except can do equal for rows misorder(with and without order by are same) assert(omniResult.except(sparkResult).isEmpty, diff --git a/omnioperator/omniop-spark-extension/pom.xml b/omnioperator/omniop-spark-extension/pom.xml index b7315c5b49145c0805d940cd088eb70cbc9c6265..24a42654b4f07b86944ff098639a3d3e237e9402 100644 --- a/omnioperator/omniop-spark-extension/pom.xml +++ b/omnioperator/omniop-spark-extension/pom.xml @@ -8,7 +8,7 @@ com.huawei.kunpeng boostkit-omniop-spark-parent pom - 3.3.1-1.4.0 + 3.3.1-1.5.0 BoostKit Spark Native Sql Engine Extension Parent Pom @@ -20,7 +20,7 @@ UTF-8 3.13.0-h19 FALSE - 1.4.0 + 1.5.0 java