diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/expression/OmniExpressionAdaptor.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/expression/OmniExpressionAdaptor.scala index 6082dc141195d4d6cb8f22540335b32db9b639db..85a1f2212f96fa1c699d5a845efabf96dd22e4d7 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/expression/OmniExpressionAdaptor.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/expression/OmniExpressionAdaptor.scala @@ -22,7 +22,7 @@ import scala.collection.mutable.ArrayBuffer import com.huawei.boostkit.spark.Constant.{DEFAULT_STRING_TYPE_LENGTH, IS_CHECK_OMNI_EXP, OMNI_BOOLEAN_TYPE, OMNI_DATE_TYPE, OMNI_DECIMAL128_TYPE, OMNI_DECIMAL64_TYPE, OMNI_DOUBLE_TYPE, OMNI_INTEGER_TYPE, OMNI_LONG_TYPE, OMNI_SHOR_TYPE, OMNI_TIMESTAMP_TYPE, OMNI_VARCHAR_TYPE} import nova.hetu.omniruntime.`type`.{BooleanDataType, DataTypeSerializer, Date32DataType, Decimal128DataType, Decimal64DataType, DoubleDataType, IntDataType, LongDataType, ShortDataType, VarcharDataType, TimestampDataType} import nova.hetu.omniruntime.constants.FunctionType -import nova.hetu.omniruntime.constants.FunctionType.{OMNI_AGGREGATION_TYPE_AVG, OMNI_AGGREGATION_TYPE_COUNT_ALL, OMNI_AGGREGATION_TYPE_COUNT_COLUMN, OMNI_AGGREGATION_TYPE_FIRST_IGNORENULL, OMNI_AGGREGATION_TYPE_FIRST_INCLUDENULL, OMNI_AGGREGATION_TYPE_MAX, OMNI_AGGREGATION_TYPE_MIN, OMNI_AGGREGATION_TYPE_SUM, OMNI_WINDOW_TYPE_RANK, OMNI_WINDOW_TYPE_ROW_NUMBER} +import nova.hetu.omniruntime.constants.FunctionType.{OMNI_AGGREGATION_TYPE_AVG, OMNI_AGGREGATION_TYPE_COUNT_ALL, OMNI_AGGREGATION_TYPE_COUNT_COLUMN, OMNI_AGGREGATION_TYPE_FIRST_IGNORENULL, OMNI_AGGREGATION_TYPE_FIRST_INCLUDENULL, OMNI_AGGREGATION_TYPE_MAX, OMNI_AGGREGATION_TYPE_MIN, OMNI_AGGREGATION_TYPE_SAMP, OMNI_AGGREGATION_TYPE_SUM, OMNI_WINDOW_TYPE_RANK, OMNI_WINDOW_TYPE_ROW_NUMBER} import nova.hetu.omniruntime.constants.JoinType._ import nova.hetu.omniruntime.operator.OmniExprVerify import com.huawei.boostkit.spark.ColumnarPluginConfig @@ -809,6 +809,7 @@ object OmniExpressionAdaptor extends Logging { case Max(_) => OMNI_AGGREGATION_TYPE_MAX case Average(_, _) => OMNI_AGGREGATION_TYPE_AVG case Min(_) => OMNI_AGGREGATION_TYPE_MIN + case StddevSamp(_,_) => OMNI_AGGREGATION_TYPE_SAMP case Count(Literal(1, IntegerType) :: Nil) | Count(ArrayBuffer(Literal(1, IntegerType))) => if (isMergeCount) { OMNI_AGGREGATION_TYPE_COUNT_COLUMN diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/util/OmniAdaptorUtil.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/util/OmniAdaptorUtil.scala index da1b5f54e6123571f286d2806cfe64e3ce55db62..15f36511e14004ef6475cb49891ef37e4dc87f8c 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/util/OmniAdaptorUtil.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/util/OmniAdaptorUtil.scala @@ -305,6 +305,10 @@ object OmniAdaptorUtil { OverflowConfig.OverflowConfigId.OVERFLOW_CONFIG_NULL } + private def GetIsStatisticalAggregate(): Boolean = { + SQLConf.get.legacyStatisticalAggregate + } + def getAggOperator(groupingExpressions: Seq[NamedExpression], omniGroupByChanel: Array[String], omniAggChannels: Array[Array[String]], @@ -329,8 +333,8 @@ object OmniAdaptorUtil { omniAggOutputTypes, omniInputRaws, omniOutputPartials, - new OperatorConfig(sparkSpillConf, new OverflowConfig(OmniAdaptorUtil.overflowConf()), - IS_SKIP_VERIFY_EXP)) + new OperatorConfig(sparkSpillConf, GetIsStatisticalAggregate(), + new OverflowConfig(OmniAdaptorUtil.overflowConf()), IS_SKIP_VERIFY_EXP)) operator = hashAggregationWithExprOperatorFactory.createOperator } else { aggregationWithExprOperatorFactory = new OmniAggregationWithExprOperatorFactory( @@ -342,7 +346,8 @@ object OmniAdaptorUtil { omniAggOutputTypes, omniInputRaws, omniOutputPartials, - new OperatorConfig(SpillConfig.NONE, new OverflowConfig(OmniAdaptorUtil.overflowConf()), IS_SKIP_VERIFY_EXP)) + new OperatorConfig(SpillConfig.NONE, GetIsStatisticalAggregate(), + new OverflowConfig(OmniAdaptorUtil.overflowConf()), IS_SKIP_VERIFY_EXP)) operator = aggregationWithExprOperatorFactory.createOperator } (operator, hashAggregationWithExprOperatorFactory, aggregationWithExprOperatorFactory) diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarFileSourceScanExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarFileSourceScanExec.scala index 800dcf1a0c047b206603b24a2a963ef7f3e48db1..c14eb74548facd1b0602c20ed3ebee02fef5461a 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarFileSourceScanExec.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarFileSourceScanExec.scala @@ -605,7 +605,7 @@ abstract class BaseColumnarFileSourceScanExec( throw new UnsupportedOperationException(s"Unsupported final aggregate expression in operator fusion, exp: $exp") } else if (exp.mode == Partial) { exp.aggregateFunction match { - case Sum(_, _) | Min(_) | Average(_, _) | Max(_) | Count(_) | First(_, _) => + case Sum(_, _) | Min(_) | Average(_, _) | Max(_) | Count(_) | First(_, _) | StddevSamp(_,_) => val aggExp = exp.aggregateFunction.children.head omniOutputExressionOrder += { exp.aggregateFunction.inputAggBufferAttributes.head.exprId -> @@ -623,7 +623,7 @@ abstract class BaseColumnarFileSourceScanExec( } } else if (exp.mode == PartialMerge) { exp.aggregateFunction match { - case Sum(_, _) | Min(_) | Average(_, _) | Max(_) | Count(_) | First(_, _) => + case Sum(_, _) | Min(_) | Average(_, _) | Max(_) | Count(_) | First(_, _) | StddevSamp(_,_) => val aggExp = exp.aggregateFunction.children.head omniOutputExressionOrder += { exp.aggregateFunction.inputAggBufferAttributes.head.exprId -> diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarHashAggregateExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarHashAggregateExec.scala index fe4531ffe3e6cb8351df515df1ca88dc07bf865a..6ad6f3bed535091c456d750ed97da72949c5337a 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarHashAggregateExec.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarHashAggregateExec.scala @@ -132,7 +132,7 @@ case class ColumnarHashAggregateExec( } if (exp.mode == Final) { exp.aggregateFunction match { - case Sum(_, _) | Min(_) | Max(_) | Count(_) | Average(_, _) | First(_,_) => + case Sum(_, _) | Min(_) | Max(_) | Count(_) | Average(_, _) | First(_,_) | StddevSamp(_,_) => omniAggFunctionTypes(index) = toOmniAggFunType(exp, true, true) omniAggOutputTypes(index) = toOmniAggInOutType(exp.aggregateFunction.dataType) omniAggChannels(index) = @@ -143,7 +143,7 @@ case class ColumnarHashAggregateExec( } } else if (exp.mode == PartialMerge) { exp.aggregateFunction match { - case Sum(_, _) | Min(_) | Max(_) | Count(_) | Average(_, _) | First(_,_) => + case Sum(_, _) | Min(_) | Max(_) | Count(_) | Average(_, _) | First(_,_) | StddevSamp(_,_) => omniAggFunctionTypes(index) = toOmniAggFunType(exp, true, true) omniAggOutputTypes(index) = toOmniAggInOutType(exp.aggregateFunction.inputAggBufferAttributes) @@ -155,7 +155,7 @@ case class ColumnarHashAggregateExec( } } else if (exp.mode == Partial) { exp.aggregateFunction match { - case Sum(_, _) | Min(_) | Max(_) | Count(_) | Average(_, _) | First(_,_) => + case Sum(_, _) | Min(_) | Max(_) | Count(_) | Average(_, _) | First(_,_) | StddevSamp(_,_) => omniAggFunctionTypes(index) = toOmniAggFunType(exp, true) omniAggOutputTypes(index) = toOmniAggInOutType(exp.aggregateFunction.inputAggBufferAttributes) @@ -262,7 +262,7 @@ case class ColumnarHashAggregateExec( } if (exp.mode == Final) { exp.aggregateFunction match { - case Sum(_, _) | Min(_) | Max(_) | Count(_) | Average(_, _) | First(_, _) => + case Sum(_, _) | Min(_) | Max(_) | Count(_) | Average(_, _) | First(_, _) | StddevSamp(_,_) => omniAggFunctionTypes(index) = toOmniAggFunType(exp, true, true) omniAggOutputTypes(index) = toOmniAggInOutType(exp.aggregateFunction.dataType) @@ -283,7 +283,7 @@ case class ColumnarHashAggregateExec( omniAggChannels) } else if (exp.mode == Partial) { exp.aggregateFunction match { - case Sum(_, _) | Min(_) | Max(_) | Count(_) | Average(_, _) | First(_, _) => + case Sum(_, _) | Min(_) | Max(_) | Count(_) | Average(_, _) | First(_, _) | StddevSamp(_,_) => omniAggFunctionTypes(index) = toOmniAggFunType(exp, true, true) omniAggOutputTypes(index) = toOmniAggInOutType(exp.aggregateFunction.inputAggBufferAttributes)