diff --git a/tf_adapter_2.x/npu_device/core/npu_device.cpp b/tf_adapter_2.x/npu_device/core/npu_device.cpp index 05b62a3e1dc53dcf3d4daa5630e111ccfd6f9287..12e4ce0de52800850173d284558d5d556ce41192 100644 --- a/tf_adapter_2.x/npu_device/core/npu_device.cpp +++ b/tf_adapter_2.x/npu_device/core/npu_device.cpp @@ -93,6 +93,60 @@ size_t RemoveRedundantControlEdges(tensorflow::Graph *graph) { } } } + + for (auto node : graph->op_nodes()) { + if (node->type_string() == kDropOutGenMaskV3) { + bool is_first_dropout_mask = true; + for (const auto edge : node->in_edges()) { + if (edge->IsControlEdge() && edge->src()->type_string() == kDropOutDoMaskV3) { + is_first_dropout_mask = false; + break; + } + } + if (is_first_dropout_mask) { + std::vector dropout_gen_masks; + std::vector dropout_do_masks; + std::set seen_masks; + + const std::function &enter = [&dropout_gen_masks, &dropout_do_masks, + &seen_masks](tensorflow::Node *node) { + if (node->type_string() == kDropOutGenMaskV3) { + if (seen_masks.insert(node).second) { + dropout_gen_masks.push_back(node); + } + } else if (node->type_string() == kDropOutDoMaskV3) { + if (seen_masks.insert(node).second) { + dropout_do_masks.push_back(node); + } + } + }; + + tensorflow::EdgeFilter filter = [](const tensorflow::Edge &edge) { + return edge.dst()->type_string() == kDropOutDoMaskV3 || edge.dst()->type_string() == kDropOutGenMaskV3; + }; + + tensorflow::DFSFrom(*graph, {node}, enter, {}, {}, filter); + + if (dropout_gen_masks.size() != dropout_do_masks.size()) { + LOG(ERROR) << "Size mismatch gen masks " << dropout_gen_masks.size() << " vs. do masks " + << dropout_do_masks.size(); + } else { + size_t start = 0; + size_t end = dropout_do_masks.size() - dropout_do_masks.size() / 2; + while (end < dropout_do_masks.size()) { + graph->AddControlEdge(dropout_gen_masks[start++], dropout_do_masks[end++]); + } + start = 1; + end = dropout_gen_masks.size() - dropout_gen_masks.size() / 2 + 1; + while (end < dropout_gen_masks.size()) { + graph->AddControlEdge(dropout_do_masks[start++], dropout_gen_masks[end++]); + } + } + break; + } + } + } + for (auto edge : edges_to_remove) { DLOG() << "Remove redundant control edge " << edge->DebugString(); graph->RemoveEdge(edge);