diff --git a/omnioperator/omniop-hive-extension/src/main/java/com/huawei/boostkit/hive/OmniExecuteWithHookContext.java b/omnioperator/omniop-hive-extension/src/main/java/com/huawei/boostkit/hive/OmniExecuteWithHookContext.java index 31b26a6ca1afddd9a4044d410d7f10e047d2ed21..0e270f6ccf3f6718091ec1ab77f8dd3433222d61 100644 --- a/omnioperator/omniop-hive-extension/src/main/java/com/huawei/boostkit/hive/OmniExecuteWithHookContext.java +++ b/omnioperator/omniop-hive-extension/src/main/java/com/huawei/boostkit/hive/OmniExecuteWithHookContext.java @@ -22,6 +22,7 @@ import static com.huawei.boostkit.hive.expression.TypeUtils.checkOmniJsonWhiteLi import static com.huawei.boostkit.hive.expression.TypeUtils.checkUnsupportedArithmetic; import static com.huawei.boostkit.hive.expression.TypeUtils.checkUnsupportedCast; import static nova.hetu.omniruntime.constants.FunctionType.OMNI_AGGREGATION_TYPE_AVG; +import static nova.hetu.omniruntime.constants.FunctionType.OMNI_AGGREGATION_TYPE_SUM; import static org.apache.hadoop.hive.serde.serdeConstants.SERIALIZATION_LIB; import static org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector.PrimitiveCategory.BOOLEAN; import static org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector.PrimitiveCategory.CHAR; @@ -490,41 +491,7 @@ public class OmniExecuteWithHookContext implements ExecuteWithHookContext { } return result; case PTF: - PTFDesc conf = (PTFDesc) operator.getConf(); - List expressions = new ArrayList<>(conf.getFuncDef().getPartition().getExpressions()); - expressions.addAll(conf.getFuncDef().getOrder().getExpressions()); - for (PTFExpressionDef expression : expressions) { - ExprNodeDesc exprNode = expression.getExprNode(); - if (exprNode instanceof ExprNodeGenericFuncDesc) { - return false; - } - if (exprNode instanceof ExprNodeColumnDesc && !SUPPORTED_TYPE - .contains(((PrimitiveTypeInfo) exprNode.getTypeInfo()).getPrimitiveCategory())) { - return false; - } - } - if (!(conf.getFuncDef() instanceof WindowTableFunctionDef)) { - return false; - } - List windowFunctionDefs = ((WindowTableFunctionDef) conf.getFuncDef()).getWindowFunctions(); - for (int i = 0; i < windowFunctionDefs.size(); i++) { - WindowFunctionDef windowFunctionDef = windowFunctionDefs.get(i); - FunctionType windowFunctionType = TypeUtils.getWindowFunctionType(windowFunctionDef); - if (windowFunctionType == null || windowFunctionType == OMNI_AGGREGATION_TYPE_AVG && isAvgUnsupported(windowFunctionDef.getArgs())) { - return false; - } - List args = windowFunctionDefs.get(i).getArgs(); - boolean isCountAll = (windowFunctionDef.getName().equals("count") && windowFunctionDef.isStar()); - if (args != null) { - if (!isCountAll && args.size() > 1) { - return false; - } - if (!(args.get(0).getExprNode() instanceof ExprNodeColumnDesc)) { - return false; - } - } - } - return true; + return PTFReplaceable(operator); case GROUPBY: return groupByReplaceable(operator, reduceSinkCanReplace); case REDUCESINK: @@ -562,8 +529,49 @@ public class OmniExecuteWithHookContext implements ExecuteWithHookContext { } } + private boolean PTFReplaceable(Operator operator) { + PTFDesc conf = (PTFDesc) operator.getConf(); + List expressions = new ArrayList<>(conf.getFuncDef().getPartition().getExpressions()); + expressions.addAll(conf.getFuncDef().getOrder().getExpressions()); + for (PTFExpressionDef expression : expressions) { + ExprNodeDesc exprNode = expression.getExprNode(); + if (exprNode instanceof ExprNodeGenericFuncDesc) { + return false; + } + if (exprNode instanceof ExprNodeColumnDesc && !SUPPORTED_TYPE + .contains(((PrimitiveTypeInfo) exprNode.getTypeInfo()).getPrimitiveCategory())) { + return false; + } + } + if (!(conf.getFuncDef() instanceof WindowTableFunctionDef)) { + return false; + } + List windowFunctionDefs = ((WindowTableFunctionDef) conf.getFuncDef()).getWindowFunctions(); + for (int i = 0; i < windowFunctionDefs.size(); i++) { + WindowFunctionDef windowFunctionDef = windowFunctionDefs.get(i); + FunctionType windowFunctionType = TypeUtils.getWindowFunctionType(windowFunctionDef); + List args = windowFunctionDefs.get(i).getArgs(); + boolean isCountAll = (windowFunctionDef.getName().equals("count") && windowFunctionDef.isStar()); + if (args != null) { + if (windowFunctionType == null || isPTFAggUnsupported(windowFunctionType, args)) { + return false; + } + if (!isCountAll && args.size() > 1) { + return false; + } + if (!(args.get(0).getExprNode() instanceof ExprNodeColumnDesc)) { + return false; + } + } + } + return true; + } + private boolean groupByReplaceable(Operator operator, boolean reduceSinkCanReplace) { GroupByDesc groupByDesc = (GroupByDesc) operator.getConf(); + if (groupByDesc.isGroupingSetsPresent() && groupByDesc.getMode() == GroupByDesc.Mode.PARTIALS) { + return false; + } if (!isUDFSupport(groupByDesc.getKeys())) { return false; } @@ -611,17 +619,21 @@ public class OmniExecuteWithHookContext implements ExecuteWithHookContext { return SUPPORTED_JOIN.contains(joinCondDescs[0].getType()); } - private boolean isAvgUnsupported(List expressions) { + private boolean isPTFAggUnsupported(FunctionType windowFunctionType, List expressions) { for (PTFExpressionDef expression : expressions) { ObjectInspector oi = expression.getOI(); if (!(oi instanceof PrimitiveObjectInspector)) { continue; } PrimitiveTypeInfo primitiveTypeInfo = ((PrimitiveObjectInspector) oi).getTypeInfo(); - if (primitiveTypeInfo.getPrimitiveCategory().equals(DECIMAL) + if (windowFunctionType == OMNI_AGGREGATION_TYPE_AVG && primitiveTypeInfo.getPrimitiveCategory().equals(DECIMAL) && ((DecimalTypeInfo) primitiveTypeInfo).getPrecision() > DECIMAL64_MAX_PRECISION) { return true; } + if ((windowFunctionType == OMNI_AGGREGATION_TYPE_AVG || windowFunctionType == OMNI_AGGREGATION_TYPE_SUM) + && primitiveTypeInfo.getPrimitiveCategory().equals(STRING)) { + return true; + } } return false; }