diff --git a/tf_adapter/optimizers/om_partition_subgraphs_pass.cc b/tf_adapter/optimizers/om_partition_subgraphs_pass.cc index 7b86b01b7cf6554e128ef47c4694c7fc14893c4b..32e85aa0923320da4e929772bdaca2ccf83e0a8c 100644 --- a/tf_adapter/optimizers/om_partition_subgraphs_pass.cc +++ b/tf_adapter/optimizers/om_partition_subgraphs_pass.cc @@ -2395,17 +2395,21 @@ Status OMPartitionSubgraphsPass::SplitUnaryOpsComposition(Graph *graph, Node *no Node *pre_node = (*node->in_edges().begin())->src(); auto node_list = node->def().attr().at("op_names").list(); Node *unary_node = nullptr; + bool src_from_org = true; for (int i = 0; i < node_list.s_size(); i++) { const std::string &node_name = node_list.s(i); std::string op_name = node->name() + "_" + std::to_string(i) + "_" + node_name; - ADP_LOG(INFO) << "op_names node_list: " << i << " is node: " << node_name; + const auto src_output = src_from_org ? (*node->in_edges().begin())->src_output() : 0; + ADP_LOG(INFO) << "op_names node_list: " << i << " is node: " << node_name << "src_node:" << pre_node->name() + << "output index:" << src_output; TF_CHECK_OK(NodeBuilder(op_name, node_name) - .Input(pre_node, 0) - .Device(pre_node->def().device()) - .Finalize(graph, &unary_node)); + .Input(pre_node, src_output) + .Device(pre_node->def().device()) + .Finalize(graph, &unary_node)); ADP_LOG(INFO) << unary_node->type_string() << " has built success."; pre_node = unary_node; REQUIRES_NOT_NULL(pre_node); + src_from_org = false; } for (auto out_edge : node->out_edges()) { if (out_edge->IsControlEdge()) {