diff --git a/tf_adapter/optimizers/om_partition_subgraphs_pass.cc b/tf_adapter/optimizers/om_partition_subgraphs_pass.cc index 3eced9e721fa508258f9bf8fa9f5e69a9887588c..70e0ff512d7c571ac80176b78e469b4640657988 100644 --- a/tf_adapter/optimizers/om_partition_subgraphs_pass.cc +++ b/tf_adapter/optimizers/om_partition_subgraphs_pass.cc @@ -87,6 +87,23 @@ void GetAccumulateBuilderInfo(Node *node, std::vector &clusterSet, tensorflow::GraphCycles &cycles, + std::vector> &clusters_to_reachable) { + const size_t max_index = *clusterSet.rbegin() + 1; + clusters_to_reachable.resize(max_index); + for (size_t i = 0U; i < max_index; ++i) { + clusters_to_reachable[i].resize(max_index); + } + LOG(INFO) << "cluster set max index:" << *clusterSet.end(); + for (int src : clusterSet) { + for (int dst : clusterSet) { + if (src == dst) { continue; } + LOG(INFO) << "cached src:" << src << " and dst:" << dst << "reachable result"; + clusters_to_reachable[src][dst] = cycles.IsReachableNonConst(src, dst); + } + } +} } // namespace static const int64 kMicrosToMillis = 1000; @@ -832,9 +849,19 @@ Status MergeSubgraphsInNewWay(std::vector> &sortedCluster if (mergedClusters.count(dstSubgraph) < 1) { (void) mergedClusters.insert(dstSubgraph); for (const string &toMerge : clusterToMerge[dstSubgraph]) { + bool can_merge = true; if (clusterToMerge[toMerge].count(dstSubgraph) > 0) { - (void) mergedClusters.insert(toMerge); - mergePair[toMerge] = dstSubgraph; + for (const auto merged : mergePair) { + if (merged.second == dstSubgraph && (clusterToMerge[merged.first].count(toMerge) == 0 || + clusterToMerge[toMerge].count(merged.first) == 0)) { + can_merge = false; + break; + } + } + if (can_merge) { + mergePair[toMerge] = dstSubgraph; + (void) mergedClusters.insert(toMerge); + } } } } @@ -1084,19 +1111,18 @@ Status MarkForPartition(const std::unique_ptr *graph_in, int &clusterNum, } // Generate Merge possibility between clusters if (clusterSet.size() > 1) { - for (int src : clusterSet) { - for (int dst : clusterSet) { + std::vector> clusters_to_reachable; + CacheClusterReachableStatus(clusterSet, cycles, clusters_to_reachable); + std::set> is_seen_clusters; + for (auto src_iter = clusterSet.begin(); src_iter != clusterSet.end(); ++src_iter) { + for (auto dst_iter = std::next(src_iter); dst_iter != clusterSet.end(); ++dst_iter) { + const int src = *src_iter; + const int dst = *dst_iter; if (src == dst) { continue; } - if (!cycles.IsReachableNonConst(src, dst) && !cycles.IsReachableNonConst(dst, src)) { + LOG(INFO) << "find result src:" << src << " and dst:" << dst; + if (!clusters_to_reachable[src][dst] && !clusters_to_reachable[dst][src]) { if (mix_compile_mode) { bool canReach = false; - for (auto cluster : clusterIndexToMerge[src]) { - if (cycles.IsReachableNonConst(dst, cluster) || cycles.IsReachableNonConst(cluster, dst)) { - canReach = true; - break; - } - } - if (!canReach && job != "localhost") { Cluster *cluster_src = nullptr; Cluster *cluster_dst = nullptr; @@ -1121,10 +1147,11 @@ Status MarkForPartition(const std::unique_ptr *graph_in, int &clusterNum, } if (!canReach) { (void) clusterToMerge[clusterInfo[src].first].insert(clusterInfo[dst].first); - (void) clusterIndexToMerge[src].insert(dst); + (void) clusterToMerge[clusterInfo[dst].first].insert(clusterInfo[src].first); } } else { (void) clusterToMerge[clusterInfo[src].first].insert(clusterInfo[dst].first); + (void) clusterToMerge[clusterInfo[dst].first].insert(clusterInfo[src].first); } } }