diff --git a/tf_adapter/kernels/geop_npu.cc b/tf_adapter/kernels/geop_npu.cc index f8f3bc3d62705d7bd781e68dcceb42656fffde5e..e29e4dbb1fce56e12c5769765da7a727200dfe20 100644 --- a/tf_adapter/kernels/geop_npu.cc +++ b/tf_adapter/kernels/geop_npu.cc @@ -213,7 +213,7 @@ GeOp::GeOp(OpKernelConstruction *ctx) sess_init_flag_(false), compute_graph_empty_(false), data_format_(""), graph_id_(0), is_initialized_graph_(false), need_iteration_(false), tf_session_(""), ge_session_(nullptr), job_type_(""), is_host_graph_(false), handle_(nullptr), aoe_tuning_(nullptr), - need_compile_graph_first_(false), aoe_init_(nullptr), aoe_finalize_(nullptr) { + need_compile_graph_first_(false), aoe_init_(nullptr), aoe_finalize_(nullptr), geop_aoe_mode("") { Initialize(ctx); } @@ -457,7 +457,12 @@ void GeOp::GetExecGraphId(OpKernelContext *ctx, uint32_t &cache_graph_id, compute_graph_empty_ = false; } } - +void GeOp::RevertAoeMode() { + if (!geop_aoe_mode.empty()) { + init_options_["ge.jobType"] = geop_aoe_mode; + geop_aoe_mode = ""; + } +} void GeOp::ComputeAsync(OpKernelContext *ctx, DoneCallback done) { // ctx is not nullptr OP_REQUIRES_ASYNC(ctx, init_flag_, errors::InvalidArgument("GeOp not Initialize success."), done); @@ -478,6 +483,7 @@ void GeOp::ComputeAsync(OpKernelContext *ctx, DoneCallback done) { bool res = IncrementGraphIdCount(tf_session_, graph_id_); if (!res || graph_id_ < kInvalidGraphId) { OP_REQUIRES_ASYNC(ctx, false, errors::Unavailable("Get ge session failed."), done); + this->RevertAoeMode(); return; } @@ -486,6 +492,7 @@ void GeOp::ComputeAsync(OpKernelContext *ctx, DoneCallback done) { res = SessionManager::GetInstance().GetOrCreateGeSession(tf_session_, ge_session_, sess_options_); if (!res || tf_session_.empty() || ge_session_ == nullptr) { OP_REQUIRES_ASYNC(ctx, false, errors::Unavailable("Get ge session failed."), done); + this->RevertAoeMode(); return; } if (!init_options_["ge.jobType"].empty() && !init_options_["ge.tuningPath"].empty()) { @@ -526,10 +533,12 @@ void GeOp::ComputeAsync(OpKernelContext *ctx, DoneCallback done) { std::stringstream ss; ss << "dynamic input config can not use with mstuning."; OP_REQUIRES_ASYNC(ctx, false, errors::Internal(ss.str()), done); + this->RevertAoeMode(); return; } else if (is_set_dynamic_config && !is_tuning) { if (InitRebuildFlag(cache_graph_id) != 0) { OP_REQUIRES_ASYNC(ctx, false, errors::Internal("Failed to check rebuild flag"), done); + this->RevertAoeMode(); return; } } else if (!is_set_dynamic_config && is_tuning) { @@ -541,6 +550,7 @@ void GeOp::ComputeAsync(OpKernelContext *ctx, DoneCallback done) { } if (InitRebuildFlag(cache_graph_id) != 0) { OP_REQUIRES_ASYNC(ctx, false, errors::Internal("Failed to check rebuild flag"), done); + this->RevertAoeMode(); return; } } @@ -562,6 +572,7 @@ void GeOp::ComputeAsync(OpKernelContext *ctx, DoneCallback done) { Tensor initialized_tensor(ctx->expected_output_dtype(0), TensorShape({0})); ctx->set_output(0, initialized_tensor); done(); + this->RevertAoeMode(); return; } @@ -672,6 +683,7 @@ void GeOp::ComputeAsync(OpKernelContext *ctx, DoneCallback done) { << ", ret_status:" << ToString(ge::SUCCESS) << " , tf session: " << tf_session_ << " ,graph id: " << cache_graph_id << " [" << ((endTime - startTime) / kMicrosToMillis) << " ms]"; done(); + this->RevertAoeMode(); return; } @@ -699,10 +711,10 @@ void GeOp::ComputeAsync(OpKernelContext *ctx, DoneCallback done) { ADP_LOG(INFO) << "[GEOP] in tune mode, nontraining graphs should be cache."; OP_REQUIRES_ASYNC(ctx, SessionManager::GetInstance().CacheGeGraphs(ge_session_, ge_graph), errors::Internal("[GEOP] cache ge session failed."), done); - build_flag_ = true; - BuildOutTensorInfo(ctx); - done(); - return; + // build_flag_ = true; + // BuildOutTensorInfo(ctx); + // done(); + // return; } else { ADP_LOG(INFO) << "[GEOP] in tune mode, training graph handled by tools."; std::vector ge_graphs; @@ -712,11 +724,14 @@ void GeOp::ComputeAsync(OpKernelContext *ctx, DoneCallback done) { AoeStatus tune_ret = (*aoe_tuning_)(ge_graph, ge_graphs, ge_session_, tune_options_); OP_REQUIRES_ASYNC(ctx, tune_ret == AOE_SUCCESS, errors::Internal("[GEOP] exec aoe tuning func failed."), done); ADP_LOG(INFO) << "[GEOP] aoe success."; - build_flag_ = true; - BuildOutTensorInfo(ctx); - done(); - return; + // build_flag_ = true; + // BuildOutTensorInfo(ctx); + // done(); + // return; } + geop_aoe_mode = init_options_["ge.jobType"]; + init_options_["ge.jobType"] = ""; + return this->ComputeAsync(ctx, done); } // call ge session addGraph api @@ -757,6 +772,7 @@ void GeOp::ComputeAsync(OpKernelContext *ctx, DoneCallback done) { ADP_LOG(INFO) << "[GEOP] Build graph success."; done(); + this->RevertAoeMode(); return; } LOG(INFO) << "The model has been compiled on the Ascend AI processor, current graph id is:" << cache_graph_id; @@ -767,16 +783,18 @@ void GeOp::ComputeAsync(OpKernelContext *ctx, DoneCallback done) { << ", ret_status:" << ToString(ge::SUCCESS) << " , tf session: " << tf_session_ << " ,graph id: " << cache_graph_id << " [" << ((endTime - startTime) / kMicrosToMillis) << " ms]"; done(); + this->RevertAoeMode(); return; } } - if (is_tuning) { - ADP_LOG(INFO) << "in mstune mode, graph only execute once, The remaining steps return directly."; - BuildOutTensorInfo(ctx); - done(); - return; - } + // if (is_tuning) { + // ADP_LOG(INFO) << "in mstune mode, graph only execute once, The remaining steps return directly."; + // BuildOutTensorInfo(ctx); + // done(); + // this->RevertAoeMode(); + // return; + // } int64 run_start_time = InferShapeUtil::GetCurrentTimestap(); auto callback = [done, ctx, run_start_time](ge::Status ge_status, std::vector &outputs) { @@ -834,6 +852,7 @@ void GeOp::ComputeAsync(OpKernelContext *ctx, DoneCallback done) { ADP_LOG(INFO) << "[GEOP] End GeOp::ComputeAsync, kernel_name:" << geop_name << ", ret_status:" << ToString(status) << " ,tf session: " << tf_session_ << " ,graph id: " << cache_graph_id << " [" << ((endTime - startTime) / kMicrosToMillis) << " ms]"; + this->RevertAoeMode(); return; } diff --git a/tf_adapter/kernels/geop_npu.h b/tf_adapter/kernels/geop_npu.h index f6c1d5d753cccff9804ba18d52b35c094bb5889c..f7aea78c1aab00da055d5682c059b34c10231326 100644 --- a/tf_adapter/kernels/geop_npu.h +++ b/tf_adapter/kernels/geop_npu.h @@ -41,7 +41,7 @@ class GeOp : public AsyncOpKernel { explicit GeOp(OpKernelConstruction *ctx); ~GeOp(); void ComputeAsync(OpKernelContext *ctx, DoneCallback done) override; - + void RevertAoeMode(); private: void Initialize(OpKernelConstruction *ctx); void Finalize(); @@ -104,7 +104,7 @@ class GeOp : public AsyncOpKernel { static const std::string SERIALIZE_DATATYPE; static const std::string SERIALIZE_SHAPE; static const std::string SubGraph; - + std::string geop_aoe_mode; static mutex mu_; bool init_flag_; @@ -127,9 +127,9 @@ class GeOp : public AsyncOpKernel { std::map sess_options_; std::map init_options_; static std::unordered_map session_and_graph_id_map_; - uint32_t iteration_per_loop_; + bool is_host_graph_; - std::map graph_options_; + std::map graph_options_; std::map outputs_shape_; std::string is_train_graph_; void *handle_; diff --git a/tf_adapter/tests/ut/optimizers/testcase/om_partition_subgraphs_pass_test.cc b/tf_adapter/tests/ut/optimizers/testcase/om_partition_subgraphs_pass_test.cc index ddda5fb9dd170d3990c1b4a539d766e48ec8239a..c7c23496d4618e756d204872a8a09f9cb2e211fc 100644 --- a/tf_adapter/tests/ut/optimizers/testcase/om_partition_subgraphs_pass_test.cc +++ b/tf_adapter/tests/ut/optimizers/testcase/om_partition_subgraphs_pass_test.cc @@ -150,7 +150,7 @@ TEST_F(OmOptimizationPassTest, MergeClustersTest) { && target_graph.find("v3->") != target_graph.npos) { ret = true; } EXPECT_EQ(ret, true); } -TEST_F(OmOptimizationPassTest, MixCompileTest) { +TEST_F(OmOptimizationPassTest, MixCompileTest) { string org_graph_def_path = "tf_adapter/tests/ut/optimizers/pbtxt/om_test_mix_compile.pbtxt"; InitGraph(org_graph_def_path); std::string target_graph = "GeOp7_0->Unique;Unique->GeOp7_1";