From 440a1079517dede5ba4efcfb39635e7996d770c2 Mon Sep 17 00:00:00 2001 From: lianghuikang <505519763@qq.com> Date: Sat, 16 Oct 2021 16:28:14 +0800 Subject: [PATCH] test --- tf_adapter/kernels/geop_npu.cc | 51 +++++++++++++------ tf_adapter/kernels/geop_npu.h | 8 +-- .../om_partition_subgraphs_pass_test.cc | 2 +- 3 files changed, 40 insertions(+), 21 deletions(-) diff --git a/tf_adapter/kernels/geop_npu.cc b/tf_adapter/kernels/geop_npu.cc index f8f3bc3d6..e29e4dbb1 100644 --- a/tf_adapter/kernels/geop_npu.cc +++ b/tf_adapter/kernels/geop_npu.cc @@ -213,7 +213,7 @@ GeOp::GeOp(OpKernelConstruction *ctx) sess_init_flag_(false), compute_graph_empty_(false), data_format_(""), graph_id_(0), is_initialized_graph_(false), need_iteration_(false), tf_session_(""), ge_session_(nullptr), job_type_(""), is_host_graph_(false), handle_(nullptr), aoe_tuning_(nullptr), - need_compile_graph_first_(false), aoe_init_(nullptr), aoe_finalize_(nullptr) { + need_compile_graph_first_(false), aoe_init_(nullptr), aoe_finalize_(nullptr), geop_aoe_mode("") { Initialize(ctx); } @@ -457,7 +457,12 @@ void GeOp::GetExecGraphId(OpKernelContext *ctx, uint32_t &cache_graph_id, compute_graph_empty_ = false; } } - +void GeOp::RevertAoeMode() { + if (!geop_aoe_mode.empty()) { + init_options_["ge.jobType"] = geop_aoe_mode; + geop_aoe_mode = ""; + } +} void GeOp::ComputeAsync(OpKernelContext *ctx, DoneCallback done) { // ctx is not nullptr OP_REQUIRES_ASYNC(ctx, init_flag_, errors::InvalidArgument("GeOp not Initialize success."), done); @@ -478,6 +483,7 @@ void GeOp::ComputeAsync(OpKernelContext *ctx, DoneCallback done) { bool res = IncrementGraphIdCount(tf_session_, graph_id_); if (!res || graph_id_ < kInvalidGraphId) { OP_REQUIRES_ASYNC(ctx, false, errors::Unavailable("Get ge session failed."), done); + this->RevertAoeMode(); return; } @@ -486,6 +492,7 @@ void GeOp::ComputeAsync(OpKernelContext *ctx, DoneCallback done) { res = SessionManager::GetInstance().GetOrCreateGeSession(tf_session_, ge_session_, sess_options_); if (!res || tf_session_.empty() || ge_session_ == nullptr) { OP_REQUIRES_ASYNC(ctx, false, errors::Unavailable("Get ge session failed."), done); + this->RevertAoeMode(); return; } if (!init_options_["ge.jobType"].empty() && !init_options_["ge.tuningPath"].empty()) { @@ -526,10 +533,12 @@ void GeOp::ComputeAsync(OpKernelContext *ctx, DoneCallback done) { std::stringstream ss; ss << "dynamic input config can not use with mstuning."; OP_REQUIRES_ASYNC(ctx, false, errors::Internal(ss.str()), done); + this->RevertAoeMode(); return; } else if (is_set_dynamic_config && !is_tuning) { if (InitRebuildFlag(cache_graph_id) != 0) { OP_REQUIRES_ASYNC(ctx, false, errors::Internal("Failed to check rebuild flag"), done); + this->RevertAoeMode(); return; } } else if (!is_set_dynamic_config && is_tuning) { @@ -541,6 +550,7 @@ void GeOp::ComputeAsync(OpKernelContext *ctx, DoneCallback done) { } if (InitRebuildFlag(cache_graph_id) != 0) { OP_REQUIRES_ASYNC(ctx, false, errors::Internal("Failed to check rebuild flag"), done); + this->RevertAoeMode(); return; } } @@ -562,6 +572,7 @@ void GeOp::ComputeAsync(OpKernelContext *ctx, DoneCallback done) { Tensor initialized_tensor(ctx->expected_output_dtype(0), TensorShape({0})); ctx->set_output(0, initialized_tensor); done(); + this->RevertAoeMode(); return; } @@ -672,6 +683,7 @@ void GeOp::ComputeAsync(OpKernelContext *ctx, DoneCallback done) { << ", ret_status:" << ToString(ge::SUCCESS) << " , tf session: " << tf_session_ << " ,graph id: " << cache_graph_id << " [" << ((endTime - startTime) / kMicrosToMillis) << " ms]"; done(); + this->RevertAoeMode(); return; } @@ -699,10 +711,10 @@ void GeOp::ComputeAsync(OpKernelContext *ctx, DoneCallback done) { ADP_LOG(INFO) << "[GEOP] in tune mode, nontraining graphs should be cache."; OP_REQUIRES_ASYNC(ctx, SessionManager::GetInstance().CacheGeGraphs(ge_session_, ge_graph), errors::Internal("[GEOP] cache ge session failed."), done); - build_flag_ = true; - BuildOutTensorInfo(ctx); - done(); - return; + // build_flag_ = true; + // BuildOutTensorInfo(ctx); + // done(); + // return; } else { ADP_LOG(INFO) << "[GEOP] in tune mode, training graph handled by tools."; std::vector ge_graphs; @@ -712,11 +724,14 @@ void GeOp::ComputeAsync(OpKernelContext *ctx, DoneCallback done) { AoeStatus tune_ret = (*aoe_tuning_)(ge_graph, ge_graphs, ge_session_, tune_options_); OP_REQUIRES_ASYNC(ctx, tune_ret == AOE_SUCCESS, errors::Internal("[GEOP] exec aoe tuning func failed."), done); ADP_LOG(INFO) << "[GEOP] aoe success."; - build_flag_ = true; - BuildOutTensorInfo(ctx); - done(); - return; + // build_flag_ = true; + // BuildOutTensorInfo(ctx); + // done(); + // return; } + geop_aoe_mode = init_options_["ge.jobType"]; + init_options_["ge.jobType"] = ""; + return this->ComputeAsync(ctx, done); } // call ge session addGraph api @@ -757,6 +772,7 @@ void GeOp::ComputeAsync(OpKernelContext *ctx, DoneCallback done) { ADP_LOG(INFO) << "[GEOP] Build graph success."; done(); + this->RevertAoeMode(); return; } LOG(INFO) << "The model has been compiled on the Ascend AI processor, current graph id is:" << cache_graph_id; @@ -767,16 +783,18 @@ void GeOp::ComputeAsync(OpKernelContext *ctx, DoneCallback done) { << ", ret_status:" << ToString(ge::SUCCESS) << " , tf session: " << tf_session_ << " ,graph id: " << cache_graph_id << " [" << ((endTime - startTime) / kMicrosToMillis) << " ms]"; done(); + this->RevertAoeMode(); return; } } - if (is_tuning) { - ADP_LOG(INFO) << "in mstune mode, graph only execute once, The remaining steps return directly."; - BuildOutTensorInfo(ctx); - done(); - return; - } + // if (is_tuning) { + // ADP_LOG(INFO) << "in mstune mode, graph only execute once, The remaining steps return directly."; + // BuildOutTensorInfo(ctx); + // done(); + // this->RevertAoeMode(); + // return; + // } int64 run_start_time = InferShapeUtil::GetCurrentTimestap(); auto callback = [done, ctx, run_start_time](ge::Status ge_status, std::vector &outputs) { @@ -834,6 +852,7 @@ void GeOp::ComputeAsync(OpKernelContext *ctx, DoneCallback done) { ADP_LOG(INFO) << "[GEOP] End GeOp::ComputeAsync, kernel_name:" << geop_name << ", ret_status:" << ToString(status) << " ,tf session: " << tf_session_ << " ,graph id: " << cache_graph_id << " [" << ((endTime - startTime) / kMicrosToMillis) << " ms]"; + this->RevertAoeMode(); return; } diff --git a/tf_adapter/kernels/geop_npu.h b/tf_adapter/kernels/geop_npu.h index f6c1d5d75..f7aea78c1 100644 --- a/tf_adapter/kernels/geop_npu.h +++ b/tf_adapter/kernels/geop_npu.h @@ -41,7 +41,7 @@ class GeOp : public AsyncOpKernel { explicit GeOp(OpKernelConstruction *ctx); ~GeOp(); void ComputeAsync(OpKernelContext *ctx, DoneCallback done) override; - + void RevertAoeMode(); private: void Initialize(OpKernelConstruction *ctx); void Finalize(); @@ -104,7 +104,7 @@ class GeOp : public AsyncOpKernel { static const std::string SERIALIZE_DATATYPE; static const std::string SERIALIZE_SHAPE; static const std::string SubGraph; - + std::string geop_aoe_mode; static mutex mu_; bool init_flag_; @@ -127,9 +127,9 @@ class GeOp : public AsyncOpKernel { std::map sess_options_; std::map init_options_; static std::unordered_map session_and_graph_id_map_; - uint32_t iteration_per_loop_; + bool is_host_graph_; - std::map graph_options_; + std::map graph_options_; std::map outputs_shape_; std::string is_train_graph_; void *handle_; diff --git a/tf_adapter/tests/ut/optimizers/testcase/om_partition_subgraphs_pass_test.cc b/tf_adapter/tests/ut/optimizers/testcase/om_partition_subgraphs_pass_test.cc index ddda5fb9d..c7c23496d 100644 --- a/tf_adapter/tests/ut/optimizers/testcase/om_partition_subgraphs_pass_test.cc +++ b/tf_adapter/tests/ut/optimizers/testcase/om_partition_subgraphs_pass_test.cc @@ -150,7 +150,7 @@ TEST_F(OmOptimizationPassTest, MergeClustersTest) { && target_graph.find("v3->") != target_graph.npos) { ret = true; } EXPECT_EQ(ret, true); } -TEST_F(OmOptimizationPassTest, MixCompileTest) { +TEST_F(OmOptimizationPassTest, MixCompileTest) { string org_graph_def_path = "tf_adapter/tests/ut/optimizers/pbtxt/om_test_mix_compile.pbtxt"; InitGraph(org_graph_def_path); std::string target_graph = "GeOp7_0->Unique;Unique->GeOp7_1"; -- Gitee