From d8384ef352ea19c65c6c9856a5825dfc758ca6b7 Mon Sep 17 00:00:00 2001 From: liusongling Date: Mon, 24 Oct 2022 15:34:48 +0800 Subject: [PATCH] Fix false fallback for aggregation --- .../operator/AbstractOmniOperatorFactory.java | 46 +++++++++++++++++++ .../olk/operator/AggregationOmniOperator.java | 41 +---------------- .../operator/HashAggregationOmniOperator.java | 37 +-------------- 3 files changed, 48 insertions(+), 76 deletions(-) diff --git a/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/operator/AbstractOmniOperatorFactory.java b/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/operator/AbstractOmniOperatorFactory.java index 508f38244..6009e516c 100644 --- a/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/operator/AbstractOmniOperatorFactory.java +++ b/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/operator/AbstractOmniOperatorFactory.java @@ -20,13 +20,19 @@ import io.prestosql.operator.Operator; import io.prestosql.operator.OperatorFactory; import io.prestosql.spi.PrestoException; import io.prestosql.spi.StandardErrorCode; +import io.prestosql.spi.plan.AggregationNode; import io.prestosql.spi.plan.PlanNodeId; import io.prestosql.spi.type.StandardTypes; import io.prestosql.spi.type.Type; import io.prestosql.spi.type.TypeSignature; +import nova.hetu.omniruntime.constants.FunctionType; import java.util.List; +import static nova.hetu.omniruntime.constants.FunctionType.OMNI_AGGREGATION_TYPE_AVG; +import static nova.hetu.omniruntime.constants.FunctionType.OMNI_AGGREGATION_TYPE_COUNT_ALL; +import static nova.hetu.omniruntime.constants.FunctionType.OMNI_AGGREGATION_TYPE_SUM; + public abstract class AbstractOmniOperatorFactory implements OperatorFactory { @@ -74,6 +80,46 @@ public abstract class AbstractOmniOperatorFactory } } + public static void checkDataTypes(final List sourceTypes, + final FunctionType[] aggregatorTypes, + final int[] aggregationInputChannels, + final AggregationNode.Step step) + { + int channelIndex = 0; + for (FunctionType fType : aggregatorTypes) { + if (fType == OMNI_AGGREGATION_TYPE_COUNT_ALL) { + continue; + } + + String base = sourceTypes.get(aggregationInputChannels[channelIndex++]).getTypeSignature().getBase(); + switch (base) { + case StandardTypes.INTEGER: + case StandardTypes.BIGINT: + case StandardTypes.DOUBLE: + case StandardTypes.BOOLEAN: + case StandardTypes.VARCHAR: + case StandardTypes.CHAR: + case StandardTypes.DECIMAL: + case StandardTypes.DATE: + continue; + case StandardTypes.VARBINARY: + if (step == AggregationNode.Step.FINAL + && (fType == OMNI_AGGREGATION_TYPE_AVG || fType == OMNI_AGGREGATION_TYPE_SUM)) { + continue; + } + case StandardTypes.ROW: { + if (step == AggregationNode.Step.FINAL && fType == OMNI_AGGREGATION_TYPE_AVG) { + continue; + } + } + default: + throw new PrestoException( + StandardErrorCode.NOT_SUPPORTED, + "Not support data Type " + base + " for aggregation " + fType + " with step " + step); + } + } + } + public void checkDataType(Type type) { TypeSignature signature = type.getTypeSignature(); diff --git a/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/operator/AggregationOmniOperator.java b/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/operator/AggregationOmniOperator.java index 8fa7b8a99..694782013 100644 --- a/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/operator/AggregationOmniOperator.java +++ b/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/operator/AggregationOmniOperator.java @@ -20,13 +20,9 @@ import io.prestosql.operator.Operator; import io.prestosql.operator.OperatorContext; import io.prestosql.operator.OperatorFactory; import io.prestosql.spi.Page; -import io.prestosql.spi.PrestoException; -import io.prestosql.spi.StandardErrorCode; import io.prestosql.spi.plan.AggregationNode; import io.prestosql.spi.plan.PlanNodeId; -import io.prestosql.spi.type.StandardTypes; import io.prestosql.spi.type.Type; -import io.prestosql.spi.type.TypeSignature; import nova.hetu.olk.tool.OperatorUtils; import nova.hetu.olk.tool.VecAllocatorHelper; import nova.hetu.olk.tool.VecBatchToPageIterator; @@ -37,15 +33,12 @@ import nova.hetu.omniruntime.type.DataType; import nova.hetu.omniruntime.vector.VecAllocator; import nova.hetu.omniruntime.vector.VecBatch; -import java.util.Arrays; import java.util.List; import java.util.Optional; import static com.google.common.base.Preconditions.checkState; import static java.util.Objects.requireNonNull; import static nova.hetu.olk.tool.OperatorUtils.buildVecBatch; -import static nova.hetu.omniruntime.constants.FunctionType.OMNI_AGGREGATION_TYPE_AVG; -import static nova.hetu.omniruntime.constants.FunctionType.OMNI_AGGREGATION_TYPE_SUM; /** * The type Aggregation omni operator. @@ -198,7 +191,7 @@ public class AggregationOmniOperator maskChannelArray[i] = INVALID_MASK_CHANNEL; } } - checkDataTypes(this.sourceTypes); + checkDataTypes(this.sourceTypes, this.aggregatorTypes, this.aggregationInputChannels, this.step); this.omniFactory = new OmniAggregationOperatorFactory(sourceDataTypes, aggregatorTypes, aggregationInputChannels, maskChannelArray, aggregationOutputTypes, step.isInputRaw(), step.isOutputPartial()); @@ -226,37 +219,5 @@ public class AggregationOmniOperator return new AggregationOmniOperatorFactory(operatorId, planNodeId, sourceTypes, aggregatorTypes, aggregationInputChannels, maskChannels, aggregationOutputTypes, step); } - - @Override - public void checkDataType(Type type) - { - TypeSignature signature = type.getTypeSignature(); - String base = signature.getBase(); - - switch (base) { - case StandardTypes.INTEGER: - case StandardTypes.BIGINT: - case StandardTypes.DOUBLE: - case StandardTypes.BOOLEAN: - case StandardTypes.VARCHAR: - case StandardTypes.CHAR: - case StandardTypes.DECIMAL: - case StandardTypes.DATE: - return; - case StandardTypes.VARBINARY: - if (this.step == AggregationNode.Step.FINAL && this.aggregatorTypes.length != 0 && - Arrays.stream(this.aggregatorTypes).allMatch(item -> item == OMNI_AGGREGATION_TYPE_AVG || item == OMNI_AGGREGATION_TYPE_SUM)) { - return; - } - case StandardTypes.ROW: { - if (this.step == AggregationNode.Step.FINAL && this.aggregatorTypes.length != 0 && - Arrays.stream(this.aggregatorTypes).allMatch(item -> item == OMNI_AGGREGATION_TYPE_AVG)) { - return; - } - } - default: - throw new PrestoException(StandardErrorCode.NOT_SUPPORTED, "Not support data Type " + base); - } - } } } diff --git a/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/operator/HashAggregationOmniOperator.java b/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/operator/HashAggregationOmniOperator.java index 15efdcd33..0bf42ef00 100644 --- a/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/operator/HashAggregationOmniOperator.java +++ b/omnioperator/omniop-openlookeng-extension/src/main/java/nova/hetu/olk/operator/HashAggregationOmniOperator.java @@ -32,7 +32,6 @@ import io.prestosql.spi.plan.AggregationNode.Step; import io.prestosql.spi.plan.PlanNodeId; import io.prestosql.spi.type.StandardTypes; import io.prestosql.spi.type.Type; -import io.prestosql.spi.type.TypeSignature; import nova.hetu.olk.tool.BlockUtils; import nova.hetu.olk.tool.VecAllocatorHelper; import nova.hetu.olk.tool.VecBatchToPageIterator; @@ -53,8 +52,6 @@ import static com.google.common.base.Preconditions.checkState; import static java.util.Objects.requireNonNull; import static nova.hetu.olk.tool.OperatorUtils.buildVecBatch; import static nova.hetu.olk.tool.OperatorUtils.createExpressions; -import static nova.hetu.omniruntime.constants.FunctionType.OMNI_AGGREGATION_TYPE_AVG; -import static nova.hetu.omniruntime.constants.FunctionType.OMNI_AGGREGATION_TYPE_SUM; /** * The type Hash aggregation omni operator. @@ -263,7 +260,7 @@ public class HashAggregationOmniOperator this.aggregationOutputTypes = Arrays.copyOf( requireNonNull(aggregationOutputTypes, "aggregationOutputTypes is null."), aggregationOutputTypes.length); - checkDataTypes(this.sourceTypes); + checkDataTypes(this.sourceTypes, this.aggregatorTypes, this.aggregationInputChannels, this.step); this.omniFactory = new OmniHashAggregationOperatorFactory(createExpressions(this.groupByInputChannels), this.groupByInputTypes, createExpressions(this.aggregationInputChannels), this.aggregationInputTypes, this.aggregatorTypes, maskChannelArray, this.aggregationOutputTypes, step.isInputRaw(), @@ -337,37 +334,5 @@ public class HashAggregationOmniOperator groupByInputTypes, aggregationInputChannels, aggregationInputTypes, aggregatorTypes, maskChannels, aggregationOutputTypes, step); } - - @Override - public void checkDataType(Type type) - { - TypeSignature signature = type.getTypeSignature(); - String base = signature.getBase(); - - switch (base) { - case StandardTypes.INTEGER: - case StandardTypes.BIGINT: - case StandardTypes.DOUBLE: - case StandardTypes.BOOLEAN: - case StandardTypes.VARCHAR: - case StandardTypes.CHAR: - case StandardTypes.DECIMAL: - case StandardTypes.DATE: - return; - case StandardTypes.VARBINARY: - if (this.step == AggregationNode.Step.FINAL && this.aggregatorTypes.length != 0 && - Arrays.stream(this.aggregatorTypes).allMatch(item -> item == OMNI_AGGREGATION_TYPE_AVG || item == OMNI_AGGREGATION_TYPE_SUM)) { - return; - } - case StandardTypes.ROW: { - if (this.step == AggregationNode.Step.FINAL && this.aggregatorTypes.length != 0 && - Arrays.stream(this.aggregatorTypes).allMatch(item -> item == OMNI_AGGREGATION_TYPE_AVG)) { - return; - } - } - default: - throw new PrestoException(StandardErrorCode.NOT_SUPPORTED, "Not support data Type " + base); - } - } } } -- Gitee