diff --git a/tf_adapter/kernels/geop_npu.cc b/tf_adapter/kernels/geop_npu.cc index 7f5683bc42a3fd5f709d1418a9656d5fbfcbff9e..56fe196d543a4cd3dccd21f62dea0ba241d52829 100644 --- a/tf_adapter/kernels/geop_npu.cc +++ b/tf_adapter/kernels/geop_npu.cc @@ -62,6 +62,7 @@ #include "graph/ascend_string.h" #include "graph/utils/graph_utils.h" +#include "graph/utils/node_adapter.h" #include "graph/compute_graph.h" #include "graph/ge_attr_value.h" #include "graph/model.h" @@ -978,17 +979,25 @@ Status GeOp::ParseOnnxGraphOpAttr(Node *&node) { std::string model_path = node_def.attr().find("model_path")->second.s(); ge::Graph sub_graph("onnx_compute_graph_" + node->name()); std::map parser_params; - std::string subgrph_name("onnx_compute_graph_" + node->name() + CurrentTimeInStr()); + std::string subgrph_name("onnx_compute_graph_" + node->name() + '_' + CurrentTimeInStr()); parser_params.insert({ge::AscendString(ge::ir_option::OUTPUT), ge::AscendString(subgrph_name.c_str())}); - if(ge::SUCCESS != ge::aclgrphParseONNX(model_path.c_str(), parser_params, sub_graph)) { + if (ge::SUCCESS != ge::aclgrphParseONNX(model_path.c_str(), parser_params, sub_graph)) { LOG(ERROR) << "[GEOP] node: " << node->name() << ": Onnx Model Parse Failed."; return errors::Internal("[GEOP] node: %s Onnx Model Parse Failed.",node->name()); } + // rename the nodes in subgraph of onnx model + for (auto &sub_node : sub_graph.GetAllNodes()) { + auto snode = ge::NodeAdapter::GNode2Node(sub_node); + auto orig_name = snode->GetName(); + auto modi_name = node->name() + '_' + orig_name; + snode->GetOpDesc()->SetName(modi_name); + } + ge::Model onnx_model("onnx_compute_model_" + node->name(), ""); onnx_model.SetGraph(sub_graph); ge::Buffer model_buf; - if(ge::SUCCESS != onnx_model.Save(model_buf, false)){ + if (ge::SUCCESS != onnx_model.Save(model_buf, false)) { LOG(ERROR) << "[GEOP] node: " << node->name() << ": Onnx Model Serialized Failed."; return errors::Internal("[GEOP] node: %s Onnx Model Serialized Failed.", node->name()); } diff --git a/tf_adapter/tests/depends/ge_runner/src/ge_runner_stub.cc b/tf_adapter/tests/depends/ge_runner/src/ge_runner_stub.cc index 5f564352a4613a3d5940efb23b1c749663121f78..8a7eb0ef41b93c9a888c009f86b0cde956ec36ae 100644 --- a/tf_adapter/tests/depends/ge_runner/src/ge_runner_stub.cc +++ b/tf_adapter/tests/depends/ge_runner/src/ge_runner_stub.cc @@ -26,6 +26,7 @@ #include "ge/ge_api_types.h" #include "graph/tensor.h" #include "graph/utils/graph_utils.h" +#include "graph/utils/node_adapter.h" #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/graph/graph_constructor.h" #include "graph/buffer.h" @@ -282,6 +283,24 @@ Graph GraphUtils::CreateGraphFromComputeGraph(const ComputeGraphPtr compute_grap void Graph::SetNeedIteration(bool need_iteration) {} +std::vector Graph::GetAllNodes() const { + return std::vector(); +} + +NodePtr NodeAdapter::GNode2Node(ge::GNode const &node) { + return nullptr; +} + +std::string Node::GetName() { + return ""; +} + +OpDescPtr Node::GetOpDesc() const { + return nullptr; +} + +void OpDesc::SetName(std::string const &name) {} + graphStatus aclgrphParseONNX(const char *model_file, const std::map &parser_params, ge::Graph &graph) { std::string model_(model_file);