From 47f9a464f3dae852e359f4c4d6f2644f4983e541 Mon Sep 17 00:00:00 2001 From: huanruizhi Date: Mon, 4 Jan 2021 15:43:39 +0800 Subject: [PATCH] recommend network func --- tf_adapter/interface_spec/api_npu_config.pyh | 3 +- tf_adapter/kernels/host_queue_dataset_op.cc | 133 +++++++++++++++--- tf_adapter/kernels/threads_pool.h | 119 ++++++++++++++++ .../optimizers/dp_tf_ge_conversion_pass.cc | 21 ++- .../npu_bridge/estimator/npu/npu_config.py | 8 +- .../npu_bridge/estimator/npu/npu_estimator.py | 4 + tf_adapter/util/ge_plugin.cc | 5 +- tf_adapter/util/npu_attrs.cc | 71 ++++++++++ 8 files changed, 333 insertions(+), 31 deletions(-) create mode 100644 tf_adapter/kernels/threads_pool.h diff --git a/tf_adapter/interface_spec/api_npu_config.pyh b/tf_adapter/interface_spec/api_npu_config.pyh index 4f7e29e58..ffefeac46 100644 --- a/tf_adapter/interface_spec/api_npu_config.pyh +++ b/tf_adapter/interface_spec/api_npu_config.pyh @@ -13,7 +13,8 @@ class NPURunConfig(run_config_lib.RunConfig): mstune_mode=None, work_path=None, buffer_optimize="l2_optimize", enable_small_channel=0, fusion_switch_file=None, enable_compress_weight=False, compress_weight_conf=None, op_compiler_cache_mode=None, op_compiler_cache_dir=None, debug_dir=None, hcom_multi_mode=False, dynamic_input=False, - dynamic_graph_execute_mode="dynamic_execute", dynamic_inputs_shape_range=None): + dynamic_graph_execute_mode="dynamic_execute", dynamic_inputs_shape_range=None, + local_rank_id=None, local_device_list=None): class ProfilingConfig(): def __init__(self, enable_profiling=False, profiling_options=None): diff --git a/tf_adapter/kernels/host_queue_dataset_op.cc b/tf_adapter/kernels/host_queue_dataset_op.cc index b55e55736..4119022b7 100644 --- a/tf_adapter/kernels/host_queue_dataset_op.cc +++ b/tf_adapter/kernels/host_queue_dataset_op.cc @@ -37,6 +37,7 @@ limitations under the License. #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/util/env_var.h" #include "tf_adapter/common/common.h" +#include "tf_adapter/kernels/threads_pool.h" #include "tf_adapter/util/npu_attrs.h" #include #include @@ -63,27 +64,68 @@ class HostQueueDatasetOp : public DatasetOpKernel { public: explicit HostQueueDatasetOp(OpKernelConstruction *ctx) : DatasetOpKernel(ctx) { // ctx is not nullptr + device_id_ = 0; + std::string tmp_rank_id; + std::string tmp_device_list; OP_REQUIRES_OK(ctx, ctx->GetAttr("channel_name", &channel_name_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_)); - LOG(INFO) << "Start to init tdt."; - uint32_t device_id = 0; - OP_REQUIRES_OK(ctx, GetEnvDeviceID(device_id)); - int32_t tdt_status = TdtHostInit(device_id); - OP_REQUIRES(ctx, tdt_status == 0, errors::InvalidArgument("Tdt client init failed.")); + OP_REQUIRES_OK(ctx, ctx->GetAttr("_local_rank_id", &tmp_rank_id)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("_local_device_list", &tmp_device_list)); + LOG(INFO) << "Get local rank id:" << tmp_rank_id << ", local device list:" << tmp_device_list; + // local rank id range 0-7 + local_rank_id_ = std::atoi(tmp_rank_id.c_str()); + for (int i = 0; i < tmp_device_list.size(); i += 2) { + int device_id = std::atoi(&tmp_device_list[i]); + OP_REQUIRES(ctx, device_id >= 0, errors::InvalidArgument("device id should be >= 0.")); + local_device_list_.push_back(device_id); + } + if (local_rank_id_ == 0) { + LOG(INFO) << "Start to init all tdt host."; + pools_ = std::make_shared(); + pools_->InitThreadPool(local_device_list_.size()); + std::vector> tdt_status; + for (auto device_id : local_device_list_) { + tdt_status.emplace_back(pools_->Enqueue(TdtInFeedInit, device_id)); + } + for (auto && result : tdt_status) { + OP_REQUIRES(ctx, result.get() == 0, errors::InvalidArgument("Tdt host init failed.")); + } + LOG(INFO) << "Init all tdt host success."; + } else if (local_rank_id_ == -1) { + LOG(INFO) << "Start to init tdt."; + uint32_t device_id = 0; + OP_REQUIRES_OK(ctx, GetEnvDeviceID(device_id)); + device_id_ = device_id; + int32_t tdt_status = TdtInFeedInit(device_id_); + OP_REQUIRES(ctx, tdt_status == 0, errors::InvalidArgument("Tdt client init failed.")); + LOG(INFO) << "Init tdt host success."; + } else { LOG(INFO) << "Tdt client do not init in slave."; } tdt_release = false; } ~HostQueueDatasetOp() { - LOG(INFO) << "Start to destroy tdt."; - if (!tdt_release) { - int32_t tdt_status = TdtHostDestroy(); + int32_t tdt_status = 0; + if (!tdt_release && local_rank_id_ == 0) { + LOG(INFO) << "Start to destroy all host tdt."; + std::vector> tdt_status; + for (auto device_id : local_device_list_) { + tdt_status.emplace_back(pools_->Enqueue(TdtInFeedDestroy, device_id)); + } + for (auto &&result : tdt_status) { + if (result.get() != 0) { LOG(ERROR) << "Tdt client close failed."; } + } + LOG(INFO) << "Tdt client close all host success."; + tdt_release = true; + } else if (!tdt_release && local_rank_id_ == -1) { + LOG(INFO) << "Start to destroy tdt."; + tdt_status = TdtInFeedDestroy(device_id_); if (tdt_status != 0) { LOG(ERROR) << "Tdt client close failed."; } else { LOG(INFO) << "Tdt client close success."; tdt_release = true; } - } + } else { LOG(INFO) << "Tdt client do not destroy in slave."; } } void MakeDataset(OpKernelContext *ctx, DatasetBase **output) override { std::vector inputs; @@ -93,7 +135,8 @@ class HostQueueDatasetOp : public DatasetOpKernel { OP_REQUIRES_OK(ctx, GetDatasetFromVariantTensor(ctx->input(i), &input)); inputs.push_back(input); } - *output = new (nothrow) Dataset(ctx, inputs, channel_name_, output_types_, output_shapes_); + *output = new (nothrow) Dataset(ctx, inputs, channel_name_, output_types_, output_shapes_, + local_rank_id_, local_device_list_, device_id_, pools_); OP_REQUIRES(ctx, *output != nullptr, errors::InvalidArgument("Data process host queue dataset op: new dataset failed.")); } @@ -102,9 +145,12 @@ class HostQueueDatasetOp : public DatasetOpKernel { class Dataset : public DatasetBase { public: Dataset(OpKernelContext *ctx, const std::vector &inputs, const string &channelName, - const DataTypeVector &outputTypes, const vector &outputShapes) + const DataTypeVector &outputTypes, const vector &outputShapes, + const int &local_rank_id, const std::vector &local_device_list, + const uint32_t &device_id, std::shared_ptr pools) : DatasetBase(DatasetContext(ctx)), inputs_(inputs), channel_name_(channelName), output_types_(outputTypes), - output_shapes_(outputShapes) { + output_shapes_(outputShapes), local_rank_id_(local_rank_id), local_device_list_(local_device_list), + device_id_(device_id), pools_(pools) { for (const auto &input : inputs_) { input->Ref(); } } @@ -226,8 +272,19 @@ class HostQueueDatasetOp : public DatasetOpKernel { LOG(ERROR) << "Get data failed " << buffer_.front().status.ToString(); } items.emplace_back(end_item); - int32_t tdt_status = TdtHostPushData(dataset()->channel_name_, items); - if (tdt_status != 0) { LOG(INFO) << "End training as tdt host push end data failed " << tdt_status; } + if (dataset()->local_rank_id_ == 0) { + std::vector> tdt_status; + for (auto device_id : dataset()->local_device_list_) { + tdt_status.emplace_back(dataset()->pools_->Enqueue(TdtHostPushData, + dataset()->channel_name_, items, device_id)); + } + for (auto &&result : tdt_status) { + if (result.get() != 0) { LOG(INFO) << "End training as tdt host push end data failed."; } + } + } else { + int32_t tdt_status = TdtHostPushData(dataset()->channel_name_, items, dataset()->device_id_); + if (tdt_status != 0) { LOG(INFO) << "End training as tdt host push end data failed " << tdt_status; } + } cancelled_ = true; cond_var_.notify_all(); return; @@ -274,13 +331,30 @@ class HostQueueDatasetOp : public DatasetOpKernel { total_bytes += tensor.TotalBytes(); } // call tdt interface - int32_t tdt_status = TdtHostPushData(dataset()->channel_name_, items); - if (tdt_status != 0) { - LOG(INFO) << "End training as tdt host push data finished: " << tdt_status; - mutex_lock lck(mu_); - cancelled_ = true; - cond_var_.notify_all(); - return; + if (dataset()->local_rank_id_ == 0) { + std::vector> tdt_status; + for (auto device_id : dataset()->local_device_list_) { + tdt_status.emplace_back(dataset()->pools_->Enqueue(TdtHostPushData, + dataset()->channel_name_, items, device_id)); + } + for (auto &&result : tdt_status) { + if (result.get() != 0) { + LOG(INFO) << "End training as tdt host push data finished."; + mutex_lock lck(mu_); + cancelled_ = true; + cond_var_.notify_all(); + return; + } + } + } else { + int32_t tdt_status = TdtHostPushData(dataset()->channel_name_, items, dataset()->device_id_); + if (tdt_status != 0) { + LOG(INFO) << "End training as tdt host push data finished: " << tdt_status; + mutex_lock lck(mu_); + cancelled_ = true; + cond_var_.notify_all(); + return; + } } { mutex_lock lck(mu_); @@ -330,8 +404,13 @@ class HostQueueDatasetOp : public DatasetOpKernel { } { mutex_lock lck(mu_); - TF_RETURN_IF_ERROR(EnsureReceiveThreadStarted(ctx)); - TF_RETURN_IF_ERROR(EnsureSendThreadStarted(ctx)); + if (dataset()->local_rank_id_ <= 0) { + TF_RETURN_IF_ERROR(EnsureReceiveThreadStarted(ctx)); + TF_RETURN_IF_ERROR(EnsureSendThreadStarted(ctx)); + } else { + LOG(INFO) << "HostQueue is not chief, not send data."; + return Status::OK(); + } } LOG(INFO) << "HostQueue success to Initialize. channelName: " << dataset()->channel_name_; @@ -376,14 +455,22 @@ class HostQueueDatasetOp : public DatasetOpKernel { std::unique_ptr receive_thread_ GUARDED_BY(mu_); std::unique_ptr send_thread_ GUARDED_BY(mu_); }; + std::shared_ptr pools_; const std::vector inputs_; std::string channel_name_; const DataTypeVector output_types_; const vector output_shapes_; + int local_rank_id_; + std::vector local_device_list_; + uint32_t device_id_; }; std::string channel_name_; DataTypeVector output_types_; vector output_shapes_; + std::shared_ptr pools_; + int local_rank_id_; + std::vector local_device_list_; + uint32_t device_id_; }; REGISTER_KERNEL_BUILDER(Name("HostQueueDataset").Device(DEVICE_CPU), HostQueueDatasetOp); diff --git a/tf_adapter/kernels/threads_pool.h b/tf_adapter/kernels/threads_pool.h new file mode 100644 index 000000000..23d9a58af --- /dev/null +++ b/tf_adapter/kernels/threads_pool.h @@ -0,0 +1,119 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + +Copyright (C) 2019-2020. Huawei Technologies Co., Ltd. All rights reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_TF_ADAPTER_KERNELS_THREAD_POOL_H +#define TENSORFLOW_TF_ADAPTER_KERNELS_THREAD_POOL_H + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "tensorflow/core/platform/logging.h" + +class ThreadPool { + public: + template + auto Enqueue(F&& f, Args&&... args) + -> std::future::type>; + // initialize thread pool + void InitThreadPool(size_t threads); + // ThreadPool construct + ThreadPool() : stop_(false), init_flag_(false) {} + // ThreadPool destruct + ~ThreadPool(); + private: + // need to keep track of threads so we can join them + std::vector< std::thread > workers_; + // the task queue + std::queue< std::function > tasks_; + std::mutex queue_mutex_; + std::condition_variable condition_; + std::atomic init_flag_; + bool stop_; +}; + +// launch some amount of workers_ +void ThreadPool::InitThreadPool(size_t threads) +{ + if (!init_flag_) { + for (size_t i = 0; i < threads; ++i) { + workers_.emplace_back([this] { + for (;;) { + std::function task; + { + std::unique_lock lock(this->queue_mutex_); + this->condition_.wait(lock, + [this] { return this->stop_ || !this->tasks_.empty(); }); + if (this->stop_ || this->tasks_.empty()) { return; } + task = std::move(this->tasks_.front()); + this->tasks_.pop(); + } + task(); + } + }); + } + } + init_flag_ = true; +} + +// add new work item to the pool +template +auto ThreadPool::Enqueue(F&& f, Args&&... args) + -> std::future::type> +{ + if (!init_flag_) { LOG(ERROR) << "thread pool is not initialized."; } + using return_type = typename std::result_of::type; + auto task = std::make_shared< std::packaged_task >( + std::bind(std::forward(f), std::forward(args)...)); + std::future res = task->get_future(); + { + std::unique_lock lock(queue_mutex_); + if (stop_) { LOG(ERROR) << "Enqueue on stopped ThreadPool."; } + tasks_.emplace([task]() { (*task)(); }); + } + condition_.notify_one(); + return res; +} + +ThreadPool::~ThreadPool() +{ + { + std::unique_lock lock(queue_mutex_); + stop_ = true; + } + init_flag_ = false; + condition_.notify_all(); + for (std::thread &worker : workers_) { worker.join(); } +} + +#endif \ No newline at end of file diff --git a/tf_adapter/optimizers/dp_tf_ge_conversion_pass.cc b/tf_adapter/optimizers/dp_tf_ge_conversion_pass.cc index b8fa518bc..945b7045c 100644 --- a/tf_adapter/optimizers/dp_tf_ge_conversion_pass.cc +++ b/tf_adapter/optimizers/dp_tf_ge_conversion_pass.cc @@ -48,6 +48,7 @@ limitations under the License. namespace tensorflow { static const int64 kMicrosToMillis = 1000; +static int64 g_channel_index = 1; // GE ops white list const static std::vector GE_OPS_WHITELIST = { "MapDataset", "ParallelMapDataset", "BatchDataset", "MapAndBatchDataset", "DeviceQueueDataset", @@ -165,7 +166,8 @@ class DpTfToGEConversionPassImpl { inline bool IsDeviceSupportedFunc(const std::string &fn) const; inline Status GetSplitEdges(const Node *n, std::vector &split_edges, const Edge *e); inline void RemoveSplitEdges(Node *topo_end); - inline Status InsertChannelQueue(Node *topo_end, std::string &host_queue_name, std::string &device_queue_name) const; + inline Status InsertChannelQueue(Node *topo_end, std::string &host_queue_name, std::string &device_queue_name, + std::map &all_options) const; bool GetNodeFuncs(const FunctionLibraryDefinition *flib_def, Node *node, std::vector &node_funcs); bool RemoveIsolatedNode(Graph *g, std::unordered_set visited); Status RemoveNotSupportDataset(Graph *g, const std::string &device_queue_dataset, @@ -376,13 +378,22 @@ Status DpTfToGEConversionPassImpl::GetSplitEdges(const Node *n, std::vector &all_options) const { LOG(INFO) << "Start to insert HostQueueDataset and DeviceQueueDataset."; for (const Edge *e : split_edges_.at(topo_end)) { REQUIRES_NOT_NULL(e); REQUIRES_NOT_NULL(e->src()); REQUIRES_NOT_NULL(e->dst()); - std::string queue_name = strings::StrCat("Queue_", GetEdgeName(e), "_", GetRandomName()); + std::string local_rank_id = all_options["local_rank_id"]; + std::string local_device_list = all_options["local_device_list"]; + std::string queue_name; + if (local_rank_id == "-1") { + queue_name = strings::StrCat("Queue_", GetEdgeName(e), "_", GetRandomName()); + } else { + queue_name = strings::StrCat(e->src()->name(), "_index_", std::to_string(g_channel_index)); + g_channel_index += 1; + } host_queue_name = strings::StrCat("Host", queue_name); device_queue_name = strings::StrCat("Device", queue_name); LOG(INFO) << "Add_" << host_queue_name; @@ -401,6 +412,8 @@ Status DpTfToGEConversionPassImpl::InsertChannelQueue(Node *topo_end, std::strin .Attr("channel_name", queue_name) .Attr("output_types", type_status ? m_src["output_types"] : m_src["Toutput_types"]) .Attr("output_shapes", m_src["output_shapes"]) + .Attr("_local_rank_id", local_rank_id) + .Attr("_local_device_list", local_device_list) .Finalize(&*graph_, &queue_node_host)); REQUIRES_NOT_NULL(queue_node_host); LOG(INFO) << "Add_" << device_queue_name; @@ -558,7 +571,7 @@ bool DpTfToGEConversionPassImpl::RunPass(std::unique_ptr *g, FunctionLibr LOG(INFO) << "Start to add host and device queue on split edges"; std::string host_queue_name; std::string device_queue_name; - TF_DO_CHECK_OK(InsertChannelQueue(topo_end, host_queue_name, device_queue_name), ERROR); + TF_DO_CHECK_OK(InsertChannelQueue(topo_end, host_queue_name, device_queue_name, all_options), ERROR); LOG(INFO) << "host queue name is " << host_queue_name; LOG(INFO) << "device queue name is " << device_queue_name; // Remove all split edges diff --git a/tf_adapter/python/npu_bridge/estimator/npu/npu_config.py b/tf_adapter/python/npu_bridge/estimator/npu/npu_config.py index dc666c922..717f68502 100644 --- a/tf_adapter/python/npu_bridge/estimator/npu/npu_config.py +++ b/tf_adapter/python/npu_bridge/estimator/npu/npu_config.py @@ -65,7 +65,9 @@ class NPURunConfig(run_config_lib.RunConfig): hcom_multi_mode=False, dynamic_input=False, dynamic_graph_execute_mode="dynamic_execute", - dynamic_inputs_shape_range=None + dynamic_inputs_shape_range=None, + local_rank_id=None, + local_device_list=None ): """ Constructs a NPUConfig. @@ -129,6 +131,8 @@ class NPURunConfig(run_config_lib.RunConfig): dynamic_input:Whether Input is dynamic. dynamic_graph_execute_mode:Dynamic graph execute mode. lazy_recompile or dynamic_execute dynamic_inputs_shape_range:Inputs shape range. + local_rank_id: Local sequence number of the device in a group. + local_device_list: Available devices. """ # Check iterations_per_loop. @@ -205,6 +209,8 @@ class NPURunConfig(run_config_lib.RunConfig): self._dynamic_input = dynamic_input self._dynamic_graph_execute_mode = dynamic_graph_execute_mode self._dynamic_inputs_shape_range = dynamic_inputs_shape_range + self._local_rank_id = local_rank_id + self._local_device_list = local_device_list super(NPURunConfig, self).__init__( model_dir=model_dir, diff --git a/tf_adapter/python/npu_bridge/estimator/npu/npu_estimator.py b/tf_adapter/python/npu_bridge/estimator/npu/npu_estimator.py index 11abb1d4a..6d9484f4e 100644 --- a/tf_adapter/python/npu_bridge/estimator/npu/npu_estimator.py +++ b/tf_adapter/python/npu_bridge/estimator/npu/npu_estimator.py @@ -730,6 +730,10 @@ class NPUEstimator(estimator_lib.Estimator): custom_op.parameter_map["dynamic_graph_execute_mode"].s = tf.compat.as_bytes(config._dynamic_graph_execute_mode) if config._dynamic_inputs_shape_range is not None: custom_op.parameter_map["dynamic_inputs_shape_range"].s = tf.compat.as_bytes(config._dynamic_inputs_shape_range) + if config._local_rank_id is not None: + custom_op.parameter_map["local_rank_id"].i = config._local_rank_id + if config._local_device_list is not None: + custom_op.parameter_map["local_device_list"].s = tf.compat.as_bytes(config._local_device_list) # add profiling options to custom_op self.__load_profiling_options(config, custom_op) diff --git a/tf_adapter/util/ge_plugin.cc b/tf_adapter/util/ge_plugin.cc index c9c231ab3..7ee8e590a 100644 --- a/tf_adapter/util/ge_plugin.cc +++ b/tf_adapter/util/ge_plugin.cc @@ -180,7 +180,7 @@ void GePlugin::Init(std::map &init_options, bool is_gl // Open TsdClient first, then call GEInitialize LOG(INFO) << "[GePlugin] Open TsdClient and Init tdt host."; - int32_t ret = tdt::TdtHostInit(static_cast(device_id_)); + int32_t ret = tdt::TdtOutFeedInit(static_cast(device_id_)); if (ret != 0) { std::this_thread::sleep_for(std::chrono::milliseconds(kFatalSleepTime)); LOG(FATAL) << "[GePlugin] Tdt host init failed, tdt error code : " << ret; @@ -226,7 +226,8 @@ void GePlugin::Finalize() { if (status_parser != ge::SUCCESS) { LOG(ERROR) << "[GePlugin] Parser finalize failed, ret : " << ToString(status); } LOG(INFO) << "[GePlugin] Close TsdClient and destroy tdt."; - int32_t ret = tdt::TdtHostDestroy(); + int32_t ret = tdt::TdtOutFeedDestroy(); + if (ret != 0) { LOG(ERROR) << "[GePlugin] Close tdt host failed."; } isInit_ = false; } diff --git a/tf_adapter/util/npu_attrs.cc b/tf_adapter/util/npu_attrs.cc index fd2a60ced..01a1f3feb 100644 --- a/tf_adapter/util/npu_attrs.cc +++ b/tf_adapter/util/npu_attrs.cc @@ -198,6 +198,23 @@ inline Status CheckDynamicDims(const string &dynamic_dims) { return Status::OK(); } +inline Status CheckLocalRankId(int local_rank_id) { + int kMaxDeviceId = 7; + if (local_rank_id < 0 || local_rank_id > kMaxDeviceId) { + return errors::InvalidArgument("local rank id should be in [0,7]"); + } + return Status::OK(); +} + +inline Status CheckDeviceList(std::string local_device_list) { + std::string tmp_device_list = local_device_list + ","; + std::regex pattern("(\\d{1,},)+"); + if (!regex_match(tmp_device_list, pattern)) { + return errors::InvalidArgument("local_device_list style is invalid, example:'1,2,3'"); + } + return Status::OK(); +} + std::map NpuAttrs::GetSessOptions(OpKernelConstruction *ctx) { std::map sess_options; std::string variable_format_optimize = std::to_string(true); @@ -379,6 +396,8 @@ std::map NpuAttrs::GetPassOptions(const GraphOptimizat bool dynamic_input = false; std::string dynamic_graph_execute_mode = "dynamic_execute"; std::string dynamic_inputs_shape_range; + int local_rank_id = -1; + std::string local_device_list; for (const auto &custom_optimizer : rewrite_options.custom_optimizers()) { if (custom_optimizer.name() == "NpuOptimizer") { do_npu_optimizer = true; @@ -406,6 +425,16 @@ std::map NpuAttrs::GetPassOptions(const GraphOptimizat if (params.count("dynamic_inputs_shape_range")) { dynamic_inputs_shape_range = params.at("dynamic_inputs_shape_range").s(); } } } + if (params.count("local_rank_id")) { + local_rank_id = params.at("local_rank_id").i(); + Status s = CheckLocalRankId(local_rank_id); + if (!s.ok()) { LOG(FATAL) << s.error_message(); } + } + if (params.count("local_device_list")) { + local_device_list = params.at("local_device_list").s(); + Status s = CheckDeviceList(local_device_list); + if (!s.ok()) {LOG(FATAL) << s.error_message(); } + } } } if (!do_npu_optimizer) { @@ -427,6 +456,8 @@ std::map NpuAttrs::GetPassOptions(const GraphOptimizat pass_options["dynamic_input"] = std::to_string(dynamic_input); pass_options["dynamic_graph_execute_mode"] = dynamic_graph_execute_mode; pass_options["dynamic_inputs_shape_range"] = dynamic_inputs_shape_range; + pass_options["local_rank_id"] = std::to_string(local_rank_id); + pass_options["local_device_list"] = local_device_list; return pass_options; } @@ -444,6 +475,8 @@ std::map NpuAttrs::GetPassOptions(OpKernelConstruction std::string dynamic_input = std::to_string(false); std::string dynamic_graph_execute_mode = "dynamic_execute"; std::string dynamic_inputs_shape_range; + std::string local_rank_id = "-1"; + std::string local_device_list; Status s = Status::OK(); string npuOptimizer; @@ -459,6 +492,8 @@ std::map NpuAttrs::GetPassOptions(OpKernelConstruction ctx->GetAttr("_dynamic_input", &dynamic_input); ctx->GetAttr("_dynamic_graph_execute_mode", &dynamic_graph_execute_mode); ctx->GetAttr("_dynamic_inputs_shape_range", &dynamic_inputs_shape_range); + ctx->GetAttr("_local_rank_id", &local_rank_id); + ctx->GetAttr("_local_device_list", &local_device_list); } } // pass options @@ -473,6 +508,8 @@ std::map NpuAttrs::GetPassOptions(OpKernelConstruction pass_options["dynamic_input"] = dynamic_input; pass_options["dynamic_graph_execute_mode"] = dynamic_graph_execute_mode; pass_options["dynamic_inputs_shape_range"] = dynamic_inputs_shape_range; + pass_options["local_rank_id"] = local_rank_id; + pass_options["local_device_list"] = local_device_list; return pass_options; } @@ -490,6 +527,8 @@ std::map NpuAttrs::GetPassOptions(AttrSlice attrs) { std::string dynamic_input = std::to_string(false); std::string dynamic_graph_execute_mode = "dynamic_execute"; std::string dynamic_inputs_shape_range; + std::string local_rank_id = "-1"; + std::string local_device_list; Status s = Status::OK(); if (attrs.Find("_NpuOptimizer") != nullptr) { @@ -516,6 +555,12 @@ std::map NpuAttrs::GetPassOptions(AttrSlice attrs) { if (attrs.Find("_dynamic_inputs_shape_range") != nullptr) { dynamic_inputs_shape_range = attrs.Find("_dynamic_inputs_shape_range")->s(); } + if (attrs.Find("_local_rank_id") != nullptr) { + local_rank_id = attrs.Find("_local_rank_id")->s(); + } + if (attrs.Find("_local_device_list") != nullptr) { + local_device_list = attrs.Find("_local_device_list")->s(); + } } // pass options pass_options["do_npu_optimizer"] = do_npu_optimizer; @@ -529,6 +574,8 @@ std::map NpuAttrs::GetPassOptions(AttrSlice attrs) { pass_options["dynamic_input"] = dynamic_input; pass_options["dynamic_graph_execute_mode"] = dynamic_graph_execute_mode; pass_options["dynamic_inputs_shape_range"] = dynamic_inputs_shape_range; + pass_options["local_rank_id"] = local_rank_id; + pass_options["local_device_list"] = local_device_list; return pass_options; } @@ -543,6 +590,8 @@ std::map NpuAttrs::GetAllAttrOptions(AttrSlice attrs) std::string lower_functional_ops = std::to_string(false); string job = "default"; std::string task_index = "0"; + std::string local_rank_id = "-1"; + std::string local_device_list; Status s = Status::OK(); std::string variable_format_optimize = std::to_string(true); @@ -601,6 +650,12 @@ std::map NpuAttrs::GetAllAttrOptions(AttrSlice attrs) job = "localhost"; } if (attrs.Find("_task_index") != nullptr) { task_index = attrs.Find("_task_index")->s(); } + if (attrs.Find("_local_rank_id") != nullptr) { + local_rank_id = attrs.Find("_local_rank_id")->s(); + } + if (attrs.Find("_local_device_list") != nullptr) { + local_device_list = attrs.Find("_local_device_list")->s(); + } if (attrs.Find("_variable_format_optimize") != nullptr) { variable_format_optimize = attrs.Find("_variable_format_optimize")->s(); @@ -728,6 +783,8 @@ std::map NpuAttrs::GetAllAttrOptions(AttrSlice attrs) all_options["lower_functional_ops"] = lower_functional_ops; all_options["job"] = job; all_options["task_index"] = task_index; + all_options["local_rank_id"] = local_rank_id; + all_options["local_device_list"] = local_device_list; all_options["op_select_implmode"] = op_select_implmode; all_options["optypelist_for_implmode"] = optypelist_for_implmode; all_options["input_shape"] = input_shape; @@ -794,6 +851,8 @@ Status NpuAttrs::SetNpuOptimizerAttr(const GraphOptimizationPassOptions &options bool lower_functional_ops = false; string job = "localhost"; int task_index = 0; + int local_rank_id = -1; + std::string local_device_list; bool dynamic_input = false; std::string dynamic_graph_execute_mode = "dynamic_execute"; std::string dynamic_inputs_shape_range; @@ -929,6 +988,16 @@ Status NpuAttrs::SetNpuOptimizerAttr(const GraphOptimizationPassOptions &options if (params.count("dynamic_inputs_shape_range")) { dynamic_inputs_shape_range = params.at("dynamic_inputs_shape_range").s(); } } } + if (params.count("local_rank_id")) { + local_rank_id = params.at("local_rank_id").i(); + Status s = CheckLocalRankId(local_rank_id); + if (!s.ok()) { LOG(FATAL) << s.error_message(); } + } + if (params.count("local_device_list")) { + local_device_list = params.at("local_device_list").s(); + Status s = CheckDeviceList(local_device_list); + if (!s.ok()) { LOG(FATAL) << s.error_message(); } + } if (params.count("enable_exception_dump")) { enable_exception_dump = params.at("enable_exception_dump").i(); } if (!params.count("op_select_implmode") && !params.count("optypelist_for_implmode")) { @@ -1042,6 +1111,8 @@ Status NpuAttrs::SetNpuOptimizerAttr(const GraphOptimizationPassOptions &options pass_options["dynamic_input"] = std::to_string(dynamic_input); pass_options["dynamic_graph_execute_mode"] = dynamic_graph_execute_mode; pass_options["dynamic_inputs_shape_range"] = dynamic_inputs_shape_range; + pass_options["local_rank_id"] = std::to_string(local_rank_id); + pass_options["local_device_list"] = local_device_list; std::string attr_name; for (const auto &option : sess_options) { -- Gitee