From 80a1afaa288c76d3681d6e77bd49b21024d18c09 Mon Sep 17 00:00:00 2001 From: panghongjun Date: Wed, 9 Jun 2021 15:18:50 +0800 Subject: [PATCH] weight_update_sharding --- tf_adapter_2.x/npu_device/core/npu_device.cpp | 323 ++++++++++++- tf_adapter_2.x/npu_device/core/npu_device.h | 4 + .../core/weight_update_sharding_pass.cc | 428 ++++++++++++++++++ .../core/weight_update_sharding_pass.h | 51 +++ .../npu_device/_api/distribute/__init__.py | 1 + .../distribute/weight_update_grouping.py | 159 +++++++ .../optimizer/npu_loss_scale_optimizer.py | 8 +- 7 files changed, 971 insertions(+), 3 deletions(-) create mode 100644 tf_adapter_2.x/npu_device/core/weight_update_sharding_pass.cc create mode 100644 tf_adapter_2.x/npu_device/core/weight_update_sharding_pass.h create mode 100644 tf_adapter_2.x/python/npu_device/distribute/weight_update_grouping.py diff --git a/tf_adapter_2.x/npu_device/core/npu_device.cpp b/tf_adapter_2.x/npu_device/core/npu_device.cpp index bdaba6abf..d3e645f88 100644 --- a/tf_adapter_2.x/npu_device/core/npu_device.cpp +++ b/tf_adapter_2.x/npu_device/core/npu_device.cpp @@ -45,6 +45,12 @@ limitations under the License. #include "framework/omg/parser/model_parser.h" #include "framework/omg/parser/parser_factory.h" +#define REQUIRES_NOT_NULL(v) \ + if ((v) == nullptr) { \ + LOG(ERROR) << #v " is nullptr."; \ + return tensorflow::errors::InvalidArgument(#v " is nullptr."); \ + } + using Format = ge::Format; const static uint64_t kInvalidGeGraphId = -1; @@ -74,7 +80,10 @@ size_t RemoveRedundantHcomControlEdges(tensorflow::Graph *graph) { std::vector edges_to_remove; for (auto edge : graph->edges()) { if (edge->IsControlEdge() && (edge->src()->type_string() == kHcomType || edge->dst()->type_string() == kHcomType)) { - edges_to_remove.push_back(edge); + if (!(edge->src()->type_string() == "HcomAllReduce" && + edge->src()->def().attr().find("_npu_loss_scale") != edge->src()->def().attr().end())) { + edges_to_remove.push_back(edge); + } } } for (auto edge : edges_to_remove) { @@ -1002,6 +1011,12 @@ void NpuDevice::GetOrCreateSpec(TFE_Context *context, const char *op_name, const graph_dumper.DumpWithSubGraphs("after_optimize", optimize_graph->ToGraphDefDebug(), lib_def); } + NPU_CTX_REQUIRES_OK(s, TailingOptimizer(optimize_graph.get())); + + if (kDumpExecutionDetail || kDumpGraph) { + graph_dumper.DumpWithSubGraphs("after_tailing_optimize", optimize_graph->ToGraphDefDebug(), lib_def); + } + std::map> dependent_host_resources; NPU_CTX_REQUIRES_OK( s, TransResourceInput2GraphNode(context, optimize_graph.get(), num_inputs, inputs, dependent_host_resources)); @@ -1010,6 +1025,12 @@ void NpuDevice::GetOrCreateSpec(TFE_Context *context, const char *op_name, const graph_dumper.DumpWithSubGraphs("after_replace_resource_inputs", optimize_graph->ToGraphDefDebug(), lib_def); } + // NPU_CTX_REQUIRES_OK(s, OptimizeWeigth(optimize_graph.get())); + + // if (kDumpExecutionDetail || kDumpGraph) { + // graph_dumper.DumpWithSubGraphs("after_tailing_optimize", optimize_graph->ToGraphDefDebug(), lib_def); + // } + PruneFunction(*fdef, optimize_graph.get()); DLOG() << "NPU Start inferring shape for function node " << op_name; @@ -1977,3 +1998,303 @@ tensorflow::Status NpuDevice::GetMirroredIteratorShapesAndTypes(const tensorflow types.assign(iter->second.second.begin(), iter->second.second.end()); return tensorflow::Status::OK(); } + +static std::atomic graph_run_num(1); + +tensorflow::Status NpuDevice::OptimizeWeigth(tensorflow::Graph *graphIn) { + int graph_num; + graph_num = graph_run_num++; + + bool tailing_optimize = false; + bool npu_loss_scale = false; + for (tensorflow::Node *node : graphIn->op_nodes()) { + REQUIRES_NOT_NULL(node); + std::string op_name; + std::string op_type; + op_name = node->name(); + op_type = node->type_string(); + if (op_name.find("_Broadcast_tailing_optimize") != std::string::npos) { + tailing_optimize = true; + if (npu_loss_scale == true) { + break; + } + } + if (op_name.find("NPULossScaleOptimizer") != std::string::npos && + op_type == "NpuAllocFloatStatus") { + npu_loss_scale = true; + if (tailing_optimize == true) { + break; + } + } + } + + if (tailing_optimize) { + // tensorflow::int64 startTime = InferShapeUtil::GetCurrentTimestap(); + // char *need_print = getenv("PRINT_MODEL"); + + // if (need_print != nullptr && strcmp("1", need_print) == 0) { + // tensorflow::GraphDef ori_graph_def; + // graphIn->ToGraphDef(&ori_graph_def); + // string ori_model_path = GetDumpPath() + "BeforeWeightUpdateSharding_"; + // string omodel_path = ori_model_path + std::to_string(graph_num) + ".pbtxt"; + // Status status_out = WriteTextProto(Env::Default(), omodel_path, ori_graph_def); + // } + + std::vector in_nodes; + for (tensorflow::Node *node : graphIn->nodes()) { in_nodes.push_back(node); } + for (int i = in_nodes.size() - 1; i >= 0; i--) { + tensorflow::Node *node = in_nodes.at(i); + REQUIRES_NOT_NULL(node); + std::string op_type = node->type_string(); + std::string dst_name; + std::string dst_type; + if (op_type == "VarHandleOp" || op_type == "Identity" || + op_type == "ReadVariableOp") { + tensorflow::Node *var_node = nullptr; + tensorflow::Node *broadcast_node = nullptr; + std::vector remove_edges; + for (auto in_edge : node->in_edges()) { + REQUIRES_NOT_NULL(in_edge); + REQUIRES_NOT_NULL(in_edge->src()); + REQUIRES_NOT_NULL(in_edge->dst()); + if (in_edge->src()->IsVariable()) { + var_node = in_edge->src(); + break; + } + } + std::vector out_edges; + for (auto edge : node->out_edges()) { out_edges.push_back(edge); } + for (auto out_edge : out_edges) { + REQUIRES_NOT_NULL(out_edge); + REQUIRES_NOT_NULL(out_edge->src()); + REQUIRES_NOT_NULL(out_edge->dst()); + dst_name = out_edge->dst()->name(); + dst_type = out_edge->dst()->type_string(); + if (!npu_loss_scale) { + if (dst_name.find("_Broadcast_tailing_optimize") != std::string::npos && + dst_type == "HcomBroadcast") { + bool find_broadcast = false; + for (auto broadcast_edge : out_edge->dst()->in_edges()) { + REQUIRES_NOT_NULL(broadcast_edge); + REQUIRES_NOT_NULL(broadcast_edge->src()); + REQUIRES_NOT_NULL(broadcast_edge->dst()); + if (broadcast_edge->IsControlEdge()) { + find_broadcast = true; + // remove edge : reduce/apply --> broadcast + remove_edges.push_back(broadcast_edge); + } + } + if (find_broadcast) { + broadcast_node = out_edge->dst(); + //remove edge : VarHandleOp/Identity --> broadcast + remove_edges.push_back(out_edge); + for (auto broadcast_edge : out_edge->dst()->out_edges()) { + REQUIRES_NOT_NULL(broadcast_edge); + REQUIRES_NOT_NULL(broadcast_edge->src()); + REQUIRES_NOT_NULL(broadcast_edge->dst()); + if (broadcast_edge->IsControlEdge()) { + // remove edge : broadcast --> group + remove_edges.push_back(broadcast_edge); + } + } + break; + } + } + } else { + if (dst_type == "Switch") { + for (auto switch_out_edge : out_edge->dst()->out_edges()) { + REQUIRES_NOT_NULL(switch_out_edge); + REQUIRES_NOT_NULL(switch_out_edge->src()); + REQUIRES_NOT_NULL(switch_out_edge->dst()); + std::string node_name = switch_out_edge->dst()->name(); + std::string node_type = switch_out_edge->dst()->type_string(); + if (node_name.find("_Broadcast_tailing_optimize") != std::string::npos && + node_type == "HcomBroadcast") { + bool find_broadcast = false; + for (auto broadcast_edge : switch_out_edge->dst()->in_edges()) { + REQUIRES_NOT_NULL(broadcast_edge); + REQUIRES_NOT_NULL(broadcast_edge->src()); + REQUIRES_NOT_NULL(broadcast_edge->dst()); + if (broadcast_edge->IsControlEdge()) { + find_broadcast = true; + // remove edge : reduce/apply --> broadcast + remove_edges.push_back(broadcast_edge); + } + } + if (find_broadcast) { + broadcast_node = switch_out_edge->dst(); + //remove edge : Switch --> broadcast + remove_edges.push_back(switch_out_edge); + for (auto broadcast_edge : switch_out_edge->dst()->out_edges()) { + REQUIRES_NOT_NULL(broadcast_edge); + REQUIRES_NOT_NULL(broadcast_edge->src()); + REQUIRES_NOT_NULL(broadcast_edge->dst()); + if (broadcast_edge->IsControlEdge()) { + //remove edge : broadcast --> group + remove_edges.push_back(broadcast_edge); + } + } + break; + } + } + } + } + } + } + if (broadcast_node != nullptr && var_node != nullptr) { + for (auto edge : remove_edges) { + graphIn->RemoveEdge(edge); + } + // add edge : variable --> broadcast + graphIn->AddEdge(var_node, 0, broadcast_node, 0); + for (auto var_edge : var_node->out_edges()) { + REQUIRES_NOT_NULL(var_edge); + REQUIRES_NOT_NULL(var_edge->src()); + REQUIRES_NOT_NULL(var_edge->dst()); + if (var_edge->dst() != broadcast_node) { + graphIn->AddControlEdge(broadcast_node, var_edge->dst()); + } + } + } + } + } + + // if (need_print != nullptr && strcmp("1", need_print) == 0) { + // tensorflow::GraphDef omg_graph_def; + // graphIn->ToGraphDef(&omg_graph_def); + // string tmpmodel_path = "AfterWeightUpdateSharding_"; + // string tmodel_path = tmpmodel_path + std::to_string(graph_num) + ".pbtxt"; + // Status status_o = WriteTextProto(Env::Default(), tmodel_path, omg_graph_def); + // } + // int64 endTime = InferShapeUtil::GetCurrentTimestap(); + // LOG(INFO) << "WeightUpdateSharding_" << std::to_string(graph_num) << " success. [" + // << ((endTime - startTime) / kMicrosToMillis) << " ms]"; + } + + return tensorflow::Status::OK(); +} + +tensorflow::Status NpuDevice::TailingOptimizer(tensorflow::Graph *graphIn) { + int graph_num; + graph_num = graph_run_num++; + + bool tailing_optimize = true; + + if (tailing_optimize) { + std::vector in_nodes; + for (tensorflow::Node *node : graphIn->nodes()) { + // in_nodes.push_back(node); } + // // tensorflow::Node *npu_alloc_float = nullptr; + // // tensorflow::Node *npu_loss_scale_math = nullptr; + // // tensorflow::Node *npu_loss_scale_all = nullptr; + // // std::set need_remove; + // // std::set grads; + // // tensorflow::Node *last_allreduce = nullptr; + // for (int i = in_nodes.size() - 1; i >= 0; i--) { + // tensorflow::Node *node = in_nodes.at(i); + REQUIRES_NOT_NULL(node); + std::string op_type = node->type_string(); + std::string op_name = node->name(); + if (op_type == "NpuAllocFloatStatus") { + LOG(INFO) << "PHJLOG" << "attr: " << (node->def().attr().find("_npu_loss_scale") != node->def().attr().end()); + } + std::vector float_status_nodes = { "NpuAllocFloatStatus", + "NpuGetFloatStatus", + "NpuClearFloatStatus", + "Equal", + "All" }; + if (op_type == "NpuAllocFloatStatus" && + node->def().attr().find("_npu_loss_scale") != node->def().attr().end()) { + std::vector need_remove; + std::vector grads; + tensorflow::Node *last_allreduce = nullptr; + tensorflow::Node *float_status_allreduce = nullptr; + tensorflow::Node *previous_allreduce = nullptr; + for (auto in_edge : node->in_edges()) { + REQUIRES_NOT_NULL(in_edge); + REQUIRES_NOT_NULL(in_edge->src()); + REQUIRES_NOT_NULL(in_edge->dst()); + std::string in_node_type = in_edge->src()->type_string(); + if (in_edge->IsControlEdge()) { + if (in_node_type == "HcomAllReduce") { + if (last_allreduce == nullptr) { + last_allreduce = in_edge->src(); + } + } + need_remove.push_back(in_edge); + } + } + for (auto out_edge : node->out_edges()) { + REQUIRES_NOT_NULL(out_edge); + REQUIRES_NOT_NULL(out_edge->src()); + REQUIRES_NOT_NULL(out_edge->dst()); + if (out_edge->dst()->type_string() == "HcomAllReduce" && + out_edge->dst()->def().attr().find("_npu_loss_scale") != node->def().attr().end()) { + float_status_allreduce = out_edge->dst(); + } + } + previous_allreduce = last_allreduce; + while (previous_allreduce != nullptr) { + bool find_previous = false; + for (auto in_edge : previous_allreduce->in_edges()) { + REQUIRES_NOT_NULL(in_edge); + REQUIRES_NOT_NULL(in_edge->src()); + REQUIRES_NOT_NULL(in_edge->dst()); + std::string in_node_type = in_edge->src()->type_string(); + if (!in_edge->IsControlEdge()) { + grads.push_back(in_edge->src()); + } else if (in_node_type == "HcomAllReduce") { + previous_allreduce = in_edge->src(); + find_previous = true; + } + } + if (find_previous) { + continue; + } else { + break; + } + } + + LOG(INFO) << "PHJLOG" + << " last_allreduce: " << (last_allreduce != nullptr) + << " need_remove: " << need_remove.size() + << " grads: " << grads.size() + << " all: " << (all != nullptr); + + if (last_allreduce != nullptr && grads.size() > 0 && + float_status_allreduce != nullptr && need_remove.size() > 0) { + for (auto edge : need_remove) { + graphIn->RemoveEdge(edge); + } + for (auto grad : grads) { + graphIn->AddControlEdge(grad, node); + } + graphIn->AddControlEdge(float_status_allreduce, last_allreduce); + } + } + } + } + + // LOG(INFO) << "PHJLOG" << " npu_alloc_float: " << (npu_alloc_float != nullptr) + // << " npu_loss_scale_math: " << (npu_loss_scale_math != nullptr) + // << " npu_loss_scale_all: " << (npu_loss_scale_all != nullptr) + // << " last_allreduce: " << (last_allreduce != nullptr) + // << " need_remove: " << need_remove.size() + // << " grads: " << grads.size(); + // if (npu_alloc_float != nullptr && npu_loss_scale_math != nullptr && + // npu_loss_scale_all != nullptr && last_allreduce != nullptr && + // need_remove.size() > 0 && grads.size() > 0) { + // for (auto edge : need_remove) { + // graphIn->RemoveEdge(edge); + // } + // for (auto node : grads) { + // graphIn->AddControlEdge(node, npu_alloc_float); + // graphIn->AddControlEdge(node, npu_loss_scale_math); + // graphIn->AddControlEdge(node, npu_loss_scale_all); + // } + // graphIn->AddControlEdge(npu_loss_scale_all, last_allreduce); + // } + // } + + return tensorflow::Status::OK(); +} diff --git a/tf_adapter_2.x/npu_device/core/npu_device.h b/tf_adapter_2.x/npu_device/core/npu_device.h index bafeb9f57..e69d4006e 100644 --- a/tf_adapter_2.x/npu_device/core/npu_device.h +++ b/tf_adapter_2.x/npu_device/core/npu_device.h @@ -201,6 +201,10 @@ class NpuDevice { tensorflow::CancellationManager *CancellationManager() { return cancellation_manager_.get(); } + tensorflow::Status OptimizeWeigth(tensorflow::Graph *graphIn); + + tensorflow::Status TailingOptimizer(tensorflow::Graph *graphIn); + int device_id; tensorflow::string device_name; tensorflow::string underlying_device; diff --git a/tf_adapter_2.x/npu_device/core/weight_update_sharding_pass.cc b/tf_adapter_2.x/npu_device/core/weight_update_sharding_pass.cc new file mode 100644 index 000000000..864d6842b --- /dev/null +++ b/tf_adapter_2.x/npu_device/core/weight_update_sharding_pass.cc @@ -0,0 +1,428 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + +Copyright (C) 2019-2020. Huawei Technologies Co., Ltd. All rights reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "weight_update_sharding_pass.h" + +#include +#include +#include +#include + +#include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/public/session_options.h" +#include "tf_adapter/common/adp_logger.h" +#include "tf_adapter/common/common.h" +#include "tf_adapter/util/npu_attrs.h" +#include "tf_adapter/util/infershape_util.h" + +namespace tensorflow { +namespace grappler { +static const int64 kMicrosToMillis = 1000; + +static std::atomic graph_run_num(1); + +Status WeightUpdateShardingOptimizer::Optimize(Cluster *cluster, const GrapplerItem &item, GraphDef *optimizedGraph) { + REQUIRES_NOT_NULL(optimizedGraph); + GraphDef graphOrigin; + std::map, std::vector> fusionHcomOps; + std::map, int64_t> currentGradSumSize; + *optimizedGraph = item.graph; + LOG(INFO) << "INFO: GradFusionOptimizer::Optimize begin, OriginNodeNum: " << item.graph.node_size(); + + if (fusionTensorSize < 0) { return errors::InvalidArgument("FUSION_TENSOR_SIZE is invalid"); } + + REQUIRES_STATUS_OK(TopologicalSort(optimizedGraph)); + nodeMap_.reset(new (std::nothrow) NodeMap(optimizedGraph)); + REQUIRES_NOT_NULL(nodeMap_); + fusionOpInfo_.clear(); + fusionOpPool_.clear(); + graphOrigin = *optimizedGraph; + for (const auto &nodeDef : graphOrigin.node()) { nameToNode_[nodeDef.name()] = nodeDef; } + + std::unique_ptr device_graph(new Graph(OpRegistry::Global())); + GraphConstructorOptions device_opts; + // There are internal operations (e.g., send/recv) that we now allow. + device_opts.allow_internal_ops = true; + device_opts.expect_device_spec = true; + TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(device_opts, graphOrigin, device_graph.get())); + TF_RETURN_IF_ERROR(Run(device_graph.get())); + device_graph->get()->ToGraphDef(optimizedGraph); + + return Status::OK(); +} + +Status WeightUpdateShardingOptimizer::Run(Graph *graphIn) { + int graph_num; + graph_num = graph_run_num++; + + bool weight_update_sharding = false; + bool npu_loss_scale = false; + for (Node *node : graphIn->op_nodes()) { + REQUIRES_NOT_NULL(node); + std::string op_name; + std::string op_type; + op_name = node->name(); + op_type = node->type_string(); + if (op_name.find("_Broadcast_Weight_Update_Sharding") != std::string::npos) { + weight_update_sharding = true; + if (npu_loss_scale == true) { + break; + } + } + if (op_name.find("NPULossScaleOptimizer") != std::string::npos && + op_type == "NpuAllocFloatStatus") { + npu_loss_scale = true; + if (weight_update_sharding == true) { + break; + } + } + } + + if (weight_update_sharding) { + int64 startTime = InferShapeUtil::GetCurrentTimestap(); + char *need_print = getenv("PRINT_MODEL"); + + if (need_print != nullptr && strcmp("1", need_print) == 0) { + GraphDef ori_graph_def; + graphIn->ToGraphDef(&ori_graph_def); + string ori_model_path = GetDumpPath() + "BeforeWeightUpdateSharding_"; + string omodel_path = ori_model_path + std::to_string(graph_num) + ".pbtxt"; + Status status_out = WriteTextProto(Env::Default(), omodel_path, ori_graph_def); + } + + std::vector in_nodes; + for (Node *node : graphIn->nodes()) { in_nodes.push_back(node); } + for (int i = in_nodes.size() - 1; i >= 0; i--) { + Node *node = in_nodes.at(i); + REQUIRES_NOT_NULL(node); + std::string op_type = node->type_string(); + std::string dst_name; + std::string dst_type; + if (op_type == "VarHandleOp" || op_type == "Identity" || + op_type == "ReadVariableOp") { + Node *var_node = nullptr; + Node *broadcast_node = nullptr; + std::vector remove_edges; + for (auto in_edge : node->in_edges()) { + REQUIRES_NOT_NULL(in_edge); + REQUIRES_NOT_NULL(in_edge->src()); + REQUIRES_NOT_NULL(in_edge->dst()); + if (in_edge->src()->IsVariable()) { + var_node = in_edge->src(); + break; + } + } + std::vector out_edges; + for (auto edge : node->out_edges()) { out_edges.push_back(edge); } + for (auto out_edge : out_edges) { + REQUIRES_NOT_NULL(out_edge); + REQUIRES_NOT_NULL(out_edge->src()); + REQUIRES_NOT_NULL(out_edge->dst()); + dst_name = out_edge->dst()->name(); + dst_type = out_edge->dst()->type_string(); + if (!npu_loss_scale) { + if (dst_name.find("_Broadcast_Weight_Update_Sharding") != std::string::npos && + dst_type == "HcomBroadcast") { + bool find_broadcast = false; + for (auto broadcast_edge : out_edge->dst()->in_edges()) { + REQUIRES_NOT_NULL(broadcast_edge); + REQUIRES_NOT_NULL(broadcast_edge->src()); + REQUIRES_NOT_NULL(broadcast_edge->dst()); + if (broadcast_edge->IsControlEdge()) { + find_broadcast = true; + // remove edge : reduce/apply --> broadcast + remove_edges.push_back(broadcast_edge); + } + } + if (find_broadcast) { + broadcast_node = out_edge->dst(); + //remove edge : VarHandleOp/Identity --> broadcast + remove_edges.push_back(out_edge); + for (auto broadcast_edge : out_edge->dst()->out_edges()) { + REQUIRES_NOT_NULL(broadcast_edge); + REQUIRES_NOT_NULL(broadcast_edge->src()); + REQUIRES_NOT_NULL(broadcast_edge->dst()); + if (broadcast_edge->IsControlEdge()) { + // remove edge : broadcast --> group + remove_edges.push_back(broadcast_edge); + } + } + break; + } + } + } else { + if (dst_type == "Switch") { + for (auto switch_out_edge : out_edge->dst()->out_edges()) { + REQUIRES_NOT_NULL(switch_out_edge); + REQUIRES_NOT_NULL(switch_out_edge->src()); + REQUIRES_NOT_NULL(switch_out_edge->dst()); + std::string node_name = switch_out_edge->dst()->name(); + std::string node_type = switch_out_edge->dst()->type_string(); + if (node_name.find("_Broadcast_Weight_Update_Sharding") != std::string::npos && + node_type == "HcomBroadcast") { + bool find_broadcast = false; + for (auto broadcast_edge : switch_out_edge->dst()->in_edges()) { + REQUIRES_NOT_NULL(broadcast_edge); + REQUIRES_NOT_NULL(broadcast_edge->src()); + REQUIRES_NOT_NULL(broadcast_edge->dst()); + if (broadcast_edge->IsControlEdge()) { + find_broadcast = true; + // remove edge : reduce/apply --> broadcast + remove_edges.push_back(broadcast_edge); + } + } + if (find_broadcast) { + broadcast_node = switch_out_edge->dst(); + //remove edge : Switch --> broadcast + remove_edges.push_back(switch_out_edge); + for (auto broadcast_edge : switch_out_edge->dst()->out_edges()) { + REQUIRES_NOT_NULL(broadcast_edge); + REQUIRES_NOT_NULL(broadcast_edge->src()); + REQUIRES_NOT_NULL(broadcast_edge->dst()); + if (broadcast_edge->IsControlEdge()) { + //remove edge : broadcast --> group + remove_edges.push_back(broadcast_edge); + } + } + break; + } + } + } + } + } + } + if (broadcast_node != nullptr && var_node != nullptr) { + for (auto edge : remove_edges) { + graphIn->RemoveEdge(edge); + } + // add edge : variable --> broadcast + graphIn->AddEdge(var_node, 0, broadcast_node, 0); + for (auto var_edge : var_node->out_edges()) { + REQUIRES_NOT_NULL(var_edge); + REQUIRES_NOT_NULL(var_edge->src()); + REQUIRES_NOT_NULL(var_edge->dst()); + if (var_edge->dst() != broadcast_node) { + graphIn->AddControlEdge(broadcast_node, var_edge->dst()); + } + } + } + } + } + + if (need_print != nullptr && strcmp("1", need_print) == 0) { + GraphDef omg_graph_def; + graphIn->ToGraphDef(&omg_graph_def); + string tmpmodel_path = "AfterWeightUpdateSharding_"; + string tmodel_path = tmpmodel_path + std::to_string(graph_num) + ".pbtxt"; + Status status_o = WriteTextProto(Env::Default(), tmodel_path, omg_graph_def); + } + int64 endTime = InferShapeUtil::GetCurrentTimestap(); + ADP_LOG(INFO) << "WeightUpdateSharding_" << std::to_string(graph_num) << " success. [" + << ((endTime - startTime) / kMicrosToMillis) << " ms]"; + } + + return Status::OK(); +} + +REGISTER_GRAPH_OPTIMIZER(WeightUpdateShardingOptimizer); +} + +Status OptimizeWeigth(Graph *graphIn) { + int graph_num; + graph_num = graph_run_num++; + + bool weight_update_sharding = false; + bool npu_loss_scale = false; + for (Node *node : graphIn->op_nodes()) { + REQUIRES_NOT_NULL(node); + std::string op_name; + std::string op_type; + op_name = node->name(); + op_type = node->type_string(); + if (op_name.find("_Broadcast_Weight_Update_Sharding") != std::string::npos) { + weight_update_sharding = true; + if (npu_loss_scale == true) { + break; + } + } + if (op_name.find("NPULossScaleOptimizer") != std::string::npos && + op_type == "NpuAllocFloatStatus") { + npu_loss_scale = true; + if (weight_update_sharding == true) { + break; + } + } + } + + if (weight_update_sharding) { + int64 startTime = InferShapeUtil::GetCurrentTimestap(); + char *need_print = getenv("PRINT_MODEL"); + + if (need_print != nullptr && strcmp("1", need_print) == 0) { + GraphDef ori_graph_def; + graphIn->ToGraphDef(&ori_graph_def); + string ori_model_path = GetDumpPath() + "BeforeWeightUpdateSharding_"; + string omodel_path = ori_model_path + std::to_string(graph_num) + ".pbtxt"; + Status status_out = WriteTextProto(Env::Default(), omodel_path, ori_graph_def); + } + + std::vector in_nodes; + for (Node *node : graphIn->nodes()) { in_nodes.push_back(node); } + for (int i = in_nodes.size() - 1; i >= 0; i--) { + Node *node = in_nodes.at(i); + REQUIRES_NOT_NULL(node); + std::string op_type = node->type_string(); + std::string dst_name; + std::string dst_type; + if (op_type == "VarHandleOp" || op_type == "Identity" || + op_type == "ReadVariableOp") { + Node *var_node = nullptr; + Node *broadcast_node = nullptr; + std::vector remove_edges; + for (auto in_edge : node->in_edges()) { + REQUIRES_NOT_NULL(in_edge); + REQUIRES_NOT_NULL(in_edge->src()); + REQUIRES_NOT_NULL(in_edge->dst()); + if (in_edge->src()->IsVariable()) { + var_node = in_edge->src(); + break; + } + } + std::vector out_edges; + for (auto edge : node->out_edges()) { out_edges.push_back(edge); } + for (auto out_edge : out_edges) { + REQUIRES_NOT_NULL(out_edge); + REQUIRES_NOT_NULL(out_edge->src()); + REQUIRES_NOT_NULL(out_edge->dst()); + dst_name = out_edge->dst()->name(); + dst_type = out_edge->dst()->type_string(); + if (!npu_loss_scale) { + if (dst_name.find("_Broadcast_Weight_Update_Sharding") != std::string::npos && + dst_type == "HcomBroadcast") { + bool find_broadcast = false; + for (auto broadcast_edge : out_edge->dst()->in_edges()) { + REQUIRES_NOT_NULL(broadcast_edge); + REQUIRES_NOT_NULL(broadcast_edge->src()); + REQUIRES_NOT_NULL(broadcast_edge->dst()); + if (broadcast_edge->IsControlEdge()) { + find_broadcast = true; + // remove edge : reduce/apply --> broadcast + remove_edges.push_back(broadcast_edge); + } + } + if (find_broadcast) { + broadcast_node = out_edge->dst(); + //remove edge : VarHandleOp/Identity --> broadcast + remove_edges.push_back(out_edge); + for (auto broadcast_edge : out_edge->dst()->out_edges()) { + REQUIRES_NOT_NULL(broadcast_edge); + REQUIRES_NOT_NULL(broadcast_edge->src()); + REQUIRES_NOT_NULL(broadcast_edge->dst()); + if (broadcast_edge->IsControlEdge()) { + // remove edge : broadcast --> group + remove_edges.push_back(broadcast_edge); + } + } + break; + } + } + } else { + if (dst_type == "Switch") { + for (auto switch_out_edge : out_edge->dst()->out_edges()) { + REQUIRES_NOT_NULL(switch_out_edge); + REQUIRES_NOT_NULL(switch_out_edge->src()); + REQUIRES_NOT_NULL(switch_out_edge->dst()); + std::string node_name = switch_out_edge->dst()->name(); + std::string node_type = switch_out_edge->dst()->type_string(); + if (node_name.find("_Broadcast_Weight_Update_Sharding") != std::string::npos && + node_type == "HcomBroadcast") { + bool find_broadcast = false; + for (auto broadcast_edge : switch_out_edge->dst()->in_edges()) { + REQUIRES_NOT_NULL(broadcast_edge); + REQUIRES_NOT_NULL(broadcast_edge->src()); + REQUIRES_NOT_NULL(broadcast_edge->dst()); + if (broadcast_edge->IsControlEdge()) { + find_broadcast = true; + // remove edge : reduce/apply --> broadcast + remove_edges.push_back(broadcast_edge); + } + } + if (find_broadcast) { + broadcast_node = switch_out_edge->dst(); + //remove edge : Switch --> broadcast + remove_edges.push_back(switch_out_edge); + for (auto broadcast_edge : switch_out_edge->dst()->out_edges()) { + REQUIRES_NOT_NULL(broadcast_edge); + REQUIRES_NOT_NULL(broadcast_edge->src()); + REQUIRES_NOT_NULL(broadcast_edge->dst()); + if (broadcast_edge->IsControlEdge()) { + //remove edge : broadcast --> group + remove_edges.push_back(broadcast_edge); + } + } + break; + } + } + } + } + } + } + if (broadcast_node != nullptr && var_node != nullptr) { + for (auto edge : remove_edges) { + graphIn->RemoveEdge(edge); + } + // add edge : variable --> broadcast + graphIn->AddEdge(var_node, 0, broadcast_node, 0); + for (auto var_edge : var_node->out_edges()) { + REQUIRES_NOT_NULL(var_edge); + REQUIRES_NOT_NULL(var_edge->src()); + REQUIRES_NOT_NULL(var_edge->dst()); + if (var_edge->dst() != broadcast_node) { + graphIn->AddControlEdge(broadcast_node, var_edge->dst()); + } + } + } + } + } + + if (need_print != nullptr && strcmp("1", need_print) == 0) { + GraphDef omg_graph_def; + graphIn->ToGraphDef(&omg_graph_def); + string tmpmodel_path = "AfterWeightUpdateSharding_"; + string tmodel_path = tmpmodel_path + std::to_string(graph_num) + ".pbtxt"; + Status status_o = WriteTextProto(Env::Default(), tmodel_path, omg_graph_def); + } + int64 endTime = InferShapeUtil::GetCurrentTimestap(); + ADP_LOG(INFO) << "WeightUpdateSharding_" << std::to_string(graph_num) << " success. [" + << ((endTime - startTime) / kMicrosToMillis) << " ms]"; + } + + return Status::OK(); +} + + +} // namespace tensorflow diff --git a/tf_adapter_2.x/npu_device/core/weight_update_sharding_pass.h b/tf_adapter_2.x/npu_device/core/weight_update_sharding_pass.h new file mode 100644 index 000000000..02f7bcf7b --- /dev/null +++ b/tf_adapter_2.x/npu_device/core/weight_update_sharding_pass.h @@ -0,0 +1,51 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + +Copyright (C) 2019-2020. Huawei Technologies Co., Ltd. All rights reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_WEIGHT_UPDATE_SHARDING_OPTIMIZER_H_ +#define TENSORFLOW_WEIGHT_UPDATE_SHARDING_OPTIMIZER_H_ + +#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h" +#include "tensorflow/core/grappler/utils.h" + +namespace tensorflow { +namespace grappler { +class WeightUpdateShardingOptimizer : public CustomGraphOptimizer { + public: + WeightUpdateShardingOptimizer() = default; + ~WeightUpdateShardingOptimizer() override = default; + Status Optimize(Cluster *cluster, const GrapplerItem &item, GraphDef *optimizedGraph) override; + + private: + Status Run(Graph *graphIn); +}; + +} + +Status OptimizeWeigth(Graph *graphIn); + // end namespace grappler +} // end namespace tensorflow +#endif // TENSORFLOW_WEIGHT_UPDATE_SHARDING_OPTIMIZER_H_ diff --git a/tf_adapter_2.x/python/npu_device/_api/distribute/__init__.py b/tf_adapter_2.x/python/npu_device/_api/distribute/__init__.py index a83b076a3..31b75154d 100644 --- a/tf_adapter_2.x/python/npu_device/_api/distribute/__init__.py +++ b/tf_adapter_2.x/python/npu_device/_api/distribute/__init__.py @@ -1,3 +1,4 @@ from npu_device.distribute.hccl import all_reduce from npu_device.distribute.hccl import broadcast from npu_device.distribute.hccl import shard_and_rebatch_dataset +from npu_device.distribute.weight_update_grouping import grouping_gradients_apply diff --git a/tf_adapter_2.x/python/npu_device/distribute/weight_update_grouping.py b/tf_adapter_2.x/python/npu_device/distribute/weight_update_grouping.py new file mode 100644 index 000000000..f4cb129b3 --- /dev/null +++ b/tf_adapter_2.x/python/npu_device/distribute/weight_update_grouping.py @@ -0,0 +1,159 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2021. All rights reserved. +# Description: Common depends and micro defines for and only for data preprocess module + +import os +import tensorflow as tf +from npu_device.distribute import hccl_ops +from npu_device.npu_device import never_nested_function + +class GroupingVars(): + def __init__(self, vars, rank_size): + self._vars = [] + for var in vars: + if var is not None: + item = self._GradDivisionItem(var) + self._vars.append(item) + self._fair_division(rank_size) + + def _fair_division(self, number): + def get_sum(list): + res = 0 + for item in list: + res += item.size + return res + + def get_average(size): + large_number_list = [] + average_size = 0 + res = 0 + if size == 1: + for item in self._vars: + if item.root_rank_id < 0: + res += item.size + return res + while True: + res = 0 + find_large_number = False + for item in self._vars: + if item not in large_number_list and item.root_rank_id < 0: + res += item.size + average_size = res // (size - len(large_number_list)) + for item in self._vars: + if item not in large_number_list and item.root_rank_id < 0 and item.size > res - item.size: + find_large_number = True + large_number_list.append(item) + if not find_large_number: + break + return average_size + + if number > len(self._vars) or number < 0: + raise ValueError("'number' is greater than the number of vars or 'number' is less than 0. ") + elif number == len(self._vars): + for i in range(len(self._vars)): + self._vars[i].root_rank_id = i + return + + j = -1 + last_index = 0 + while True: + j = j+1 + total_number = number - j + if total_number == 0: + break + average_size = get_average(self._vars, total_number) + tmp_list = [] + tmp_last_index = last_index + for i in range(tmp_last_index, len(self._vars) - total_number + 1): + if get_sum(tmp_list) + self._vars[i].size <= average_size: + self._vars[i].root_rank_id = j + tmp_list.append(self._vars[i]) + last_index = i+1 + else: + if len(tmp_list) <= 0: + self._vars[i].root_rank_id = j + tmp_list.append(self._vars[i]) + last_index = i+1 + elif (get_sum(tmp_list) + self._vars[i].size - average_size) <= (average_size - get_sum(tmp_list)): + self._vars[i].root_rank_id = j + tmp_list.append(self._vars[i]) + last_index = i+1 + break + return + + class _GradDivisionItem(): + def __init__(self, var): + self.var = var + self.size = self.__get_size() + self.root_rank_id = -1 + + def __get_size(self): + size = 1 + var_shape = self.var.shape + if len(var_shape) <= 0: + return 0 + for i in range(len(var_shape)): + size = size * int(var_shape[i]) + size = size * self.var.dtype.size + return size + + def get_all_grad_item(self): + return self._vars + + def get_gid_by_var(self, var): + gid = -1 + for item in self._vars: + if item.var is var: + gid = item.root_rank_id + return gid + +@never_nested_function +def grouping_gradients_apply(apply_func, grads, vars, *args, **kwargs): + + rank_size = os.getenv('RANK_SIZE') + if rank_size == None or int(rank_size) <= 1: + return apply_func(zip(grads, vars), *args, **kwargs) + + op_list = [] + local_rank_id = os.getenv('RANK_ID') + if local_rank_id == None or int(local_rank_id) < 0: + raise ValueError('Please set the correct RANK_ID value, current RANK_ID is:', local_rank_id) + grouping_vars = GroupingVars(vars, int(rank_size)) + local_grads_and_vars = [] + for i in range(len(vars)): + var = vars[i] + rank_id = grouping_vars.get_gid_by_var(var) + if rank_id >= 0 and rank_id == int(local_rank_id): + local_grads_and_vars.append((grads[i], var)) + apply_res = apply_func(local_grads_and_vars, *args, **kwargs) + with tf.control_dependencies([apply_res]): + with tf.name_scope("NPU_Broadcast_Weight_Update_Sharding"): + for i in range(len(vars)): + var = vars[i] + rank_id = grouping_vars.get_gid_by_var(var) + with tf.control_dependencies(op_list): + outputs = hccl_ops.broadcast([var], rank_id, 0) + if outputs is not None: + op_list.append(outputs[0].op) + for i in range(len(vars)): + var = vars[i] + rank_id = grouping_vars.get_gid_by_var(var) + if rank_id >= 0 and rank_id != int(local_rank_id): + op_list.append(grads[i]) + op_list.append(apply_res) + return tf.group(op_list) + +@never_nested_function +def grouping_broadcast(vars): + rank_size = os.getenv('RANK_SIZE') + if rank_size == None or int(rank_size) <= 1: + raise ValueError('Please set the correct RANK_ID value, current RANK_ID is:') + grouping_vars = GroupingVars(vars, rank_size) + op_list = [] + with tf.name_scope("NPU_Broadcast_Weight_Update_Sharding"): + for var in vars: + rank_id = grouping_vars.get_gid_by_var(var) + with tf.control_dependencies(op_list): + outputs = hccl_ops.broadcast([var], rank_id, 0) + if outputs is not None: + op_list.append(outputs[0].op) + return tf.group(op_list) \ No newline at end of file diff --git a/tf_adapter_2.x/python/npu_device/train/optimizer/npu_loss_scale_optimizer.py b/tf_adapter_2.x/python/npu_device/train/optimizer/npu_loss_scale_optimizer.py index aaa973fb8..b634dd3c0 100644 --- a/tf_adapter_2.x/python/npu_device/train/optimizer/npu_loss_scale_optimizer.py +++ b/tf_adapter_2.x/python/npu_device/train/optimizer/npu_loss_scale_optimizer.py @@ -1,4 +1,6 @@ import tensorflow as tf +from tensorflow.core.framework import attr_value_pb2 +from tensorflow.python.framework import ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.util import nest @@ -17,8 +19,10 @@ from npu_device.distribute.hccl import all_reduce def _npu_finite_status_after_executed(executed_ops): if not isinstance(executed_ops, (tuple, list)): executed_ops = [executed_ops] - with tf.control_dependencies(executed_ops): - current_status = gen_npu_ops.npu_alloc_float_status() + with ops.get_default_graph()._attr_scope( + { "_npu_loss_scale": attr_value_pb2.AttrValue(b=True) }): + with tf.control_dependencies(executed_ops): + current_status = gen_npu_ops.npu_alloc_float_status() assign_float_status = gen_npu_ops.npu_get_float_status(current_status) finite_status = gen_npu_ops.npu_clear_float_status(assign_float_status) if global_npu_ctx() and global_npu_ctx().workers_num > 1: -- Gitee