From b118232cb83aa99e5bc4f3fdaed64d58394f0bfb Mon Sep 17 00:00:00 2001 From: guojunfei <970763131@qq.com> Date: Sat, 6 May 2023 11:29:06 +0800 Subject: [PATCH] add close for factory --- .../boostkit/spark/util/OmniAdaptorUtil.scala | 29 +++++++++---------- .../ColumnarBasicPhysicalOperators.scala | 2 ++ .../sql/execution/ColumnarExpandExec.scala | 4 +++ .../ColumnarFileSourceScanExec.scala | 17 +++++++++-- .../execution/ColumnarHashAggregateExec.scala | 4 ++- .../sql/execution/ColumnarProjection.scala | 1 + .../ColumnarShuffleExchangeExec.scala | 1 + .../sql/execution/ColumnarSortExec.scala | 1 + .../ColumnarTakeOrderedAndProjectExec.scala | 1 + .../sql/execution/ColumnarWindowExec.scala | 1 + .../sql/execution/util/MergeIterator.scala | 2 ++ 11 files changed, 44 insertions(+), 19 deletions(-) diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/util/OmniAdaptorUtil.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/util/OmniAdaptorUtil.scala index ed99f6b43..65ed58f38 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/util/OmniAdaptorUtil.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/util/OmniAdaptorUtil.scala @@ -274,19 +274,17 @@ object OmniAdaptorUtil { OverflowConfig.OverflowConfigId.OVERFLOW_CONFIG_NULL } - def getAggOperator(groupingExpressions: Seq[NamedExpression], - omniGroupByChanel: Array[String], - omniAggChannels: Array[Array[String]], - omniAggChannelsFilter: Array[String], - omniSourceTypes: Array[nova.hetu.omniruntime.`type`.DataType], - omniAggFunctionTypes: Array[FunctionType], - omniAggOutputTypes: Array[Array[nova.hetu.omniruntime.`type`.DataType]], - omniInputRaws: Array[Boolean], - omniOutputPartials: Array[Boolean]): OmniOperator = { - var operator: OmniOperator = null + def getAggOperatorFactory(groupingExpressions: Seq[NamedExpression], + omniGroupByChanel: Array[String], + omniAggChannels: Array[Array[String]], + omniAggChannelsFilter: Array[String], + omniSourceTypes: Array[nova.hetu.omniruntime.`type`.DataType], + omniAggFunctionTypes: Array[FunctionType], + omniAggOutputTypes: Array[Array[nova.hetu.omniruntime.`type`.DataType]], + omniInputRaws: Array[Boolean], + omniOutputPartials: Array[Boolean]) = { if (groupingExpressions.nonEmpty) { - operator = new OmniHashAggregationWithExprOperatorFactory( - omniGroupByChanel, + new OmniHashAggregationWithExprOperatorFactory(omniGroupByChanel, omniAggChannels, omniAggChannelsFilter, omniSourceTypes, @@ -294,9 +292,9 @@ object OmniAdaptorUtil { omniAggOutputTypes, omniInputRaws, omniOutputPartials, - new OperatorConfig(SpillConfig.NONE, new OverflowConfig(OmniAdaptorUtil.overflowConf()), IS_SKIP_VERIFY_EXP)).createOperator + new OperatorConfig(SpillConfig.NONE, new OverflowConfig(OmniAdaptorUtil.overflowConf()), IS_SKIP_VERIFY_EXP)) } else { - operator = new OmniAggregationWithExprOperatorFactory( + new OmniAggregationWithExprOperatorFactory( omniGroupByChanel, omniAggChannels, omniAggChannelsFilter, @@ -305,9 +303,8 @@ object OmniAdaptorUtil { omniAggOutputTypes, omniInputRaws, omniOutputPartials, - new OperatorConfig(SpillConfig.NONE, new OverflowConfig(OmniAdaptorUtil.overflowConf()), IS_SKIP_VERIFY_EXP)).createOperator + new OperatorConfig(SpillConfig.NONE, new OverflowConfig(OmniAdaptorUtil.overflowConf()), IS_SKIP_VERIFY_EXP)) } - operator } def pruneOutput(output: Seq[Attribute], projectList: Seq[NamedExpression]): Seq[Attribute] = { diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarBasicPhysicalOperators.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarBasicPhysicalOperators.scala index cb23b68f0..d0e151207 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarBasicPhysicalOperators.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarBasicPhysicalOperators.scala @@ -204,6 +204,7 @@ case class ColumnarFilterExec(condition: Expression, child: SparkPlan) // close operator addLeakSafeTaskCompletionListener[Unit](_ => { filterOperator.close() + filterOperatorFactory.close() }) val localSchema = this.schema @@ -302,6 +303,7 @@ case class ColumnarConditionProjectExec(projectList: Seq[NamedExpression], // close operator addLeakSafeTaskCompletionListener[Unit](_ => { operator.close() + operatorFactory.close() }) val localSchema = this.schema diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarExpandExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarExpandExec.scala index 27b05b16c..27ba5629c 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarExpandExec.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarExpandExec.scala @@ -16,6 +16,7 @@ import org.apache.spark.sql.execution.util.SparkMemoryUtils.addLeakSafeTaskCompl import org.apache.spark.sql.execution.vectorized.OmniColumnVector import org.apache.spark.sql.vectorized.ColumnarBatch +import scala.collection.mutable import scala.concurrent.duration.NANOSECONDS /** @@ -80,15 +81,18 @@ case class ColumnarExpandExec( child.executeColumnar().mapPartitionsWithIndexInternal { (index, iter) => val startCodegen = System.nanoTime() + val projectOperatorFactories : mutable.Set[OmniProjectOperatorFactory] = mutable.Set() var projectOperators = omniExpressions.map(exps => { val factory = new OmniProjectOperatorFactory(exps, omniInputTypes, 1, new OperatorConfig(SpillConfig.NONE, new OverflowConfig(OmniAdaptorUtil.overflowConf()), IS_SKIP_VERIFY_EXP)) + projectOperatorFactories.add(factory) factory.createOperator }) omniCodegenTimeMetric += NANOSECONDS.toMillis(System.nanoTime() - startCodegen) // close operator addLeakSafeTaskCompletionListener[Unit](_ => { projectOperators.foreach(operator => operator.close()) + projectOperatorFactories.foreach(factory => factory.close()) }) new Iterator[ColumnarBatch] { diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarFileSourceScanExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarFileSourceScanExec.scala index faf692baa..977242faf 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarFileSourceScanExec.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarFileSourceScanExec.scala @@ -859,7 +859,7 @@ case class ColumnarMultipleOperatorExec( // for join val deserializer = VecBatchSerializerFactory.create() val startCodegen = System.nanoTime() - val aggOperator = OmniAdaptorUtil.getAggOperator(aggregate.groupingExpressions, + val aggFactory = OmniAdaptorUtil.getAggOperatorFactory(aggregate.groupingExpressions, omniGroupByChanel, omniAggChannels, omniAggChannelsFilter, @@ -868,9 +868,11 @@ case class ColumnarMultipleOperatorExec( omniAggOutputTypes, omniAggInputRaw, omniAggOutputPartial) + val aggOperator = aggFactory.createOperator() omniCodegenTime += NANOSECONDS.toMillis(System.nanoTime() - startCodegen) SparkMemoryUtils.addLeakSafeTaskCompletionListener[Unit](_ => { aggOperator.close() + aggFactory.close() }) val projectOperatorFactory1 = new OmniProjectOperatorFactory(proj1OmniExpressions, proj1OmniInputTypes, 1, @@ -879,6 +881,7 @@ case class ColumnarMultipleOperatorExec( // close operator addLeakSafeTaskCompletionListener[Unit](_ => { projectOperator1.close() + projectOperatorFactory1.close() }) val buildOpFactory1 = new OmniHashBuilderWithExprOperatorFactory(buildTypes1, @@ -912,6 +915,7 @@ case class ColumnarMultipleOperatorExec( // close operator addLeakSafeTaskCompletionListener[Unit](_ => { projectOperator2.close() + projectOperatorFactory2.close() }) val buildOpFactory2 = new OmniHashBuilderWithExprOperatorFactory(buildTypes2, @@ -946,6 +950,7 @@ case class ColumnarMultipleOperatorExec( // close operator addLeakSafeTaskCompletionListener[Unit](_ => { projectOperator3.close() + projectOperatorFactory3.close() }) val buildOpFactory3 = new OmniHashBuilderWithExprOperatorFactory(buildTypes3, @@ -980,6 +985,7 @@ case class ColumnarMultipleOperatorExec( // close operator addLeakSafeTaskCompletionListener[Unit](_ => { projectOperator4.close() + projectOperatorFactory4.close() }) val buildOpFactory4 = new OmniHashBuilderWithExprOperatorFactory(buildTypes4, @@ -1016,6 +1022,7 @@ case class ColumnarMultipleOperatorExec( // close operator SparkMemoryUtils.addLeakSafeTaskCompletionListener[Unit](_ => { condOperator.close() + condOperatorFactory.close() }) while (batches.hasNext) { @@ -1220,7 +1227,7 @@ case class ColumnarMultipleOperatorExec1( // for join val deserializer = VecBatchSerializerFactory.create() val startCodegen = System.nanoTime() - val aggOperator = OmniAdaptorUtil.getAggOperator(aggregate.groupingExpressions, + val aggFactory = OmniAdaptorUtil.getAggOperatorFactory(aggregate.groupingExpressions, omniGroupByChanel, omniAggChannels, omniAggChannelsFilter, @@ -1229,9 +1236,11 @@ case class ColumnarMultipleOperatorExec1( omniAggOutputTypes, omniAggInputRaw, omniAggOutputPartial) + val aggOperator = aggFactory.createOperator() omniCodegenTime += NANOSECONDS.toMillis(System.nanoTime() - startCodegen) SparkMemoryUtils.addLeakSafeTaskCompletionListener[Unit](_ => { aggOperator.close() + aggFactory.close() }) val projectOperatorFactory1 = new OmniProjectOperatorFactory(proj1OmniExpressions, proj1OmniInputTypes, 1, @@ -1240,6 +1249,7 @@ case class ColumnarMultipleOperatorExec1( // close operator addLeakSafeTaskCompletionListener[Unit](_ => { projectOperator1.close() + projectOperatorFactory1.close() }) val buildOpFactory1 = new OmniHashBuilderWithExprOperatorFactory(buildTypes1, @@ -1274,6 +1284,7 @@ case class ColumnarMultipleOperatorExec1( // close operator addLeakSafeTaskCompletionListener[Unit](_ => { projectOperator2.close() + projectOperatorFactory2.close() }) val buildOpFactory2 = new OmniHashBuilderWithExprOperatorFactory(buildTypes2, @@ -1308,6 +1319,7 @@ case class ColumnarMultipleOperatorExec1( // close operator addLeakSafeTaskCompletionListener[Unit](_ => { projectOperator3.close() + projectOperatorFactory3.close() }) val buildOpFactory3 = new OmniHashBuilderWithExprOperatorFactory(buildTypes3, @@ -1344,6 +1356,7 @@ case class ColumnarMultipleOperatorExec1( // close operator SparkMemoryUtils.addLeakSafeTaskCompletionListener[Unit](_ => { condOperator.close() + condOperatorFactory.close() }) while (batches.hasNext) { diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarHashAggregateExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarHashAggregateExec.scala index 6dc3cbef8..f7979d58d 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarHashAggregateExec.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarHashAggregateExec.scala @@ -267,7 +267,7 @@ case class ColumnarHashAggregateExec( child.executeColumnar().mapPartitionsWithIndex { (index, iter) => val startCodegen = System.nanoTime() - val operator = OmniAdaptorUtil.getAggOperator(groupingExpressions, + val factory = OmniAdaptorUtil.getAggOperatorFactory(groupingExpressions, omniGroupByChanel, omniAggChannels, omniAggChannelsFilter, @@ -276,11 +276,13 @@ case class ColumnarHashAggregateExec( omniAggOutputTypes, omniInputRaws, omniOutputPartials) + val operator = factory.createOperator() omniCodegenTime += NANOSECONDS.toMillis(System.nanoTime() - startCodegen) // close operator SparkMemoryUtils.addLeakSafeTaskCompletionListener[Unit](_ => { operator.close() + factory.close() }) while (iter.hasNext) { diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarProjection.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarProjection.scala index 0ccdbd6de..6e347805d 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarProjection.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarProjection.scala @@ -50,6 +50,7 @@ object ColumnarProjection { // close operator addLeakSafeTaskCompletionListener[Unit](_ => { projectOperator.close() + projectOperatorFactory.close() }) new Iterator[ColumnarBatch] { diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarShuffleExchangeExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarShuffleExchangeExec.scala index cea0a1438..f7f5d7468 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarShuffleExchangeExec.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarShuffleExchangeExec.scala @@ -303,6 +303,7 @@ object ColumnarShuffleExchangeExec extends Logging { // close operator addLeakSafeTaskCompletionListener[Unit](_ => { op.close() + factory.close() }) cbIter.map { cb => diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarSortExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarSortExec.scala index 7c7001dbc..1e5275121 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarSortExec.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarSortExec.scala @@ -125,6 +125,7 @@ case class ColumnarSortExec( omniCodegenTime += NANOSECONDS.toMillis(System.nanoTime() - startCodegen) SparkMemoryUtils.addLeakSafeTaskCompletionListener[Unit](_ => { sortOperator.close() + sortOperatorFactory.close() }) addAllAndGetIterator(sortOperator, iter, this.schema, longMetric("addInputTime"), longMetric("numInputVecBatchs"), longMetric("numInputRows"), diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarTakeOrderedAndProjectExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarTakeOrderedAndProjectExec.scala index 6fec9f9a0..c42120f2b 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarTakeOrderedAndProjectExec.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarTakeOrderedAndProjectExec.scala @@ -116,6 +116,7 @@ case class ColumnarTakeOrderedAndProjectExec( longMetric("omniCodegenTime") += NANOSECONDS.toMillis(System.nanoTime() - startCodegen) SparkMemoryUtils.addLeakSafeTaskCompletionListener[Unit]( _ => { topNOperator.close() + topNOperatorFactory.close() }) addAllAndGetIterator(topNOperator, iter, schema, longMetric("addInputTime"), longMetric("numInputVecBatchs"), longMetric("numInputRows"), diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarWindowExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarWindowExec.scala index b44c78803..3b8f8b7ed 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarWindowExec.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarWindowExec.scala @@ -338,6 +338,7 @@ case class ColumnarWindowExec(windowExpression: Seq[NamedExpression], // close operator SparkMemoryUtils.addLeakSafeTaskCompletionListener[Unit](_ => { windowOperator.close() + windowOperatorFactory.close() }) while (iter.hasNext) { diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/util/MergeIterator.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/util/MergeIterator.scala index 93ec7d89b..68ac49cec 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/util/MergeIterator.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/util/MergeIterator.scala @@ -97,6 +97,8 @@ class MergeIterator(iter: Iterator[ColumnarBatch], localSchema: StructType, src.close() } } + // close bufferedBatch + bufferedBatch.foreach(batch => batch.close()) } private def flush(): Unit = { -- Gitee