diff --git a/omnioperator/omniop-native-reader/cpp/src/CMakeLists.txt b/omnioperator/omniop-native-reader/cpp/src/CMakeLists.txt index 6986422c62493df2fafd2fb641c5a24c5852151a..090d2375f95b8b315c1c5964e64a94cacf8af1ec 100644 --- a/omnioperator/omniop-native-reader/cpp/src/CMakeLists.txt +++ b/omnioperator/omniop-native-reader/cpp/src/CMakeLists.txt @@ -4,6 +4,7 @@ set (PROJ_TARGET native_reader) set (SOURCE_FILES + jni/OrcColumnarBatchJniWriter.cpp jni/OrcColumnarBatchJniReader.cpp jni/jni_common.cpp jni/ParquetColumnarBatchJniReader.cpp diff --git a/omnioperator/omniop-native-reader/cpp/src/filesystem/file_interface.h b/omnioperator/omniop-native-reader/cpp/src/filesystem/file_interface.h index ba5e0af9dc2d501873a0e75b63997d5c9309dc08..6ef616de86a9d0b2ae1ec8893d0d078f989fae27 100644 --- a/omnioperator/omniop-native-reader/cpp/src/filesystem/file_interface.h +++ b/omnioperator/omniop-native-reader/cpp/src/filesystem/file_interface.h @@ -48,7 +48,26 @@ public: virtual int64_t Read(void *buffer, int32_t length) = 0; }; -} +class WriteableFile { +public: + // Virtual destructor + virtual ~WriteableFile() = default; + + // Close the file + virtual Status Close() = 0; + + // Open the file + virtual Status OpenFile() = 0; + + // Get the size of the file + virtual int64_t GetFileSize() = 0; + + // Write data from the current position into the buffer with the given + // length + virtual int64_t Write(const void *buffer, int32_t length) = 0; +}; + +} // namespace fs #endif //SPARK_THESTRAL_PLUGIN_FILE_INTERFACE_H diff --git a/omnioperator/omniop-native-reader/cpp/src/filesystem/hdfs_file.cpp b/omnioperator/omniop-native-reader/cpp/src/filesystem/hdfs_file.cpp index 4b08d1b2152ae53967f368430a7fe38825ae6584..4998ff3acb41731eabb79bbf1434978b3f60dfe0 100644 --- a/omnioperator/omniop-native-reader/cpp/src/filesystem/hdfs_file.cpp +++ b/omnioperator/omniop-native-reader/cpp/src/filesystem/hdfs_file.cpp @@ -97,5 +97,57 @@ int64_t HdfsReadableFile::Read(void *buffer, int32_t length) { return hdfsRead(fileSystem_->getFileSystem(), file_, buffer, length); } +HdfsWriteableFile::HdfsWriteableFile(std::shared_ptr fileSystemPtr, const std::string &path, + int64_t bufferSize) : + fileSystem_(std::move(fileSystemPtr)), + path_(path), bufferSize_(bufferSize) +{ +} + +HdfsWriteableFile::~HdfsWriteableFile() { this->TryClose(); } + +Status HdfsWriteableFile::Close() { return TryClose(); } + +Status HdfsWriteableFile::OpenFile() { + if (isOpen_) { + return Status::OK(); + } + hdfsFile handle = hdfsOpenFile(fileSystem_->getFileSystem(), path_.c_str(), O_WRONLY, bufferSize_, 0, 0); + if (handle == nullptr) { + return Status::IOError("Fail to open hdfs file, path is " + path_); + } + + this->file_ = handle; + this->isOpen_ = true; + return Status::OK(); +} + +int64_t HdfsWriteableFile::Write(const void *buffer, int32_t length) { + if (!OpenFile().IsOk()) { + return -1; + } + hdfsWrite(fileSystem_->getFileSystem(), file_, buffer, length); + return hdfsHFlush(fileSystem_->getFileSystem(), file_); +} + +Status HdfsWriteableFile::TryClose() { + if (!isOpen_) { + return Status::OK(); + } + int st = hdfsCloseFile(fileSystem_->getFileSystem(), file_); + if (st == -1) { + return Status::IOError("Fail to close hdfs file, path is " + path_); + } + this->isOpen_ = false; + return Status::OK(); +} + +int64_t HdfsWriteableFile::GetFileSize() { + if (!OpenFile().IsOk()) { + return -1; + } + FileInfo fileInfo = fileSystem_->GetFileInfo(path_); + return fileInfo.size(); +} -} \ No newline at end of file +} // namespace fs \ No newline at end of file diff --git a/omnioperator/omniop-native-reader/cpp/src/filesystem/hdfs_file.h b/omnioperator/omniop-native-reader/cpp/src/filesystem/hdfs_file.h index ebfe0334fb2fc307612ed3fba8ff50f3710fcb4c..88d9533e5467ac839e7dc3bf78b1958be8dd45ce 100644 --- a/omnioperator/omniop-native-reader/cpp/src/filesystem/hdfs_file.h +++ b/omnioperator/omniop-native-reader/cpp/src/filesystem/hdfs_file.h @@ -59,7 +59,35 @@ private: hdfsFile file_; }; -} +class HdfsWriteableFile : public WriteableFile { +public: + HdfsWriteableFile(std::shared_ptr fileSystemPtr, const std::string &path, int64_t bufferSize = 0); + + ~HdfsWriteableFile(); + + Status Close() override; + + Status OpenFile() override; + + int64_t Write(const void *buffer, int32_t length) override; + + int64_t GetFileSize() override; + +private: + Status TryClose(); + + std::shared_ptr fileSystem_; + + const std::string &path_; + + int64_t bufferSize_; + + bool isOpen_ = false; + + hdfsFile file_{}; +}; + +} // namespace fs #endif //SPARK_THESTRAL_PLUGIN_HDFS_FILE_H diff --git a/omnioperator/omniop-native-reader/cpp/src/jni/OrcColumnarBatchJniWriter.cpp b/omnioperator/omniop-native-reader/cpp/src/jni/OrcColumnarBatchJniWriter.cpp new file mode 100644 index 0000000000000000000000000000000000000000..bc81f018aaf0c77a0dde5b5f488418e46746c83c --- /dev/null +++ b/omnioperator/omniop-native-reader/cpp/src/jni/OrcColumnarBatchJniWriter.cpp @@ -0,0 +1,410 @@ +/** + * Copyright (C) 2024-2024. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "OrcColumnarBatchJniWriter.h" +#include +#include +#include +#include +#include "jni_common.h" + +using namespace omniruntime::vec; +using namespace omniruntime::type; +using namespace orc; + +static constexpr int32_t DECIMAL_PRECISION_INDEX = 0; +static constexpr int32_t DECIMAL_SCALE_INDEX = 1; +static constexpr int32_t MINOR_VERSION_11 = 11; +static constexpr int32_t MINOR_VERSION_12 = 12; +static constexpr int32_t MAJOR_VERSION_0 = 0; + +JNIEXPORT jlong JNICALL Java_com_huawei_boostkit_write_jni_OrcColumnarBatchJniWriter_initializeOutputStream( + JNIEnv *env, jobject jObj, jobject uriJson) +{ + JNI_FUNC_START + jstring schemaJstr = (jstring)env->CallObjectMethod(uriJson, jsonMethodString, env->NewStringUTF("scheme")); + const char *schemaPtr = env->GetStringUTFChars(schemaJstr, nullptr); + std::string schemaStr(schemaPtr); + env->ReleaseStringUTFChars(schemaJstr, schemaPtr); + jstring fileJstr = (jstring)env->CallObjectMethod(uriJson, jsonMethodString, env->NewStringUTF("path")); + const char *filePtr = env->GetStringUTFChars(fileJstr, nullptr); + std::string fileStr(filePtr); + env->ReleaseStringUTFChars(fileJstr, filePtr); + jstring hostJstr = (jstring)env->CallObjectMethod(uriJson, jsonMethodString, env->NewStringUTF("host")); + const char *hostPtr = env->GetStringUTFChars(hostJstr, nullptr); + std::string hostStr(hostPtr); + env->ReleaseStringUTFChars(hostJstr, hostPtr); + jint port = (jint)env->CallIntMethod(uriJson, jsonMethodInt, env->NewStringUTF("port")); + UriInfo uri{schemaStr, fileStr, hostStr, std::to_string(port)}; + std::unique_ptr outputStream = orc::writeFileOverride(uri); + orc::OutputStream *outputStreamNew = outputStream.release(); + return (jlong)(outputStreamNew); + JNI_FUNC_END(runtimeExceptionClass) +} + +JNIEXPORT jlong JNICALL Java_com_huawei_boostkit_write_jni_OrcColumnarBatchJniWriter_initializeSchemaType( + JNIEnv *env, jobject jObj, jintArray orcTypeIds, jobjectArray schemaNames, jobjectArray decimalParam) +{ + JNI_FUNC_START + auto orcTypeIdPtr = env->GetIntArrayElements(orcTypeIds, JNI_FALSE); + if (orcTypeIdPtr == NULL) { + env->ThrowNew(runtimeExceptionClass, "Orc type ids should not be null."); + } + auto orcTypeIdLength = (int32_t)env->GetArrayLength(orcTypeIds); + auto writeType = createPrimitiveType(orc::TypeKind::STRUCT); + + for (int i = 0; i < orcTypeIdLength; ++i) { + jint orcType = orcTypeIdPtr[i]; + jstring schemaName = (jstring)env->GetObjectArrayElement(schemaNames, i); + const char *cSchemaName = env->GetStringUTFChars(schemaName, nullptr); + std::unique_ptr writeOrcType; + if (static_cast(orcType) == orc::TypeKind::DECIMAL) { + auto decimalParamArray = (jintArray)env->GetObjectArrayElement(decimalParam, i); + auto decimalParamArrayPtr = env->GetIntArrayElements(decimalParamArray, JNI_FALSE); + auto precision = decimalParamArrayPtr[DECIMAL_PRECISION_INDEX]; + auto scale = decimalParamArrayPtr[DECIMAL_SCALE_INDEX]; + writeOrcType = createDecimalType(precision, scale); + } + else { + writeOrcType = createPrimitiveType(static_cast(orcType)); + } + writeType->addStructField(std::string(cSchemaName), std::move(writeOrcType)); + env->ReleaseStringUTFChars(schemaName, cSchemaName); + } + + orc::Type *writerTypeNew = writeType.release(); + return (jlong)(writerTypeNew); + JNI_FUNC_END(runtimeExceptionClass) +} + +JNIEXPORT jlong JNICALL Java_com_huawei_boostkit_write_jni_OrcColumnarBatchJniWriter_initializeWriter( + JNIEnv *env, jobject jObj, jlong outputStream, jlong schemaType, jobject writerOptionsJson) +{ + JNI_FUNC_START + // Set write options + // other param should set here, like padding tolerance, columns use bloom + // filter, bloom filter fpp ... + orc::MemoryPool *pool = orc::getDefaultPool(); + orc::WriterOptions writerOptions; + writerOptions.setMemoryPool(pool); + + // Parsing and setting file version + jobject versionJosnObj = + (jobject)env->CallObjectMethod(writerOptionsJson, jsonMethodJsonObj, env->NewStringUTF("file version")); + jint majorJint = (jint)env->CallIntMethod(versionJosnObj, jsonMethodInt, env->NewStringUTF("major")); + jint minorJint = (jint)env->CallIntMethod(versionJosnObj, jsonMethodInt, env->NewStringUTF("minor")); + uint32_t major = (uint32_t)majorJint; + uint32_t minor = (uint32_t)minorJint; + if (minor == MINOR_VERSION_11 && major == 0) { + writerOptions.setFileVersion(FileVersion::v_0_11()); + } + else if (minor == MINOR_VERSION_12 && major == 0) { + writerOptions.setFileVersion(FileVersion::v_0_12()); + } + else { + env->ThrowNew(runtimeExceptionClass, "Unsupported file version."); + } + + jint compressionJint = (jint)env->CallIntMethod(writerOptionsJson, jsonMethodInt, env->NewStringUTF("compression")); + writerOptions.setCompression(static_cast(compressionJint)); + + jlong stripSizeJint = + (jlong)env->CallLongMethod(writerOptionsJson, jsonMethodLong, env->NewStringUTF("strip size")); + writerOptions.setStripeSize(stripSizeJint); + + jint rowIndexStrideJint = + (jint)env->CallIntMethod(writerOptionsJson, jsonMethodInt, env->NewStringUTF("row index stride")); + writerOptions.setRowIndexStride((uint64_t)rowIndexStrideJint); + + jint compressionStrategyJint = + (jint)env->CallIntMethod(writerOptionsJson, jsonMethodInt, env->NewStringUTF("compression strategy")); + writerOptions.setCompressionStrategy(static_cast(compressionStrategyJint)); + + orc::OutputStream *stream = (orc::OutputStream *)outputStream; + orc::Type *writeType = (orc::Type *)schemaType; + + std::unique_ptr writer = createWriter((*writeType), stream, writerOptions); + orc::Writer *writerNew = writer.release(); + return (jlong)(writerNew); + JNI_FUNC_END(runtimeExceptionClass) +} + +JNIEXPORT jlong JNICALL Java_com_huawei_boostkit_write_jni_OrcColumnarBatchJniWriter_initializeBatch(JNIEnv *env, + jobject jObj, + jlong writer, + jlong batchSize) +{ + orc::Writer *writerPtr = (orc::Writer *)writer; + auto rowBatch = writerPtr->createRowBatch(batchSize); + orc::ColumnVectorBatch *batch = rowBatch.release(); + return (jlong)(batch); +} + +template +void WriteVector(BaseVector *vec, ColumnVectorBatch *fieldBatch, bool isSplitWrite = false, long startPos = 0, + long endPos = 0) +{ + using T = typename NativeType::type; + auto vector = (Vector *)vec; + V *lvb = dynamic_cast(fieldBatch); + auto values = lvb->data.data(); + auto notNulls = lvb->notNull.data(); + long index = 0; + if (!isSplitWrite) { + startPos = 0; + endPos = vector->GetSize(); + } + if (vector->HasNull()) { + lvb->hasNulls = true; + for (long j = startPos; j < endPos; j++) { + if (vector->IsNull(j)) { + notNulls[index] = 0; + } + index++; + } + } + index = 0; + for (long j = startPos; j < endPos; j++) { + values[index] = vector->GetValue(j); + index++; + } +} + +void WriteDecimal128VectorBatch(BaseVector *vec, ColumnVectorBatch *fieldBatch, bool isSplitWrite = false, + long startPos = 0, long endPos = 0) +{ + auto vector = (Vector *)vec; + auto *lvb = dynamic_cast(fieldBatch); + auto values = lvb->values.data(); + auto notNulls = lvb->notNull.data(); + long index = 0; + if (!isSplitWrite) { + startPos = 0; + endPos = vector->GetSize(); + } + if (vector->HasNull()) { + lvb->hasNulls = true; + for (long j = startPos; j < endPos; j++) { + if (vector->IsNull(j)) { + notNulls[index] = 0; + } + index++; + } + } + index = 0; + for (long j = startPos; j < endPos; j++) { + values[index] = vector->GetValue(j).ToInt128(); + index++; + } +} + +void WriteDecimal64VectorBatch(BaseVector *vec, ColumnVectorBatch *fieldBatch, bool isSplitWrite = false, + long startPos = 0, long endPos = 0) +{ + auto vector = (Vector *)vec; + auto *lvb = dynamic_cast(fieldBatch); + auto values = lvb->values.data(); + auto notNulls = lvb->notNull.data(); + long index = 0; + if (!isSplitWrite) { + startPos = 0; + endPos = vector->GetSize(); + } + if (vector->HasNull()) { + lvb->hasNulls = true; + for (long j = startPos; j < endPos; j++) { + if (vector->IsNull(j)) { + notNulls[index] = 0; + } + index++; + } + } + index = 0; + for (long j = startPos; j < endPos; j++) { + values[index] = vector->GetValue(j); + index++; + } +} + +void WriteVarCharVectorBatch(BaseVector *baseVector, ColumnVectorBatch *fieldBatch, bool isSplitWrite = false, + long startPos = 0, long endPos = 0) +{ + auto vector = (Vector> *)baseVector; + auto *lvb = dynamic_cast(fieldBatch); + auto values = lvb->data.data(); + auto notNulls = lvb->notNull.data(); + auto lens = lvb->length.data(); + long index = 0; + if (!isSplitWrite) { + startPos = 0; + endPos = vector->GetSize(); + } + if (vector->HasNull()) { + lvb->hasNulls = true; + for (long j = startPos; j < endPos; j++) { + if (vector->IsNull(j)) { + notNulls[index] = 0; + } + index++; + } + } + index = 0; + for (long j = startPos; j < endPos; j++) { + values[index] = const_cast(vector->GetValue(j).data()); + lens[index] = vector->GetValue(j).size(); + index++; + } +} + +void WriteLongVectorBatch(JNIEnv *env, DataTypeId typeId, BaseVector *baseVector, ColumnVectorBatch *fieldBatch, + bool isSplitWrite = false, long startPos = 0, long endPos = 0) +{ + JNI_FUNC_START + switch (typeId) { + case OMNI_BOOLEAN: + return WriteVector(baseVector, fieldBatch, isSplitWrite, startPos, endPos); + case OMNI_SHORT: + return WriteVector(baseVector, fieldBatch, isSplitWrite, startPos, endPos); + case OMNI_INT: + return WriteVector(baseVector, fieldBatch, isSplitWrite, startPos, endPos); + case OMNI_LONG: + return WriteVector(baseVector, fieldBatch, isSplitWrite, startPos, endPos); + case OMNI_DATE32: + return WriteVector(baseVector, fieldBatch, isSplitWrite, startPos, endPos); + case OMNI_DATE64: + return WriteVector(baseVector, fieldBatch, isSplitWrite, startPos, endPos); + default: + env->ThrowNew(runtimeExceptionClass, "DealLongVectorBatch not support for type:" + typeId); + } + JNI_FUNC_END_VOID(runtimeExceptionClass) +} + +void WriteVector(JNIEnv *env, long *vecNativeId, int colNums, orc::StructVectorBatch *batch, const int *omniTypes, + const unsigned char *dataColumnsIds, bool isSplitWrite = false, long startPos = 0, long endPos = 0) +{ + JNI_FUNC_START + for (int i = 0; i < colNums; ++i) { + if (!dataColumnsIds[i]) { + continue; + } + auto vec = (BaseVector *)vecNativeId[i]; + auto typeId = static_cast(omniTypes[i]); + auto fieldBatch = batch->fields[i]; + switch (typeId) { + case OMNI_BOOLEAN: + case OMNI_SHORT: + case OMNI_INT: + case OMNI_LONG: + case OMNI_DATE32: + case OMNI_DATE64: + WriteLongVectorBatch(env, typeId, vec, fieldBatch, isSplitWrite, startPos, endPos); + break; + case OMNI_DOUBLE: + WriteVector(vec, fieldBatch, isSplitWrite, startPos, endPos); + break; + case OMNI_VARCHAR: + WriteVarCharVectorBatch(vec, fieldBatch, isSplitWrite, startPos, endPos); + break; + case OMNI_DECIMAL64: + WriteDecimal64VectorBatch(vec, fieldBatch, isSplitWrite, startPos, endPos); + break; + case OMNI_DECIMAL128: + WriteDecimal128VectorBatch(vec, fieldBatch, isSplitWrite, startPos, endPos); + break; + default: + env->ThrowNew(runtimeExceptionClass, "Native columnar write not support for this type: " + typeId); + } + } + JNI_FUNC_END_VOID(runtimeExceptionClass) +} + +JNIEXPORT void JNICALL Java_com_huawei_boostkit_write_jni_OrcColumnarBatchJniWriter_write( + JNIEnv *env, jobject jObj, jlong writer, jlongArray vecNativeId, jintArray omniTypes, jbooleanArray dataColumnsIds, + jint numRows) +{ + JNI_FUNC_START + + auto vecNativeIdPtr = env->GetLongArrayElements(vecNativeId, JNI_FALSE); + auto colNums = env->GetArrayLength(vecNativeId); + auto omniTypesPtr = env->GetIntArrayElements(omniTypes, JNI_FALSE); + auto dataColumnsIdsPtr = env->GetBooleanArrayElements(dataColumnsIds, JNI_FALSE); + orc::Writer *writerPtr = (orc::Writer *)writer; + auto rowBatch = writerPtr->createRowBatch(numRows); + rowBatch->numElements = numRows; + orc::StructVectorBatch *batch = static_cast(rowBatch.get()); + WriteVector(env, vecNativeIdPtr, colNums, batch, omniTypesPtr, dataColumnsIdsPtr); + writerPtr->add(*batch); + JNI_FUNC_END_VOID(runtimeExceptionClass) +} + +JNIEXPORT void JNICALL Java_com_huawei_boostkit_write_jni_OrcColumnarBatchJniWriter_splitWrite( + JNIEnv *env, jobject jObj, jlong writer, jlongArray vecNativeId, jintArray omniTypes, jbooleanArray dataColumnsIds, + jlong startPos, jlong endPos) +{ + JNI_FUNC_START + auto vecNativeIdPtr = env->GetLongArrayElements(vecNativeId, JNI_FALSE); + auto colNums = env->GetArrayLength(vecNativeId); + auto omniTypesPtr = env->GetIntArrayElements(omniTypes, JNI_FALSE); + auto dataColumnsIdsPtr = env->GetBooleanArrayElements(dataColumnsIds, JNI_FALSE); + auto writeRows = endPos - startPos; + orc::Writer *writerPtr = (orc::Writer *)writer; + auto rowBatch = writerPtr->createRowBatch(writeRows); + rowBatch->numElements = writeRows; + orc::StructVectorBatch *batch = static_cast(rowBatch.get()); + WriteVector(env, vecNativeIdPtr, colNums, batch, omniTypesPtr, dataColumnsIdsPtr, true, startPos, endPos); + writerPtr->add(*batch); + JNI_FUNC_END_VOID(runtimeExceptionClass) +} + +JNIEXPORT void JNICALL Java_com_huawei_boostkit_write_jni_OrcColumnarBatchJniWriter_close(JNIEnv *env, jobject jObj, + jlong outputStream, + jlong schemaType, + jlong writer) +{ + JNI_FUNC_START + + orc::Writer *writerPtr = (orc::Writer *)writer; + if (writerPtr == nullptr) { + env->ThrowNew(runtimeExceptionClass, "delete nullptr error for writer"); + } + + try { + writerPtr->close(); + } + catch (const char *e) { + std::string errorMsg = "close columnar writer fail:"; + errorMsg += e; + env->ThrowNew(runtimeExceptionClass, errorMsg.c_str()); + } + + orc::OutputStream *outputStreamPtr = (orc::OutputStream *)outputStream; + if (outputStreamPtr == nullptr) { + env->ThrowNew(runtimeExceptionClass, "delete nullptr error for write output stream"); + } + delete outputStreamPtr; + + orc::Type *schemaTypePtr = (orc::Type *)schemaType; + if (schemaTypePtr == nullptr) { + env->ThrowNew(runtimeExceptionClass, "delete nullptr error for write schema type"); + } + delete schemaTypePtr; + + delete writerPtr; + JNI_FUNC_END_VOID(runtimeExceptionClass) +} diff --git a/omnioperator/omniop-native-reader/cpp/src/jni/OrcColumnarBatchJniWriter.h b/omnioperator/omniop-native-reader/cpp/src/jni/OrcColumnarBatchJniWriter.h new file mode 100644 index 0000000000000000000000000000000000000000..d7762dd7016f5ba8ff647e4b9a3c243c75bfeca6 --- /dev/null +++ b/omnioperator/omniop-native-reader/cpp/src/jni/OrcColumnarBatchJniWriter.h @@ -0,0 +1,92 @@ +/** + * Copyright (C) 2024-2024. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* Header for class OMNI_RUNTIME_ORCCOLUMNARBATCHJNIWRITER_H */ + +#ifndef OMNI_RUNTIME_ORCCOLUMNARBATCHJNIWRITER_H +#define OMNI_RUNTIME_ORCCOLUMNARBATCHJNIWRITER_H + +#include +#include +#include +#include +#include +#include "orcfile/OrcFileOverride.hh" + +#ifdef __cplusplus +extern "C" { +#endif + +/* + * Class: com_huawei_boostkit_writer_jni_OrcColumnarBatchJniWriter + * Method: initializeOutputStream + * Signature: + */ +JNIEXPORT jlong JNICALL Java_com_huawei_boostkit_write_jni_OrcColumnarBatchJniWriter_initializeOutputStream( + JNIEnv *env, jobject jObj, jobject uriJson); + +/* + * Class: com_huawei_boostkit_writer_jni_OrcColumnarBatchJniWriter + * Method: initializeSchemaType + * Signature: + */ +JNIEXPORT jlong JNICALL Java_com_huawei_boostkit_write_jni_OrcColumnarBatchJniWriter_initializeSchemaType( + JNIEnv *env, jobject jObj, jintArray orcTypeIds, jobjectArray schemaNames, jobjectArray decimalParam); + +/* + * Class: com_huawei_boostkit_writer_jni_OrcColumnarBatchJniWriter + * Method: initializeWriter + * Signature: + */ +JNIEXPORT jlong JNICALL Java_com_huawei_boostkit_write_jni_OrcColumnarBatchJniWriter_initializeWriter( + JNIEnv *env, jobject jObj, jlong outputStream, jlong schemaType, jobject writeOptionsJson); + + +/* + * Class: com_huawei_boostkit_writer_jni_OrcColumnarBatchJniWriter + * Method: write + * Signature: + */ +JNIEXPORT void JNICALL Java_com_huawei_boostkit_write_jni_OrcColumnarBatchJniWriter_write( + JNIEnv *env, jobject jObj, jlong writer, jlongArray vecNativeId, jintArray omniTypes, jbooleanArray dataColumnsIds, + jint numRows); + +/* + * Class: com_huawei_boostkit_writer_jni_OrcColumnarBatchJniWriter + * Method: write + * Signature: + */ +JNIEXPORT void JNICALL Java_com_huawei_boostkit_write_jni_OrcColumnarBatchJniWriter_splitWrite( + JNIEnv *env, jobject jObj, jlong writer, jlongArray vecNativeId, jintArray omniTypes, jbooleanArray dataColumnsIds, + jlong startPos, jlong endPos); + +/* + * Class: com_huawei_boostkit_writer_jni_OrcColumnarBatchJniWriter + * Method: close + * Signature: + */ +JNIEXPORT void JNICALL Java_com_huawei_boostkit_write_jni_OrcColumnarBatchJniWriter_close(JNIEnv *env, jobject jObj, + jlong outputStream, + jlong schemaType, + jlong writer); + +#ifdef __cplusplus +} +#endif +#endif diff --git a/omnioperator/omniop-native-reader/cpp/src/orcfile/OrcFileOverride.cc b/omnioperator/omniop-native-reader/cpp/src/orcfile/OrcFileOverride.cc index b52401b1a3e3eea8cf5841cefdc871d96a1ec99d..c5e7719122261fd7778c5eda2821fd2b4225fe1f 100644 --- a/omnioperator/omniop-native-reader/cpp/src/orcfile/OrcFileOverride.cc +++ b/omnioperator/omniop-native-reader/cpp/src/orcfile/OrcFileOverride.cc @@ -28,4 +28,13 @@ namespace orc { return orc::readLocalFile(std::string(uri.Path())); } } + + std::unique_ptr writeFileOverride(const UriInfo &uri) { + if (uri.Scheme() == "hdfs") { + return orc::createHdfsFileOutputStream(uri); + } + else { + return orc::writeLocalFile(std::string(uri.Path())); + } + } } diff --git a/omnioperator/omniop-native-reader/cpp/src/orcfile/OrcFileOverride.hh b/omnioperator/omniop-native-reader/cpp/src/orcfile/OrcFileOverride.hh index 8d038627d788b3371e24cd6d4651430457489589..e42a7e370c02e326a0c76cf3066006d0fa494f47 100644 --- a/omnioperator/omniop-native-reader/cpp/src/orcfile/OrcFileOverride.hh +++ b/omnioperator/omniop-native-reader/cpp/src/orcfile/OrcFileOverride.hh @@ -30,17 +30,29 @@ namespace orc { + /** + * Create a input stream to a local file or HDFS file if path begins with "hdfs://" + * @param uri the UriInfo of HDFS + */ + ORC_UNIQUE_PTR readFileOverride(const UriInfo &uri); + /** - * Create a stream to a local file or HDFS file if path begins with "hdfs://" + * Create a output stream to a local file or HDFS file if path begins with "hdfs://" * @param uri the UriInfo of HDFS */ - ORC_UNIQUE_PTR readFileOverride(const UriInfo &uri); + ORC_UNIQUE_PTR writeFileOverride(const UriInfo &uri); /** - * Create a stream to an HDFS file. + * Create a input stream to an HDFS file. * @param uri the UriInfo of HDFS */ ORC_UNIQUE_PTR createHdfsFileInputStream(const UriInfo &uri); + + /** + * Create a output stream to an HDFS file. + * @param uri the UriInfo of HDFS + */ + ORC_UNIQUE_PTR createHdfsFileOutputStream(const UriInfo &uri); } #endif diff --git a/omnioperator/omniop-native-reader/cpp/src/orcfile/OrcHdfsFileOverride.cc b/omnioperator/omniop-native-reader/cpp/src/orcfile/OrcHdfsFileOverride.cc index 2a877087b3ef0fc02052601a4fbb60273a68f790..6d8c77255bc264c5fae0163eb6823ba00e253afc 100644 --- a/omnioperator/omniop-native-reader/cpp/src/orcfile/OrcHdfsFileOverride.cc +++ b/omnioperator/omniop-native-reader/cpp/src/orcfile/OrcHdfsFileOverride.cc @@ -105,4 +105,42 @@ namespace orc { std::unique_ptr createHdfsFileInputStream(const UriInfo &uri) { return std::unique_ptr(new HdfsFileInputStreamOverride(uri)); } + + class HdfsFileOutputStreamOverride : public OutputStream { + private: + std::string filename_; + std::unique_ptr hdfs_file_; + uint64_t total_length_{0}; + const uint64_t WRITE_SIZE_ = 1024 * 1024; + + public: + explicit HdfsFileOutputStreamOverride(const UriInfo &uri) { + this->filename_ = uri.Path(); + std::shared_ptr fileSystemPtr = getHdfsFileSystem(uri.Host(), uri.Port()); + this->hdfs_file_ = std::make_unique(fileSystemPtr, this->filename_, 0); + Status openFileSt = hdfs_file_->OpenFile(); + if (!openFileSt.IsOk()) { + throw IOException(openFileSt.ToString()); + } + + this->total_length_ = hdfs_file_->GetFileSize(); + } + + ~HdfsFileOutputStreamOverride() override {} + + [[nodiscard]] uint64_t getLength() const override { return total_length_; } + + + [[nodiscard]] uint64_t getNaturalWriteSize() const override { return WRITE_SIZE_; } + + void write(const void *buf, size_t length) override { hdfs_file_->Write(buf, length); } + + [[nodiscard]] const std::string &getName() const override { return filename_; } + + void close() override { hdfs_file_->Close(); } + }; + + std::unique_ptr createHdfsFileOutputStream(const UriInfo &uri) { + return std::unique_ptr(new HdfsFileOutputStreamOverride(uri)); + } } diff --git a/omnioperator/omniop-native-reader/java/src/main/java/com/huawei/boostkit/write/jni/OrcColumnarBatchJniWriter.java b/omnioperator/omniop-native-reader/java/src/main/java/com/huawei/boostkit/write/jni/OrcColumnarBatchJniWriter.java new file mode 100644 index 0000000000000000000000000000000000000000..0788ec39c2c94639c4f9c347e580483715d63459 --- /dev/null +++ b/omnioperator/omniop-native-reader/java/src/main/java/com/huawei/boostkit/write/jni/OrcColumnarBatchJniWriter.java @@ -0,0 +1,36 @@ +/* + * Copyright (C) 2024-2024. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.huawei.boostkit.write.jni; + +import com.huawei.boostkit.scan.jni.NativeReaderLoader; +import org.json.JSONObject; + +public class OrcColumnarBatchJniWriter { + public OrcColumnarBatchJniWriter(){ + NativeReaderLoader.getInstance(); + } + + public native long initializeOutputStream(JSONObject uriJson); + public native long initializeSchemaType(int[]orcTypeIds, String[] schemaNames, int[][] decimalParam); + public native long initializeWriter(long outputStream, long schemaType, JSONObject writerOptions); + public native long initializeBatch(long writer, long batchSize); + public native void write(long writer, long[] vecNativeId, int[] omniTypes, boolean[] dataColumnsIds, int rowNums); + public native void splitWrite(long writer, long[] vecNativeId, int[] omniTypes, boolean[] dataColumnsIds, long startPos, long endPos); + public native void close(long outputSteam, long schemaType, long writer); +} \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchWriter.java b/omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchWriter.java new file mode 100644 index 0000000000000000000000000000000000000000..d5a8bb64f2674308d67206ba5d3ace3541f1fe73 --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchWriter.java @@ -0,0 +1,210 @@ +/* + * Copyright (C) 2024-2024. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.huawei.boostkit.spark.jni; + +import static org.apache.orc.CompressionKind.SNAPPY; +import static org.apache.orc.CompressionKind.ZLIB; + +import com.huawei.boostkit.write.jni.OrcColumnarBatchJniWriter; + +import nova.hetu.omniruntime.vector.Vec; + +import org.apache.orc.OrcFile; +import org.apache.spark.sql.execution.vectorized.OmniColumnVector; +import org.apache.spark.sql.types.BooleanType; +import org.apache.spark.sql.types.CharType; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.DateType; +import org.apache.spark.sql.types.DecimalType; +import org.apache.spark.sql.types.DoubleType; +import org.apache.spark.sql.types.IntegerType; +import org.apache.spark.sql.types.LongType; +import org.apache.spark.sql.types.ShortType; +import org.apache.spark.sql.types.StringType; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.types.VarcharType; +import org.apache.spark.sql.vectorized.ColumnarBatch; +import org.json.JSONObject; + +import java.net.URI; +import java.util.ArrayList; + +public class OrcColumnarBatchWriter { + public OrcColumnarBatchWriter() { + jniWriter = new OrcColumnarBatchJniWriter(); + } + + public enum OrcLibTypeKind { + BOOLEAN, + BYTE, + SHORT, + INT, + LONG, + FLOAT, + DOUBLE, + STRING, + BINARY, + TIMESTAMP, + LIST, + MAP, + STRUCT, + UNION, + DECIMAL, + DATE, + VARCHAR, + CHAR, + TIMESTAMP_INSTANT + } + + public void initializeOutputStreamJava(URI uri) { + JSONObject uriJson = new JSONObject(); + + uriJson.put("scheme", uri.getScheme() == null ? "" : uri.getScheme()); + uriJson.put("host", uri.getHost() == null ? "" : uri.getHost()); + uriJson.put("port", uri.getPort()); + uriJson.put("path", uri.getPath() == null ? "" : uri.getPath()); + + outputStream = jniWriter.initializeOutputStream(uriJson); + } + + public void initializeSchemaTypeJava(StructType dataSchema) { + schemaType = jniWriter.initializeSchemaType(sparkTypeToOrcLibType(dataSchema), extractSchemaName(dataSchema), + extractDecimalParam(dataSchema)); + } + + /** + * Init Orc writer. + * + * @param uri of output file path + * @param options write file options + */ + public void initializeWriterJava(URI uri, StructType dataSchema, OrcFile.WriterOptions options) { + JSONObject writerOptionsJson = new JSONObject(); + + JSONObject versionJob = new JSONObject(); + versionJob.put("major", options.getVersion().getMajor()); + versionJob.put("minor", options.getVersion().getMinor()); + writerOptionsJson.put("file version", versionJob); + + writerOptionsJson.put("compression", options.getCompress().ordinal()); + writerOptionsJson.put("strip size", options.getStripeSize()); + writerOptionsJson.put("compression block size", options.getBlockSize()); + writerOptionsJson.put("row index stride", options.getRowIndexStride()); + writerOptionsJson.put("compression strategy", options.getCompressionStrategy().ordinal()); + writerOptionsJson.put("padding tolerance", options.getPaddingTolerance()); + writerOptionsJson.put("columns use bloom filter", options.getBloomFilterColumns()); + writerOptionsJson.put("bloom filter fpp", options.getBloomFilterFpp()); + + writer = jniWriter.initializeWriter(outputStream, schemaType, writerOptionsJson); + } + + public int[] sparkTypeToOrcLibType(StructType dataSchema) { + int[] orcLibType = new int[dataSchema.length()]; + for (int i = 0; i < dataSchema.length(); i++) { + orcLibType[i] = sparkTypeToOrcLibType(dataSchema.fields()[i].dataType()); + } + return orcLibType; + } + + public int sparkTypeToOrcLibType(DataType dataType) { + if (dataType instanceof BooleanType) { + return OrcLibTypeKind.BOOLEAN.ordinal(); + } else if (dataType instanceof ShortType) { + return OrcLibTypeKind.SHORT.ordinal(); + } else if (dataType instanceof IntegerType) { + return OrcLibTypeKind.INT.ordinal(); + } else if (dataType instanceof LongType) { + return OrcLibTypeKind.LONG.ordinal(); + } else if (dataType instanceof DateType) { + return OrcLibTypeKind.DATE.ordinal(); + } else if (dataType instanceof DoubleType) { + return OrcLibTypeKind.DOUBLE.ordinal(); + } else if (dataType instanceof VarcharType) { + return OrcLibTypeKind.VARCHAR.ordinal(); + } else if (dataType instanceof StringType) { + return OrcLibTypeKind.STRING.ordinal(); + } else if (dataType instanceof CharType) { + return OrcLibTypeKind.CHAR.ordinal(); + } else if (dataType instanceof DecimalType) { + return OrcLibTypeKind.DECIMAL.ordinal(); + } else { + throw new RuntimeException( + "UnSupport type convert spark type " + dataType.simpleString() + " to orc lib type"); + } + } + + public String[] extractSchemaName(StructType dataSchema) { + String[] schemaNames = new String[dataSchema.length()]; + for (int i = 0; i < dataSchema.length(); i++) { + schemaNames[i] = dataSchema.fields()[i].name(); + } + return schemaNames; + } + + public int[][] extractDecimalParam(StructType dataSchema) { + int paramNum = 2; + int precisionIndex = 0; + int scaleIndex = 1; + int[][] decimalParams = new int[dataSchema.length()][paramNum]; + for (int i = 0; i < dataSchema.length(); i++) { + DataType dataType = dataSchema.fields()[i].dataType(); + if (dataType instanceof DecimalType) { + DecimalType decimal = (DecimalType) dataType; + decimalParams[i][precisionIndex] = decimal.precision(); + decimalParams[i][scaleIndex] = decimal.scale(); + } + } + return decimalParams; + } + + public void write(int[] omniTypes, boolean[] dataColumnsIds, ColumnarBatch columBatch) { + + long[] vecNativeIds = new long[columBatch.numCols()]; + for (int i = 0; i < columBatch.numCols(); i++) { + OmniColumnVector omniVec = (OmniColumnVector) columBatch.column(i); + Vec vec = omniVec.getVec(); + vecNativeIds[i] = vec.getNativeVector(); + } + + jniWriter.write(writer, vecNativeIds, omniTypes, dataColumnsIds, columBatch.numRows()); + } + + public void splitWrite(int[] omniTypes, boolean[] dataColumnsIds, ColumnarBatch inputBatch, long startPos, long endPos) { + long[] vecNativeIds = new long[inputBatch.numCols()]; + for (int i = 0; i < inputBatch.numCols(); i++) { + OmniColumnVector omniVec = (OmniColumnVector) inputBatch.column(i); + Vec vec = omniVec.getVec(); + vecNativeIds[i] = vec.getNativeVector(); + } + + jniWriter.splitWrite(writer, vecNativeIds, omniTypes, dataColumnsIds, startPos, endPos); + } + + public void close() { + jniWriter.close(outputStream, schemaType, writer); + } + + public long outputStream; + + public long schemaType; + + public long writer; + + public OrcColumnarBatchJniWriter jniWriter; +} diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarPlugin.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarPlugin.scala index 76e0d2eb65e1bb0af6914d10f412b8cfe8b13598..675f493939e019176054065a8b87423491e95c96 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarPlugin.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarPlugin.scala @@ -18,6 +18,7 @@ package com.huawei.boostkit.spark +import nova.hetu.omniruntime.memory.MemoryManager import com.huawei.boostkit.spark.Constant.OMNI_IS_ADAPTIVE_CONTEXT import com.huawei.boostkit.spark.expression.OmniExpressionAdaptor import com.huawei.boostkit.spark.util.PhysicalPlanSelector @@ -39,12 +40,13 @@ import org.apache.spark.sql.types.ColumnarBatchSupportUtil.checkColumnarBatchSup import org.apache.spark.sql.catalyst.planning.PhysicalAggregation import org.apache.spark.sql.catalyst.plans.LeftSemi import org.apache.spark.sql.catalyst.plans.logical.Aggregate +import org.apache.spark.sql.execution.command.{DataWritingCommand, DataWritingCommandExec} +import org.apache.spark.sql.execution.datasources.orc.{OmniOrcFileFormat, OrcFileFormat} +import org.apache.spark.sql.execution.datasources.{FileFormat, InsertIntoHadoopFsRelationCommand, OmniInsertIntoHadoopFsRelationCommand} import org.apache.spark.sql.execution.util.SparkMemoryUtils.addLeakSafeTaskCompletionListener import scala.collection.mutable.ListBuffer -import nova.hetu.omniruntime.memory.MemoryManager - case class ColumnarPreOverrides(isSupportAdaptive: Boolean = true) extends Rule[SparkPlan] { val columnarConf: ColumnarPluginConfig = ColumnarPluginConfig.getSessionConf @@ -57,6 +59,7 @@ case class ColumnarPreOverrides(isSupportAdaptive: Boolean = true) val enableRollupOptimization: Boolean = columnarConf.enableRollupOptimization val enableRowShuffle: Boolean = columnarConf.enableRowShuffle val columnsThreshold: Int = columnarConf.columnsThreshold + val enableColumnarDataWritingCommand: Boolean = columnarConf.enableColumnarDataWritingCommand def checkBhjRightChild(x: Any): Boolean = { x match { @@ -552,6 +555,41 @@ case class ColumnarPreOverrides(isSupportAdaptive: Boolean = true) val child = replaceWithColumnarPlan(plan.child) logInfo(s"Columnar Processing for ${plan.getClass} is currently supported.") ColumnarCoalesceExec(plan.numPartitions, child) + case plan: DataWritingCommandExec if enableColumnarDataWritingCommand => + val child = replaceWithColumnarPlan(plan.child) + logInfo(s"Columnar Processing for ${plan.getClass} is currently supported.") + var unSupportedColumnarCommand = false + var unSupportedFileFormat = false + val omniCmd = plan.cmd match { + case cmd: InsertIntoHadoopFsRelationCommand => + logInfo(s"Columnar Processing for ${cmd.getClass} is currently supported.") + val fileFormat: FileFormat = cmd.fileFormat match { + case _: OrcFileFormat => new OmniOrcFileFormat() + case format => + logInfo(s"Unsupported ${format.getClass} file " + + s"format for columnar data write command.") + unSupportedFileFormat = true + null + } + if (unSupportedFileFormat) { + cmd + } else { + OmniInsertIntoHadoopFsRelationCommand(cmd.outputPath, cmd.staticPartitions, + cmd.ifPartitionNotExists, cmd.partitionColumns, cmd.bucketSpec, fileFormat, + cmd.options, cmd.query, cmd.mode, cmd.catalogTable, + cmd.fileIndex, cmd.outputColumnNames + ) + } + case cmd: DataWritingCommand => + logInfo(s"Columnar Processing for ${cmd.getClass} is currently not supported.") + unSupportedColumnarCommand = true + cmd + } + if (!unSupportedColumnarCommand && !unSupportedFileFormat) { + ColumnarDataWritingCommandExec(omniCmd, child) + } else { + plan + } case p => val children = plan.children.map(replaceWithColumnarPlan) logInfo(s"Columnar Processing for ${p.getClass} is currently not supported.") diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarPluginConfig.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarPluginConfig.scala index 2622f9b25b539cdd553a275bd5490bcc4f45de92..7797323012fde52830af5988e3c3c48a8c94b092 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarPluginConfig.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/ColumnarPluginConfig.scala @@ -111,6 +111,10 @@ class ColumnarPluginConfig(conf: SQLConf) extends Logging { .getConfString("spark.omni.sql.columnar.sortMergeJoin", "true") .toBoolean + val enableColumnarDataWritingCommand: Boolean = conf + .getConfString("spark.omni.sql.columnar.dataWritingCommand", "true") + .toBoolean + val enableTakeOrderedAndProject: Boolean = conf .getConfString("spark.omni.sql.columnar.takeOrderedAndProject", "true").toBoolean diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/TransformHintRule.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/TransformHintRule.scala index e37e701ffc9f1ee08d8f8fd0ee8621209dcb8f65..553463d56a7876adb129ba282481e34ec2909156 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/TransformHintRule.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/TransformHintRule.scala @@ -24,10 +24,11 @@ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.trees.TreeNodeTag import org.apache.spark.sql.execution.adaptive.{BroadcastQueryStageExec, OmniAQEShuffleReadExec} import org.apache.spark.sql.execution.aggregate.HashAggregateExec +import org.apache.spark.sql.execution.command.DataWritingCommandExec import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec} import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, BroadcastNestedLoopJoinExec, ColumnarBroadcastHashJoinExec, ColumnarShuffledHashJoinExec, ColumnarSortMergeJoinExec, ShuffledHashJoinExec, SortMergeJoinExec} import org.apache.spark.sql.execution.window.WindowExec -import org.apache.spark.sql.execution.{CoalesceExec, CodegenSupport, ColumnarBroadcastExchangeExec, ColumnarCoalesceExec, ColumnarExpandExec, ColumnarFileSourceScanExec, ColumnarFilterExec, ColumnarGlobalLimitExec, ColumnarHashAggregateExec, ColumnarLocalLimitExec, ColumnarProjectExec, ColumnarShuffleExchangeExec, ColumnarSortExec, ColumnarTakeOrderedAndProjectExec, ColumnarTopNSortExec, ColumnarUnionExec, ColumnarWindowExec, ExpandExec, FileSourceScanExec, FilterExec, GlobalLimitExec, LocalLimitExec, ProjectExec, SortExec, SparkPlan, TakeOrderedAndProjectExec, TopNSortExec, UnionExec} +import org.apache.spark.sql.execution.{CoalesceExec, CodegenSupport, ColumnarBroadcastExchangeExec, ColumnarCoalesceExec, ColumnarDataWritingCommandExec, ColumnarExpandExec, ColumnarFileSourceScanExec, ColumnarFilterExec, ColumnarGlobalLimitExec, ColumnarHashAggregateExec, ColumnarLocalLimitExec, ColumnarProjectExec, ColumnarShuffleExchangeExec, ColumnarSortExec, ColumnarTakeOrderedAndProjectExec, ColumnarTopNSortExec, ColumnarUnionExec, ColumnarWindowExec, ExpandExec, FileSourceScanExec, FilterExec, GlobalLimitExec, LocalLimitExec, ProjectExec, SortExec, SparkPlan, TakeOrderedAndProjectExec, TopNSortExec, UnionExec} import org.apache.spark.sql.types.ColumnarBatchSupportUtil.checkColumnarBatchSupport trait TransformHint { @@ -125,6 +126,7 @@ case class AddTransformHintRule() extends Rule[SparkPlan] { val enableLocalColumnarLimit: Boolean = columnarConf.enableLocalColumnarLimit val enableGlobalColumnarLimit: Boolean = columnarConf.enableGlobalColumnarLimit val enableColumnarCoalesce: Boolean = columnarConf.enableColumnarCoalesce + val enableColumnarDataWritingCommand: Boolean = columnarConf.enableColumnarDataWritingCommand override def apply(plan: SparkPlan): SparkPlan = { addTransformableTags(plan) @@ -381,6 +383,14 @@ case class AddTransformHintRule() extends Rule[SparkPlan] { } ColumnarCoalesceExec(plan.numPartitions, plan.child).buildCheck() TransformHints.tagTransformable(plan) + case plan: DataWritingCommandExec => + if (!enableColumnarDataWritingCommand) { + TransformHints.tagNotTransformable( + plan, "columnar data writing is not support") + return + } + ColumnarDataWritingCommandExec(plan.cmd, plan.child).buildCheck() + TransformHints.tagTransformable(plan) case _ => TransformHints.tagTransformable(plan) } diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/expression/OmniExpressionAdaptor.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/expression/OmniExpressionAdaptor.scala index d1f911a8cde2a7361c3d23a3877ba7940f27a883..b13f82acc9b5fb7dfe49418c13f9a9ce12d2c96f 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/expression/OmniExpressionAdaptor.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/expression/OmniExpressionAdaptor.scala @@ -400,7 +400,9 @@ object OmniExpressionAdaptor extends Logging { .put("arguments", new JSONArray().put(rewriteToOmniJsonExpressionLiteralJsonObject(round.child, exprsIndexMap)) .put(rewriteToOmniJsonExpressionLiteralJsonObject(round.scale, exprsIndexMap))) - case attr: Attribute => toOmniJsonAttribute(attr, exprsIndexMap(attr.exprId)) + case attr: Attribute => toOmniJsonLeafExpression(attr, exprsIndexMap(attr.exprId)) + case boundReference: BoundReference => + toOmniJsonLeafExpression(boundReference, boundReference.ordinal) // might_contain case bloomFilterMightContain: BloomFilterMightContain => @@ -657,14 +659,17 @@ object OmniExpressionAdaptor extends Logging { jsonObject } - def toOmniJsonAttribute(attr: Attribute, colVal: Int): JSONObject = { + def toOmniJsonLeafExpression(attr: LeafExpression, colVal: Int): JSONObject = { val omniDataType = sparkTypeToOmniExpType(attr.dataType) attr.dataType match { case StringType => new JSONObject().put("exprType", "FIELD_REFERENCE") .put("dataType", omniDataType.toInt) .put("colVal", colVal) - .put("width", getStringLength(attr.metadata)) + .put("width", attr match { + case attribute: Attribute => getStringLength(attribute.metadata) + case _ => DEFAULT_STRING_TYPE_LENGTH + }) case dt: DecimalType => new JSONObject().put("exprType", "FIELD_REFERENCE") .put("colVal", colVal) diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarDataWritingCommandExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarDataWritingCommandExec.scala new file mode 100644 index 0000000000000000000000000000000000000000..b16c01fd06b054793b3d1ee1cc9188fc24875b1e --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarDataWritingCommandExec.scala @@ -0,0 +1,88 @@ +/* + * Copyright (C) 2024-2024. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import com.huawei.boostkit.spark.expression.OmniExpressionAdaptor.sparkTypeToOmniType +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.execution.command.DataWritingCommand +import org.apache.spark.sql.execution.metric.SQLMetric +import org.apache.spark.sql.vectorized.ColumnarBatch + +/** + * A physical operator that executes the run method of a `ColumnarDataWritingCommand` and + * saves the result to prevent multiple executions. + * + * @param cmd the `ColumnarDataWritingCommand` this operator will run. + * @param child the physical plan child ran by the `DataWritingCommand`. + */ +case class ColumnarDataWritingCommandExec(@transient cmd: DataWritingCommand, child: SparkPlan) + extends UnaryExecNode { + + override lazy val metrics: Map[String, SQLMetric] = cmd.metrics + + protected[sql] lazy val sideEffectResult: Seq[ColumnarBatch] = { + val converter = CatalystTypeConverters.createToCatalystConverter(schema) + val rows = cmd.run(session, child) + + rows.map(converter(_).asInstanceOf[ColumnarBatch]) + } + + override def output: Seq[Attribute] = cmd.output + + override def nodeName: String = "OmniExecute " + cmd.nodeName + + // override the default one, otherwise the `cmd.nodeName` will appear twice from simpleString + override def argString(maxFields: Int): String = cmd.argString(maxFields) + + override def executeCollect(): Array[InternalRow] = { + throw new UnsupportedOperationException("This operator doesn't support executeCollect") + } + + override def executeToIterator(): Iterator[InternalRow] = { + throw new UnsupportedOperationException("This operator doesn't support executeToIterator") + } + + override def executeTake(limit: Int): Array[InternalRow] = { + throw new UnsupportedOperationException("This operator doesn't support executeTake") + } + + override def executeTail(limit: Int): Array[InternalRow] = { + throw new UnsupportedOperationException("This operator doesn't support executeTail") + } + + override def supportsColumnar: Boolean = true + + def buildCheck(): Unit = { + child.output.foreach(exp => sparkTypeToOmniType(exp.dataType, exp.metadata)) + } + + protected override def doExecute(): RDD[InternalRow] = { + throw new UnsupportedOperationException("This operator doesn't support doExecute") + } + + override def doExecuteColumnar(): RDD[ColumnarBatch] = { + sparkContext.parallelize(sideEffectResult, 1) + } + + override protected def withNewChildInternal(newChild: SparkPlan): ColumnarDataWritingCommandExec = + copy(child = newChild) + +} \ No newline at end of file diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/OmniFileFormatDataWriter.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/OmniFileFormatDataWriter.scala new file mode 100644 index 0000000000000000000000000000000000000000..8983a6f180fdc618fe465144e461f41417902e4f --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/OmniFileFormatDataWriter.scala @@ -0,0 +1,312 @@ +/* + * Copyright (C) 2024-2024. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources + +import org.apache.hadoop.fs.Path +import org.apache.hadoop.mapreduce.TaskAttemptContext +import org.apache.spark.internal.io.{FileCommitProtocol, FileNameSpec} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils +import org.apache.spark.sql.catalyst.expressions.{Cast, Concat, Expression, Literal, ScalaUDF, UnsafeProjection, UnsafeRow} +import org.apache.spark.sql.connector.write.DataWriter +import org.apache.spark.sql.execution.datasources.orc.OmniOrcOutputWriter +import org.apache.spark.sql.execution.metric.{CustomMetrics, SQLMetric} +import org.apache.spark.sql.types.StringType +import org.apache.spark.sql.vectorized.ColumnarBatch +import org.apache.spark.util.Utils + +import scala.collection.mutable + + +/** Writes data to a single directory (used for non-dynamic-partition writes). */ +class OmniSingleDirectoryDataWriter( + description: WriteJobDescription, + taskAttemptContext: TaskAttemptContext, + committer: FileCommitProtocol, + customMetrics: Map[String, SQLMetric] = Map.empty) + extends FileFormatDataWriter(description, taskAttemptContext, committer, customMetrics) { + private var fileCounter: Int = _ + private var recordsInFile: Long = _ + // Initialize currentWriter and statsTrackers + newOutputWriter() + + private def newOutputWriter(): Unit = { + recordsInFile = 0 + releaseResources() + + val ext = description.outputWriterFactory.getFileExtension(taskAttemptContext) + val currentPath = committer.newTaskTempFile( + taskAttemptContext, + None, + f"-c$fileCounter%03d" + ext) + + currentWriter = description.outputWriterFactory.newInstance( + path = currentPath, + dataSchema = description.dataColumns.toStructType, + context = taskAttemptContext) + + currentWriter match { + case _: OmniOrcOutputWriter => + currentWriter.asInstanceOf[OmniOrcOutputWriter] + .initialize(description.allColumns, description.dataColumns) + case _ => + throw new UnsupportedOperationException + (s"Unsupported ${currentWriter.getClass} Output writer!") + } + statsTrackers.foreach(_.newFile(currentPath)) + } + + override def write(record: InternalRow): Unit = { + assert(record.isInstanceOf[OmniInternalRow]) + if (description.maxRecordsPerFile > 0 && recordsInFile >= description.maxRecordsPerFile) { + fileCounter += 1 + assert(fileCounter < MAX_FILE_COUNTER, + s"File counter $fileCounter is beyond max value $MAX_FILE_COUNTER") + + newOutputWriter() + } + + currentWriter.write(record) + statsTrackers.foreach(_.newRow(currentWriter.path, record)) + recordsInFile += record.asInstanceOf[OmniInternalRow].batch.numRows() + } +} + +/** + * Holds common logic for writing data with dynamic partition writes, meaning it can write to + * multiple directories (partitions) or files (bucketing). + */ +abstract class OmniBaseDynamicPartitionDataWriter( + description: WriteJobDescription, + taskAttemptContext: TaskAttemptContext, + committer: FileCommitProtocol, + customMetrics: Map[String, SQLMetric]) + extends FileFormatDataWriter(description, taskAttemptContext, committer, customMetrics) { + + /** Flag saying whether or not the data to be written out is partitioned. */ + protected val isPartitioned = description.partitionColumns.nonEmpty + + /** Flag saying whether or not the data to be written out is bucketed. */ + protected val isBucketed = description.bucketSpec.isDefined + + assert(isPartitioned || isBucketed, + s"""DynamicPartitionWriteTask should be used for writing out data that's either + |partitioned or bucketed. In this case neither is true. + |WriteJobDescription: $description + """.stripMargin) + + /** Number of records in current file. */ + protected var recordsInFile: Long = _ + + /** + * File counter for writing current partition or bucket. For same partition or bucket, + * we may have more than one file, due to number of records limit per file. + */ + protected var fileCounter: Int = _ + + /** Extracts the partition values out of an input row. */ + protected lazy val getPartitionValues: InternalRow => UnsafeRow = { + val proj = UnsafeProjection.create(description.partitionColumns, description.allColumns) + row => proj(row) + } + + /** Expression that given partition columns builds a path string like: col1=val/col2=val/... */ + private lazy val partitionPathExpression: Expression = Concat( + description.partitionColumns.zipWithIndex.flatMap { case (c, i) => + val partitionName = ScalaUDF( + ExternalCatalogUtils.getPartitionPathString _, + StringType, + Seq(Literal(c.name), Cast(c, StringType, Option(description.timeZoneId)))) + if (i == 0) Seq(partitionName) else Seq(Literal(Path.SEPARATOR), partitionName) + }) + + /** + * Evaluates the `partitionPathExpression` above on a row of `partitionValues` and returns + * the partition string. + */ + private lazy val getPartitionPath: InternalRow => String = { + val proj = UnsafeProjection.create(Seq(partitionPathExpression), description.partitionColumns) + row => proj(row).getString(0) + } + + /** Given an input row, returns the corresponding `bucketId` */ + protected lazy val getBucketId: InternalRow => Int = { + val proj = + UnsafeProjection.create(Seq(description.bucketSpec.get.bucketIdExpression), + description.allColumns) + row => proj(row).getInt(0) + } + + /** Returns the data columns to be written given an input row */ + protected val getOutputRow = + UnsafeProjection.create(description.dataColumns, description.allColumns) + + /** + * Opens a new OutputWriter given a partition key and/or a bucket id. + * If bucket id is specified, we will append it to the end of the file name, but before the + * file extension, e.g. part-r-00009-ea518ad4-455a-4431-b471-d24e03814677-00002.gz.parquet + * + * @param partitionValues the partition which all tuples being written by this OutputWriter + * belong to + * @param bucketId the bucket which all tuples being written by this OutputWriter belong to + * @param closeCurrentWriter close and release resource for current writer + */ + protected def renewCurrentWriter( + partitionValues: Option[InternalRow], + bucketId: Option[Int], + closeCurrentWriter: Boolean): Unit = { + + recordsInFile = 0 + if (closeCurrentWriter) { + releaseCurrentWriter() + } + + val partDir = partitionValues.map(getPartitionPath(_)) + partDir.foreach(updatedPartitions.add) + + val bucketIdStr = bucketId.map(BucketingUtils.bucketIdToString).getOrElse("") + + // The prefix and suffix must be in a form that matches our bucketing format. See BucketingUtils + // for details. The prefix is required to represent bucket id when writing Hive-compatible + // bucketed table. + val prefix = bucketId match { + case Some(id) => description.bucketSpec.get.bucketFileNamePrefix(id) + case _ => "" + } + val suffix = f"$bucketIdStr.c$fileCounter%03d" + + description.outputWriterFactory.getFileExtension(taskAttemptContext) + val fileNameSpec = FileNameSpec(prefix, suffix) + + val customPath = partDir.flatMap { dir => + description.customPartitionLocations.get(PartitioningUtils.parsePathFragment(dir)) + } + val currentPath = if (customPath.isDefined) { + committer.newTaskTempFileAbsPath(taskAttemptContext, customPath.get, fileNameSpec) + } else { + committer.newTaskTempFile(taskAttemptContext, partDir, fileNameSpec) + } + + currentWriter = description.outputWriterFactory.newInstance( + path = currentPath, + dataSchema = description.dataColumns.toStructType, + context = taskAttemptContext) + currentWriter.asInstanceOf[OmniOrcOutputWriter] + .initialize(description.allColumns, description.dataColumns) + statsTrackers.foreach(_.newFile(currentPath)) + } + + /** + * Open a new output writer when number of records exceeding limit. + * + * @param partitionValues the partition which all tuples being written by this `OutputWriter` + * belong to + * @param bucketId the bucket which all tuples being written by this `OutputWriter` belong to + */ + protected def renewCurrentWriterIfTooManyRecords( + partitionValues: Option[InternalRow], + bucketId: Option[Int]): Unit = { + // Exceeded the threshold in terms of the number of records per file. + // Create a new file by increasing the file counter. + fileCounter += 1 + assert(fileCounter < MAX_FILE_COUNTER, + s"File counter $fileCounter is beyond max value $MAX_FILE_COUNTER") + renewCurrentWriter(partitionValues, bucketId, closeCurrentWriter = true) + } + + /** + * Writes the given record with current writer. + * + * @param record The record to write + */ + protected def writeRecord(record: InternalRow, startPos: Long, endPos: Long): Unit = { + // TODO After add OmniParquetOutPutWriter need extract + // a abstract interface named OmniOutPutWriter + assert(currentWriter.isInstanceOf[OmniOrcOutputWriter]) + currentWriter.asInstanceOf[OmniOrcOutputWriter].spiltWrite(record, startPos, endPos) + + statsTrackers.foreach(_.newRow(currentWriter.path, record)) + recordsInFile += record.asInstanceOf[OmniInternalRow].batch.numRows() + } +} + +/** + * Dynamic partition writer with single writer, meaning only one writer is opened at any time for + * writing. The records to be written are required to be sorted on partition and/or bucket + * column(s) before writing. + */ +class OmniDynamicPartitionDataSingleWriter( + description: WriteJobDescription, + taskAttemptContext: TaskAttemptContext, + committer: FileCommitProtocol, + customMetrics: Map[String, SQLMetric] = Map.empty) + extends OmniBaseDynamicPartitionDataWriter(description, taskAttemptContext, committer, + customMetrics) { + + private var currentPartitionValues: Option[UnsafeRow] = None + private var currentBucketId: Option[Int] = None + + override def write(record: InternalRow): Unit = { + assert(record.isInstanceOf[OmniInternalRow]) + splitWrite(record) + } + + private def splitWrite(omniInternalRow: InternalRow): Unit = { + val batch = omniInternalRow.asInstanceOf[OmniInternalRow].batch + val numRows = batch.numRows() + var lastIndex = 0 + for (i <- 0 until numRows) { + val record = batch.getRow(i) + val nextPartitionValues = if (isPartitioned) Some(getPartitionValues(record)) else None + val nextBucketId = if (isBucketed) Some(getBucketId(record)) else None + + if (currentPartitionValues != nextPartitionValues || currentBucketId != nextBucketId) { + // See a new partition or bucket - write to a new partition dir (or a new bucket file). + if (isPartitioned && currentPartitionValues != nextPartitionValues) { + currentPartitionValues = Some(nextPartitionValues.get.copy()) + statsTrackers.foreach(_.newPartition(currentPartitionValues.get)) + } + if (isBucketed) { + currentBucketId = nextBucketId + } + + fileCounter = 0 + if (i != 0) { + writeRecord(omniInternalRow, lastIndex, i) + lastIndex = i + } + renewCurrentWriter(currentPartitionValues, currentBucketId, closeCurrentWriter = true) + } else if ( + description.maxRecordsPerFile > 0 && + recordsInFile >= description.maxRecordsPerFile + ) { + if (i != 0) { + writeRecord(omniInternalRow, lastIndex, i) + lastIndex = i + } + renewCurrentWriterIfTooManyRecords(currentPartitionValues, currentBucketId) + } + } + if (lastIndex < batch.numRows()) { + writeRecord(omniInternalRow, lastIndex, numRows) + } + } +} + + + diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/OmniFileFormatWriter.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/OmniFileFormatWriter.scala new file mode 100644 index 0000000000000000000000000000000000000000..465123d8730d39656d242bd4785e22680312ce05 --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/OmniFileFormatWriter.scala @@ -0,0 +1,377 @@ +/* + * Copyright (C) 2024-2024. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources + +import java.util.{Date, UUID} +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{FileAlreadyExistsException, Path} +import org.apache.hadoop.mapreduce._ +import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat +import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl +import org.apache.spark._ +import org.apache.spark.internal.Logging +import org.apache.spark.internal.io.{FileCommitProtocol, SparkHadoopWriterUtils} +import org.apache.spark.shuffle.FetchFailedException +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.catalog.BucketSpec +import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning +import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils} +import org.apache.spark.sql.errors.QueryExecutionErrors +import org.apache.spark.sql.execution.datasources.FileFormatWriter.ConcurrentOutputWriterSpec +import org.apache.spark.sql.execution.{ColumnarProjectExec, ColumnarSortExec, OmniColumnarToRowExec, ProjectExec, SQLExecution, SortExec, SparkPlan, UnsafeExternalRowSorter} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.StringType +import org.apache.spark.sql.vectorized.ColumnarBatch +import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.util.{SerializableConfiguration, Utils} + + +/** A helper object for writing FileFormat data out to a location. */ +object OmniFileFormatWriter extends Logging { + /** Describes how output files should be placed in the filesystem. */ + case class OutputSpec( + outputPath: String, + customPartitionLocations: Map[TablePartitionSpec, String], + outputColumns: Seq[Attribute]) + + /** A function that converts the empty string to null for partition values. */ + case class Empty2Null(child: Expression) extends UnaryExpression with String2StringExpression { + override def convert(v: UTF8String): UTF8String = if (v.numBytes() == 0) null else v + + override def nullable: Boolean = true + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + throw new UnsupportedOperationException("This operator doesn't support doGenCode") + } + + override protected def withNewChildInternal(newChild: Expression): Empty2Null = + copy(child = newChild) + } + + /** + * Basic work flow of this command is: + * 1. Driver side setup, including output committer initialization and data source specific + * preparation work for the write job to be issued. + * 2. Issues a write job consists of one or more executor side tasks, each of which writes all + * rows within an RDD partition. + * 3. If no exception is thrown in a task, commits that task, otherwise aborts that task; If any + * exception is thrown during task commitment, also aborts that task. + * 4. If all tasks are committed, commit the job, otherwise aborts the job; If any exception is + * thrown during job commitment, also aborts the job. + * 5. If the job is successfully committed, perform post-commit operations such as + * processing statistics. + * + * @return The set of all partition paths that were updated during this write job. + */ + def write( + sparkSession: SparkSession, + plan: SparkPlan, + fileFormat: FileFormat, + committer: FileCommitProtocol, + outputSpec: OutputSpec, + hadoopConf: Configuration, + partitionColumns: Seq[Attribute], + bucketSpec: Option[BucketSpec], + statsTrackers: Seq[WriteJobStatsTracker], + options: Map[String, String]) + : Set[String] = { + + val job = Job.getInstance(hadoopConf) + job.setOutputKeyClass(classOf[Void]) + job.setOutputValueClass(classOf[InternalRow]) + FileOutputFormat.setOutputPath(job, new Path(outputSpec.outputPath)) + + val partitionSet = AttributeSet(partitionColumns) + // cleanup the internal metadata information of + // the file source metadata attribute if any before write out + val finalOutputSpec = outputSpec.copy(outputColumns = outputSpec.outputColumns + .map(FileSourceMetadataAttribute.cleanupFileSourceMetadataInformation)) + val dataColumns = finalOutputSpec.outputColumns.filterNot(partitionSet.contains) + + var needConvert = false + val projectList: Seq[NamedExpression] = plan.output.map { + case p if partitionSet.contains(p) && p.dataType == StringType && p.nullable => + needConvert = true + Alias(Empty2Null(p), p.name)() + case attr => attr + } + val empty2NullPlan = if (needConvert) ColumnarProjectExec(projectList, plan) else plan + + val writerBucketSpec = bucketSpec.map { spec => + val bucketColumns = spec.bucketColumnNames.map(c => dataColumns.find(_.name == c).get) + + if (options.getOrElse(BucketingUtils.optionForHiveCompatibleBucketWrite, "false") == + "true") { + // Hive bucketed table: use `HiveHash` and bitwise-and as bucket id expression. + // Without the extra bitwise-and operation, we can get wrong bucket id when hash value of + // columns is negative. See Hive implementation in + // `org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils#getBucketNumber()`. + val hashId = BitwiseAnd(HiveHash(bucketColumns), Literal(Int.MaxValue)) + val bucketIdExpression = Pmod(hashId, Literal(spec.numBuckets)) + + // The bucket file name prefix is following Hive, Presto and Trino conversion, so this + // makes sure Hive bucketed table written by Spark, can be read by other SQL engines. + // + // Hive: `org.apache.hadoop.hive.ql.exec.Utilities#getBucketIdFromFile()`. + // Trino: `io.trino.plugin.hive.BackgroundHiveSplitLoader#BUCKET_PATTERNS`. + val fileNamePrefix = (bucketId: Int) => f"$bucketId%05d_0_" + WriterBucketSpec(bucketIdExpression, fileNamePrefix) + } else { + // Spark bucketed table: use `HashPartitioning.partitionIdExpression` as bucket id + // expression, so that we can guarantee the data distribution is same between shuffle and + // bucketed data source, which enables us to only shuffle one side when join a bucketed + // table and a normal one. + val bucketIdExpression = HashPartitioning(bucketColumns, spec.numBuckets) + .partitionIdExpression + WriterBucketSpec(bucketIdExpression, (_: Int) => "") + } + } + val sortColumns = bucketSpec.toSeq.flatMap { + spec => spec.sortColumnNames.map(c => dataColumns.find(_.name == c).get) + } + + val caseInsensitiveOptions = CaseInsensitiveMap(options) + + val dataSchema = dataColumns.toStructType + DataSourceUtils.verifySchema(fileFormat, dataSchema) + // Note: prepareWrite has side effect. It sets "job". + val outputWriterFactory = + fileFormat.prepareWrite(sparkSession, job, caseInsensitiveOptions, dataSchema) + + val description = new WriteJobDescription( + uuid = UUID.randomUUID.toString, + serializableHadoopConf = new SerializableConfiguration(job.getConfiguration), + outputWriterFactory = outputWriterFactory, + allColumns = finalOutputSpec.outputColumns, + dataColumns = dataColumns, + partitionColumns = partitionColumns, + bucketSpec = writerBucketSpec, + path = finalOutputSpec.outputPath, + customPartitionLocations = finalOutputSpec.customPartitionLocations, + maxRecordsPerFile = caseInsensitiveOptions.get("maxRecordsPerFile").map(_.toLong) + .getOrElse(sparkSession.sessionState.conf.maxRecordsPerFile), + timeZoneId = caseInsensitiveOptions.get(DateTimeUtils.TIMEZONE_OPTION) + .getOrElse(sparkSession.sessionState.conf.sessionLocalTimeZone), + statsTrackers = statsTrackers + ) + + // We should first sort by partition columns, then bucket id, and finally sorting columns. + val requiredOrdering = + partitionColumns ++ writerBucketSpec.map(_.bucketIdExpression) ++ sortColumns + // the sort order doesn't matter + val actualOrdering = empty2NullPlan.outputOrdering.map(_.child) + val orderingMatched = if (requiredOrdering.length > actualOrdering.length) { + false + } else { + requiredOrdering.zip(actualOrdering).forall { + case (requiredOrder, childOutputOrder) => + requiredOrder.semanticEquals(childOutputOrder) + } + } + + SQLExecution.checkSQLExecutionId(sparkSession) + + // propagate the description UUID into the jobs, so that committers + // get an ID guaranteed to be unique. + job.getConfiguration.set("spark.sql.sources.writeJobUUID", description.uuid) + + // This call shouldn't be put into the `try` block below because it only initializes and + // prepares the job, any exception thrown from here shouldn't cause abortJob() to be called. + committer.setupJob(job) + + try { + val (rdd, concurrentOutputWriterSpec) = if (orderingMatched) { + (empty2NullPlan.executeColumnar(), None) + } else { + // SPARK-21165: the `requiredOrdering` is based on the attributes from analyzed plan, and + // the physical plan may have different attribute ids due to optimizer removing some + // aliases. Here we bind the expression ahead to avoid potential attribute ids mismatch. + val orderingExpr = bindReferences( + requiredOrdering.map(SortOrder(_, Ascending)), finalOutputSpec.outputColumns) + // val orderingExpr = requiredOrdering.map(SortOrder(_, Ascending)) + val sortPlan = ColumnarSortExec( + orderingExpr, + global = false, + child = empty2NullPlan) + + val maxWriters = sparkSession.sessionState.conf.maxConcurrentOutputFileWriters + val concurrentWritersEnabled = maxWriters > 0 && sortColumns.isEmpty + if (concurrentWritersEnabled) { + // TODO Concurrent output write + logInfo("Columnar concurrent write is not support now, use un concurrent write") + (sortPlan.executeColumnar(), None) + } else { + (sortPlan.executeColumnar(), None) + } + } + + // SPARK-23271 If we are attempting to write a zero partition rdd, create a dummy single + // partition rdd to make sure we at least set up one write task to write the metadata. + val rddWithNonEmptyPartitions = if (rdd.partitions.length == 0) { + sparkSession.sparkContext.parallelize(Array.empty[OmniInternalRow], 1) + } else { + rdd.map(cb => new OmniInternalRow(cb)) + } + + val jobIdInstant = new Date().getTime + val ret = new Array[WriteTaskResult](rddWithNonEmptyPartitions.partitions.length) + sparkSession.sparkContext.runJob( + rddWithNonEmptyPartitions, + (taskContext: TaskContext, iter: Iterator[InternalRow]) => { + executeTask( + description = description, + jobIdInstant = jobIdInstant, + sparkStageId = taskContext.stageId(), + sparkPartitionId = taskContext.partitionId(), + sparkAttemptNumber = taskContext.taskAttemptId().toInt & Integer.MAX_VALUE, + committer, + iterator = iter, + concurrentOutputWriterSpec = concurrentOutputWriterSpec) + }, + rddWithNonEmptyPartitions.partitions.indices, + (index, res: WriteTaskResult) => { + committer.onTaskCommit(res.commitMsg) + ret(index) = res + }) + + val commitMsgs = ret.map(_.commitMsg) + + logInfo(s"Start to commit write Job ${description.uuid}.") + val (_, duration) = Utils.timeTakenMs { + committer.commitJob(job, commitMsgs) + } + logInfo(s"Write Job ${description.uuid} committed. Elapsed time: $duration ms.") + + processStats(description.statsTrackers, ret.map(_.summary.stats), duration) + logInfo(s"Finished processing stats for write job ${description.uuid}.") + + // return a set of all the partition paths that were updated during this job + ret.map(_.summary.updatedPartitions).reduceOption(_ ++ _).getOrElse(Set.empty) + } catch { + case cause: Throwable => + logError(s"Aborting job ${description.uuid}.", cause) + committer.abortJob(job) + throw QueryExecutionErrors.jobAbortedError(cause) + } + } + + /** Writes data out in a single Spark task. */ + private def executeTask( + description: WriteJobDescription, + jobIdInstant: Long, + sparkStageId: Int, + sparkPartitionId: Int, + sparkAttemptNumber: Int, + committer: FileCommitProtocol, + iterator: Iterator[InternalRow], + concurrentOutputWriterSpec: + Option[ConcurrentOutputWriterSpec]): WriteTaskResult = { + + val jobId = SparkHadoopWriterUtils.createJobID(new Date(jobIdInstant), sparkStageId) + val taskId = new TaskID(jobId, TaskType.MAP, sparkPartitionId) + val taskAttemptId = new TaskAttemptID(taskId, sparkAttemptNumber) + + // Set up the attempt context required to use in the output committer. + val taskAttemptContext: TaskAttemptContext = { + // Set up the configuration object + val hadoopConf = description.serializableHadoopConf.value + hadoopConf.set("mapreduce.job.id", jobId.toString) + hadoopConf.set("mapreduce.task.id", taskAttemptId.getTaskID.toString) + hadoopConf.set("mapreduce.task.attempt.id", taskAttemptId.toString) + hadoopConf.setBoolean("mapreduce.task.ismap", true) + hadoopConf.setInt("mapreduce.task.partition", 0) + + new TaskAttemptContextImpl(hadoopConf, taskAttemptId) + } + + committer.setupTask(taskAttemptContext) + + val dataWriter = + if (sparkPartitionId != 0 && !iterator.hasNext) { + // In case of empty job, leave first partition to save meta for file format like parquet. + new EmptyDirectoryDataWriter(description, taskAttemptContext, committer) + } else if (description.partitionColumns.isEmpty && description.bucketSpec.isEmpty) { + new OmniSingleDirectoryDataWriter(description, taskAttemptContext, committer) + } else { + concurrentOutputWriterSpec match { + case Some(spec) => + new DynamicPartitionDataConcurrentWriter( + description, taskAttemptContext, committer, spec) + case _ => + new OmniDynamicPartitionDataSingleWriter(description, taskAttemptContext, committer) + } + } + + try { + Utils.tryWithSafeFinallyAndFailureCallbacks(block = { + // Execute the task to write rows out and commit the task. + dataWriter.writeWithIterator(iterator) + dataWriter.commit() + })(catchBlock = { + // If there is an error, abort the task + dataWriter.abort() + logError(s"Job $jobId aborted.") + }, finallyBlock = { + dataWriter.close() + }) + } catch { + case e: FetchFailedException => + throw e + case f: FileAlreadyExistsException if SQLConf.get.fastFailFileFormatOutput => + // If any output file to write already exists, it does not make sense to re-run this task. + // We throw the exception and let Executor throw ExceptionFailure to abort the job. + throw new TaskOutputFileAlreadyExistException(f) + case t: Throwable => + throw QueryExecutionErrors.taskFailedWhileWritingRowsError(t) + } + } + + /** + * For every registered [[WriteJobStatsTracker]], call `processStats()` on it, passing it + * the corresponding [[WriteTaskStats]] from all executors. + */ + private[datasources] def processStats( + statsTrackers: Seq[WriteJobStatsTracker], + statsPerTask: Seq[Seq[WriteTaskStats]], + jobCommitDuration: Long) + : Unit = { + + val numStatsTrackers = statsTrackers.length + assert(statsPerTask.forall(_.length == numStatsTrackers), + s"""Every WriteTask should have produced one `WriteTaskStats` object for every tracker. + |There are $numStatsTrackers statsTrackers, but some task returned + |${statsPerTask.find(_.length != numStatsTrackers).get.length} results instead. + """.stripMargin) + + val statsPerTracker = if (statsPerTask.nonEmpty) { + statsPerTask.transpose + } else { + statsTrackers.map(_ => Seq.empty) + } + + statsTrackers.zip(statsPerTracker).foreach { + case (statsTracker, stats) => statsTracker.processStats(stats, jobCommitDuration) + } + } +} diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/OmniInsertIntoHadoopFsRelationCommand.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/OmniInsertIntoHadoopFsRelationCommand.scala new file mode 100644 index 0000000000000000000000000000000000000000..9d0008e0b2cf1a57443e019bf41113f60c8855fd --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/OmniInsertIntoHadoopFsRelationCommand.scala @@ -0,0 +1,280 @@ +/* + * Copyright (C) 2024-2024. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources + +import org.apache.hadoop.fs.{FileSystem, Path} + +import org.apache.spark.internal.io.FileCommitProtocol +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogTable, CatalogTablePartition} +import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec +import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils._ +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap +import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.command._ +import org.apache.spark.sql.internal.SQLConf.PartitionOverwriteMode +import org.apache.spark.sql.util.SchemaUtils +import org.apache.spark.sql.vectorized.ColumnarBatch + +/** + * A command for writing data to a [[HadoopFsRelation]]. Supports both overwriting and appending. + * Writing to dynamic partitions is also supported. + * + * @param staticPartitions partial partitioning spec for write. This defines the scope of partition + * overwrites: when the spec is empty, all partitions are overwritten. + * When it covers a prefix of the partition keys, only partitions matching + * the prefix are overwritten. + * @param ifPartitionNotExists If true, only write if the partition does not exist. + * Only valid for static partitions. + */ +case class OmniInsertIntoHadoopFsRelationCommand( + outputPath: Path, + staticPartitions: TablePartitionSpec, + ifPartitionNotExists: Boolean, + partitionColumns: Seq[Attribute], + bucketSpec: Option[BucketSpec], + fileFormat: FileFormat, + options: Map[String, String], + query: LogicalPlan, + mode: SaveMode, + catalogTable: Option[CatalogTable], + fileIndex: Option[FileIndex], + outputColumnNames: Seq[String]) + extends DataWritingCommand { + + private lazy val parameters = CaseInsensitiveMap(options) + + private[sql] lazy val dynamicPartitionOverwrite: Boolean = { + val partitionOverwriteMode = parameters.get(DataSourceUtils.PARTITION_OVERWRITE_MODE) + // scalastyle:off caselocale + .map(mode => PartitionOverwriteMode.withName(mode.toUpperCase)) + // scalastyle:on caselocale + .getOrElse(conf.partitionOverwriteMode) + val enableDynamicOverwrite = partitionOverwriteMode == PartitionOverwriteMode.DYNAMIC + // This config only makes sense when we are overwriting a partitioned dataset with dynamic + // partition columns. + enableDynamicOverwrite && mode == SaveMode.Overwrite && + staticPartitions.size < partitionColumns.length + } + + // Return Seq[Row] but Seq[ColumBatch] since + // 1. reuse the origin interface of spark to avoid add duplicate code + // 2. this func return a Seq.empty[Row] and this data doesn't do anything else + override def run(sparkSession: SparkSession, child: SparkPlan): Seq[Row] = { + // Most formats don't do well with duplicate columns, so lets not allow that + SchemaUtils.checkColumnNameDuplication( + outputColumnNames, + s"when inserting into $outputPath", + sparkSession.sessionState.conf.caseSensitiveAnalysis) + + val hadoopConf = sparkSession.sessionState.newHadoopConfWithOptions(options) + val fs = outputPath.getFileSystem(hadoopConf) + val qualifiedOutputPath = outputPath.makeQualified(fs.getUri, fs.getWorkingDirectory) + + val partitionsTrackedByCatalog = sparkSession.sessionState.conf.manageFilesourcePartitions && + catalogTable.isDefined && + catalogTable.get.partitionColumnNames.nonEmpty && + catalogTable.get.tracksPartitionsInCatalog + + var initialMatchingPartitions: Seq[TablePartitionSpec] = Nil + var customPartitionLocations: Map[TablePartitionSpec, String] = Map.empty + var matchingPartitions: Seq[CatalogTablePartition] = Seq.empty + + // When partitions are tracked by the catalog, compute all custom partition locations that + // may be relevant to the insertion job. + if (partitionsTrackedByCatalog) { + matchingPartitions = sparkSession.sessionState.catalog.listPartitions( + catalogTable.get.identifier, Some(staticPartitions)) + initialMatchingPartitions = matchingPartitions.map(_.spec) + customPartitionLocations = getCustomPartitionLocations( + fs, catalogTable.get, qualifiedOutputPath, matchingPartitions) + } + + val jobId = java.util.UUID.randomUUID().toString + val committer = FileCommitProtocol.instantiate( + sparkSession.sessionState.conf.fileCommitProtocolClass, + jobId = jobId, + outputPath = outputPath.toString, + dynamicPartitionOverwrite = dynamicPartitionOverwrite) + + val doInsertion = if (mode == SaveMode.Append) { + true + } else { + val pathExists = fs.exists(qualifiedOutputPath) + (mode, pathExists) match { + case (SaveMode.ErrorIfExists, true) => + throw QueryCompilationErrors.outputPathAlreadyExistsError(qualifiedOutputPath) + case (SaveMode.Overwrite, true) => + if (ifPartitionNotExists && matchingPartitions.nonEmpty) { + false + } else if (dynamicPartitionOverwrite) { + // For dynamic partition overwrite, do not delete partition directories ahead. + true + } else { + deleteMatchingPartitions(fs, qualifiedOutputPath, customPartitionLocations, committer) + true + } + case (SaveMode.Overwrite, _) | (SaveMode.ErrorIfExists, false) => + true + case (SaveMode.Ignore, exists) => + !exists + case (s, exists) => + throw QueryExecutionErrors.saveModeUnsupportedError(s, exists) + } + } + + if (doInsertion) { + def refreshUpdatedPartitions(updatedPartitionPaths: Set[String]): Unit = { + val updatedPartitions = updatedPartitionPaths.map(PartitioningUtils.parsePathFragment) + if (partitionsTrackedByCatalog) { + val newPartitions = updatedPartitions -- initialMatchingPartitions + if (newPartitions.nonEmpty) { + AlterTableAddPartitionCommand( + catalogTable.get.identifier, newPartitions.toSeq.map(p => (p, None)), + ifNotExists = true).run(sparkSession) + } + // For dynamic partition overwrite, we never remove partitions but only update existing + // ones. + if (mode == SaveMode.Overwrite && !dynamicPartitionOverwrite) { + val deletedPartitions = initialMatchingPartitions.toSet -- updatedPartitions + if (deletedPartitions.nonEmpty) { + AlterTableDropPartitionCommand( + catalogTable.get.identifier, deletedPartitions.toSeq, + ifExists = true, purge = false, + retainData = true /* already deleted */).run(sparkSession) + } + } + } + } + + // For dynamic partition overwrite, FileOutputCommitter's output path is staging path, files + // will be renamed from staging path to final output path during commit job + val committerOutputPath = if (dynamicPartitionOverwrite) { + FileCommitProtocol.getStagingDir(outputPath.toString, jobId) + .makeQualified(fs.getUri, fs.getWorkingDirectory) + } else { + qualifiedOutputPath + } + + val updatedPartitionPaths = + OmniFileFormatWriter.write( + sparkSession = sparkSession, + plan = child, + fileFormat = fileFormat, + committer = committer, + outputSpec = OmniFileFormatWriter.OutputSpec( + committerOutputPath.toString, customPartitionLocations, outputColumns), + hadoopConf = hadoopConf, + partitionColumns = partitionColumns, + bucketSpec = bucketSpec, + statsTrackers = Seq(basicWriteJobStatsTracker(hadoopConf)), + options = options) + + + // update metastore partition metadata + if (updatedPartitionPaths.isEmpty && staticPartitions.nonEmpty + && partitionColumns.length == staticPartitions.size) { + // Avoid empty static partition can't loaded to datasource table. + val staticPathFragment = + PartitioningUtils.getPathFragment(staticPartitions, partitionColumns) + refreshUpdatedPartitions(Set(staticPathFragment)) + } else { + refreshUpdatedPartitions(updatedPartitionPaths) + } + + // refresh cached files in FileIndex + fileIndex.foreach(_.refresh()) + // refresh data cache if table is cached + sparkSession.sharedState.cacheManager.recacheByPath(sparkSession, outputPath, fs) + + if (catalogTable.nonEmpty) { + CommandUtils.updateTableStats(sparkSession, catalogTable.get) + } + + } else { + logInfo("Skipping insertion into a relation that already exists.") + } + + Seq.empty[Row] + } + + /** + * Deletes all partition files that match the specified static prefix. Partitions with custom + * locations are also cleared based on the custom locations map given to this class. + */ + private def deleteMatchingPartitions( + fs: FileSystem, + qualifiedOutputPath: Path, + customPartitionLocations: Map[TablePartitionSpec, String], + committer: FileCommitProtocol): Unit = { + val staticPartitionPrefix = if (staticPartitions.nonEmpty) { + "/" + partitionColumns.flatMap { p => + staticPartitions.get(p.name).map(getPartitionPathString(p.name, _)) + }.mkString("/") + } else { + "" + } + // first clear the path determined by the static partition keys (e.g. /table/foo=1) + val staticPrefixPath = qualifiedOutputPath.suffix(staticPartitionPrefix) + if (fs.exists(staticPrefixPath) && !committer.deleteWithJob(fs, staticPrefixPath, true)) { + throw QueryExecutionErrors.cannotClearOutputDirectoryError(staticPrefixPath) + } + // now clear all custom partition locations (e.g. /custom/dir/where/foo=2/bar=4) + for ((spec, customLoc) <- customPartitionLocations) { + assert( + (staticPartitions.toSet -- spec).isEmpty, + "Custom partition location did not match static partitioning keys") + val path = new Path(customLoc) + if (fs.exists(path) && !committer.deleteWithJob(fs, path, true)) { + throw QueryExecutionErrors.cannotClearPartitionDirectoryError(path) + } + } + } + + /** + * Given a set of input partitions, returns those that have locations that differ from the + * Hive default (e.g. /k1=v1/k2=v2). These partitions were manually assigned locations by + * the user. + * + * @return a mapping from partition specs to their custom locations + */ + private def getCustomPartitionLocations( + fs: FileSystem, + table: CatalogTable, + qualifiedOutputPath: Path, + partitions: Seq[CatalogTablePartition]): Map[TablePartitionSpec, String] = { + partitions.flatMap { p => + val defaultLocation = qualifiedOutputPath.suffix( + "/" + PartitioningUtils.getPathFragment(p.spec, table.partitionSchema)).toString + val catalogLocation = new Path(p.location).makeQualified( + fs.getUri, fs.getWorkingDirectory).toString + if (catalogLocation != defaultLocation) { + Some(p.spec -> catalogLocation) + } else { + None + } + }.toMap + } + + override protected def withNewChildInternal(newChild: LogicalPlan): + OmniInsertIntoHadoopFsRelationCommand = copy(query = newChild) +} diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/OmniInternalRow.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/OmniInternalRow.scala new file mode 100644 index 0000000000000000000000000000000000000000..e106bd39b638909df121544bec15abe6e19bb3b0 --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/OmniInternalRow.scala @@ -0,0 +1,73 @@ +/* + * Copyright (C) 2024-2024. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.datasources + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} +import org.apache.spark.sql.types.{DataType, Decimal} +import org.apache.spark.sql.vectorized.ColumnarBatch +import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} + +// This class used to reuse the native spark interface of table +// write to reduce the number new files. It is essentially a batch +class OmniInternalRow(val batch: ColumnarBatch) extends InternalRow { + override def numFields: Int = throw new UnsupportedOperationException() + + override def setNullAt(i: Int): Unit = throw new UnsupportedOperationException() + + override def update(i: Int, value: Any): Unit = throw new UnsupportedOperationException() + + override def copy(): InternalRow = throw new UnsupportedOperationException() + + override def isNullAt(ordinal: Int): Boolean = throw new UnsupportedOperationException() + + override def getBoolean(ordinal: Int): Boolean = throw new UnsupportedOperationException() + + override def getByte(ordinal: Int): Byte = throw new UnsupportedOperationException() + + override def getShort(ordinal: Int): Short = throw new UnsupportedOperationException() + + override def getInt(ordinal: Int): Int = throw new UnsupportedOperationException() + + override def getLong(ordinal: Int): Long = throw new UnsupportedOperationException() + + override def getFloat(ordinal: Int): Float = throw new UnsupportedOperationException() + + override def getDouble(ordinal: Int): Double = throw new UnsupportedOperationException() + + override def getDecimal(ordinal: Int, precision: Int, scale: Int): Decimal = + throw new UnsupportedOperationException() + + override def getUTF8String(ordinal: Int): UTF8String = + throw new UnsupportedOperationException() + + override def getBinary(ordinal: Int): Array[Byte] = throw new UnsupportedOperationException() + + override def getInterval(ordinal: Int): CalendarInterval = + throw new UnsupportedOperationException() + + override def getStruct(ordinal: Int, numFields: Int): InternalRow = + throw new UnsupportedOperationException() + + override def getArray(ordinal: Int): ArrayData = throw new UnsupportedOperationException() + + override def getMap(ordinal: Int): MapData = throw new UnsupportedOperationException() + + override def get(ordinal: Int, dataType: DataType): AnyRef = + throw new UnsupportedOperationException() +} diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OmniOrcFileFormat.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OmniOrcFileFormat.scala index 334800f5111521feeda8ab1fa78a3ed9de436ac3..bfbdb94d896fa1d977897fa57799c3bc32276312 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OmniOrcFileFormat.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OmniOrcFileFormat.scala @@ -24,6 +24,7 @@ import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.hadoop.mapreduce._ import org.apache.hadoop.mapreduce.lib.input.FileSplit import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl +import org.apache.orc.OrcConf.COMPRESS import org.apache.orc.{OrcConf, OrcFile, TypeDescription} import org.apache.orc.TypeDescription.Category._ import org.apache.orc.mapreduce.OrcInputFormat @@ -36,7 +37,6 @@ import org.apache.spark.sql.sources._ import org.apache.spark.sql.types.StructType import org.apache.spark.util.{SerializableConfiguration, Utils} import org.apache.spark.sql.types.StringType - import org.apache.spark.sql.types.DecimalType class OmniOrcFileFormat extends FileFormat with DataSourceRegister with Serializable { @@ -182,6 +182,22 @@ class OmniOrcFileFormat extends FileFormat with DataSourceRegister with Serializ job: Job, options: Map[String, String], dataSchema: StructType): OutputWriterFactory = { - throw new UnsupportedOperationException() + new OutputWriterFactory { + override def getFileExtension(context: TaskAttemptContext): String = { + val compressionExtension: String = { + val name = context.getConfiguration.get(COMPRESS.getAttribute) + OrcUtils.extensionsForCompressionCodecNames.getOrElse(name, "") + } + + compressionExtension + ".orc" + } + + override def newInstance( + path: String, + dataSchema: StructType, + context: TaskAttemptContext): OutputWriter = { + new OmniOrcOutputWriter(path, dataSchema, context) + } + } } } diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OmniOrcOutPutWriter.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OmniOrcOutPutWriter.scala new file mode 100644 index 0000000000000000000000000000000000000000..2ac2ff33f84f359bdb555b49d36d171303dcd07e --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OmniOrcOutPutWriter.scala @@ -0,0 +1,72 @@ +/* + * Copyright (C) 2024-2024. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.orc + +import com.huawei.boostkit.spark.expression.OmniExpressionAdaptor.{sparkTypeToOmniExpType, sparkTypeToOmniType} +import com.huawei.boostkit.spark.jni.OrcColumnarBatchWriter +import org.apache.hadoop.fs.Path +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.execution.datasources.{OmniInternalRow, OutputWriter} +import org.apache.hadoop.mapreduce.TaskAttemptContext +import org.apache.orc.{OrcConf, OrcFile} +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.types.StructType + +import java.net.URI + +private[sql] class OmniOrcOutputWriter(path: String, dataSchema: StructType, + context: TaskAttemptContext) extends OutputWriter { + + val writer = new OrcColumnarBatchWriter() + var omniTypes: Array[Int] = new Array[Int](0) + var dataColumnsIds: Array[Boolean] = new Array[Boolean](0) + + def initialize(allColumns: Seq[Attribute], dataColumns: Seq[Attribute]): Unit = { + val filePath = new Path(new URI(path)) + val conf = context.getConfiguration + val writerOptions = OrcFile.writerOptions(conf). + fileSystem(new Path(new URI(path)).getFileSystem(conf)) + writer.initializeOutputStreamJava(filePath.toUri) + writer.initializeSchemaTypeJava(dataSchema) + writer.initializeWriterJava(filePath.toUri, dataSchema, writerOptions) + dataSchema.foreach(field => { + omniTypes = omniTypes :+ sparkTypeToOmniType(field.dataType, field.metadata).getId.ordinal() + }) + dataColumnsIds = allColumns.map(x => dataColumns.contains(x)).toArray + } + + override def write(row: InternalRow): Unit = { + assert(row.isInstanceOf[OmniInternalRow]) + writer.write(omniTypes, dataColumnsIds, row.asInstanceOf[OmniInternalRow].batch) + } + + def spiltWrite(row: InternalRow, startPos: Long, endPos: Long): Unit = { + assert(row.isInstanceOf[OmniInternalRow]) + writer.splitWrite(omniTypes, dataColumnsIds, + row.asInstanceOf[OmniInternalRow].batch, startPos, endPos) + } + + override def close(): Unit = { + writer.close() + } + + override def path(): String = { + path + } +} diff --git a/omnioperator/omniop-spark-extension/java/src/test/scala/com/huawei/boostkit/spark/TableWriteBasicFunctionSuite.scala b/omnioperator/omniop-spark-extension/java/src/test/scala/com/huawei/boostkit/spark/TableWriteBasicFunctionSuite.scala new file mode 100644 index 0000000000000000000000000000000000000000..d2c06b15b58b8f7da8ea6b12831ffdb3107e8dca --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/test/scala/com/huawei/boostkit/spark/TableWriteBasicFunctionSuite.scala @@ -0,0 +1,80 @@ +/* + * Copyright (C) 2024-2024. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.huawei.boostkit.spark + +import org.apache.spark.SparkConf +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{QueryTest, Row} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.plans.logical.Statistics +import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, BroadcastQueryStageExec, QueryStageExec} +import org.apache.spark.sql.execution.command.DataWritingCommandExec +import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, Exchange} +import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, ColumnarBroadcastHashJoinExec} +import org.apache.spark.sql.execution.{ColumnarBroadcastExchangeExec, ColumnarProjectExec, ColumnarTakeOrderedAndProjectExec, CommandResultExec, LeafExecNode, OmniColumnarToRowExec, ProjectExec, RowToOmniColumnarExec, SparkPlan, TakeOrderedAndProjectExec, UnaryExecNode} +import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf} +import org.apache.spark.sql.test.SharedSparkSession + +import scala.concurrent.Future + +class TableWriteBasicFunctionSuite extends QueryTest with SharedSparkSession { + + import testImplicits._ + + override def sparkConf: SparkConf = super.sparkConf + .setAppName("test tableWriteBasicFunctionSuit") + .set(StaticSQLConf.SPARK_SESSION_EXTENSIONS.key, "com.huawei.boostkit.spark.ColumnarPlugin") + .set(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, "false") + .set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.OmniColumnarShuffleManager") + + override def beforeAll(): Unit = { + super.beforeAll() + } + + test("Unsupported Scenarios") { + val data = Seq[(Int, Int)]( + (10000, 35), + ).toDF("id", "age") + + data.write.format("parquet").saveAsTable("table_write_ut_parquet_test") + var insert = spark.sql("insert into table_write_ut_parquet_test values(1,2)") + insert.collect() + var columnarDataWrite = insert.queryExecution.executedPlan.asInstanceOf[CommandResultExec] + .commandPhysicalPlan.find({ + case _: DataWritingCommandExec => true + case _ => false + } + ) + assert(columnarDataWrite.isDefined, "use columnar data writing command") + + val createTable = spark.sql("create table table_write_ut_map_test" + + " (id int, grades MAP) using orc") + createTable.collect() + insert = spark.sql("insert into table_write_ut_map_test (id, grades)" + + " values(1, MAP('Math',90, 'English', 85))") + insert.collect() + columnarDataWrite = insert.queryExecution.executedPlan.asInstanceOf[CommandResultExec] + .commandPhysicalPlan.find({ + case _: DataWritingCommandExec => true + case _ => false + } + ) + } +}