From 638267ba7a9c43d806ce1bb5cf3162163e22a659 Mon Sep 17 00:00:00 2001 From: lianghuikang <505519763@qq.com> Date: Thu, 18 Mar 2021 10:42:50 +0800 Subject: [PATCH] =?UTF-8?q?=E6=94=AF=E6=8C=81=E6=B7=B7=E5=90=88=E8=AE=A1?= =?UTF-8?q?=E7=AE=97=E9=85=8D=E7=BD=AE=E9=A6=96=E5=B0=BE=E7=AE=97=E5=AD=90?= =?UTF-8?q?=E5=90=8D=E4=B8=8B=E6=B2=89=E9=85=8D=E7=BD=AE=E3=80=82?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../optimizers/om_partition_subgraphs_pass.cc | 174 +++++++++++++++++- .../optimizers/om_partition_subgraphs_pass.h | 6 +- tf_adapter/util/npu_attrs.cc | 15 ++ 3 files changed, 184 insertions(+), 11 deletions(-) diff --git a/tf_adapter/optimizers/om_partition_subgraphs_pass.cc b/tf_adapter/optimizers/om_partition_subgraphs_pass.cc index 87f73041d..c9fbfc96d 100644 --- a/tf_adapter/optimizers/om_partition_subgraphs_pass.cc +++ b/tf_adapter/optimizers/om_partition_subgraphs_pass.cc @@ -214,15 +214,16 @@ bool EndsWith(const std::string &str, const std::string &suffix) { return str.size() >= suffix.size() && str.compare(str.size() - suffix.size(), suffix.size(), suffix) == 0; } -bool IsWhiteListSupport(const string &op_name, bool mix_compile_mode, const string &node_name) { +bool IsWhiteListSupport(const string &op_name, bool mix_compile_mode, const string &node_name, + bool support_const = false) { static const std::string suffix_op = "Dataset"; static const std::string suffix_op_v2 = "DatasetV2"; auto identifier = NpuOpsIdentifier::GetInstance(mix_compile_mode); bool ans = (identifier->IsNpuSupported(op_name, node_name)) && !EndsWith(op_name, suffix_op) && - !EndsWith(op_name, suffix_op_v2) && !(op_name == "Const") && !(op_name == "_Arg") && - !(op_name == "_Retval") && !(op_name == "StringJoin"); + !EndsWith(op_name, suffix_op_v2) && (support_const || !(op_name == "Const")) && + !(op_name == "_Arg") && !(op_name == "_Retval") && !(op_name == "StringJoin"); if (!ans) { auto ret = not_support_nodes.insert(op_name); if (ret.second) { @@ -313,15 +314,17 @@ bool IsNpuSupportingFunc(Node *node, FunctionLibraryDefinition *func_lib, int de return true; } -bool IsNpuSupportingNode(const NodeDef &node_def, bool mix_compile_mode, FunctionLibraryDefinition *func_lib) { +bool IsNpuSupportingNode(const NodeDef &node_def, bool mix_compile_mode, + FunctionLibraryDefinition *func_lib, bool support_const) { if (IsWithoutNpuScope(node_def)) { return false; } - if (IsWhiteListSupport(node_def.op(), mix_compile_mode, node_def.name())) { return true; } + if (IsWhiteListSupport(node_def.op(), mix_compile_mode, node_def.name(), support_const)) { return true; } if (IsNpuSupportingFunc(node_def.op(), func_lib, 0)) { return true; } return false; } -bool IsNpuSupportingNode(Node *node, bool mix_compile_mode, FunctionLibraryDefinition *func_lib) { - return IsNpuSupportingNode(node->def(), mix_compile_mode, func_lib); +bool IsNpuSupportingNode(Node *node, bool mix_compile_mode, FunctionLibraryDefinition *func_lib, + bool support_const) { + return IsNpuSupportingNode(node->def(), mix_compile_mode, func_lib, support_const); } bool IsUnSupportedResource(bool mix_compile_mode, Node* node) { @@ -344,6 +347,154 @@ bool IsUnSupportedResource(bool mix_compile_mode, Node* node) { } return false; } +using NodeSet = std::set; +using NodeMap = std::map; +using NodeStack = std::vector; +using GraphPath = std::vector; +using GraphPaths = std::vector; + +int FindAllPath(Node *op_head, Node *op_tail, NodeSet &ops_save) { + if (!op_head || !op_tail || op_head == op_tail) { return ops_save.size(); } + + NodeMap seen; + NodeStack stack; + GraphPath path; + + NodeSet empty; + seen.insert(NodeMap::value_type(op_head, empty)); + stack.push_back(op_head); + + while (!stack.empty()) { + Node *cur_node = stack.back(); + stack.pop_back(); + + if (!cur_node) { + seen.erase(path.back()); + if (path.size() >= 2) { seen[path[path.size() - 2]].erase(path.back()); } + path.pop_back(); + continue; + } + + path.push_back(cur_node); + if (path.size() >= 2) { seen[path[path.size() - 2]].insert(path.back()); } + stack.push_back(nullptr); + + if (cur_node == op_tail || ops_save.count(cur_node) > 0) { + for (auto node : path) { ops_save.insert(node); } + continue; + } + + for (auto out_node : cur_node->out_nodes()) { + unsigned int n = 0; + for (auto out_out_node : out_node->out_nodes()) { ++n; } + if (seen.insert(NodeMap::value_type(out_node, empty)).second) { + stack.push_back(out_node); + } else if (seen[out_node].size() < n) { + stack.push_back(out_node); + } + } + } + return ops_save.size(); +} + +using IOP = std::pair; +using OneGraphIOP = std::vector; +using AllGraphIOP = std::vector; + +int ParseInOutPair(const std::string &in_out_pair, AllGraphIOP &all_graph_iop) { + using Nodes = std::vector; + using Pair = std::vector; + using Pairs = std::vector; + + int model = 0; + std::string s; + Nodes nodes; + Pair pair; + Pairs pairs; + for (char c : in_out_pair) { + switch (c) { + case '[': + ++model; + break; + case ']': + case ',': + if (1 == model && !pair.empty()) { pairs.emplace_back(std::move(pair)); } + else if (2 == model && !nodes.empty()) { pair.emplace_back(std::move(nodes)); } + else if (3 == model && !s.empty()) { nodes.emplace_back(std::move(s)); } + if (']' == c) { --model; } + break; + case ' ': + case '\t': + case '\'': + break; + default: + s += c; + } + } + + int size = 0; + for (auto &pair : pairs) { + OneGraphIOP one_graph_iop; + if (pair.size() < 2) { continue; } + for (auto &in : pair[0]) { + for (auto &out : pair[1]) { + ++size; + one_graph_iop.push_back(IOP(in, out)); + } + } + all_graph_iop.push_back(one_graph_iop); + } + return size; +} + +Status FindCandidatesByInOutPair(const Graph &graph, OrderedNodeSet *candidates, + FunctionLibraryDefinition *func_lib, const std::string &in_out_pair) { + AllGraphIOP all_graph_iop; + if (0 >= ParseInOutPair(in_out_pair, all_graph_iop)) { + return errors::Internal(in_out_pair, "is invalid."); + } + for (auto &one_graph_iop : all_graph_iop) { + for (auto &iop : one_graph_iop) { + LOG(INFO) << iop.first << " -> " << iop.second; + + Node *op_head = nullptr; + Node *op_tail = nullptr; + for (auto node : graph.nodes()) { + if (node->name() == iop.first) { op_head = node; } + if (node->name() == iop.second) { op_tail = node; } + if (op_head && op_tail) { + break; + } + } + if (!op_head && op_tail) { + ADP_LOG(ERROR) << iop.first << " -> " << iop.second << ", but " << iop.first << " is not find."; + LOG(ERROR) << iop.first << " -> " << iop.second << ", but " << iop.first << " is not find."; + } + if (op_head && !op_tail) { + ADP_LOG(ERROR) << iop.first << " -> " << iop.second << ", but " << iop.second << " is not find."; + LOG(ERROR) << iop.first << " -> " << iop.second << ", but " << iop.second << " is not find."; + } + if (op_head && op_tail) { + NodeSet ops_save; + if (0 < FindAllPath(op_head, op_tail, ops_save)) { + for (auto node : ops_save) { + candidates->insert(node); + } + } + } + } + LOG(INFO) << "\n"; + } + if (candidates->empty()) { + return errors::Internal("no node is dump."); + } + for (auto node : *candidates) { + if (!IsNpuSupportingNode(node, true, func_lib, true)) { + return errors::Internal(node->name(), " is not supported npu node."); + } + } + return Status::OK(); +} Status FindNpuSupportCandidates(const Graph &graph, OrderedNodeSet *candidates, FunctionLibraryDefinition *func_lib, bool enableDP, bool mix_compile_mode) { @@ -686,8 +837,13 @@ Status MarkForPartition(std::unique_ptr *graphIn, int &clusterNum, bool m Graph *graph = graphIn->get(); bool enableDP = pass_options["enable_dp"] == "1"; OrderedNodeSet npuSupportCandidates; - TF_RETURN_IF_ERROR(FindNpuSupportCandidates(*graph, &npuSupportCandidates, func_lib, enableDP, mix_compile_mode)); - TF_RETURN_IF_ERROR(AddRelationalConst(*graph, &npuSupportCandidates)); + if (!pass_options["in_out_pair"].empty()) { + TF_RETURN_IF_ERROR(FindCandidatesByInOutPair(*graph, &npuSupportCandidates, func_lib, pass_options["in_out_pair"])); + } + if (npuSupportCandidates.empty()) { + TF_RETURN_IF_ERROR(FindNpuSupportCandidates(*graph, &npuSupportCandidates, func_lib, enableDP, mix_compile_mode)); + TF_RETURN_IF_ERROR(AddRelationalConst(*graph, &npuSupportCandidates)); + } std::map> cluster_map; tensorflow::GraphCycles cycles; diff --git a/tf_adapter/optimizers/om_partition_subgraphs_pass.h b/tf_adapter/optimizers/om_partition_subgraphs_pass.h index 7a72061bd..7392fb0d7 100644 --- a/tf_adapter/optimizers/om_partition_subgraphs_pass.h +++ b/tf_adapter/optimizers/om_partition_subgraphs_pass.h @@ -47,8 +47,10 @@ Status MarkForPartition(const GraphOptimizationPassOptions &options, int &cluste Status OMPartitionSubgraphsInFunctions(string groupAttribute, const GraphOptimizationPassOptions &options, string graph_format); -bool IsNpuSupportingNode(const NodeDef &node_def, bool mix_compile_mode, FunctionLibraryDefinition *func_lib); -bool IsNpuSupportingNode(Node *node, bool mix_compile_mode, FunctionLibraryDefinition *func_lib); +bool IsNpuSupportingNode(const NodeDef &node_def, bool mix_compile_mode, + FunctionLibraryDefinition *func_lib, bool support_const = false); +bool IsNpuSupportingNode(Node *node, bool mix_compile_mode, FunctionLibraryDefinition *func_lib, + bool support_const = false); } // namespace OMSplitter class OMPartitionSubgraphsPass : public GraphOptimizationPass { diff --git a/tf_adapter/util/npu_attrs.cc b/tf_adapter/util/npu_attrs.cc index 2d6a2f1bf..15ec63979 100644 --- a/tf_adapter/util/npu_attrs.cc +++ b/tf_adapter/util/npu_attrs.cc @@ -420,6 +420,7 @@ std::map NpuAttrs::GetPassOptions(const GraphOptimizat std::string dynamic_inputs_shape_range; int local_rank_id = -1; std::string local_device_list; + std::string in_out_pair; for (const auto &custom_optimizer : rewrite_options.custom_optimizers()) { if (custom_optimizer.name() == "NpuOptimizer") { do_npu_optimizer = true; @@ -464,6 +465,7 @@ std::map NpuAttrs::GetPassOptions(const GraphOptimizat LOG(FATAL) << s.error_message(); } } + if (params.count("in_out_pair")) { in_out_pair = params.at("in_out_pair").s(); } } } if (!do_npu_optimizer) { @@ -487,6 +489,7 @@ std::map NpuAttrs::GetPassOptions(const GraphOptimizat pass_options["dynamic_inputs_shape_range"] = dynamic_inputs_shape_range; pass_options["local_rank_id"] = std::to_string(local_rank_id); pass_options["local_device_list"] = local_device_list; + pass_options["in_out_pair"] = in_out_pair; return pass_options; } @@ -506,6 +509,7 @@ std::map NpuAttrs::GetPassOptions(OpKernelConstruction std::string dynamic_inputs_shape_range; std::string local_rank_id = "-1"; std::string local_device_list; + std::string in_out_pair; Status s = Status::OK(); string npuOptimizer; @@ -524,6 +528,7 @@ std::map NpuAttrs::GetPassOptions(OpKernelConstruction ctx->GetAttr("_local_rank_id", &local_rank_id); ctx->GetAttr("_local_device_list", &local_device_list); } + ctx->GetAttr("_in_out_pair", &in_out_pair); } // pass options pass_options["do_npu_optimizer"] = do_npu_optimizer; @@ -539,6 +544,7 @@ std::map NpuAttrs::GetPassOptions(OpKernelConstruction pass_options["dynamic_inputs_shape_range"] = dynamic_inputs_shape_range; pass_options["local_rank_id"] = local_rank_id; pass_options["local_device_list"] = local_device_list; + pass_options["in_out_pair"] = in_out_pair; return pass_options; } @@ -558,6 +564,7 @@ std::map NpuAttrs::GetPassOptions(AttrSlice attrs) { std::string dynamic_inputs_shape_range; std::string local_rank_id = "-1"; std::string local_device_list; + std::string in_out_pair; Status s = Status::OK(); if (attrs.Find("_NpuOptimizer") != nullptr) { @@ -590,6 +597,7 @@ std::map NpuAttrs::GetPassOptions(AttrSlice attrs) { if (attrs.Find("_local_device_list") != nullptr) { local_device_list = attrs.Find("_local_device_list")->s(); } + if (attrs.Find("_in_out_pair") != nullptr) { in_out_pair = attrs.Find("_in_out_pair")->s(); } } // pass options pass_options["do_npu_optimizer"] = do_npu_optimizer; @@ -605,6 +613,7 @@ std::map NpuAttrs::GetPassOptions(AttrSlice attrs) { pass_options["dynamic_inputs_shape_range"] = dynamic_inputs_shape_range; pass_options["local_rank_id"] = local_rank_id; pass_options["local_device_list"] = local_device_list; + pass_options["in_out_pair"] = in_out_pair; return pass_options; } @@ -621,6 +630,7 @@ std::map NpuAttrs::GetAllAttrOptions(AttrSlice attrs) std::string task_index = "0"; std::string local_rank_id = "-1"; std::string local_device_list; + std::string in_out_pair; Status s = Status::OK(); std::string variable_format_optimize = std::to_string(true); @@ -688,6 +698,7 @@ std::map NpuAttrs::GetAllAttrOptions(AttrSlice attrs) if (attrs.Find("_local_device_list") != nullptr) { local_device_list = attrs.Find("_local_device_list")->s(); } + if (attrs.Find("_in_out_pair") != nullptr) { in_out_pair = attrs.Find("_in_out_pair")->s(); } if (attrs.Find("_variable_format_optimize") != nullptr) { variable_format_optimize = attrs.Find("_variable_format_optimize")->s(); @@ -831,6 +842,7 @@ std::map NpuAttrs::GetAllAttrOptions(AttrSlice attrs) all_options["task_index"] = task_index; all_options["local_rank_id"] = local_rank_id; all_options["local_device_list"] = local_device_list; + all_options["in_out_pair"] = in_out_pair; all_options["op_select_implmode"] = op_select_implmode; all_options["optypelist_for_implmode"] = optypelist_for_implmode; all_options["input_shape"] = input_shape; @@ -905,6 +917,7 @@ Status NpuAttrs::SetNpuOptimizerAttr(const GraphOptimizationPassOptions &options std::string dynamic_inputs_shape_range; int local_rank_id = -1; std::string local_device_list; + std::string in_out_pair; int enable_exception_dump = 0; std::string op_select_implmode; std::string optypelist_for_implmode; @@ -1089,6 +1102,7 @@ Status NpuAttrs::SetNpuOptimizerAttr(const GraphOptimizationPassOptions &options LOG(FATAL) << s.error_message(); } } + if (params.count("in_out_pair")) { in_out_pair = params.at("in_out_pair").s(); } if (params.count("enable_exception_dump")) { enable_exception_dump = params.at("enable_exception_dump").i(); } if (!params.count("op_select_implmode") && !params.count("optypelist_for_implmode")) { @@ -1232,6 +1246,7 @@ Status NpuAttrs::SetNpuOptimizerAttr(const GraphOptimizationPassOptions &options pass_options["dynamic_inputs_shape_range"] = dynamic_inputs_shape_range; pass_options["local_rank_id"] = std::to_string(local_rank_id); pass_options["local_device_list"] = local_device_list; + pass_options["in_out_pair"] = in_out_pair; std::string attr_name; for (const auto &option : sess_options) { -- Gitee