diff --git a/omnidata/omnidata-spark-connector/README.md b/omnidata/omnidata-spark-connector/README.md index c773c416e83311847a6aea19163aaf83aa5d9492..de2c8b8c72f7ae75d7346a6306e4dbe70c901bee 100644 --- a/omnidata/omnidata-spark-connector/README.md +++ b/omnidata/omnidata-spark-connector/README.md @@ -5,7 +5,7 @@ Introduction ============ -The omnidata spark connector library running on Kunpeng processors is a Spark SQL plugin that pushes computing-side operators to storage nodes for computing. It is developed based on original APIs of Apache [Spark 3.0.0](https://github.com/apache/spark/tree/v3.0.0). This library applies to the big data storage separation scenario or large-scale fusion scenario where a large number of compute nodes read data from remote nodes. In this scenario, a large amount of raw data is transferred from storage nodes to compute nodes over the network for processing, resulting in a low proportion of valid data and a huge waste of network bandwidth. You can find the latest documentation, including a programming guide, on the project web page. This README file only contains basic setup instructions. +The omnidata spark connector library running on Kunpeng processors is a Spark SQL plugin that pushes computing-side operators to storage nodes for computing. It is developed based on original APIs of Apache [Spark 3.1.1](https://github.com/apache/spark/tree/v3.1.1). This library applies to the big data storage separation scenario or large-scale fusion scenario where a large number of compute nodes read data from remote nodes. In this scenario, a large amount of raw data is transferred from storage nodes to compute nodes over the network for processing, resulting in a low proportion of valid data and a huge waste of network bandwidth. You can find the latest documentation, including a programming guide, on the project web page. This README file only contains basic setup instructions. Building And Packageing diff --git a/omnioperator/omniop-spark-extension-ock/cpp/src/jni/OckShuffleJniReader.cpp b/omnioperator/omniop-spark-extension-ock/cpp/src/jni/OckShuffleJniReader.cpp index 456519e9a8ee7edac294289f84273244f50c9d62..cc72f65d4d8ae03d7a4b8e5df28ed166658fb423 100644 --- a/omnioperator/omniop-spark-extension-ock/cpp/src/jni/OckShuffleJniReader.cpp +++ b/omnioperator/omniop-spark-extension-ock/cpp/src/jni/OckShuffleJniReader.cpp @@ -20,11 +20,16 @@ static const char *exceptionClass = "java/lang/Exception"; static void JniInitialize(JNIEnv *env) { + if (UNLIKELY(env ==nullptr)) { + LOG_ERROR("JNIEnv is null."); + return; + } std::lock_guard lk(gInitLock); if (UNLIKELY(gLongClass == nullptr)) { gLongClass = env->FindClass("java/lang/Long"); if (UNLIKELY(gLongClass == nullptr)) { env->ThrowNew(env->FindClass(exceptionClass), "Failed to find class java/lang/Long"); + return; } gLongValueFieldId = env->GetFieldID(gLongClass, "value", "J"); @@ -38,24 +43,53 @@ static void JniInitialize(JNIEnv *env) JNIEXPORT jlong JNICALL Java_com_huawei_ock_spark_jni_OckShuffleJniReader_make(JNIEnv *env, jobject, jintArray jTypeIds) { + if (UNLIKELY(env == nullptr)) { + LOG_ERROR("JNIEnv is null."); + return 0; + } + if (UNLIKELY(jTypeIds == nullptr)) { + env->ThrowNew(env->FindClass(exceptionClass), "jTypeIds is null."); + return 0; + } std::shared_ptr instance = std::make_shared(); if (UNLIKELY(instance == nullptr)) { env->ThrowNew(env->FindClass(exceptionClass), "Failed to create instance for ock merge reader"); return 0; } - bool result = instance->Initialize(env->GetIntArrayElements(jTypeIds, nullptr), env->GetArrayLength(jTypeIds)); + auto typeIds = env->GetIntArrayElements(jTypeIds, nullptr); + if (UNLIKELY(typeIds == nullptr)) { + env->ThrowNew(env->FindClass(exceptionClass), "Failed to get int array elements."); + return 0; + } + bool result = instance->Initialize(typeIds, env->GetArrayLength(jTypeIds)); if (UNLIKELY(!result)) { + env->ReleaseIntArrayElements(jTypeIds, typeIds, JNI_ABORT); env->ThrowNew(env->FindClass(exceptionClass), "Failed to initialize ock merge reader"); return 0; } - + env->ReleaseIntArrayElements(jTypeIds, typeIds, JNI_ABORT); return gBlobReader.Insert(instance); } +JNIEXPORT void JNICALL Java_com_huawei_ock_spark_jni_OckShuffleJniReader_close(JNIEnv *env, jobject, jlong jReaderId) +{ + if (UNLIKELY(env == nullptr)) { + LOG_ERROR("JNIENV is null."); + return; + } + + gBlobReader.Erase(jReaderId); +} + JNIEXPORT jint JNICALL Java_com_huawei_ock_spark_jni_OckShuffleJniReader_nativeGetVectorBatch(JNIEnv *env, jobject, jlong jReaderId, jlong jAddress, jint jRemain, jint jMaxRow, jint jMaxSize, jobject jRowCnt) { + if (UNLIKELY(env == nullptr)) { + LOG_ERROR("JNIEnv is null."); + return -1; + } + auto mergeReader = gBlobReader.Lookup(jReaderId); if (UNLIKELY(!mergeReader)) { std::string errMsg = "Invalid reader id " + std::to_string(jReaderId); @@ -80,6 +114,10 @@ JNIEXPORT jint JNICALL Java_com_huawei_ock_spark_jni_OckShuffleJniReader_nativeG JNIEXPORT jint JNICALL Java_com_huawei_ock_spark_jni_OckShuffleJniReader_nativeGetVecValueLength(JNIEnv *env, jobject, jlong jReaderId, jint jColIndex) { + if (UNLIKELY(env == nullptr)) { + LOG_ERROR("JNIEnv is null."); + return 0; + } auto mergeReader = gBlobReader.Lookup(jReaderId); if (UNLIKELY(!mergeReader)) { std::string errMsg = "Invalid reader id " + std::to_string(jReaderId); @@ -100,6 +138,11 @@ JNIEXPORT jint JNICALL Java_com_huawei_ock_spark_jni_OckShuffleJniReader_nativeG JNIEXPORT void JNICALL Java_com_huawei_ock_spark_jni_OckShuffleJniReader_nativeCopyVecDataInVB(JNIEnv *env, jobject, jlong jReaderId, jlong dstNativeVec, jint jColIndex) { + if (UNLIKELY(env == nullptr)) { + LOG_ERROR("JNIEnv is null."); + return; + } + auto dstVector = reinterpret_cast(dstNativeVec); // get from scala which is real vector if (UNLIKELY(dstVector == nullptr)) { std::string errMsg = "Invalid dst vector address for reader id " + std::to_string(jReaderId); diff --git a/omnioperator/omniop-spark-extension-ock/cpp/src/jni/OckShuffleJniReader.h b/omnioperator/omniop-spark-extension-ock/cpp/src/jni/OckShuffleJniReader.h index 80a63c403ef8ce43ee5be522ab6bfd5fea6c9b37..eb8a692a7dcde68fafed60820da44870c3fc3a3e 100644 --- a/omnioperator/omniop-spark-extension-ock/cpp/src/jni/OckShuffleJniReader.h +++ b/omnioperator/omniop-spark-extension-ock/cpp/src/jni/OckShuffleJniReader.h @@ -18,6 +18,12 @@ extern "C" { */ JNIEXPORT jlong JNICALL Java_com_huawei_ock_spark_jni_OckShuffleJniReader_make(JNIEnv *, jobject, jintArray); +/* + * Class: com_huawei_ock_spark_jni_OckShuffleJniReader + * Method: close + * Signature: (JI)I + */ +JNIEXPORT void JNICALL Java_com_huawei_ock_spark_jni_OckShuffleJniReader_close(JNIEnv *, jobject, jlong); /* * Class: com_huawei_ock_spark_jni_OckShuffleJniReader * Method: nativeGetVectorBatch diff --git a/omnioperator/omniop-spark-extension-ock/cpp/src/jni/OckShuffleJniWriter.cpp b/omnioperator/omniop-spark-extension-ock/cpp/src/jni/OckShuffleJniWriter.cpp index 61633605eb8afbf26abeeea595fcfc48742f3498..e1bcdec442798804d80ba6bb51ca88f0ce74cc19 100644 --- a/omnioperator/omniop-spark-extension-ock/cpp/src/jni/OckShuffleJniWriter.cpp +++ b/omnioperator/omniop-spark-extension-ock/cpp/src/jni/OckShuffleJniWriter.cpp @@ -20,11 +20,15 @@ static const char *exceptionClass = "java/lang/Exception"; JNIEXPORT jboolean JNICALL Java_com_huawei_ock_spark_jni_OckShuffleJniWriter_initialize(JNIEnv *env, jobject) { + if (UNLIKELY(env == nullptr)) { + LOG_ERROR("JNIEnv is null."); + return JNI_FALSE; + } gSplitResultClass = CreateGlobalClassReference(env, "Lcom/huawei/boostkit/spark/vectorized/SplitResult;"); gSplitResultConstructor = GetMethodID(env, gSplitResultClass, "", "(JJJJJ[J)V"); if (UNLIKELY(!OckShuffleSdk::Initialize())) { - std::cout << "Failed to load ock shuffle library." << std::endl; + env->ThrowNew(env->FindClass(exceptionClass), std::string("Failed to load ock shuffle library.").c_str()); return JNI_FALSE; } @@ -36,9 +40,14 @@ JNIEXPORT jlong JNICALL Java_com_huawei_ock_spark_jni_OckShuffleJniWriter_native jstring jPartitioningMethod, jint jPartitionNum, jstring jColTypes, jint jColNum, jint jRegionSize, jint jMinCapacity, jint jMaxCapacity, jboolean jIsCompress) { + if (UNLIKELY(env == nullptr)) { + LOG_ERROR("JNIEnv is null."); + return 0; + } auto appIdStr = env->GetStringUTFChars(jAppId, JNI_FALSE); if (UNLIKELY(appIdStr == nullptr)) { env->ThrowNew(env->FindClass(exceptionClass), std::string("ApplicationId can't be empty").c_str()); + return 0; } auto appId = std::string(appIdStr); env->ReleaseStringUTFChars(jAppId, appIdStr); @@ -46,6 +55,7 @@ JNIEXPORT jlong JNICALL Java_com_huawei_ock_spark_jni_OckShuffleJniWriter_native auto partitioningMethodStr = env->GetStringUTFChars(jPartitioningMethod, JNI_FALSE); if (UNLIKELY(partitioningMethodStr == nullptr)) { env->ThrowNew(env->FindClass(exceptionClass), std::string("Partitioning method can't be empty").c_str()); + return 0; } auto partitionMethod = std::string(partitioningMethodStr); env->ReleaseStringUTFChars(jPartitioningMethod, partitioningMethodStr); @@ -53,6 +63,7 @@ JNIEXPORT jlong JNICALL Java_com_huawei_ock_spark_jni_OckShuffleJniWriter_native auto colTypesStr = env->GetStringUTFChars(jColTypes, JNI_FALSE); if (UNLIKELY(colTypesStr == nullptr)) { env->ThrowNew(env->FindClass(exceptionClass), std::string("Columns types can't be empty").c_str()); + return 0; } DataTypes colTypes = Deserialize(colTypesStr); @@ -63,7 +74,8 @@ JNIEXPORT jlong JNICALL Java_com_huawei_ock_spark_jni_OckShuffleJniWriter_native jmethodID jMethodId = env->GetStaticMethodID(jThreadCls, "currentThread", "()Ljava/lang/Thread;"); jobject jThread = env->CallStaticObjectMethod(jThreadCls, jMethodId); if (UNLIKELY(jThread == nullptr)) { - std::cout << "Failed to get current thread instance." << std::endl; + env->ThrowNew(env->FindClass(exceptionClass), std::string("Failed to get current thread instance.").c_str()); + return 0; } else { jThreadId = env->CallLongMethod(jThread, env->GetMethodID(jThreadCls, "getId", "()J")); } @@ -71,16 +83,19 @@ JNIEXPORT jlong JNICALL Java_com_huawei_ock_spark_jni_OckShuffleJniWriter_native auto splitter = OckSplitter::Make(partitionMethod, jPartitionNum, colTypes.GetIds(), jColNum, (uint64_t)jThreadId); if (UNLIKELY(splitter == nullptr)) { env->ThrowNew(env->FindClass(exceptionClass), std::string("Failed to make ock splitter").c_str()); + return 0; } bool ret = splitter->SetShuffleInfo(appId, jShuffleId, jStageId, jStageAttemptNum, jMapId, jTaskAttemptId); if (UNLIKELY(!ret)) { env->ThrowNew(env->FindClass(exceptionClass), std::string("Failed to set shuffle information").c_str()); + return 0; } ret = splitter->InitLocalBuffer(jRegionSize, jMinCapacity, jMaxCapacity, (jIsCompress == JNI_TRUE)); if (UNLIKELY(!ret)) { env->ThrowNew(env->FindClass(exceptionClass), std::string("Failed to initialize local buffer").c_str()); + return 0; } return gOckSplitterMap.Insert(std::shared_ptr(splitter)); @@ -89,21 +104,28 @@ JNIEXPORT jlong JNICALL Java_com_huawei_ock_spark_jni_OckShuffleJniWriter_native JNIEXPORT void JNICALL Java_com_huawei_ock_spark_jni_OckShuffleJniWriter_split(JNIEnv *env, jobject, jlong splitterId, jlong nativeVectorBatch) { + if (UNLIKELY(env == nullptr)) { + LOG_ERROR("JNIEnv is null."); + return; + } auto splitter = gOckSplitterMap.Lookup(splitterId); if (UNLIKELY(!splitter)) { std::string errMsg = "Invalid splitter id " + std::to_string(splitterId); env->ThrowNew(env->FindClass(exceptionClass), errMsg.c_str()); + return; } auto vecBatch = (VectorBatch *)nativeVectorBatch; if (UNLIKELY(vecBatch == nullptr)) { std::string errMsg = "Invalid address for native vector batch."; env->ThrowNew(env->FindClass(exceptionClass), errMsg.c_str()); + return; } if (UNLIKELY(!splitter->Split(*vecBatch))) { std::string errMsg = "Failed to split vector batch by splitter id " + std::to_string(splitterId); env->ThrowNew(env->FindClass(exceptionClass), errMsg.c_str()); + return; } delete vecBatch; @@ -112,10 +134,15 @@ JNIEXPORT void JNICALL Java_com_huawei_ock_spark_jni_OckShuffleJniWriter_split(J JNIEXPORT jobject JNICALL Java_com_huawei_ock_spark_jni_OckShuffleJniWriter_stop(JNIEnv *env, jobject, jlong splitterId) { + if (UNLIKELY(env == nullptr)) { + LOG_ERROR("JNIEnv is null."); + return nullptr; + } auto splitter = gOckSplitterMap.Lookup(splitterId); if (UNLIKELY(!splitter)) { std::string error_message = "Invalid splitter id " + std::to_string(splitterId); env->ThrowNew(env->FindClass(exceptionClass), error_message.c_str()); + return nullptr; } splitter->Stop(); // free resource @@ -132,10 +159,15 @@ JNIEXPORT jobject JNICALL Java_com_huawei_ock_spark_jni_OckShuffleJniWriter_stop JNIEXPORT void JNICALL Java_com_huawei_ock_spark_jni_OckShuffleJniWriter_close(JNIEnv *env, jobject, jlong splitterId) { + if (UNLIKELY(env == nullptr)) { + LOG_ERROR("JNIEnv is null."); + return; + } auto splitter = gOckSplitterMap.Lookup(splitterId); if (UNLIKELY(!splitter)) { std::string errMsg = "Invalid splitter id " + std::to_string(splitterId); env->ThrowNew(env->FindClass(exceptionClass), errMsg.c_str()); + return; } gOckSplitterMap.Erase(splitterId); diff --git a/omnioperator/omniop-spark-extension-ock/cpp/src/shuffle/ock_hash_write_buffer.cpp b/omnioperator/omniop-spark-extension-ock/cpp/src/shuffle/ock_hash_write_buffer.cpp index b9c6ced10a6742812d257ee6cf95c84b9e5b3ad0..a8d9a92e9643aad41dbd988d6fe80116d9655cc1 100644 --- a/omnioperator/omniop-spark-extension-ock/cpp/src/shuffle/ock_hash_write_buffer.cpp +++ b/omnioperator/omniop-spark-extension-ock/cpp/src/shuffle/ock_hash_write_buffer.cpp @@ -23,9 +23,21 @@ bool OckHashWriteBuffer::Initialize(uint32_t regionSize, uint32_t minCapacity, u mIsCompress = isCompress; uint32_t bufferNeed = regionSize * mPartitionNum; mDataCapacity = std::min(std::max(bufferNeed, minCapacity), maxCapacity); + if (UNLIKELY(mDataCapacity < mSinglePartitionAndRegionUsedSize * mPartitionNum)) { + LogError("mDataCapacity should be bigger than mSinglePartitionAndRegionUsedSize * mPartitionNum"); + return false; + } mRegionPtRecordOffset = mDataCapacity - mSinglePartitionAndRegionUsedSize * mPartitionNum; + if (UNLIKELY(mDataCapacity < mSingleRegionUsedSize * mPartitionNum)) { + LogError("mDataCapacity should be bigger than mSingleRegionUsedSize * mPartitionNum"); + return false; + } mRegionUsedRecordOffset = mDataCapacity - mSingleRegionUsedSize * mPartitionNum; + if (UNLIKELY(mDataCapacity / mPartitionNum < mSinglePartitionAndRegionUsedSize)) { + LogError("mDataCapacity / mPartitionNum should be bigger than mSinglePartitionAndRegionUsedSize"); + return false; + } mEachPartitionSize = mDataCapacity / mPartitionNum - mSinglePartitionAndRegionUsedSize; mDoublePartitionSize = reserveSize * mEachPartitionSize; @@ -76,6 +88,10 @@ OckHashWriteBuffer::ResultFlag OckHashWriteBuffer::PreoccupiedDataSpace(uint32_t return ResultFlag::UNEXPECTED; } + if (UNLIKELY(mTotalSize > UINT32_MAX -length)) { + LogError("mTotalSize + length exceed UINT32_MAX"); + return ResultFlag::UNEXPECTED; + } // 1. get the new region id for partitionId uint32_t regionId = UINT32_MAX; if (newRegion && !GetNewRegion(partitionId, regionId)) { @@ -98,7 +114,7 @@ OckHashWriteBuffer::ResultFlag OckHashWriteBuffer::PreoccupiedDataSpace(uint32_t (mDoublePartitionSize - mRegionUsedSize[regionId] - mRegionUsedSize[nearRegionId]); if (remainBufLength >= length) { mRegionUsedSize[regionId] += length; - mTotalSize += length; // todo check + mTotalSize += length; return ResultFlag::ENOUGH; } @@ -111,8 +127,16 @@ uint8_t *OckHashWriteBuffer::GetEndAddressOfRegion(uint32_t partitionId, uint32_ regionId = mPtCurrentRegionId[partitionId]; if ((regionId % groupSize) == 0) { + if (UNLIKELY(regionId * mEachPartitionSize + mRegionUsedSize[regionId] < length)) { + LogError("regionId * mEachPartitionSize + mRegionUsedSize[regionId] shoulld be bigger than length"); + return nullptr; + } offset = regionId * mEachPartitionSize + mRegionUsedSize[regionId] - length; } else { + if (UNLIKELY((regionId + 1) * mEachPartitionSize < mRegionUsedSize[regionId])) { + LogError("(regionId + 1) * mEachPartitionSize shoulld be bigger than mRegionUsedSize[regionId]"); + return nullptr; + } offset = (regionId + 1) * mEachPartitionSize - mRegionUsedSize[regionId]; } diff --git a/omnioperator/omniop-spark-extension-ock/cpp/src/shuffle/ock_merge_reader.cpp b/omnioperator/omniop-spark-extension-ock/cpp/src/shuffle/ock_merge_reader.cpp index 80ff1737977846dee4dad93049c35ffb44509f13..ca7af1baabcba918e7763c64bb062837abbfd874 100644 --- a/omnioperator/omniop-spark-extension-ock/cpp/src/shuffle/ock_merge_reader.cpp +++ b/omnioperator/omniop-spark-extension-ock/cpp/src/shuffle/ock_merge_reader.cpp @@ -15,12 +15,17 @@ using namespace ock::dopspark; bool OckMergeReader::Initialize(const int32_t *typeIds, uint32_t colNum) { mColNum = colNum; - mVectorBatch = new (std::nothrow) VBDataDesc(colNum); + mVectorBatch = std::make_shared(); if (UNLIKELY(mVectorBatch == nullptr)) { LOG_ERROR("Failed to new instance for vector batch description"); return false; } + if (UNLIKELY(!mVectorBatch->Initialize(colNum))) { + LOG_ERROR("Failed to initialize vector batch."); + return false; + } + mColTypeIds.reserve(colNum); for (uint32_t index = 0; index < colNum; ++index) { mColTypeIds.emplace_back(typeIds[index]); @@ -29,44 +34,48 @@ bool OckMergeReader::Initialize(const int32_t *typeIds, uint32_t colNum) return true; } -bool OckMergeReader::GenerateVector(OckVector &vector, uint32_t rowNum, int32_t typeId, uint8_t *&startAddress) +bool OckMergeReader::GenerateVector(OckVectorPtr &vector, uint32_t rowNum, int32_t typeId, uint8_t *&startAddress) { uint8_t *address = startAddress; - vector.SetValueNulls(static_cast(address)); - vector.SetSize(rowNum); + vector->SetValueNulls(static_cast(address)); + vector->SetSize(rowNum); address += rowNum; switch (typeId) { case OMNI_BOOLEAN: { - vector.SetCapacityInBytes(sizeof(uint8_t) * rowNum); + vector->SetCapacityInBytes(sizeof(uint8_t) * rowNum); break; } case OMNI_SHORT: { - vector.SetCapacityInBytes(sizeof(uint16_t) * rowNum); + vector->SetCapacityInBytes(sizeof(uint16_t) * rowNum); break; } case OMNI_INT: case OMNI_DATE32: { - vector.SetCapacityInBytes(sizeof(uint32_t) * rowNum); + vector->SetCapacityInBytes(sizeof(uint32_t) * rowNum); break; } case OMNI_LONG: case OMNI_DOUBLE: case OMNI_DECIMAL64: case OMNI_DATE64: { - vector.SetCapacityInBytes(sizeof(uint64_t) * rowNum); + vector->SetCapacityInBytes(sizeof(uint64_t) * rowNum); break; } case OMNI_DECIMAL128: { - vector.SetCapacityInBytes(decimal128Size * rowNum); // 16 means value cost 16Byte + vector->SetCapacityInBytes(decimal128Size * rowNum); // 16 means value cost 16Byte break; } case OMNI_CHAR: case OMNI_VARCHAR: { // unknown length for value vector, calculate later // will add offset_vector_len when the length of values_vector is variable - vector.SetValueOffsets(static_cast(address)); + vector->SetValueOffsets(static_cast(address)); address += capacityOffset * (rowNum + 1); // 4 means value cost 4Byte - vector.SetCapacityInBytes(*reinterpret_cast(address - capacityOffset)); + vector->SetCapacityInBytes(*reinterpret_cast(address - capacityOffset)); + if (UNLIKELY(vector->GetCapacityInBytes() > maxCapacityInBytes)) { + LOG_ERROR("vector capacityInBytes exceed maxCapacityInBytes"); + return false; + } break; } default: { @@ -75,26 +84,26 @@ bool OckMergeReader::GenerateVector(OckVector &vector, uint32_t rowNum, int32_t } } - vector.SetValues(static_cast(address)); - address += vector.GetCapacityInBytes(); + vector->SetValues(static_cast(address)); + address += vector->GetCapacityInBytes(); startAddress = address; return true; } bool OckMergeReader::CalVectorValueLength(uint32_t colIndex, uint32_t &length) { - OckVector *vector = mVectorBatch->mColumnsHead[colIndex]; + auto vector = mVectorBatch->GetColumnHead(colIndex); + length = 0; for (uint32_t cnt = 0; cnt < mMergeCnt; ++cnt) { if (UNLIKELY(vector == nullptr)) { LOG_ERROR("Failed to calculate value length for column index %d", colIndex); return false; } - - mVectorBatch->mVectorValueLength[colIndex] += vector->GetCapacityInBytes(); + length += vector->GetCapacityInBytes(); vector = vector->GetNextVector(); } - length = mVectorBatch->mVectorValueLength[colIndex]; + mVectorBatch->SetColumnCapacity(colIndex, length); return true; } @@ -102,37 +111,27 @@ bool OckMergeReader::ScanOneVectorBatch(uint8_t *&startAddress) { uint8_t *address = startAddress; // get vector batch msg as vb_data_batch memory layout (upper) - mCurVBHeader = reinterpret_cast(address); - mVectorBatch->mHeader.rowNum += mCurVBHeader->rowNum; - mVectorBatch->mHeader.length += mCurVBHeader->length; + auto curVBHeader = reinterpret_cast(address); + mVectorBatch->AddTotalCapacity(curVBHeader->length); + mVectorBatch->AddTotalRowNum(curVBHeader->rowNum); address += sizeof(struct VBDataHeaderDesc); OckVector *curVector = nullptr; for (uint32_t colIndex = 0; colIndex < mColNum; colIndex++) { - curVector = mVectorBatch->mColumnsCur[colIndex]; - if (UNLIKELY(!GenerateVector(*curVector, mCurVBHeader->rowNum, mColTypeIds[colIndex], address))) { - LOG_ERROR("Failed to generate vector"); + auto curVector = mVectorBatch->GetCurColumn(colIndex); + if (UNLIKELY(curVector == nullptr)) { + LOG_ERROR("curVector is null, index %d", colIndex); return false; } - - if (curVector->GetNextVector() == nullptr) { - curVector = new (std::nothrow) OckVector(); - if (UNLIKELY(curVector == nullptr)) { - LOG_ERROR("Failed to new instance for ock vector"); - return false; - } - - // set next vector in the column merge list, and current column vector point to it - mVectorBatch->mColumnsCur[colIndex]->SetNextVector(curVector); - mVectorBatch->mColumnsCur[colIndex] = curVector; - } else { - mVectorBatch->mColumnsCur[colIndex] = curVector->GetNextVector(); + if (UNLIKELY(!GenerateVector(curVector, curVBHeader->rowNum, mColTypeIds[colIndex], address))) { + LOG_ERROR("Failed to generate vector"); + return false; } } - if (UNLIKELY((uint32_t)(address - startAddress) != mCurVBHeader->length)) { + if (UNLIKELY((uint32_t)(address - startAddress) != curVBHeader->length)) { LOG_ERROR("Failed to scan one vector batch as invalid date setting %d vs %d", - (uint32_t)(address - startAddress), mCurVBHeader->length); + (uint32_t)(address - startAddress), curVBHeader->length); return false; } @@ -159,34 +158,44 @@ bool OckMergeReader::GetMergeVectorBatch(uint8_t *&startAddress, uint32_t remain } mMergeCnt++; - if (mVectorBatch->mHeader.rowNum >= maxRowNum || mVectorBatch->mHeader.length >= maxSize) { + if (mVectorBatch->GetTotalRowNum() >= maxRowNum || mVectorBatch->GetTotalCapacity() >= maxSize) { break; } } startAddress = address; - return true; } -bool OckMergeReader::CopyPartDataToVector(uint8_t *&nulls, uint8_t *&values, - OckVector &srcVector, uint32_t colIndex) +bool OckMergeReader::CopyPartDataToVector(uint8_t *&nulls, uint8_t *&values, uint32_t &remainingSize, + uint32_t &remainingCapacity, OckVectorPtr &srcVector) { - errno_t ret = memcpy_s(nulls, srcVector.GetSize(), srcVector.GetValueNulls(), srcVector.GetSize()); + uint32_t srcSize = srcVector->GetSize(); + if (UNLIKELY(remainingSize < srcSize)) { + LOG_ERROR("Not eneough resource. remainingSize %d, srcSize %d.", remainingSize, srcSize); + return false; + } + errno_t ret = memcpy_s(nulls, remainingSize, srcVector->GetValueNulls(), srcSize); if (UNLIKELY(ret != EOK)) { LOG_ERROR("Failed to copy null vector"); return false; } - nulls += srcVector.GetSize(); + nulls += srcSize; + remainingSize -= srcSize; - if (srcVector.GetCapacityInBytes() > 0) { - ret = memcpy_s(values, srcVector.GetCapacityInBytes(), srcVector.GetValues(), - srcVector.GetCapacityInBytes()); + uint32_t srcCapacity = srcVector->GetCapacityInBytes(); + if (UNLIKELY(remainingCapacity < srcCapacity)) { + LOG_ERROR("Not enough resource. remainingCapacity %d, srcCapacity %d", remainingCapacity, srcCapacity); + return false; + } + if (srcCapacity > 0) { + ret = memcpy_s(values, remainingCapacity, srcVector->GetValues(), srcCapacity); if (UNLIKELY(ret != EOK)) { LOG_ERROR("Failed to copy values vector"); return false; } - values += srcVector.GetCapacityInBytes(); + values += srcCapacity; + remainingCapacity -=srcCapacity; } return true; @@ -195,13 +204,20 @@ bool OckMergeReader::CopyPartDataToVector(uint8_t *&nulls, uint8_t *&values, bool OckMergeReader::CopyDataToVector(Vector *dstVector, uint32_t colIndex) { // point to first src vector in list - OckVector *srcVector = mVectorBatch->mColumnsHead[colIndex]; + auto srcVector = mVectorBatch->GetColumnHead(colIndex); auto *nullsAddress = (uint8_t *)dstVector->GetValueNulls(); auto *valuesAddress = (uint8_t *)dstVector->GetValues(); uint32_t *offsetsAddress = (uint32_t *)dstVector->GetValueOffsets(); + dstVector->SetNullFlag(true); uint32_t totalSize = 0; uint32_t currentSize = 0; + if (dstVector->GetSize() < 0 || dstVector->GetCapacityInBytes() < 0) { + LOG_ERROR("Invalid vector size %d or capacity %d", dstVector->GetSize(), dstVector->GetCapacityInBytes()); + return false; + } + uint32_t remainingSize = (uint32_t)dstVector->GetSize(); + uint32_t remainingCapacity = (uint32_t)dstVector->GetCapacityInBytes(); for (uint32_t cnt = 0; cnt < mMergeCnt; ++cnt) { if (UNLIKELY(srcVector == nullptr)) { @@ -209,7 +225,7 @@ bool OckMergeReader::CopyDataToVector(Vector *dstVector, uint32_t colIndex) return false; } - if (UNLIKELY(!CopyPartDataToVector(nullsAddress, valuesAddress, *srcVector, colIndex))) { + if (UNLIKELY(!CopyPartDataToVector(nullsAddress, valuesAddress, remainingSize, remainingCapacity, srcVector))) { return false; } @@ -226,9 +242,9 @@ bool OckMergeReader::CopyDataToVector(Vector *dstVector, uint32_t colIndex) if (mColTypeIds[colIndex] == OMNI_CHAR || mColTypeIds[colIndex] == OMNI_VARCHAR) { *offsetsAddress = totalSize; - if (UNLIKELY(totalSize != mVectorBatch->mVectorValueLength[colIndex])) { + if (UNLIKELY(totalSize != mVectorBatch->GetColumnCapacity(colIndex))) { LOG_ERROR("Failed to calculate variable vector value length, %d to %d", totalSize, - mVectorBatch->mVectorValueLength[colIndex]); + mVectorBatch->GetColumnCapacity(colIndex)); return false; } } diff --git a/omnioperator/omniop-spark-extension-ock/cpp/src/shuffle/ock_merge_reader.h b/omnioperator/omniop-spark-extension-ock/cpp/src/shuffle/ock_merge_reader.h index b5d5fba4d7ddd910146126201cc27776f6ad813b..7120b260dd3474993bcf02b3f36e5347b3cf9aa4 100644 --- a/omnioperator/omniop-spark-extension-ock/cpp/src/shuffle/ock_merge_reader.h +++ b/omnioperator/omniop-spark-extension-ock/cpp/src/shuffle/ock_merge_reader.h @@ -15,33 +15,34 @@ public: bool Initialize(const int32_t *typeIds, uint32_t colNum); bool GetMergeVectorBatch(uint8_t *&address, uint32_t remain, uint32_t maxRowNum, uint32_t maxSize); - bool CopyPartDataToVector(uint8_t *&nulls, uint8_t *&values, OckVector &srcVector, uint32_t colIndex); + bool CopyPartDataToVector(uint8_t *&nulls, uint8_t *&values, uint32_t &remainingSize, uint32_t &remainingCapacity, + OckVectorPtr &srcVector); bool CopyDataToVector(omniruntime::vec::Vector *dstVector, uint32_t colIndex); [[nodiscard]] inline uint32_t GetVectorBatchLength() const { - return mVectorBatch->mHeader.length; + return mVectorBatch->GetTotalCapacity(); } [[nodiscard]] inline uint32_t GetRowNumAfterMerge() const { - return mVectorBatch->mHeader.rowNum; + return mVectorBatch->GetTotalRowNum(); } bool CalVectorValueLength(uint32_t colIndex, uint32_t &length); private: - static bool GenerateVector(OckVector &vector, uint32_t rowNum, int32_t typeId, uint8_t *&startAddress); + static bool GenerateVector(OckVectorPtr &vector, uint32_t rowNum, int32_t typeId, uint8_t *&startAddress); bool ScanOneVectorBatch(uint8_t *&startAddress); static constexpr int capacityOffset = 4; static constexpr int decimal128Size = 16; + static constexpr int maxCapacityInBytes = 1073741824; private: // point to shuffle blob current vector batch data header uint32_t mColNum = 0; uint32_t mMergeCnt = 0; std::vector mColTypeIds {}; - VBHeaderPtr mCurVBHeader = nullptr; VBDataDescPtr mVectorBatch = nullptr; }; } diff --git a/omnioperator/omniop-spark-extension-ock/cpp/src/shuffle/ock_splitter.cpp b/omnioperator/omniop-spark-extension-ock/cpp/src/shuffle/ock_splitter.cpp index 5c046686755c88ccf3e0bdb39e70633c49015aca..1732ceb37a3c5e2698c6eb309958b0808bb8422e 100644 --- a/omnioperator/omniop-spark-extension-ock/cpp/src/shuffle/ock_splitter.cpp +++ b/omnioperator/omniop-spark-extension-ock/cpp/src/shuffle/ock_splitter.cpp @@ -47,7 +47,7 @@ bool OckSplitter::ToSplitterTypeId(const int32_t *vBColTypes) break; } case OMNI_CHAR: - case OMNI_VARCHAR: { // unknown length for value vector, calculate later + case OMNI_VARCHAR: { // unknown length for value vector, calculate later mMinDataLenInVBByRow += uint32Size; // 4 means offset mVBColShuffleTypes.emplace_back(ShuffleTypeId::SHUFFLE_BINARY); mColIndexOfVarVec.emplace_back(colIndex); @@ -70,11 +70,15 @@ bool OckSplitter::ToSplitterTypeId(const int32_t *vBColTypes) return true; } -void OckSplitter::InitCacheRegion() +bool OckSplitter::InitCacheRegion() { mCacheRegion.reserve(mPartitionNum); mCacheRegion.resize(mPartitionNum); + if (UNLIKELY(mOckBuffer->GetRegionSize() * 2 < mMinDataLenInVB || mMinDataLenInVBByRow == 0)) { + LOG_DEBUG("regionSize * doubleNum should be bigger than mMinDataLenInVB %d", mMinDataLenInVBByRow); + return false; + } uint32_t rowNum = (mOckBuffer->GetRegionSize() * 2 - mMinDataLenInVB) / mMinDataLenInVBByRow; LOG_INFO("Each region can cache row number is %d", rowNum); @@ -84,6 +88,7 @@ void OckSplitter::InitCacheRegion() region.mLength = 0; region.mRowNum = 0; } + return true; } bool OckSplitter::Initialize(const int32_t *colTypeIds) @@ -122,6 +127,10 @@ std::shared_ptr OckSplitter::Create(const int32_t *colTypeIds, int3 std::shared_ptr OckSplitter::Make(const std::string &partitionMethod, int partitionNum, const int32_t *colTypeIds, int32_t colNum, uint64_t threadId) { + if (UNLIKELY(colTypeIds == nullptr || colNum == 0)) { + LOG_ERROR("colTypeIds is null or colNum is 0, colNum %d", colNum); + return nullptr; + } if (partitionMethod == "hash" || partitionMethod == "rr" || partitionMethod == "range") { return Create(colTypeIds, colNum, partitionNum, false, threadId); } else if (UNLIKELY(partitionMethod == "single")) { @@ -176,27 +185,58 @@ bool OckSplitter::WriteFixedWidthValueTemple(Vector *vector, bool isDict, std::v T *srcValues = nullptr; if (isDict) { - auto ids = static_cast(mAllocator->alloc(mCurrentVB->GetRowCount() * sizeof(int32_t))); + int32_t idsNum = mCurrentVB->GetRowCount(); + int64_t idsSizeInBytes = idsNum * sizeof(int32_t); + auto ids = static_cast(mAllocator->alloc(idsSizeInBytes)); if (UNLIKELY(ids == nullptr)) { LOG_ERROR("Failed to allocate space for fixed width value ids."); return false; } auto dictionary = - (reinterpret_cast(vector))->ExtractDictionaryAndIds(0, mCurrentVB->GetRowCount(), ids); + (reinterpret_cast(vector))->ExtractDictionaryAndIds(0, idsNum, ids); if (UNLIKELY(dictionary == nullptr)) { LOG_ERROR("Failed to get dictionary"); + mAllocator->free((uint8_t *)(ids), idsSizeInBytes); return false; } srcValues = reinterpret_cast(VectorHelper::GetValuesAddr(dictionary)); + if (UNLIKELY(srcValues == nullptr)) { + LOG_ERROR("Source values address is null."); + mAllocator->free((uint8_t *)(ids), idsSizeInBytes); + return false; + } + int32_t srcRowCount = dictionary->GetSize(); for (uint32_t index = 0; index < rowNum; ++index) { - *dstValues++ = srcValues[reinterpret_cast(ids)[rowIndexes[index]]]; // write value to local blob + uint32_t idIndex = rowIndexes[index]; + if (UNLIKELY(idIndex >= idsNum)) { + LOG_ERROR("Invalid idIndex %d, idsNum.", idIndex, idsNum); + mAllocator->free((uint8_t *)(ids), idsSizeInBytes); + return false; + } + uint32_t rowIndex = reinterpret_cast(ids)[idIndex]; + if (UNLIKELY(rowIndex >= srcRowCount)) { + LOG_ERROR("Invalid rowIndex %d, srcRowCount %d.", rowIndex, srcRowCount); + mAllocator->free((uint8_t *)(ids), idsSizeInBytes); + return false; + } + *dstValues++ = srcValues[rowIndex]; // write value to local blob } - mAllocator->free((uint8_t *)(ids), mCurrentVB->GetRowCount() * sizeof(int32_t)); + mAllocator->free((uint8_t *)(ids), idsSizeInBytes); } else { srcValues = reinterpret_cast(VectorHelper::GetValuesAddr(vector)); + if (UNLIKELY(srcValues == nullptr)) { + LOG_ERROR("Source values address is null."); + return false; + } + int32_t srcRowCount = vector->GetSize(); for (uint32_t index = 0; index < rowNum; ++index) { - *dstValues++ = srcValues[rowIndexes[index]]; // write value to local blob + uint32_t rowIndex = rowIndexes[index]; + if (UNLIKELY(rowIndex >= srcRowCount)) { + LOG_ERROR("Invalid rowIndex %d, srcRowCount %d.", rowIndex, srcRowCount); + return false; + } + *dstValues++ = srcValues[rowIndex]; // write value to local blob } } @@ -205,37 +245,67 @@ bool OckSplitter::WriteFixedWidthValueTemple(Vector *vector, bool isDict, std::v return true; } -bool OckSplitter::WriteDecimal128(Vector *vector, bool isDict, std::vector &rowIndexes, - uint32_t rowNum, uint64_t *&address) +bool OckSplitter::WriteDecimal128(Vector *vector, bool isDict, std::vector &rowIndexes, uint32_t rowNum, + uint64_t *&address) { uint64_t *dstValues = address; uint64_t *srcValues = nullptr; if (isDict) { - auto ids = static_cast(mAllocator->alloc(mCurrentVB->GetRowCount() * sizeof(int32_t))); + uint32_t idsNum = mCurrentVB->GetRowCount(); + int64_t idsSizeInBytes = idsNum * sizeof(int32_t); + auto ids = static_cast(mAllocator->alloc(idsSizeInBytes)); if (UNLIKELY(ids == nullptr)) { LOG_ERROR("Failed to allocate space for fixed width value ids."); return false; } - auto dictionary = - (reinterpret_cast(vector))->ExtractDictionaryAndIds(0, mCurrentVB->GetRowCount(), ids); + auto dictionary = (reinterpret_cast(vector))->ExtractDictionaryAndIds(0, idsNum, ids); if (UNLIKELY(dictionary == nullptr)) { LOG_ERROR("Failed to get dictionary"); + mAllocator->free((uint8_t *)(ids), idsSizeInBytes); return false; } srcValues = reinterpret_cast(VectorHelper::GetValuesAddr(dictionary)); + if (UNLIKELY(srcValues == nullptr)) { + LOG_ERROR("Source values address is null."); + mAllocator->free((uint8_t *)(ids), idsSizeInBytes); + return false; + } + int32_t srcRowCount = dictionary->GetSize(); for (uint32_t index = 0; index < rowNum; ++index) { - *dstValues++ = srcValues[reinterpret_cast(ids)[rowIndexes[index]] << 1]; - *dstValues++ = srcValues[(reinterpret_cast(ids)[rowIndexes[index]] << 1) | 1]; + uint32_t idIndex = rowIndexes[index]; + if (UNLIKELY(idIndex >= idsNum)) { + LOG_ERROR("Invalid idIndex %d, idsNum.", idIndex, idsNum); + mAllocator->free((uint8_t *)(ids), idsSizeInBytes); + return false; + } + uint32_t rowIndex = reinterpret_cast(ids)[idIndex]; + if (UNLIKELY(rowIndex >= srcRowCount)) { + LOG_ERROR("Invalid rowIndex %d, srcRowCount %d.", rowIndex, srcRowCount); + mAllocator->free((uint8_t *)(ids), idsSizeInBytes); + return false; + } + *dstValues++ = srcValues[rowIndex << 1]; + *dstValues++ = srcValues[rowIndex << 1 | 1]; } - mAllocator->free((uint8_t *)(ids), mCurrentVB->GetRowCount() * sizeof(int32_t)); + mAllocator->free((uint8_t *)(ids), idsSizeInBytes); } else { srcValues = reinterpret_cast(VectorHelper::GetValuesAddr(vector)); + if (UNLIKELY(srcValues == nullptr)) { + LOG_ERROR("Source values address is null."); + return false; + } + int32_t srcRowCount = vector->GetSize(); for (uint32_t index = 0; index < rowNum; ++index) { + uint32_t rowIndex = rowIndexes[index]; + if (UNLIKELY(rowIndex >= srcRowCount)) { + LOG_ERROR("Invalid rowIndex %d, srcRowCount %d.", rowIndex, srcRowCount); + return false; + } *dstValues++ = srcValues[rowIndexes[index] << 1]; // write value to local blob - *dstValues++ = srcValues[(rowIndexes[index] << 1) | 1]; // write value to local blob + *dstValues++ = srcValues[rowIndexes[index] << 1 | 1]; // write value to local blob } } @@ -243,8 +313,8 @@ bool OckSplitter::WriteDecimal128(Vector *vector, bool isDict, std::vector &rowIndexes, uint32_t rowNum, uint8_t *&address) +bool OckSplitter::WriteFixedWidthValue(Vector *vector, ShuffleTypeId typeId, std::vector &rowIndexes, + uint32_t rowNum, uint8_t *&address) { bool isDict = (vector->GetEncoding() == OMNI_VEC_ENCODING_DICTIONARY); switch (typeId) { @@ -285,8 +355,8 @@ bool OckSplitter::WriteFixedWidthValue(Vector *vector, ShuffleTypeId typeId, return true; } -bool OckSplitter::WriteVariableWidthValue(Vector *vector, std::vector &rowIndexes, - uint32_t rowNum, uint8_t *&address) +bool OckSplitter::WriteVariableWidthValue(Vector *vector, std::vector &rowIndexes, uint32_t rowNum, + uint8_t *&address) { bool isDict = (vector->GetEncoding() == OMNI_VEC_ENCODING_DICTIONARY); auto *offsetAddress = reinterpret_cast(address); // point the offset space base address @@ -295,11 +365,17 @@ bool OckSplitter::WriteVariableWidthValue(Vector *vector, std::vector int32_t length = 0; uint8_t *srcValues = nullptr; + int32_t vectorSize = vector->GetSize(); for (uint32_t rowCnt = 0; rowCnt < rowNum; rowCnt++) { + uint32_t rowIndex = rowIndexes[rowCnt]; + if (UNLIKELY(rowIndex >= vectorSize)) { + LOG_ERROR("Invalid rowIndex %d, vectorSize %d.", rowIndex, vectorSize); + return false; + } if (isDict) { - length = reinterpret_cast(vector)->GetVarchar(rowIndexes[rowCnt], &srcValues); + length = reinterpret_cast(vector)->GetVarchar(rowIndex, &srcValues); } else { - length = reinterpret_cast(vector)->GetValue(rowIndexes[rowCnt], &srcValues); + length = reinterpret_cast(vector)->GetValue(rowIndex, &srcValues); } // write the null value in the vector with row index to local blob if (UNLIKELY(length > 0 && memcpy_s(valueAddress, length, srcValues, length) != EOK)) { @@ -353,6 +429,10 @@ bool OckSplitter::WritePartVectorBatch(VectorBatch &vb, uint32_t partitionId) uint32_t regionId = 0; // backspace from local blob the region end address to remove preoccupied bytes for the vector batch region auto address = mOckBuffer->GetEndAddressOfRegion(partitionId, regionId, vbRegion->mLength); + if (UNLIKELY(address == nullptr)) { + LOG_ERROR("Failed to get address with partitionId %d", partitionId); + return false; + } // write the header information of the vector batch in local blob auto header = reinterpret_cast(address); header->length = vbRegion->mLength; @@ -361,6 +441,10 @@ bool OckSplitter::WritePartVectorBatch(VectorBatch &vb, uint32_t partitionId) if (!mOckBuffer->IsCompress()) { // record write bytes when don't need compress mTotalWriteBytes += header->length; } + if (UNLIKELY(partitionId > mPartitionLengths.size())) { + LOG_ERROR("Illegal partitionId %d", partitionId); + return false; + } mPartitionLengths[partitionId] += header->length; // we can't get real length when compress address += vbHeaderSize; // 8 means header length so skip @@ -382,6 +466,10 @@ bool OckSplitter::WritePartVectorBatch(VectorBatch &vb, uint32_t partitionId) bool OckSplitter::FlushAllRegionAndGetNewBlob(VectorBatch &vb) { + if (UNLIKELY(mPartitionNum > mCacheRegion.size())) { + LOG_ERROR("Illegal mPartitionNum %d", mPartitionNum); + return false; + } for (uint32_t partitionId = 0; partitionId < mPartitionNum; ++partitionId) { if (mCacheRegion[partitionId].mRowNum == 0) { continue; @@ -421,6 +509,10 @@ bool OckSplitter::FlushAllRegionAndGetNewBlob(VectorBatch &vb) bool OckSplitter::PreoccupiedBufferSpace(VectorBatch &vb, uint32_t partitionId, uint32_t rowIndex, uint32_t rowLength, bool newRegion) { + if (UNLIKELY(partitionId > mCacheRegion.size())) { + LOG_ERROR("Illegal partitionId %d", partitionId); + return false; + } uint32_t preoccupiedSize = rowLength; if (mCacheRegion[partitionId].mRowNum == 0) { preoccupiedSize += mMinDataLenInVB; // means create a new vector batch, so will cost header diff --git a/omnioperator/omniop-spark-extension-ock/cpp/src/shuffle/ock_splitter.h b/omnioperator/omniop-spark-extension-ock/cpp/src/shuffle/ock_splitter.h index fc81195099f49a2cae202a10324f0725ee5a08bb..8f26b84be3a42442876712ed4f7b6055a33dbdbd 100644 --- a/omnioperator/omniop-spark-extension-ock/cpp/src/shuffle/ock_splitter.h +++ b/omnioperator/omniop-spark-extension-ock/cpp/src/shuffle/ock_splitter.h @@ -70,7 +70,10 @@ public: return false; } - InitCacheRegion(); + if (UNLIKELY(!InitCacheRegion())) { + LOG_ERROR("Failed to initialize CacheRegion"); + return false; + } return true; } @@ -98,7 +101,7 @@ private: return mIsSinglePt ? 0 : mPtViewInCurVB->GetValue(rowIndex); } - void InitCacheRegion(); + bool InitCacheRegion(); inline void ResetCacheRegion() { @@ -159,6 +162,7 @@ private: static constexpr uint32_t uint64Size = 8; static constexpr uint32_t decimal128Size = 16; static constexpr uint32_t vbHeaderSize = 8; + static constexpr uint32_t doubleNum = 2; /* the region use for all vector batch ---------------------------------------------------------------- */ // this splitter which corresponding to one map task in one shuffle, so some params is same uint32_t mPartitionNum = 0; diff --git a/omnioperator/omniop-spark-extension-ock/cpp/src/shuffle/ock_type.h b/omnioperator/omniop-spark-extension-ock/cpp/src/shuffle/ock_type.h index e07e67f17d7281f5df0e1d4ee17a4949bc1da697..03e444b6ce4e7284a36e859c327cc51546fb26ab 100644 --- a/omnioperator/omniop-spark-extension-ock/cpp/src/shuffle/ock_type.h +++ b/omnioperator/omniop-spark-extension-ock/cpp/src/shuffle/ock_type.h @@ -6,7 +6,7 @@ #define SPARK_THESTRAL_PLUGIN_OCK_TYPE_H #include "ock_vector.h" -#include "common/debug.h" +#include "common/common.h" namespace ock { namespace dopspark { @@ -33,58 +33,118 @@ enum class ShuffleTypeId : int { using VBHeaderPtr = struct VBDataHeaderDesc { uint32_t length = 0; // 4Byte uint32_t rowNum = 0; // 4Byte -} __attribute__((packed)) * ; +} __attribute__((packed)) *; -using VBDataDescPtr = struct VBDataDesc { - explicit VBDataDesc(uint32_t colNum) +class VBDataDesc { +public: + VBDataDesc() = default; + ~VBDataDesc() { + for (auto &vector : mColumnsHead) { + if (vector == nullptr) { + continue; + } + auto currVector = vector; + while (currVector->GetNextVector() != nullptr) { + auto nextVector = currVector->GetNextVector(); + currVector->SetNextVector(nullptr); + currVector = nextVector; + } + } + } + + bool Initialize(uint32_t colNum) + { + this->colNum = colNum; mHeader.rowNum = 0; mHeader.length = 0; - mColumnsHead.reserve(colNum); mColumnsHead.resize(colNum); - mColumnsCur.reserve(colNum); mColumnsCur.resize(colNum); - mVectorValueLength.reserve(colNum); - mVectorValueLength.resize(colNum); + mColumnsCapacity.resize(colNum); - for (auto &index : mColumnsHead) { - index = new (std::nothrow) OckVector(); + for (auto &vector : mColumnsHead) { + vector = std::make_shared(); + if (vector == nullptr) { + mColumnsHead.clear(); + return false; + } } + return true; } inline void Reset() { mHeader.rowNum = 0; mHeader.length = 0; - std::fill(mVectorValueLength.begin(), mVectorValueLength.end(), 0); + std::fill(mColumnsCapacity.begin(), mColumnsCapacity.end(), 0); for (uint32_t index = 0; index < mColumnsCur.size(); ++index) { mColumnsCur[index] = mColumnsHead[index]; } } + std::shared_ptr GetColumnHead(uint32_t colIndex) { + if (colIndex >= colNum) { + return nullptr; + } + return mColumnsHead[colIndex]; + } + + void SetColumnCapacity(uint32_t colIndex, uint32_t length) { + mColumnsCapacity[colIndex] = length; + } + + uint32_t GetColumnCapacity(uint32_t colIndex) { + return mColumnsCapacity[colIndex]; + } + + std::shared_ptr GetCurColumn(uint32_t colIndex) + { + if (colIndex >= colNum) { + return nullptr; + } + auto currVector = mColumnsCur[colIndex]; + if (currVector->GetNextVector() == nullptr) { + auto newCurVector = std::make_shared(); + if (UNLIKELY(newCurVector == nullptr)) { + LOG_ERROR("Failed to new instance for ock vector"); + return nullptr; + } + currVector->SetNextVector(newCurVector); + mColumnsCur[colIndex] = newCurVector; + } else { + mColumnsCur[colIndex] = currVector->GetNextVector(); + } + return currVector; + } + + uint32_t GetTotalCapacity() + { + return mHeader.length; + } + + uint32_t GetTotalRowNum() + { + return mHeader.rowNum; + } + + void AddTotalCapacity(uint32_t length) { + mHeader.length += length; + } + + void AddTotalRowNum(uint32_t rowNum) + { + mHeader.rowNum +=rowNum; + } + +private: + uint32_t colNum = 0; VBDataHeaderDesc mHeader; - std::vector mVectorValueLength; - std::vector mColumnsCur; - std::vector mColumnsHead; // Array[List[OckVector *]] -} * ; + std::vector mColumnsCapacity; + std::vector mColumnsCur; + std::vector mColumnsHead; // Array[List[OckVector *]] +}; +using VBDataDescPtr = std::shared_ptr; } } -#define PROFILE_START_L1(name) \ - long tcDiff##name = 0; \ - struct timespec tcStart##name = { 0, 0 }; \ - clock_gettime(CLOCK_MONOTONIC, &tcStart##name); - -#define PROFILE_END_L1(name) \ - struct timespec tcEnd##name = { 0, 0 }; \ - clock_gettime(CLOCK_MONOTONIC, &tcEnd##name); \ - \ - long diffSec##name = tcEnd##name.tv_sec - tcStart##name.tv_sec; \ - if (diffSec##name == 0) { \ - tcDiff##name = tcEnd##name.tv_nsec - tcStart##name.tv_nsec; \ - } else { \ - tcDiff##name = diffSec##name * 1000000000 + tcEnd##name.tv_nsec - tcStart##name.tv_nsec; \ - } - -#define PROFILE_VALUE(name) tcDiff##name #endif // SPARK_THESTRAL_PLUGIN_OCK_TYPE_H \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension-ock/cpp/src/shuffle/ock_vector.h b/omnioperator/omniop-spark-extension-ock/cpp/src/shuffle/ock_vector.h index 0cfca5d63173c04c37771900e1ac17c2c04e8bba..515f88db8355a58321a7290179e48b48802cb8cc 100644 --- a/omnioperator/omniop-spark-extension-ock/cpp/src/shuffle/ock_vector.h +++ b/omnioperator/omniop-spark-extension-ock/cpp/src/shuffle/ock_vector.h @@ -69,12 +69,12 @@ public: valueOffsetsAddress = address; } - inline void SetNextVector(OckVector *next) + inline void SetNextVector(std::shared_ptr next) { mNext = next; } - inline OckVector *GetNextVector() + inline std::shared_ptr GetNextVector() { return mNext; } @@ -87,8 +87,9 @@ private: void *valueNullsAddress = nullptr; void *valueOffsetsAddress = nullptr; - OckVector *mNext = nullptr; + std::shared_ptr mNext = nullptr; }; +using OckVectorPtr = std::shared_ptr; } } #endif // SPARK_THESTRAL_PLUGIN_OCK_VECTOR_H diff --git a/omnioperator/omniop-spark-extension-ock/src/main/java/com/huawei/ock/spark/jni/OckShuffleJniReader.java b/omnioperator/omniop-spark-extension-ock/src/main/java/com/huawei/ock/spark/jni/OckShuffleJniReader.java index ec294bdbf2208361846b4576ba0559abb9cfabc2..462ad9d105a54374bc867a9d83e45133fc238332 100644 --- a/omnioperator/omniop-spark-extension-ock/src/main/java/com/huawei/ock/spark/jni/OckShuffleJniReader.java +++ b/omnioperator/omniop-spark-extension-ock/src/main/java/com/huawei/ock/spark/jni/OckShuffleJniReader.java @@ -150,8 +150,18 @@ public class OckShuffleJniReader { nativeCopyVecDataInVB(nativeReader, dstVec.getNativeVector(), colIndex); } + /** + * close reader. + * + */ + public void doClose() { + close(nativeReader); + } + private native long make(int[] typeIds); + private native long close(long readerId); + private native int nativeGetVectorBatch(long readerId, long vbDataAddr, int capacity, int maxRow, int maxDataSize, Long rowCnt); diff --git a/omnioperator/omniop-spark-extension-ock/src/main/java/com/huawei/ock/spark/serialize/.keep b/omnioperator/omniop-spark-extension-ock/src/main/java/com/huawei/ock/spark/serialize/.keep deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/omnioperator/omniop-spark-extension-ock/src/main/scala/com/huawei/.keep b/omnioperator/omniop-spark-extension-ock/src/main/scala/com/huawei/.keep deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/omnioperator/omniop-spark-extension-ock/src/main/scala/com/huawei/ock/.keep b/omnioperator/omniop-spark-extension-ock/src/main/scala/com/huawei/ock/.keep deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/omnioperator/omniop-spark-extension-ock/src/main/scala/com/huawei/ock/spark/.keep b/omnioperator/omniop-spark-extension-ock/src/main/scala/com/huawei/ock/spark/.keep deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/omnioperator/omniop-spark-extension-ock/src/main/scala/org/.keep b/omnioperator/omniop-spark-extension-ock/src/main/scala/org/.keep deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/omnioperator/omniop-spark-extension-ock/src/main/scala/org/apache/.keep b/omnioperator/omniop-spark-extension-ock/src/main/scala/org/apache/.keep deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/omnioperator/omniop-spark-extension-ock/src/main/scala/org/apache/spark/.keep b/omnioperator/omniop-spark-extension-ock/src/main/scala/org/apache/spark/.keep deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/omnioperator/omniop-spark-extension-ock/src/main/scala/org/apache/spark/shuffle/.keep b/omnioperator/omniop-spark-extension-ock/src/main/scala/org/apache/spark/shuffle/.keep deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/omnioperator/omniop-spark-extension-ock/src/main/scala/org/apache/spark/shuffle/ock/.keep b/omnioperator/omniop-spark-extension-ock/src/main/scala/org/apache/spark/shuffle/ock/.keep deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/omnioperator/omniop-spark-extension-ock/src/main/scala/org/apache/spark/shuffle/ock/OckColumnarShuffleBufferIterator.scala b/omnioperator/omniop-spark-extension-ock/src/main/scala/org/apache/spark/shuffle/ock/OckColumnarShuffleBufferIterator.scala index dc7e081555dfed6646beed6b85fc1f8356b8aa86..89bfcad6f877bcafb9c5166b5b9d59b21b3c49d7 100644 --- a/omnioperator/omniop-spark-extension-ock/src/main/scala/org/apache/spark/shuffle/ock/OckColumnarShuffleBufferIterator.scala +++ b/omnioperator/omniop-spark-extension-ock/src/main/scala/org/apache/spark/shuffle/ock/OckColumnarShuffleBufferIterator.scala @@ -51,6 +51,9 @@ class OckColumnarShuffleBufferIterator[T]( NativeShuffle.destroyMapTaskInfo(mapTaskToHostInfo.getNativeObjHandle) mapTaskToHostInfo.setNativeObjHandle(0) } + blobMap.values.foreach(reader => { + reader.doClose() + }) } private[this] def throwFetchException(fetchError: FetchError): Unit = { diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarPluginConfig.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarPluginConfig.scala index 58eef4125480f04cdb7e80acf247b1525c6e7313..b155bc3c5d85d4f5e547fcc6f2d0b3f0b6e7a5b8 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarPluginConfig.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarPluginConfig.scala @@ -148,7 +148,7 @@ class ColumnarPluginConfig(conf: SQLConf) extends Logging { .toBoolean val enableFusion: Boolean = conf - .getConfString("spark.omni.sql.columnar.fusion", "true") + .getConfString("spark.omni.sql.columnar.fusion", "false") .toBoolean // Pick columnar shuffle hash join if one side join count > = 0 to build local hash map, and is diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/expression/OmniExpressionAdaptor.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/expression/OmniExpressionAdaptor.scala index 26555cc23b21a78180403ada2e9a3c3921543c6b..d418a29526404e31111b3cbdb6c18c27a01ab9b5 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/expression/OmniExpressionAdaptor.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/expression/OmniExpressionAdaptor.scala @@ -31,7 +31,7 @@ import com.huawei.boostkit.spark.ColumnarPluginConfig import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ -import org.apache.spark.sql.catalyst.plans.{FullOuter, InnerLike, JoinType, LeftOuter, RightOuter} +import org.apache.spark.sql.catalyst.plans.{FullOuter, InnerLike, JoinType, LeftAnti, LeftOuter, LeftSemi, RightOuter} import org.apache.spark.sql.catalyst.util.CharVarcharUtils.getRawTypeString import org.apache.spark.sql.hive.HiveUdfAdaptorUtil import org.apache.spark.sql.types.{BooleanType, DataType, DateType, Decimal, DecimalType, DoubleType, IntegerType, LongType, Metadata, ShortType, StringType} @@ -970,6 +970,10 @@ object OmniExpressionAdaptor extends Logging { OMNI_JOIN_TYPE_LEFT case RightOuter => OMNI_JOIN_TYPE_RIGHT + case LeftSemi => + OMNI_JOIN_TYPE_LEFT_SEMI + case LeftAnti => + OMNI_JOIN_TYPE_LEFT_ANTI case _ => throw new UnsupportedOperationException(s"Join-type[$joinType] is not supported.") } diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/joins/ColumnarSortMergeJoinExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/joins/ColumnarSortMergeJoinExec.scala index 632f718a1c1daac5ca9996174a4c8533870ba16f..92fb96b6799553ce3b18b287bab33c4ace2741cf 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/joins/ColumnarSortMergeJoinExec.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/joins/ColumnarSortMergeJoinExec.scala @@ -94,8 +94,8 @@ class ColumnarSortMergeJoinExec( def buildCheck(): Unit = { joinType match { - case _: InnerLike | LeftOuter | FullOuter => - // SMJ join support InnerLike | LeftOuter | FullOuter + case _: InnerLike | LeftOuter | FullOuter | LeftSemi | LeftAnti => + // SMJ join support InnerLike | LeftOuter | FullOuter | LeftSemi | LeftAnti case _ => throw new UnsupportedOperationException(s"Join-type[${joinType}] is not supported " + s"in ${this.nodeName}") @@ -130,7 +130,7 @@ class ColumnarSortMergeJoinExec( condition match { case Some(expr) => val filterExpr: String = OmniExpressionAdaptor.rewriteToOmniJsonExpressionLiteral(expr, - OmniExpressionAdaptor.getExprIdMap(output.map(_.toAttribute))) + OmniExpressionAdaptor.getExprIdMap((left.output ++ right.output).map(_.toAttribute))) if (!isSimpleColumn(filterExpr)) { checkOmniJsonWhiteList(filterExpr, new Array[AnyRef](0)) } @@ -150,15 +150,6 @@ class ColumnarSortMergeJoinExec( val streamVecBatchs = longMetric("numStreamVecBatchs") val bufferVecBatchs = longMetric("numBufferVecBatchs") - val omniJoinType : nova.hetu.omniruntime.constants.JoinType = joinType match { - case _: InnerLike => OMNI_JOIN_TYPE_INNER - case LeftOuter => OMNI_JOIN_TYPE_LEFT - case FullOuter => OMNI_JOIN_TYPE_FULL - case x => - throw new UnsupportedOperationException(s"ColumnSortMergeJoin Join-type[$x] is not supported " + - s"in ${this.nodeName}") - } - val streamedTypes = new Array[DataType](left.output.size) left.output.zipWithIndex.foreach { case (attr, i) => streamedTypes(i) = OmniExpressionAdaptor.sparkTypeToOmniType(attr.dataType, attr.metadata) @@ -177,12 +168,19 @@ class ColumnarSortMergeJoinExec( OmniExpressionAdaptor.rewriteToOmniJsonExpressionLiteral(x, OmniExpressionAdaptor.getExprIdMap(right.output.map(_.toAttribute))) }.toArray - val bufferedOutputChannel = right.output.indices.toArray + val bufferedOutputChannel: Array[Int] = joinType match { + case _: InnerLike | LeftOuter | FullOuter => + right.output.indices.toArray + case LeftExistence(_) => + Array[Int]() + case x => + throw new UnsupportedOperationException(s"ColumnSortMergeJoin Join-type[$x] is not supported!") + } val filterString: String = condition match { case Some(expr) => OmniExpressionAdaptor.rewriteToOmniJsonExpressionLiteral(expr, - OmniExpressionAdaptor.getExprIdMap(output.map(_.toAttribute))) + OmniExpressionAdaptor.getExprIdMap((left.output ++ right.output).map(_.toAttribute))) case _ => null } @@ -220,8 +218,8 @@ class ColumnarSortMergeJoinExec( val iterBatch = new Iterator[ColumnarBatch] { var isFinished : Boolean = joinType match { - case _: InnerLike => !streamedIter.hasNext || !bufferedIter.hasNext - case LeftOuter => !streamedIter.hasNext + case _: InnerLike | LeftSemi => !streamedIter.hasNext || !bufferedIter.hasNext + case LeftOuter | LeftAnti => !streamedIter.hasNext case FullOuter => !(streamedIter.hasNext || bufferedIter.hasNext) case x => throw new UnsupportedOperationException(s"ColumnSortMergeJoin Join-type[$x] is not supported!")