From 6903630c3a8b4118fa82eaaeb5f299358baaa8a5 Mon Sep 17 00:00:00 2001 From: Fu Jingguo Date: Sat, 4 Sep 2021 18:18:40 +0800 Subject: [PATCH 1/2] add onnx graph op --- CMakeLists.txt | 4 + configure.py | 2 + inc/parser/inc/external/parser/onnx_parser.h | 47 +++++ module.mk | 1 + tf_adapter/BUILD | 6 +- tf_adapter/kernels/geop_npu.cc | 41 +++++ tf_adapter/kernels/npu_ops.cc | 12 ++ tf_adapter/module.BUILD | 1 + tf_adapter/ops/npu_ops.cc | 10 ++ .../depends/ge_runner/src/ge_runner_stub.cc | 36 +++- .../tests/ut/kernels/onnx_model/conv2d.onnx | Bin 0 -> 273 bytes .../pbtxt/geop_npu_onnx_graph_op.pbtxt | 170 ++++++++++++++++++ .../ut/kernels/testcase/geop_npu_test.cc | 10 ++ .../testcase/npu_onnx_graph_op_test.cc | 32 ++++ 14 files changed, 368 insertions(+), 4 deletions(-) create mode 100644 inc/parser/inc/external/parser/onnx_parser.h create mode 100644 tf_adapter/tests/ut/kernels/onnx_model/conv2d.onnx create mode 100644 tf_adapter/tests/ut/kernels/pbtxt/geop_npu_onnx_graph_op.pbtxt create mode 100644 tf_adapter/tests/ut/kernels/testcase/npu_onnx_graph_op_test.cc diff --git a/CMakeLists.txt b/CMakeLists.txt index 65d273d0a..efcc02851 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -27,6 +27,8 @@ if (ENABLE_OPEN_SRC) include_directories(${CMAKE_CURRENT_LIST_DIR}/inc/toolchain/tuning_tool) include_directories(${CMAKE_CURRENT_LIST_DIR}/inc/external) include_directories(${CMAKE_CURRENT_LIST_DIR}/inc/soft_dp) + include_directories(${CMAKE_CURRENT_LIST_DIR}/inc/parser/inc) + include_directories(${CMAKE_CURRENT_LIST_DIR}/inc/parser/inc/external) include_directories(${CMAKE_CURRENT_LIST_DIR}/inc/graphengine/inc) include_directories(${CMAKE_CURRENT_LIST_DIR}/inc/graphengine/inc/external) include_directories(${CMAKE_CURRENT_LIST_DIR}/inc/metadef/inc) @@ -154,6 +156,8 @@ else() ${TOP_DIR}/inc/ ${TOP_DIR}/inc/external/ ${TOP_DIR}/inc/common/ + ${TOP_DIR}/inc/parser/inc/ + ${TOP_DIR}/inc/parser/inc/external/ ${TOP_DIR}/soft_dp/ ${TOP_DIR}/ace/execfwk/soft_dp/ ${TOP_DIR}/graphengine/inc/ diff --git a/configure.py b/configure.py index 0e2a1bde5..520198a2a 100755 --- a/configure.py +++ b/configure.py @@ -143,12 +143,14 @@ def setup_ascend(env_path): if 'ALL_IN_ONE_ENABLE' in os.environ: f.write(os.path.join(ascend_path, "compiler", "lib64", "libge_runner.so\n")) f.write(os.path.join(ascend_path, "compiler", "lib64", "libfmk_parser.so\n")) + f.write(os.path.join(ascend_path, "compiler", "lib64", "libfmk_onnx_parser.so\n")) f.write(os.path.join(ascend_path, "runtime", "lib64", "libdatatransfer.so\n")) f.write(os.path.join(ascend_path, "runtime", "lib64", "libindextransform.so\n")) f.write(os.path.join(ascend_path, "compiler", "lib64", "libalog.so\n")) else: f.write(os.path.join(ascend_path, "fwkacllib", "lib64", "libge_runner.so\n")) f.write(os.path.join(ascend_path, "fwkacllib", "lib64", "libfmk_parser.so\n")) + f.write(os.path.join(ascend_path, "fwkacllib", "lib64", "libfmk_onnx_parser.so\n")) f.write(os.path.join(ascend_path, "fwkacllib", "lib64", "libdatatransfer.so\n")) f.write(os.path.join(ascend_path, "fwkacllib", "lib64", "libindextransform.so\n")) f.write(os.path.join(ascend_path, "fwkacllib", "lib64", "libalog.so\n")) diff --git a/inc/parser/inc/external/parser/onnx_parser.h b/inc/parser/inc/external/parser/onnx_parser.h new file mode 100644 index 000000000..92877c633 --- /dev/null +++ b/inc/parser/inc/external/parser/onnx_parser.h @@ -0,0 +1,47 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_EXTERNAL_PARSER_ONNX_PARSER_H_ +#define INC_EXTERNAL_PARSER_ONNX_PARSER_H_ + +#if defined(_MSC_VER) +#ifdef FUNC_VISIBILITY +#define PARSER_FUNC_VISIBILITY _declspec(dllexport) +#else +#define PARSER_FUNC_VISIBILITY +#endif +#else +#ifdef FUNC_VISIBILITY +#define PARSER_FUNC_VISIBILITY __attribute__((visibility("default"))) +#else +#define PARSER_FUNC_VISIBILITY +#endif +#endif + +#include "graph/ascend_string.h" +#include "graph/ge_error_codes.h" +#include "graph/graph.h" +#include "graph/types.h" + +namespace ge { +PARSER_FUNC_VISIBILITY graphStatus aclgrphParseONNX(const char *model_file, + const std::map &parser_params, ge::Graph &graph); + +PARSER_FUNC_VISIBILITY graphStatus aclgrphParseONNXFromMem(const char *buffer, size_t size, + const std::map &parser_params, ge::Graph &graph); +} // namespace ge + +#endif // INC_EXTERNAL_PARSER_ONNX_PARSER_H_ diff --git a/module.mk b/module.mk index bdce6fe1e..4c2d817fc 100644 --- a/module.mk +++ b/module.mk @@ -42,6 +42,7 @@ LOCAL_SHARED_LIBRARIES := \ libtsdclient \ libdatatransfer \ libfmk_parser \ + libfmk_onnx_parser \ libindextransform LOCAL_SOFT_DP_LIBRARIES := libSoftDp diff --git a/tf_adapter/BUILD b/tf_adapter/BUILD index d28056115..05a146b5e 100644 --- a/tf_adapter/BUILD +++ b/tf_adapter/BUILD @@ -33,9 +33,9 @@ cc_binary( linkopts = [] + select({ # Public introduction of external dependencies on project. # External linked libraries, typically, located in out/${product}/host/obj/lib - ":cloud_build": ["-Lexternal/tf_adapter_cloud_host_libs/ -lc_sec -lge_runner -ltsdclient -ldatatransfer -lfmk_parser -lindextransform"], - ":mini_build": ["-Lexternal/tf_adapter_mini_host_libs/ -lc_sec -lge_runner -ltsdclient -ldatatransfer -lfmk_parser -lindextransform",], - ":onetrack_build": ["-Lexternal/tf_adapter_onetrack_host_libs/ -lc_sec -lge_runner -ltsdclient -ldatatransfer -lfmk_parser -lindextransform",], + ":cloud_build": ["-Lexternal/tf_adapter_cloud_host_libs/ -lc_sec -lge_runner -ltsdclient -ldatatransfer -lfmk_parser -lfmk_onnx_parser -lindextransform"], + ":mini_build": ["-Lexternal/tf_adapter_mini_host_libs/ -lc_sec -lge_runner -ltsdclient -ldatatransfer -lfmk_parser -lfmk_onnx_parser -lindextransform",], + ":onetrack_build": ["-Lexternal/tf_adapter_onetrack_host_libs/ -lc_sec -lge_runner -ltsdclient -ldatatransfer -lfmk_parser -lfmk_onnx_parser -lindextransform",], "//conditions:default": [], }) + [ # "-z defs", diff --git a/tf_adapter/kernels/geop_npu.cc b/tf_adapter/kernels/geop_npu.cc index 8337a9c34..f666a3642 100644 --- a/tf_adapter/kernels/geop_npu.cc +++ b/tf_adapter/kernels/geop_npu.cc @@ -67,9 +67,15 @@ limitations under the License. #include "framework/omg/parser/parser_api.h" #include "framework/omg/parser/parser_factory.h" #include "framework/omg/parser/parser_inner_ctx.h" +#include "parser/onnx_parser.h" #include "ge/ge_api.h" #include "ge/ge_api_types.h" +#include "graph/utils/graph_utils.h" +#include "graph/compute_graph.h" +#include "graph/ge_attr_value.h" +#include "graph/model.h" + namespace tensorflow { Status FunctionalizeControlFlow(Graph *graph, FunctionLibraryDefinition *library); namespace { @@ -911,6 +917,31 @@ Status GeOp::BuildGraphDef(FunctionLibraryDefinition &flib_def, << ret.error_message(); return ret; } + + if(node->type_string() == "NpuOnnxGraphOp") { + NodeDef &node_def = const_cast(node->def()); + std::string model_path = node_def.attr().find("model_path")->second.s(); + + ge::Graph sub_graph("onnx_compute_graph_" + node->name()); + std::map parser_params; + if(ge::SUCCESS != ge::aclgrphParseONNX(model_path.c_str(), parser_params, sub_graph)) { + LOG(ERROR) << "[GEOP] node: " << node->name() << ": Onnx Model Parse Failed."; + return errors::Internal("[GEOP] node: %s Onnx Model Parse Failed.",node->name()); + } + + ge::Model onnx_model("onnx_compute_model_" + node->name(), ""); + onnx_model.SetGraph(sub_graph); + ge::Buffer model_buf; + if(ge::SUCCESS != onnx_model.Save(model_buf, false)){ + LOG(ERROR) << "[GEOP] node: " << node->name() << ": Onnx Model Serialized Failed."; + return errors::Internal("[GEOP] node: %s Onnx Model Serialized Failed.", node->name()); + } + + std::string model_str(reinterpret_cast(model_buf.GetData()), model_buf.GetSize()); + AttrValue attr_value; + attr_value.set_s(model_str); + node_def.mutable_attr()->insert({"_external_model", attr_value}); + } if (is_tuning) { // output handle NodeDef &node_def = const_cast(node->def()); @@ -1178,6 +1209,16 @@ Status GeOp::GenerateDesc(Node *&node) { DataTypeVector outputs; TF_RETURN_IF_ERROR(tensorflow::InOutTypesForNode(node_def, op_def, &inputs, &outputs)); + //Get input and output numbers of NpuOnnxGraphOp op + if(node->type_string() == "NpuOnnxGraphOp") { + AttrValue in_value; + in_value.set_i(inputs.size()); + node_def.mutable_attr()->insert({"_input_num", in_value}); + AttrValue ot_value; + ot_value.set_i(outputs.size()); + node_def.mutable_attr()->insert({"_output_num", ot_value}); + } + int num; Node *in_node = nullptr; const Edge *in_edge = nullptr; diff --git a/tf_adapter/kernels/npu_ops.cc b/tf_adapter/kernels/npu_ops.cc index 6662a6162..ae5985bee 100644 --- a/tf_adapter/kernels/npu_ops.cc +++ b/tf_adapter/kernels/npu_ops.cc @@ -39,5 +39,17 @@ class NPUTestOP : public OpKernel { }; REGISTER_KERNEL_BUILDER(Name("NPUTest").Device(DEVICE_CPU), NPUTestOP); + +class NpuOnnxGraphOp : public OpKernel { + public: + explicit NpuOnnxGraphOp(OpKernelConstruction *context) : OpKernel(context) {} + ~NpuOnnxGraphOp() override = default; + void Compute(OpKernelContext *context) override { + return; + } + bool IsExpensive() override { return false; } +}; + +REGISTER_KERNEL_BUILDER(Name("NpuOnnxGraphOp").Device(DEVICE_CPU), NpuOnnxGraphOp); } // namespace } // namespace tensorflow diff --git a/tf_adapter/module.BUILD b/tf_adapter/module.BUILD index d95f2a136..79985d093 100644 --- a/tf_adapter/module.BUILD +++ b/tf_adapter/module.BUILD @@ -44,6 +44,7 @@ cc_library( "libge_runner.so", "libdatatransfer.so", "libfmk_parser.so", + "libfmk_onnx_parser.so", "libindextransform.so", "libmmpa" ]), diff --git a/tf_adapter/ops/npu_ops.cc b/tf_adapter/ops/npu_ops.cc index 261585242..eb67c8f9a 100644 --- a/tf_adapter/ops/npu_ops.cc +++ b/tf_adapter/ops/npu_ops.cc @@ -26,6 +26,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/shape_inference.h" #include "tensorflow/core/framework/op.h" namespace tensorflow { @@ -487,6 +488,15 @@ REGISTER_OP("AdamApplyOneWithDecayAssign") .Attr("T: {float16, float32}") .SetShapeFn(shape_inference::NoOutputs); +REGISTER_OP("NpuOnnxGraphOp") + .Input("inputs: Tin") + .Attr("Tin: list(type) >= 0") + .Output("outputs: Tout") + .Attr("Tout: list(type) >= 0") + .Attr("model_path: string") + .SetShapeFn(shape_inference::UnknownShape); + + REGISTER_OP("KMeansCentroids") .Input("x: T") .Input("y: T") diff --git a/tf_adapter/tests/depends/ge_runner/src/ge_runner_stub.cc b/tf_adapter/tests/depends/ge_runner/src/ge_runner_stub.cc index 8f0567234..d6c799c35 100644 --- a/tf_adapter/tests/depends/ge_runner/src/ge_runner_stub.cc +++ b/tf_adapter/tests/depends/ge_runner/src/ge_runner_stub.cc @@ -28,6 +28,8 @@ #include "graph/utils/graph_utils.h" #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/graph/graph_constructor.h" +#include "graph/buffer.h" +#include "graph/model.h" #include @@ -273,12 +275,44 @@ size_t ComputeGraph::GetAllNodesSize() const { return 1; } +Graph::Graph(const std::string& grph) {} Graph::Graph(char const* name) {} Graph GraphUtils::CreateGraphFromComputeGraph(const ComputeGraphPtr compute_graph) { return Graph("ge"); } void Graph::SetNeedIteration(bool need_iteration) {} +graphStatus aclgrphParseONNX(const char *model_file, + const std::map &parser_params, ge::Graph &graph) { + return SUCCESS; +} + +Buffer::Buffer() {} +std::size_t Buffer::GetSize() const { + return sizeof("_external_model"); +} +std::uint8_t *Buffer::GetData() { + std::string *buf_ = new std::string("_external_model"); + return reinterpret_cast (buf_); +} +const std::uint8_t *Buffer::GetData() const { + std::string *buf_ = new std::string("_external_model"); + return reinterpret_cast (buf_); +} + +Model::Model() {} +Model::Model(const string &name, const string &custom_version) {} +void Model::SetGraph(const Graph& graph) {} +graphStatus Model::Save(Buffer &buffer, bool is_dump) const { + return GRAPH_SUCCESS; +} +ConstProtoAttrMapHelper Model::GetAttrMap() const { + return ConstProtoAttrMapHelper(); +} +ProtoAttrMapHelper Model::MutableAttrMap() { + return ProtoAttrMapHelper(); +} + Tensor::Tensor() {} graphStatus Tensor::SetTensorDesc(const TensorDesc &tensorDesc) { return GRAPH_SUCCESS; } @@ -340,4 +374,4 @@ std::shared_ptr ModelParserFactory::CreateModelParser(const d } ModelParserFactory::~ModelParserFactory() {} -} // end domi \ No newline at end of file +} // end domi diff --git a/tf_adapter/tests/ut/kernels/onnx_model/conv2d.onnx b/tf_adapter/tests/ut/kernels/onnx_model/conv2d.onnx new file mode 100644 index 0000000000000000000000000000000000000000..aa823ed7ada00ce7ee57815bcec5667678de3e8d GIT binary patch literal 273 zcmaiu%L>9U5JfwE&^Ri@pdz>{t_rT)xv|8LD7Yz+1}&B}Bvo|hr}&kwjjyH1a2bYs z?#z&w(MGBBTnl5RDOzdUqebYhMlsiMO!t>W{|-R;F~$-1z6p=B0i!4n!_s73k8d~+0I+nTDB-%ad&V?n@GKmrQM%=50Rf8vm* jeF6>-pC|{unY!QE7>5KHMW4V&k?YacnU`NC_i6V5aw9Tp literal 0 HcmV?d00001 diff --git a/tf_adapter/tests/ut/kernels/pbtxt/geop_npu_onnx_graph_op.pbtxt b/tf_adapter/tests/ut/kernels/pbtxt/geop_npu_onnx_graph_op.pbtxt new file mode 100644 index 000000000..1afa2132e --- /dev/null +++ b/tf_adapter/tests/ut/kernels/pbtxt/geop_npu_onnx_graph_op.pbtxt @@ -0,0 +1,170 @@ +node { + name: "_arg_conv_input_0_0" + op: "_Arg" + device: "/job:localhost/replica:0/task:0/device:CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "index" + value { + i: 0 + } + } +} +node { + name: "_retval_conv2d_0_0" + op: "_RetVal" + input: "GeOp91_0" + device: "/job:localhost/replica:0/task:0/device:CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "index" + value { + i: 0 + } + } +} +node { + name: "GeOp91_0" + op: "GeOp" + input: "_arg_conv_input_0_0" + device: "/job:localhost/replica:0/task:0/device:CPU:0" + attr { + key: "Tin" + value { + list { + type: DT_FLOAT + } + } + } + attr { + key: "Tout" + value { + list { + type: DT_FLOAT + } + } + } + attr { + key: "_NpuOptimizer" + value { + s: "NpuOptimizer" + } + } + attr { + key: "_do_npu_optimizer" + value { + s: "1" + } + } + attr { + key: "_use_off_line" + value { + s: "1" + } + } + attr { + key: "_variable_format_optimize" + value { + s: "1" + } + } + attr { + key: "data_format" + value { + s: "NHWC" + } + } + attr { + key: "function" + value { + func { + name: "GeOp91_0" + } + } + } +} + +library { + function { + signature { + name: "GeOp91_0" + input_arg { + name: "_arg_conv_input_0_0_0_arg" + type: DT_FLOAT + } + output_arg { + name: "conv2d_0_retval" + type: DT_FLOAT + } + } + node_def { + name: "conv2d" + op: "NpuOnnxGraphOp" + input: "_arg_conv_input_0_0_0_arg" + device: "/job:localhost/replica:0/task:0/device:CPU:0" + attr { + key: "Tin" + value { + list { + type: DT_FLOAT + } + } + } + attr { + key: "Tout" + value { + list { + type: DT_FLOAT + } + } + } + attr { + key: "_NpuOptimizer" + value { + s: "NpuOptimizer" + } + } + attr { + key: "_do_npu_optimizer" + value { + s: "1" + } + } + attr { + key: "_use_off_line" + value { + s: "1" + } + } + attr { + key: "_variable_format_optimize" + value { + s: "1" + } + } + attr { + key: "model_path" + value { + s: "tf_adapter/tests/ut/kernels/onnx_model/conv2d.onnx" + } + } + } + ret { + key: "conv2d_0_retval" + value: "conv2d:outputs:0" + } + } +} +versions { + producer: 134 +} \ No newline at end of file diff --git a/tf_adapter/tests/ut/kernels/testcase/geop_npu_test.cc b/tf_adapter/tests/ut/kernels/testcase/geop_npu_test.cc index 3d49a063c..4436e9fd4 100644 --- a/tf_adapter/tests/ut/kernels/testcase/geop_npu_test.cc +++ b/tf_adapter/tests/ut/kernels/testcase/geop_npu_test.cc @@ -277,5 +277,15 @@ TEST_F(GeOpTest, GeOpWhileLoopV2Test) { EXPECT_TRUE(GeOpRunGraphAsync(graph_def_path, inputs, node_def, "GeOp13_0").ok()); } +TEST_F(GeOpTest, GeOpNpuOnnxGraphOpTest) { + NodeDef node_def; + //std::string onnx_model_path = "tf_adapter/tests/ut/kernels/onnx_model/conv2d.onnx"; + std::string grph_pbtxt_path = "tf_adapter/tests/ut/kernels/pbtxt/geop_npu_onnx_graph_op.pbtxt"; + + Tensor in(DT_FLOAT, TensorShape({1,1,5,5})); + gtl::InlinedVector inputs{TensorValue(&in)}; + EXPECT_TRUE(GeOpRunGraphAsync(grph_pbtxt_path, inputs, node_def, "GeOp91_0").ok()); +} + } } //end tensorflow \ No newline at end of file diff --git a/tf_adapter/tests/ut/kernels/testcase/npu_onnx_graph_op_test.cc b/tf_adapter/tests/ut/kernels/testcase/npu_onnx_graph_op_test.cc new file mode 100644 index 000000000..08bd3f476 --- /dev/null +++ b/tf_adapter/tests/ut/kernels/testcase/npu_onnx_graph_op_test.cc @@ -0,0 +1,32 @@ +#include +#include "tf_adapter/kernels/npu_ops.cc" +#include "gtest/gtest.h" + +namespace tensorflow { +class NpuOnnxGraphOpTest : public testing::Test { + protected: + virtual void SetUp() {} + virtual void TearDown() {} +}; + +TEST_F(NpuOnnxGraphOpTest, TestNpuOnnxGraphOp) { + DataTypeSlice input_types({DT_FLOAT}); + MemoryTypeSlice input_memory_types; + DataTypeSlice output_types({DT_FLOAT}); + MemoryTypeSlice output_memory_types; + DeviceBase *device = new DeviceBase(Env::Default()); + NodeDef *node_def = new NodeDef(); + OpDef *op_def = new OpDef(); + OpKernelConstruction *context = new OpKernelConstruction(DEVICE_CPU, device, nullptr, node_def, op_def, nullptr, + input_types, input_memory_types, output_types, output_memory_types, + 1, nullptr); + NpuOnnxGraphOp npu_onnx_graph_conv(context); + OpKernelContext *ctx = nullptr; + npu_onnx_graph_conv.Compute(ctx); + npu_onnx_graph_conv.IsExpensive(); + delete device; + delete node_def; + delete op_def; + delete context; +} +} \ No newline at end of file -- Gitee From 1fa638dc4b64da7282f2c4159c68384579d40c2f Mon Sep 17 00:00:00 2001 From: Fu Jingguo Date: Sat, 4 Sep 2021 19:40:27 +0800 Subject: [PATCH 2/2] add onnx graph op, def a function for parse component --- tf_adapter/kernels/geop_npu.cc | 66 +++++++++++++++++----------------- 1 file changed, 34 insertions(+), 32 deletions(-) diff --git a/tf_adapter/kernels/geop_npu.cc b/tf_adapter/kernels/geop_npu.cc index f666a3642..55e7d532c 100644 --- a/tf_adapter/kernels/geop_npu.cc +++ b/tf_adapter/kernels/geop_npu.cc @@ -919,28 +919,7 @@ Status GeOp::BuildGraphDef(FunctionLibraryDefinition &flib_def, } if(node->type_string() == "NpuOnnxGraphOp") { - NodeDef &node_def = const_cast(node->def()); - std::string model_path = node_def.attr().find("model_path")->second.s(); - - ge::Graph sub_graph("onnx_compute_graph_" + node->name()); - std::map parser_params; - if(ge::SUCCESS != ge::aclgrphParseONNX(model_path.c_str(), parser_params, sub_graph)) { - LOG(ERROR) << "[GEOP] node: " << node->name() << ": Onnx Model Parse Failed."; - return errors::Internal("[GEOP] node: %s Onnx Model Parse Failed.",node->name()); - } - - ge::Model onnx_model("onnx_compute_model_" + node->name(), ""); - onnx_model.SetGraph(sub_graph); - ge::Buffer model_buf; - if(ge::SUCCESS != onnx_model.Save(model_buf, false)){ - LOG(ERROR) << "[GEOP] node: " << node->name() << ": Onnx Model Serialized Failed."; - return errors::Internal("[GEOP] node: %s Onnx Model Serialized Failed.", node->name()); - } - - std::string model_str(reinterpret_cast(model_buf.GetData()), model_buf.GetSize()); - AttrValue attr_value; - attr_value.set_s(model_str); - node_def.mutable_attr()->insert({"_external_model", attr_value}); + this->ParseOnnxGraphOpAttr(node); } if (is_tuning) { // output handle @@ -988,6 +967,39 @@ Status GeOp::BuildGraphDef(FunctionLibraryDefinition &flib_def, return Status::OK(); } +void GeOp::ParseOnnxGraphOpAttr(Node *&node) { + NodeDef &node_def = const_cast(node->def()); + std::string model_path = node_def.attr().find("model_path")->second.s(); + + //Get input and output numbers of NpuOnnxGraphOp op + AttrValue in_value; + in_value.set_i(static_cast(node->num_inputs())); + node_def.mutable_attr()->insert({"_input_num", in_value}); + AttrValue ot_value; + ot_value.set_i(static_cast(node->num_outputs())); + node_def.mutable_attr()->insert({"_output_num", ot_value}); + + ge::Graph sub_graph("onnx_compute_graph_" + node->name()); + std::map parser_params; + if(ge::SUCCESS != ge::aclgrphParseONNX(model_path.c_str(), parser_params, sub_graph)) { + LOG(ERROR) << "[GEOP] node: " << node->name() << ": Onnx Model Parse Failed."; + return errors::Internal("[GEOP] node: %s Onnx Model Parse Failed.",node->name()); + } + + ge::Model onnx_model("onnx_compute_model_" + node->name(), ""); + onnx_model.SetGraph(sub_graph); + ge::Buffer model_buf; + if(ge::SUCCESS != onnx_model.Save(model_buf, false)){ + LOG(ERROR) << "[GEOP] node: " << node->name() << ": Onnx Model Serialized Failed."; + return errors::Internal("[GEOP] node: %s Onnx Model Serialized Failed.", node->name()); + } + + std::string model_str(reinterpret_cast(model_buf.GetData()), model_buf.GetSize()); + AttrValue attr_value; + attr_value.set_s(model_str); + node_def.mutable_attr()->insert({"_external_model", attr_value}); +} + void GeOp::BuildShapeNodeAndCacheArgNodes(Graph &graph) { std::string dynamic_node_type = sess_options_["ge.dynamicNodeType"]; for (Node *node : graph.nodes()) { @@ -1209,16 +1221,6 @@ Status GeOp::GenerateDesc(Node *&node) { DataTypeVector outputs; TF_RETURN_IF_ERROR(tensorflow::InOutTypesForNode(node_def, op_def, &inputs, &outputs)); - //Get input and output numbers of NpuOnnxGraphOp op - if(node->type_string() == "NpuOnnxGraphOp") { - AttrValue in_value; - in_value.set_i(inputs.size()); - node_def.mutable_attr()->insert({"_input_num", in_value}); - AttrValue ot_value; - ot_value.set_i(outputs.size()); - node_def.mutable_attr()->insert({"_output_num", ot_value}); - } - int num; Node *in_node = nullptr; const Edge *in_edge = nullptr; -- Gitee