diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/util/OmniAdaptorUtil.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/util/OmniAdaptorUtil.scala index 113e88399ba897dcc38c3c40841a18dafd2a9315..2270c0c861413cd53910f9be1692f7d11f39e85f 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/util/OmniAdaptorUtil.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/util/OmniAdaptorUtil.scala @@ -301,10 +301,13 @@ object OmniAdaptorUtil { omniAggOutputTypes: Array[Array[nova.hetu.omniruntime.`type`.DataType]], omniInputRaws: Array[Boolean], omniOutputPartials: Array[Boolean], - sparkSpillConf: SpillConfig = SpillConfig.NONE): OmniOperator = { + sparkSpillConf: SpillConfig = SpillConfig.NONE): + (OmniOperator, OmniHashAggregationWithExprOperatorFactory, OmniAggregationWithExprOperatorFactory) = { + var hashAggregationWithExprOperatorFactory: OmniHashAggregationWithExprOperatorFactory = null + var aggregationWithExprOperatorFactory : OmniAggregationWithExprOperatorFactory = null var operator: OmniOperator = null if (groupingExpressions.nonEmpty) { - operator = new OmniHashAggregationWithExprOperatorFactory( + hashAggregationWithExprOperatorFactory = new OmniHashAggregationWithExprOperatorFactory( omniGroupByChanel, omniAggChannels, omniAggChannelsFilter, @@ -314,9 +317,10 @@ object OmniAdaptorUtil { omniInputRaws, omniOutputPartials, new OperatorConfig(sparkSpillConf, new OverflowConfig(OmniAdaptorUtil.overflowConf()), - IS_SKIP_VERIFY_EXP)).createOperator + IS_SKIP_VERIFY_EXP)) + operator = hashAggregationWithExprOperatorFactory.createOperator } else { - operator = new OmniAggregationWithExprOperatorFactory( + aggregationWithExprOperatorFactory = new OmniAggregationWithExprOperatorFactory( omniGroupByChanel, omniAggChannels, omniAggChannelsFilter, @@ -325,9 +329,10 @@ object OmniAdaptorUtil { omniAggOutputTypes, omniInputRaws, omniOutputPartials, - new OperatorConfig(SpillConfig.NONE, new OverflowConfig(OmniAdaptorUtil.overflowConf()), IS_SKIP_VERIFY_EXP)).createOperator + new OperatorConfig(SpillConfig.NONE, new OverflowConfig(OmniAdaptorUtil.overflowConf()), IS_SKIP_VERIFY_EXP)) + operator = aggregationWithExprOperatorFactory.createOperator } - operator + (operator, hashAggregationWithExprOperatorFactory, aggregationWithExprOperatorFactory) } def pruneOutput(output: Seq[Attribute], projectExprIdList: Seq[ExprId]): Seq[Attribute] = { diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarExpandExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarExpandExec.scala index 24c74d600b9f360db932d25f0e3a60ea87ade19c..f6f93a9208c695c2924849b77e605eaa7e1e7a5c 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarExpandExec.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarExpandExec.scala @@ -320,7 +320,7 @@ case class ColumnarOptRollupExec( factory.createOperator }) - val hashaggOperator = OmniAdaptorUtil.getAggOperator(groupingExpressions, + val (hashaggOperator, hashAggregationWithExprOperatorFactory, aggregationWithExprOperatorFactory) = OmniAdaptorUtil.getAggOperator(groupingExpressions, omniGroupByChannel, omniAggChannels, omniAggChannelsFilter, @@ -339,6 +339,12 @@ case class ColumnarOptRollupExec( addLeakSafeTaskCompletionListener[Unit](_ => { projectOperators.foreach(operator => operator.close()) hashaggOperator.close() + if (hashAggregationWithExprOperatorFactory != null) { + hashAggregationWithExprOperatorFactory.close() + } + if (aggregationWithExprOperatorFactory != null) { + aggregationWithExprOperatorFactory.close() + } results.foreach(vecBatch => { vecBatch.releaseAllVectors() vecBatch.close() diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarFileSourceScanExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarFileSourceScanExec.scala index ddabce36747a404dc95450b9c9f5e95ea9a1b508..800dcf1a0c047b206603b24a2a963ef7f3e48db1 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarFileSourceScanExec.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarFileSourceScanExec.scala @@ -908,7 +908,7 @@ case class ColumnarMultipleOperatorExec( // for join val deserializer = VecBatchSerializerFactory.create() val startCodegen = System.nanoTime() - val aggOperator = OmniAdaptorUtil.getAggOperator(aggregate.groupingExpressions, + val (aggOperator, hashAggregationWithExprOperatorFactory, aggregationWithExprOperatorFactory) = OmniAdaptorUtil.getAggOperator(aggregate.groupingExpressions, omniGroupByChanel, omniAggChannels, omniAggChannelsFilter, @@ -920,6 +920,12 @@ case class ColumnarMultipleOperatorExec( omniCodegenTime += NANOSECONDS.toMillis(System.nanoTime() - startCodegen) SparkMemoryUtils.addLeakSafeTaskCompletionListener[Unit](_ => { aggOperator.close() + if (hashAggregationWithExprOperatorFactory != null) { + hashAggregationWithExprOperatorFactory.close() + } + if (aggregationWithExprOperatorFactory != null) { + aggregationWithExprOperatorFactory.close() + } }) val projectOperatorFactory1 = new OmniProjectOperatorFactory(proj1OmniExpressions, proj1OmniInputTypes, 1, @@ -928,6 +934,7 @@ case class ColumnarMultipleOperatorExec( // close operator addLeakSafeTaskCompletionListener[Unit](_ => { projectOperator1.close() + projectOperatorFactory1.close() }) val buildOpFactory1 = new OmniHashBuilderWithExprOperatorFactory(OMNI_JOIN_TYPE_INNER, buildTypes1, @@ -962,6 +969,7 @@ case class ColumnarMultipleOperatorExec( // close operator addLeakSafeTaskCompletionListener[Unit](_ => { projectOperator2.close() + projectOperatorFactory2.close() }) val buildOpFactory2 = new OmniHashBuilderWithExprOperatorFactory(OMNI_JOIN_TYPE_INNER, buildTypes2, @@ -997,6 +1005,7 @@ case class ColumnarMultipleOperatorExec( // close operator addLeakSafeTaskCompletionListener[Unit](_ => { projectOperator3.close() + projectOperatorFactory3.close() }) val buildOpFactory3 = new OmniHashBuilderWithExprOperatorFactory(OMNI_JOIN_TYPE_INNER, buildTypes3, @@ -1032,6 +1041,7 @@ case class ColumnarMultipleOperatorExec( // close operator addLeakSafeTaskCompletionListener[Unit](_ => { projectOperator4.close() + projectOperatorFactory4.close() }) val buildOpFactory4 = new OmniHashBuilderWithExprOperatorFactory(OMNI_JOIN_TYPE_INNER, buildTypes4, @@ -1069,6 +1079,7 @@ case class ColumnarMultipleOperatorExec( // close operator SparkMemoryUtils.addLeakSafeTaskCompletionListener[Unit](_ => { condOperator.close() + condOperatorFactory.close() }) while (batches.hasNext) { @@ -1273,7 +1284,7 @@ case class ColumnarMultipleOperatorExec1( // for join val deserializer = VecBatchSerializerFactory.create() val startCodegen = System.nanoTime() - val aggOperator = OmniAdaptorUtil.getAggOperator(aggregate.groupingExpressions, + val (aggOperator, hashAggregationWithExprOperatorFactory, aggregationWithExprOperatorFactory) = OmniAdaptorUtil.getAggOperator(aggregate.groupingExpressions, omniGroupByChanel, omniAggChannels, omniAggChannelsFilter, @@ -1285,6 +1296,12 @@ case class ColumnarMultipleOperatorExec1( omniCodegenTime += NANOSECONDS.toMillis(System.nanoTime() - startCodegen) SparkMemoryUtils.addLeakSafeTaskCompletionListener[Unit](_ => { aggOperator.close() + if (hashAggregationWithExprOperatorFactory != null) { + hashAggregationWithExprOperatorFactory.close() + } + if (aggregationWithExprOperatorFactory != null) { + aggregationWithExprOperatorFactory.close() + } }) val projectOperatorFactory1 = new OmniProjectOperatorFactory(proj1OmniExpressions, proj1OmniInputTypes, 1, @@ -1293,6 +1310,7 @@ case class ColumnarMultipleOperatorExec1( // close operator addLeakSafeTaskCompletionListener[Unit](_ => { projectOperator1.close() + projectOperatorFactory1.close() }) val buildOpFactory1 = new OmniHashBuilderWithExprOperatorFactory(OMNI_JOIN_TYPE_INNER, buildTypes1, @@ -1328,6 +1346,7 @@ case class ColumnarMultipleOperatorExec1( // close operator addLeakSafeTaskCompletionListener[Unit](_ => { projectOperator2.close() + projectOperatorFactory2.close() }) val buildOpFactory2 = new OmniHashBuilderWithExprOperatorFactory(OMNI_JOIN_TYPE_INNER, buildTypes2, @@ -1363,6 +1382,7 @@ case class ColumnarMultipleOperatorExec1( // close operator addLeakSafeTaskCompletionListener[Unit](_ => { projectOperator3.close() + projectOperatorFactory3.close() }) val buildOpFactory3 = new OmniHashBuilderWithExprOperatorFactory(OMNI_JOIN_TYPE_INNER, buildTypes3, @@ -1400,6 +1420,7 @@ case class ColumnarMultipleOperatorExec1( // close operator SparkMemoryUtils.addLeakSafeTaskCompletionListener[Unit](_ => { condOperator.close() + condOperatorFactory.close() }) while (batches.hasNext) { diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarHashAggregateExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarHashAggregateExec.scala index 55fba9f2b750d215326d6835ca849ba66d994caf..2e5ecc6532bfe210d99c9c70679882a8ecde7c85 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarHashAggregateExec.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarHashAggregateExec.scala @@ -308,7 +308,7 @@ case class ColumnarHashAggregateExec( spillDirDiskReserveSize, hashAggSpillRowThreshold, spillMemPctThreshold, spillWriteBufferSize) val startCodegen = System.nanoTime() - val operator = OmniAdaptorUtil.getAggOperator(groupingExpressions, + val (operator, hashAggregationWithExprOperatorFactory, aggregationWithExprOperatorFactory) = OmniAdaptorUtil.getAggOperator(groupingExpressions, omniGroupByChanel, omniAggChannels, omniAggChannelsFilter, @@ -324,6 +324,12 @@ case class ColumnarHashAggregateExec( SparkMemoryUtils.addLeakSafeTaskCompletionListener[Unit](_ => { spillSize += operator.getSpilledBytes() operator.close() + if (hashAggregationWithExprOperatorFactory != null) { + hashAggregationWithExprOperatorFactory.close() + } + if (aggregationWithExprOperatorFactory != null) { + aggregationWithExprOperatorFactory.close() + } }) while (iter.hasNext) { diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarLimit.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarLimit.scala index 3603ecccc9df432c2b2e19650204c2341ebf15d5..c2318cddb49cfbc750a727d2b77c3844644d926d 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarLimit.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarLimit.scala @@ -77,6 +77,7 @@ trait ColumnarBaseLimitExec extends LimitExec { // close operator SparkMemoryUtils.addLeakSafeTaskCompletionListener[Unit](_ => { limitOperator.close() + limitOperatorFactory.close() }) val localSchema = this.schema diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarProjection.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarProjection.scala index 49e6968685ccd2079776ace720aaa2342791ca9e..eb493750c79a8b0c782e542d84373b1fdd021f53 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarProjection.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarProjection.scala @@ -50,6 +50,7 @@ object ColumnarProjection { // close operator addLeakSafeTaskCompletionListener[Unit](_ => { projectOperator.close() + projectOperatorFactory.close() }) new Iterator[ColumnarBatch] { diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarSortExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarSortExec.scala index d94d256568e70a80e2b593f7281765826695fd3d..594d0c512b9c061959fed016d0b6d9098358d751 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarSortExec.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarSortExec.scala @@ -113,6 +113,7 @@ case class ColumnarSortExec( SparkMemoryUtils.addLeakSafeTaskCompletionListener[Unit](_ => { spillSize += sortOperator.getSpilledBytes() sortOperator.close() + sortOperatorFactory.close() }) addAllAndGetIterator(sortOperator, iter, this.schema, longMetric("addInputTime"), longMetric("numInputVecBatches"), longMetric("numInputRows"), diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarTopNSortExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarTopNSortExec.scala index 9e52282922c1aceb1da34d650349b7e9562804d7..5293522c17229008a16c488f5d47ecb57a292b94 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarTopNSortExec.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarTopNSortExec.scala @@ -93,6 +93,7 @@ case class ColumnarTopNSortExec( omniCodegenTime += NANOSECONDS.toMillis(System.nanoTime() - startCodegen) SparkMemoryUtils.addLeakSafeTaskCompletionListener[Unit](_ => { topNSortOperator.close() + topNSortOperatorFactory.close() }) addAllAndGetIterator(topNSortOperator, iter, this.schema, longMetric("addInputTime"), longMetric("numInputVecBatches"), longMetric("numInputRows"), diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarWindowExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarWindowExec.scala index 837760ac89c7f714a2cdb8c75a41f6969840f2ca..c400dc9999ae802de6fe686f977fb02c7db7eee2 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarWindowExec.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarWindowExec.scala @@ -382,6 +382,7 @@ case class ColumnarWindowExec(windowExpression: Seq[NamedExpression], SparkMemoryUtils.addLeakSafeTaskCompletionListener[Unit](_ => { spillSize += windowOperator.getSpilledBytes windowOperator.close() + windowOperatorFactory.close() }) while (iter.hasNext) {