From 7ae9afa13af0d6d99065c0ab369eb296112186af Mon Sep 17 00:00:00 2001 From: liyanglinhw Date: Tue, 19 Mar 2024 17:27:44 +0800 Subject: [PATCH] support decimal64 --- .../hive/OmniExecuteWithHookContext.java | 2 +- .../boostkit/hive/OmniGroupByOperator.java | 2 + .../hive/OmniVectorizedVectorOperator.java | 5 + .../boostkit/hive/cache/VecBufferCache.java | 14 ++- .../hive/converter/Decimal64VecConverter.java | 58 +-------- .../hive/converter/DecimalVecConverter.java | 113 ++++++++++++++++++ .../hive/expression/ExpressionUtils.java | 34 ++++-- .../hive/expression/LiteralFactor.java | 9 ++ .../boostkit/hive/expression/TypeUtils.java | 15 ++- .../ArithmeticExpressionProcessor.java | 10 +- .../CaseWhenExpressionProcessor.java | 2 +- .../processor/ComputeExpressionProcessor.java | 23 +++- .../hive/reader/OmniOrcRecordReader.java | 64 ++++++++-- .../reader/OrcColumnarBatchScanReader.java | 1 - .../hive/shuffle/OmniVecBatchSerDe.java | 8 +- 15 files changed, 268 insertions(+), 92 deletions(-) diff --git a/omnioperator/omniop-hive-extension/src/main/java/com/huawei/boostkit/hive/OmniExecuteWithHookContext.java b/omnioperator/omniop-hive-extension/src/main/java/com/huawei/boostkit/hive/OmniExecuteWithHookContext.java index b52b78e18..e16a35d02 100644 --- a/omnioperator/omniop-hive-extension/src/main/java/com/huawei/boostkit/hive/OmniExecuteWithHookContext.java +++ b/omnioperator/omniop-hive-extension/src/main/java/com/huawei/boostkit/hive/OmniExecuteWithHookContext.java @@ -552,7 +552,7 @@ public class OmniExecuteWithHookContext implements ExecuteWithHookContext { } List> childOperators = op.getChildOperators(); for (Operator childOperator : childOperators) { - if (childOperator.getType().equals(OperatorType.REDUCESINK) && reduceSinkDescUnReplaceable((ReduceSinkDesc) childOperator.getConf())) { + if (childOperator.getType()!=null && childOperator.getType().equals(OperatorType.REDUCESINK) && reduceSinkDescUnReplaceable((ReduceSinkDesc) childOperator.getConf())) { return false; } } diff --git a/omnioperator/omniop-hive-extension/src/main/java/com/huawei/boostkit/hive/OmniGroupByOperator.java b/omnioperator/omniop-hive-extension/src/main/java/com/huawei/boostkit/hive/OmniGroupByOperator.java index f78cb162e..917c42216 100644 --- a/omnioperator/omniop-hive-extension/src/main/java/com/huawei/boostkit/hive/OmniGroupByOperator.java +++ b/omnioperator/omniop-hive-extension/src/main/java/com/huawei/boostkit/hive/OmniGroupByOperator.java @@ -585,6 +585,8 @@ public class OmniGroupByOperator extends OmniHiveOperator imple case OMNI_DECIMAL64: if (exprValue instanceof Timestamp) { ((LongVec) newVec).set(i, ((Timestamp) exprValue).toEpochMilli()); + } else if (exprValue instanceof HiveDecimal) { + ((LongVec) newVec).set(i, ((HiveDecimal) exprValue).unscaledValue().longValue()); } else { ((LongVec) newVec).set(i, (long) exprValue); } diff --git a/omnioperator/omniop-hive-extension/src/main/java/com/huawei/boostkit/hive/OmniVectorizedVectorOperator.java b/omnioperator/omniop-hive-extension/src/main/java/com/huawei/boostkit/hive/OmniVectorizedVectorOperator.java index f692d3cb3..f4e4b402b 100644 --- a/omnioperator/omniop-hive-extension/src/main/java/com/huawei/boostkit/hive/OmniVectorizedVectorOperator.java +++ b/omnioperator/omniop-hive-extension/src/main/java/com/huawei/boostkit/hive/OmniVectorizedVectorOperator.java @@ -65,6 +65,7 @@ import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.StandardStructObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.StructField; import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector; +import org.apache.hadoop.hive.serde2.typeinfo.DecimalTypeInfo; import org.apache.hadoop.hive.serde2.typeinfo.PrimitiveTypeInfo; import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo; @@ -282,6 +283,10 @@ public class OmniVectorizedVectorOperator extends OmniHiveOperator 18 ? 16 : 8; + } else { + columnTypeLen[i] = TYPE_LEN.getOrDefault(this.categories[i], 0); + } if (columnTypeLen[i] == 0) { cache[i] = new VecBuffer(getEstimateLen((PrimitiveTypeInfo) typeInfos.get(i)), true); } else { @@ -126,7 +132,11 @@ public class VecBufferCache { } break; case DECIMAL: - vec = new Decimal128Vec(rowCount); + if (columnTypeLen[index] == 8) { + vec = new LongVec(rowCount); + } else { + vec = new Decimal128Vec(rowCount); + } break; default: throw new IllegalStateException("Unexpected value: " + categories[index]); diff --git a/omnioperator/omniop-hive-extension/src/main/java/com/huawei/boostkit/hive/converter/Decimal64VecConverter.java b/omnioperator/omniop-hive-extension/src/main/java/com/huawei/boostkit/hive/converter/Decimal64VecConverter.java index e3ae983ec..cc7e76c78 100644 --- a/omnioperator/omniop-hive-extension/src/main/java/com/huawei/boostkit/hive/converter/Decimal64VecConverter.java +++ b/omnioperator/omniop-hive-extension/src/main/java/com/huawei/boostkit/hive/converter/Decimal64VecConverter.java @@ -17,12 +17,8 @@ */ package com.huawei.boostkit.hive.converter; - -import com.huawei.boostkit.hive.cache.ColumnCache; -import com.huawei.boostkit.hive.cache.LongColumnCache; - -import nova.hetu.omniruntime.vector.Decimal128Vec; import nova.hetu.omniruntime.vector.DictionaryVec; +import nova.hetu.omniruntime.vector.LongVec; import nova.hetu.omniruntime.vector.Vec; import org.apache.hadoop.hive.common.type.DataTypePhysicalVariation; @@ -33,12 +29,9 @@ import org.apache.hadoop.hive.ql.metadata.HiveException; import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFLastValue; import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; import org.apache.hadoop.hive.serde2.typeinfo.DecimalTypeInfo; -import org.apache.hadoop.hive.serde2.typeinfo.PrimitiveTypeInfo; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import java.util.Map; - public class Decimal64VecConverter extends LongVecConverter { static final Logger LOG = LoggerFactory.getLogger(GenericUDAFLastValue.class.getName()); @@ -59,59 +52,16 @@ public class Decimal64VecConverter extends LongVecConverter { long value; if (vec instanceof DictionaryVec) { DictionaryVec dictionaryVec = (DictionaryVec) vec; - value = dictionaryVec.getDecimal128(i)[0]; + value = dictionaryVec.getLong(i); } else { - Decimal128Vec decimal128Vec = (Decimal128Vec) vec; - byte[] result = decimal128Vec.getBytes(i); - if (result.length > 8) { - byte[] newBytes = new byte[8]; - System.arraycopy(result, result.length - 8, newBytes, 0, 8); - value = Decimal128Vec.bytesToLong(newBytes); - } else { - value = Decimal128Vec.bytesToLong(result); - } + LongVec longVec = (LongVec) vec; + value = longVec.get(i); } decimal64ColumnVector.vector[i - start] = value; } return decimal64ColumnVector; } - @Override - public Vec toOmniVec(Object[] col, int columnSize, PrimitiveTypeInfo primitiveTypeInfo) { - Decimal128Vec decimal128Vec = new Decimal128Vec(columnSize); - for (int i = 0; i < columnSize; i++) { - if (col[i] == null) { - decimal128Vec.setNull(i); - } else { - long value = (long) col[i]; - decimal128Vec.setBigInteger(i, Decimal128Vec.longToBytes(value), value < 0L); - } - } - return decimal128Vec; - } - - @Override - public Vec toOmniVec(ColumnCache columnCache, int columnSize) { - Decimal128Vec decimal128Vec = new Decimal128Vec(columnSize); - LongColumnCache longColumnCache = (LongColumnCache) columnCache; - if (longColumnCache.noNulls) { - for (int i = 0; i < columnSize; i++) { - long value = longColumnCache.dataCache[i]; - decimal128Vec.setBigInteger(i, Decimal128Vec.longToBytes(value), value < 0L); - } - } else { - for (int i = 0; i < columnSize; i++) { - if (longColumnCache.isNull[i]) { - decimal128Vec.setNull(i); - } else { - long value = longColumnCache.dataCache[i]; - decimal128Vec.setBigInteger(i, Decimal128Vec.longToBytes(value), value < 0L); - } - } - } - return decimal128Vec; - } - public static boolean isConvertedDecimal64(String fieldName, VectorizationContext vectorizationContext) { boolean convertedDecimal64 = false; try { diff --git a/omnioperator/omniop-hive-extension/src/main/java/com/huawei/boostkit/hive/converter/DecimalVecConverter.java b/omnioperator/omniop-hive-extension/src/main/java/com/huawei/boostkit/hive/converter/DecimalVecConverter.java index 0e2fa94fb..3e00bb1d8 100644 --- a/omnioperator/omniop-hive-extension/src/main/java/com/huawei/boostkit/hive/converter/DecimalVecConverter.java +++ b/omnioperator/omniop-hive-extension/src/main/java/com/huawei/boostkit/hive/converter/DecimalVecConverter.java @@ -21,9 +21,12 @@ package com.huawei.boostkit.hive.converter; import com.huawei.boostkit.hive.cache.ColumnCache; import com.huawei.boostkit.hive.cache.DecimalColumnCache; +import com.huawei.boostkit.hive.cache.LongColumnCache; +import nova.hetu.omniruntime.type.LongDataType; import nova.hetu.omniruntime.utils.OmniRuntimeException; import nova.hetu.omniruntime.vector.Decimal128Vec; import nova.hetu.omniruntime.vector.DictionaryVec; +import nova.hetu.omniruntime.vector.LongVec; import nova.hetu.omniruntime.vector.Vec; import org.apache.hadoop.hive.common.type.HiveDecimal; @@ -54,7 +57,16 @@ public class DecimalVecConverter implements VecConverter { DecimalTypeInfo decimalTypeInfo = (DecimalTypeInfo) primitiveObjectInspector.getTypeInfo(); if (vec instanceof DictionaryVec) { DictionaryVec dictionaryVec = (DictionaryVec) vec; + if (dictionaryVec.getType().getId() == LongDataType.LONG.getId()) { + HiveDecimalWritable hiveDecimalWritable = new HiveDecimalWritable(); + hiveDecimalWritable.setFromLongAndScale(dictionaryVec.getLong(index), decimalTypeInfo.getScale()); + return hiveDecimalWritable; + } return getDecimalWritableFromLong(dictionaryVec.getDecimal128(index), decimalTypeInfo.getScale()); + } else if (vec instanceof LongVec) { + HiveDecimalWritable hiveDecimalWritable = new HiveDecimalWritable(); + hiveDecimalWritable.setFromLongAndScale(((LongVec) vec).get(index), decimalTypeInfo.getScale()); + return hiveDecimalWritable; } Decimal128Vec decimal128Vec = (Decimal128Vec) vec; byte[] result = decimal128Vec.getBytes(index); @@ -93,6 +105,9 @@ public class DecimalVecConverter implements VecConverter { @Override public Vec toOmniVec(Object[] col, int columnSize, PrimitiveTypeInfo primitiveTypeInfo) { DecimalTypeInfo decimalTypeInfo = (DecimalTypeInfo) primitiveTypeInfo; + if (decimalTypeInfo.getPrecision() <=18) { + return toOmniLongVec(col, columnSize, decimalTypeInfo); + } Decimal128Vec decimal128Vec = new Decimal128Vec(columnSize); for (int i = 0; i < columnSize; i++) { if (col[i] == null) { @@ -106,8 +121,32 @@ public class DecimalVecConverter implements VecConverter { return decimal128Vec; } + private Vec toOmniLongVec(Object[] col, int columnSize, DecimalTypeInfo decimalTypeInfo) { + LongVec longVec = new LongVec(columnSize); + for (int i = 0; i < columnSize; i++) { + if (col[i] == null) { + longVec.setNull(i); + } else { + HiveDecimal hiveDecimal = (HiveDecimal) col[i]; + longVec.set(i, getLongFromHiveDecimal((HiveDecimal) col[i], decimalTypeInfo)); + } + } + return longVec; + } + + private long getLongFromHiveDecimal(HiveDecimal hiveDecimal, DecimalTypeInfo decimalTypeInfo) { + if (hiveDecimal.scale() < decimalTypeInfo.getScale()) { + return hiveDecimal.scaleByPowerOfTen(decimalTypeInfo.getScale()).longValue(); + } else { + return hiveDecimal.unscaledValue().longValue(); + } + } + @Override public Vec toOmniVec(ColumnCache columnCache, int columnSize, PrimitiveTypeInfo primitiveTypeInfo) { + if (columnCache instanceof LongColumnCache) { + return toOmniVecLong(columnCache, columnSize); + } Decimal128Vec decimal128Vec = new Decimal128Vec(columnSize); DecimalColumnCache decimal128ColumnCache = (DecimalColumnCache) columnCache; byte[] value = new byte[columnSize * 16]; @@ -130,10 +169,33 @@ public class DecimalVecConverter implements VecConverter { return decimal128Vec; } + private Vec toOmniVecLong(ColumnCache columnCache, int columnSize) { + LongVec longVec = new LongVec(columnSize); + LongColumnCache longColumnCache = (LongColumnCache) columnCache; + if (longColumnCache.noNulls) { + for (int i = 0; i < columnSize; i++) { + longVec.set(i, longColumnCache.dataCache[i]); + } + } else { + for (int i = 0; i < columnSize; i++) { + if (longColumnCache.isNull[i]) { + longVec.setNull(i); + } else { + longVec.set(i, longColumnCache.dataCache[i]); + } + } + } + return longVec; + } + @Override public void setValueFromColumnVector(VectorizedRowBatch vectorizedRowBatch, int vectorColIndex, ColumnCache columnCache, int colIndex, int rowCount, PrimitiveTypeInfo primitiveTypeInfo) { + if (columnCache instanceof LongColumnCache) { + setValueFromColumnVectorLong(vectorizedRowBatch, vectorColIndex, columnCache, colIndex, rowCount, primitiveTypeInfo); + return; + } ColumnVector columnVector = vectorizedRowBatch.cols[vectorColIndex]; DecimalColumnCache decimalColumnCache = (DecimalColumnCache) columnCache; DecimalTypeInfo decimalTypeInfo = (DecimalTypeInfo) primitiveTypeInfo; @@ -185,6 +247,57 @@ public class DecimalVecConverter implements VecConverter { } } + private void setValueFromColumnVectorLong(VectorizedRowBatch vectorizedRowBatch, int vectorColIndex, ColumnCache columnCache, int colIndex, int rowCount, PrimitiveTypeInfo primitiveTypeInfo) { + ColumnVector columnVector = vectorizedRowBatch.cols[vectorColIndex]; + LongColumnCache longColumnCache = (LongColumnCache) columnCache; + DecimalTypeInfo decimalTypeInfo = (DecimalTypeInfo) primitiveTypeInfo; + HiveDecimalWritable[] vector = ((DecimalColumnVector) columnVector).vector; + if (!columnVector.noNulls) { + longColumnCache.noNulls = false; + } + if (columnVector.isRepeating) { + if (columnVector.isNull[0]) { + for (int i = 0; i < vectorizedRowBatch.size; i++) { + longColumnCache.isNull[rowCount + i] = true; + } + } else { + for (int i = 0; i < vectorizedRowBatch.size; i++) { + longColumnCache.dataCache[rowCount + i] = getLongFromHiveDecimal(vector[0].getHiveDecimal(), decimalTypeInfo); + } + } + } else if (vectorizedRowBatch.selectedInUse) { + if (columnVector.noNulls) { + for (int i = 0; i < vectorizedRowBatch.size; i++) { + longColumnCache.dataCache[rowCount + i] = getLongFromHiveDecimal(vector[vectorizedRowBatch.selected[i]].getHiveDecimal(), + decimalTypeInfo); + } + } else { + for (int i = 0; i < vectorizedRowBatch.size; i++) { + if (columnVector.isNull[vectorizedRowBatch.selected[i]]) { + longColumnCache.isNull[rowCount + i] = true; + } else { + longColumnCache.dataCache[rowCount + i] = getLongFromHiveDecimal( + vector[vectorizedRowBatch.selected[i]].getHiveDecimal(), decimalTypeInfo); + } + } + } + } else { + if (columnVector.noNulls) { + for (int i = 0; i < vectorizedRowBatch.size; i++) { + longColumnCache.dataCache[rowCount + i] = getLongFromHiveDecimal(vector[i].getHiveDecimal(), decimalTypeInfo); + } + } else { + for (int i = 0; i < vectorizedRowBatch.size; i++) { + if (columnVector.isNull[i]) { + longColumnCache.isNull[rowCount + i] = true; + } else { + longColumnCache.dataCache[rowCount + i] = getLongFromHiveDecimal(vector[i].getHiveDecimal(), decimalTypeInfo); + } + } + } + } + } + @Override public ColumnVector getColumnVectorFromOmniVec(Vec vec, int start, int end, PrimitiveObjectInspector primitiveObjectInspector) { diff --git a/omnioperator/omniop-hive-extension/src/main/java/com/huawei/boostkit/hive/expression/ExpressionUtils.java b/omnioperator/omniop-hive-extension/src/main/java/com/huawei/boostkit/hive/expression/ExpressionUtils.java index 97f3e5537..112739d40 100644 --- a/omnioperator/omniop-hive-extension/src/main/java/com/huawei/boostkit/hive/expression/ExpressionUtils.java +++ b/omnioperator/omniop-hive-extension/src/main/java/com/huawei/boostkit/hive/expression/ExpressionUtils.java @@ -158,15 +158,17 @@ public class ExpressionUtils { BaseExpression leaf; Object value = ((ExprNodeConstantDesc) next).getValue(); int omniType = TypeUtils.convertHiveTypeToOmniType(next.getTypeInfo()); - if (omniType == Decimal128DataType.DECIMAL128.getId().toValue()) { + if (omniType == Decimal128DataType.DECIMAL128.getId().toValue() || omniType == Decimal64DataType.DECIMAL64.getId().toValue()) { int scale = ((DecimalTypeInfo) next.getTypeInfo()).getScale(); - String realValue; + Object realValue; if (value == null) { realValue = null; - } else { + } else if (omniType == Decimal128DataType.DECIMAL128.getId().toValue()) { realValue = new BigInteger(((HiveDecimal) value).bigIntegerBytesScaled(scale)).toString(); + } else { + realValue = ((HiveDecimal) value).scaleByPowerOfTen(scale).longValue(); } - leaf = new DecimalLiteral(realValue, Decimal128DataType.DECIMAL128.getId().toValue(), + leaf = new DecimalLiteral(realValue, omniType, ((DecimalTypeInfo) next.getTypeInfo()).getPrecision(), scale); } else { leaf = new LiteralFactor<>("LITERAL", null, null, TypeUtils.getLiteralValue(value, next.getTypeInfo()), @@ -215,10 +217,13 @@ public class ExpressionUtils { public static BaseExpression preCast(BaseExpression castExpression, ExprNodeDesc castNodeDesc, ExprNodeDesc comparedNode) { TypeInfo baseTypeInfo = comparedNode.getTypeInfo(); - if (castNodeDesc instanceof ExprNodeConstantDesc) { - if (baseTypeInfo instanceof DecimalTypeInfo && ((ExprNodeConstantDesc) castNodeDesc).getValue().equals(0)) { + if (castNodeDesc instanceof ExprNodeConstantDesc && baseTypeInfo instanceof DecimalTypeInfo && ((ExprNodeConstantDesc) castNodeDesc).getValue().equals(0)) { + if (((DecimalTypeInfo) baseTypeInfo).getPrecision() >18) { return new DecimalLiteral("0", TypeUtils.convertHiveTypeToOmniType(baseTypeInfo), ((DecimalTypeInfo) baseTypeInfo).getPrecision(), ((DecimalTypeInfo) baseTypeInfo).getScale()); + } else { + return new DecimalLiteral(0L, TypeUtils.convertHiveTypeToOmniType(baseTypeInfo), + ((DecimalTypeInfo) baseTypeInfo).getPrecision(), ((DecimalTypeInfo) baseTypeInfo).getScale()); } } Integer precision = null; @@ -258,16 +263,23 @@ public class ExpressionUtils { Integer toCastPrecision = castFunctionExpression.getPrecision(); Integer toCastScale = castFunctionExpression.getScale(); Integer scale = ((DecimalLiteral) expression).getScale(); + ((DecimalLiteral) expression).setDataType(castFunctionExpression.getReturnType()); if (scale >= toCastScale) { return expression; } - String value = (String) ((DecimalLiteral) expression).getValue(); + Object value = ((DecimalLiteral) expression).getValue(); ((DecimalLiteral) expression).setPrecision(toCastPrecision); ((DecimalLiteral) expression).setScale(toCastScale); - BigDecimal decimalValue = new BigDecimal(value); - BigDecimal newValue = decimalValue.multiply(BigDecimal.TEN.pow(toCastScale - scale)); - ((DecimalLiteral) expression).setValue(newValue.toString()); - return expression; + if (value instanceof Long) { + long newValue = (long) Math.pow(10, toCastScale - scale) * (Long) value; + ((DecimalLiteral) expression).setValue(newValue); + return expression; + } else { + BigDecimal decimalValue = new BigDecimal((String) value); + BigDecimal newValue = decimalValue.multiply(BigDecimal.TEN.pow(toCastScale - scale)); + ((DecimalLiteral) expression).setValue(newValue.toString()); + return expression; + } } castFunctionExpression.add(expression); return castFunctionExpression; diff --git a/omnioperator/omniop-hive-extension/src/main/java/com/huawei/boostkit/hive/expression/LiteralFactor.java b/omnioperator/omniop-hive-extension/src/main/java/com/huawei/boostkit/hive/expression/LiteralFactor.java index b1380be8d..4a668fd96 100644 --- a/omnioperator/omniop-hive-extension/src/main/java/com/huawei/boostkit/hive/expression/LiteralFactor.java +++ b/omnioperator/omniop-hive-extension/src/main/java/com/huawei/boostkit/hive/expression/LiteralFactor.java @@ -46,6 +46,11 @@ public class LiteralFactor extends BaseExpression { } + @Override + public Integer getReturnType() { + return dataType; + } + @Override public boolean isFull() { return true; @@ -71,6 +76,10 @@ public class LiteralFactor extends BaseExpression { return dataType; } + public void setDataType(Integer dataType) { + this.dataType= dataType; + } + public T getValue() { return value; } diff --git a/omnioperator/omniop-hive-extension/src/main/java/com/huawei/boostkit/hive/expression/TypeUtils.java b/omnioperator/omniop-hive-extension/src/main/java/com/huawei/boostkit/hive/expression/TypeUtils.java index bce826311..0a2608601 100644 --- a/omnioperator/omniop-hive-extension/src/main/java/com/huawei/boostkit/hive/expression/TypeUtils.java +++ b/omnioperator/omniop-hive-extension/src/main/java/com/huawei/boostkit/hive/expression/TypeUtils.java @@ -43,6 +43,7 @@ import nova.hetu.omniruntime.type.BooleanDataType; import nova.hetu.omniruntime.type.DataType; import nova.hetu.omniruntime.type.DataTypeSerializer; import nova.hetu.omniruntime.type.Decimal128DataType; +import nova.hetu.omniruntime.type.Decimal64DataType; import nova.hetu.omniruntime.type.DoubleDataType; import nova.hetu.omniruntime.type.IntDataType; import nova.hetu.omniruntime.type.LongDataType; @@ -221,8 +222,13 @@ public class TypeUtils { public static DataType buildInputDataType(TypeInfo typeInfo) { if (typeInfo instanceof DecimalTypeInfo) { - return new Decimal128DataType(((DecimalTypeInfo) typeInfo).getPrecision(), - ((DecimalTypeInfo) typeInfo).getScale()); + if (((DecimalTypeInfo) typeInfo).getPrecision() > 18) { + return new Decimal128DataType(((DecimalTypeInfo) typeInfo).getPrecision(), + ((DecimalTypeInfo) typeInfo).getScale()); + } else { + return new Decimal64DataType(((DecimalTypeInfo) typeInfo).getPrecision(), + ((DecimalTypeInfo) typeInfo).getScale()); + } } else if (typeInfo instanceof BaseCharTypeInfo) { return new VarcharDataType(((BaseCharTypeInfo) typeInfo).getLength()); } @@ -341,11 +347,10 @@ public class TypeUtils { private static boolean checkUnsupportedDecimal(Integer dataType, Integer returnType, CastFunctionExpression castFunctionExpression) { // not support Cast(double/decimal as decimal) - if ((dataType == OMNI_DOUBLE.toValue() || dataType == OMNI_DECIMAL64.toValue()) - && (returnType == OMNI_DECIMAL64.toValue() || returnType == OMNI_DECIMAL128.toValue())) { + if (dataType == OMNI_DOUBLE.toValue() && (returnType == OMNI_DECIMAL64.toValue() || returnType == OMNI_DECIMAL128.toValue())) { return true; } - if (dataType == OMNI_DECIMAL128.toValue() && returnType == OMNI_DECIMAL128.toValue()) { + if ((dataType == OMNI_DECIMAL128.toValue() || dataType == OMNI_DECIMAL64.toValue()) && (returnType == OMNI_DECIMAL128.toValue() || dataType == OMNI_DECIMAL64.toValue())) { BaseExpression child = castFunctionExpression.getArguments().get(0); if (child instanceof DecimalReference && !((DecimalReference) child).getScale().equals(castFunctionExpression.getScale())) { return true; diff --git a/omnioperator/omniop-hive-extension/src/main/java/com/huawei/boostkit/hive/processor/ArithmeticExpressionProcessor.java b/omnioperator/omniop-hive-extension/src/main/java/com/huawei/boostkit/hive/processor/ArithmeticExpressionProcessor.java index 097d17cd6..ae587ba67 100644 --- a/omnioperator/omniop-hive-extension/src/main/java/com/huawei/boostkit/hive/processor/ArithmeticExpressionProcessor.java +++ b/omnioperator/omniop-hive-extension/src/main/java/com/huawei/boostkit/hive/processor/ArithmeticExpressionProcessor.java @@ -26,6 +26,8 @@ import com.huawei.boostkit.hive.expression.DivideExpression; import com.huawei.boostkit.hive.expression.ExpressionUtils; import com.huawei.boostkit.hive.expression.TypeUtils; +import nova.hetu.omniruntime.type.Decimal128DataType; +import org.apache.hadoop.hive.common.type.Decimal128; import org.apache.hadoop.hive.ql.plan.ExprNodeDesc; import org.apache.hadoop.hive.ql.plan.ExprNodeGenericFuncDesc; import org.apache.hadoop.hive.ql.udf.generic.GenericUDF; @@ -60,12 +62,18 @@ public class ArithmeticExpressionProcessor implements ExpressionProcessor { boolean needUpscale = functionName.equals("GenericUDFOPMultiply") || functionName.equals("GenericUDFOPDivide") || functionName.equals("GenericUDFOPMod"); int maxScale = 0; + boolean hasDecimal128 = false; if (node.getTypeInfo() instanceof DecimalTypeInfo) { precision = ((DecimalTypeInfo) node.getTypeInfo()).getPrecision(); scale = ((DecimalTypeInfo) node.getTypeInfo()).getScale(); maxScale = getMaxScale(children, scale); + for (ExprNodeDesc child : children) { + if (child.getTypeInfo() instanceof DecimalTypeInfo && ((DecimalTypeInfo) child.getTypeInfo()).getPrecision() > 18) { + hasDecimal128 = true; + } + } } - int returnType = TypeUtils.convertHiveTypeToOmniType(node.getTypeInfo()); + int returnType = hasDecimal128 ? Decimal128DataType.DECIMAL128.getId().toValue() : TypeUtils.convertHiveTypeToOmniType(node.getTypeInfo()); DivideExpression compareExpression = new DivideExpression(returnType, OPERATOR.get(node.getGenericUDF().getClass()), precision, scale); for (ExprNodeDesc child : children) { diff --git a/omnioperator/omniop-hive-extension/src/main/java/com/huawei/boostkit/hive/processor/CaseWhenExpressionProcessor.java b/omnioperator/omniop-hive-extension/src/main/java/com/huawei/boostkit/hive/processor/CaseWhenExpressionProcessor.java index 062c8f511..c5a862a40 100644 --- a/omnioperator/omniop-hive-extension/src/main/java/com/huawei/boostkit/hive/processor/CaseWhenExpressionProcessor.java +++ b/omnioperator/omniop-hive-extension/src/main/java/com/huawei/boostkit/hive/processor/CaseWhenExpressionProcessor.java @@ -90,7 +90,7 @@ public class CaseWhenExpressionProcessor implements ExpressionProcessor { TypeUtils.convertHiveTypeToOmniType(root.getTypeInfo())); } if (root.getTypeInfo() instanceof DecimalTypeInfo && ((ExprNodeConstantDesc) nodeDesc).getValue().equals(0)) { - return new DecimalLiteral("0", TypeUtils.convertHiveTypeToOmniType(root.getTypeInfo()), + return new DecimalLiteral(((DecimalTypeInfo) root.getTypeInfo()).getPrecision() > 18 ? "0" : 0L, TypeUtils.convertHiveTypeToOmniType(root.getTypeInfo()), ((DecimalTypeInfo) root.getTypeInfo()).getPrecision(), ((DecimalTypeInfo) root.getTypeInfo()).getScale()); } diff --git a/omnioperator/omniop-hive-extension/src/main/java/com/huawei/boostkit/hive/processor/ComputeExpressionProcessor.java b/omnioperator/omniop-hive-extension/src/main/java/com/huawei/boostkit/hive/processor/ComputeExpressionProcessor.java index 4a0cb378f..b197eee03 100644 --- a/omnioperator/omniop-hive-extension/src/main/java/com/huawei/boostkit/hive/processor/ComputeExpressionProcessor.java +++ b/omnioperator/omniop-hive-extension/src/main/java/com/huawei/boostkit/hive/processor/ComputeExpressionProcessor.java @@ -19,13 +19,17 @@ package com.huawei.boostkit.hive.processor; import com.huawei.boostkit.hive.expression.BaseExpression; +import com.huawei.boostkit.hive.expression.CastFunctionExpression; import com.huawei.boostkit.hive.expression.CompareExpression; import com.huawei.boostkit.hive.expression.ExpressionUtils; import com.huawei.boostkit.hive.expression.TypeUtils; +import nova.hetu.omniruntime.type.Decimal128DataType; +import nova.hetu.omniruntime.type.Decimal64DataType; import org.apache.hadoop.hive.ql.plan.ExprNodeConstantDesc; import org.apache.hadoop.hive.ql.plan.ExprNodeDesc; import org.apache.hadoop.hive.ql.plan.ExprNodeGenericFuncDesc; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.typeinfo.DecimalTypeInfo; import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo; import java.util.List; @@ -34,16 +38,31 @@ public class ComputeExpressionProcessor implements ExpressionProcessor { @Override public BaseExpression process(ExprNodeGenericFuncDesc node, String operator, ObjectInspector inspector) { List children = node.getChildren(); + boolean hasDecimal128 = false; + for (ExprNodeDesc child : children) { + if (child.getTypeInfo() instanceof DecimalTypeInfo && ((DecimalTypeInfo) child.getTypeInfo()).getPrecision() > 18) { + hasDecimal128 = true; + } + } conjureNodeType(children); CompareExpression compareExpression = new CompareExpression("BINARY", TypeUtils.convertHiveTypeToOmniType(node.getTypeInfo()), operator); for (ExprNodeDesc child : children) { + BaseExpression baseExpression; if (child instanceof ExprNodeGenericFuncDesc) { - compareExpression.add(ExpressionUtils.build((ExprNodeGenericFuncDesc) child, inspector)); + baseExpression = ExpressionUtils.build((ExprNodeGenericFuncDesc) child, inspector); } else { - compareExpression.add(ExpressionUtils.createNode(child, inspector)); + baseExpression = ExpressionUtils.createNode(child, inspector); + } + if (baseExpression.getReturnType() == Decimal64DataType.DECIMAL64.getId().toValue() && hasDecimal128) { + DecimalTypeInfo decimalTypeInfo = (DecimalTypeInfo) child.getTypeInfo(); + CastFunctionExpression castFunctionExpression = new CastFunctionExpression(Decimal128DataType.DECIMAL128.getId().toValue(), + TypeUtils.getCharWidth(child), decimalTypeInfo.getPrecision(), decimalTypeInfo.getScale()); + castFunctionExpression.add(baseExpression); + baseExpression = castFunctionExpression; } + compareExpression.add(baseExpression); } return compareExpression; } diff --git a/omnioperator/omniop-hive-extension/src/main/java/com/huawei/boostkit/hive/reader/OmniOrcRecordReader.java b/omnioperator/omniop-hive-extension/src/main/java/com/huawei/boostkit/hive/reader/OmniOrcRecordReader.java index e61c142d0..d223c6f1a 100644 --- a/omnioperator/omniop-hive-extension/src/main/java/com/huawei/boostkit/hive/reader/OmniOrcRecordReader.java +++ b/omnioperator/omniop-hive-extension/src/main/java/com/huawei/boostkit/hive/reader/OmniOrcRecordReader.java @@ -26,12 +26,12 @@ import static org.apache.hadoop.hive.serde2.ColumnProjectionUtils.READ_COLUMN_ID import nova.hetu.omniruntime.type.BooleanDataType; import nova.hetu.omniruntime.type.DataType; import nova.hetu.omniruntime.type.Decimal128DataType; +import nova.hetu.omniruntime.type.Decimal64DataType; import nova.hetu.omniruntime.type.DoubleDataType; import nova.hetu.omniruntime.type.IntDataType; import nova.hetu.omniruntime.type.LongDataType; import nova.hetu.omniruntime.type.ShortDataType; import nova.hetu.omniruntime.type.VarcharDataType; -import nova.hetu.omniruntime.vector.Decimal128Vec; import nova.hetu.omniruntime.vector.Vec; import nova.hetu.omniruntime.vector.VecBatch; @@ -41,23 +41,27 @@ import org.apache.hadoop.hive.ql.exec.Operator; import org.apache.hadoop.hive.ql.exec.Utilities; import org.apache.hadoop.hive.ql.io.StatsProvidingRecordReader; import org.apache.hadoop.hive.ql.io.orc.OrcFile; +import org.apache.hadoop.hive.ql.io.sarg.ConvertAstToSearchArg; import org.apache.hadoop.hive.ql.io.sarg.SearchArgument; import org.apache.hadoop.hive.ql.io.sarg.SearchArgumentImpl; import org.apache.hadoop.hive.ql.plan.api.OperatorType; +import org.apache.hadoop.hive.serde2.ColumnProjectionUtils; import org.apache.hadoop.hive.serde2.SerDeStats; -import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; import org.apache.hadoop.io.NullWritable; import org.apache.hadoop.mapred.FileSplit; import org.apache.hadoop.mapred.RecordReader; import org.apache.hive.com.esotericsoftware.kryo.Kryo; import org.apache.hive.com.esotericsoftware.kryo.io.Input; import org.apache.orc.OrcConf; +import org.apache.orc.OrcProto; +import org.apache.orc.OrcUtils; import org.apache.orc.TypeDescription; import java.io.IOException; import java.util.Arrays; import java.util.HashMap; import java.util.HashSet; +import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.Set; @@ -96,14 +100,23 @@ public class OmniOrcRecordReader implements RecordReader18) { + typeIds[i] = Decimal128DataType.DECIMAL128.getId().toValue(); + } else { + typeIds[i] = Decimal64DataType.DECIMAL64.getId().toValue(); + } + } else { + typeIds[i] = CATEGORY_TO_OMNI_TYPE.get(requiredSchema.getChildren().get(i).getCategory()).getId().toValue(); + } } recordReader = new OrcColumnarBatchScanReader(); recordReader.initializeReaderJava(split.getPath().toUri(), readerOptions); @@ -125,23 +138,48 @@ public class OmniOrcRecordReader implements RecordReader types = OrcUtils.getOrcTypes(schema); + options.searchArgument(sarg, getSargColumnNames(neededColumnNames.split(","), types, options.getInclude())); + } } return options; } + private static String[] getSargColumnNames(String[] originalColumnNames, + List types, boolean[] includedColumns) { + int rootColumn = 0; + String[] columnNames = new String[types.size() - rootColumn]; + int i = 0; + // The way this works is as such. originalColumnNames is the equivalent on getNeededColumns + // from TSOP. They are assumed to be in the same order as the columns in ORC file, AND they are + // assumed to be equivalent to the columns in includedColumns (because it was generated from + // the same column list at some point in the past), minus the subtype columns. Therefore, when + // we go thru all the top level ORC file columns that are included, in order, they match + // originalColumnNames. This way, we do not depend on names stored inside ORC for SARG leaf + // column name resolution (see mapSargColumns method). + for(int columnId: types.get(rootColumn).getSubtypesList()) { + if (includedColumns == null || includedColumns[columnId - rootColumn]) { + // this is guaranteed to be positive because types only have children + // ids greater than their own id. + columnNames[columnId - rootColumn] = originalColumnNames[i++]; + } + } + return columnNames; + } + private TypeDescription getRequiredSchema(TypeDescription schema) { Set requiredIds = new HashSet<>(included); TypeDescription result = TypeDescription.createStruct(); diff --git a/omnioperator/omniop-hive-extension/src/main/java/com/huawei/boostkit/hive/reader/OrcColumnarBatchScanReader.java b/omnioperator/omniop-hive-extension/src/main/java/com/huawei/boostkit/hive/reader/OrcColumnarBatchScanReader.java index 41e293fb9..b7be77550 100644 --- a/omnioperator/omniop-hive-extension/src/main/java/com/huawei/boostkit/hive/reader/OrcColumnarBatchScanReader.java +++ b/omnioperator/omniop-hive-extension/src/main/java/com/huawei/boostkit/hive/reader/OrcColumnarBatchScanReader.java @@ -216,7 +216,6 @@ public class OrcColumnarBatchScanReader { } } job.put("includedColumns", colToInclu.toArray()); - job.put("isDecimal64Transfor128", true); recordReader = jniReader.initializeRecordReader(reader, job); return recordReader; } diff --git a/omnioperator/omniop-hive-extension/src/main/java/com/huawei/boostkit/hive/shuffle/OmniVecBatchSerDe.java b/omnioperator/omniop-hive-extension/src/main/java/com/huawei/boostkit/hive/shuffle/OmniVecBatchSerDe.java index bdd00e40b..8400aaf8c 100644 --- a/omnioperator/omniop-hive-extension/src/main/java/com/huawei/boostkit/hive/shuffle/OmniVecBatchSerDe.java +++ b/omnioperator/omniop-hive-extension/src/main/java/com/huawei/boostkit/hive/shuffle/OmniVecBatchSerDe.java @@ -26,6 +26,7 @@ import org.apache.hadoop.hive.serde2.lazybinary.LazyBinaryUtils; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; import org.apache.hadoop.hive.serde2.typeinfo.BaseCharTypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.DecimalTypeInfo; import org.apache.hadoop.hive.serde2.typeinfo.PrimitiveTypeInfo; import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo; import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoFactory; @@ -117,7 +118,12 @@ public class OmniVecBatchSerDe extends AbstractSerDe { for (int i = 0; i < columnTypes.size(); i++) { PrimitiveObjectInspector.PrimitiveCategory primitiveCategory = ((PrimitiveTypeInfo) columnTypes.get(i)) .getPrimitiveCategory(); - columnTypeLen[i] = TYPE_LEN.getOrDefault(primitiveCategory, 0); + if (primitiveCategory == PrimitiveObjectInspector.PrimitiveCategory.DECIMAL) { + DecimalTypeInfo decimalTypeInfo = (DecimalTypeInfo) columnTypes.get(i); + columnTypeLen[i] = decimalTypeInfo.getPrecision() > 18 ? 16 : 8; + } else { + columnTypeLen[i] = TYPE_LEN.getOrDefault(primitiveCategory, 0); + } if (columnTypeLen[i] == 0) { writeLen = writeLen + getEstimateLen((PrimitiveTypeInfo) columnTypes.get(i)) + 4; columnSerDes[i] = new VariableWidthColumnSerDe(); -- Gitee