diff --git a/omnioperator/omniop-hive-extension/src/main/java/com/huawei/boostkit/hive/OmniGroupByOperator.java b/omnioperator/omniop-hive-extension/src/main/java/com/huawei/boostkit/hive/OmniGroupByOperator.java index 8105b4c1a38579bc59def1225c6e919c01bc46af..d9d54ee15acc7696b582fe9380eaeeeb2b9fcc99 100644 --- a/omnioperator/omniop-hive-extension/src/main/java/com/huawei/boostkit/hive/OmniGroupByOperator.java +++ b/omnioperator/omniop-hive-extension/src/main/java/com/huawei/boostkit/hive/OmniGroupByOperator.java @@ -20,6 +20,8 @@ package com.huawei.boostkit.hive; import static com.huawei.boostkit.hive.expression.TypeUtils.buildInputDataType; import static nova.hetu.omniruntime.constants.FunctionType.OMNI_AGGREGATION_TYPE_COUNT_ALL; +import static nova.hetu.omniruntime.type.DataType.DataTypeId.OMNI_CHAR; +import static nova.hetu.omniruntime.type.DataType.DataTypeId.OMNI_VARCHAR; import static org.apache.hadoop.hive.ql.exec.GroupByOperator.groupingSet2BitSet; import static org.apache.hadoop.hive.ql.exec.GroupByOperator.shouldEmitSummaryRow; @@ -29,6 +31,8 @@ import com.huawei.boostkit.hive.expression.TypeUtils; import javolution.util.FastBitSet; import nova.hetu.omniruntime.constants.FunctionType; import nova.hetu.omniruntime.operator.OmniOperator; +import nova.hetu.omniruntime.operator.OmniOperatorFactory; +import nova.hetu.omniruntime.operator.aggregator.OmniAggregationWithExprOperatorFactory; import nova.hetu.omniruntime.operator.aggregator.OmniHashAggregationWithExprOperatorFactory; import nova.hetu.omniruntime.operator.config.OperatorConfig; import nova.hetu.omniruntime.operator.config.OverflowConfig; @@ -98,7 +102,7 @@ import java.util.Set; public class OmniGroupByOperator extends OmniHiveOperator implements Serializable, VectorizationContextRegion, IConfigureJobConf { private static final long serialVersionUID = 1L; private static final Logger LOG = LoggerFactory.getLogger(OmniGroupByOperator.class.getName()); - private transient OmniHashAggregationWithExprOperatorFactory omniHashAggregationWithExprOperatorFactory; + private transient OmniOperatorFactory omniOperatorFactory; private transient OmniOperator omniOperator; private transient List keyFields; private transient boolean firstRow; @@ -286,10 +290,16 @@ public class OmniGroupByOperator extends OmniHiveOperator imple groupByChanel = getExprFromExprNode(keyFields); aggChannels = getTwoDimenExprFromExprNode(aggChannelFields); } - omniHashAggregationWithExprOperatorFactory = new OmniHashAggregationWithExprOperatorFactory(groupByChanel, aggChannels, - aggChannelsFilter, sourceTypes, aggFunctionTypes, aggOutputTypes, isInputRaws, isOutputPartials, - operatorConfig); - omniOperator = omniHashAggregationWithExprOperatorFactory.createOperator(); + if (numKeys == 0) { + omniOperatorFactory = new OmniAggregationWithExprOperatorFactory(groupByChanel, aggChannels, + aggChannelsFilter, sourceTypes, aggFunctionTypes, aggOutputTypes, isInputRaws, isOutputPartials, + operatorConfig); + } else { + omniOperatorFactory = new OmniHashAggregationWithExprOperatorFactory(groupByChanel, aggChannels, + aggChannelsFilter, sourceTypes, aggFunctionTypes, aggOutputTypes, isInputRaws, isOutputPartials, + operatorConfig); + } + omniOperator = omniOperatorFactory.createOperator(); } @Override @@ -475,14 +485,13 @@ public class OmniGroupByOperator extends OmniHiveOperator imple int rowCount = vec.getSize(); int groupingSetSize = groupingSets.size(); Vec newVec = VecFactory.createFlatVec(rowCount * groupingSetSize, vec.getType()); - Vec flatVec = vec; - if (vec instanceof DictionaryVec) { - flatVec = ((DictionaryVec) vec).expandDictionary(); - } + Vec flatVec = (vec instanceof DictionaryVec) ? ((DictionaryVec) vec).expandDictionary() : vec; + byte[] rawValueNulls = vec.getRawValueNulls(); + DataType.DataTypeId dataTypeId = vec.getType().getId(); + int[] rawValueOffset = (dataTypeId == OMNI_VARCHAR || dataTypeId == OMNI_CHAR) ? ((VarcharVec) flatVec).getRawValueOffset() : new int[0]; for (int i = 0; i < groupingSetSize; i++) { + newVec.setNulls(i * rowCount, rawValueNulls, 0, rowCount); if ((groupingSets.get(i) & mask) == 0) { - DataType.DataTypeId dataTypeId = vec.getType().getId(); - newVec.setNulls(i * rowCount, vec.getValuesNulls(0, rowCount), 0, rowCount); switch (dataTypeId) { case OMNI_INT: case OMNI_DATE32: @@ -509,14 +518,14 @@ public class OmniGroupByOperator extends OmniHiveOperator imple case OMNI_VARCHAR: case OMNI_CHAR: ((VarcharVec) newVec).put(i * rowCount, ((VarcharVec) flatVec).get(0, rowCount), 0, - ((VarcharVec) flatVec).getValueOffset(0, rowCount), 0, rowCount); + rawValueOffset, 0, rowCount); break; default: throw new RuntimeException("Not support dataType, dataTypeId: " + dataTypeId); } } else { - boolean[] nulls = new boolean[rowCount]; - Arrays.fill(nulls, true); + byte[] nulls = new byte[rowCount]; + Arrays.fill(nulls, (byte) 1); newVec.setNulls(i * rowCount, nulls, 0, rowCount); } } @@ -712,7 +721,7 @@ public class OmniGroupByOperator extends OmniHiveOperator imple for (Vec vec : constantVec) { vec.close(); } - omniHashAggregationWithExprOperatorFactory.close(); + omniOperatorFactory.close(); omniOperator.close(); super.closeOp(abort); }