From 99a99d8809eed1746e476a25d4f680d84a1589bd Mon Sep 17 00:00:00 2001 From: yanqingshang Date: Thu, 24 Dec 2020 10:32:23 +0800 Subject: [PATCH 01/17] sync code --- .../interface_spec/api_npu_optimizer.pyh | 2 +- .../optimizers/dp_tf_ge_conversion_pass.cc | 2 +- .../optimizers/gradient_fusion_optimizer.cc | 5 - .../optimizers/om_partition_subgraphs_pass.cc | 2 +- .../optimizers/weight_update_sharding_pass.cc | 232 ------------------ .../optimizers/weight_update_sharding_pass.h | 44 ---- .../npu_bridge/estimator/npu/npu_estimator.py | 4 +- .../npu_bridge/estimator/npu/npu_hook.py | 2 +- .../estimator/npu/npu_loss_scale_optimizer.py | 3 +- .../npu_bridge/estimator/npu/npu_optimizer.py | 100 +------- .../npu_bridge/estimator/npu/npu_saver.py | 143 ----------- .../python/npu_bridge/estimator/npu/util.py | 123 +--------- 12 files changed, 15 insertions(+), 647 deletions(-) delete mode 100644 tf_adapter/optimizers/weight_update_sharding_pass.cc delete mode 100644 tf_adapter/optimizers/weight_update_sharding_pass.h delete mode 100644 tf_adapter/python/npu_bridge/estimator/npu/npu_saver.py diff --git a/tf_adapter/interface_spec/api_npu_optimizer.pyh b/tf_adapter/interface_spec/api_npu_optimizer.pyh index 27d4163de..485580506 100644 --- a/tf_adapter/interface_spec/api_npu_optimizer.pyh +++ b/tf_adapter/interface_spec/api_npu_optimizer.pyh @@ -5,4 +5,4 @@ class NPUOptimizer(optimizer.Optimizer): is_loss_scale=False, is_tailing_optimization=False, name=None): class NPUDistributedOptimizer(tf.train.Optimizer): - def __init__(self, optimizer, is_weight_update_sharding=False, name=None): + def __init__(self, optimizer, name=None): diff --git a/tf_adapter/optimizers/dp_tf_ge_conversion_pass.cc b/tf_adapter/optimizers/dp_tf_ge_conversion_pass.cc index b8fa518bc..b28cd32f9 100644 --- a/tf_adapter/optimizers/dp_tf_ge_conversion_pass.cc +++ b/tf_adapter/optimizers/dp_tf_ge_conversion_pass.cc @@ -935,6 +935,6 @@ Status DpTfToGEConversionPassImpl::ProcessGraph(std::unique_ptr *graph, F } // We register DpTfToGE insertion for phase 102 in POST_PARTITIONING grouping -REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 4, DpTfToGEConversionPass); +REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 3, DpTfToGEConversionPass); REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_PARTITIONING, 102, DpTfToGEConversionPass); } // namespace tensorflow diff --git a/tf_adapter/optimizers/gradient_fusion_optimizer.cc b/tf_adapter/optimizers/gradient_fusion_optimizer.cc index 98dab597a..0235497e0 100644 --- a/tf_adapter/optimizers/gradient_fusion_optimizer.cc +++ b/tf_adapter/optimizers/gradient_fusion_optimizer.cc @@ -263,11 +263,6 @@ Status GradFusionOptimizer::Optimize(Cluster *cluster, const GrapplerItem &item, for (const auto &nodeDef : graphOrigin.node()) { if (IsHcomOp(nodeDef)) { - std::string op_name; - op_name = nodeDef.name(); - if (op_name.find("_Weight_Update_Sharding") != std::string::npos) { - continue; - } DataType dType; auto attrMap = nodeDef.attr(); auto iter = attrMap.find(DATA_TYPE_ATTR); diff --git a/tf_adapter/optimizers/om_partition_subgraphs_pass.cc b/tf_adapter/optimizers/om_partition_subgraphs_pass.cc index 08da5757c..151f21f9f 100644 --- a/tf_adapter/optimizers/om_partition_subgraphs_pass.cc +++ b/tf_adapter/optimizers/om_partition_subgraphs_pass.cc @@ -1972,6 +1972,6 @@ Status OMPartitionSubgraphsPass::ProcessGraph(std::unique_ptr *graph, Fun return Status::OK(); } -REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 3, OMPartitionSubgraphsPass); +REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 2, OMPartitionSubgraphsPass); REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_PARTITIONING, 101, OMPartitionSubgraphsPass); } // namespace tensorflow diff --git a/tf_adapter/optimizers/weight_update_sharding_pass.cc b/tf_adapter/optimizers/weight_update_sharding_pass.cc deleted file mode 100644 index 1d34e9f7c..000000000 --- a/tf_adapter/optimizers/weight_update_sharding_pass.cc +++ /dev/null @@ -1,232 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. - -Copyright (C) 2019-2020. Huawei Technologies Co., Ltd. All rights reserved. -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tf_adapter/optimizers/weight_update_sharding_pass.h" - -#include -#include -#include -#include - -#include "tensorflow/core/common_runtime/function.h" -#include "tensorflow/core/public/session_options.h" -#include "tf_adapter/common/common.h" -#include "tf_adapter/util/npu_attrs.h" -#include "tf_adapter/util/infershape_util.h" - -namespace tensorflow { -static const int64 kMicrosToMillis = 1000; - -static std::atomic graph_run_num(1); -Status WeightUpdateShardingPass::Run(const GraphOptimizationPassOptions &options) { - if (options.graph == nullptr || options.flib_def == nullptr || options.session_options == nullptr) { - return Status::OK(); - } - int graph_num; - graph_num = graph_run_num++; - - Graph *graphIn = (options.graph)->get(); - std::map pass_options = NpuAttrs::GetPassOptions(options); - std::string job = pass_options["job"]; - if (job == "ps" || job == "default") { - LOG(INFO) << "job is " << job << " Skip the optimizer : WeightUpdateShardingPass. "; - return Status::OK(); - } - - bool weight_update_sharding = false; - bool npu_loss_scale = false; - for (Node *node : graphIn->op_nodes()) { - REQUIRES_NOT_NULL(node); - std::string op_name; - std::string op_type; - op_name = node->name(); - op_type = node->type_string(); - if (op_name.find("_Broadcast_Weight_Update_Sharding") != std::string::npos) { - weight_update_sharding = true; - if (npu_loss_scale == true) { - break; - } - } - if (op_name.find("NPULossScaleOptimizer") != std::string::npos && - op_type == "NpuAllocFloatStatus") { - npu_loss_scale = true; - if (weight_update_sharding == true) { - break; - } - } - } - - if (weight_update_sharding) { - int64 startTime = InferShapeUtil::GetCurrentTimestap(); - char *need_print = getenv("PRINT_MODEL"); - - if (need_print != nullptr && strcmp("1", need_print) == 0) { - GraphDef ori_graph_def; - graphIn->ToGraphDef(&ori_graph_def); - string ori_model_path = "BeforeWeightUpdateSharding_"; - string omodel_path = ori_model_path + std::to_string(graph_num) + ".pbtxt"; - Status status_out = WriteTextProto(Env::Default(), omodel_path, ori_graph_def); - } - - std::vector in_nodes; - for (Node *node : graphIn->nodes()) { in_nodes.push_back(node); } - for (int i = in_nodes.size() - 1; i >= 0; i--) { - Node *node = in_nodes.at(i); - REQUIRES_NOT_NULL(node); - std::string op_type = node->type_string(); - std::string dst_name; - std::string dst_type; - if (op_type == "VarHandleOp" || op_type == "Identity" || - op_type == "ReadVariableOp") { - Node *var_node = nullptr; - Node *broadcast_node = nullptr; - Node *forward_node = nullptr; - int forward_input_idx = -1; - std::vector remove_edges; - for (auto in_edge : node->in_edges()) { - REQUIRES_NOT_NULL(in_edge); - REQUIRES_NOT_NULL(in_edge->src()); - REQUIRES_NOT_NULL(in_edge->dst()); - if (in_edge->src()->IsVariable()) { - var_node = in_edge->src(); - break; - } - } - std::vector out_edges; - for (auto edge : node->out_edges()) { out_edges.push_back(edge); } - for (auto out_edge : out_edges) { - REQUIRES_NOT_NULL(out_edge); - REQUIRES_NOT_NULL(out_edge->src()); - REQUIRES_NOT_NULL(out_edge->dst()); - dst_name = out_edge->dst()->name(); - dst_type = out_edge->dst()->type_string(); - if (!npu_loss_scale) { - if (dst_name.find("_Broadcast_Weight_Update_Sharding") != std::string::npos && - dst_type == "HcomBroadcast") { - bool find_broadcast = false; - for (auto broadcast_edge : out_edge->dst()->in_edges()) { - REQUIRES_NOT_NULL(broadcast_edge); - REQUIRES_NOT_NULL(broadcast_edge->src()); - REQUIRES_NOT_NULL(broadcast_edge->dst()); - if (broadcast_edge->IsControlEdge()) { - find_broadcast = true; - // remove edge : reduce/apply --> broadcast - remove_edges.push_back(broadcast_edge); - } - } - if (find_broadcast) { - broadcast_node = out_edge->dst(); - //remove edge : VarHandleOp/Identity --> broadcast - remove_edges.push_back(out_edge); - for (auto broadcast_edge : out_edge->dst()->out_edges()) { - REQUIRES_NOT_NULL(broadcast_edge); - REQUIRES_NOT_NULL(broadcast_edge->src()); - REQUIRES_NOT_NULL(broadcast_edge->dst()); - if (broadcast_edge->IsControlEdge()) { - // remove edge : broadcast --> group - remove_edges.push_back(broadcast_edge); - } - } - break; - } - } - } else { - if (dst_type == "Switch") { - for (auto switch_out_edge : out_edge->dst()->out_edges()) { - REQUIRES_NOT_NULL(switch_out_edge); - REQUIRES_NOT_NULL(switch_out_edge->src()); - REQUIRES_NOT_NULL(switch_out_edge->dst()); - std::string node_name = switch_out_edge->dst()->name(); - std::string node_type = switch_out_edge->dst()->type_string(); - if (node_name.find("_Broadcast_Weight_Update_Sharding") != std::string::npos && - node_type == "HcomBroadcast") { - bool find_broadcast = false; - for (auto broadcast_edge : switch_out_edge->dst()->in_edges()) { - REQUIRES_NOT_NULL(broadcast_edge); - REQUIRES_NOT_NULL(broadcast_edge->src()); - REQUIRES_NOT_NULL(broadcast_edge->dst()); - if (broadcast_edge->IsControlEdge()) { - find_broadcast = true; - // remove edge : reduce/apply --> broadcast - remove_edges.push_back(broadcast_edge); - } - } - if (find_broadcast) { - broadcast_node = switch_out_edge->dst(); - //remove edge : Switch --> broadcast - remove_edges.push_back(switch_out_edge); - for (auto broadcast_edge : switch_out_edge->dst()->out_edges()) { - REQUIRES_NOT_NULL(broadcast_edge); - REQUIRES_NOT_NULL(broadcast_edge->src()); - REQUIRES_NOT_NULL(broadcast_edge->dst()); - if (broadcast_edge->IsControlEdge()) { - //remove edge : broadcast --> group - remove_edges.push_back(broadcast_edge); - } - } - break; - } - } - } - } - } - } - if (broadcast_node != nullptr && var_node != nullptr) { - for (auto edge : remove_edges) { - graphIn->RemoveEdge(edge); - } - // add edge : variable --> broadcast - graphIn->AddEdge(var_node, 0, broadcast_node, 0); - for (auto var_edge : var_node->out_edges()) { - REQUIRES_NOT_NULL(var_edge); - REQUIRES_NOT_NULL(var_edge->src()); - REQUIRES_NOT_NULL(var_edge->dst()); - if (var_edge->dst() != broadcast_node) { - graphIn->AddControlEdge(broadcast_node, var_edge->dst()); - } - } - } - } - } - - if (need_print != nullptr && strcmp("1", need_print) == 0) { - GraphDef omg_graph_def; - graphIn->ToGraphDef(&omg_graph_def); - string tmpmodel_path = "AfterWeightUpdateSharding_"; - string tmodel_path = tmpmodel_path + std::to_string(graph_num) + ".pbtxt"; - Status status_o = WriteTextProto(Env::Default(), tmodel_path, omg_graph_def); - } - int64 endTime = InferShapeUtil::GetCurrentTimestap(); - LOG(INFO) << "WeightUpdateSharding_" << std::to_string(graph_num) << " success. [" - << ((endTime - startTime) / kMicrosToMillis) << " ms]"; - } - - return Status::OK(); -} - -REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 2, WeightUpdateShardingPass); -} // namespace tensorflow diff --git a/tf_adapter/optimizers/weight_update_sharding_pass.h b/tf_adapter/optimizers/weight_update_sharding_pass.h deleted file mode 100644 index 4306bfca0..000000000 --- a/tf_adapter/optimizers/weight_update_sharding_pass.h +++ /dev/null @@ -1,44 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. - -Copyright (C) 2019-2020. Huawei Technologies Co., Ltd. All rights reserved. -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_WEIGHT_UPDATE_SHARDING_PASS_H_ -#define TENSORFLOW_WEIGHT_UPDATE_SHARDING_PASS_H_ - -#include "tensorflow/core/common_runtime/optimization_registry.h" -#include "tensorflow/core/framework/function.h" -#include "tensorflow/core/graph/graph.h" -#include "tensorflow/core/lib/core/status.h" - -namespace tensorflow { -class WeightUpdateShardingPass : public GraphOptimizationPass { - public: - WeightUpdateShardingPass() = default; - ~WeightUpdateShardingPass() override = default; - Status Run(const GraphOptimizationPassOptions &options) override; -}; -} // namespace tensorflow -#endif // TENSORFLOW_WEIGHT_UPDATE_SHARDING_PASS_H_ diff --git a/tf_adapter/python/npu_bridge/estimator/npu/npu_estimator.py b/tf_adapter/python/npu_bridge/estimator/npu/npu_estimator.py index 690e02103..699ceb464 100644 --- a/tf_adapter/python/npu_bridge/estimator/npu/npu_estimator.py +++ b/tf_adapter/python/npu_bridge/estimator/npu/npu_estimator.py @@ -22,7 +22,6 @@ from npu_bridge.estimator.npu.npu_config import NPURunConfig from npu_bridge.estimator.npu.npu_hook import * from npu_bridge.estimator.npu.npu_common import NPUBasics from npu_bridge.estimator import npu_ops -from npu_bridge.estimator.npu.npu_saver import * import six from six.moves import queue as Queue @@ -442,8 +441,7 @@ class NPUEstimator(estimator_lib.Estimator): npu_hooks.append(NPUCheckpointSaverHook( checkpoint_dir=model_dir, save_secs=config.save_checkpoints_secs, - save_steps=config.save_checkpoints_steps, - saver=NPUSaver())) + save_steps=config.save_checkpoints_steps)) if isinstance(estimator_spec, NPUEstimatorSpec): if estimator_spec._host_call is not None: diff --git a/tf_adapter/python/npu_bridge/estimator/npu/npu_hook.py b/tf_adapter/python/npu_bridge/estimator/npu/npu_hook.py index 02a5dd576..f1fc8bb69 100644 --- a/tf_adapter/python/npu_bridge/estimator/npu/npu_hook.py +++ b/tf_adapter/python/npu_bridge/estimator/npu/npu_hook.py @@ -143,7 +143,7 @@ def broadcast_global_variables(root_rank, index): to all other processes. """ op_list = [] - for var in tf.trainable_variables(): + for var in tf.global_variables(): # the input and out tensor of HCOMBroadcast interface are list inputs = [var] outputs=hccl_ops.broadcast(tensor=inputs,root_rank=root_rank) diff --git a/tf_adapter/python/npu_bridge/estimator/npu/npu_loss_scale_optimizer.py b/tf_adapter/python/npu_bridge/estimator/npu/npu_loss_scale_optimizer.py index 24f3759f3..0356f7f8f 100644 --- a/tf_adapter/python/npu_bridge/estimator/npu/npu_loss_scale_optimizer.py +++ b/tf_adapter/python/npu_bridge/estimator/npu/npu_loss_scale_optimizer.py @@ -116,8 +116,7 @@ class NPULossScaleOptimizer(optimizer.Optimizer): scaled_loss = loss_val * math_ops.cast(loss_scale, loss_val.dtype.base_dtype) - with tf.name_scope(self._name): - self._float_status = gen_npu_ops.npu_alloc_float_status() + self._float_status = gen_npu_ops.npu_alloc_float_status() grads_and_vars = self._opt.compute_gradients( scaled_loss, diff --git a/tf_adapter/python/npu_bridge/estimator/npu/npu_optimizer.py b/tf_adapter/python/npu_bridge/estimator/npu/npu_optimizer.py index 2d0930e42..8e89daca9 100644 --- a/tf_adapter/python/npu_bridge/estimator/npu/npu_optimizer.py +++ b/tf_adapter/python/npu_bridge/estimator/npu/npu_optimizer.py @@ -22,7 +22,6 @@ gen_npu_ops = helper.get_gen_ops(); from tensorflow.python.platform import tf_logging as logging from npu_bridge.estimator.npu.npu_common import NPUBasics -from npu_bridge.estimator.npu import util def allreduce(tensor, var, average=True): """ @@ -71,43 +70,6 @@ def allreduce(tensor, var, average=True): return new_tensor -def reduce(tensor, var, root_rank, average=True, fusion=0, fusion_id=-1): - basic = NPUBasics("") - size = basic.size() - # the tensor is the instance of tf.IndexedSlices - if isinstance(tensor, tf.IndexedSlices): - # For IndexedSlices, do two allgathers intead of a reduce. - logging.debug("HcomAllgather...") - values=hccl_ops.allgather(tensor.values, size) - indices=hccl_ops.allgather(tensor.indices, size) - - if values is None: - raise ValueError('the result of tf.HcomAllgather([tensor.values]) is empty') - if indices is None: - raise ValueError('the result of tf.HcomAllgather([tensor.indices]) is empty') - - # To make this operation into an average, divide all gathered values by the size. - rank_size = tf.cast(size, tensor.values.dtype) - new_values = tf.div(values, rank_size) if average else values - - return tf.IndexedSlices(new_values, indices,dense_shape=tensor.dense_shape) - - else: - logging.debug("HcomReduce...") - local_rank_id = os.getenv('DEVICE_ID') - if local_rank_id == None or int(local_rank_id) < 0: - raise ValueError('Please set the correct RANK_ID value, current RANK_ID is:', local_rank_id) - - summed_tensor=hccl_ops.reduce(tensor,"sum", root_rank, fusion, fusion_id) - if summed_tensor is None:# and summed_tensor: - raise ValueError('the result of tf.DavinciReduce([tensor]) is empty') - if root_rank != int(local_rank_id): - return summed_tensor - else: - rank_size = tf.cast(size, dtype=tensor.dtype) - new_tensor = tf.div(summed_tensor, rank_size) if average else summed_tensor - return new_tensor - class NPUOptimizer(optimizer.Optimizer): """An optimizer that wraps another tf.Optimizer that can using an allreduce to average gradient values before applying gradients to model weights when @@ -266,9 +228,7 @@ class NPUDistributedOptimizer(tf.train.Optimizer): average gradient values before applying gradients to model weights. """ - def __init__(self, optimizer, - is_weight_update_sharding=False, - name=None): + def __init__(self, optimizer, name=None): """ Construct a new DistributedOptimizer, which uses another optimizer under the hood for computing single-process gradient values and @@ -285,7 +245,6 @@ class NPUDistributedOptimizer(tf.train.Optimizer): if name is None: name = "Distributed{}".format(type(optimizer).__name__) self._optimizer = optimizer - self._is_weight_update_sharding = is_weight_update_sharding super(NPUDistributedOptimizer, self).__init__(name=name, use_locking=False) def compute_gradients(self, *args, **kwargs): @@ -302,58 +261,15 @@ class NPUDistributedOptimizer(tf.train.Optimizer): return gradients averaged_gradients = [] - if self._is_weight_update_sharding and int(rank_size) <= len(gradients): - local_rank_id = os.getenv('DEVICE_ID') - if local_rank_id == None or int(local_rank_id) < 0: - raise ValueError('Please set the correct RANK_ID value, current RANK_ID is:', local_rank_id) - util.add_grads_and_vars(gradients, int(rank_size)) - with tf.name_scope(self._name + "_Reduce_Weight_Update_Sharding"): - for grad, var in gradients: - rank_id = util.get_gid_by_grad(grad) - avg_grad = reduce(grad, var, rank_id, True, 2, rank_id) if grad is not None else None - averaged_gradients.append((avg_grad, var)) - elif self._is_weight_update_sharding and int(rank_size) > len(gradients): - raise ValueError("The number of gradients is less than rank_size, " - "so weight_update_sharding cannot be executed") - else: - with tf.name_scope(self._name + "_Allreduce"): - for grad, var in gradients: - avg_grad = allreduce(grad, var, True) if grad is not None else None - averaged_gradients.append((avg_grad, var)) + with tf.name_scope(self._name + "_Allreduce"): + for grad, var in gradients: + avg_grad = allreduce(grad, var, True) if grad is not None else None + averaged_gradients.append((avg_grad, var)) return averaged_gradients - def apply_gradients(self, grads_and_vars, global_step=None, name=None): - rank_size = os.getenv('RANK_SIZE') - if rank_size == None or int(rank_size) <= 1: - return self._optimizer.apply_gradients(grads_and_vars, global_step, name) - - if self._is_weight_update_sharding: - op_list = [] - local_rank_id = os.getenv('DEVICE_ID') - if local_rank_id == None or int(local_rank_id) < 0: - raise ValueError('Please set the correct RANK_ID value, current RANK_ID is:', local_rank_id) - local_grads_and_vars = [] - for grad, var in grads_and_vars: - rank_id = util.get_gid_by_weight(var) - if rank_id >= 0 and rank_id == int(local_rank_id): - local_grads_and_vars.append((grad, var)) - apply_res = self._optimizer.apply_gradients(local_grads_and_vars, global_step, name) - with tf.get_default_graph().control_dependencies([apply_res]): - with tf.name_scope(self._name + "_Broadcast_Weight_Update_Sharding"): - for grad, var in grads_and_vars: - rank_id = util.get_gid_by_weight(var) - with tf.get_default_graph().control_dependencies(op_list): - outputs = hccl_ops.broadcast([var], rank_id) - if outputs is not None: - op_list.append(outputs[0].op) - for grad, var in grads_and_vars: - rank_id = util.get_gid_by_weight(var) - if rank_id >= 0 and rank_id != int(local_rank_id): - op_list.append(grad) - op_list.append(apply_res) - return tf.group(op_list) - else: - return self._optimizer.apply_gradients(grads_and_vars, global_step, name) + def apply_gradients(self, *args, **kwargs): + """Calls this same method on the underlying optimizer.""" + return self._optimizer.apply_gradients(*args, **kwargs) def get_slot(self, *args, **kwargs): """Calls this same method on the underlying optimizer.""" diff --git a/tf_adapter/python/npu_bridge/estimator/npu/npu_saver.py b/tf_adapter/python/npu_bridge/estimator/npu/npu_saver.py deleted file mode 100644 index 0d2b87954..000000000 --- a/tf_adapter/python/npu_bridge/estimator/npu/npu_saver.py +++ /dev/null @@ -1,143 +0,0 @@ -from tensorflow.python.training.saver import BulkSaverBuilder -from tensorflow.python.training.saver import Saver -from npu_bridge.estimator.npu import util -import tensorflow as tf -from npu_bridge.hccl import hccl_ops -from tensorflow.python.platform import tf_logging as logging - -from tensorflow.python.eager import context -from tensorflow.python.training.saving import saveable_object_util -from tensorflow.core.protobuf import saver_pb2 -from tensorflow.python.framework import ops -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import variables - -class NPUBulkSaverBuilder(BulkSaverBuilder): - def _build_internal(self, - names_to_saveables, - reshape=False, - sharded=False, - max_to_keep=5, - keep_checkpoint_every_n_hours=10000.0, - name=None, - restore_sequentially=False, - filename="model", - build_save=True, - build_restore=True): - """build() with option to only perform save and restore.""" - if not context.executing_eagerly() and (not build_save or - not build_restore): - raise ValueError("save and restore operations need to be built together " - " when eager execution is not enabled.") - - saveables = saveable_object_util.validate_and_slice_inputs( - names_to_saveables) - if max_to_keep is None: - max_to_keep = 0 - - with ops.name_scope(name, "save", - [saveable.op for saveable in saveables]) as name: - # Add a placeholder string tensor for the filename. - filename_tensor = array_ops.placeholder_with_default( - filename or "model", shape=(), name="filename") - # Keep the name "Const" for backwards compatibility. - filename_tensor = array_ops.placeholder_with_default( - filename_tensor, shape=(), name="Const") - - # Add the save ops. - if sharded: - per_device = self._GroupByDevices(saveables) - if build_save: - op_list = [] - with tf.name_scope("Save_Weight_Update_Sharding"): - grad_and_var_items = util.get_all_grad_item() - for item in grad_and_var_items: - if item.var in names_to_saveables: - rank_id = item.root_rank_id - if rank_id >= 0: - with tf.get_default_graph().control_dependencies(op_list): - out_var = hccl_ops.broadcast([item.var], rank_id, 2, rank_id) - op_list.append(out_var[0].op) - if len(op_list) > 0: - with tf.get_default_graph().control_dependencies(op_list): - save_tensor = self._AddShardedSaveOps(filename_tensor, per_device) - else: - save_tensor = self._AddShardedSaveOps(filename_tensor, per_device) - if build_restore: - restore_op = self._AddShardedRestoreOps(filename_tensor, per_device, - restore_sequentially, reshape) - else: - if build_save: - op_list = [] - with tf.name_scope("Save_Weight_Update_Sharding"): - grad_and_var_items = util.get_all_grad_item() - for item in grad_and_var_items: - if item.var in names_to_saveables: - rank_id = item.root_rank_id - if rank_id >= 0: - with tf.get_default_graph().control_dependencies(op_list): - out_var = hccl_ops.broadcast([item.var], rank_id, 2, rank_id) - op_list.append(out_var[0].op) - if len(op_list) > 0: - with tf.get_default_graph().control_dependencies(op_list): - save_tensor = self._AddSaveOps(filename_tensor, saveables) - else: - save_tensor = self._AddSaveOps(filename_tensor, saveables) - if build_restore: - restore_op = self._AddRestoreOps(filename_tensor, saveables, - restore_sequentially, reshape) - - # In the following use case, it's possible to have restore_ops be called - # something else: - # - Build inference graph and export a meta_graph. - # - Import the inference meta_graph - # - Extend the inference graph to a train graph. - # - Export a new meta_graph. - # Now the second restore_op will be called "restore_all_1". - # As such, comment out the assert for now until we know whether supporting - # such usage model makes sense. - # - # assert restore_op.name.endswith("restore_all"), restore_op.name - if context.executing_eagerly(): - # Store the tensor values to the tensor_names. - save_tensor_name = save_tensor.numpy() if build_save else "" - return saver_pb2.SaverDef( - filename_tensor_name=filename_tensor.numpy(), - save_tensor_name=save_tensor_name, - restore_op_name="", - max_to_keep=max_to_keep, - sharded=sharded, - keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours, - version=self._write_version) - else: - graph = ops.get_default_graph() - # Do some sanity checking on collections containing - # PartitionedVariables. If a saved collection has a PartitionedVariable, - # the GraphDef needs to include concat ops to get the value (or there'll - # be a lookup error on load). - check_collection_list = graph.get_all_collection_keys() - for collection_type in check_collection_list: - for element in graph.get_collection(collection_type): - if isinstance(element, variables.PartitionedVariable): - try: - graph.get_operation_by_name(element.name) - except KeyError: - # Create a concat op for this PartitionedVariable. The user may - # not need it, but we'll try looking it up on MetaGraph restore - # since it's in a collection. - element.as_tensor() - return saver_pb2.SaverDef( - filename_tensor_name=filename_tensor.name, - save_tensor_name=save_tensor.name, - restore_op_name=restore_op.name, - max_to_keep=max_to_keep, - sharded=sharded, - keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours, - version=self._write_version) - -class NPUSaver(Saver): - def _build(self, checkpoint_path, build_save, build_restore): - if not self.saver_def or context.executing_eagerly(): - if self._builder is None: - self._builder = NPUBulkSaverBuilder(self._write_version) - super()._build(checkpoint_path=checkpoint_path, build_save=build_save, build_restore=build_restore) \ No newline at end of file diff --git a/tf_adapter/python/npu_bridge/estimator/npu/util.py b/tf_adapter/python/npu_bridge/estimator/npu/util.py index 1796f7862..955eae281 100644 --- a/tf_adapter/python/npu_bridge/estimator/npu/util.py +++ b/tf_adapter/python/npu_bridge/estimator/npu/util.py @@ -249,125 +249,4 @@ def variable_initializer_in_host(var_list): Returns: An Op that run the initializers of all the specified variables. """ - return tf.initializers.variables(var_list, name='var_in_host') - -def fair_division(input, number): - def get_sum(list): - res = 0 - for item in list: - res += item.size - return res - - def get_left_input_sum(list): - res = 0 - for item in list: - if item.root_rank_id < 0: - res += item.size - return res - - def get_average(list, size): - large_number_list = [] - average_size = 0 - res = 0 - if size == 1: - for item in list: - if item.root_rank_id < 0: - res += item.size - return res - while True: - res = 0 - find_large_number = False - for item in list: - if item not in large_number_list and item.root_rank_id < 0: - res += item.size - average_size = res // (size - len(large_number_list)) - for item in list: - if item not in large_number_list and item.root_rank_id < 0 and item.size > res - item.size: - find_large_number = True - large_number_list.append(item) - if not find_large_number: - break - return average_size - - if number > len(input) or number < 0: - raise ValueError("'number' is greater than the number of inputs or 'number' is less than 0. ") - elif number == len(input): - for i in range(len(input)): - input[i].root_rank_id = i - return input - - j = -1 - last_index = 0 - while True: - j = j+1 - total_number = number - j - if total_number == 0: - break - average_size = get_average(input, total_number) - tmp_list = [] - tmp_last_index = last_index - for i in range(tmp_last_index, len(input) - total_number + 1): - if get_sum(tmp_list) + input[i].size <= average_size: - input[i].root_rank_id = j - tmp_list.append(input[i]) - last_index = i+1 - else: - if len(tmp_list) <= 0: - input[i].root_rank_id = j - tmp_list.append(input[i]) - last_index = i+1 - elif (get_sum(tmp_list) + input[i].size - average_size) <= (average_size - get_sum(tmp_list)): - input[i].root_rank_id = j - tmp_list.append(input[i]) - last_index = i+1 - break - - return input - -class GradDivisionItem(): - def __init__(self, grad, var): - self.grad = grad - self.var = var - self.size = self.__get_size() - self.root_rank_id = -1 - - def __get_size(self): - size = 1 - grad_shape = self.grad.shape - if len(grad_shape) <= 0: - return 0 - for i in range(len(grad_shape)): - size = size * int(grad_shape[i]) - size = size * self.grad.dtype.size - return size - -_GRADIENTS_AND_VARS = [] - -def add_grads_and_vars(grads_and_vars, rank_size): - global _GRADIENTS_AND_VARS - _GRADIENTS_AND_VARS.clear() - for grad, var in grads_and_vars: - if grad is not None: - item = GradDivisionItem(grad, var) - _GRADIENTS_AND_VARS.append(item) - _GRADIENTS_AND_VARS = fair_division(_GRADIENTS_AND_VARS, rank_size) - -def get_gid_by_grad(grad): - gid = -1 - global _GRADIENTS_AND_VARS - for item in _GRADIENTS_AND_VARS: - if item.grad.name == grad.name: - gid = item.root_rank_id - return gid - -def get_gid_by_weight(weight): - gid = -1 - global _GRADIENTS_AND_VARS - for item in _GRADIENTS_AND_VARS: - if item.var.name == weight.name: - gid = item.root_rank_id - return gid - -def get_all_grad_item(): - global _GRADIENTS_AND_VARS - return _GRADIENTS_AND_VARS + return tf.initializers.variables(var_list, name='var_in_host') \ No newline at end of file -- Gitee From de62f9325b2d9f9d0f7634505ed6162eccd90d38 Mon Sep 17 00:00:00 2001 From: yanqingshang Date: Fri, 25 Dec 2020 11:03:50 +0800 Subject: [PATCH 02/17] sync code --- conver_tf2npu/README.md | 4 ++-- conver_tf2npu/conver.py | 3 ++- conver_tf2npu/main.py | 14 +++++++++++--- conver_tf2npu/util_global.py | 3 +++ 4 files changed, 18 insertions(+), 6 deletions(-) diff --git a/conver_tf2npu/README.md b/conver_tf2npu/README.md index 17ab08d16..ee3550aa9 100644 --- a/conver_tf2npu/README.md +++ b/conver_tf2npu/README.md @@ -20,9 +20,9 @@ /home/BERT --这个是被迁移的脚本路径 - /home/out --这个是迁移后的脚本路径 + /home/out --这个是迁移后的脚本路径,会在这个目录下生成转换后的脚本,文件命名规则:BERT_npu_yyyyMMddHHmmss - /home/report --这个是迁移过程的迁移报告 + /home/report --这个是迁移过程的迁移报告,会在这个目录下生成报告,文件命名规则:report_npu_yyyyMMddHHmmss 迁移报告分三种: diff --git a/conver_tf2npu/conver.py b/conver_tf2npu/conver.py index cf4acd977..a3929326e 100644 --- a/conver_tf2npu/conver.py +++ b/conver_tf2npu/conver.py @@ -24,10 +24,11 @@ def conver(): print("Begin conver, input file: " + util_global.get_value('input')) out_path = util_global.get_value('output') dst_path = os.path.split(util_global.get_value('input').rstrip('\\/'))[-1] + dst_path_new = dst_path + util_global.get_value('timestap') conver_path = os.walk(util_global.get_value('input')) for path,dir_list,file_list in conver_path: for file_name in file_list: - out_path_dst = abs_join(dst_path, path.split(dst_path)[1]) + out_path_dst = abs_join(dst_path_new, path.split(dst_path)[1]) if file_name.endswith(".py"): util_global.set_value('path', os.path.join(path, file_name)) mkdir(os.path.join(out_path, out_path_dst)) diff --git a/conver_tf2npu/main.py b/conver_tf2npu/main.py index 5ed7f3d04..925cac690 100644 --- a/conver_tf2npu/main.py +++ b/conver_tf2npu/main.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ +import os import sys import getopt import util_global @@ -20,8 +21,9 @@ from conver import conver def para_check_and_set(argv): input = "input" - output = "output" - report = "report" + output = "output" + util_global.get_value('timestap') + report = "report" + util_global.get_value('timestap') + report_suffix = report try: opts, args = getopt.getopt(argv, "hi:o:r:", ["help", "input=", "output=", "report="]) @@ -44,10 +46,17 @@ def para_check_and_set(argv): sys.exit() elif opt in ("-i", "--input"): input = arg + if str(input).endswith('/'): + input = input[0:len(input)-1] elif opt in ("-o", "--output"): output = arg + if str(output).endswith('/'): + output = output[0:len(output)-1] elif opt in ("-r", "--report"): report = arg + if str(report).endswith('/'): + report = report[0:len(report)-1] + report = os.path.join(report, report_suffix) util_global.set_value('input', input) util_global.set_value('output', output) util_global.set_value('report', report) @@ -55,5 +64,4 @@ def para_check_and_set(argv): if __name__ == "__main__": util_global._init() para_check_and_set(sys.argv[1:]) - before_clear() conver() \ No newline at end of file diff --git a/conver_tf2npu/util_global.py b/conver_tf2npu/util_global.py index 0a4ecefd0..6ce1b0bf5 100644 --- a/conver_tf2npu/util_global.py +++ b/conver_tf2npu/util_global.py @@ -13,6 +13,7 @@ # limitations under the License. # ============================================================================ import json +import time def _init(): global _global_dict @@ -22,6 +23,8 @@ def _init(): items = load_dict.items() for key, value in items: set_value(key, value) + value = "_npu_" + time.strftime('%Y%m%d%H%M%S') + set_value('timestap', value) def set_value(key, value): _global_dict[key] = value -- Gitee From 380ff725a6892a400fac923618ea4a9f541aea6f Mon Sep 17 00:00:00 2001 From: yanqingshang Date: Fri, 25 Dec 2020 15:49:51 +0800 Subject: [PATCH 03/17] sync code --- .idea/workspace.xml | 3 +-- configure.py | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/.idea/workspace.xml b/.idea/workspace.xml index e5ad95231..f60966896 100644 --- a/.idea/workspace.xml +++ b/.idea/workspace.xml @@ -10,8 +10,7 @@ - - +