diff --git a/tf_adapter/optimizers/om_partition_subgraphs_pass.cc b/tf_adapter/optimizers/om_partition_subgraphs_pass.cc index 2438922834cbc054761058b70a381ca56e8171af..65dcb76ead6b36277e2e7103431010dc74a0161a 100644 --- a/tf_adapter/optimizers/om_partition_subgraphs_pass.cc +++ b/tf_adapter/optimizers/om_partition_subgraphs_pass.cc @@ -25,7 +25,7 @@ #include #include #include - +#include #include "tensorflow/compiler/jit/graphcycles/graphcycles.h" #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/common_runtime/shape_refiner.h" @@ -254,32 +254,32 @@ Status SetIteratorShardName(Node *node) { } // Make sure we don't recurse infinitely on recursive functions. -const int kMaxRecursionDepth = 10; - -bool IsNpuSupportingFunc(const string &func_name, const FunctionLibraryDefinition *func_lib, int depth) { +bool IsNpuSupportingFunc(const string &func_name, const FunctionLibraryDefinition *func_lib) { if (func_lib == nullptr) { ADP_LOG(ERROR) << "func lib is nullptr, function name is " << func_name; LOG(ERROR) << "func lib is nullptr, function name is " << func_name; return false; } - if (depth >= kMaxRecursionDepth) { - ADP_LOG(ERROR) << "Rejecting " << func_name << ": function depth limit exceeded."; - LOG(ERROR) << "Rejecting " << func_name << ": function depth limit exceeded."; - return false; - } - const FunctionDef *func_def = func_lib->Find(func_name); - if (func_def == nullptr) { - return false; - } - for (NodeDef node_def : func_def->node_def()) { - if (node_def.op() == "Const") { - ADP_LOG(INFO) << "Const node in function can dump."; - } else if (!IsNpuSupportingNode(node_def, compile_mode, func_lib)) { + std::stack func_name_stack; + func_name_stack.emplace(func_name); + while (!func_name_stack.empty()) { + std::string top_func_name = func_name_stack.front(); + func_name_stack.pop(); + + const FunctionDef *func_def = func_lib->Find(top_func_name); + if (func_def == nullptr) { return false; } - for (const auto &item : node_def.attr()) { - if (item.second.has_func()) { - if (!IsNpuSupportingFunc(item.second.func().name(), func_lib, depth + 1)) { return false; } + for (NodeDef node_def : func_def->node_def()) { + if (node_def.op() == "Const") { + ADP_LOG(INFO) << "Const node in function can dump."; + } else if (!IsNpuSupportingNode(node_def, compile_mode, func_lib)) { + return false; + } + for (const auto &item : node_def.attr()) { + if (item.second.has_func()) { + func_name_stack.emplace(item.second.func().name()); + } } } } @@ -290,7 +290,7 @@ bool IsNpuSupportingFunc(const Node *node, const FunctionLibraryDefinition *func for (const auto &it : node->attrs()) { if (it.second.has_func()) { string func_name = it.second.func().name(); - if (!IsNpuSupportingFunc(func_name, func_lib, depth)) { return false; } + if (!IsNpuSupportingFunc(func_name, func_lib)) { return false; } } } return true; @@ -300,7 +300,7 @@ bool IsNpuSupportingNode(const NodeDef &node_def, bool mix_compile_mode, const FunctionLibraryDefinition *func_lib, bool support_const) { if (IsWithoutNpuScope(node_def)) { return false; } if (IsWhiteListSupport(node_def.op(), mix_compile_mode, node_def.name(), support_const)) { return true; } - if (IsNpuSupportingFunc(node_def.op(), func_lib, 0)) { return true; } + if (IsNpuSupportingFunc(node_def.op(), func_lib)) { return true; } return false; } @@ -560,7 +560,7 @@ Status FindNpuSupportCandidates(const Graph &graph, OrderedNodeSet *candidates, OrderedNodeSet outSet; for (Node *node : sortedNodes) { // 0 is function depth - if (!IsNpuSupportingFunc(node, func_lib, 0)) { continue; } + if (!IsNpuSupportingFunc(node, func_lib)) { continue; } if (!node->IsOp()) { // Ship Sink/Source nodes. continue; }