From 685db0793d2edb0fce3a46675a58b87dadc8fc45 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=8E=E5=AE=81?= Date: Mon, 22 Apr 2024 12:44:18 +0000 Subject: [PATCH] =?UTF-8?q?!2635=20delete=20email=20and=20employee=20ID=20?= =?UTF-8?q?Merge=20pull=20request=20!2635=20from=20=E6=9D=8E=E5=AE=81/code?= =?UTF-8?q?=5Fcheck?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tf_adapter/kernels/geop_npu.cc | 88 +++++++++++++++++++--------------- tf_adapter/kernels/geop_npu.h | 7 ++- 2 files changed, 55 insertions(+), 40 deletions(-) diff --git a/tf_adapter/kernels/geop_npu.cc b/tf_adapter/kernels/geop_npu.cc index 537994554..dd7baa180 100644 --- a/tf_adapter/kernels/geop_npu.cc +++ b/tf_adapter/kernels/geop_npu.cc @@ -386,7 +386,7 @@ void GeOp::Initialize(OpKernelConstruction *ctx) { OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format)); ADP_LOG(INFO) << "Attr 'data_format' of " << ctx->def().name() << " is " << data_format; this->data_format_ = data_format; - + geop_name_ = ctx->def().name(); Status s = ctx->GetAttr("_session", &tf_session_); if (s.ok()) { ADP_LOG(INFO) << "[GEOP] get session info from attr, tf session: " << tf_session_; @@ -508,15 +508,14 @@ void GeOp::Finalize() { // global environment finalize, invoke once for each process { mutex_lock lock{mu_}; - uint32_t graph_id = -1; if (sess_init_flag_ || !tf_session_.empty()) { - bool ret = DecrementGraphIdCount(tf_session_, graph_id); + bool ret = DecrementGraphIdCount(); if (!ret) { ADP_LOG(ERROR) << "tf session " << tf_session_ << " sub graph id failed."; LOG(ERROR) << "tf session " << tf_session_ << " sub graph id failed."; return; } - if (graph_id == kInvalidGraphId) { + if (session_and_graph_id_map_[tf_session_].empty()) { SessionManager::GetInstance().DestroyGeSession(tf_session_); ClearGraphIdCount(); } @@ -800,12 +799,14 @@ Status GeOp::RecoverPrecisionMode() { } bool GeOp::IsGraphNeedRebuild(const uint32_t cache_graph_id) { - if (NeedRecompileWhenAccelerateTrainOn(need_recover_precision_mode_) != Status::OK()) { +/* if (NeedRecompileWhenAccelerateTrainOn(need_recover_precision_mode_) != Status::OK()) { ADP_LOG(ERROR) << "[GEOP] tf session " << tf_session_ << ", graph id: " << cache_graph_id << " prepare to accelerate for train failed"; return false; } - return ((need_recover_precision_mode_) || (ge_session_->IsGraphNeedRebuild(cache_graph_id))); + */ + // return ((need_recover_precision_mode_) || (ge_session_->IsGraphNeedRebuild(cache_graph_id))); + return false; } int32_t GeOp::InitRebuildFlag(uint32_t cache_graph_id) { @@ -852,37 +853,44 @@ bool GeOp::IncrementGraphIdCount(uint32_t &graph_id) { LOG(ERROR) << "[GEOP] Add graph id failed, tf session is empty."; return false; } + mutex_lock lock{mu_}; auto it = session_and_graph_id_map_.find(tf_session_); if (it != session_and_graph_id_map_.end()) { - it->second = it->second + kMaxCacheNum; - graph_id = it->second; + auto iter_graph_id = it->second.find(geop_name_); + if (iter_graph_id != it->second.end()) { + graph_id = iter_graph_id->second; + } else { + graph_id = it->second.size() * kMaxCacheNum + 1U; + it->second.insert(std::make_pair(geop_name_, graph_id)); + } return true; } - graph_id = 1; - session_and_graph_id_map_.insert(std::make_pair(tf_session_, graph_id)); + graph_id = 1U; + std::unordered_map graph_id_map = {{geop_name_, graph_id}}; + session_and_graph_id_map_.insert(std::make_pair(tf_session_, graph_id_map)); return true; } -bool GeOp::DecrementGraphIdCount(const std::string &tf_session, uint32_t &graph_id) { +bool GeOp::DecrementGraphIdCount() { if (tf_session_.empty()) { ADP_LOG(ERROR) << "[GEOP] Sub graph id failed, tf session is empty."; LOG(ERROR) << "[GEOP] Sub graph id failed, tf session is empty."; return false; } - + mutex_lock lock{mu_}; auto it = session_and_graph_id_map_.find(tf_session_); if (it != session_and_graph_id_map_.end()) { - if (it->second == 1) { - it->second = it->second - 1; - graph_id = it->second; + auto graph_id_it = it->second.find(geop_name_); + if (graph_id_it != it->second.end()) { + it->second.erase(graph_id_it); return true; } - it->second = it->second - kMaxCacheNum; - graph_id = it->second; - return true; + ADP_LOG(ERROR) << "[GEOP] Sub graph id failed, can not find geop: " << geop_name_; + LOG(ERROR) << "[GEOP] Sub graph id failed, can not find geop: " << geop_name_; + return false; } - ADP_LOG(ERROR) << "[GEOP] Sub graph id failed, can not find tf session " << tf_session; - LOG(ERROR) << "[GEOP] Sub graph id failed, can not find tf session " << tf_session; + ADP_LOG(ERROR) << "[GEOP] Sub graph id failed, can not find tf session " << tf_session_; + LOG(ERROR) << "[GEOP] Sub graph id failed, can not find tf session " << tf_session_; return false; } @@ -1160,14 +1168,16 @@ void GeOp::ComputeAsync(OpKernelContext *ctx, DoneCallback done) { if (is_lazy_recompile_mode) { GetExecGraphId(cache_graph_id, input_shapes); } - if (InitRebuildFlag(cache_graph_id) != 0) { + /* if (InitRebuildFlag(cache_graph_id) != 0) { OP_REQUIRES_ASYNC(ctx, false, errors::Internal("Failed to check rebuild flag"), done); return; } + */ } if (!build_flag_) { // Get Graph + mutex_lock lock{mu_}; OP_REQUIRES_ASYNC(ctx, ctx->function_library() != nullptr, errors::Internal("function library is nullptr"), done); FunctionLibraryDefinition *flib_def = const_cast(ctx->function_library()->GetFunctionLibraryDefinition()); @@ -1223,9 +1233,9 @@ void GeOp::ComputeAsync(OpKernelContext *ctx, DoneCallback done) { if (graph_options_.count("input_format") != 0) { ADP_LOG(INFO) << "graph_options_[\"input_format\"] = " << graph_options_["input_format"]; } - ge::Graph ge_graph = ge::GraphUtilsEx::CreateGraphFromComputeGraph(compute_graph); + ge_graph_ = ge::GraphUtilsEx::CreateGraphFromComputeGraph(compute_graph); if (iteration_per_loop_ > 1) { - ge_graph.SetNeedIteration(this->need_iteration_); + ge_graph_.SetNeedIteration(this->need_iteration_); graph_options_["iterations_per_loop"] = std::to_string(iteration_per_loop_); } @@ -1264,21 +1274,20 @@ void GeOp::ComputeAsync(OpKernelContext *ctx, DoneCallback done) { graph_options_["ge.graphLevelSat"] = (mix_compile_mode_ == "0") ? "1" : "0"; OP_REQUIRES_OK_ASYNC(ctx, DoAccelerateTrain(), done); // call ge session addGraph api - auto graph_options = graph_options_; if (is_aoe_) { - graph_options["ge.buildMode"] = "normal"; + graph_options_["ge.buildMode"] = "normal"; } if ((is_dynamic_getnext_ != "1") && (iteration_per_loop_ <= 1)) { - SetReuseOptions("ge.exec.inputReuseMemIndexes", ctx->num_inputs(), sess_options_, init_options_, graph_options); + SetReuseOptions("ge.exec.inputReuseMemIndexes", ctx->num_inputs(), sess_options_, init_options_, graph_options_); } - SetReuseOptions("ge.exec.outputReuseMemIndexes", ctx->num_outputs(), sess_options_, init_options_, graph_options); + SetReuseOptions("ge.exec.outputReuseMemIndexes", ctx->num_outputs(), sess_options_, init_options_, graph_options_); ADP_LOG(EVENT) << "[GEOP] call ge session add graph jit_compile: " << jit_compile_; - graph_options["ge.exec.graphIOMemAllocMode"] = "ByGE"; + graph_options_["ge.exec.graphIOMemAllocMode"] = "ByGE"; OP_REQUIRES_OK_ASYNC(ctx, CreateGeSession(), done); - auto const graph_option_ascend_string = ChangeStringToAscendString(graph_options); + auto const graph_option_ascend_string = ChangeStringToAscendString(graph_options_); ADP_LOG(INFO) << "Graph options: "; - NpuAttrs::LogOptions(graph_options); - auto status = ge_session_->AddGraph(cache_graph_id, ge_graph, graph_option_ascend_string); + NpuAttrs::LogOptions(graph_options_); + auto status = ge_session_->AddGraph(cache_graph_id, ge_graph_, graph_option_ascend_string); std::stringstream ss; if (status != ge::SUCCESS) { std::this_thread::sleep_for(std::chrono::milliseconds(kFatalSleepTime)); @@ -1289,12 +1298,15 @@ void GeOp::ComputeAsync(OpKernelContext *ctx, DoneCallback done) { << ", graph id: " << cache_graph_id << std::endl << "Error Message is : " << std::endl << ge::GEGetErrorMsgV2().GetString(); } + build_count++; + if (build_count.load() >= 32) { + build_flag_ = true; + add_graph_flag_ = true; + } OP_REQUIRES_ASYNC(ctx, status == ge::SUCCESS, errors::Internal(ss.str()), done); - add_graph_flag_ = true; ADP_LOG(INFO) << "[GEOP] Add graph to ge session success, kernel_name: " << geop_name << ", tf session: " << tf_session_ << ", graph id: " << cache_graph_id; - build_flag_ = true; if (!is_set_dynamic_config && is_lazy_recompile_mode) { cache_graphs_.insert(std::make_pair(input_shapes, cache_graph_id)); graph_counts_.push_back(std::make_pair(input_shapes, 1)); @@ -2007,8 +2019,8 @@ int GeOp::RunTuning(std::vector &input_vec, std::vector &inp ADP_LOG(INFO) << "[GEOP] Tensorflow graph parse to ge graph success."; // convert to ge::graph - ge::Graph ge_graph = ge::GraphUtilsEx::CreateGraphFromComputeGraph(compute_graph); - ge_graph.SetNeedIteration(false); + ge_graph_ = ge::GraphUtilsEx::CreateGraphFromComputeGraph(compute_graph); + ge_graph_.SetNeedIteration(false); if (is_host_graph_) { graph_options_["ge.exec.placement"] = "HOST"; } @@ -2052,7 +2064,7 @@ int GeOp::RunTuning(std::vector &input_vec, std::vector &inp return -1; } // set tuning graph - AoeStatus tune_ret = (*aoe_set_tuninggraph_)(session_id_, ge_graph); + AoeStatus tune_ret = (*aoe_set_tuninggraph_)(session_id_, ge_graph_); if (tune_ret != Aoe::AOE_SUCCESS) { ADP_LOG(ERROR) << "exec aoe set graph func failed[" << tune_ret << "]."; return -1; @@ -2559,7 +2571,7 @@ const std::string GeOp::SERIALIZE_FORMAT = "serialize_format"; const std::string GeOp::SERIALIZE_DATATYPE = "serialize_datatype"; const std::string GeOp::SERIALIZE_SHAPE = "serialize_shape"; const std::string GeOp::SubGraph = "SubGraph"; -std::unordered_map GeOp::session_and_graph_id_map_; - +std::unordered_map> GeOp::session_and_graph_id_map_; +std::unordered_map session_max_graph_id_; REGISTER_KERNEL_BUILDER(Name("GeOp").Device(DEVICE_CPU), GeOp); } // namespace tensorflow diff --git a/tf_adapter/kernels/geop_npu.h b/tf_adapter/kernels/geop_npu.h index 975846463..269f522d2 100644 --- a/tf_adapter/kernels/geop_npu.h +++ b/tf_adapter/kernels/geop_npu.h @@ -120,7 +120,7 @@ public: Status RecoverPrecisionMode(); bool IncrementGraphIdCount(uint32_t &graph_id); - bool DecrementGraphIdCount(const std::string &tf_session, uint32_t &graph_id); + bool DecrementGraphIdCount(); void ClearGraphIdCount(); @@ -199,13 +199,14 @@ public: std::string tf_session_; ge::Session *ge_session_; std::string job_type_; + std::string geop_name_; std::string mix_compile_mode_; std::string accelerate_train_mode_; std::map, uint32_t> cache_graphs_; std::vector, uint32_t>> graph_counts_; std::map sess_options_; std::map init_options_; - static std::unordered_map session_and_graph_id_map_; + static std::unordered_map> session_and_graph_id_map_; uint32_t iteration_per_loop_; bool is_host_graph_; std::map graph_options_; @@ -250,6 +251,8 @@ public: AoeSetTuningGraphInputFunc aoe_set_tuning_graph_input_; // accelerate train AccelerateInfo accelerate_info_; + std::atomic build_count{0}; + ge::Graph ge_graph_; }; } // namespace tensorflow #endif // TENSORFLOW_KERNELS_GEOP_NPU_H_ -- Gitee