From f0b5956a51e2606a7889e92658bd01144c9f9cd4 Mon Sep 17 00:00:00 2001 From: liqi1013 Date: Mon, 8 Sep 2025 14:58:36 +0800 Subject: [PATCH 1/2] =?UTF-8?q?=E6=90=BA=E7=A8=8B=E5=88=86=E6=94=AF?= =?UTF-8?q?=E4=BB=A3=E7=A0=81=E8=BF=81=E7=A7=BB?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../cpp/src/CMakeLists.txt | 2 + .../cpp/src/orcfile/OmniByteRLE.cc | 104 +++-- .../cpp/src/orcfile/OmniColReader.cc | 176 ++++++-- .../cpp/src/orcfile/OmniColReader.hh | 52 ++- .../cpp/src/orcfile/OmniRLEv2.cc | 14 + .../cpp/src/orcfile/OmniRLEv2.hh | 6 + .../cpp/src/orcfile/OmniRowReaderImpl.cc | 2 +- .../cpp/src/orcfile/OrcUtil.cpp | 41 ++ .../cpp/src/orcfile/OrcUtil.h | 13 + .../cpp/test/resources/orc_nested_type | Bin 0 -> 2930 bytes .../cpp/test/tablescan/orc_scan_map_test.cpp | 97 +++++ .../test/tablescan/orc_scan_nested_test.cpp | 103 +++++ .../spark/jni/OrcColumnarBatchScanReader.java | 15 +- .../orc/OmniOrcColumnarBatchReader.java | 39 +- .../vectorized/OmniColumnVector.java | 113 +++++- .../boostkit/spark/ColumnarPlugin.scala | 5 +- .../com/huawei/boostkit/spark/Constant.scala | 2 + .../expression/OmniExpressionAdaptor.scala | 77 +++- .../boostkit/spark/util/OmniAdaptorUtil.scala | 19 + .../ColumnarBasicPhysicalOperators.scala | 12 +- .../spark/sql/execution/ColumnarExec.scala | 382 +++++++++++++++--- .../ColumnarFileSourceScanExec.scala | 5 + .../sql/types/ColumnarBatchSupportUtil.scala | 2 +- .../jni/OrcColumnarBatchJniMapReaderTest.java | 172 ++++++++ .../jni/OrcColumnarBatchJniReaderTest.java | 2 +- .../boostkit/spark/jni/orcsrc/orc_nested_type | Bin 0 -> 2930 bytes .../sql/catalyst/trees/TreePatterns.scala | 4 + 27 files changed, 1286 insertions(+), 173 deletions(-) create mode 100644 omnioperator/omniop-native-reader/cpp/src/orcfile/OrcUtil.cpp create mode 100644 omnioperator/omniop-native-reader/cpp/src/orcfile/OrcUtil.h create mode 100644 omnioperator/omniop-native-reader/cpp/test/resources/orc_nested_type create mode 100644 omnioperator/omniop-native-reader/cpp/test/tablescan/orc_scan_map_test.cpp create mode 100644 omnioperator/omniop-native-reader/cpp/test/tablescan/orc_scan_nested_test.cpp create mode 100644 omnioperator/omniop-spark-extension/spark-extension-core/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniMapReaderTest.java create mode 100644 omnioperator/omniop-spark-extension/spark-extension-core/src/test/java/com/huawei/boostkit/spark/jni/orcsrc/orc_nested_type diff --git a/omnioperator/omniop-native-reader/cpp/src/CMakeLists.txt b/omnioperator/omniop-native-reader/cpp/src/CMakeLists.txt index d68c5b3b32..dbb89ad3a6 100644 --- a/omnioperator/omniop-native-reader/cpp/src/CMakeLists.txt +++ b/omnioperator/omniop-native-reader/cpp/src/CMakeLists.txt @@ -22,6 +22,8 @@ set(SOURCE_FILES orcfile/OmniRLEv2.cc orcfile/OmniColReader.cc orcfile/OmniByteRLE.cc + orcfile/OrcUtil.cpp + common/UriInfo.cc filesystem/hdfs_file.cpp filesystem/hdfs_filesystem.cpp filesystem/io_exception.cpp diff --git a/omnioperator/omniop-native-reader/cpp/src/orcfile/OmniByteRLE.cc b/omnioperator/omniop-native-reader/cpp/src/orcfile/OmniByteRLE.cc index d39af57b1a..ba60d84641 100644 --- a/omnioperator/omniop-native-reader/cpp/src/orcfile/OmniByteRLE.cc +++ b/omnioperator/omniop-native-reader/cpp/src/orcfile/OmniByteRLE.cc @@ -59,60 +59,78 @@ namespace omniruntime::reader { } } - void OmniBooleanRleDecoder::nextNulls(char *data, uint64_t numValues, uint64_t *nulls) { - if (nulls) { - throw std::runtime_error("Not implemented yet for struct type!"); - } - - uint64_t nonNulls = numValues; - - const uint32_t outputBytes = (numValues + 7) / 8; - if (nonNulls == 0) { - ::memset(data, 1, outputBytes); - return; + inline void SetNull(char *data,uint64_t position, bool isNull) { + if (isNull) { + data[position / 8] |= (1<<(position % 8)); + } else { + data[position / 8] &= ~(1<<(position % 8)); } + } - if (remainingBits >= nonNulls) { - // handle remaining bits, which can cover this round - data[0] = reversedAndFlipLastByte >> (8 - remainingBits) & 0xff >> (8 - nonNulls); - remainingBits -= nonNulls; - } else { - // put the remaining bits, if any, into previousByte. - uint8_t previousByte{0}; - if (remainingBits > 0) { - previousByte = reversedAndFlipLastByte >> (8 - remainingBits); - } - // compute byte size that should read - uint64_t bytesRead = (nonNulls - remainingBits + 7) / 8; - OmniByteRleDecoder::next(data, bytesRead, nullptr); + void OmniBooleanRleDecoder::nextNulls(char *data, uint64_t numValues, uint64_t *nulls) { - ReverseAndFlipBytes(reinterpret_cast(data), bytesRead); - reversedAndFlipLastByte = data[bytesRead - 1]; + uint64_t position = 0; - // now shift the data in place - if (remainingBits > 0 ) { - uint64_t nonNullDWords = nonNulls / 64; - for (uint64_t i = 0; i < nonNullDWords; i++) { - uint64_t tmp = reinterpret_cast(data)[i]; - reinterpret_cast(data)[i] = - previousByte | tmp << remainingBits; // previousByte is LSB - previousByte = (tmp >> (64 - remainingBits)) & 0xff; + if(nulls) { + while(remainingBits > 0 && position < numValues) { + if(!BitUtil::IsBitSet(nulls, position)) { + remainingBits -= 1; + SetNull(data, position, !((static_cast(lastByte) >> remainingBits) & 0x1)); + } else { + SetNull(data, position, true); } + position += 1; + } + } else { + while(remainingBits > 0 && position < numValues) { + remainingBits -= 1; + SetNull(data, position, !((static_cast(lastByte) >> remainingBits) & 0x1)); + position += 1; + } + } - // shift 8 bits a time for the remaining bits - const uint64_t nonNullOutputBytes = (nonNulls + 7) / 8; - for (int32_t i = nonNullDWords * 8; i < nonNullOutputBytes; ++i) { - uint8_t tmp = data[i]; // already reversed - data[i] = previousByte | tmp << remainingBits; // previousByte is LSB - previousByte = tmp >> (8 - remainingBits); + uint64_t nonNulls = numValues - position; + if(nulls) { + for(uint64_t i = position; i < numValues; ++i) { + if(BitUtil::IsBitSet(nulls, i)) { + nonNulls -= 1; } } - remainingBits = bytesRead * 8 + remainingBits - nonNulls; } - // clear the most significant bits in the last byte which will be processed in the next round - data[outputBytes - 1] &= 0xff >> (outputBytes * 8 - numValues); + // fill in the remaining values + if (nonNulls == 0) { + while (position < numValues) { + SetNull(data, position++, true); + } + } else if (position < numValues) { + // read the new bytes into the array + uint64_t bytesRead = (nonNulls + 7) / 8; + char bits[bytesRead]; + OmniByteRleDecoder::next(bits, bytesRead, nullptr); + lastByte = bits[bytesRead - 1]; + remainingBits = bytesRead * 8 - nonNulls; + // expand the array backwards so that we don't clobber the data + uint64_t bitsLeft = bytesRead * 8 - remainingBits; + if (nulls) { + for (int64_t i = static_cast(numValues) - 1; i >= static_cast(position); --i) { + if (!BitUtil::IsBitSet(nulls, i)) { + uint64_t shiftPosn = (-bitsLeft) % 8; + SetNull(data, i, !((bits[(bitsLeft -1)/8] >> shiftPosn) & 0x1)); + bitsLeft -= 1; + } else { + SetNull(data, i, true); + } + } + } else { + for (uint64_t i = 0; i < bitsLeft; i++) { + uint64_t byteIndex = i/8; + uint64_t bitPosition = 7 - (i%8); + SetNull(data, position + i, !((bits[byteIndex] >> bitPosition) & 0x1)); + } + } + } } diff --git a/omnioperator/omniop-native-reader/cpp/src/orcfile/OmniColReader.cc b/omnioperator/omniop-native-reader/cpp/src/orcfile/OmniColReader.cc index 38b43fd643..1948f68be8 100644 --- a/omnioperator/omniop-native-reader/cpp/src/orcfile/OmniColReader.cc +++ b/omnioperator/omniop-native-reader/cpp/src/orcfile/OmniColReader.cc @@ -27,6 +27,7 @@ using omniruntime::vec::VectorBatch; using omniruntime::vec::BaseVector; +using omniruntime::vec::MapVector; using omniruntime::exception::OmniException; using omniruntime::vec::NullsBuffer; using orc::ColumnReader; @@ -170,7 +171,8 @@ namespace omniruntime::reader { } else { return std::make_unique(type, stripe); } - + case orc::MAP: + return std::make_unique(type, stripe, julianPtr); case orc::FLOAT: case orc::DOUBLE: return std::make_unique(type, stripe); @@ -248,11 +250,130 @@ namespace omniruntime::reader { } } + OmniMapColumnReader::OmniMapColumnReader(const orc::Type& type, orc::StripeStreams& stripe, + common::JulianGregorianRebase *julianPtr): OmniColumnReader(type, stripe), orcType(&type) { + const std::vector selectedColumns = stripe.getSelectedColumns(); + RleVersion vers = omniConvertRleVersion(stripe.getEncoding(columnId).kind()); + std::unique_ptr stream = + stripe.getStream(columnId, orc::proto::Stream_Kind_LENGTH, true); + rle = createOmniRleDecoder(std::move(stream), false, vers, memoryPool); + + const Type* keyType = type.getSubtype(0); + if (selectedColumns[static_cast(keyType->getColumnId())]) { + keyReader = omniBuildReader(*keyType, stripe, julianPtr); + } + const Type* valueType = type.getSubtype(1); + if (selectedColumns[static_cast(valueType->getColumnId())]) { + valueReader = omniBuildReader(*valueType, stripe, julianPtr); + } + } + + uint64_t OmniMapColumnReader::skip(uint64_t numValues) { + numValues = OmniColumnReader::skip(numValues); + ColumnReader *rawKeyReader = keyReader.get(); + ColumnReader *rawValueReader = valueReader.get(); + if (rawKeyReader || rawValueReader) { + const uint64_t BUFFER_SIZE = 1024; + int64_t buffer[BUFFER_SIZE]; + uint64_t childrenElements = 0; + uint64_t lengthsRead = 0; + while (lengthsRead < numValues) { + uint64_t chunk = std::min(numValues - lengthsRead, BUFFER_SIZE); + rle->next(buffer, chunk, nullptr); + for(size_t i=0; i < chunk; ++i) { + childrenElements += static_cast(buffer[i]); + } + lengthsRead += chunk; + } + if (rawKeyReader) { + rawKeyReader->skip(childrenElements); + } + if (rawValueReader) { + rawValueReader->skip(childrenElements); + } + } else { + rle->skip(numValues); + } + return numValues; + } + + void OmniMapColumnReader::next(omniruntime::vec::BaseVector *vec, uint64_t numValues, + uint64_t *incomingNulls, int omniTypeId) + { + nextInternal(vec, numValues, incomingNulls, omniTypeId); + } + + template + void OmniMapColumnReader::nextInternal(omniruntime::vec::BaseVector *vec, uint64_t numValues, + uint64_t *incomingNulls, int omniTypeId) { + if (encoded) { + std::string message("OmniMapColumnReader::nextInternal encoded is not finished!"); + throw omniruntime::exception::OmniException("EXPRESSION_NOT_SUPPORT", message); + } + auto nulls = omniruntime::vec::unsafe::UnsafeBaseVector::GetNulls(vec); + readNulls(this, numValues, incomingNulls, nulls); + bool hasNull = vec->HasNull(); + + auto mapvector = reinterpret_cast(vec); + + int64_t* offsets = mapvector->GetOffsets(); + auto nullsTrans = reinterpret_cast(nulls); + rle->next(offsets, numValues, nullsTrans); + + uint64_t totalChildren = 0; + if (hasNull) { + for (size_t i = 0; i < numValues; ++i) { + if (!BitUtil::IsBitSet(nulls, i)) { + uint64_t tmp = static_cast(offsets[i]); + offsets[i] = static_cast(totalChildren); + totalChildren += tmp; + } else { + offsets[i] = static_cast(totalChildren); + } + } + } else { + for (size_t i = 0; i < numValues; ++i) { + uint64_t tmp = static_cast(offsets[i]); + offsets[i] = static_cast(totalChildren); + totalChildren += tmp; + } + } + offsets[numValues] = static_cast(totalChildren); + + ColumnReader *rawKeyReader = keyReader.get(); + if (rawKeyReader) { + const Type* keyOrcType = orcType->getSubtype(0); + auto keyDataTypeId = OrcUtil::tranOrcTypeToOmniType(keyOrcType); + std::shared_ptr keyVector = std::move(makeNewVector(totalChildren, keyOrcType, keyDataTypeId)); + mapvector->SetKeyVector(keyVector); + reinterpret_cast(rawKeyReader)->next((mapvector->GetKeyVector().get()), totalChildren, nullptr, keyDataTypeId); + } + ColumnReader *rawValueReader = valueReader.get(); + if (rawValueReader) { + const Type* valueOrcType = orcType->getSubtype(1); + auto valueDataTypeId= OrcUtil::tranOrcTypeToOmniType(valueOrcType); + std::shared_ptr valueVector = std::move(makeNewVector(totalChildren, valueOrcType, valueDataTypeId)); + mapvector->SetValueVector(valueVector); + reinterpret_cast(rawValueReader)->next((mapvector->GetValueVector().get()), totalChildren, nullptr, valueDataTypeId); + } + } + + void OmniMapColumnReader::seekToRowGroup(std::unordered_map& positions) { + OmniColumnReader::seekToRowGroup(positions); + rle->seek(positions.at(columnId)); + if (keyReader.get()) { + keyReader->seekToRowGroup(positions); + } + if (valueReader.get()) { + valueReader->seekToRowGroup(positions); + } + } + /** * OmniStructColumnReader funcs */ OmniStructColumnReader::OmniStructColumnReader(const Type& type, StripeStreams& stripe, - common::JulianGregorianRebase *julianPtr): OmniColumnReader(type, stripe) { + common::JulianGregorianRebase *julianPtr): OmniColumnReader(type, stripe), type_(&type) { // count the number of selected sub-columns const std::vector selectedColumns = stripe.getSelectedColumns(); switch (static_cast(stripe.getEncoding(columnId).kind())) { @@ -312,7 +433,7 @@ namespace omniruntime::reader { const Type* orcType = baseTp.getSubtype(i); omniruntime::type::DataTypeId dataTypeId; if (omniTypeId == nullptr) { - dataTypeId = getDefaultOmniType(orcType); + dataTypeId = OrcUtil::tranOrcTypeToOmniType(orcType); } else { dataTypeId = static_cast(omniTypeId[i]); } @@ -323,38 +444,22 @@ namespace omniruntime::reader { } } - omniruntime::type::DataTypeId OmniStructColumnReader::getDefaultOmniType(const Type* type) { - constexpr int32_t OMNI_MAX_DECIMAL64_DIGITS = 18; - switch (type->getKind()) { - case orc::TypeKind::BOOLEAN: - return omniruntime::type::OMNI_BOOLEAN; - case orc::TypeKind::SHORT: - return omniruntime::type::OMNI_SHORT; - case orc::TypeKind::DATE: - //To do check if the DATE is DATE64 type - return omniruntime::type::OMNI_DATE32; - case orc::TypeKind::INT: - return omniruntime::type::OMNI_INT; - case orc::TypeKind::LONG: - return omniruntime::type::OMNI_LONG; - case orc::TypeKind::TIMESTAMP: - case orc::TypeKind::TIMESTAMP_INSTANT: - return omniruntime::type::OMNI_TIMESTAMP; - case orc::TypeKind::DOUBLE: - return omniruntime::type::OMNI_DOUBLE; - case orc::TypeKind::CHAR: - case orc::TypeKind::STRING: - case orc::TypeKind::VARCHAR: - return omniruntime::type::OMNI_VARCHAR; - case orc::TypeKind::DECIMAL: - if (type->getPrecision() > OMNI_MAX_DECIMAL64_DIGITS) { - return omniruntime::type::OMNI_DECIMAL128; - } else { - return omniruntime::type::OMNI_DECIMAL64; - } - default: - throw omniruntime::exception::OmniException( - "EXPRESSION_NOT_SUPPORT", "Not Supported Type: " + type->getKind()); + void OmniStructColumnReader::next(omniruntime::vec::BaseVector *vec, uint64_t numValues, + uint64_t *incomingNulls, int omniTypeId) { + auto nulls = omniruntime::vec::unsafe::UnsafeBaseVector::GetNulls(vec); + readNulls(this, numValues, incomingNulls, nulls); + bool hasNull = vec->HasNull(); + + auto rowVector = reinterpret_cast(vec); + + uint64_t i = 0; + for(auto iter = children.begin(); iter != children.end(); ++iter, ++i) { + const orc::Type *child = type_->getSubtype(i); + auto dataTypeId = OrcUtil::tranOrcTypeToOmniType(child); + auto childVec = std::shared_ptr(makeNewVector(numValues, child, dataTypeId).release()); + reinterpret_cast(&(*iter->get()))->next(childVec.get(), numValues, + hasNull ? reinterpret_cast(nulls) : nullptr, dataTypeId); + rowVector->AddChild(childVec); } } @@ -1166,7 +1271,6 @@ namespace omniruntime::reader { rle->seek(positions.at(columnId)); } - OmniDoubleColumnReader::OmniDoubleColumnReader(const Type& type, StripeStreams& stripe ): OmniColumnReader(type, stripe), diff --git a/omnioperator/omniop-native-reader/cpp/src/orcfile/OmniColReader.hh b/omnioperator/omniop-native-reader/cpp/src/orcfile/OmniColReader.hh index 6a08e69435..8964ba0b14 100644 --- a/omnioperator/omniop-native-reader/cpp/src/orcfile/OmniColReader.hh +++ b/omnioperator/omniop-native-reader/cpp/src/orcfile/OmniColReader.hh @@ -26,6 +26,7 @@ #include "OmniRLEv2.hh" #include "orc/Int128.hh" #include "OmniByteRLE.hh" +#include "OrcUtil.h" #include "common/JulianGregorianRebase.h" namespace omniruntime::reader { @@ -64,6 +65,7 @@ namespace omniruntime::reader { class OmniStructColumnReader: public OmniColumnReader { private: std::vector> children; + const orc::Type *type_; public: OmniStructColumnReader(const orc::Type& type, orc::StripeStreams& stipe, @@ -82,6 +84,16 @@ namespace omniruntime::reader { void next(void *&omniVecBatch, uint64_t numValues, char *notNull, const orc::Type& baseTp, int* omniTypeId) override; + /** + * direct read VectorBatch in next + * @param vec the vec to push + * @param numValues the numValues of VectorBatch + * @param incomingNulls the notNull array indicates value not null + * @param omniTypeId the omniTypeId to push + */ + void next(omniruntime::vec::BaseVector *vec, uint64_t numValues, + uint64_t *incomingNulls, int omniTypeId) override; + void seekToRowGroup( std::unordered_map& positions) override; @@ -97,9 +109,45 @@ namespace omniruntime::reader { template void nextInternal(std::vector &vecs, uint64_t numValues, uint64_t *incomingNulls, const orc::Type& baseTp, int* omniTypeId); + }; + + class OmniMapColumnReader: public OmniColumnReader { + private: + std::unique_ptr keyReader; + std::unique_ptr valueReader; + std::unique_ptr rle; + const orc::Type* orcType; + + public: + OmniMapColumnReader(const orc::Type& type, orc::StripeStreams& stipe, + common::JulianGregorianRebase *julianPtr); + + uint64_t skip(uint64_t numValues) override; + + /** + * direct read VectorBatch in next + * @param vec the vec to push + * @param numValues the numValues of VectorBatch + * @param incomingNulls the notNull array indicates value not null + * @param omniTypeId the omniTypeId to push + */ + void next(omniruntime::vec::BaseVector *vec, uint64_t numValues, + uint64_t *incomingNulls, int omniTypeId) override; + + void seekToRowGroup(std::unordered_map& positions) override; + + private: + /** + * direct read VectorBatch in next for omni + * @param vec the vec to push + * @param numValues the numValues of VectorBatch + * @param incomingNulls the notNull array indicates value not null + * @param omniTypeId the omniTypeId to push + */ + template + void nextInternal(omniruntime::vec::BaseVector *vec, uint64_t numValues, + uint64_t *incomingNulls, int omniTypeId); - // Get default omni type from orc type. - omniruntime::type::DataTypeId getDefaultOmniType(const orc::Type *type); }; class OmniBooleanColumnReader: public OmniColumnReader { diff --git a/omnioperator/omniop-native-reader/cpp/src/orcfile/OmniRLEv2.cc b/omnioperator/omniop-native-reader/cpp/src/orcfile/OmniRLEv2.cc index 2184d32da4..c4d4b0474b 100644 --- a/omnioperator/omniop-native-reader/cpp/src/orcfile/OmniRLEv2.cc +++ b/omnioperator/omniop-native-reader/cpp/src/orcfile/OmniRLEv2.cc @@ -76,6 +76,16 @@ namespace omniruntime::reader { } } + std::unique_ptr makeMapVector(uint64_t numValues, + omniruntime::type::DataTypeId dataTypeId) { + return std::make_unique(numValues); + } + + std::unique_ptr makeRowVector(uint64_t numValues, const orc::Type *structType, + omniruntime::type::DataTypeId dataTypeId) { + return std::make_unique(numValues); + } + std::unique_ptr makeNewVector(uint64_t numValues, const orc::Type* baseTp, omniruntime::type::DataTypeId dataTypeId) { switch (baseTp->getKind()) { @@ -95,6 +105,10 @@ namespace omniruntime::reader { return makeVarcharVector(numValues, dataTypeId); case orc::TypeKind::DECIMAL: return makeDecimalVector(numValues, dataTypeId); + case orc::TypeKind::MAP: + return makeMapVector(numValues, dataTypeId); + case orc::TypeKind::STRUCT: + return makeRowVector(numValues, baseTp, dataTypeId); default: { throw std::runtime_error("Not support For This ORC Type: " + baseTp->getKind()); } diff --git a/omnioperator/omniop-native-reader/cpp/src/orcfile/OmniRLEv2.hh b/omnioperator/omniop-native-reader/cpp/src/orcfile/OmniRLEv2.hh index 0965ce31d0..2306a10bf4 100644 --- a/omnioperator/omniop-native-reader/cpp/src/orcfile/OmniRLEv2.hh +++ b/omnioperator/omniop-native-reader/cpp/src/orcfile/OmniRLEv2.hh @@ -37,6 +37,12 @@ namespace omniruntime::reader { std::unique_ptr makeDecimalVector(uint64_t numValues, omniruntime::type::DataTypeId dataTypeId); + std::unique_ptr makeMapVector(uint64_t numValues, + omniruntime::type::DataTypeId dataTypeId); + + std::unique_ptr makeRowVector(uint64_t numValues, const orc::Type *structType, + omniruntime::type::DataTypeId dataTypeId); + std::unique_ptr makeNewVector(uint64_t numValues, const orc::Type* baseTp, omniruntime::type::DataTypeId dataTypeId); diff --git a/omnioperator/omniop-native-reader/cpp/src/orcfile/OmniRowReaderImpl.cc b/omnioperator/omniop-native-reader/cpp/src/orcfile/OmniRowReaderImpl.cc index 95d3b6b453..7f6d73ee74 100644 --- a/omnioperator/omniop-native-reader/cpp/src/orcfile/OmniRowReaderImpl.cc +++ b/omnioperator/omniop-native-reader/cpp/src/orcfile/OmniRowReaderImpl.cc @@ -159,7 +159,7 @@ namespace omniruntime::reader { *contents->stream, writerTimezone, readerTimezone); - reader = omniBuildReader(*contents->schema, stripeStreams, + reader = omniBuildReader(getSelectedType(), stripeStreams, (julianPtr == nullptr) ? nullptr : julianPtr.get()); if (sargsApplier) { diff --git a/omnioperator/omniop-native-reader/cpp/src/orcfile/OrcUtil.cpp b/omnioperator/omniop-native-reader/cpp/src/orcfile/OrcUtil.cpp new file mode 100644 index 0000000000..2109346d80 --- /dev/null +++ b/omnioperator/omniop-native-reader/cpp/src/orcfile/OrcUtil.cpp @@ -0,0 +1,41 @@ +#include "OrcUtil.h" + +namespace omniruntime::reader { + omniruntime::type::DataTypeId OrcUtil::tranOrcTypeToOmniType(const orc::Type *type) { + constexpr int32_t OMNI_MAX_DECIMAL64_DIGITS = 18; + switch (type->getKind()) { + case orc::TypeKind::BOOLEAN: + return omniruntime::type::OMNI_BOOLEAN; + case orc::TypeKind::SHORT: + return omniruntime::type::OMNI_SHORT; + case orc::TypeKind::DATE: + return omniruntime::type::OMNI_DATE32; + case orc::TypeKind::INT: + return omniruntime::type::OMNI_INT; + case orc::TypeKind::LONG: + return omniruntime::type::OMNI_LONG; + case orc::TypeKind::TIMESTAMP: + case orc::TypeKind::TIMESTAMP_INSTANT: + return omniruntime::type::OMNI_TIMESTAMP; + case orc::TypeKind::DOUBLE: + return omniruntime::type::OMNI_DOUBLE; + case orc::TypeKind::CHAR: + case orc::TypeKind::STRING: + case orc::TypeKind::VARCHAR: + return omniruntime::type::OMNI_VARCHAR; + case orc::TypeKind::DECIMAL: + if (type->getPrecision() > OMNI_MAX_DECIMAL64_DIGITS) { + return omniruntime::type::OMNI_DECIMAL128; + } else { + return omniruntime::type::OMNI_DECIMAL64; + } + case orc::TypeKind::MAP: + return omniruntime::type::OMNI_MAP; + case orc::TypeKind::STRUCT: + return omniruntime::type::OMNI_ROW; + default: + throw omniruntime::exception::OmniException( + "EXPRESSION_NOT_SUPPORT", "OmniMapColumnReader Not Supported Type: " + type->getKind()); + } + } +} \ No newline at end of file diff --git a/omnioperator/omniop-native-reader/cpp/src/orcfile/OrcUtil.h b/omnioperator/omniop-native-reader/cpp/src/orcfile/OrcUtil.h new file mode 100644 index 0000000000..787b7c797c --- /dev/null +++ b/omnioperator/omniop-native-reader/cpp/src/orcfile/OrcUtil.h @@ -0,0 +1,13 @@ +#ifndef CPP_ORCUTIL_H +#define CPP_ORCUTIL_H + +#include "orc/Type.hh" +#include "vector/vector_common.h" + +namespace omniruntime::reader { + class OrcUtil { + public: + static omniruntime::type::DataTypeId tranOrcTypeToOmniType(const orc::Type *type); + }; +} +#endif CPP_ORCUTIL_H \ No newline at end of file diff --git a/omnioperator/omniop-native-reader/cpp/test/resources/orc_nested_type b/omnioperator/omniop-native-reader/cpp/test/resources/orc_nested_type new file mode 100644 index 0000000000000000000000000000000000000000..5aa0d82c0b7ba74a02d7debf14c8d3053e04183f GIT binary patch literal 2930 zcmZuzdpMNo8-8a#!V9Gg@pmr^dNqX)Bc)#};E3A*VJJ z*`$z22aRLpkWiMSFh!fTl~}f<{f4W(_WJevW9E8ip8L6<`+1)G{jSfIV@m`8Bds8X zhS*>ud{o3h000FX8d8u(BMH(Z92>X*AU2uw=K%!O5Feee1x#K=-B^I|l#(d$<&JwGR^y?e^fL?R+Bx{be2pX-OD5*H=bNG zT<`-^b&)GBEButiWCKr0D-$cNIO78~cr+Tc%~m5ok;{n2QeB{J=q?UMH=OT_vn!uO z-~blP9YBYm_(1yhlw`H{^efKp7LiLI?`%wgmC#{6 z<<-&R(a;t&`qpJ6q5cZYsRl=|tC6qwExWRB1=6=)<6x0ug{hU^w&Pv2@U#v)YKN0- z3B2(CxW^kt^{-Aq8Q%S78Xukj>UmqYWzv`eb#FA5l_coC<&L^gRI4Q00OvZu37l11 z>A?v^n%RY=iL50{D*n!+sHA&KrQxzwFl-a7nzX8#2wMe9 zRS`tInbvk%?wd!Bwp2ohwx7fRc9N1`EOIJqvS^c{)b=OX^0>I{1Wix7{-;y{;;S)z z=~0>uwn*q#Ct(cJBWQ!&S&zJJsYX0)KgItsm^5rKc&Ekr`R}Y?3g8Ta3@mJmj=K_mVx5~2Lp~St6R@`w~>MEIJqP$9f5+e@&<62rq|ic^7spCEz)s^S~A}atW zOjsGy7=SN#p%}gqnw1UA($>%>VF*i z)8*IA+4*NyL}6ga=VB}XE=VO^%XI)8M6Oru6|CobRj(fvCUrOY0J~cPS;C* zL+~s_Xe4swtO$Ugf51LoU{G)f{D(9GepVQpFeuJ`A$$t=faxYPimP83#rh!6++s6@ z9m3~PY{C!jTN&rghwwAZg?EVJ6%iF?6A^tNf^Qof65tmUu@X*>Q~`jG1TN};a2PQh zC8>wh1)$5N%CsB)ApHXS#SD^+36g^CRJdZ1UgUrfAL1)|3C6Koj0ABcckfQb z$|#ZgaK6E=o=x3S3Ogw+B)dETnoIDLULioigYD@STJ?0}m<$G&37?q&RKw>1<_iQ! z>g(gjGYKd*$(zYU?6_WQwn;nG)VZ+f!)yAh&!1DpqcJ9Cu@s#{L5CPyl|p_oBX&-N z;MmRi+7^AX_($V~=Qdll}~wPs$P zu|7ItHFgSl^hW)B--Dy!zbD6bd}tD4J`x@q>v%GBa&=x+h-2JZkE9j^zF`a=vbxH( zjqN;IM2_;yy%&;{JKs4N^oDz2J@vOYuJCt5Mc$q+*0PLS>((s~Tpwk4;BJE~H5uy1(S!?s9^DXT@xt`fYnV#p_xp~}zi@v&v+P(%vl)GVX8Opt?t=+?$ zlbwGW?;-3o&qX6H=7_TpPc&QljK-G{^tfcXX!+5g?BK)Il{~A5%0zYVcg#~Qjiwof zyG2l9#%{8htM@LB7#yd6+C=7Rwy6$8TB=q_G&;OilZx9cQ>=cxqI5R)ec9Zf9h=c6 zAq{DxMY~^9{?=H0*FSt>k9F$##=6aU>Di6fadWtIp|ev@w_&1%+#JP#rB__t7-vk# zFG%>rlW`C4&;s00*~eO%K&(c#o^zCC&U=wa2|}Z*;3U~x=IR=TiiToFD<<85h)_$x zP|hZ>gbBlGuBGE0OI=}#%J{XPm;ZEBXL}1-f`sAJgn-^H+Hw1*t+zBb4ly(3%;FT+ zj@pc&_nl0`Mn>v?A8D+OF%EP@BX&rY+_#k@88Yp`=%um3y!;CUZ}m3pfonF*z`Yma zxh`I*iE8X?G~K8iyX!+mvuUi^_lXuuoGxYc_i43Sb)x%rq*r|YPa81Evt~7{xnpAh zGuvFXM(oR-5c}VUFNqQ|Gf+XA`Ms*IcQr!nKj=hQ45sHvOLOp9zQkS8GS&BuY&|W8 z1ebFFSxeMj;FVPnrYlw6tr;$ut!1|gh_^vFDI*h@;qK+3HaU?&F$%2H0z;GhaZ9p)=S9wxP?T<7{wzG*(d1XDCohrw6+QYP*FMc->S4m1x`FZ(A z8TpQicM}_mWww!=mSjY2Kgy&>lw+;so&Gxd@!9l3(?WfOk@_CD+mOaAi)YxrNaba2 zPB@qCXu^pPMW4LpP^YmoSih=^r87OpkNn?p}=QKzQL zMb3TM#wN;6LcbXsfGSPQMA6Y~XXU%Ke%LhgYlmqd1L4Pfit|r*{GO`uH1Xkm(P9N7 zhAAI2Ffi|Uu4LoO(C8iZ=9RyF+DUagH-XNPH?jI=gnVOIc0Q_{ZRY3r}OVQFP^ zf4yk1UpVo6^1N{5q}yftflJ4?PT_Eq*d65h4OD$LH;|>A%c&{#ZH7!zNXGnk!>*e@x0gwY#)yea=OpjrO?D&jHoBd_yEd zE$DV=PhNOGKJQ6h#h`x2tEc^@Tjky;9%~HbaP3vdJxSztaSl)1ViTJ?pBmOL8*M*R zZ{1`iUvsmNlvOi2sCT`_*g3!I(opt@ +#include +#include "jni/OrcColumnarBatchJniReader.h" +#include "scan_test.h" +#include + +class ScanMapTest : public testing::Test { +protected: + // run before each case... + virtual void SetUp() override + { + orc::ReaderOptions readerOpts; + orc::RowReaderOptions rowReaderOptions; + std::string filename = "/../resources/orc_nested_type"; + filename = PROJECT_PATH + filename; + UriInfo uriInfo("file", filename, "", "-1"); + std::unique_ptr reader = omniruntime::reader::omniCreateReader(orc::readFileOverride(uriInfo, false), readerOpts); + + std::list includedColumns = {"c1", "c2", "c3", "c4", "c5", "c17"}; + rowReaderOptions.include(includedColumns); + + std::unique_ptr julianPtr; + std::unique_ptr predicatePtr; + auto readerPtr = static_cast(reader.get()); + rowReader = readerPtr->createRowReader(rowReaderOptions, julianPtr, predicatePtr).release(); + omniruntime::reader::OmniRowReaderImpl *rowReaderPtr = (omniruntime::reader::OmniRowReaderImpl*) rowReader; + rowReaderPtr->next(&recordBatch, nullptr, 4096); + } + + // run after each case... + virtual void TearDown() override { + for (auto vec : recordBatch) { + delete vec; + } + recordBatch.clear(); + delete rowReader; + rowReader = nullptr; + } + + orc::RowReader *rowReader; + std::vector recordBatch; +}; + +TEST_F(ScanMapTest, map_test_correctness_stringVec) +{ + // string type, "c3" + auto *olbStr = (omniruntime::vec::Vector> *)( + recordBatch[2]); + std::string_view actualStr = olbStr->GetValue(0); + ASSERT_EQ(actualStr, "string value 1"); +} + +TEST_F(ScanMapTest, map_test_correctness_map) +{ + // maptype type, "c17" + auto *mapVec = (omniruntime::vec::MapVector *)(recordBatch[5]); + ASSERT_EQ(mapVec->GetSize(0), 2); + ASSERT_EQ(mapVec->GetSize(1), 2); + ASSERT_EQ(mapVec->GetSize(2), 2); + ASSERT_EQ(mapVec->GetOffset(3), 6); + + + auto keyVector = reinterpret_cast>*>(mapVec->GetKeyVector().get()); + ASSERT_EQ(keyVector->GetValue(0), "key1"); + ASSERT_EQ(keyVector->GetValue(1), "key2"); + ASSERT_EQ(keyVector->GetValue(2), "key3"); + ASSERT_EQ(keyVector->GetValue(3), "key4"); + ASSERT_EQ(keyVector->GetValue(4), "key5"); + ASSERT_EQ(keyVector->GetValue(5), "key6"); + + auto valueVector = reinterpret_cast*>(mapVec->GetValueVector().get()); + ASSERT_EQ(valueVector->GetValue(0), 1.1); + ASSERT_EQ(valueVector->GetValue(1), 2.2); + ASSERT_EQ(valueVector->GetValue(2), 3.3); + ASSERT_EQ(valueVector->GetValue(3), 4.4); + ASSERT_EQ(valueVector->GetValue(4), 5.5); + ASSERT_EQ(valueVector->GetValue(5), 6.6); +} \ No newline at end of file diff --git a/omnioperator/omniop-native-reader/cpp/test/tablescan/orc_scan_nested_test.cpp b/omnioperator/omniop-native-reader/cpp/test/tablescan/orc_scan_nested_test.cpp new file mode 100644 index 0000000000..b22a545db6 --- /dev/null +++ b/omnioperator/omniop-native-reader/cpp/test/tablescan/orc_scan_nested_test.cpp @@ -0,0 +1,103 @@ + +//CREATE TABLE `orc_nested_type` ( +// `c1` int, +// `c2` varchar(60), +// `c3` string, +// `c4` bigint, +// `c5` char(40), +// `c6` float, +// `c7` double, +// `c8` decimal(9,8), +// `c9` decimal(18,5), +// `c10` boolean, +// `c11` smallint, +// `c12` timestamp, +// `c13` date, +// -- New nested type fields +// `c14` array, -- Integer array +// `c15` array, -- String array +// `c16` struct, -- Struct type +// `c17` map, -- Map from string to double +// `c18` array> -- Array of structs +//) stored as orc; +// +//-- 1st row +//INSERT INTO orc_nested_test VALUES ( +// 1, -- c1 int +// 'varchar value 1', -- c2 varchar(60) +// 'string value 1', -- c3 string +// 10000000000, -- c4 bigint +// 'char value 1', -- c5 char(40) +// 1.5, -- c6 float +// 123.456, -- c7 double +// 0.12345678, -- c8 decimal(9,8) +// 1234567890.12345, -- c9 decimal(18,5) +// true, -- c10 boolean +// 10, -- c11 smallint +// timestamp('2023-01-01 12:00:00'), -- c12 timestamp +// date('2023-01-01'), -- c13 date +// +// array(1, 2, 3), -- c14 array +// array('a', 'b', 'c'), -- c15 array +// named_struct('address', '123 Main St', 'city', 'New York', 'zip', 10001), -- c16 struct +// map('key1', 1.1, 'key2', 2.2), -- c17 map +// array( -- c18 array> +// named_struct('name', 'item1', 'value', 1.1), +// named_struct('name', 'item2', 'value', 2.2) +// ) +//); +// +//-- 2nd row +//INSERT INTO orc_nested_test VALUES ( +// 2, +// 'varchar value 2', +// 'string value 2', +// 20000000000, +// 'char value 2', +// 2.5, +// 234.567, +// 0.23456789, +// 2345678901.23456, +// false, +// 20, +// timestamp('2023-02-02 13:30:00'), +// date('2023-02-02'), +// array(4, 5, 6), +// array('d', 'e', 'f'), +// named_struct('address', '456 Oak Ave', 'city', 'Boston', 'zip', 02108), +// map('key3', 3.3, 'key4', 4.4), +// array( +// named_struct('name', 'item3', 'value', 3.3), +// named_struct('name', 'item4', 'value', 4.4) +// ) +//); +// +//-- 3rd row +//INSERT INTO orc_nested_test VALUES ( +// 3, +// 'varchar value 3', +// 'string value 3', +// 30000000000, +// 'char value 3', +// 3.5, +// 345.678, +// 0.34567890, +// 3456789012.34567, +// true, +// 30, +// timestamp('2023-03-03 14:45:00'), +// date('2023-03-03'), +// array(7, 8, 9), +// array('g', 'h', 'i'), +// named_struct('address', '789 Pine Blvd', 'city', 'Chicago', 'zip', 60601), +// map('key5', 5.5, 'key6', 6.6), +// array( +// named_struct('name', 'item5', 'value', 5.5), +// named_struct('name', 'item6', 'value', 6.6) +// ) +//); + + + + + diff --git a/omnioperator/omniop-spark-extension/spark-extension-core/src/main/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchScanReader.java b/omnioperator/omniop-spark-extension/spark-extension-core/src/main/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchScanReader.java index 3f8b2db2cb..6ded8558ba 100644 --- a/omnioperator/omniop-spark-extension/spark-extension-core/src/main/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchScanReader.java +++ b/omnioperator/omniop-spark-extension/spark-extension-core/src/main/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchScanReader.java @@ -23,6 +23,8 @@ import com.huawei.boostkit.spark.predicate.*; import com.huawei.boostkit.spark.timestamp.JulianGregorianRebase; import com.huawei.boostkit.spark.timestamp.TimestampUtil; import nova.hetu.omniruntime.type.DataType; +import nova.hetu.omniruntime.type.MapDataType; +import nova.hetu.omniruntime.type.StructDataType; import nova.hetu.omniruntime.vector.*; import org.apache.orc.impl.writer.TimestampTreeWriter; @@ -63,6 +65,7 @@ import java.text.ParseException; import java.text.SimpleDateFormat; import java.util.ArrayList; import java.util.Arrays; +import java.util.List; import java.util.Map; import java.util.regex.Matcher; import java.util.regex.Pattern; @@ -253,7 +256,7 @@ public class OrcColumnarBatchScanReader { } } - public int next(Vec[] vecList, int[] typeIds) { + public int next(Vec[] vecList, int[] typeIds, List dataTypes) { long[] vecNativeIds = new long[typeIds.length]; long rtn = jniReader.recordReaderNext(recordReader, batchReader, typeIds, vecNativeIds); if (rtn == 0) { @@ -264,7 +267,7 @@ public class OrcColumnarBatchScanReader { if (colsToGet[i] != 0) { continue; } - switch (DataType.DataTypeId.values()[typeIds[nativeGetId]]) { + switch (DataType.DataTypeId.fromValue(typeIds[nativeGetId])) { case OMNI_BOOLEAN: { vecList[i] = new BooleanVec(vecNativeIds[nativeGetId]); break; @@ -306,6 +309,14 @@ public class OrcColumnarBatchScanReader { vecList[i] = new Decimal128Vec(vecNativeIds[nativeGetId]); break; } + case OMNI_MAP: { + vecList[i] = new MapVec(vecNativeIds[nativeGetId], (MapDataType) dataTypes.get(i), (int) rtn); + break; + } + case OMNI_ROW: { + vecList[i] = new StructVec(vecNativeIds[nativeGetId], (StructDataType) dataTypes.get(i), (int) rtn); + break; + } default: { throw new RuntimeException("UnSupport type for ColumnarFileScan:" + DataType.DataTypeId.values()[typeIds[i]]); diff --git a/omnioperator/omniop-spark-extension/spark-extension-core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OmniOrcColumnarBatchReader.java b/omnioperator/omniop-spark-extension/spark-extension-core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OmniOrcColumnarBatchReader.java index 1f3c1a8ef0..03d06f07aa 100644 --- a/omnioperator/omniop-spark-extension/spark-extension-core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OmniOrcColumnarBatchReader.java +++ b/omnioperator/omniop-spark-extension/spark-extension-core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OmniOrcColumnarBatchReader.java @@ -38,6 +38,8 @@ import org.apache.spark.sql.vectorized.ColumnarBatch; import java.io.IOException; import java.util.ArrayList; +import java.util.List; +import java.util.Stack; /** * To support vectorization in WholeStageCodeGen, this reader returns ColumnarBatch. @@ -63,6 +65,7 @@ public class OmniOrcColumnarBatchReader extends RecordReader dataTypes = new ArrayList<>(); private StructType requiredSchema; private Filter pushedFilter; @@ -139,14 +142,25 @@ public class OmniOrcColumnarBatchReader extends RecordReader(); // collect read cols types ArrayList typeBuilder = new ArrayList<>(); - + Stackprefix=new Stack<>(); for (int i = 0; i < requiredfieldNames.length; i++) { String target = requiredfieldNames[i]; // if not find, set colsToGet value -1, else set colsToGet 0 if (recordReader.allFieldsNames.contains(target)) { recordReader.colsToGet[i] = 0; - recordReader.includedColumns.add(requiredfieldNames[i]); - typeBuilder.add(OmniExpressionAdaptor.sparkTypeToOmniType(requiredSchema.fields()[i].dataType())); + StructField field = requiredSchema.fields()[i]; + if (field.dataType() instanceof StructType) { + prefix.push(field.name()); + buildRequireSchema((StructType)field.dataType(),recordReader.includedColumns,prefix); + prefix.pop(); + } else { + recordReader.includedColumns.add(target); + } + nova.hetu.omniruntime.type.DataType dataType = + OmniExpressionAdaptor.sparkTypeToOmniTypeWithComplex(field.dataType(), field.metadata()); + + typeBuilder.add(dataType.getIdValue()); + dataTypes.add(dataType); } else { recordReader.colsToGet[i] = -1; } @@ -155,6 +169,23 @@ public class OmniOrcColumnarBatchReader extends RecordReaderincludedColumns, Stack prefix) { + StringBuilder sb = new StringBuilder(); + for (String s : prefix) { + sb.append(s).append("."); + } + for (StructField field : structType.fields()) { + String name = field.name(); + if (field.dataType() instanceof StructType) { + prefix.push(name); + buildRequireSchema((StructType) field.dataType(),includedColumns,prefix); + prefix.pop(); + } else { + includedColumns.add(new StringBuilder(sb).append(name).toString()); + } + } + } + /** * Initialize columnar batch by setting required schema and partition information. * With this information, this creates ColumnarBatch with the full schema. @@ -220,7 +251,7 @@ public class OmniOrcColumnarBatchReader extends RecordReader - ColumnarConditionProjectExec(plan.projectList, condition, child) + // disable ColumnarConditionProjectExec +// case ColumnarFilterExec(condition, child) => +// ColumnarConditionProjectExec(plan.projectList, condition, child) case join: ColumnarBroadcastHashJoinExec => if (plan.projectList.forall(project => OmniExpressionAdaptor.isSimpleProjectForAll(project)) && enableColumnarProjectFusion) { ColumnarBroadcastHashJoinExec( diff --git a/omnioperator/omniop-spark-extension/spark-extension-core/src/main/scala/com/huawei/boostkit/spark/Constant.scala b/omnioperator/omniop-spark-extension/spark-extension-core/src/main/scala/com/huawei/boostkit/spark/Constant.scala index f7ec101780..49d2f4000c 100644 --- a/omnioperator/omniop-spark-extension/spark-extension-core/src/main/scala/com/huawei/boostkit/spark/Constant.scala +++ b/omnioperator/omniop-spark-extension/spark-extension-core/src/main/scala/com/huawei/boostkit/spark/Constant.scala @@ -37,6 +37,8 @@ object Constant { val OMNI_DECIMAL64_TYPE: String = DataTypeId.OMNI_DECIMAL64.ordinal().toString val OMNI_DECIMAL128_TYPE: String = DataTypeId.OMNI_DECIMAL128.ordinal().toString val OMNI_TIMESTAMP_TYPE: String = DataTypeId.OMNI_TIMESTAMP.ordinal().toString + val OMNI_ROW_TYPE: String = DataTypeId.OMNI_ROW.toValue.toString + val OMNI_MAP_TYPE: String = DataTypeId.OMNI_MAP.toValue.toString // for UT val OMNI_IS_ADAPTIVE_CONTEXT = "omni.isAdaptiveContext" } \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension/spark-extension-core/src/main/scala/com/huawei/boostkit/spark/expression/OmniExpressionAdaptor.scala b/omnioperator/omniop-spark-extension/spark-extension-core/src/main/scala/com/huawei/boostkit/spark/expression/OmniExpressionAdaptor.scala index 3cd2fb3826..d607343001 100644 --- a/omnioperator/omniop-spark-extension/spark-extension-core/src/main/scala/com/huawei/boostkit/spark/expression/OmniExpressionAdaptor.scala +++ b/omnioperator/omniop-spark-extension/spark-extension-core/src/main/scala/com/huawei/boostkit/spark/expression/OmniExpressionAdaptor.scala @@ -19,8 +19,8 @@ package com.huawei.boostkit.spark.expression import scala.collection.mutable.ArrayBuffer -import com.huawei.boostkit.spark.Constant.{DEFAULT_STRING_TYPE_LENGTH, IS_CHECK_OMNI_EXP, OMNI_BOOLEAN_TYPE, OMNI_DATE_TYPE, OMNI_DECIMAL128_TYPE, OMNI_DECIMAL64_TYPE, OMNI_DOUBLE_TYPE, OMNI_INTEGER_TYPE, OMNI_LONG_TYPE, OMNI_SHOR_TYPE, OMNI_TIMESTAMP_TYPE, OMNI_VARCHAR_TYPE} -import nova.hetu.omniruntime.`type`.{BooleanDataType, DataTypeSerializer, Date32DataType, Decimal128DataType, Decimal64DataType, DoubleDataType, IntDataType, LongDataType, ShortDataType, TimestampDataType, VarcharDataType} +import com.huawei.boostkit.spark.Constant.{DEFAULT_STRING_TYPE_LENGTH, IS_CHECK_OMNI_EXP, OMNI_BOOLEAN_TYPE, OMNI_DATE_TYPE, OMNI_DECIMAL128_TYPE, OMNI_DECIMAL64_TYPE, OMNI_DOUBLE_TYPE, OMNI_INTEGER_TYPE, OMNI_LONG_TYPE, OMNI_MAP_TYPE, OMNI_ROW_TYPE, OMNI_SHOR_TYPE, OMNI_TIMESTAMP_TYPE, OMNI_VARCHAR_TYPE} +import nova.hetu.omniruntime.`type`.{BooleanDataType, DataTypeSerializer, Date32DataType, Decimal128DataType, Decimal64DataType, DoubleDataType, IntDataType, LongDataType, MapDataType, ShortDataType, StructDataType, TimestampDataType, VarcharDataType} import nova.hetu.omniruntime.constants.FunctionType import nova.hetu.omniruntime.constants.FunctionType.{OMNI_AGGREGATION_TYPE_AVG, OMNI_AGGREGATION_TYPE_COUNT_ALL, OMNI_AGGREGATION_TYPE_COUNT_COLUMN, OMNI_AGGREGATION_TYPE_FIRST_IGNORENULL, OMNI_AGGREGATION_TYPE_FIRST_INCLUDENULL, OMNI_AGGREGATION_TYPE_MAX, OMNI_AGGREGATION_TYPE_MIN, OMNI_AGGREGATION_TYPE_SAMP, OMNI_AGGREGATION_TYPE_SUM, OMNI_WINDOW_TYPE_RANK, OMNI_WINDOW_TYPE_ROW_NUMBER} import nova.hetu.omniruntime.constants.JoinType._ @@ -39,7 +39,7 @@ import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide import org.apache.spark.sql.catalyst.util.CharVarcharUtils.getRawTypeString import org.apache.spark.sql.execution import org.apache.spark.sql.hive.HiveUdfAdaptorUtil -import org.apache.spark.sql.types.{BinaryType, BooleanType, DataType, DateType, Decimal, DecimalType, DoubleType, IntegerType, LongType, Metadata, NullType, ShortType, StringType, TimestampType} +import org.apache.spark.sql.types.{BinaryType, BooleanType, DataType, DateType, Decimal, DecimalType, DoubleType, IntegerType, LongType, MapType, Metadata, NullType, ShortType, StringType, StructField, StructType, TimestampType} import org.apache.spark.sql.util.ShimUtil import java.util.Locale @@ -84,9 +84,6 @@ object OmniExpressionAdaptor extends Logging { if (!ColumnarPluginConfig.getSessionConf.enableOmniUnixTimeFunc) { throw new UnsupportedOperationException(s"Not Enabled Omni UnixTime Function") } - if (ColumnarPluginConfig.getSessionConf.timeParserPolicy == "LEGACY") { - throw new UnsupportedOperationException(s"Unsupported Time Parser Policy: LEGACY") - } if (!timeZoneSet.contains(timeZone)) { throw new UnsupportedOperationException(s"Unsupported Time Zone: $timeZone") } @@ -121,6 +118,22 @@ object OmniExpressionAdaptor extends Logging { rewriteToOmniJsonExpressionLiteralJsonObject(expr, exprsIndexMap, expr.dataType) } + private def GenGetStructField(expr: Expression, exprsIndexMap: Map[ExprId, Int]) + : JsonObject = { + expr match { + case getStructField: GetStructField => + val json = GenGetStructField(getStructField.child, exprsIndexMap) + new JsonObject().put("exprType", "FIELD_REFERENCE") + .put("dataType", sparkTypeToOmniExpType(getStructField.dataType).toInt) + .put("ordinal", getStructField.ordinal) + .put("input", json) + case attr: Attribute => + new JsonObject().put("exprType", "FIELD_REFERENCE") + .put("dataType", sparkTypeToOmniExpType(attr.dataType).toInt) + .put("colVal", exprsIndexMap(attr.exprId)) + } + } + private def rewriteToOmniJsonExpressionLiteralJsonObject(expr: Expression, exprsIndexMap: Map[ExprId, Int], returnDatatype: DataType): JsonObject = { @@ -538,6 +551,15 @@ object OmniExpressionAdaptor extends Logging { .addOmniExpJsonType("returnType", floor.dataType) .put("function_name", "floor") .put("arguments", new JsonArray().put(rewriteToOmniJsonExpressionLiteralJsonObject(floor.child, exprsIndexMap))) + case dateDiff: DateDiff => + new JsonObject().put("exprType", "FUNCTION") + .addOmniExpJsonType("returnType", dateDiff.dataType) + .put("function_name", "date_diff") + .put("arguments", new JsonArray().put(rewriteToOmniJsonExpressionLiteralJsonObject(dateDiff.children(0), exprsIndexMap)) + .put(rewriteToOmniJsonExpressionLiteralJsonObject(dateDiff.children(1), exprsIndexMap))) + + case getStructField: GetStructField => + GenGetStructField(getStructField, exprsIndexMap) case _ => val jsonObj = ModifyUtilAdaptor.rewriteToOmniJsonExpression(expr, exprsIndexMap, returnDatatype, rewriteToOmniJsonExpressionLiteralJsonObject) @@ -904,6 +926,10 @@ object OmniExpressionAdaptor extends Logging { } else { OMNI_DECIMAL128_TYPE } + case f: StructField => + sparkTypeToOmniExpType(f.dataType) + case f: StructType => OMNI_ROW_TYPE + case m: MapType => OMNI_MAP_TYPE case NullType => OMNI_BOOLEAN_TYPE case _ => throw new UnsupportedOperationException(s"Unsupported datatype: $datatype") @@ -1009,12 +1035,49 @@ object OmniExpressionAdaptor extends Logging { } } + def sparkTypeToOmniTypeWithComplex(dataType: DataType, metadata: Metadata = Metadata.empty): + nova.hetu.omniruntime.`type`.DataType = { + dataType match { + case ShortType => + ShortDataType.SHORT + case IntegerType => + IntDataType.INTEGER + case LongType => + LongDataType.LONG + case TimestampType => + TimestampDataType.TIMESTAMP + case DoubleType => + DoubleDataType.DOUBLE + case BooleanType => + BooleanDataType.BOOLEAN + case StringType => + new VarcharDataType(getStringLength(metadata)) + case DateType => + Date32DataType.DATE32 + case dt: DecimalType => + if (DecimalType.is64BitDecimalType(dt)) { + new Decimal64DataType(dt.precision, dt.scale) + } else { + new Decimal128DataType(dt.precision, dt.scale) + } + case m: MapType => + new MapDataType(sparkTypeToOmniTypeWithComplex(m.keyType), sparkTypeToOmniTypeWithComplex(m.valueType)) + case s: StructType => + val children = s.fields.map(f => sparkTypeToOmniTypeWithComplex(f.dataType, f.metadata)) + new StructDataType(children) + case s: StructField => + sparkTypeToOmniTypeWithComplex(s.dataType, s.metadata) + case _ => + throw new UnsupportedOperationException(s"Unsupported datatype: $dataType") + } + } + def sparkProjectionToOmniJsonProjection(attr: Attribute, colVal: Int): String = { val dataType: DataType = attr.dataType val metadata = attr.metadata val omniDataType: String = sparkTypeToOmniExpType(dataType) dataType match { - case ShortType | IntegerType | LongType | DoubleType | BooleanType | DateType | TimestampType => + case ShortType | IntegerType | LongType | DoubleType | BooleanType | DateType | TimestampType | StructType(_) => new JsonObject().put("exprType", "FIELD_REFERENCE") .put("dataType", omniDataType.toInt) .put("colVal", colVal).toString diff --git a/omnioperator/omniop-spark-extension/spark-extension-core/src/main/scala/com/huawei/boostkit/spark/util/OmniAdaptorUtil.scala b/omnioperator/omniop-spark-extension/spark-extension-core/src/main/scala/com/huawei/boostkit/spark/util/OmniAdaptorUtil.scala index 3892d0d0cd..bf55f638cf 100644 --- a/omnioperator/omniop-spark-extension/spark-extension-core/src/main/scala/com/huawei/boostkit/spark/util/OmniAdaptorUtil.scala +++ b/omnioperator/omniop-spark-extension/spark-extension-core/src/main/scala/com/huawei/boostkit/spark/util/OmniAdaptorUtil.scala @@ -29,6 +29,7 @@ import nova.hetu.omniruntime.operator.OmniOperator import nova.hetu.omniruntime.operator.aggregator.{OmniAggregationWithExprOperatorFactory, OmniHashAggregationWithExprOperatorFactory} import nova.hetu.omniruntime.operator.config.{OperatorConfig, OverflowConfig, SpillConfig} import nova.hetu.omniruntime.vector._ +import nova.hetu.omniruntime.`type`.{DataType => OmniDataType, MapDataType, StructDataType} import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, ExprId, NamedExpression, SortOrder} import org.apache.spark.sql.execution.datasources.orc.OrcColumnVector import org.apache.spark.sql.execution.metric.SQLMetric @@ -200,6 +201,24 @@ object OmniAdaptorUtil { } vec } + case StructType(fields) => + val vec = new StructVec(new StructDataType(fields.map(field => sparkTypeToOmniTypeWithComplex(field.dataType, Metadata.empty))), columnSize) + val numChildren = fields.length + for (i <- 0 until numChildren) { + vec.add (i, transColumnVector (columnVector.getChild (i), columnSize) ) + } + vec + case MapType(keyType, valueType, valueContainsNull) => + val offsets = new Array[Int](columnSize + 1) + offsets(0) = 0 + for (i <- 1 until columnSize + 1) { + offsets(i) = offsets(i - 1) + columnVector.getMap(i - 1).numElements() + } + val vec = new MapVec(new MapDataType(sparkTypeToOmniTypeWithComplex(keyType, Metadata.empty), sparkTypeToOmniTypeWithComplex(valueType, Metadata.empty)), offsets(columnSize)) + vec.AddKeys(transColumnVector(columnVector.getChild(0), offsets(columnSize))) + vec.AddValues(transColumnVector(columnVector.getChild(1), offsets(columnSize))) + vec.AddOffsets(offsets) + vec case _ => throw new UnsupportedOperationException("unsupport column vector!") } diff --git a/omnioperator/omniop-spark-extension/spark-extension-core/src/main/scala/org/apache/spark/sql/execution/ColumnarBasicPhysicalOperators.scala b/omnioperator/omniop-spark-extension/spark-extension-core/src/main/scala/org/apache/spark/sql/execution/ColumnarBasicPhysicalOperators.scala index 486d7bbc8b..566a5f59c2 100644 --- a/omnioperator/omniop-spark-extension/spark-extension-core/src/main/scala/org/apache/spark/sql/execution/ColumnarBasicPhysicalOperators.scala +++ b/omnioperator/omniop-spark-extension/spark-extension-core/src/main/scala/org/apache/spark/sql/execution/ColumnarBasicPhysicalOperators.scala @@ -62,7 +62,8 @@ case class ColumnarProjectExec(projectList: Seq[NamedExpression], child: SparkPl def buildCheck(): Unit = { val omniAttrExpsIdMap = getExprIdMap(child.output) child.output.map( - exp => sparkTypeToOmniType(exp.dataType, exp.metadata)).toArray + // TODO RowToColumnTest + exp => sparkTypeToOmniTypeWithComplex(exp.dataType, exp.metadata)).toArray val omniExpressions: Array[AnyRef] = projectList.map( exp => rewriteToOmniJsonExpressionLiteral(exp, omniAttrExpsIdMap)).toArray checkOmniJsonWhiteList("", omniExpressions) @@ -77,7 +78,8 @@ case class ColumnarProjectExec(projectList: Seq[NamedExpression], child: SparkPl val omniAttrExpsIdMap = getExprIdMap(child.output) val omniInputTypes = child.output.map( - exp => sparkTypeToOmniType(exp.dataType, exp.metadata)).toArray + // TODO RowToColumnTest + exp => sparkTypeToOmniTypeWithComplex(exp.dataType, exp.metadata)).toArray val omniExpressions = projectList.map( exp => rewriteToOmniJsonExpressionLiteral(exp, omniAttrExpsIdMap)).toArray @@ -180,7 +182,7 @@ case class ColumnarFilterExec(condition: Expression, child: SparkPlan) def buildCheck(): Unit = { val omniAttrExpsIdMap = getExprIdMap(child.output) child.output.map( - exp => sparkTypeToOmniType(exp.dataType, exp.metadata)).toArray + exp => sparkTypeToOmniTypeWithComplex(exp.dataType, exp.metadata)).toArray val filterExpression = rewriteToOmniJsonExpressionLiteral(condition, omniAttrExpsIdMap) checkOmniJsonWhiteList(filterExpression, new Array[AnyRef](0)) } @@ -196,7 +198,7 @@ case class ColumnarFilterExec(condition: Expression, child: SparkPlan) val omniAttrExpsIdMap = getExprIdMap(child.output) val omniInputTypes = child.output.map( - exp => sparkTypeToOmniType(exp.dataType, exp.metadata)).toArray + exp => sparkTypeToOmniTypeWithComplex(exp.dataType, exp.metadata)).toArray val omniProjectIndices = child.output.map( exp => sparkProjectionToOmniJsonProjection(exp, omniAttrExpsIdMap(exp.exprId))).toArray @@ -298,7 +300,7 @@ case class ColumnarConditionProjectExec(projectList: Seq[NamedExpression], val omniAttrExpsIdMap = getExprIdMap(child.output) val omniInputTypes = child.output.map( - exp => sparkTypeToOmniType(exp.dataType, exp.metadata)).toArray + exp => sparkTypeToOmniTypeWithComplex(exp.dataType, exp.metadata)).toArray val omniExpressions = projectList.map( exp => rewriteToOmniJsonExpressionLiteral(exp, omniAttrExpsIdMap)).toArray diff --git a/omnioperator/omniop-spark-extension/spark-extension-core/src/main/scala/org/apache/spark/sql/execution/ColumnarExec.scala b/omnioperator/omniop-spark-extension/spark-extension-core/src/main/scala/org/apache/spark/sql/execution/ColumnarExec.scala index c573cb82f9..0acd742eaf 100644 --- a/omnioperator/omniop-spark-extension/spark-extension-core/src/main/scala/org/apache/spark/sql/execution/ColumnarExec.scala +++ b/omnioperator/omniop-spark-extension/spark-extension-core/src/main/scala/org/apache/spark/sql/execution/ColumnarExec.scala @@ -32,8 +32,7 @@ import org.apache.spark.sql.vectorized.ColumnarBatch import java.util.concurrent.TimeUnit.NANOSECONDS import scala.collection.JavaConverters._ import scala.collection.mutable.ListBuffer - -import nova.hetu.omniruntime.vector.Vec +import nova.hetu.omniruntime.vector.{MapVec, StructVec, Vec} /** * Provides an optimized set of APIs to append row based data to an array of @@ -44,12 +43,14 @@ private[execution] class OmniRowToColumnConverter(schema: StructType) extends Se f => OmniRowToColumnConverter.getConverterForType(f.dataType, f.nullable) } - final def convert(row: InternalRow, vectors: Array[WritableColumnVector]): Unit = { + final def convert(rows: Seq[InternalRow], size: Int): Seq[WritableColumnVector] = { var idx = 0 - while (idx < row.numFields) { - converters(idx).append(row, idx, vectors(idx)) + val res = new Array[WritableColumnVector](schema.fields.length) + while (idx < schema.fields.length) { + res(idx) = converters(idx).add(rows, idx, size) idx += 1 } + res.toSeq } } @@ -62,16 +63,7 @@ private object OmniRowToColumnConverter { private abstract class TypeConverter extends Serializable { def append(row: SpecializedGetters, column: Int, cv: WritableColumnVector): Unit - } - - private final case class BasicNullableTypeConverter(base: TypeConverter) extends TypeConverter { - override def append(row: SpecializedGetters, column: Int, cv: WritableColumnVector): Unit = { - if (row.isNullAt(column)) { - cv.appendNull - } else { - base.append(row, column, cv) - } - } + def add(rows: Seq[SpecializedGetters], column: Int, size: Int): WritableColumnVector } private def getConverterForType(dataType: DataType, nullable: Boolean): TypeConverter = { @@ -80,87 +72,346 @@ private object OmniRowToColumnConverter { case BooleanType => BooleanConverter case ByteType => ByteConverter case ShortType => ShortConverter - case IntegerType | DateType => IntConverter - case LongType | TimestampType => LongConverter + case IntegerType | DateType => IntConverter(dataType) + case LongType | TimestampType => LongConverter(dataType) case DoubleType => DoubleConverter case StringType => StringConverter case CalendarIntervalType => CalendarConverter case dt: DecimalType => DecimalConverter(dt) + case struct: StructType => + val fieldConverters = struct.fields.map { f => + getConverterForType(f.dataType, f.nullable) + } + StructConverter(fieldConverters, struct) + case map: MapType => + val keyConverter = getConverterForType(map.keyType, nullable) + val valueConverter = getConverterForType(map.valueType, map.valueContainsNull) + MapConverter(keyConverter, valueConverter, map) case unknown => throw new UnsupportedOperationException( s"Type $unknown not supported") } + core + } - if (nullable) { - dataType match { - case _ => new BasicNullableTypeConverter(core) + private case class MapConverter( + keyConverter: TypeConverter, + valueConverter: TypeConverter, + dataType: MapType) extends TypeConverter { + override def append(row: SpecializedGetters, column: Int, cv: WritableColumnVector): Unit = { + throw new UnsupportedOperationException("StructConverter not support append()") + } + + override def add(rows: Seq[SpecializedGetters], + column: Int, size: Int): WritableColumnVector = { + val cv = new OmniColumnVector(size, dataType, true) + // count total offset + var totalLen = 0 + val offsets = new ListBuffer[Int] + val nulls = new ListBuffer[Byte] + offsets += 0 + for (row <- rows) { + val mapData = if (row == null) null else row.getMap(column) + if (mapData == null) { + nulls += 1 + } else { + nulls += 0 + val num = mapData.numElements + totalLen += num + } + offsets += totalLen } - } else { - core + + val keyVector = new OmniColumnVector(totalLen, dataType.keyType, true) + val valueVector = new OmniColumnVector(totalLen, dataType.valueType, true) + for (row <- rows) { + val mapData = if (row == null) null else row.getMap(column) + if (mapData != null) { + val mapLength = mapData.numElements + for (i <- 0 until mapLength) { + keyConverter.append(mapData.keyArray(), i, keyVector) + valueConverter.append(mapData.valueArray(), i, valueVector) + } + } + } + + cv.setChild(keyVector, 0) + cv.setChild(valueVector, 1) + cv.setOffsets(offsets.toArray) + cv.updateVec() + cv.putNulls(0, nulls.toArray, size) + cv + } + } + + private case class StructConverter(childConverters: Array[TypeConverter], dataType: StructType) + extends TypeConverter { + override def append(row: SpecializedGetters, column: Int, cv: WritableColumnVector): Unit = { + throw new UnsupportedOperationException("StructConverter not support append()") + } + + override def add(rows: Seq[SpecializedGetters], + column: Int, size: Int): WritableColumnVector = { + // not init child + val cv = new OmniColumnVector(size, dataType, true) + val structRows = new ListBuffer[SpecializedGetters]() + val nulls = new ListBuffer[Byte]() + for (row <- rows) { + val struct = if (row == null) null else row.getStruct(column, childConverters.length) + if (struct == null) { + nulls += 1 + } else { + nulls += 0 + } + structRows += struct + } + childConverters.zipWithIndex.foreach { case (childConverter, fieldIndex) => + val vector = childConverter.add(structRows, fieldIndex, size).asInstanceOf[OmniColumnVector] + cv.setChild(vector, fieldIndex) + } + cv.updateVec() + cv.putNulls(0, nulls.toArray, size) + cv } } private object BinaryConverter extends TypeConverter { override def append(row: SpecializedGetters, column: Int, cv: WritableColumnVector): Unit = { - val bytes = row.getBinary(column) - cv.appendByteArray(bytes, 0, bytes.length) + if (row == null || row.isNullAt(column)) { + cv.appendNull + } else { + val bytes = row.getBinary(column) + cv.appendByteArray(bytes, 0, bytes.length) + } + } + + override def add(rows: Seq[SpecializedGetters], + column: Int, size: Int): WritableColumnVector = { + val cv = new OmniColumnVector(size, BinaryType, true) + for (row <- rows) { + if (row == null || row.isNullAt(column)) { + cv.appendNull + } else { + append(row, column, cv) + } + } + cv } } private object BooleanConverter extends TypeConverter { - override def append(row: SpecializedGetters, column: Int, cv: WritableColumnVector): Unit = - cv.appendBoolean(row.getBoolean(column)) + override def append(row: SpecializedGetters, column: Int, cv: WritableColumnVector): Unit = { + if (row == null || row.isNullAt(column)) { + cv.appendNull + } else { + cv.appendBoolean(row.getBoolean(column)) + } + } + + override def add(rows: Seq[SpecializedGetters], + column: Int, size: Int): WritableColumnVector = { + val cv = new OmniColumnVector(size, BooleanType, true) + for (row <- rows) { + if (row == null || row.isNullAt(column)) { + cv.appendNull + } else { + append(row, column, cv) + } + } + cv + } } private object ByteConverter extends TypeConverter { - override def append(row: SpecializedGetters, column: Int, cv: WritableColumnVector): Unit = - cv.appendByte(row.getByte(column)) + override def append(row: SpecializedGetters, column: Int, cv: WritableColumnVector): Unit = { + if (row == null || row.isNullAt(column)) { + cv.appendNull + } else { + cv.appendByte(row.getByte(column)) + } + } + + override def add(rows: Seq[SpecializedGetters], + column: Int, size: Int): WritableColumnVector = { + val cv = new OmniColumnVector(size, ByteType, true) + for (row <- rows) { + if (row == null || row.isNullAt(column)) { + cv.appendNull + } else { + append(row, column, cv) + } + } + cv + } } private object ShortConverter extends TypeConverter { - override def append(row: SpecializedGetters, column: Int, cv: WritableColumnVector): Unit = - cv.appendShort(row.getShort(column)) + override def append(row: SpecializedGetters, column: Int, cv: WritableColumnVector): Unit = { + if (row == null || row.isNullAt(column)) { + cv.appendNull + } else { + cv.appendShort(row.getShort(column)) + } + } + + override def add(rows: Seq[SpecializedGetters], + column: Int, size: Int): WritableColumnVector = { + val cv = new OmniColumnVector(size, ShortType, true) + for (row <- rows) { + if (row == null || row.isNullAt(column)) { + cv.appendNull + } else { + append(row, column, cv) + } + } + cv + } } - private object IntConverter extends TypeConverter { - override def append(row: SpecializedGetters, column: Int, cv: WritableColumnVector): Unit = - cv.appendInt(row.getInt(column)) + private case class IntConverter(dataType: DataType) extends TypeConverter { + override def append(row: SpecializedGetters, column: Int, cv: WritableColumnVector): Unit = { + if (row == null || row.isNullAt(column)) { + cv.appendNull + } else { + cv.appendInt(row.getInt(column)) + } + } + + override def add(rows: Seq[SpecializedGetters], + column: Int, size: Int): WritableColumnVector = { + val cv = new OmniColumnVector(size, dataType, true) + for (row <- rows) { + if (row == null || row.isNullAt(column)) { + cv.appendNull + } else { + append(row, column, cv) + } + } + cv + } } - private object LongConverter extends TypeConverter { - override def append(row: SpecializedGetters, column: Int, cv: WritableColumnVector): Unit = - cv.appendLong(row.getLong(column)) + private case class LongConverter(dataType: DataType) extends TypeConverter { + override def append(row: SpecializedGetters, column: Int, cv: WritableColumnVector): Unit = { + if (row == null || row.isNullAt(column)) { + cv.appendNull + } else { + cv.appendLong(row.getLong(column)) + } + } + + override def add(rows: Seq[SpecializedGetters], + column: Int, size: Int): WritableColumnVector = { + val cv = new OmniColumnVector(size, dataType, true) + for (row <- rows) { + if (row == null || row.isNullAt(column)) { + cv.appendNull + } else { + append(row, column, cv) + } + } + cv + } } private object DoubleConverter extends TypeConverter { - override def append(row: SpecializedGetters, column: Int, cv: WritableColumnVector): Unit = - cv.appendDouble(row.getDouble(column)) + override def append(row: SpecializedGetters, column: Int, cv: WritableColumnVector): Unit = { + if (row == null || row.isNullAt(column)) { + cv.appendNull + } else { + cv.appendDouble(row.getDouble(column)) + } + } + + override def add(rows: Seq[SpecializedGetters], + column: Int, size: Int): WritableColumnVector = { + val cv = new OmniColumnVector(size, DoubleType, true) + for (row <- rows) { + if (row == null || row.isNullAt(column)) { + cv.appendNull + } else { + append(row, column, cv) + } + } + cv + } } private object StringConverter extends TypeConverter { override def append(row: SpecializedGetters, column: Int, cv: WritableColumnVector): Unit = { - val data = row.getUTF8String(column).getBytes - cv.asInstanceOf[OmniColumnVector].appendString(data.length, data, 0) + if (row == null || row.isNullAt(column)) { + cv.appendNull + } else { + val data = row.getUTF8String(column).getBytes + cv.asInstanceOf[OmniColumnVector].appendString(data.length, data, 0) + } + } + + override def add(rows: Seq[SpecializedGetters], + column: Int, size: Int): WritableColumnVector = { + val cv = new OmniColumnVector(size, StringType, true) + for (row <- rows) { + if (row == null || row.isNullAt(column)) { + cv.appendNull + } else { + append(row, column, cv) + } + } + cv } } private object CalendarConverter extends TypeConverter { override def append(row: SpecializedGetters, column: Int, cv: WritableColumnVector): Unit = { - val c = row.getInterval(column) - cv.appendStruct(false) - cv.getChild(0).appendInt(c.months) - cv.getChild(1).appendInt(c.days) - cv.getChild(2).appendLong(c.microseconds) + if (row == null || row.isNullAt(column)) { + cv.appendNull + } else { + val c = row.getInterval(column) + cv.appendStruct(false) + cv.getChild(0).appendInt(c.months) + cv.getChild(1).appendInt(c.days) + cv.getChild(2).appendLong(c.microseconds) + } + } + + override def add(rows: Seq[SpecializedGetters], + column: Int, size: Int): WritableColumnVector = { + val cv = new OmniColumnVector(size, CalendarIntervalType, true) + for (row <- rows) { + if (row == null || row.isNullAt(column)) { + cv.appendNull + } else { + append(row, column, cv) + } + } + cv } } private case class DecimalConverter(dt: DecimalType) extends TypeConverter { override def append(row: SpecializedGetters, column: Int, cv: WritableColumnVector): Unit = { - val d = row.getDecimal(column, dt.precision, dt.scale) - if (DecimalType.is64BitDecimalType(dt)) { - cv.appendLong(d.toUnscaledLong) + if (row == null || row.isNullAt(column)) { + cv.appendNull } else { - cv.asInstanceOf[OmniColumnVector].appendDecimal(d) + val d = row.getDecimal(column, dt.precision, dt.scale) + if (DecimalType.is64BitDecimalType(dt)) { + cv.appendLong(d.toUnscaledLong) + } else { + cv.asInstanceOf[OmniColumnVector].appendDecimal(d) + } + } + } + + override def add(rows: Seq[SpecializedGetters], + column: Int, size: Int): WritableColumnVector = { + val cv = new OmniColumnVector(size, dt, true) + for (row <- rows) { + if (row == null || row.isNullAt(column)) { + cv.appendNull + } else { + append(row, column, cv) + } } + cv } } } @@ -226,7 +477,8 @@ case class RowToOmniColumnarExec(child: SparkPlan) extends RowToColumnarTransiti _) ) val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) - SQLMetrics.postDriverMetricUpdates(sparkContext, executionId, Seq(numInputRows, numOutputBatches, rowToOmniColumnarTime)) + SQLMetrics.postDriverMetricUpdates(sparkContext, executionId, + Seq(numInputRows, numOutputBatches, rowToOmniColumnarTime)) broadcast } @@ -237,7 +489,8 @@ case class RowToOmniColumnarExec(child: SparkPlan) extends RowToColumnarTransiti override lazy val metrics: Map[String, SQLMetric] = Map( "numInputRows" -> SQLMetrics.createMetric(sparkContext, "number of input rows"), "numOutputBatches" -> SQLMetrics.createMetric(sparkContext, "number of output batches"), - "rowToOmniColumnarTime" -> SQLMetrics.createTimingMetric(sparkContext, "time in row to OmniColumnar") + "rowToOmniColumnarTime" -> SQLMetrics.createTimingMetric(sparkContext, + "time in row to OmniColumnar") ) override protected def withNewChildInternal(newChild: SparkPlan): RowToOmniColumnarExec = @@ -255,7 +508,8 @@ case class RowToOmniColumnarExec(child: SparkPlan) extends RowToColumnarTransiti // plan (this) in the closure. val localSchema = this.schema child.execute().mapPartitionsInternal { rowIterator => - InternalRowToColumnarBatch.convert(enableOffHeapColumnVector, numInputRows, numOutputBatches, rowToOmniColumnarTime, numRows, localSchema, rowIterator) + InternalRowToColumnarBatch.convert(enableOffHeapColumnVector, numInputRows, numOutputBatches, + rowToOmniColumnarTime, numRows, localSchema, rowIterator) } } @@ -275,7 +529,8 @@ case class OmniColumnarToRowExec(child: SparkPlan, override lazy val metrics: Map[String, SQLMetric] = Map( "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"), "numInputBatches" -> SQLMetrics.createMetric(sparkContext, "number of input batches"), - "omniColumnarToRowTime" -> SQLMetrics.createTimingMetric(sparkContext, "time in omniColumnar to row") + "omniColumnarToRowTime" -> + SQLMetrics.createTimingMetric(sparkContext, "time in omniColumnar to row") ) override def verboseStringWithOperatorId(): String = { @@ -337,22 +592,17 @@ object InternalRowToColumnarBatch { override def next(): ColumnarBatch = { val startTime = System.nanoTime() - val vectors: Seq[WritableColumnVector] = OmniColumnVector.allocateColumns(numRows, - localSchema, true) - val cb: ColumnarBatch = new ColumnarBatch(vectors.toArray) - cb.setNumRows(0) - vectors.foreach(_.reset()) + var rowCount = 0 + val bufferRow = ListBuffer[InternalRow]() while (rowCount < numRows && rowIterator.hasNext) { - val row = rowIterator.next() - converters.convert(row, vectors.toArray) + var row = rowIterator.next() + bufferRow += row.copy() rowCount += 1 } - if (!enableOffHeapColumnVector) { - vectors.foreach { v => - v.asInstanceOf[OmniColumnVector].getVec.setSize(rowCount) - } - } + + val vectors = converters.convert(bufferRow, rowCount) + val cb: ColumnarBatch = new ColumnarBatch(vectors.toArray) cb.setNumRows(rowCount) numInputRows += rowCount numOutputBatches += 1 @@ -424,4 +674,4 @@ object ColumnarBatchToInternalRow { } batchIter } -} \ No newline at end of file +} diff --git a/omnioperator/omniop-spark-extension/spark-extension-core/src/main/scala/org/apache/spark/sql/execution/ColumnarFileSourceScanExec.scala b/omnioperator/omniop-spark-extension/spark-extension-core/src/main/scala/org/apache/spark/sql/execution/ColumnarFileSourceScanExec.scala index 9991b1468b..43f881b717 100644 --- a/omnioperator/omniop-spark-extension/spark-extension-core/src/main/scala/org/apache/spark/sql/execution/ColumnarFileSourceScanExec.scala +++ b/omnioperator/omniop-spark-extension/spark-extension-core/src/main/scala/org/apache/spark/sql/execution/ColumnarFileSourceScanExec.scala @@ -385,6 +385,11 @@ abstract class BaseColumnarFileSourceScanExec( def buildCheck(): Unit = { output.zipWithIndex.foreach { + case (attr, i) => + sparkTypeToOmniTypeWithComplex(attr.dataType, attr.metadata) + } + val partitionSchema = relation.partitionSchema + partitionSchema.zipWithIndex.foreach { case (attr, i) => sparkTypeToOmniType(attr.dataType, attr.metadata) } diff --git a/omnioperator/omniop-spark-extension/spark-extension-core/src/main/scala/org/apache/spark/sql/types/ColumnarBatchSupportUtil.scala b/omnioperator/omniop-spark-extension/spark-extension-core/src/main/scala/org/apache/spark/sql/types/ColumnarBatchSupportUtil.scala index bb31d7f82b..40e634df48 100644 --- a/omnioperator/omniop-spark-extension/spark-extension-core/src/main/scala/org/apache/spark/sql/types/ColumnarBatchSupportUtil.scala +++ b/omnioperator/omniop-spark-extension/spark-extension-core/src/main/scala/org/apache/spark/sql/types/ColumnarBatchSupportUtil.scala @@ -37,7 +37,7 @@ object ColumnarBatchSupportUtil { val supportBatchReader: Boolean = { val partitionSchema = plan.relation.partitionSchema val resultSchema = StructType(plan.requiredSchema.fields ++ partitionSchema.fields) - (conf.orcVectorizedReaderEnabled || conf.parquetVectorizedReaderEnabled) && resultSchema.forall(_.dataType.isInstanceOf[AtomicType]) + (conf.orcVectorizedReaderEnabled || conf.parquetVectorizedReaderEnabled) } supportBatchReader && isSupportFormat } diff --git a/omnioperator/omniop-spark-extension/spark-extension-core/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniMapReaderTest.java b/omnioperator/omniop-spark-extension/spark-extension-core/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniMapReaderTest.java new file mode 100644 index 0000000000..79da8e7458 --- /dev/null +++ b/omnioperator/omniop-spark-extension/spark-extension-core/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniMapReaderTest.java @@ -0,0 +1,172 @@ +/* + * Copyright (C) 2022-2022. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.huawei.boostkit.spark.jni; + +import com.huawei.boostkit.spark.expression.OmniExpressionAdaptor; +import junit.framework.TestCase; +import nova.hetu.omniruntime.vector.IntVec; +import nova.hetu.omniruntime.vector.Vec; +import nova.hetu.omniruntime.vector.MapVec; +import nova.hetu.omniruntime.vector.VarcharVec; +import nova.hetu.omniruntime.vector.DoubleVec; +import org.apache.hadoop.conf.Configuration; +import org.apache.orc.Reader; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.types.MapType; +import org.junit.After; +import org.junit.Before; +import org.junit.FixMethodOrder; +import org.junit.Test; +import org.junit.runners.MethodSorters; +import org.apache.spark.sql.types.DataTypes; + +import java.io.File; +import java.net.URI; +import java.net.URISyntaxException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +import static org.apache.spark.sql.types.DataTypes.DoubleType; +import static org.apache.spark.sql.types.DataTypes.StringType; +import static org.apache.spark.sql.types.DataTypes.IntegerType; + +@FixMethodOrder(value = MethodSorters.NAME_ASCENDING) +public class OrcColumnarBatchJniMapReaderTest extends TestCase { + private List dataTypes = new ArrayList<>(); + public Configuration conf = new Configuration(); + public OrcColumnarBatchScanReader orcColumnarBatchScanReader; + private int batchSize = 4096; + + private StructType requiredSchema; + private int[] vecTypeIds; + + private long offset = 0; + + private long length = Integer.MAX_VALUE; + + @Before + public void setUp() throws Exception { + orcColumnarBatchScanReader = new OrcColumnarBatchScanReader(); + constructSchema(); + initReaderJava(); + initDataColIds(); + initRecordReaderJava(); + initBatch(); + } + + private void constructSchema() { + requiredSchema = new StructType() + .add("c1", IntegerType) + .add("c17", DataTypes.createMapType(StringType, DoubleType)); + } + + private void initDataColIds() { + // find requiredS fieldNames + String[] requiredfieldNames = requiredSchema.fieldNames(); + // save valid cols and numbers of valid cols + orcColumnarBatchScanReader.colsToGet = new int[requiredfieldNames.length]; + orcColumnarBatchScanReader.includedColumns = new ArrayList<>(); + // collect read cols types + ArrayList typeBuilder = new ArrayList<>(); + + for (int i = 0; i < requiredfieldNames.length; i++) { + String target = requiredfieldNames[i]; + + // if not find, set colsToGet value -1, else set colsToGet 0 + boolean is_find = false; + for (int j = 0; j < orcColumnarBatchScanReader.allFieldsNames.size(); j++) { + if (target.equals(orcColumnarBatchScanReader.allFieldsNames.get(j))) { + orcColumnarBatchScanReader.colsToGet[i] = 0; + orcColumnarBatchScanReader.includedColumns.add(requiredfieldNames[i]); + StructField field = requiredSchema.fields()[i]; + nova.hetu.omniruntime.type.DataType dataType = + OmniExpressionAdaptor.sparkTypeToOmniTypeWithComplex(field.dataType(), field.metadata()); + typeBuilder.add(dataType.getId().ordinal()); + dataTypes.add(dataType); + is_find = true; + break; + } + } + + if (!is_find) { + orcColumnarBatchScanReader.colsToGet[i] = -1; + } + } + + vecTypeIds = typeBuilder.stream().mapToInt(Integer::intValue).toArray(); + } + + @After + public void tearDown() throws Exception { + System.out.println("OrcColumnarBatchJniMapReaderTest test finished"); + } + + private void initReaderJava() { + File directory = new File("src/test/java/com/huawei/boostkit/spark/jni/orcsrc/orc_nested_type"); + String path = directory.getAbsolutePath(); + URI uri = null; + try { + uri = new URI(path); + } catch (URISyntaxException ignore) { + // if URISyntaxException thrown, next line assertNotNull will interrupt the test + } + assertNotNull(uri); + orcColumnarBatchScanReader.reader = orcColumnarBatchScanReader.initializeReaderJava(uri); + assertTrue(orcColumnarBatchScanReader.reader != 0); + } + + private void initRecordReaderJava() { + orcColumnarBatchScanReader.recordReader = orcColumnarBatchScanReader. + initializeRecordReaderJava(offset, length, null, requiredSchema, false, true); + assertTrue(orcColumnarBatchScanReader.recordReader != 0); + } + + private void initBatch() { + orcColumnarBatchScanReader.initBatchJava(batchSize); + assertTrue(orcColumnarBatchScanReader.batchReader != 0); + } + + @Test + public void testNext() { + Vec[] vecs = new Vec[2]; + long rtn = orcColumnarBatchScanReader.next(vecs, vecTypeIds, dataTypes); + assertTrue(((IntVec) vecs[0]).get(0) == 1); + VarcharVec keyVector = (VarcharVec)(((MapVec) vecs[1]).getKeyVec()); + assertTrue(new String(keyVector.get(0)).equals("key1")); + + DoubleVec valueVector = (DoubleVec)(((MapVec) vecs[1]).getValueVec()); + assertTrue(valueVector.get(0) - 1.1 < 0.000001); + vecs[0].close(); + vecs[1].close(); + } + + @Test + public void testGetProgress() { + String tmp = ""; + try { + double progressValue = orcColumnarBatchScanReader.getProgress(); + } catch (Exception e) { + tmp = e.getMessage(); + } finally { + assertTrue(tmp.equals("recordReaderGetProgress is unsupported")); + } + } +} diff --git a/omnioperator/omniop-spark-extension/spark-extension-core/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderTest.java b/omnioperator/omniop-spark-extension/spark-extension-core/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderTest.java index 9b28daf37b..b0318c4c30 100644 --- a/omnioperator/omniop-spark-extension/spark-extension-core/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderTest.java +++ b/omnioperator/omniop-spark-extension/spark-extension-core/src/test/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchJniReaderTest.java @@ -137,7 +137,7 @@ public class OrcColumnarBatchJniReaderTest extends TestCase { @Test public void testNext() { Vec[] vecs = new Vec[2]; - long rtn = orcColumnarBatchScanReader.next(vecs, vecTypeIds); + long rtn = orcColumnarBatchScanReader.next(vecs, vecTypeIds, null); assertTrue(rtn == 4096); assertTrue(((LongVec) vecs[0]).get(0) == 1); String str = new String(((VarcharVec) vecs[1]).get(0)); diff --git a/omnioperator/omniop-spark-extension/spark-extension-core/src/test/java/com/huawei/boostkit/spark/jni/orcsrc/orc_nested_type b/omnioperator/omniop-spark-extension/spark-extension-core/src/test/java/com/huawei/boostkit/spark/jni/orcsrc/orc_nested_type new file mode 100644 index 0000000000000000000000000000000000000000..5aa0d82c0b7ba74a02d7debf14c8d3053e04183f GIT binary patch literal 2930 zcmZuzdpMNo8-8a#!V9Gg@pmr^dNqX)Bc)#};E3A*VJJ z*`$z22aRLpkWiMSFh!fTl~}f<{f4W(_WJevW9E8ip8L6<`+1)G{jSfIV@m`8Bds8X zhS*>ud{o3h000FX8d8u(BMH(Z92>X*AU2uw=K%!O5Feee1x#K=-B^I|l#(d$<&JwGR^y?e^fL?R+Bx{be2pX-OD5*H=bNG zT<`-^b&)GBEButiWCKr0D-$cNIO78~cr+Tc%~m5ok;{n2QeB{J=q?UMH=OT_vn!uO z-~blP9YBYm_(1yhlw`H{^efKp7LiLI?`%wgmC#{6 z<<-&R(a;t&`qpJ6q5cZYsRl=|tC6qwExWRB1=6=)<6x0ug{hU^w&Pv2@U#v)YKN0- z3B2(CxW^kt^{-Aq8Q%S78Xukj>UmqYWzv`eb#FA5l_coC<&L^gRI4Q00OvZu37l11 z>A?v^n%RY=iL50{D*n!+sHA&KrQxzwFl-a7nzX8#2wMe9 zRS`tInbvk%?wd!Bwp2ohwx7fRc9N1`EOIJqvS^c{)b=OX^0>I{1Wix7{-;y{;;S)z z=~0>uwn*q#Ct(cJBWQ!&S&zJJsYX0)KgItsm^5rKc&Ekr`R}Y?3g8Ta3@mJmj=K_mVx5~2Lp~St6R@`w~>MEIJqP$9f5+e@&<62rq|ic^7spCEz)s^S~A}atW zOjsGy7=SN#p%}gqnw1UA($>%>VF*i z)8*IA+4*NyL}6ga=VB}XE=VO^%XI)8M6Oru6|CobRj(fvCUrOY0J~cPS;C* zL+~s_Xe4swtO$Ugf51LoU{G)f{D(9GepVQpFeuJ`A$$t=faxYPimP83#rh!6++s6@ z9m3~PY{C!jTN&rghwwAZg?EVJ6%iF?6A^tNf^Qof65tmUu@X*>Q~`jG1TN};a2PQh zC8>wh1)$5N%CsB)ApHXS#SD^+36g^CRJdZ1UgUrfAL1)|3C6Koj0ABcckfQb z$|#ZgaK6E=o=x3S3Ogw+B)dETnoIDLULioigYD@STJ?0}m<$G&37?q&RKw>1<_iQ! z>g(gjGYKd*$(zYU?6_WQwn;nG)VZ+f!)yAh&!1DpqcJ9Cu@s#{L5CPyl|p_oBX&-N z;MmRi+7^AX_($V~=Qdll}~wPs$P zu|7ItHFgSl^hW)B--Dy!zbD6bd}tD4J`x@q>v%GBa&=x+h-2JZkE9j^zF`a=vbxH( zjqN;IM2_;yy%&;{JKs4N^oDz2J@vOYuJCt5Mc$q+*0PLS>((s~Tpwk4;BJE~H5uy1(S!?s9^DXT@xt`fYnV#p_xp~}zi@v&v+P(%vl)GVX8Opt?t=+?$ zlbwGW?;-3o&qX6H=7_TpPc&QljK-G{^tfcXX!+5g?BK)Il{~A5%0zYVcg#~Qjiwof zyG2l9#%{8htM@LB7#yd6+C=7Rwy6$8TB=q_G&;OilZx9cQ>=cxqI5R)ec9Zf9h=c6 zAq{DxMY~^9{?=H0*FSt>k9F$##=6aU>Di6fadWtIp|ev@w_&1%+#JP#rB__t7-vk# zFG%>rlW`C4&;s00*~eO%K&(c#o^zCC&U=wa2|}Z*;3U~x=IR=TiiToFD<<85h)_$x zP|hZ>gbBlGuBGE0OI=}#%J{XPm;ZEBXL}1-f`sAJgn-^H+Hw1*t+zBb4ly(3%;FT+ zj@pc&_nl0`Mn>v?A8D+OF%EP@BX&rY+_#k@88Yp`=%um3y!;CUZ}m3pfonF*z`Yma zxh`I*iE8X?G~K8iyX!+mvuUi^_lXuuoGxYc_i43Sb)x%rq*r|YPa81Evt~7{xnpAh zGuvFXM(oR-5c}VUFNqQ|Gf+XA`Ms*IcQr!nKj=hQ45sHvOLOp9zQkS8GS&BuY&|W8 z1ebFFSxeMj;FVPnrYlw6tr;$ut!1|gh_^vFDI*h@;qK+3HaU?&F$%2H0z;GhaZ9p)=S9wxP?T<7{wzG*(d1XDCohrw6+QYP*FMc->S4m1x`FZ(A z8TpQicM}_mWww!=mSjY2Kgy&>lw+;so&Gxd@!9l3(?WfOk@_CD+mOaAi)YxrNaba2 zPB@qCXu^pPMW4LpP^YmoSih=^r87OpkNn?p}=QKzQL zMb3TM#wN;6LcbXsfGSPQMA6Y~XXU%Ke%LhgYlmqd1L4Pfit|r*{GO`uH1Xkm(P9N7 zhAAI2Ffi|Uu4LoO(C8iZ=9RyF+DUagH-XNPH?jI=gnVOIc0Q_{ZRY3r}OVQFP^ zf4yk1UpVo6^1N{5q}yftflJ4?PT_Eq*d65h4OD$LH;|>A%c&{#ZH7!zNXGnk!>*e@x0gwY#)yea=OpjrO?D&jHoBd_yEd zE$DV=PhNOGKJQ6h#h`x2tEc^@Tjky;9%~HbaP3vdJxSztaSl)1ViTJ?pBmOL8*M*R zZ{1`iUvsmNlvOi2sCT`_*g3!I(opt@ Date: Mon, 8 Sep 2025 11:58:22 +0800 Subject: [PATCH 2/2] fix map null scan --- .../omniop-native-reader/cpp/src/orcfile/OmniColReader.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/omnioperator/omniop-native-reader/cpp/src/orcfile/OmniColReader.cc b/omnioperator/omniop-native-reader/cpp/src/orcfile/OmniColReader.cc index 1948f68be8..578f083c2f 100644 --- a/omnioperator/omniop-native-reader/cpp/src/orcfile/OmniColReader.cc +++ b/omnioperator/omniop-native-reader/cpp/src/orcfile/OmniColReader.cc @@ -346,7 +346,7 @@ namespace omniruntime::reader { auto keyDataTypeId = OrcUtil::tranOrcTypeToOmniType(keyOrcType); std::shared_ptr keyVector = std::move(makeNewVector(totalChildren, keyOrcType, keyDataTypeId)); mapvector->SetKeyVector(keyVector); - reinterpret_cast(rawKeyReader)->next((mapvector->GetKeyVector().get()), totalChildren, nullptr, keyDataTypeId); + reinterpret_cast(rawKeyReader)->next((mapvector->GetKeyVector().get()), totalChildren, hasNull ? reinterpret_cast(nulls) : nullptr, keyDataTypeId); } ColumnReader *rawValueReader = valueReader.get(); if (rawValueReader) { @@ -354,7 +354,7 @@ namespace omniruntime::reader { auto valueDataTypeId= OrcUtil::tranOrcTypeToOmniType(valueOrcType); std::shared_ptr valueVector = std::move(makeNewVector(totalChildren, valueOrcType, valueDataTypeId)); mapvector->SetValueVector(valueVector); - reinterpret_cast(rawValueReader)->next((mapvector->GetValueVector().get()), totalChildren, nullptr, valueDataTypeId); + reinterpret_cast(rawValueReader)->next((mapvector->GetValueVector().get()), totalChildren, hasNull ? reinterpret_cast(nulls) : nullptr, valueDataTypeId); } } -- Gitee