diff --git a/tf_adapter/python/npu_bridge/estimator/npu/npu_config.py b/tf_adapter/python/npu_bridge/estimator/npu/npu_config.py index 3cd4f55b2ff752d00c53ec64ad1be0ba040aa14b..977c0caa18720689d3d70ab408382c0a240b884f 100644 --- a/tf_adapter/python/npu_bridge/estimator/npu/npu_config.py +++ b/tf_adapter/python/npu_bridge/estimator/npu/npu_config.py @@ -110,7 +110,8 @@ class NPURunConfig(run_config_lib.RunConfig): frozen_variable=False, variable_placement="Device", jit_compile="auto", - precision_mode_v2=None + precision_mode_v2=None, + host_scheduling_max_threshold=None ): """ Constructs a NPUConfig. @@ -265,6 +266,7 @@ class NPURunConfig(run_config_lib.RunConfig): self._external_weight = external_weight self.es_cluster_config = es_cluster_config self._jit_compile = jit_compile + self._host_scheduling_max_threshold = host_scheduling_max_threshold super(NPURunConfig, self).__init__( model_dir=model_dir, diff --git a/tf_adapter/python/npu_bridge/estimator/npu/npu_estimator.py b/tf_adapter/python/npu_bridge/estimator/npu/npu_estimator.py index f66fd572212d3cd2348548758a61b9aa5b3001d3..3bb3a4df5296fc444d39a92d4092437e15883ceb 100644 --- a/tf_adapter/python/npu_bridge/estimator/npu/npu_estimator.py +++ b/tf_adapter/python/npu_bridge/estimator/npu/npu_estimator.py @@ -770,6 +770,8 @@ class NPUEstimator(estimator_lib.Estimator): custom_op.parameter_map["op_wait_timeout"].i = config._op_wait_timeout if config._op_execute_timeout is not None: custom_op.parameter_map["op_execute_timeout"].i = config._op_execute_timeout + if config._host_scheduling_max_threshold is not None: + custom_op.parameter_map["host_scheduling_max_threshold"].i = config._host_scheduling_max_threshold if config._HCCL_algorithm is not None: custom_op.parameter_map["HCCL_algorithm"].s = tf.compat.as_bytes(config._HCCL_algorithm) if config._customize_dtypes is not None: diff --git a/tf_adapter/util/ge_plugin.cc b/tf_adapter/util/ge_plugin.cc index 75389659fbf951bb9ebc031a5aedf91695f5dd07..7b5cf44b21d86209e7d6824d5909ccae30f15afe 100644 --- a/tf_adapter/util/ge_plugin.cc +++ b/tf_adapter/util/ge_plugin.cc @@ -123,6 +123,7 @@ void SetOptionNameMap(json &option_name_map) { option_name_map.emplace("ge.esClusterConfig", "es_cluster_config"); option_name_map.emplace(ge::OPTION_EXEC_DYNAMIC_EXECUTE_MODE, "dynamic_graph_execute_mode"); option_name_map.emplace(ge::OPTION_EXEC_DYNAMIC_INPUT, "dynamic_input"); + option_name_map.emplace("ge.exec.hostSchedulingMaxThreshold", "host_scheduling_max_threshold"); } } // namespace diff --git a/tf_adapter/util/npu_attrs.cc b/tf_adapter/util/npu_attrs.cc index fe9f1c3b480f44b87748151fa8c304bd874879ec..25ba1d3f533c3951741fb48836cf3eab193d56c0 100644 --- a/tf_adapter/util/npu_attrs.cc +++ b/tf_adapter/util/npu_attrs.cc @@ -651,6 +651,7 @@ std::map NpuAttrs::GetInitOptions(const OpKernelConstr std::string stream_sync_timeout = "-1"; std::string event_sync_timeout = "-1"; std::string es_cluster_config; + std::string host_scheduling_max_threshold; if (ctx != nullptr && ctx->GetAttr("_NpuOptimizer", &npuOptimizer) == Status::OK()) { (void) ctx->GetAttr("_precision_mode", &precision_mode); @@ -691,6 +692,7 @@ std::map NpuAttrs::GetInitOptions(const OpKernelConstr (void) ctx->GetAttr("_stream_sync_timeout", &stream_sync_timeout); (void) ctx->GetAttr("_event_sync_timeout", &event_sync_timeout); (void) ctx->GetAttr("_es_cluster_config", &es_cluster_config); + (void) ctx->GetAttr("_host_scheduling_max_threshold", &host_scheduling_max_threshold); } std::lock_guard lock(mutex_); @@ -743,6 +745,7 @@ std::map NpuAttrs::GetInitOptions(const OpKernelConstr init_options_["stream_sync_timeout"] = stream_sync_timeout; init_options_["event_sync_timeout"] = event_sync_timeout; init_options_["ge.esClusterConfig"] = es_cluster_config; + init_options_["ge.exec.hostSchedulingMaxThreshold"] = host_scheduling_max_threshold; return init_options_; } @@ -1158,6 +1161,7 @@ std::map NpuAttrs::GetAllAttrOptions(const AttrSlice & std::string graph_compiler_cache_dir; std::string graph_slice_mode; std::string accelerate_train_mode; + std::string host_scheduling_max_threshold; auto NpuOptimizer_value = attrs.Find("_NpuOptimizer"); auto enable_data_pre_proc_value = attrs.Find("_enable_data_pre_proc"); @@ -1248,6 +1252,7 @@ std::map NpuAttrs::GetAllAttrOptions(const AttrSlice & auto jit_compile_value = attrs.Find("_jit_compile"); auto graph_compiler_cache_dir_val = attrs.Find("_graph_compiler_cache_dir"); auto accelerate_train_mode_value = attrs.Find("_accelerate_train_mode"); + auto host_scheduling_max_threshold_value = attrs.Find("_host_scheduling_max_threshold"); if (NpuOptimizer_value != nullptr) { do_npu_optimizer = "1"; @@ -1539,6 +1544,11 @@ std::map NpuAttrs::GetAllAttrOptions(const AttrSlice & if (graph_compiler_cache_dir_val != nullptr) { graph_compiler_cache_dir = graph_compiler_cache_dir_val->s(); } + + if (host_scheduling_max_threshold_value != nullptr) { + host_scheduling_max_threshold = host_scheduling_max_threshold_value->s(); + } + } all_options["variable_format_optimize"] = variable_format_optimize; @@ -1644,6 +1654,7 @@ std::map NpuAttrs::GetAllAttrOptions(const AttrSlice & all_options["enable_graph_parallel"] = enable_graph_parallel; all_options["frozen_variable"] = frozen_variable; all_options["variable_location"] = variable_location; + all_options["host_scheduling_max_threshold"] = host_scheduling_max_threshold; return all_options; } @@ -1761,6 +1772,7 @@ Status NpuAttrs::SetNpuOptimizerAttr(const GraphOptimizationPassOptions &options std::string graph_slice_mode; std::string jit_compile; std::string accelerate_train_mode; + std::string host_scheduling_max_threshold; const RewriterConfig &rewrite_options = options.session_options->config.graph_options().rewrite_options(); for (const auto &custom_optimizer : rewrite_options.custom_optimizers()) { @@ -2280,6 +2292,9 @@ Status NpuAttrs::SetNpuOptimizerAttr(const GraphOptimizationPassOptions &options return errors::Internal("graph_slice must be in ['auto', 'manual']"); } } + if (params.count("host_scheduling_max_threshold") > 0) { + host_scheduling_max_threshold = std::to_string(params.at("host_scheduling_max_threshold").i()); + } } } @@ -2423,6 +2438,8 @@ Status NpuAttrs::SetNpuOptimizerAttr(const GraphOptimizationPassOptions &options pass_options["frozen_variable"] = std::to_string(static_cast(frozen_variable)); pass_options["variable_location"] = variable_location; pass_options["accelerate_train_mode"] = accelerate_train_mode; + init_options_["host_scheduling_max_threshold"] = host_scheduling_max_threshold; + init_options_["ge.exec.hostSchedulingMaxThreshold"] = host_scheduling_max_threshold; for (const auto &option : sess_options) { std::string attr_name = std::string("_") + option.first; diff --git a/tf_adapter_2.x/npu_device/core/npu_wrapper.cpp b/tf_adapter_2.x/npu_device/core/npu_wrapper.cpp index ce930125ed1425140725d9fbf229470dba22496c..9e50a2f254384ebf174804b46a87c6123666a35e 100644 --- a/tf_adapter_2.x/npu_device/core/npu_wrapper.cpp +++ b/tf_adapter_2.x/npu_device/core/npu_wrapper.cpp @@ -127,7 +127,8 @@ const std::map kConfigurableOptions = { {"_distribute.cm_worker_size", ge::OPTION_EXEC_CM_WORKER_SIZE}, {"jit_compile", "ge.jit_compile"}, {"graph_compiler_cache_dir", "ge.graph_compiler_cache_dir"}, - {"graph_slice", "ge.graphSliceMode"} + {"graph_slice", "ge.graphSliceMode"}, + {"host_scheduling_max_threshold", "ge.exec.hostSchedulingMaxThreshold"} }; } // namespace