diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/expression/OmniExpressionAdaptor.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/expression/OmniExpressionAdaptor.scala index 4ec7579009dbc9c6e8467dca377c6a3fc40d97b5..c760d6ee5f39314d7f4903314af7ef9910b1c5a3 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/expression/OmniExpressionAdaptor.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/expression/OmniExpressionAdaptor.scala @@ -21,7 +21,7 @@ package com.huawei.boostkit.spark.expression import com.huawei.boostkit.spark.Constant.{DEFAULT_STRING_TYPE_LENGTH, IS_DECIMAL_CHECK, OMNI_BOOLEAN_TYPE, OMNI_DATE_TYPE, OMNI_DECIMAL128_TYPE, OMNI_DECIMAL64_TYPE, OMNI_DOUBLE_TYPE, OMNI_INTEGER_TYPE, OMNI_LONG_TYPE, OMNI_SHOR_TYPE, OMNI_VARCHAR_TYPE} import nova.hetu.omniruntime.`type`.{BooleanDataType, DataTypeSerializer, Date32DataType, Decimal128DataType, Decimal64DataType, DoubleDataType, IntDataType, LongDataType, ShortDataType, VarcharDataType} import nova.hetu.omniruntime.constants.FunctionType -import nova.hetu.omniruntime.constants.FunctionType.{OMNI_AGGREGATION_TYPE_AVG, OMNI_AGGREGATION_TYPE_COUNT_COLUMN, OMNI_AGGREGATION_TYPE_MAX, OMNI_AGGREGATION_TYPE_MIN, OMNI_AGGREGATION_TYPE_SUM, OMNI_WINDOW_TYPE_RANK, OMNI_WINDOW_TYPE_ROW_NUMBER} +import nova.hetu.omniruntime.constants.FunctionType.{OMNI_AGGREGATION_TYPE_AVG, OMNI_AGGREGATION_TYPE_COUNT_ALL, OMNI_AGGREGATION_TYPE_COUNT_COLUMN, OMNI_AGGREGATION_TYPE_MAX, OMNI_AGGREGATION_TYPE_MIN, OMNI_AGGREGATION_TYPE_SUM, OMNI_WINDOW_TYPE_RANK, OMNI_WINDOW_TYPE_ROW_NUMBER} import nova.hetu.omniruntime.operator.OmniExprVerify import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.expressions._ @@ -667,7 +667,7 @@ object OmniExpressionAdaptor extends Logging { } } - def toOmniAggFunType(agg: AggregateExpression, isHashAgg: Boolean = false): FunctionType = { + def toOmniAggFunType(agg: AggregateExpression, isHashAgg: Boolean = false, isFinal: Boolean = false): FunctionType = { agg.aggregateFunction match { case Sum(_) => { if (isHashAgg) { @@ -681,7 +681,11 @@ object OmniExpressionAdaptor extends Logging { case Average(_) => OMNI_AGGREGATION_TYPE_AVG case Min(_) => OMNI_AGGREGATION_TYPE_MIN case Count(Literal(1, IntegerType) :: Nil) | Count(ArrayBuffer(Literal(1, IntegerType))) => - throw new UnsupportedOperationException("Unsupported count(*) or count(1)") + if (isFinal) { + OMNI_AGGREGATION_TYPE_COUNT_COLUMN + } else { + OMNI_AGGREGATION_TYPE_COUNT_ALL + } case Count(_) => OMNI_AGGREGATION_TYPE_COUNT_COLUMN case _ => throw new UnsupportedOperationException(s"Unsupported aggregate function: $agg") } diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarHashAggregateExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarHashAggregateExec.scala index 1ecc0923d524c5c75291667c7a02f49c4adc9df0..d93e1dc54ace3160f28c4351e47c3d6b33115533 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarHashAggregateExec.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarHashAggregateExec.scala @@ -23,6 +23,7 @@ import com.huawei.boostkit.spark.expression.OmniExpressionAdaptor._ import com.huawei.boostkit.spark.util.OmniAdaptorUtil.transColBatchToOmniVecs import nova.hetu.omniruntime.`type`.DataType import nova.hetu.omniruntime.constants.FunctionType +import nova.hetu.omniruntime.constants.FunctionType.OMNI_AGGREGATION_TYPE_COUNT_ALL import nova.hetu.omniruntime.operator.aggregator.OmniHashAggregationWithExprOperatorFactory import nova.hetu.omniruntime.operator.config.{OperatorConfig, SpillConfig} import nova.hetu.omniruntime.vector.VecBatch @@ -76,7 +77,7 @@ case class ColumnarHashAggregateExec( val omniAggTypes = new Array[DataType](aggregateExpressions.size) val omniAggFunctionTypes = new Array[FunctionType](aggregateExpressions.size) val omniAggOutputTypes = new Array[DataType](aggregateExpressions.size) - val omniAggChannels = new Array[AnyRef](aggregateExpressions.size) + var omniAggChannels = new Array[AnyRef](aggregateExpressions.size) var index = 0 for (exp <- aggregateExpressions) { if (exp.filter.isDefined) { @@ -90,12 +91,12 @@ case class ColumnarHashAggregateExec( case Sum(_) | Min(_) | Max(_) | Count(_) => val aggExp = exp.aggregateFunction.inputAggBufferAttributes.head omniAggTypes(index) = sparkTypeToOmniType(aggExp.dataType, aggExp.metadata) - omniAggFunctionTypes(index) = toOmniAggFunType(exp, true) + omniAggFunctionTypes(index) = toOmniAggFunType(exp, true, true) omniAggOutputTypes(index) = sparkTypeToOmniType(exp.aggregateFunction.dataType) omniAggChannels(index) = rewriteToOmniJsonExpressionLiteral(aggExp, attrExpsIdMap) - case _ => throw new UnsupportedOperationException(s"Unsupported aggregate aggregateFunction: $exp") + case _ => throw new UnsupportedOperationException(s"Unsupported aggregate aggregateFunction: ${exp}") } } else if (exp.mode == Partial) { omniInputRaw = true @@ -109,6 +110,9 @@ case class ColumnarHashAggregateExec( sparkTypeToOmniType(exp.aggregateFunction.dataType) omniAggChannels(index) = rewriteToOmniJsonExpressionLiteral(aggExp, attrExpsIdMap) + if (omniAggFunctionTypes(index) == OMNI_AGGREGATION_TYPE_COUNT_ALL) { + omniAggChannels(index) = null + } case _ => throw new UnsupportedOperationException(s"Unsupported aggregate aggregateFunction: $exp") } } else { @@ -116,7 +120,7 @@ case class ColumnarHashAggregateExec( } index += 1 } - + omniAggChannels = omniAggChannels.filter(key => key != null) val omniSourceTypes = new Array[DataType](child.outputSet.size) val inputIter = child.outputSet.toIterator var i = 0 @@ -158,7 +162,7 @@ case class ColumnarHashAggregateExec( val omniAggTypes = new Array[DataType](aggregateExpressions.size) val omniAggFunctionTypes = new Array[FunctionType](aggregateExpressions.size) val omniAggOutputTypes = new Array[DataType](aggregateExpressions.size) - val omniAggChannels = new Array[String](aggregateExpressions.size) + var omniAggChannels = new Array[String](aggregateExpressions.size) var index = 0 for (exp <- aggregateExpressions) { if (exp.filter.isDefined) { @@ -172,7 +176,7 @@ case class ColumnarHashAggregateExec( case Sum(_) | Min(_) | Max(_) | Count(_) => val aggExp = exp.aggregateFunction.inputAggBufferAttributes.head omniAggTypes(index) = sparkTypeToOmniType(aggExp.dataType, aggExp.metadata) - omniAggFunctionTypes(index) = toOmniAggFunType(exp, true) + omniAggFunctionTypes(index) = toOmniAggFunType(exp, true, true) omniAggOutputTypes(index) = sparkTypeToOmniType(exp.aggregateFunction.dataType) omniAggChannels(index) = @@ -191,6 +195,9 @@ case class ColumnarHashAggregateExec( sparkTypeToOmniType(exp.aggregateFunction.dataType) omniAggChannels(index) = rewriteToOmniJsonExpressionLiteral(aggExp, attrExpsIdMap) + if (omniAggFunctionTypes(index) == OMNI_AGGREGATION_TYPE_COUNT_ALL) { + omniAggChannels(index) = null + } case _ => throw new UnsupportedOperationException(s"Unsupported aggregate aggregateFunction: ${exp}") } } else { @@ -199,6 +206,7 @@ case class ColumnarHashAggregateExec( index += 1 } + omniAggChannels = omniAggChannels.filter(key => key != null) val omniSourceTypes = new Array[DataType](child.outputSet.size) val inputIter = child.outputSet.toIterator var i = 0 diff --git a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/ColumnarHashAggregateExecSuite.scala b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/ColumnarHashAggregateExecSuite.scala index 1f74413f93852aa8c5526c8180296d4b3c27e5ed..5c732d6b97eff3be3e46d6d3afda3c776444e6b0 100644 --- a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/ColumnarHashAggregateExecSuite.scala +++ b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/ColumnarHashAggregateExecSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution -import org.apache.spark.sql.functions.sum +import org.apache.spark.sql.functions.{sum, count} import org.apache.spark.sql.types._ import org.apache.spark.sql.{DataFrame, Row} @@ -58,4 +58,23 @@ class ColumnarHashAggregateExecSuite extends ColumnarSparkPlanTest { Seq(Row(1, 2.0), Row(2, 1.0)) ) } + + test("Test ColumnarHashAggregateExec happen and result is correct when execute count(*) api") { + val res = df.agg(count("*")) + assert(res.queryExecution.executedPlan.find(_.isInstanceOf[ColumnarHashAggregateExec]).isDefined, s"ColumnarHashAggregateExec not happened, executedPlan as follows: \n${res.queryExecution.executedPlan}") + checkAnswer( + res, + Seq(Row(5)) + ) + } + + test("Test ColumnarHashAggregateExec happen and result " + + "is correct when execute count(*) api with group by") { + val res = df.groupBy("a").agg(count("*")) + assert(res.queryExecution.executedPlan.find(_.isInstanceOf[ColumnarHashAggregateExec]).isDefined, s"ColumnarHashAggregateExec not happened, executedPlan as follows: \n${res.queryExecution.executedPlan}") + checkAnswer( + res, + Seq(Row(1, 2), Row(2, 1), Row(null, 2)) + ) + } } diff --git a/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/forsql/ColumnarHashAggregateExecSqlSuite.scala b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/forsql/ColumnarHashAggregateExecSqlSuite.scala new file mode 100644 index 0000000000000000000000000000000000000000..5b0d14f45cc25c3c90b67e99c73458d3d6c75b13 --- /dev/null +++ b/omnioperator/omniop-spark-extension/java/src/test/scala/org/apache/spark/sql/execution/forsql/ColumnarHashAggregateExecSqlSuite.scala @@ -0,0 +1,196 @@ +/* + * Copyright (C) 2022-2022. Huawei Technologies Co., Ltd. All rights reserved. + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.forsql + +import org.apache.spark.sql.execution.aggregate.HashAggregateExec +import org.apache.spark.sql.execution.{ColumnarHashAggregateExec, ColumnarSparkPlanTest} +import org.apache.spark.sql.types._ +import org.apache.spark.sql.{DataFrame, Row} + +class ColumnarHashAggregateExecSqlSuite extends ColumnarSparkPlanTest { + private var df: DataFrame = _ + + protected override def beforeAll(): Unit = { + super.beforeAll() + df = spark.createDataFrame( + sparkContext.parallelize(Seq( + Row(1, 2.0, 1L, "a"), + Row(1, 2.0, 2L, null), + Row(2, 1.0, 3L, "c"), + Row(null, null, 6L, "e"), + Row(null, 5.0, 7L, "f") + )), new StructType().add("intCol", IntegerType).add("doubleCol", DoubleType) + .add("longCol", LongType).add("stringCol", StringType)) + df.createOrReplaceTempView("test_table") + } + + test("Test ColumnarHashAggregateExec happen and result is correct when execute count(*)") { + val res = spark.sql("select count(*) from test_table") + val executedPlan = res.queryExecution.executedPlan + assert(executedPlan.find(_.isInstanceOf[ColumnarHashAggregateExec]).isDefined, s"ColumnarHashAggregateExec not happened, executedPlan as follows: \n$executedPlan") + assert(executedPlan.find(_.isInstanceOf[HashAggregateExec]).isEmpty, s"HashAggregateExec happened, executedPlan as follows: \n$executedPlan") + checkAnswer( + res, + Seq(Row(5)) + ) + } + + test("Test ColumnarHashAggregateExec happen and result is correct when execute count(1)") { + val res = spark.sql("select count(1) from test_table") + val executedPlan = res.queryExecution.executedPlan + assert(executedPlan.find(_.isInstanceOf[ColumnarHashAggregateExec]).isDefined, s"ColumnarHashAggregateExec not happened, executedPlan as follows: \n$executedPlan") + assert(executedPlan.find(_.isInstanceOf[HashAggregateExec]).isEmpty, s"HashAggregateExec happened, executedPlan as follows: \n$executedPlan") + checkAnswer( + res, + Seq(Row(5)) + ) + } + + test("Test ColumnarHashAggregateExec happen and result is correct when execute count(-1)") { + val res = spark.sql("select count(-1) from test_table") + val executedPlan = res.queryExecution.executedPlan + assert(executedPlan.find(_.isInstanceOf[ColumnarHashAggregateExec]).isDefined, s"ColumnarHashAggregateExec not happened, executedPlan as follows: \n$executedPlan") + assert(executedPlan.find(_.isInstanceOf[HashAggregateExec]).isEmpty, s"HashAggregateExec happened, executedPlan as follows: \n$executedPlan") + checkAnswer( + res, + Seq(Row(5)) + ) + } + + test("Test ColumnarHashAggregateExec happen and result " + + "is correct when execute otherAgg-count(*)") { + val res = spark.sql("select max(intCol), count(*) from test_table") + val executedPlan = res.queryExecution.executedPlan + assert(executedPlan.find(_.isInstanceOf[ColumnarHashAggregateExec]).isDefined, s"ColumnarHashAggregateExec not happened, executedPlan as follows: \n$executedPlan") + assert(executedPlan.find(_.isInstanceOf[HashAggregateExec]).isEmpty, s"HashAggregateExec happened, executedPlan as follows: \n$executedPlan") + checkAnswer( + res, + Seq(Row(2, 5)) + ) + } + + test("Test ColumnarHashAggregateExec happen and result " + + "is correct when execute count(*)-otherAgg") { + val res = spark.sql("select count(*), max(intCol) from test_table") + val executedPlan = res.queryExecution.executedPlan + assert(executedPlan.find(_.isInstanceOf[ColumnarHashAggregateExec]).isDefined, s"ColumnarHashAggregateExec not happened, executedPlan as follows: \n$executedPlan") + assert(executedPlan.find(_.isInstanceOf[HashAggregateExec]).isEmpty, s"HashAggregateExec happened, executedPlan as follows: \n$executedPlan") + checkAnswer( + res, + Seq(Row(5, 2)) + ) + } + + test("Test ColumnarHashAggregateExec happen and result " + + "is correct when execute count(*)-count(*)") { + val res = spark.sql("select count(*), count(*) from test_table") + val executedPlan = res.queryExecution.executedPlan + assert(executedPlan.find(_.isInstanceOf[ColumnarHashAggregateExec]).isDefined, s"ColumnarHashAggregateExec not happened, executedPlan as follows: \n$executedPlan") + assert(executedPlan.find(_.isInstanceOf[HashAggregateExec]).isEmpty, s"HashAggregateExec happened, executedPlan as follows: \n$executedPlan") + checkAnswer( + res, + Seq(Row(5, 5)) + ) + } + + test("Test ColumnarHashAggregateExec happen and result " + + "is correct when execute count(*)-otherAgg-count(*)") { + val res = spark.sql("select count(*), max(intCol), count(*) from test_table") + val executedPlan = res.queryExecution.executedPlan + assert(executedPlan.find(_.isInstanceOf[ColumnarHashAggregateExec]).isDefined, s"ColumnarHashAggregateExec not happened, executedPlan as follows: \n$executedPlan") + assert(executedPlan.find(_.isInstanceOf[HashAggregateExec]).isEmpty, s"HashAggregateExec happened, executedPlan as follows: \n$executedPlan") + checkAnswer( + res, + Seq(Row(5, 2, 5)) + ) + } + + test("Test ColumnarHashAggregateExec happen and result " + + "is correct when execute otherAgg-count(*)-otherAgg") { + val res = spark.sql("select max(intCol), count(*), min(intCol) from test_table") + val executedPlan = res.queryExecution.executedPlan + assert(executedPlan.find(_.isInstanceOf[ColumnarHashAggregateExec]).isDefined, s"ColumnarHashAggregateExec not happened, executedPlan as follows: \n$executedPlan") + assert(executedPlan.find(_.isInstanceOf[HashAggregateExec]).isEmpty, s"HashAggregateExec happened, executedPlan as follows: \n$executedPlan") + checkAnswer( + res, + Seq(Row(2, 5, 1)) + ) + } + + test("Test ColumnarHashAggregateExec happen and result " + + "is correct when execute otherAgg-count(*)-count(*)") { + val res = spark.sql("select max(intCol), count(*), count(*) from test_table") + val executedPlan = res.queryExecution.executedPlan + assert(executedPlan.find(_.isInstanceOf[ColumnarHashAggregateExec]).isDefined, s"ColumnarHashAggregateExec not happened, executedPlan as follows: \n$executedPlan") + assert(executedPlan.find(_.isInstanceOf[HashAggregateExec]).isEmpty, s"HashAggregateExec happened, executedPlan as follows: \n$executedPlan") + checkAnswer( + res, + Seq(Row(2, 5, 5)) + ) + } + + test("Test ColumnarHashAggregateExec happen and result " + + "is correct when execute count(*) with group by") { + val res = spark.sql("select count(*) from test_table group by intCol") + val executedPlan = res.queryExecution.executedPlan + assert(executedPlan.find(_.isInstanceOf[ColumnarHashAggregateExec]).isDefined, s"ColumnarHashAggregateExec not happened, executedPlan as follows: \n$executedPlan") + assert(executedPlan.find(_.isInstanceOf[HashAggregateExec]).isEmpty, s"HashAggregateExec happened, executedPlan as follows: \n$executedPlan") + checkAnswer( + res, + Seq(Row(2), Row(1), Row(2)) + ) + } + + test("Test ColumnarHashAggregateExec happen and result" + + " is correct when execute count(*) with calculation expr") { + val res = spark.sql("select count(*) / 2 from test_table") + val executedPlan = res.queryExecution.executedPlan + assert(executedPlan.find(_.isInstanceOf[ColumnarHashAggregateExec]).isDefined, s"ColumnarHashAggregateExec not happened, executedPlan as follows: \n$executedPlan") + assert(executedPlan.find(_.isInstanceOf[HashAggregateExec]).isEmpty, s"HashAggregateExec happened, executedPlan as follows: \n$executedPlan") + checkAnswer( + res, + Seq(Row(2.5)) + ) + } + + test("Test ColumnarHashAggregateExec happen and result" + + " is correct when execute count(*) with cast expr") { + val res = spark.sql("select cast(count(*) as bigint) from test_table") + val executedPlan = res.queryExecution.executedPlan + assert(executedPlan.find(_.isInstanceOf[ColumnarHashAggregateExec]).isDefined, s"ColumnarHashAggregateExec not happened, executedPlan as follows: \n$executedPlan") + assert(executedPlan.find(_.isInstanceOf[HashAggregateExec]).isEmpty, s"HashAggregateExec happened, executedPlan as follows: \n$executedPlan") + checkAnswer( + res, + Seq(Row(5)) + ) + } + + test("Test ColumnarHashAggregateExec happen and result" + + " is correct when execute count(*) with subQuery") { + val res = spark.sql("select count(*) from (select intCol," + + "count(*) from test_table group by intCol)") + val executedPlan = res.queryExecution.executedPlan + assert(executedPlan.find(_.isInstanceOf[ColumnarHashAggregateExec]).isDefined, s"ColumnarHashAggregateExec not happened, executedPlan as follows: \n$executedPlan") + assert(executedPlan.find(_.isInstanceOf[HashAggregateExec]).isEmpty, s"HashAggregateExec happened, executedPlan as follows: \n$executedPlan") + checkAnswer( + res, + Seq(Row(3)) + ) + } +} \ No newline at end of file