diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/util/OmniAdaptorUtil.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/util/OmniAdaptorUtil.scala index ed99f6b4311a48492438095a87d450f7d9d89a5a..65ed58f382654813c04102ab9b0afa748a99a4e5 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/util/OmniAdaptorUtil.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/com/huawei/boostkit/spark/util/OmniAdaptorUtil.scala @@ -274,19 +274,17 @@ object OmniAdaptorUtil { OverflowConfig.OverflowConfigId.OVERFLOW_CONFIG_NULL } - def getAggOperator(groupingExpressions: Seq[NamedExpression], - omniGroupByChanel: Array[String], - omniAggChannels: Array[Array[String]], - omniAggChannelsFilter: Array[String], - omniSourceTypes: Array[nova.hetu.omniruntime.`type`.DataType], - omniAggFunctionTypes: Array[FunctionType], - omniAggOutputTypes: Array[Array[nova.hetu.omniruntime.`type`.DataType]], - omniInputRaws: Array[Boolean], - omniOutputPartials: Array[Boolean]): OmniOperator = { - var operator: OmniOperator = null + def getAggOperatorFactory(groupingExpressions: Seq[NamedExpression], + omniGroupByChanel: Array[String], + omniAggChannels: Array[Array[String]], + omniAggChannelsFilter: Array[String], + omniSourceTypes: Array[nova.hetu.omniruntime.`type`.DataType], + omniAggFunctionTypes: Array[FunctionType], + omniAggOutputTypes: Array[Array[nova.hetu.omniruntime.`type`.DataType]], + omniInputRaws: Array[Boolean], + omniOutputPartials: Array[Boolean]) = { if (groupingExpressions.nonEmpty) { - operator = new OmniHashAggregationWithExprOperatorFactory( - omniGroupByChanel, + new OmniHashAggregationWithExprOperatorFactory(omniGroupByChanel, omniAggChannels, omniAggChannelsFilter, omniSourceTypes, @@ -294,9 +292,9 @@ object OmniAdaptorUtil { omniAggOutputTypes, omniInputRaws, omniOutputPartials, - new OperatorConfig(SpillConfig.NONE, new OverflowConfig(OmniAdaptorUtil.overflowConf()), IS_SKIP_VERIFY_EXP)).createOperator + new OperatorConfig(SpillConfig.NONE, new OverflowConfig(OmniAdaptorUtil.overflowConf()), IS_SKIP_VERIFY_EXP)) } else { - operator = new OmniAggregationWithExprOperatorFactory( + new OmniAggregationWithExprOperatorFactory( omniGroupByChanel, omniAggChannels, omniAggChannelsFilter, @@ -305,9 +303,8 @@ object OmniAdaptorUtil { omniAggOutputTypes, omniInputRaws, omniOutputPartials, - new OperatorConfig(SpillConfig.NONE, new OverflowConfig(OmniAdaptorUtil.overflowConf()), IS_SKIP_VERIFY_EXP)).createOperator + new OperatorConfig(SpillConfig.NONE, new OverflowConfig(OmniAdaptorUtil.overflowConf()), IS_SKIP_VERIFY_EXP)) } - operator } def pruneOutput(output: Seq[Attribute], projectList: Seq[NamedExpression]): Seq[Attribute] = { diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarBasicPhysicalOperators.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarBasicPhysicalOperators.scala index cb23b68f09bb085d86e133d3ef40628e8c5ca4c2..d0e151207d6691c06dcc18661c44f11033f06630 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarBasicPhysicalOperators.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarBasicPhysicalOperators.scala @@ -204,6 +204,7 @@ case class ColumnarFilterExec(condition: Expression, child: SparkPlan) // close operator addLeakSafeTaskCompletionListener[Unit](_ => { filterOperator.close() + filterOperatorFactory.close() }) val localSchema = this.schema @@ -302,6 +303,7 @@ case class ColumnarConditionProjectExec(projectList: Seq[NamedExpression], // close operator addLeakSafeTaskCompletionListener[Unit](_ => { operator.close() + operatorFactory.close() }) val localSchema = this.schema diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarExpandExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarExpandExec.scala index 27b05b16c017c43e73a0c3b6d4f05ea02d11f951..27ba5629c935c1e01790dfd143322ae4d59532b9 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarExpandExec.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarExpandExec.scala @@ -16,6 +16,7 @@ import org.apache.spark.sql.execution.util.SparkMemoryUtils.addLeakSafeTaskCompl import org.apache.spark.sql.execution.vectorized.OmniColumnVector import org.apache.spark.sql.vectorized.ColumnarBatch +import scala.collection.mutable import scala.concurrent.duration.NANOSECONDS /** @@ -80,15 +81,18 @@ case class ColumnarExpandExec( child.executeColumnar().mapPartitionsWithIndexInternal { (index, iter) => val startCodegen = System.nanoTime() + val projectOperatorFactories : mutable.Set[OmniProjectOperatorFactory] = mutable.Set() var projectOperators = omniExpressions.map(exps => { val factory = new OmniProjectOperatorFactory(exps, omniInputTypes, 1, new OperatorConfig(SpillConfig.NONE, new OverflowConfig(OmniAdaptorUtil.overflowConf()), IS_SKIP_VERIFY_EXP)) + projectOperatorFactories.add(factory) factory.createOperator }) omniCodegenTimeMetric += NANOSECONDS.toMillis(System.nanoTime() - startCodegen) // close operator addLeakSafeTaskCompletionListener[Unit](_ => { projectOperators.foreach(operator => operator.close()) + projectOperatorFactories.foreach(factory => factory.close()) }) new Iterator[ColumnarBatch] { diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarFileSourceScanExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarFileSourceScanExec.scala index faf692baa2eafcd0a38fd1d9994845149458b783..977242faf5da3f9cf5a71304ab1d0ceb83ac473b 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarFileSourceScanExec.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarFileSourceScanExec.scala @@ -859,7 +859,7 @@ case class ColumnarMultipleOperatorExec( // for join val deserializer = VecBatchSerializerFactory.create() val startCodegen = System.nanoTime() - val aggOperator = OmniAdaptorUtil.getAggOperator(aggregate.groupingExpressions, + val aggFactory = OmniAdaptorUtil.getAggOperatorFactory(aggregate.groupingExpressions, omniGroupByChanel, omniAggChannels, omniAggChannelsFilter, @@ -868,9 +868,11 @@ case class ColumnarMultipleOperatorExec( omniAggOutputTypes, omniAggInputRaw, omniAggOutputPartial) + val aggOperator = aggFactory.createOperator() omniCodegenTime += NANOSECONDS.toMillis(System.nanoTime() - startCodegen) SparkMemoryUtils.addLeakSafeTaskCompletionListener[Unit](_ => { aggOperator.close() + aggFactory.close() }) val projectOperatorFactory1 = new OmniProjectOperatorFactory(proj1OmniExpressions, proj1OmniInputTypes, 1, @@ -879,6 +881,7 @@ case class ColumnarMultipleOperatorExec( // close operator addLeakSafeTaskCompletionListener[Unit](_ => { projectOperator1.close() + projectOperatorFactory1.close() }) val buildOpFactory1 = new OmniHashBuilderWithExprOperatorFactory(buildTypes1, @@ -912,6 +915,7 @@ case class ColumnarMultipleOperatorExec( // close operator addLeakSafeTaskCompletionListener[Unit](_ => { projectOperator2.close() + projectOperatorFactory2.close() }) val buildOpFactory2 = new OmniHashBuilderWithExprOperatorFactory(buildTypes2, @@ -946,6 +950,7 @@ case class ColumnarMultipleOperatorExec( // close operator addLeakSafeTaskCompletionListener[Unit](_ => { projectOperator3.close() + projectOperatorFactory3.close() }) val buildOpFactory3 = new OmniHashBuilderWithExprOperatorFactory(buildTypes3, @@ -980,6 +985,7 @@ case class ColumnarMultipleOperatorExec( // close operator addLeakSafeTaskCompletionListener[Unit](_ => { projectOperator4.close() + projectOperatorFactory4.close() }) val buildOpFactory4 = new OmniHashBuilderWithExprOperatorFactory(buildTypes4, @@ -1016,6 +1022,7 @@ case class ColumnarMultipleOperatorExec( // close operator SparkMemoryUtils.addLeakSafeTaskCompletionListener[Unit](_ => { condOperator.close() + condOperatorFactory.close() }) while (batches.hasNext) { @@ -1220,7 +1227,7 @@ case class ColumnarMultipleOperatorExec1( // for join val deserializer = VecBatchSerializerFactory.create() val startCodegen = System.nanoTime() - val aggOperator = OmniAdaptorUtil.getAggOperator(aggregate.groupingExpressions, + val aggFactory = OmniAdaptorUtil.getAggOperatorFactory(aggregate.groupingExpressions, omniGroupByChanel, omniAggChannels, omniAggChannelsFilter, @@ -1229,9 +1236,11 @@ case class ColumnarMultipleOperatorExec1( omniAggOutputTypes, omniAggInputRaw, omniAggOutputPartial) + val aggOperator = aggFactory.createOperator() omniCodegenTime += NANOSECONDS.toMillis(System.nanoTime() - startCodegen) SparkMemoryUtils.addLeakSafeTaskCompletionListener[Unit](_ => { aggOperator.close() + aggFactory.close() }) val projectOperatorFactory1 = new OmniProjectOperatorFactory(proj1OmniExpressions, proj1OmniInputTypes, 1, @@ -1240,6 +1249,7 @@ case class ColumnarMultipleOperatorExec1( // close operator addLeakSafeTaskCompletionListener[Unit](_ => { projectOperator1.close() + projectOperatorFactory1.close() }) val buildOpFactory1 = new OmniHashBuilderWithExprOperatorFactory(buildTypes1, @@ -1274,6 +1284,7 @@ case class ColumnarMultipleOperatorExec1( // close operator addLeakSafeTaskCompletionListener[Unit](_ => { projectOperator2.close() + projectOperatorFactory2.close() }) val buildOpFactory2 = new OmniHashBuilderWithExprOperatorFactory(buildTypes2, @@ -1308,6 +1319,7 @@ case class ColumnarMultipleOperatorExec1( // close operator addLeakSafeTaskCompletionListener[Unit](_ => { projectOperator3.close() + projectOperatorFactory3.close() }) val buildOpFactory3 = new OmniHashBuilderWithExprOperatorFactory(buildTypes3, @@ -1344,6 +1356,7 @@ case class ColumnarMultipleOperatorExec1( // close operator SparkMemoryUtils.addLeakSafeTaskCompletionListener[Unit](_ => { condOperator.close() + condOperatorFactory.close() }) while (batches.hasNext) { diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarHashAggregateExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarHashAggregateExec.scala index 6dc3cbef8af69792837c870b172373705fe53cf5..f7979d58dd1b53b594a7554110d1ff1cd3207ef6 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarHashAggregateExec.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarHashAggregateExec.scala @@ -267,7 +267,7 @@ case class ColumnarHashAggregateExec( child.executeColumnar().mapPartitionsWithIndex { (index, iter) => val startCodegen = System.nanoTime() - val operator = OmniAdaptorUtil.getAggOperator(groupingExpressions, + val factory = OmniAdaptorUtil.getAggOperatorFactory(groupingExpressions, omniGroupByChanel, omniAggChannels, omniAggChannelsFilter, @@ -276,11 +276,13 @@ case class ColumnarHashAggregateExec( omniAggOutputTypes, omniInputRaws, omniOutputPartials) + val operator = factory.createOperator() omniCodegenTime += NANOSECONDS.toMillis(System.nanoTime() - startCodegen) // close operator SparkMemoryUtils.addLeakSafeTaskCompletionListener[Unit](_ => { operator.close() + factory.close() }) while (iter.hasNext) { diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarProjection.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarProjection.scala index 0ccdbd6de43c3cbd62b455ef7adf6e487cc69c45..6e347805d1592bf1165a4fc76e5bb432e2a8d5cb 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarProjection.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarProjection.scala @@ -50,6 +50,7 @@ object ColumnarProjection { // close operator addLeakSafeTaskCompletionListener[Unit](_ => { projectOperator.close() + projectOperatorFactory.close() }) new Iterator[ColumnarBatch] { diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarShuffleExchangeExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarShuffleExchangeExec.scala index cea0a1438b1c64a0d1372e2a272742aa9be08502..f7f5d74687e596d628d4f27c955205c82f8d41d2 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarShuffleExchangeExec.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarShuffleExchangeExec.scala @@ -303,6 +303,7 @@ object ColumnarShuffleExchangeExec extends Logging { // close operator addLeakSafeTaskCompletionListener[Unit](_ => { op.close() + factory.close() }) cbIter.map { cb => diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarSortExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarSortExec.scala index 7c7001dbc1c468465a0115946aeff9849d51a3df..1e52751214ab6bb1f65d7e3092509706b0ec5e75 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarSortExec.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarSortExec.scala @@ -125,6 +125,7 @@ case class ColumnarSortExec( omniCodegenTime += NANOSECONDS.toMillis(System.nanoTime() - startCodegen) SparkMemoryUtils.addLeakSafeTaskCompletionListener[Unit](_ => { sortOperator.close() + sortOperatorFactory.close() }) addAllAndGetIterator(sortOperator, iter, this.schema, longMetric("addInputTime"), longMetric("numInputVecBatchs"), longMetric("numInputRows"), diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarTakeOrderedAndProjectExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarTakeOrderedAndProjectExec.scala index 6fec9f9a054f345a83bc20f278c2ed3be57e6dbd..c42120f2b8f2cc9cbef51563b999e84636dc11ad 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarTakeOrderedAndProjectExec.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarTakeOrderedAndProjectExec.scala @@ -116,6 +116,7 @@ case class ColumnarTakeOrderedAndProjectExec( longMetric("omniCodegenTime") += NANOSECONDS.toMillis(System.nanoTime() - startCodegen) SparkMemoryUtils.addLeakSafeTaskCompletionListener[Unit]( _ => { topNOperator.close() + topNOperatorFactory.close() }) addAllAndGetIterator(topNOperator, iter, schema, longMetric("addInputTime"), longMetric("numInputVecBatchs"), longMetric("numInputRows"), diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarWindowExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarWindowExec.scala index b44c78803258ca39565ea8a5e5e1f523e478987b..3b8f8b7ed6976c806c4257ba6f0d327260cac6b7 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarWindowExec.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/ColumnarWindowExec.scala @@ -338,6 +338,7 @@ case class ColumnarWindowExec(windowExpression: Seq[NamedExpression], // close operator SparkMemoryUtils.addLeakSafeTaskCompletionListener[Unit](_ => { windowOperator.close() + windowOperatorFactory.close() }) while (iter.hasNext) { diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/util/MergeIterator.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/util/MergeIterator.scala index 93ec7d89b01fc4fee52c4abeb77b5618f1c48481..68ac49cec66b2da845c14b085f2f40316ce0c24b 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/util/MergeIterator.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/util/MergeIterator.scala @@ -97,6 +97,8 @@ class MergeIterator(iter: Iterator[ColumnarBatch], localSchema: StructType, src.close() } } + // close bufferedBatch + bufferedBatch.foreach(batch => batch.close()) } private def flush(): Unit = {