diff --git a/omnioperator/omniop-hive-extension/src/main/java/com/huawei/boostkit/hive/reader/OmniOrcRecordReader.java b/omnioperator/omniop-hive-extension/src/main/java/com/huawei/boostkit/hive/reader/OmniOrcRecordReader.java index 1456cd56564aa59161a2941a6d241b6a5195fd7d..e61c142d046c4bee25bf82bafc04b0a86725a2fb 100644 --- a/omnioperator/omniop-hive-extension/src/main/java/com/huawei/boostkit/hive/reader/OmniOrcRecordReader.java +++ b/omnioperator/omniop-hive-extension/src/main/java/com/huawei/boostkit/hive/reader/OmniOrcRecordReader.java @@ -19,9 +19,19 @@ package com.huawei.boostkit.hive.reader; import static com.huawei.boostkit.hive.cache.VectorCache.BATCH; +import static com.huawei.boostkit.hive.expression.TypeUtils.DEFAULT_VARCHAR_LENGTH; import static org.apache.hadoop.hive.ql.io.orc.OrcInputFormat.getDesiredRowTypeDescr; import static org.apache.hadoop.hive.serde2.ColumnProjectionUtils.READ_COLUMN_IDS_CONF_STR; +import nova.hetu.omniruntime.type.BooleanDataType; +import nova.hetu.omniruntime.type.DataType; +import nova.hetu.omniruntime.type.Decimal128DataType; +import nova.hetu.omniruntime.type.DoubleDataType; +import nova.hetu.omniruntime.type.IntDataType; +import nova.hetu.omniruntime.type.LongDataType; +import nova.hetu.omniruntime.type.ShortDataType; +import nova.hetu.omniruntime.type.VarcharDataType; +import nova.hetu.omniruntime.vector.Decimal128Vec; import nova.hetu.omniruntime.vector.Vec; import nova.hetu.omniruntime.vector.VecBatch; @@ -35,6 +45,7 @@ import org.apache.hadoop.hive.ql.io.sarg.SearchArgument; import org.apache.hadoop.hive.ql.io.sarg.SearchArgumentImpl; import org.apache.hadoop.hive.ql.plan.api.OperatorType; import org.apache.hadoop.hive.serde2.SerDeStats; +import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; import org.apache.hadoop.io.NullWritable; import org.apache.hadoop.mapred.FileSplit; import org.apache.hadoop.mapred.RecordReader; @@ -45,12 +56,31 @@ import org.apache.orc.TypeDescription; import java.io.IOException; import java.util.Arrays; +import java.util.HashMap; import java.util.HashSet; import java.util.List; +import java.util.Map; import java.util.Set; import java.util.stream.Collectors; public class OmniOrcRecordReader implements RecordReader, StatsProvidingRecordReader { + private static final Map CATEGORY_TO_OMNI_TYPE = new HashMap() { + { + put(TypeDescription.Category.SHORT, ShortDataType.SHORT); + put(TypeDescription.Category.INT, IntDataType.INTEGER); + put(TypeDescription.Category.LONG, LongDataType.LONG); + put(TypeDescription.Category.BOOLEAN, BooleanDataType.BOOLEAN); + put(TypeDescription.Category.DOUBLE, DoubleDataType.DOUBLE); + put(TypeDescription.Category.STRING, new VarcharDataType(DEFAULT_VARCHAR_LENGTH)); + put(TypeDescription.Category.TIMESTAMP, LongDataType.LONG); + put(TypeDescription.Category.DATE, IntDataType.INTEGER); + put(TypeDescription.Category.BYTE, ShortDataType.SHORT); + put(TypeDescription.Category.FLOAT, DoubleDataType.DOUBLE); + put(TypeDescription.Category.DECIMAL, Decimal128DataType.DECIMAL128); + put(TypeDescription.Category.CHAR, VarcharDataType.VARCHAR); + put(TypeDescription.Category.VARCHAR, VarcharDataType.VARCHAR); + } + }; protected OrcColumnarBatchScanReader recordReader; protected Vec[] vecs; protected final long offset; @@ -59,6 +89,7 @@ public class OmniOrcRecordReader implements RecordReader included; protected Operator tableScanOp; + protected int[] typeIds; OmniOrcRecordReader(Configuration conf, FileSplit split) throws IOException { TypeDescription schema = getDesiredRowTypeDescr(conf, false, Integer.MAX_VALUE); @@ -70,6 +101,10 @@ public class OmniOrcRecordReader implements RecordReader(field); + case omniruntime::type::OMNI_SHORT: + return CopyFixedWidth(field); + case omniruntime::type::OMNI_INT: + return CopyFixedWidth(field); + case omniruntime::type::OMNI_LONG: + return CopyOptimizedForInt64(field); + case omniruntime::type::OMNI_DATE32: + return CopyFixedWidth(field); + case omniruntime::type::OMNI_DATE64: + return CopyOptimizedForInt64(field); + default: { + throw std::runtime_error("dealLongVectorBatch not support for type: " + id); + } + } + return -1; +} + +uint64_t dealDoubleVectorBatch(DataTypeId id, orc::ColumnVectorBatch *field) { + switch (id) { + case omniruntime::type::OMNI_DOUBLE: + return CopyOptimizedForInt64(field); + default: { + throw std::runtime_error("dealDoubleVectorBatch not support for type: " + id); + } + } + return -1; +} + +uint64_t dealDecimal64VectorBatch(DataTypeId id, orc::ColumnVectorBatch *field) { + switch (id) { + case omniruntime::type::OMNI_DECIMAL64: + return CopyToOmniDecimal64Vec(field); + case omniruntime::type::OMNI_DECIMAL128: + return CopyToOmniDecimal128VecFrom64(field); + default: { + throw std::runtime_error("dealDecimal64VectorBatch not support for type: " + id); + } + } + return -1; +} + +uint64_t dealDecimal128VectorBatch(DataTypeId id, orc::ColumnVectorBatch *field) { + switch (id) { + case omniruntime::type::OMNI_DECIMAL128: + return CopyToOmniDecimal128Vec(field); + default: { + throw std::runtime_error("dealDecimal128VectorBatch not support for type: " + id); + } + } + return -1; +} + +int CopyToOmniVec(const orc::Type *type, int omniTypeId, uint64_t &omniVecId, orc::ColumnVectorBatch *field, bool isDecimal64Transfor128) { + DataTypeId dataTypeId = static_cast(omniTypeId); switch (type->getKind()) { case orc::TypeKind::BOOLEAN: - omniTypeId = static_cast(OMNI_BOOLEAN); - omniVecId = CopyFixedWidth(field); - break; case orc::TypeKind::SHORT: - omniTypeId = static_cast(OMNI_SHORT); - omniVecId = CopyFixedWidth(field); - break; case orc::TypeKind::DATE: - omniTypeId = static_cast(OMNI_DATE32); - omniVecId = CopyFixedWidth(field); - break; case orc::TypeKind::INT: - omniTypeId = static_cast(OMNI_INT); - omniVecId = CopyFixedWidth(field); - break; case orc::TypeKind::LONG: - omniTypeId = static_cast(OMNI_LONG); - omniVecId = CopyOptimizedForInt64(field); + omniVecId = dealLongVectorBatch(dataTypeId, field); break; case orc::TypeKind::DOUBLE: - omniTypeId = static_cast(OMNI_DOUBLE); - omniVecId = CopyOptimizedForInt64(field); + omniVecId = dealDoubleVectorBatch(dataTypeId, field); break; case orc::TypeKind::CHAR: - omniTypeId = static_cast(OMNI_VARCHAR); + if (dataTypeId != OMNI_VARCHAR) { + throw std::runtime_error("Cannot transfer to other OMNI_TYPE but VARCHAR for orc char"); + } omniVecId = CopyCharType(field); break; case orc::TypeKind::STRING: case orc::TypeKind::VARCHAR: - omniTypeId = static_cast(OMNI_VARCHAR); + if (dataTypeId != OMNI_VARCHAR) { + throw std::runtime_error("Cannot transfer to other OMNI_TYPE but VARCHAR for orc string/varchar"); + } omniVecId = CopyVarWidth(field); break; case orc::TypeKind::DECIMAL: if (type->getPrecision() > MAX_DECIMAL64_DIGITS) { - omniTypeId = static_cast(OMNI_DECIMAL128); - omniVecId = CopyToOmniDecimal128Vec(field); - } else if (isDecimal64Transfor128) { - omniTypeId = static_cast(OMNI_DECIMAL128); - omniVecId = CopyToOmniDecimal128VecFrom64(field); + omniVecId = dealDecimal128VectorBatch(dataTypeId, field); } else { - omniTypeId = static_cast(OMNI_DECIMAL64); - omniVecId = CopyToOmniDecimal64Vec(field); + omniVecId = dealDecimal64VectorBatch(dataTypeId, field); } break; default: { @@ -576,16 +618,16 @@ JNIEXPORT jlong JNICALL Java_com_huawei_boostkit_scan_jni_OrcColumnarBatchJniRea const orc::Type &baseTp = rowReaderPtr->getSelectedType(); int vecCnt = 0; long batchRowSize = 0; + auto ptr = env->GetIntArrayElements(typeId, JNI_FALSE); if (rowReaderPtr->next(*columnVectorBatch)) { orc::StructVectorBatch *root = dynamic_cast(columnVectorBatch); vecCnt = root->fields.size(); batchRowSize = root->fields[0]->numElements; for (int id = 0; id < vecCnt; id++) { auto type = baseTp.getSubtype(id); - int omniTypeId = 0; + int omniTypeId = ptr[id]; uint64_t omniVecId = 0; CopyToOmniVec(type, omniTypeId, omniVecId, root->fields[id], isDecimal64Transfor128); - env->SetIntArrayRegion(typeId, id, 1, &omniTypeId); jlong omniVec = static_cast(omniVecId); env->SetLongArrayRegion(vecNativeId, id, 1, &omniVec); } diff --git a/omnioperator/omniop-native-reader/cpp/src/jni/OrcColumnarBatchJniReader.h b/omnioperator/omniop-native-reader/cpp/src/jni/OrcColumnarBatchJniReader.h index 829f5c0744d3d563601ec6506ebfc82b5a020e93..e0c33b26c7f191771fc2c2bf7aaa1a278db8f424 100644 --- a/omnioperator/omniop-native-reader/cpp/src/jni/OrcColumnarBatchJniReader.h +++ b/omnioperator/omniop-native-reader/cpp/src/jni/OrcColumnarBatchJniReader.h @@ -141,7 +141,7 @@ int BuildLeaves(PredicateOperatorType leafOp, std::vector &litList bool StringToBool(const std::string &boolStr); -int CopyToOmniVec(const orc::Type *type, int &omniTypeId, uint64_t &omniVecId, orc::ColumnVectorBatch *field, +int CopyToOmniVec(const orc::Type *type, int omniTypeId, uint64_t &omniVecId, orc::ColumnVectorBatch *field, bool isDecimal64Transfor128); #ifdef __cplusplus diff --git a/omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchScanReader.java b/omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchScanReader.java index 8edbdf4622f0b5888cd4fd680fee6fc1eb3c4880..611a10826042a7916ff7d98f804c7ee5eba319b1 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchScanReader.java +++ b/omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchScanReader.java @@ -34,7 +34,6 @@ import org.slf4j.LoggerFactory; import java.net.URI; import java.sql.Date; import java.util.ArrayList; -import java.util.Arrays; import java.util.List; public class OrcColumnarBatchScanReader { @@ -253,8 +252,7 @@ public class OrcColumnarBatchScanReader { } } - public int next(Vec[] vecList) { - int[] typeIds = new int[realColsCnt]; + public int next(Vec[] vecList, int[] typeIds) { long[] vecNativeIds = new long[realColsCnt]; long rtn = jniReader.recordReaderNext(recordReader, batchReader, typeIds, vecNativeIds); if (rtn == 0) { diff --git a/omnioperator/omniop-spark-extension/java/src/main/java/org/apache/spark/sql/execution/datasources/orc/OmniOrcColumnarBatchReader.java b/omnioperator/omniop-spark-extension/java/src/main/java/org/apache/spark/sql/execution/datasources/orc/OmniOrcColumnarBatchReader.java index 24a93ede4cf4e17ccd0050157e04a848ab38ba0d..e8e7db3af876ce726d9b22e8ead47cf9f15a9fbe 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/java/org/apache/spark/sql/execution/datasources/orc/OmniOrcColumnarBatchReader.java +++ b/omnioperator/omniop-spark-extension/java/src/main/java/org/apache/spark/sql/execution/datasources/orc/OmniOrcColumnarBatchReader.java @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.datasources.orc; import com.google.common.annotations.VisibleForTesting; +import com.huawei.boostkit.spark.expression.OmniExpressionAdaptor; import com.huawei.boostkit.spark.jni.OrcColumnarBatchScanReader; import nova.hetu.omniruntime.vector.Vec; import org.apache.hadoop.conf.Configuration; @@ -79,6 +80,8 @@ public class OmniOrcColumnarBatchReader extends RecordReader