diff --git a/tf_adapter/optimizers/om_partition_subgraphs_pass.cc b/tf_adapter/optimizers/om_partition_subgraphs_pass.cc index 411a86dcc6a2c627151e86fda3e43e22cd7c404f..bf82bd89a16901f52ea6c79c7892b90c3ab963f8 100644 --- a/tf_adapter/optimizers/om_partition_subgraphs_pass.cc +++ b/tf_adapter/optimizers/om_partition_subgraphs_pass.cc @@ -214,15 +214,16 @@ bool EndsWith(const std::string &str, const std::string &suffix) { return str.size() >= suffix.size() && str.compare(str.size() - suffix.size(), suffix.size(), suffix) == 0; } -bool IsWhiteListSupport(const string &op_name, bool mix_compile_mode, const string &node_name) { +bool IsWhiteListSupport(const string &op_name, bool mix_compile_mode, const string &node_name, + bool support_const = false) { static const std::string suffix_op = "Dataset"; static const std::string suffix_op_v2 = "DatasetV2"; auto identifier = NpuOpsIdentifier::GetInstance(mix_compile_mode); bool ans = (identifier->IsNpuSupported(op_name, node_name)) && !EndsWith(op_name, suffix_op) && - !EndsWith(op_name, suffix_op_v2) && !(op_name == "Const") && !(op_name == "_Arg") && - !(op_name == "_Retval") && !(op_name == "StringJoin"); + !EndsWith(op_name, suffix_op_v2) && (support_const || !(op_name == "Const")) && + !(op_name == "_Arg") && !(op_name == "_Retval") && !(op_name == "StringJoin"); if (!ans) { auto ret = not_support_nodes.insert(op_name); if (ret.second) { @@ -313,15 +314,17 @@ bool IsNpuSupportingFunc(Node *node, FunctionLibraryDefinition *func_lib, int de return true; } -bool IsNpuSupportingNode(const NodeDef &node_def, bool mix_compile_mode, FunctionLibraryDefinition *func_lib) { +bool IsNpuSupportingNode(const NodeDef &node_def, bool mix_compile_mode, + FunctionLibraryDefinition *func_lib, bool support_const) { if (IsWithoutNpuScope(node_def)) { return false; } - if (IsWhiteListSupport(node_def.op(), mix_compile_mode, node_def.name())) { return true; } + if (IsWhiteListSupport(node_def.op(), mix_compile_mode, node_def.name(), support_const)) { return true; } if (IsNpuSupportingFunc(node_def.op(), func_lib, 0)) { return true; } return false; } -bool IsNpuSupportingNode(Node *node, bool mix_compile_mode, FunctionLibraryDefinition *func_lib) { - return IsNpuSupportingNode(node->def(), mix_compile_mode, func_lib); +bool IsNpuSupportingNode(Node *node, bool mix_compile_mode, FunctionLibraryDefinition *func_lib, + bool support_const) { + return IsNpuSupportingNode(node->def(), mix_compile_mode, func_lib, support_const); } bool IsUnSupportedResource(bool mix_compile_mode, Node* node) { @@ -344,6 +347,150 @@ bool IsUnSupportedResource(bool mix_compile_mode, Node* node) { } return false; } +using NodeSet = std::set; +using NodeMap = std::map; +using NodeStack = std::vector; +using GraphPath = std::vector; +using GraphPaths = std::vector; + +int FindNodesInPaths(Node *op_head, Node *op_tail, NodeSet &ops_save) { + if (!op_head || !op_tail || op_head == op_tail) { return ops_save.size(); } + + NodeMap seen; + NodeStack stack; + GraphPath path; + + NodeSet empty; + seen.insert(NodeMap::value_type(op_head, empty)); + stack.push_back(op_head); + + while (!stack.empty()) { + Node *cur_node = stack.back(); + stack.pop_back(); + + if (!cur_node) { + seen.erase(path.back()); + if (path.size() >= 2) { seen[path[path.size() - 2]].erase(path.back()); } + path.pop_back(); + continue; + } + + path.push_back(cur_node); + if (path.size() >= 2) { seen[path[path.size() - 2]].insert(path.back()); } + stack.push_back(nullptr); + + if (cur_node == op_tail || ops_save.count(cur_node) > 0) { + for (auto node : path) { ops_save.insert(node); } + continue; + } + + for (auto out_node : cur_node->out_nodes()) { + unsigned int n = 0; + for (auto out_out_node : out_node->out_nodes()) { ++n; } + if (seen.insert(NodeMap::value_type(out_node, empty)).second) { + stack.push_back(out_node); + } else if (seen[out_node].size() < n) { + stack.push_back(out_node); + } + } + } + return ops_save.size(); +} + +using IOP = std::pair; +using OneGraphIOP = std::vector; +using AllGraphIOP = std::vector; + +int ParseInOutPair(const std::string &in_out_pair, AllGraphIOP &all_graph_iop) { + using Nodes = std::vector; + using Pair = std::vector; + using Pairs = std::vector; + + int model = 0; + std::string s; + Nodes nodes; + Pair pair; + Pairs pairs; + for (char c : in_out_pair) { + switch (c) { + case '[': + ++model; + break; + case ']': + case ',': + if (1 == model && !pair.empty()) { pairs.emplace_back(std::move(pair)); } + else if (2 == model && !nodes.empty()) { pair.emplace_back(std::move(nodes)); } + else if (3 == model && !s.empty()) { nodes.emplace_back(std::move(s)); } + if (']' == c) { --model; } + break; + case ' ': + case '\t': + case '\'': + break; + default: + s += c; + } + } + + int size = 0; + for (auto &pair : pairs) { + OneGraphIOP one_graph_iop; + if (pair.size() < 2) { continue; } + for (auto &in : pair[0]) { + for (auto &out : pair[1]) { + ++size; + one_graph_iop.push_back(IOP(in, out)); + } + } + all_graph_iop.push_back(one_graph_iop); + } + return size; +} + +Status FindCandidatesByInOutPair(const Graph &graph, OrderedNodeSet *candidates, + FunctionLibraryDefinition *func_lib, const std::string &in_out_pair) { + AllGraphIOP all_graph_iop; + if (0 >= ParseInOutPair(in_out_pair, all_graph_iop)) { + return errors::Internal("in_out_pair: ", in_out_pair, " is invalid."); + } + for (auto &one_graph_iop : all_graph_iop) { + for (auto &iop : one_graph_iop) { + LOG(INFO) << iop.first << " -> " << iop.second; + + Node *op_head = nullptr; + Node *op_tail = nullptr; + for (auto node : graph.nodes()) { + if (node->name() == iop.first) { op_head = node; } + if (node->name() == iop.second) { op_tail = node; } + if (op_head && op_tail) { + break; + } + } + if (!op_head && op_tail) { + ADP_LOG(WARNING) << iop.first << " -> " << iop.second << ", but " << iop.first << " is not find."; + LOG(WARNING) << iop.first << " -> " << iop.second << ", but " << iop.first << " is not find."; + } + if (op_head && !op_tail) { + ADP_LOG(WARNING) << iop.first << " -> " << iop.second << ", but " << iop.second << " is not find."; + LOG(WARNING) << iop.first << " -> " << iop.second << ", but " << iop.second << " is not find."; + } + if (op_head && op_tail) { + NodeSet ops_save; + if (0 < FindNodesInPaths(op_head, op_tail, ops_save)) { + for (auto node : ops_save) { + candidates->insert(node); + } + } + } + } + } + for (auto node : *candidates) { + if (!IsNpuSupportingNode(node, true, func_lib, true)) { + return errors::Internal(node->name(), " is not supported npu node."); + } + } + return Status::OK(); +} Status FindNpuSupportCandidates(const Graph &graph, OrderedNodeSet *candidates, FunctionLibraryDefinition *func_lib, bool enableDP, bool mix_compile_mode) { @@ -686,8 +833,13 @@ Status MarkForPartition(std::unique_ptr *graphIn, int &clusterNum, bool m Graph *graph = graphIn->get(); bool enableDP = pass_options["enable_dp"] == "1"; OrderedNodeSet npuSupportCandidates; - TF_RETURN_IF_ERROR(FindNpuSupportCandidates(*graph, &npuSupportCandidates, func_lib, enableDP, mix_compile_mode)); - TF_RETURN_IF_ERROR(AddRelationalConst(*graph, &npuSupportCandidates)); + if (!pass_options["in_out_pair"].empty()) { + TF_RETURN_IF_ERROR(FindCandidatesByInOutPair(*graph, &npuSupportCandidates, func_lib, pass_options["in_out_pair"])); + } + if (npuSupportCandidates.empty()) { + TF_RETURN_IF_ERROR(FindNpuSupportCandidates(*graph, &npuSupportCandidates, func_lib, enableDP, mix_compile_mode)); + TF_RETURN_IF_ERROR(AddRelationalConst(*graph, &npuSupportCandidates)); + } std::map> cluster_map; tensorflow::GraphCycles cycles; diff --git a/tf_adapter/optimizers/om_partition_subgraphs_pass.h b/tf_adapter/optimizers/om_partition_subgraphs_pass.h index 7a72061bd66fbfa6a2fa8e95ee740a0ad393db77..7392fb0d7ca0d2f8576838a211a2d5d08e00f3a8 100644 --- a/tf_adapter/optimizers/om_partition_subgraphs_pass.h +++ b/tf_adapter/optimizers/om_partition_subgraphs_pass.h @@ -47,8 +47,10 @@ Status MarkForPartition(const GraphOptimizationPassOptions &options, int &cluste Status OMPartitionSubgraphsInFunctions(string groupAttribute, const GraphOptimizationPassOptions &options, string graph_format); -bool IsNpuSupportingNode(const NodeDef &node_def, bool mix_compile_mode, FunctionLibraryDefinition *func_lib); -bool IsNpuSupportingNode(Node *node, bool mix_compile_mode, FunctionLibraryDefinition *func_lib); +bool IsNpuSupportingNode(const NodeDef &node_def, bool mix_compile_mode, + FunctionLibraryDefinition *func_lib, bool support_const = false); +bool IsNpuSupportingNode(Node *node, bool mix_compile_mode, FunctionLibraryDefinition *func_lib, + bool support_const = false); } // namespace OMSplitter class OMPartitionSubgraphsPass : public GraphOptimizationPass { diff --git a/tf_adapter/util/npu_attrs.cc b/tf_adapter/util/npu_attrs.cc index fa8a94e70eeb61d0c0ad320529284ce7ddf88782..feb36e01b8f797e7992fc2c4509bfe2ee303aec4 100644 --- a/tf_adapter/util/npu_attrs.cc +++ b/tf_adapter/util/npu_attrs.cc @@ -463,6 +463,7 @@ std::map NpuAttrs::GetPassOptions(const GraphOptimizat std::string dynamic_inputs_shape_range; int local_rank_id = -1; std::string local_device_list; + std::string in_out_pair; for (const auto &custom_optimizer : rewrite_options.custom_optimizers()) { if (custom_optimizer.name() == "NpuOptimizer") { do_npu_optimizer = true; @@ -514,6 +515,7 @@ std::map NpuAttrs::GetPassOptions(const GraphOptimizat LOG(FATAL) << s.error_message(); } } + if (params.count("in_out_pair")) { in_out_pair = params.at("in_out_pair").s(); } } } if (!do_npu_optimizer) { @@ -537,6 +539,7 @@ std::map NpuAttrs::GetPassOptions(const GraphOptimizat pass_options["dynamic_inputs_shape_range"] = dynamic_inputs_shape_range; pass_options["local_rank_id"] = std::to_string(local_rank_id); pass_options["local_device_list"] = local_device_list; + pass_options["in_out_pair"] = in_out_pair; return pass_options; } @@ -556,6 +559,7 @@ std::map NpuAttrs::GetPassOptions(OpKernelConstruction std::string dynamic_inputs_shape_range; std::string local_rank_id = "-1"; std::string local_device_list; + std::string in_out_pair; Status s = Status::OK(); string npuOptimizer; @@ -574,6 +578,7 @@ std::map NpuAttrs::GetPassOptions(OpKernelConstruction ctx->GetAttr("_local_rank_id", &local_rank_id); ctx->GetAttr("_local_device_list", &local_device_list); } + ctx->GetAttr("_in_out_pair", &in_out_pair); } // pass options pass_options["do_npu_optimizer"] = do_npu_optimizer; @@ -589,6 +594,7 @@ std::map NpuAttrs::GetPassOptions(OpKernelConstruction pass_options["dynamic_inputs_shape_range"] = dynamic_inputs_shape_range; pass_options["local_rank_id"] = local_rank_id; pass_options["local_device_list"] = local_device_list; + pass_options["in_out_pair"] = in_out_pair; return pass_options; } @@ -608,6 +614,7 @@ std::map NpuAttrs::GetPassOptions(AttrSlice attrs) { std::string dynamic_inputs_shape_range; std::string local_rank_id = "-1"; std::string local_device_list; + std::string in_out_pair; Status s = Status::OK(); if (attrs.Find("_NpuOptimizer") != nullptr) { @@ -640,6 +647,7 @@ std::map NpuAttrs::GetPassOptions(AttrSlice attrs) { if (attrs.Find("_local_device_list") != nullptr) { local_device_list = attrs.Find("_local_device_list")->s(); } + if (attrs.Find("_in_out_pair") != nullptr) { in_out_pair = attrs.Find("_in_out_pair")->s(); } } // pass options pass_options["do_npu_optimizer"] = do_npu_optimizer; @@ -655,6 +663,7 @@ std::map NpuAttrs::GetPassOptions(AttrSlice attrs) { pass_options["dynamic_inputs_shape_range"] = dynamic_inputs_shape_range; pass_options["local_rank_id"] = local_rank_id; pass_options["local_device_list"] = local_device_list; + pass_options["in_out_pair"] = in_out_pair; return pass_options; } @@ -671,6 +680,7 @@ std::map NpuAttrs::GetAllAttrOptions(AttrSlice attrs) std::string task_index = "0"; std::string local_rank_id = "-1"; std::string local_device_list; + std::string in_out_pair; Status s = Status::OK(); std::string variable_format_optimize = std::to_string(true); @@ -738,6 +748,7 @@ std::map NpuAttrs::GetAllAttrOptions(AttrSlice attrs) if (attrs.Find("_local_device_list") != nullptr) { local_device_list = attrs.Find("_local_device_list")->s(); } + if (attrs.Find("_in_out_pair") != nullptr) { in_out_pair = attrs.Find("_in_out_pair")->s(); } if (attrs.Find("_variable_format_optimize") != nullptr) { variable_format_optimize = attrs.Find("_variable_format_optimize")->s(); @@ -881,6 +892,7 @@ std::map NpuAttrs::GetAllAttrOptions(AttrSlice attrs) all_options["task_index"] = task_index; all_options["local_rank_id"] = local_rank_id; all_options["local_device_list"] = local_device_list; + all_options["in_out_pair"] = in_out_pair; all_options["op_select_implmode"] = op_select_implmode; all_options["optypelist_for_implmode"] = optypelist_for_implmode; all_options["input_shape"] = input_shape; @@ -955,6 +967,7 @@ Status NpuAttrs::SetNpuOptimizerAttr(const GraphOptimizationPassOptions &options std::string dynamic_inputs_shape_range; int local_rank_id = -1; std::string local_device_list; + std::string in_out_pair; int enable_exception_dump = 0; std::string op_select_implmode; std::string optypelist_for_implmode; @@ -1139,6 +1152,7 @@ Status NpuAttrs::SetNpuOptimizerAttr(const GraphOptimizationPassOptions &options LOG(FATAL) << s.error_message(); } } + if (params.count("in_out_pair")) { in_out_pair = params.at("in_out_pair").s(); } if (params.count("enable_exception_dump")) { enable_exception_dump = params.at("enable_exception_dump").i(); } if (!params.count("op_select_implmode") && !params.count("optypelist_for_implmode")) { @@ -1282,6 +1296,7 @@ Status NpuAttrs::SetNpuOptimizerAttr(const GraphOptimizationPassOptions &options pass_options["dynamic_inputs_shape_range"] = dynamic_inputs_shape_range; pass_options["local_rank_id"] = std::to_string(local_rank_id); pass_options["local_device_list"] = local_device_list; + pass_options["in_out_pair"] = in_out_pair; std::string attr_name; for (const auto &option : sess_options) {