From 61708201775e9b15c54042e5fe050c47a53b9ff6 Mon Sep 17 00:00:00 2001 From: caiguangxing Date: Fri, 2 Dec 2022 10:40:22 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E5=A4=8D=E5=9C=A8=E7=BA=BF=E5=8A=A0?= =?UTF-8?q?=E6=8F=92Aipp=E7=AE=97=E5=AD=90=20=E8=AE=BE=E7=BD=AEinput=5Ffor?= =?UTF-8?q?mat=E7=9A=84=E6=9D=A1=E4=BB=B6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tf_adapter/kernels/geop_npu.cc | 14 +++++++++----- tf_adapter/kernels/geop_npu.h | 2 ++ .../ut/kernels/pbtxt/geop_dynamic_execute.pbtxt | 6 ++++++ .../tests/ut/kernels/testcase/geop_npu_test.cc | 10 ++++++++++ 4 files changed, 27 insertions(+), 5 deletions(-) diff --git a/tf_adapter/kernels/geop_npu.cc b/tf_adapter/kernels/geop_npu.cc index 75c5a72b3..bb79c7dbc 100644 --- a/tf_adapter/kernels/geop_npu.cc +++ b/tf_adapter/kernels/geop_npu.cc @@ -325,6 +325,11 @@ void GeOp::Initialize(OpKernelConstruction *ctx) { ADP_LOG(INFO) << "[GEOP] get session info from attr, tf session: " << tf_session_; } + // if insert aipp op + s = ctx->GetAttr("_insert_op_file", &insert_op_file); + if (s.ok()) { + ADP_LOG(INFO) << "[Insert][Aipp] get _insert_op_file attr from file: " << insert_op_file; + } ctx->GetAttr("_recompute_mode", &recompute_mode_); ctx->GetAttr("_deploy_inject_config", &deploy_inject_config_); ctx->GetAttr("_execute_times", &execute_times_); @@ -845,9 +850,6 @@ void GeOp::ComputeAsync(OpKernelContext *ctx, DoneCallback done) { } // convert to ge::graph - if (graph_options_.count("input_format") != 0) { - ADP_LOG(INFO) << "graph_options_[\"input_format\"] = " << graph_options_["input_format"]; - } ge::Graph ge_graph = ge::GraphUtilsEx::CreateGraphFromComputeGraph(compute_graph); if (iteration_per_loop_ > 1) { ge_graph.SetNeedIteration(this->need_iteration_); @@ -1202,12 +1204,14 @@ Status GeOp::ProcessForDiffNodeTypes(Graph &graph, bool &is_initialize, bool &is if (node->type_string() == "NpuOnnxGraphOp") { ret = this->ParseOnnxGraphOpAttr(node); - graph_options_["input_format"] = "NCHW"; - ADP_LOG(INFO) << "onnx_graph_parser graph_options_[\"input_format\"] = " << graph_options_["input_format"]; if (!ret.ok()) { LOG(ERROR) << "[GEOP]node: " << node->name() << " Parse Node with Onnx Model failed, " << ret.error_message(); return ret; } + if (!insert_op_file.empty()) { + graph_options_["input_format"] = "NCHW"; + } + ADP_LOG(INFO) << "[GEOP]node: " << node->name() << " Parse Node with Onnx Model succeed."; } if (node->type_string() == "IteratorGetNext") { diff --git a/tf_adapter/kernels/geop_npu.h b/tf_adapter/kernels/geop_npu.h index 647298a0d..d3f5eee30 100644 --- a/tf_adapter/kernels/geop_npu.h +++ b/tf_adapter/kernels/geop_npu.h @@ -190,9 +190,11 @@ private: std::string max_num_; std::string embedding_dim_; std::string recompute_mode_; + std::string insert_op_file; std::vector> input_shapes_vec_; bool jit_compile_; bool is_getnext_dynamic_shape_; + SessionId session_id_; AoeInitializeFunc aoe_initialize_; AoeFinalizeFunc aoe_finalize_; diff --git a/tf_adapter/tests/ut/kernels/pbtxt/geop_dynamic_execute.pbtxt b/tf_adapter/tests/ut/kernels/pbtxt/geop_dynamic_execute.pbtxt index fdd35b6b5..2616af875 100644 --- a/tf_adapter/tests/ut/kernels/pbtxt/geop_dynamic_execute.pbtxt +++ b/tf_adapter/tests/ut/kernels/pbtxt/geop_dynamic_execute.pbtxt @@ -138,6 +138,12 @@ node { s: "dynamic_execute" } } + attr { + key: "_insert_op_file" + value { + s: "aipp.cfg" + } + } attr { key: "_dynamic_input" value { diff --git a/tf_adapter/tests/ut/kernels/testcase/geop_npu_test.cc b/tf_adapter/tests/ut/kernels/testcase/geop_npu_test.cc index 9bf6897bb..f431310e4 100644 --- a/tf_adapter/tests/ut/kernels/testcase/geop_npu_test.cc +++ b/tf_adapter/tests/ut/kernels/testcase/geop_npu_test.cc @@ -171,6 +171,16 @@ TEST_F(GeOpTest, GeOpDynamicInput1Test) { EXPECT_TRUE(!attrs["_dynamic_input"].s().empty()); EXPECT_EQ(attrs["_dynamic_graph_execute_mode"].s() == "dynamic_execute", true); } +TEST_F(GeOpTest, GeOpAippParamTest) { + NodeDef node_def; + std::string graph_def_path = "tf_adapter/tests/ut/kernels/pbtxt/geop_dynamic_execute.pbtxt"; + Tensor a(DT_INT32, TensorShape({1,})); + gtl::InlinedVector inputs{TensorValue(&a)}; + EXPECT_TRUE(GeOpRunGraphAsync(graph_def_path, inputs, node_def, "GeOp14_0", false).ok()); + auto attrs = node_def.attr(); + EXPECT_TRUE(attrs.find("_insert_op_file") != attrs.end()); + EXPECT_TRUE(!attrs["_insert_op_file"].s().empty()); +} TEST_F(GeOpTest, GeOpAoeTuningAndDynamicDimsTest) { NodeDef node_def; std::string graph_def_path = "tf_adapter/tests/ut/kernels/pbtxt/geop_aoe_tuning_and_dynamic_dims.pbtxt"; -- Gitee