diff --git a/tf_adapter/kernels/geop_npu.cc b/tf_adapter/kernels/geop_npu.cc index 54001a191615699031443936ad72bdf0411f801d..67f2a1901e27ec0a9fbcb30ce2b2fec618b4012e 100644 --- a/tf_adapter/kernels/geop_npu.cc +++ b/tf_adapter/kernels/geop_npu.cc @@ -909,7 +909,15 @@ void GeOp::ComputeAsync(OpKernelContext *ctx, DoneCallback done) { done(); return; } - + static auto init_status = GePlugin::GetInstance()->GetInitStatus(); + if (init_status != ge::SUCCESS) { + std::this_thread::sleep_for(std::chrono::milliseconds(kFatalSleepTime)); + ADP_LOG(FATAL) << "[GePlugin] Initialize ge failed, ret : " << ToString(status); + std::string error_message = ge::GEGetErrorMsg(); + LOG(FATAL) << "[GePlugin] Initialize ge failed, ret : " << ToString(status) << std::endl + << "Error Message is : " << std::endl + << error_message; + } // convert to ge::graph if (graph_options_.count("input_format") != 0) { ADP_LOG(INFO) << "graph_options_[\"input_format\"] = " << graph_options_["input_format"]; diff --git a/tf_adapter/util/ge_plugin.cc b/tf_adapter/util/ge_plugin.cc index 65c31e6d47b4f343f11df0b5483bc17343d1efc9..cd3ca56bcee375202ecb2d900ae1103f62081d36 100644 --- a/tf_adapter/util/ge_plugin.cc +++ b/tf_adapter/util/ge_plugin.cc @@ -301,16 +301,13 @@ void GePlugin::Init(std::map &init_options, const bool SetOptionNameMap(option_name_map); init_options["ge.optionNameMap"] = option_name_map.dump(); - // ge Initialize - ge::Status status = ge::GEInitialize(init_options); - if (status != ge::SUCCESS) { - std::this_thread::sleep_for(std::chrono::milliseconds(kFatalSleepTime)); - ADP_LOG(FATAL) << "[GePlugin] Initialize ge failed, ret : " << ToString(status); - std::string error_message = ge::GEGetErrorMsg(); - LOG(FATAL) << "[GePlugin] Initialize ge failed, ret : " << ToString(status) << std::endl - << "Error Message is : " << std::endl - << error_message; - } + // ge Initialize async + future_ = std::async( + std::launch::async, + [&](const std::map &init_options) -> ge::Status { + return ge::GEInitialize(init_options); + }, + init_options); domi::GetContext().train_flag = true; ADP_LOG(INFO) << "[GePlugin] Initialize ge success."; diff --git a/tf_adapter/util/ge_plugin.h b/tf_adapter/util/ge_plugin.h index ac25effbb09f5b541247bf654f0a11508c8443d3..6f08ec82670a06b2255a169d772638a5117a5d0c 100644 --- a/tf_adapter/util/ge_plugin.h +++ b/tf_adapter/util/ge_plugin.h @@ -20,6 +20,7 @@ #include #include #include +#include #include "tensorflow/core/platform/types.h" #include "tensorflow/core/lib/core/status.h" @@ -35,6 +36,10 @@ class GePlugin { bool IsGlobal(); + ge::Status GetInitStatus() { + return future_.get(); + } + std::map GetInitOptions(); void SetRankTableFileEnv(std::map &init_options, std::string &rankTableFile); @@ -58,6 +63,7 @@ class GePlugin { std::map init_options_; std::mutex mutex_; static std::atomic_int graph_counter_; + std::future< ge::Status> future_; }; tensorflow::Status RegisterNpuCancellationCallback(std::function callback,