diff --git a/omnioperator/omniop-hive-extension/src/main/java/com/huawei/boostkit/hive/OmniExecuteWithHookContext.java b/omnioperator/omniop-hive-extension/src/main/java/com/huawei/boostkit/hive/OmniExecuteWithHookContext.java index 21e717a0d58d54bc52967f3416095426a0719c2b..ffcac1ab889e964030945b25e85dbddb1f082f54 100644 --- a/omnioperator/omniop-hive-extension/src/main/java/com/huawei/boostkit/hive/OmniExecuteWithHookContext.java +++ b/omnioperator/omniop-hive-extension/src/main/java/com/huawei/boostkit/hive/OmniExecuteWithHookContext.java @@ -18,10 +18,12 @@ package com.huawei.boostkit.hive; +import static com.huawei.boostkit.hive.expression.TypeUtils.checkUnsupportedTimestamp; import static com.huawei.boostkit.hive.expression.TypeUtils.checkOmniJsonWhiteList; import static com.huawei.boostkit.hive.expression.TypeUtils.checkUnsupportedArithmetic; import static com.huawei.boostkit.hive.expression.TypeUtils.checkUnsupportedCast; import static com.huawei.boostkit.hive.expression.TypeUtils.convertHiveTypeToOmniType; +import static com.huawei.boostkit.hive.expression.TypeUtils.isValidConversion; import static nova.hetu.omniruntime.constants.FunctionType.OMNI_AGGREGATION_TYPE_AVG; import static nova.hetu.omniruntime.constants.FunctionType.OMNI_AGGREGATION_TYPE_SUM; import static org.apache.hadoop.hive.serde.serdeConstants.SERIALIZATION_LIB; @@ -36,6 +38,7 @@ import static org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspe import static org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector.PrimitiveCategory.STRING; import static org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector.PrimitiveCategory.TIMESTAMP; import static org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector.PrimitiveCategory.VARCHAR; +import static org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector.PrimitiveCategory.VOID; import com.huawei.boostkit.hive.expression.BaseExpression; import com.huawei.boostkit.hive.expression.CastFunctionExpression; @@ -50,6 +53,7 @@ import nova.hetu.omniruntime.constants.FunctionType; import org.apache.hadoop.hive.common.type.HiveDecimal; import org.apache.hadoop.hive.conf.HiveConf; +import org.apache.hadoop.hive.metastore.api.FieldSchema; import org.apache.hadoop.hive.ql.QueryPlan; import org.apache.hadoop.hive.ql.exec.ExplainTask; import org.apache.hadoop.hive.ql.exec.MapJoinOperator; @@ -93,6 +97,7 @@ import org.apache.hadoop.hive.ql.plan.ReduceWork; import org.apache.hadoop.hive.ql.plan.SelectDesc; import org.apache.hadoop.hive.ql.plan.TableDesc; import org.apache.hadoop.hive.ql.plan.TableScanDesc; +import org.apache.hadoop.hive.ql.plan.VectorTableScanDesc; import org.apache.hadoop.hive.ql.plan.TezEdgeProperty; import org.apache.hadoop.hive.ql.plan.TezWork; import org.apache.hadoop.hive.ql.plan.UnionWork; @@ -131,7 +136,7 @@ public class OmniExecuteWithHookContext implements ExecuteWithHookContext { public static final Set SUPPORTED_JOIN = new HashSet<>(Arrays.asList(JoinDesc.INNER_JOIN, JoinDesc.LEFT_OUTER_JOIN, JoinDesc.FULL_OUTER_JOIN, JoinDesc.LEFT_SEMI_JOIN)); private static final Set SUPPORTED_TYPE = new HashSet<>(Arrays.asList(BOOLEAN, - SHORT, INT, LONG, DOUBLE, STRING, DATE, DECIMAL, VARCHAR, CHAR)); + SHORT, INT, LONG, DOUBLE, STRING, DATE, DECIMAL, VARCHAR, CHAR, TIMESTAMP, VOID)); private static final int DECIMAL64_MAX_PRECISION = 19; @@ -530,6 +535,21 @@ public class OmniExecuteWithHookContext implements ExecuteWithHookContext { if (tableMetadata != null && (!tableMetadata.getInputFormatClass().equals(OrcInputFormat.class) || tableMetadata.getParameters().getOrDefault("transactional", "").equals("true"))) { return false; } + if (tableScanDesc.isVectorized()) { + TypeInfo[] columnTypeInfos = ((VectorTableScanDesc) tableScanDesc.getVectorDesc()).getProjectedColumnTypeInfos(); + for (int id : tableScanDesc.getNeededColumnIDs()) { + if (columnTypeInfos[id].getTypeName() == "timestamp") { + return false; + } + } + } else if (tableMetadata != null && tableMetadata.getCols() != null) { + List colList = tableMetadata.getCols(); + for (int id : tableScanDesc.getNeededColumnIDs()) { + if (colList.get(id).getType() == "timestamp") { + return false; + } + } + } List> childOperators = op.getChildOperators(); for (Operator childOperator : childOperators) { if (childOperator.getType().equals(OperatorType.REDUCESINK) && reduceSinkDescUnReplaceable((ReduceSinkDesc) childOperator.getConf())) { @@ -619,7 +639,7 @@ public class OmniExecuteWithHookContext implements ExecuteWithHookContext { return replaceable; case FILTER: List colList = Collections.singletonList(((FilterDesc) operator.getConf()).getPredicate()); - if (!isUDFSupport(colList) || !isLegalDeciConstant(colList)) { + if (!isUDFSupport(colList) || !isLegalDeci(colList)) { return false; } boolean result = true; @@ -629,7 +649,7 @@ public class OmniExecuteWithHookContext implements ExecuteWithHookContext { for (Operator child : operator.getChildOperators()) { if (child.getType() != null && child.getType().equals(OperatorType.SELECT)) { SelectDesc conf = (SelectDesc) child.getConf(); - result = result && isUDFSupport(conf.getColList()) && isLegalDeciConstant(conf.getColList()); + result = result && isUDFSupport(conf.getColList()) && isLegalDeci(conf.getColList()); } } return result; @@ -698,6 +718,16 @@ public class OmniExecuteWithHookContext implements ExecuteWithHookContext { return false; } List windowFunctionDefs = ((WindowTableFunctionDef) conf.getFuncDef()).getWindowFunctions(); + for (WindowFunctionDef functionDef : windowFunctionDefs) { + if (functionDef.getArgs() == null) { + continue; + } + for (PTFExpressionDef expressionDef : functionDef.getArgs()) { + if (expressionDef.getExprNode() != null && expressionDef.getExprNode().getTypeInfo().getTypeName() == "timestamp") { + return false; + } + } + } if (!PTFSupportedAgg(windowFunctionDefs)) { return false; } @@ -848,7 +878,8 @@ public class OmniExecuteWithHookContext implements ExecuteWithHookContext { expressions.add(expr.toString()); current.getChildren().forEach(queue::offer); } else if ((current instanceof ExprNodeColumnDesc || current instanceof ExprNodeConstantDesc) - && !SUPPORTED_TYPE.contains(((PrimitiveTypeInfo) current.getTypeInfo()).getPrimitiveCategory())) { + && (!SUPPORTED_TYPE.contains(((PrimitiveTypeInfo) current.getTypeInfo()).getPrimitiveCategory()) + || checkUnsupportedTimestamp(current))) { return false; } } @@ -859,40 +890,16 @@ public class OmniExecuteWithHookContext implements ExecuteWithHookContext { return checkOmniJsonWhiteList("", expressions.toArray(new String[0])); } - private boolean isLegalDeciConstant(List colList) { + private boolean isLegalDeci(List colList) { for (ExprNodeDesc desc : colList) { - if (!checkDecimalConstant(desc)) { + if (!isValidConversion(desc)) { return false; } } if (colList.size() > 0 && colList.get(0).getChildren() != null) { List childList = colList.get(0).getChildren(); for (ExprNodeDesc desc : childList) { - if (!checkDecimalConstant(desc)) { - return false; - } - } - } - return true; - } - - private boolean checkDecimalConstant(ExprNodeDesc desc) { - if (desc instanceof ExprNodeGenericFuncDesc && desc.getChildren() != null && desc.getChildren().size() == 2) { - List child = desc.getChildren(); - if (child.get(0) instanceof ExprNodeConstantDesc && child.get(1) instanceof ExprNodeColumnDesc) { - Collections.swap(child, 0, 1); - } - if (child.get(0) instanceof ExprNodeColumnDesc && child.get(1) instanceof ExprNodeConstantDesc) { - TypeInfo deciInfo = child.get(0).getTypeInfo(); - TypeInfo constInfo = child.get(1).getTypeInfo(); - if (!(deciInfo instanceof DecimalTypeInfo && constInfo instanceof DecimalTypeInfo)) { - return true; - } - int deciPrecision = ((DecimalTypeInfo) deciInfo).getPrecision(); - int deciScale = ((DecimalTypeInfo) deciInfo).getScale(); - int constPrecision = ((DecimalTypeInfo) constInfo).getPrecision(); - int constScale = ((DecimalTypeInfo) constInfo).getScale(); - if (constPrecision - constScale > deciPrecision - deciScale || constScale > deciScale) { + if (!isValidConversion(desc)) { return false; } } diff --git a/omnioperator/omniop-hive-extension/src/main/java/com/huawei/boostkit/hive/OmniGroupByOperator.java b/omnioperator/omniop-hive-extension/src/main/java/com/huawei/boostkit/hive/OmniGroupByOperator.java index 8105b4c1a38579bc59def1225c6e919c01bc46af..485bf09eed0826d3963cd15165c5347a02d041c5 100644 --- a/omnioperator/omniop-hive-extension/src/main/java/com/huawei/boostkit/hive/OmniGroupByOperator.java +++ b/omnioperator/omniop-hive-extension/src/main/java/com/huawei/boostkit/hive/OmniGroupByOperator.java @@ -20,6 +20,8 @@ package com.huawei.boostkit.hive; import static com.huawei.boostkit.hive.expression.TypeUtils.buildInputDataType; import static nova.hetu.omniruntime.constants.FunctionType.OMNI_AGGREGATION_TYPE_COUNT_ALL; +import static nova.hetu.omniruntime.type.DataType.DataTypeId.OMNI_CHAR; +import static nova.hetu.omniruntime.type.DataType.DataTypeId.OMNI_VARCHAR; import static org.apache.hadoop.hive.ql.exec.GroupByOperator.groupingSet2BitSet; import static org.apache.hadoop.hive.ql.exec.GroupByOperator.shouldEmitSummaryRow; @@ -29,6 +31,8 @@ import com.huawei.boostkit.hive.expression.TypeUtils; import javolution.util.FastBitSet; import nova.hetu.omniruntime.constants.FunctionType; import nova.hetu.omniruntime.operator.OmniOperator; +import nova.hetu.omniruntime.operator.OmniOperatorFactory; +import nova.hetu.omniruntime.operator.aggregator.OmniAggregationWithExprOperatorFactory; import nova.hetu.omniruntime.operator.aggregator.OmniHashAggregationWithExprOperatorFactory; import nova.hetu.omniruntime.operator.config.OperatorConfig; import nova.hetu.omniruntime.operator.config.OverflowConfig; @@ -98,7 +102,7 @@ import java.util.Set; public class OmniGroupByOperator extends OmniHiveOperator implements Serializable, VectorizationContextRegion, IConfigureJobConf { private static final long serialVersionUID = 1L; private static final Logger LOG = LoggerFactory.getLogger(OmniGroupByOperator.class.getName()); - private transient OmniHashAggregationWithExprOperatorFactory omniHashAggregationWithExprOperatorFactory; + private transient OmniOperatorFactory omniOperatorFactory; private transient OmniOperator omniOperator; private transient List keyFields; private transient boolean firstRow; @@ -286,10 +290,16 @@ public class OmniGroupByOperator extends OmniHiveOperator imple groupByChanel = getExprFromExprNode(keyFields); aggChannels = getTwoDimenExprFromExprNode(aggChannelFields); } - omniHashAggregationWithExprOperatorFactory = new OmniHashAggregationWithExprOperatorFactory(groupByChanel, aggChannels, - aggChannelsFilter, sourceTypes, aggFunctionTypes, aggOutputTypes, isInputRaws, isOutputPartials, - operatorConfig); - omniOperator = omniHashAggregationWithExprOperatorFactory.createOperator(); + if (numKeys == 0) { + omniOperatorFactory = new OmniAggregationWithExprOperatorFactory(groupByChanel, aggChannels, + aggChannelsFilter, sourceTypes, aggFunctionTypes, aggOutputTypes, isInputRaws, isOutputPartials, + operatorConfig); + } else { + omniOperatorFactory = new OmniHashAggregationWithExprOperatorFactory(groupByChanel, aggChannels, + aggChannelsFilter, sourceTypes, aggFunctionTypes, aggOutputTypes, isInputRaws, isOutputPartials, + operatorConfig); + } + omniOperator = omniOperatorFactory.createOperator(); } @Override @@ -475,14 +485,13 @@ public class OmniGroupByOperator extends OmniHiveOperator imple int rowCount = vec.getSize(); int groupingSetSize = groupingSets.size(); Vec newVec = VecFactory.createFlatVec(rowCount * groupingSetSize, vec.getType()); - Vec flatVec = vec; - if (vec instanceof DictionaryVec) { - flatVec = ((DictionaryVec) vec).expandDictionary(); - } + Vec flatVec = (vec instanceof DictionaryVec) ? ((DictionaryVec) vec).expandDictionary() : vec; + byte[] rawValueNulls = vec.getRawValueNulls(); + DataType.DataTypeId dataTypeId = vec.getType().getId(); + int[] rawValueOffset = (dataTypeId == OMNI_VARCHAR || dataTypeId == OMNI_CHAR) ? ((VarcharVec) flatVec).getRawValueOffset() : new int[0]; for (int i = 0; i < groupingSetSize; i++) { + newVec.setNulls(i * rowCount, rawValueNulls, 0, rowCount); if ((groupingSets.get(i) & mask) == 0) { - DataType.DataTypeId dataTypeId = vec.getType().getId(); - newVec.setNulls(i * rowCount, vec.getValuesNulls(0, rowCount), 0, rowCount); switch (dataTypeId) { case OMNI_INT: case OMNI_DATE32: @@ -509,14 +518,14 @@ public class OmniGroupByOperator extends OmniHiveOperator imple case OMNI_VARCHAR: case OMNI_CHAR: ((VarcharVec) newVec).put(i * rowCount, ((VarcharVec) flatVec).get(0, rowCount), 0, - ((VarcharVec) flatVec).getValueOffset(0, rowCount), 0, rowCount); + rawValueOffset, 0, rowCount); break; default: throw new RuntimeException("Not support dataType, dataTypeId: " + dataTypeId); } } else { - boolean[] nulls = new boolean[rowCount]; - Arrays.fill(nulls, true); + byte[] nulls = new byte[rowCount]; + Arrays.fill(nulls, (byte) 1); newVec.setNulls(i * rowCount, nulls, 0, rowCount); } } @@ -570,38 +579,42 @@ public class OmniGroupByOperator extends OmniHiveOperator imple Vec newVec = VecFactory.createFlatVec(rowCount, dataType); DataType.DataTypeId dataTypeId = dataType.getId(); for (int i = 0; i < rowCount; i++) { + Object exprValue = exprNodeConstantEvaluator.getExpr().getValue(); + if (exprValue == null) { + newVec.setNull(i); + continue; + } switch (dataTypeId) { case OMNI_INT: case OMNI_DATE32: - ((IntVec) newVec).set(i, (int) exprNodeConstantEvaluator.getExpr().getValue()); + ((IntVec) newVec).set(i, (int) exprValue); break; case OMNI_LONG: case OMNI_DATE64: case OMNI_DECIMAL64: - Object exprValue = exprNodeConstantEvaluator.getExpr().getValue(); if (exprValue instanceof Timestamp) { ((LongVec) newVec).set(i, ((Timestamp) exprValue).toEpochMilli()); } else { - ((LongVec) newVec).set(i, (long) exprNodeConstantEvaluator.getExpr().getValue()); + ((LongVec) newVec).set(i, (long) exprValue); } break; case OMNI_DOUBLE: - ((DoubleVec) newVec).set(i, (double) exprNodeConstantEvaluator.getExpr().getValue()); + ((DoubleVec) newVec).set(i, (double) exprValue); break; case OMNI_BOOLEAN: - ((BooleanVec) newVec).set(i, (boolean) exprNodeConstantEvaluator.getExpr().getValue()); + ((BooleanVec) newVec).set(i, (boolean) exprValue); break; case OMNI_SHORT: - ((ShortVec) newVec).set(i, (short) exprNodeConstantEvaluator.getExpr().getValue()); + ((ShortVec) newVec).set(i, (short) exprValue); break; case OMNI_DECIMAL128: - HiveDecimal hiveDecimal = (HiveDecimal) exprNodeConstantEvaluator.getExpr().getValue(); + HiveDecimal hiveDecimal = (HiveDecimal) exprValue; DecimalTypeInfo decimalTypeInfo = (DecimalTypeInfo) exprNodeConstantEvaluator.getExpr().getTypeInfo(); ((Decimal128Vec) newVec).setBigInteger(i, hiveDecimal.bigIntegerBytesScaled(decimalTypeInfo.getScale()), hiveDecimal.signum() == -1); break; case OMNI_VARCHAR: case OMNI_CHAR: - ((VarcharVec) newVec).set(i, exprNodeConstantEvaluator.getExpr().getValue().toString().getBytes()); + ((VarcharVec) newVec).set(i, exprValue.toString().getBytes()); break; default: throw new RuntimeException("Not support dataType, dataTypeId: " + dataTypeId); @@ -712,7 +725,7 @@ public class OmniGroupByOperator extends OmniHiveOperator imple for (Vec vec : constantVec) { vec.close(); } - omniHashAggregationWithExprOperatorFactory.close(); + omniOperatorFactory.close(); omniOperator.close(); super.closeOp(abort); } diff --git a/omnioperator/omniop-hive-extension/src/main/java/com/huawei/boostkit/hive/OmniMapJoinOperator.java b/omnioperator/omniop-hive-extension/src/main/java/com/huawei/boostkit/hive/OmniMapJoinOperator.java index 75ab7c35dda65b9aa8b25bdfd4acacb4d4fe74a6..2235b2dcb02045f1222f027bf217df85094c151a 100644 --- a/omnioperator/omniop-hive-extension/src/main/java/com/huawei/boostkit/hive/OmniMapJoinOperator.java +++ b/omnioperator/omniop-hive-extension/src/main/java/com/huawei/boostkit/hive/OmniMapJoinOperator.java @@ -735,11 +735,17 @@ public class OmniMapJoinOperator extends AbstractMapJoinOperator inputColNameToExprName.put(exprNodeColumnDesc.getColumn(), entry.getKey()); } List fields = ((StructObjectInspector) inputObjInspectors[posBigTable]).getAllStructFieldRefs(); - List fieldNames = fields.stream().map(StructField::getFieldName).collect(Collectors.toList()); + List fieldNames = fields.stream().map(field -> inputColNameToExprName.getOrDefault( + field.getFieldName().replace("value.", "VALUE.").replace("key.", "KEY."), + field.getFieldName()).replace("value.", "").replace("key.", "") + ).collect(Collectors.toList()); List inspectors = fields.stream().map(StructField::getFieldObjectInspector).collect(Collectors.toList()); for (int buildIndex : buildIndexes) { fields = ((StructObjectInspector) inputObjInspectors[buildIndex]).getAllStructFieldRefs(); - fieldNames.addAll(fields.stream().map(field -> inputColNameToExprName.getOrDefault(field.getFieldName(), field.getFieldName())).collect(Collectors.toList())); + fieldNames.addAll(fields.stream().map(field -> inputColNameToExprName.getOrDefault( + field.getFieldName().replace("value.", "VALUE.").replace("key.", "KEY."), + field.getFieldName()).replace("value.", "").replace("key.", "") + ).collect(Collectors.toList())); inspectors.addAll(fields.stream().map(StructField::getFieldObjectInspector).collect(Collectors.toList())); } StructObjectInspector exprObjInspector = ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, inspectors); diff --git a/omnioperator/omniop-hive-extension/src/main/java/com/huawei/boostkit/hive/cache/VecBufferCache.java b/omnioperator/omniop-hive-extension/src/main/java/com/huawei/boostkit/hive/cache/VecBufferCache.java index 5aa10fb19e906fe2a5c0e1300da5ac14bebe2bcc..43f39df5af2b9f44c47da1d13d03d568e1bccc50 100644 --- a/omnioperator/omniop-hive-extension/src/main/java/com/huawei/boostkit/hive/cache/VecBufferCache.java +++ b/omnioperator/omniop-hive-extension/src/main/java/com/huawei/boostkit/hive/cache/VecBufferCache.java @@ -107,6 +107,7 @@ public class VecBufferCache { vec = new ShortVec(rowCount); break; case BOOLEAN: + case VOID: vec = new BooleanVec(rowCount); break; case DOUBLE: diff --git a/omnioperator/omniop-hive-extension/src/main/java/com/huawei/boostkit/hive/converter/VecConverter.java b/omnioperator/omniop-hive-extension/src/main/java/com/huawei/boostkit/hive/converter/VecConverter.java index f28517f37a75bdaeb2c3c5c1b42c23adf75fcc50..e9892fe9f022f21d13628b70958803e8136b0848 100644 --- a/omnioperator/omniop-hive-extension/src/main/java/com/huawei/boostkit/hive/converter/VecConverter.java +++ b/omnioperator/omniop-hive-extension/src/main/java/com/huawei/boostkit/hive/converter/VecConverter.java @@ -48,6 +48,7 @@ public interface VecConverter { put(PrimitiveObjectInspector.PrimitiveCategory.TIMESTAMP, new TimestampVecConverter()); put(PrimitiveObjectInspector.PrimitiveCategory.DATE, new DateVecConverter()); put(PrimitiveObjectInspector.PrimitiveCategory.DECIMAL, new DecimalVecConverter()); + put(PrimitiveObjectInspector.PrimitiveCategory.VOID, new BooleanVecConverter()); } }; diff --git a/omnioperator/omniop-hive-extension/src/main/java/com/huawei/boostkit/hive/expression/TypeUtils.java b/omnioperator/omniop-hive-extension/src/main/java/com/huawei/boostkit/hive/expression/TypeUtils.java index 0ecfe38bf7dbb89650b337b0bdf005d0b549cc61..97635be206ce73623b3e233722d3b0b69bbb8d32 100644 --- a/omnioperator/omniop-hive-extension/src/main/java/com/huawei/boostkit/hive/expression/TypeUtils.java +++ b/omnioperator/omniop-hive-extension/src/main/java/com/huawei/boostkit/hive/expression/TypeUtils.java @@ -87,6 +87,7 @@ import org.apache.parquet.format.DecimalType; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -108,6 +109,7 @@ public class TypeUtils { put(PrimitiveObjectInspector.PrimitiveCategory.INTERVAL_DAY_TIME, LongDataType.LONG); put(PrimitiveObjectInspector.PrimitiveCategory.BYTE, ShortDataType.SHORT); put(PrimitiveObjectInspector.PrimitiveCategory.FLOAT, DoubleDataType.DOUBLE); + put(PrimitiveObjectInspector.PrimitiveCategory.VOID, BooleanDataType.BOOLEAN); } }; @@ -371,15 +373,88 @@ public class TypeUtils { return true; } } - boolean anyDecimal128 = children.stream().anyMatch(child -> child.getTypeInfo() instanceof DecimalTypeInfo - && ((DecimalTypeInfo) child.getTypeInfo()).getPrecision() > 18); - if ((functionName.equals("GenericUDFOPMultiply") || functionName.equals("GenericUDFOPDivide") - || functionName.equals("GenericUDFOPMod")) && anyDecimal128) { - return true; + + if (functionName.equals("GenericUDFOPMultiply") || functionName.equals("GenericUDFOPDivide") + || functionName.equals("GenericUDFOPMod")) { + return !isValidConversion(node); } return false; } + public static boolean checkUnsupportedTimestamp(ExprNodeDesc desc) { + TypeInfo typeInfo = desc.getTypeInfo(); + if (typeInfo instanceof PrimitiveTypeInfo) { + if (typeInfo.getTypeName() != "timestamp") { + return false; + } + if (desc instanceof ExprNodeConstantDesc) { + Timestamp timeValue = (Timestamp) ((ExprNodeConstantDesc) desc).getValue(); + if (timeValue.getNanos() % 1000000 != 0) { + return true; + } + } else { + return true; + } + } + return false; + } + + public static boolean isValidConversion(ExprNodeDesc node) { + if (node instanceof ExprNodeGenericFuncDesc && node.getChildren() != null && node.getChildren().size() == 2) { + List children = node.getChildren(); + int precision = 0; + int scale = 0; + int maxScale = 0; + if (node.getTypeInfo() instanceof DecimalTypeInfo) { + precision = ((DecimalTypeInfo) node.getTypeInfo()).getPrecision(); + scale = ((DecimalTypeInfo) node.getTypeInfo()).getScale(); + } + if (children.get(0) instanceof ExprNodeConstantDesc && children.get(1) instanceof ExprNodeColumnDesc) { + Collections.swap(children, 0, 1); + } + if (children.get(0) instanceof ExprNodeColumnDesc && children.get(1) instanceof ExprNodeConstantDesc) { + ExprNodeDesc exprNodeDesc = children.get(0); + if (exprNodeDesc.getTypeInfo() instanceof DecimalTypeInfo) { + maxScale = ((DecimalTypeInfo) exprNodeDesc.getTypeInfo()).getScale(); + } + } else { + maxScale = getMaxScale(children, scale); + } + + int targetChildPrecision = 0; + int targetChildScale = 0; + for (ExprNodeDesc child : children) { + if (child.getTypeInfo() instanceof DecimalTypeInfo) { + int childScale = ((DecimalTypeInfo) child.getTypeInfo()).getScale(); + int childPrecision = ((DecimalTypeInfo) child.getTypeInfo()).getPrecision(); + if (maxScale != childScale) { + targetChildPrecision = Math.min(Math.max(childPrecision + maxScale - childScale, precision), 38); + targetChildScale = maxScale; + if (childPrecision - childScale > targetChildPrecision - targetChildScale || childScale > targetChildScale) { + return false; + } + } + } + } + return true; + } + return true; + } + + public static int getMaxScale(List children, int maxScale) { + for (ExprNodeDesc child : children) { + if (!(child.getTypeInfo() instanceof DecimalTypeInfo)) { + continue; + } + DecimalTypeInfo childTypeInfo = (DecimalTypeInfo) child.getTypeInfo(); + int childScale = childTypeInfo.getScale(); + if (childScale >= maxScale) { + maxScale = childScale; + } + } + return maxScale; + } + public static boolean checkOmniJsonWhiteList(String filterExpr, String[] projections) { // inputTypes will not be checked if parseFormat is json( == 1), // only if its parseFormat is String(==0) diff --git a/omnioperator/omniop-hive-extension/src/main/java/com/huawei/boostkit/hive/processor/ArithmeticExpressionProcessor.java b/omnioperator/omniop-hive-extension/src/main/java/com/huawei/boostkit/hive/processor/ArithmeticExpressionProcessor.java index 63cd5d7a4cf763a00819da3d2cd9c5c351b269a5..097d17cd63788d44e5d6dfb0ce35174f874147df 100644 --- a/omnioperator/omniop-hive-extension/src/main/java/com/huawei/boostkit/hive/processor/ArithmeticExpressionProcessor.java +++ b/omnioperator/omniop-hive-extension/src/main/java/com/huawei/boostkit/hive/processor/ArithmeticExpressionProcessor.java @@ -18,6 +18,8 @@ package com.huawei.boostkit.hive.processor; +import static com.huawei.boostkit.hive.expression.TypeUtils.getMaxScale; + import com.huawei.boostkit.hive.expression.BaseExpression; import com.huawei.boostkit.hive.expression.CastFunctionExpression; import com.huawei.boostkit.hive.expression.DivideExpression; @@ -104,18 +106,4 @@ public class ArithmeticExpressionProcessor implements ExpressionProcessor { TypeUtils.getCharWidth(node), childPrecision, childScale); compareExpression.add(ExpressionUtils.optimizeCast(childNode, functionExpression)); } - - private int getMaxScale(List children, int maxScale) { - for (ExprNodeDesc child : children) { - if (!(child.getTypeInfo() instanceof DecimalTypeInfo)) { - continue; - } - DecimalTypeInfo childTypeInfo = (DecimalTypeInfo) child.getTypeInfo(); - int childScale = childTypeInfo.getScale(); - if (childScale >= maxScale) { - maxScale = childScale; - } - } - return maxScale; - } } diff --git a/omnioperator/omniop-hive-extension/src/main/java/com/huawei/boostkit/hive/processor/TimestampExpressionProcessor.java b/omnioperator/omniop-hive-extension/src/main/java/com/huawei/boostkit/hive/processor/TimestampExpressionProcessor.java index 4d0c9b5166dcd64faf91105c8e6f28b4be19dbb0..e302fdba930c36ed9ece817cd0970cdb397a8edf 100644 --- a/omnioperator/omniop-hive-extension/src/main/java/com/huawei/boostkit/hive/processor/TimestampExpressionProcessor.java +++ b/omnioperator/omniop-hive-extension/src/main/java/com/huawei/boostkit/hive/processor/TimestampExpressionProcessor.java @@ -20,7 +20,9 @@ package com.huawei.boostkit.hive.processor; import com.huawei.boostkit.hive.expression.BaseExpression; import com.huawei.boostkit.hive.expression.CastFunctionExpression; +import com.huawei.boostkit.hive.expression.DivideExpression; import com.huawei.boostkit.hive.expression.ExpressionUtils; +import com.huawei.boostkit.hive.expression.LiteralFactor; import com.huawei.boostkit.hive.expression.TypeUtils; import com.sun.jdi.LongType; import nova.hetu.omniruntime.type.LongDataType; @@ -32,13 +34,26 @@ public class TimestampExpressionProcessor implements ExpressionProcessor { @Override public BaseExpression process(ExprNodeGenericFuncDesc node, String operator, ObjectInspector inspector) { ExprNodeDesc exprNodeDesc = node.getChildren().get(0); - CastFunctionExpression cast = new CastFunctionExpression(LongDataType.LONG.getId().toValue(), - TypeUtils.getCharWidth(node), null, null); + BaseExpression baseExpression; + int dataType = TypeUtils.convertHiveTypeToOmniType(exprNodeDesc.getTypeInfo()); if (exprNodeDesc instanceof ExprNodeGenericFuncDesc) { - cast.add(ExpressionUtils.build((ExprNodeGenericFuncDesc) exprNodeDesc, inspector)); + baseExpression = ExpressionUtils.build((ExprNodeGenericFuncDesc) exprNodeDesc, inspector); } else { - cast.add(ExpressionUtils.createNode(exprNodeDesc, inspector)); + baseExpression = ExpressionUtils.createNode(exprNodeDesc, inspector); } - return cast; + LiteralFactor longLiteralFactor = new LiteralFactor<>("LITERAL", null, null, + 86400000L, null, LongDataType.LONG.getId().toValue()); + DivideExpression divideExpression = new DivideExpression(LongDataType.LONG.getId().toValue(), "MULTIPLY", null, null); + if (dataType != LongDataType.LONG.getId().toValue()) { + CastFunctionExpression cast = new CastFunctionExpression(LongDataType.LONG.getId().toValue(), + null, null, null); + cast.add(baseExpression); + divideExpression.add(cast); + divideExpression.add(longLiteralFactor); + return divideExpression; + } + divideExpression.add(baseExpression); + divideExpression.add(longLiteralFactor); + return divideExpression; } } diff --git a/omnioperator/omniop-hive-extension/src/main/java/com/huawei/boostkit/hive/reader/OmniOrcRecordReader.java b/omnioperator/omniop-hive-extension/src/main/java/com/huawei/boostkit/hive/reader/OmniOrcRecordReader.java index 1456cd56564aa59161a2941a6d241b6a5195fd7d..e61c142d046c4bee25bf82bafc04b0a86725a2fb 100644 --- a/omnioperator/omniop-hive-extension/src/main/java/com/huawei/boostkit/hive/reader/OmniOrcRecordReader.java +++ b/omnioperator/omniop-hive-extension/src/main/java/com/huawei/boostkit/hive/reader/OmniOrcRecordReader.java @@ -19,9 +19,19 @@ package com.huawei.boostkit.hive.reader; import static com.huawei.boostkit.hive.cache.VectorCache.BATCH; +import static com.huawei.boostkit.hive.expression.TypeUtils.DEFAULT_VARCHAR_LENGTH; import static org.apache.hadoop.hive.ql.io.orc.OrcInputFormat.getDesiredRowTypeDescr; import static org.apache.hadoop.hive.serde2.ColumnProjectionUtils.READ_COLUMN_IDS_CONF_STR; +import nova.hetu.omniruntime.type.BooleanDataType; +import nova.hetu.omniruntime.type.DataType; +import nova.hetu.omniruntime.type.Decimal128DataType; +import nova.hetu.omniruntime.type.DoubleDataType; +import nova.hetu.omniruntime.type.IntDataType; +import nova.hetu.omniruntime.type.LongDataType; +import nova.hetu.omniruntime.type.ShortDataType; +import nova.hetu.omniruntime.type.VarcharDataType; +import nova.hetu.omniruntime.vector.Decimal128Vec; import nova.hetu.omniruntime.vector.Vec; import nova.hetu.omniruntime.vector.VecBatch; @@ -35,6 +45,7 @@ import org.apache.hadoop.hive.ql.io.sarg.SearchArgument; import org.apache.hadoop.hive.ql.io.sarg.SearchArgumentImpl; import org.apache.hadoop.hive.ql.plan.api.OperatorType; import org.apache.hadoop.hive.serde2.SerDeStats; +import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; import org.apache.hadoop.io.NullWritable; import org.apache.hadoop.mapred.FileSplit; import org.apache.hadoop.mapred.RecordReader; @@ -45,12 +56,31 @@ import org.apache.orc.TypeDescription; import java.io.IOException; import java.util.Arrays; +import java.util.HashMap; import java.util.HashSet; import java.util.List; +import java.util.Map; import java.util.Set; import java.util.stream.Collectors; public class OmniOrcRecordReader implements RecordReader, StatsProvidingRecordReader { + private static final Map CATEGORY_TO_OMNI_TYPE = new HashMap() { + { + put(TypeDescription.Category.SHORT, ShortDataType.SHORT); + put(TypeDescription.Category.INT, IntDataType.INTEGER); + put(TypeDescription.Category.LONG, LongDataType.LONG); + put(TypeDescription.Category.BOOLEAN, BooleanDataType.BOOLEAN); + put(TypeDescription.Category.DOUBLE, DoubleDataType.DOUBLE); + put(TypeDescription.Category.STRING, new VarcharDataType(DEFAULT_VARCHAR_LENGTH)); + put(TypeDescription.Category.TIMESTAMP, LongDataType.LONG); + put(TypeDescription.Category.DATE, IntDataType.INTEGER); + put(TypeDescription.Category.BYTE, ShortDataType.SHORT); + put(TypeDescription.Category.FLOAT, DoubleDataType.DOUBLE); + put(TypeDescription.Category.DECIMAL, Decimal128DataType.DECIMAL128); + put(TypeDescription.Category.CHAR, VarcharDataType.VARCHAR); + put(TypeDescription.Category.VARCHAR, VarcharDataType.VARCHAR); + } + }; protected OrcColumnarBatchScanReader recordReader; protected Vec[] vecs; protected final long offset; @@ -59,6 +89,7 @@ public class OmniOrcRecordReader implements RecordReader included; protected Operator tableScanOp; + protected int[] typeIds; OmniOrcRecordReader(Configuration conf, FileSplit split) throws IOException { TypeDescription schema = getDesiredRowTypeDescr(conf, false, Integer.MAX_VALUE); @@ -70,6 +101,10 @@ public class OmniOrcRecordReader implements RecordReader(field); + case omniruntime::type::OMNI_SHORT: + return CopyFixedWidth(field); + case omniruntime::type::OMNI_INT: + return CopyFixedWidth(field); + case omniruntime::type::OMNI_LONG: + return CopyOptimizedForInt64(field); + case omniruntime::type::OMNI_DATE32: + return CopyFixedWidth(field); + case omniruntime::type::OMNI_DATE64: + return CopyOptimizedForInt64(field); + default: { + throw std::runtime_error("dealLongVectorBatch not support for type: " + id); + } + } + return -1; +} + +uint64_t dealDoubleVectorBatch(DataTypeId id, orc::ColumnVectorBatch *field) { + switch (id) { + case omniruntime::type::OMNI_DOUBLE: + return CopyOptimizedForInt64(field); + default: { + throw std::runtime_error("dealDoubleVectorBatch not support for type: " + id); + } + } + return -1; +} + +uint64_t dealDecimal64VectorBatch(DataTypeId id, orc::ColumnVectorBatch *field) { + switch (id) { + case omniruntime::type::OMNI_DECIMAL64: + return CopyToOmniDecimal64Vec(field); + case omniruntime::type::OMNI_DECIMAL128: + return CopyToOmniDecimal128VecFrom64(field); + default: { + throw std::runtime_error("dealDecimal64VectorBatch not support for type: " + id); + } + } + return -1; +} + +uint64_t dealDecimal128VectorBatch(DataTypeId id, orc::ColumnVectorBatch *field) { + switch (id) { + case omniruntime::type::OMNI_DECIMAL128: + return CopyToOmniDecimal128Vec(field); + default: { + throw std::runtime_error("dealDecimal128VectorBatch not support for type: " + id); + } + } + return -1; +} + +int CopyToOmniVec(const orc::Type *type, int omniTypeId, uint64_t &omniVecId, orc::ColumnVectorBatch *field, bool isDecimal64Transfor128) { + DataTypeId dataTypeId = static_cast(omniTypeId); switch (type->getKind()) { case orc::TypeKind::BOOLEAN: - omniTypeId = static_cast(OMNI_BOOLEAN); - omniVecId = CopyFixedWidth(field); - break; case orc::TypeKind::SHORT: - omniTypeId = static_cast(OMNI_SHORT); - omniVecId = CopyFixedWidth(field); - break; case orc::TypeKind::DATE: - omniTypeId = static_cast(OMNI_DATE32); - omniVecId = CopyFixedWidth(field); - break; case orc::TypeKind::INT: - omniTypeId = static_cast(OMNI_INT); - omniVecId = CopyFixedWidth(field); - break; case orc::TypeKind::LONG: - omniTypeId = static_cast(OMNI_LONG); - omniVecId = CopyOptimizedForInt64(field); + omniVecId = dealLongVectorBatch(dataTypeId, field); break; case orc::TypeKind::DOUBLE: - omniTypeId = static_cast(OMNI_DOUBLE); - omniVecId = CopyOptimizedForInt64(field); + omniVecId = dealDoubleVectorBatch(dataTypeId, field); break; case orc::TypeKind::CHAR: - omniTypeId = static_cast(OMNI_VARCHAR); + if (dataTypeId != OMNI_VARCHAR) { + throw std::runtime_error("Cannot transfer to other OMNI_TYPE but VARCHAR for orc char"); + } omniVecId = CopyCharType(field); break; case orc::TypeKind::STRING: case orc::TypeKind::VARCHAR: - omniTypeId = static_cast(OMNI_VARCHAR); + if (dataTypeId != OMNI_VARCHAR) { + throw std::runtime_error("Cannot transfer to other OMNI_TYPE but VARCHAR for orc string/varchar"); + } omniVecId = CopyVarWidth(field); break; case orc::TypeKind::DECIMAL: if (type->getPrecision() > MAX_DECIMAL64_DIGITS) { - omniTypeId = static_cast(OMNI_DECIMAL128); - omniVecId = CopyToOmniDecimal128Vec(field); - } else if (isDecimal64Transfor128) { - omniTypeId = static_cast(OMNI_DECIMAL128); - omniVecId = CopyToOmniDecimal128VecFrom64(field); + omniVecId = dealDecimal128VectorBatch(dataTypeId, field); } else { - omniTypeId = static_cast(OMNI_DECIMAL64); - omniVecId = CopyToOmniDecimal64Vec(field); + omniVecId = dealDecimal64VectorBatch(dataTypeId, field); } break; default: { @@ -576,16 +618,16 @@ JNIEXPORT jlong JNICALL Java_com_huawei_boostkit_scan_jni_OrcColumnarBatchJniRea const orc::Type &baseTp = rowReaderPtr->getSelectedType(); int vecCnt = 0; long batchRowSize = 0; + auto ptr = env->GetIntArrayElements(typeId, JNI_FALSE); if (rowReaderPtr->next(*columnVectorBatch)) { orc::StructVectorBatch *root = dynamic_cast(columnVectorBatch); vecCnt = root->fields.size(); batchRowSize = root->fields[0]->numElements; for (int id = 0; id < vecCnt; id++) { auto type = baseTp.getSubtype(id); - int omniTypeId = 0; + int omniTypeId = ptr[id]; uint64_t omniVecId = 0; CopyToOmniVec(type, omniTypeId, omniVecId, root->fields[id], isDecimal64Transfor128); - env->SetIntArrayRegion(typeId, id, 1, &omniTypeId); jlong omniVec = static_cast(omniVecId); env->SetLongArrayRegion(vecNativeId, id, 1, &omniVec); } diff --git a/omnioperator/omniop-native-reader/cpp/src/jni/OrcColumnarBatchJniReader.h b/omnioperator/omniop-native-reader/cpp/src/jni/OrcColumnarBatchJniReader.h index 829f5c0744d3d563601ec6506ebfc82b5a020e93..e0c33b26c7f191771fc2c2bf7aaa1a278db8f424 100644 --- a/omnioperator/omniop-native-reader/cpp/src/jni/OrcColumnarBatchJniReader.h +++ b/omnioperator/omniop-native-reader/cpp/src/jni/OrcColumnarBatchJniReader.h @@ -141,7 +141,7 @@ int BuildLeaves(PredicateOperatorType leafOp, std::vector &litList bool StringToBool(const std::string &boolStr); -int CopyToOmniVec(const orc::Type *type, int &omniTypeId, uint64_t &omniVecId, orc::ColumnVectorBatch *field, +int CopyToOmniVec(const orc::Type *type, int omniTypeId, uint64_t &omniVecId, orc::ColumnVectorBatch *field, bool isDecimal64Transfor128); #ifdef __cplusplus diff --git a/omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchScanReader.java b/omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchScanReader.java index 8edbdf4622f0b5888cd4fd680fee6fc1eb3c4880..611a10826042a7916ff7d98f804c7ee5eba319b1 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchScanReader.java +++ b/omnioperator/omniop-spark-extension/java/src/main/java/com/huawei/boostkit/spark/jni/OrcColumnarBatchScanReader.java @@ -34,7 +34,6 @@ import org.slf4j.LoggerFactory; import java.net.URI; import java.sql.Date; import java.util.ArrayList; -import java.util.Arrays; import java.util.List; public class OrcColumnarBatchScanReader { @@ -253,8 +252,7 @@ public class OrcColumnarBatchScanReader { } } - public int next(Vec[] vecList) { - int[] typeIds = new int[realColsCnt]; + public int next(Vec[] vecList, int[] typeIds) { long[] vecNativeIds = new long[realColsCnt]; long rtn = jniReader.recordReaderNext(recordReader, batchReader, typeIds, vecNativeIds); if (rtn == 0) { diff --git a/omnioperator/omniop-spark-extension/java/src/main/java/org/apache/spark/sql/execution/datasources/orc/OmniOrcColumnarBatchReader.java b/omnioperator/omniop-spark-extension/java/src/main/java/org/apache/spark/sql/execution/datasources/orc/OmniOrcColumnarBatchReader.java index 24a93ede4cf4e17ccd0050157e04a848ab38ba0d..e8e7db3af876ce726d9b22e8ead47cf9f15a9fbe 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/java/org/apache/spark/sql/execution/datasources/orc/OmniOrcColumnarBatchReader.java +++ b/omnioperator/omniop-spark-extension/java/src/main/java/org/apache/spark/sql/execution/datasources/orc/OmniOrcColumnarBatchReader.java @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.datasources.orc; import com.google.common.annotations.VisibleForTesting; +import com.huawei.boostkit.spark.expression.OmniExpressionAdaptor; import com.huawei.boostkit.spark.jni.OrcColumnarBatchScanReader; import nova.hetu.omniruntime.vector.Vec; import org.apache.hadoop.conf.Configuration; @@ -79,6 +80,8 @@ public class OmniOrcColumnarBatchReader extends RecordReader