diff --git a/tf_adapter/util/npu_attrs.cc b/tf_adapter/util/npu_attrs.cc index ca54a7d15629f20214bef097ce53e124c8e79257..3da674a1b04af00d0b618f836a4a76c87e6419c3 100644 --- a/tf_adapter/util/npu_attrs.cc +++ b/tf_adapter/util/npu_attrs.cc @@ -495,6 +495,7 @@ std::map NpuAttrs::GetSessOptions(const OpKernelConstr std::string graph_max_parallel_model_num = "1"; std::string input_batch_cpy; std::string jit_compile; + std::string aicore_num; if (ctx != nullptr && ctx->GetAttr("_NpuOptimizer", &npuOptimizer) == Status::OK()) { (void) ctx->GetAttr("_variable_format_optimize", &variable_format_optimize); (void) ctx->GetAttr("_hcom_parallel", &hcom_parallel); @@ -572,6 +573,7 @@ std::map NpuAttrs::GetSessOptions(const OpKernelConstr } (void) ctx->GetAttr("_graph_compiler_cache_dir", &graph_compiler_cache_dir); (void) ctx->GetAttr("_input_batch_cpy", &input_batch_cpy); + (void) ctx->GetAttr("_aicore_num", &aicore_num); } // session options @@ -644,6 +646,8 @@ std::map NpuAttrs::GetSessOptions(const OpKernelConstr sess_options["ge.inputBatchCpy"] = input_batch_cpy; sess_options["input_batch_cpy"] = input_batch_cpy; SetForbiddenClosePassOn(sess_options); + sess_options["aicore_num"] = aicore_num; + sess_options["ge.aicoreNum"] = aicore_num; return sess_options; } @@ -1943,7 +1947,7 @@ Status NpuAttrs::SetNpuOptimizerAttr(const GraphOptimizationPassOptions &options std::string accelerate_train_mode; int32_t execute_times = -1; int32_t export_compile_stat = 1; - std::string aicore_num; + std::string aicore_num; bool oo_constant_folding = true; bool input_batch_cpy = false; std::string shape_generalization_mode = "STRICT"; @@ -2618,6 +2622,8 @@ Status NpuAttrs::SetNpuOptimizerAttr(const GraphOptimizationPassOptions &options sess_options["input_fusion_size"] = std::to_string(input_fusion_size); sess_options["input_batch_cpy"] = std::to_string(input_batch_cpy); sess_options["ge.inputBatchCpy"] = std::to_string(input_batch_cpy); + sess_options["aicore_num"] = aicore_num; + sess_options["ge.aicoreNum"] = aicore_num; init_options_["profiling_mode"] = std::to_string(static_cast(profiling_mode)); init_options_[ge::OPTION_EXEC_PROFILING_MODE] = std::to_string(static_cast(profiling_mode));