From 5cc7867d4d70a8d16e72d2bf64221b060b6a31fb Mon Sep 17 00:00:00 2001 From: panghongjun Date: Wed, 28 Apr 2021 10:57:30 +0800 Subject: [PATCH] while_loop --- .../optimizers/om_partition_subgraphs_pass.cc | 86 +++++++++++++++---- 1 file changed, 67 insertions(+), 19 deletions(-) diff --git a/tf_adapter/optimizers/om_partition_subgraphs_pass.cc b/tf_adapter/optimizers/om_partition_subgraphs_pass.cc index f4d241308..308dbe4fc 100644 --- a/tf_adapter/optimizers/om_partition_subgraphs_pass.cc +++ b/tf_adapter/optimizers/om_partition_subgraphs_pass.cc @@ -601,29 +601,77 @@ Status FindNpuSupportCandidates(const Graph &graph, OrderedNodeSet *candidates, } } } - if (mix_compile_mode) { - std::vector cfInfos; - Status status = BuildControlFlowInfo(&graph, &cfInfos); - if (!status.ok()) return status; - std::set unsupportedFrames; - for (auto it : outSet) { - auto cfInfo = cfInfos[it->id()]; - if (!cfInfo.frame_name.empty()) { unsupportedFrames.insert(cfInfo.frame_name); } - while (!cfInfos[cfInfo.parent_frame->id()].frame_name.empty()) { - unsupportedFrames.insert(cfInfos[cfInfo.parent_frame->id()].frame_name); - cfInfo = cfInfos[cfInfo.parent_frame->id()]; + // if (mix_compile_mode) { + std::vector cycleEdges; + for (auto edge : graph.edges()) { + REQUIRES_NOT_NULL(edge); + Node *src = edge->src(); + Node *dst = edge->dst(); + REQUIRES_NOT_NULL(src); + REQUIRES_NOT_NULL(dst); + // Skip source/sink + if (!src->IsOp() || !dst->IsOp()) { continue; } + if (src->IsNextIteration()) { + cycleEdges.push_back(edge); } } - for (auto it = candidates->begin(); it != candidates->end();) { - auto cfInfo = cfInfos[(*it)->id()]; - if (unsupportedFrames.find(cfInfo.frame_name) != unsupportedFrames.end()) { - outSet.insert(*it); - it = candidates->erase(it); - } else { - ++it; + + std::vector> while_loop_nodes; + for (auto edge : cycleEdges) { + REQUIRES_NOT_NULL(edge); + Node *src = edge->src(); + Node *dst = edge->dst(); + REQUIRES_NOT_NULL(src); + REQUIRES_NOT_NULL(dst); + Node *enter_node = nullptr; + Node *exit_node = nullptr; + if (src->IsNextIteration() && dst->IsMerge()) { + for (auto edge : dst->in_edges()) { + REQUIRES_NOT_NULL(edge); + REQUIRES_NOT_NULL(edge->src()); + REQUIRES_NOT_NULL(edge->dst()); + if (edge->src()->IsEnter()) { + enter_node = edge->src(); + break; + } + } + for (auto edge : dst->out_edges()) { + REQUIRES_NOT_NULL(edge); + REQUIRES_NOT_NULL(edge->src()); + REQUIRES_NOT_NULL(edge->dst()); + if (edge->dst()->IsSwitch()) { + for (auto switch_edge : edge->dst()->out_edges()) { + REQUIRES_NOT_NULL(switch_edge); + REQUIRES_NOT_NULL(switch_edge->src()); + REQUIRES_NOT_NULL(switch_edge->dst()); + if (switch_edge->dst()->IsExit()) { + exit_node = switch_edge->dst(); + } + } + } + } + } + if (enter_node != nullptr && exit_node != nullptr) { + while_loop_nodes.push_back({enter_node, exit_node}); } } - } + for (auto pair_node : while_loop_nodes) { + Node *op_head = pair_node.first; + NodeSet ops_tail; + NodeSet ops_save; + ops_tail.insert(pair_node.second); + FindNodesInPaths(op_head, ops_tail, ops_save); + for (auto it : outSet) { + if (ops_save.count(it) > 0) { + for (auto tmp : ops_save) { + outSet.insert(tmp); + candidates->erase(tmp); + } + break; + } + } + } + // } // Reference edge: The reference input/output of the sinking node does not sink while (!outSet.empty()) { -- Gitee