diff --git a/tf_adapter/kernels/aicpu/dataset_function.cc b/tf_adapter/kernels/aicpu/dataset_function.cc index eb4605a45dd33b883310f0053319ea50dfd47836..49287e31140bf5b8ff247c6fed5d25d92a203b73 100644 --- a/tf_adapter/kernels/aicpu/dataset_function.cc +++ b/tf_adapter/kernels/aicpu/dataset_function.cc @@ -185,6 +185,7 @@ static Status TransTfTensorsToGeTensors(std::vector &tf_tensors, std::ve } DatasetFunction::~DatasetFunction() { + GePlugin::GetInstance()->Finalize(); ADP_LOG(EVENT) << "[DatasetFunction] ~DatasetFunction"; } diff --git a/tf_adapter/util/ge_plugin.cc b/tf_adapter/util/ge_plugin.cc index c2b1ea0fb85a6a0e4f0e9560f53ba8c3f3d02c27..07d9fd242c5e99517efd086d4047fbd53a2d213e 100644 --- a/tf_adapter/util/ge_plugin.cc +++ b/tf_adapter/util/ge_plugin.cc @@ -81,6 +81,7 @@ GePlugin *GePlugin::GetInstance() { void GePlugin::Init(std::map &init_options, const bool is_global) { std::lock_guard lock(mutex_); + IncreaseGraphCounter(); if (isInit_) { ADP_LOG(INFO) << "[GePlugin] Ge has already initialized"; return; @@ -269,6 +270,18 @@ void GePlugin::Init(std::map &init_options, const bool isGlobal_ = is_global; } +void GePlugin::IncreaseGraphCounter() { + graph_counter_.fetch_add(1); +} + +void GePlugin::DecreaseGraphCounter() { + graph_counter_.fetch_sub(1); +} + +bool GePlugin::IsGraphCounterZero() { + return graph_counter_ == 0; +} + std::map GePlugin::GetInitOptions() { return init_options_; } @@ -289,7 +302,11 @@ void GePlugin::Finalize() { ADP_LOG(INFO) << "[GePlugin] Ge has already finalized."; return; } - + DecreaseGraphCounter(); + if (!IsGraphCounterZero()) { + ADP_LOG(INFO) << "[GePlugin] It is not a good time to finalize GE."; + return; + } // ge finalize GeFinalize(); @@ -444,3 +461,5 @@ int32_t MallocSharedMem(const ge::TensorInfo &tensor_info, uint64_t &dev_addr, u ADP_LOG(INFO) << "[GePlugin] malloc shared memory success."; return 0; } + +std::atomic_int GePlugin::graph_counter_ = {0}; diff --git a/tf_adapter/util/ge_plugin.h b/tf_adapter/util/ge_plugin.h index e3d6f028dc0bc2e9eb4c296bee9953e37baec56e..9fc000ceac69271666caa8513ffc09bf355ec5f5 100644 --- a/tf_adapter/util/ge_plugin.h +++ b/tf_adapter/util/ge_plugin.h @@ -17,6 +17,7 @@ #define TENSORFLOW_GE_PLUGIN_H_ #include +#include #include #include #include "tensorflow/core/platform/types.h" @@ -30,6 +31,12 @@ class GePlugin { void Init(std::map &init_options, const bool is_global = false); + void IncreaseGraphCounter(); + + void DecreaseGraphCounter(); + + bool IsGraphCounterZero(); + void Finalize(); bool IsGlobal(); @@ -48,6 +55,7 @@ class GePlugin { bool isGlobal_; std::map init_options_; std::mutex mutex_; + static std::atomic_int graph_counter_; }; tensorflow::Status RegisterNpuCancellationCallback(std::function callback,