diff --git a/CMakeLists.txt b/CMakeLists.txt index d16549026fbeb9daa174e128b29877f9fdca09c5..e5dc351ea07f9ef7ebf2450303e9867b721012ed 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,3 +1,4 @@ + option(ENABLE_OPEN_SRC "Enable graphengine compile in opensource." FALSE) set(TFADAPTER_DIR ${CMAKE_CURRENT_LIST_DIR}) diff --git a/tf_adapter/kernels/geop_npu.cc b/tf_adapter/kernels/geop_npu.cc index 7266804a8a47cb7885e8e5966bc4cdeadc3be362..6325c6df1e695b98094d2cd7fc2378ad37510075 100644 --- a/tf_adapter/kernels/geop_npu.cc +++ b/tf_adapter/kernels/geop_npu.cc @@ -125,6 +125,9 @@ const std::unordered_set supported_origin_precision_mode_v2 = {"ori using geDataUniquePtr = std::unique_ptr>; +const std::unordered_set session_options_move_to_graph_options_set = + {"ge.dynamicDims", "ge.dynamicNodeType"}; + class NpuHostFixedAllocator : public tensorflow::Allocator, public tensorflow::core::RefCounted { public: static tensorflow::Allocator *Create(geDataUniquePtr ptr) { @@ -1026,7 +1029,11 @@ Status GeOp::CreateGeSession() { ADP_LOG(INFO) << "[GePlugin] Initialize ge success."; first = false; } - if (!SessionManager::GetInstance().GetOrCreateGeSession(tf_session_, ge_session_, sess_options_) || + auto real_session_options = sess_options_; + for (const auto &ele : session_options_move_to_graph_options_set) { + real_session_options.erase(ele); + } + if (!SessionManager::GetInstance().GetOrCreateGeSession(tf_session_, ge_session_, real_session_options) || tf_session_.empty() || ge_session_ == nullptr) { return errors::Internal("Get ge session failed."); } @@ -1131,6 +1138,13 @@ Status GeOp::AddGraph(OpKernelContext *ctx, const uint32_t &graph_id) { << jit_compile_ << ", graph_id: " << graph_id; graph_options["ge.exec.graphIOMemAllocMode"] = "ByGE"; + for (const auto &ele : session_options_move_to_graph_options_set) { + auto it = sess_options_.find(ele); + if ((it != sess_options_.end()) && (graph_options.find(ele) == graph_options.end())) { + graph_options[ele] = it->second; + } + } + const auto graph_option_ascend_string = ChangeStringToAscendString(graph_options); ADP_LOG(INFO) << "Graph options: "; NpuAttrs::LogOptions(graph_options);