From 642cef6be0825abedfe1d28d5ea8beab6c3237bb Mon Sep 17 00:00:00 2001 From: huanruizhi Date: Wed, 27 Sep 2023 19:40:43 +0800 Subject: [PATCH] fixed 3dbf97c from https://gitee.com/official-wisdom/tensorflow/pulls/2459 add lock to run one graph multi thread --- tf_adapter/kernels/geop_npu.cc | 6 +++++- tf_adapter/kernels/geop_npu.h | 1 + 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/tf_adapter/kernels/geop_npu.cc b/tf_adapter/kernels/geop_npu.cc index 3f086db4f..1f637e69d 100644 --- a/tf_adapter/kernels/geop_npu.cc +++ b/tf_adapter/kernels/geop_npu.cc @@ -1050,6 +1050,7 @@ Status GeOp::DoGraphParser(ge::ComputeGraphPtr &compute_graph, FunctionLibraryDe } void GeOp::ComputeAsync(OpKernelContext *ctx, DoneCallback done) { + run_mtx_.lock(); // ctx is not nullptr OP_REQUIRES_ASYNC(ctx, init_flag_, errors::InvalidArgument("GeOp not Initialize success."), done); if (!sess_init_flag_) { @@ -1318,9 +1319,10 @@ void GeOp::ComputeAsync(OpKernelContext *ctx, DoneCallback done) { } int64 run_start_time = InferShapeUtil::GetCurrentTimestap(); - auto callback = [done, ctx, run_start_time](ge::Status ge_status, std::vector &outputs) { + auto callback = [done, ctx, run_start_time, this](ge::Status ge_status, std::vector &outputs) { if (ge_status == ge::SUCCESS) { if (BuildOutputTensorInfo(ctx, outputs) != Status::OK()) { + run_mtx_.unlock(); ADP_LOG(FATAL) << ctx->op_kernel().name() << " GEOP::DoRunAsync get output failed."; std::string error_message = ge::GEGetErrorMsg(); std::stringstream ss; @@ -1335,6 +1337,7 @@ void GeOp::ComputeAsync(OpKernelContext *ctx, DoneCallback done) { ADP_LOG(WARNING) << "[GEOP] Out of range: End of sequence."; LOG(WARNING) << "[GEOP] Out of range: End of sequence."; } else if (ge_status != ge::SUCCESS) { + run_mtx_.unlock(); std::this_thread::sleep_for(std::chrono::milliseconds(kFatalSleepTime)); ADP_LOG(FATAL) << ctx->op_kernel().name() << "GEOP::::DoRunAsync Failed"; std::string error_message = ge::GEGetErrorMsg(); @@ -1348,6 +1351,7 @@ void GeOp::ComputeAsync(OpKernelContext *ctx, DoneCallback done) { ADP_LOG(INFO) << "[GEOP] RunGraphAsync callback, status:" << ge_status << ", kernel_name:" << ctx->op_kernel().name() << "[ " << (run_end_time - run_start_time) << "us]"; done(); + run_mtx_.unlock(); }; // call ge session runGraphAsync api diff --git a/tf_adapter/kernels/geop_npu.h b/tf_adapter/kernels/geop_npu.h index bf102b2ff..01ca513c0 100644 --- a/tf_adapter/kernels/geop_npu.h +++ b/tf_adapter/kernels/geop_npu.h @@ -247,6 +247,7 @@ public: AoeSetTuningGraphInputFunc aoe_set_tuning_graph_input_; // accelerate train AccelerateInfo accelerate_info_; + std::mutex run_mtx_; }; } // namespace tensorflow #endif // TENSORFLOW_KERNELS_GEOP_NPU_H_ -- Gitee