From 98bc9f883f70485944c779435a2aa676dceabd3a Mon Sep 17 00:00:00 2001 From: liujingxiang Date: Thu, 14 Mar 2024 23:07:48 +0800 Subject: [PATCH] hashjoin fix memory leak when task recovery --- .../joins/ColumnarBroadcastHashJoinExec.scala | 26 ++++++++++++++----- 1 file changed, 20 insertions(+), 6 deletions(-) diff --git a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/joins/ColumnarBroadcastHashJoinExec.scala b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/joins/ColumnarBroadcastHashJoinExec.scala index ed3ca244b..552188867 100644 --- a/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/joins/ColumnarBroadcastHashJoinExec.scala +++ b/omnioperator/omniop-spark-extension/java/src/main/scala/org/apache/spark/sql/execution/joins/ColumnarBroadcastHashJoinExec.scala @@ -341,7 +341,7 @@ case class ColumnarBroadcastHashJoinExec( Optional.empty() } - def createBuildOpFactoryAndOp(): (OmniHashBuilderWithExprOperatorFactory, OmniOperator) = { + def createBuildOpFactoryAndOp(isShared: Boolean): (OmniHashBuilderWithExprOperatorFactory, OmniOperator) = { val startBuildCodegen = System.nanoTime() val opFactory = new OmniHashBuilderWithExprOperatorFactory(lookupJoinType, buildTypes, buildJoinColsExp, 1, @@ -350,6 +350,10 @@ case class ColumnarBroadcastHashJoinExec( val op = opFactory.createOperator() buildCodegenTime += NANOSECONDS.toMillis(System.nanoTime() - startBuildCodegen) + if (isShared) { + OmniHashBuilderWithExprOperatorFactory.saveHashBuilderOperatorAndFactory(buildPlan.id, + opFactory, op) + } val deserializer = VecBatchSerializerFactory.create() relation.value.buildData.foreach { input => val startBuildInput = System.nanoTime() @@ -357,7 +361,19 @@ case class ColumnarBroadcastHashJoinExec( buildAddInputTime += NANOSECONDS.toMillis(System.nanoTime() - startBuildInput) } val startBuildGetOp = System.nanoTime() - op.getOutput + try { + op.getOutput + } catch { + case e: Exception => { + if (isShared) { + OmniHashBuilderWithExprOperatorFactory.dereferenceHashBuilderOperatorAndFactory(buildPlan.id) + } else { + op.close() + opFactory.close() + } + throw new RuntimeException("HashBuilder getOutput failed") + } + } buildGetOutputTime += NANOSECONDS.toMillis(System.nanoTime() - startBuildGetOp) (opFactory, op) } @@ -369,11 +385,9 @@ case class ColumnarBroadcastHashJoinExec( try { buildOpFactory = OmniHashBuilderWithExprOperatorFactory.getHashBuilderOperatorFactory(buildPlan.id) if (buildOpFactory == null) { - val (opFactory, op) = createBuildOpFactoryAndOp() + val (opFactory, op) = createBuildOpFactoryAndOp(true) buildOpFactory = opFactory buildOp = op - OmniHashBuilderWithExprOperatorFactory.saveHashBuilderOperatorAndFactory(buildPlan.id, - buildOpFactory, buildOp) } } catch { case e: Exception => { @@ -383,7 +397,7 @@ case class ColumnarBroadcastHashJoinExec( OmniHashBuilderWithExprOperatorFactory.gLock.unlock() } } else { - val (opFactory, op) = createBuildOpFactoryAndOp() + val (opFactory, op) = createBuildOpFactoryAndOp(false) buildOpFactory = opFactory buildOp = op } -- Gitee