diff --git a/tf_adapter/kernels/geop_npu.cc b/tf_adapter/kernels/geop_npu.cc index 75c5a72b32095e3e2d2d8aa3331b376a4d90e31b..b0c61099a0022ddc78d7bc112af17c7409d2ee3d 100644 --- a/tf_adapter/kernels/geop_npu.cc +++ b/tf_adapter/kernels/geop_npu.cc @@ -379,7 +379,7 @@ void GeOp::Initialize(OpKernelConstruction *ctx) { errors::InvalidArgument("dlsym Aoe initialize API failed, ", mmDlerror())); // aoe finalize aoe_finalize_ = (AoeFinalizeFunc) mmDlsym(handle_, "AoeFinalize"); - OP_REQUIRES(ctx, aoe_initialize_ != nullptr, + OP_REQUIRES(ctx, aoe_finalize_ != nullptr, errors::InvalidArgument("dlsym Aoe Finalize API failed, ", mmDlerror())); // aoe create session aoe_create_session_ = (AoeCreateSessionFunc) mmDlsym(handle_, "AoeCreateSession"); @@ -393,6 +393,10 @@ void GeOp::Initialize(OpKernelConstruction *ctx) { aoe_set_gesession_ = (AoeSetGeSessionFunc) mmDlsym(handle_, "AoeSetGeSession"); OP_REQUIRES(ctx, aoe_set_gesession_ != nullptr, errors::InvalidArgument("dlsym Aoe set session API failed, ", mmDlerror())); + // aoe set session options + aoe_set_gesession_options_ = (AoeSetGeSessionOptionsFunc) mmDlsym(handle_, "AoeSetGeSessionOptions"); + OP_REQUIRES(ctx, aoe_set_gesession_options_ != nullptr, + errors::InvalidArgument("dlsym Aoe set session options API failed, ", mmDlerror())); // aoe set depend graphs aoe_set_dependgraphs_ = (AoeSetDependGraphFunc) mmDlsym(handle_, "AoeSetDependGraphs"); OP_REQUIRES(ctx, aoe_set_dependgraphs_ != nullptr, @@ -677,11 +681,11 @@ void GeOp::ComputeAsync(OpKernelContext *ctx, DoneCallback done) { ADP_LOG(INFO) << "[GEOP] in tuning func, aoe_mode:" << init_options_["ge.jobType"] << ", work_path:" << init_options_["ge.tuningPath"] << ", distribute_config:" << init_options_["distribute_config"]; - tune_options_.insert(init_options_.cbegin(), init_options_.cend()); - tune_options_.insert({"devices", std::to_string(device_id)}); - tune_options_.insert(sess_options_.cbegin(), sess_options_.cend()); - tune_options_.insert({"work_path", init_options_["ge.tuningPath"]}); - tune_options_.insert({"job_type", init_options_["ge.jobType"]}); + for (const auto &it : sess_options_) { + std::string key = it.first; + std::string value = it.second; + tune_options_.insert({Aoe::AscendString(key.c_str()), Aoe::AscendString(value.c_str())}); + } // aoe ini if (!tuned_initialize_flag_) { std::map global_options; @@ -1592,6 +1596,11 @@ int GeOp::RunTuning(std::vector &input_vec, std::vector &inp ADP_LOG(ERROR) << "exec aoe set session func failed[" << set_ret << "]."; return -1; } + Aoe::AoeStatus set_options_ret = (*aoe_set_gesession_options_)(session_id_, tune_options_); + if (set_options_ret != Aoe::AOE_SUCCESS) { + ADP_LOG(ERROR) << "exec aoe set session func failed[" << set_options_ret << "]."; + return -1; + } // set tuning graph Aoe::AoeStatus tune_ret = (*aoe_set_tuninggraph_)(session_id_, ge_graph); if (tune_ret != Aoe::AOE_SUCCESS) { diff --git a/tf_adapter/kernels/geop_npu.h b/tf_adapter/kernels/geop_npu.h index 647298a0dcc1b17fe94c251f4d27f0fd9710bf6d..047ce9510b3f9ebe397c866e8b1bad10b7d9c41e 100644 --- a/tf_adapter/kernels/geop_npu.h +++ b/tf_adapter/kernels/geop_npu.h @@ -39,6 +39,8 @@ using AoeFinalizeFunc = Aoe::AoeStatus (*)(); using AoeCreateSessionFunc = Aoe::AoeStatus (*)(const std::map &, SessionId &); using AoeDestroySessionFunc = Aoe::AoeStatus (*)(SessionId); using AoeSetGeSessionFunc = Aoe::AoeStatus (*)(SessionId, ge::Session*); +using AoeSetGeSessionOptionsFunc = Aoe::AoeStatus (*)(SessionId, + const std::map &); using AoeSetDependGraphFunc = Aoe::AoeStatus (*)(SessionId, std::vector&); using AoeSetDependGraphsInputsFunc = Aoe::AoeStatus (*)(SessionId, std::vector> &); using AoeSetTuningGraphInputFunc = Aoe::AoeStatus (*)(SessionId, std::vector &); @@ -179,7 +181,7 @@ private: std::string data_inputs_shape_range_; std::string getnext_inputs_shape_range_; bool need_compile_graph_first_; - std::map tune_options_; + std::map tune_options_; std::string is_dynamic_getnext_; std::string placeholder_index_; std::atomic_flag tuned_flag_; @@ -199,6 +201,7 @@ private: AoeCreateSessionFunc aoe_create_session_; AoeDestroySessionFunc aoe_destroy_session_; AoeSetGeSessionFunc aoe_set_gesession_; + AoeSetGeSessionOptionsFunc aoe_set_gesession_options_; AoeSetDependGraphFunc aoe_set_dependgraphs_; AoeSetTuningGraphFunc aoe_set_tuninggraph_; AoeTuningGraphFunc aoe_tuning_graph_; diff --git a/tf_adapter/tests/depends/aoe/src/aoe_stub.cc b/tf_adapter/tests/depends/aoe/src/aoe_stub.cc index 4a589dfc2d2d06327bd2a87dd84410effbc5139f..260069e8202e6aaae479ff21780ceff9ad7ea527 100644 --- a/tf_adapter/tests/depends/aoe/src/aoe_stub.cc +++ b/tf_adapter/tests/depends/aoe/src/aoe_stub.cc @@ -44,6 +44,14 @@ extern "C" Aoe::AoeStatus AoeSetGeSession(Aoe::SessionId SessionId, ge::Session* return Aoe::AOE_SUCCESS; } +extern "C" Aoe::AoeStatus AoeSetGeSessionOptions(Aoe::SessionId SessionId, + const std::map &sessionOptions) { + if (SessionId >= 9999) { + return Aoe::AOE_FALLURE; + } + return Aoe::AOE_SUCCESS; +} + extern "C" Aoe::AoeStatus AoeSetDependGraphs(Aoe::SessionId SessionId, std::vector &dependGraph) { return Aoe::AOE_SUCCESS; } @@ -65,4 +73,4 @@ extern "C" Aoe::AoeStatus AoeSetDependGraphsInputs(Aoe::SessionId SessionId, extern "C" Aoe::AoeStatus AoeSetTuningGraphInput(Aoe::SessionId SessionId, std::vector &input) { return Aoe::AOE_SUCCESS; } -} // namespace Aoe \ No newline at end of file +} // namespace Aoe