From bad4870ec508d33145be8ddb5c9d5db149957057 Mon Sep 17 00:00:00 2001 From: wang-xiangX Date: Wed, 31 Aug 2022 18:26:22 +0800 Subject: [PATCH] graph counter --- tf_adapter/kernels/aicpu/dataset_function.cc | 1 + tf_adapter/util/ge_plugin.cc | 21 +++++++++++++++++++- tf_adapter/util/ge_plugin.h | 8 ++++++++ 3 files changed, 29 insertions(+), 1 deletion(-) diff --git a/tf_adapter/kernels/aicpu/dataset_function.cc b/tf_adapter/kernels/aicpu/dataset_function.cc index eb4605a45..49287e311 100644 --- a/tf_adapter/kernels/aicpu/dataset_function.cc +++ b/tf_adapter/kernels/aicpu/dataset_function.cc @@ -185,6 +185,7 @@ static Status TransTfTensorsToGeTensors(std::vector &tf_tensors, std::ve } DatasetFunction::~DatasetFunction() { + GePlugin::GetInstance()->Finalize(); ADP_LOG(EVENT) << "[DatasetFunction] ~DatasetFunction"; } diff --git a/tf_adapter/util/ge_plugin.cc b/tf_adapter/util/ge_plugin.cc index c2b1ea0fb..07d9fd242 100644 --- a/tf_adapter/util/ge_plugin.cc +++ b/tf_adapter/util/ge_plugin.cc @@ -81,6 +81,7 @@ GePlugin *GePlugin::GetInstance() { void GePlugin::Init(std::map &init_options, const bool is_global) { std::lock_guard lock(mutex_); + IncreaseGraphCounter(); if (isInit_) { ADP_LOG(INFO) << "[GePlugin] Ge has already initialized"; return; @@ -269,6 +270,18 @@ void GePlugin::Init(std::map &init_options, const bool isGlobal_ = is_global; } +void GePlugin::IncreaseGraphCounter() { + graph_counter_.fetch_add(1); +} + +void GePlugin::DecreaseGraphCounter() { + graph_counter_.fetch_sub(1); +} + +bool GePlugin::IsGraphCounterZero() { + return graph_counter_ == 0; +} + std::map GePlugin::GetInitOptions() { return init_options_; } @@ -289,7 +302,11 @@ void GePlugin::Finalize() { ADP_LOG(INFO) << "[GePlugin] Ge has already finalized."; return; } - + DecreaseGraphCounter(); + if (!IsGraphCounterZero()) { + ADP_LOG(INFO) << "[GePlugin] It is not a good time to finalize GE."; + return; + } // ge finalize GeFinalize(); @@ -444,3 +461,5 @@ int32_t MallocSharedMem(const ge::TensorInfo &tensor_info, uint64_t &dev_addr, u ADP_LOG(INFO) << "[GePlugin] malloc shared memory success."; return 0; } + +std::atomic_int GePlugin::graph_counter_ = {0}; diff --git a/tf_adapter/util/ge_plugin.h b/tf_adapter/util/ge_plugin.h index e3d6f028d..9fc000cea 100644 --- a/tf_adapter/util/ge_plugin.h +++ b/tf_adapter/util/ge_plugin.h @@ -17,6 +17,7 @@ #define TENSORFLOW_GE_PLUGIN_H_ #include +#include #include #include #include "tensorflow/core/platform/types.h" @@ -30,6 +31,12 @@ class GePlugin { void Init(std::map &init_options, const bool is_global = false); + void IncreaseGraphCounter(); + + void DecreaseGraphCounter(); + + bool IsGraphCounterZero(); + void Finalize(); bool IsGlobal(); @@ -48,6 +55,7 @@ class GePlugin { bool isGlobal_; std::map init_options_; std::mutex mutex_; + static std::atomic_int graph_counter_; }; tensorflow::Status RegisterNpuCancellationCallback(std::function callback, -- Gitee