From 1184120288f24d80613eaf1e434807d022a8e13c Mon Sep 17 00:00:00 2001 From: wang-xiangX Date: Thu, 8 Sep 2022 20:59:19 +0800 Subject: [PATCH] floatOverflowMode --- tf_adapter/kernels/geop_npu.cc | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/tf_adapter/kernels/geop_npu.cc b/tf_adapter/kernels/geop_npu.cc index 18ddfc868..1082c46cd 100644 --- a/tf_adapter/kernels/geop_npu.cc +++ b/tf_adapter/kernels/geop_npu.cc @@ -78,6 +78,8 @@ #include "tf_adapter_2.x/npu_device/core/npu_micros.h" #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/framework/graph_to_functiondef.h" +#include "acl/acl_base.h" +#include "acl/acl_rt.h" namespace tensorflow { #ifdef TF_VERSION_TF2 @@ -93,6 +95,7 @@ const std::string ATTR_NAME_SUBGRAPH_MULTI_DIMS_INPUT_SHAPE = "_subgraph_multi_d const std::string ATTR_NAME_SUBGRAPH_MULTI_DIMS_INPUT_DIMS = "_subgraph_multi_dims_input_dims"; const std::string kMdatTuning = "mdat"; const std::string kAutoRecompute = "auto"; +constexpr const char* kFloatOverflowMode[] = {"SAT", "INF_NAN", "UNDEF"}; using geDataUniquePtr = std::unique_ptr>; class NpuHostFixedAllocator : public tensorflow::Allocator, public tensorflow::core::RefCounted { @@ -883,6 +886,17 @@ void GeOp::ComputeAsync(OpKernelContext *ctx, DoneCallback done) { OP_REQUIRES_ASYNC(ctx, run_graph_status == ge::SUCCESS, errors::Unavailable(ss.str()), done); } + // get floatOverflowMode + aclrtFloatOverflowMode floatOverflowMode = ACL_RT_OVERFLOW_MODE_UNDEF; + aclError ret = aclrtGetDeviceSatMode(&floatOverflowMode); + if (ret != ACL_SUCCESS) { + ADP_LOG(ERROR) << "[GePlugin] get device satMode failed, ret: " << ToString(ret); + LOG(ERROR) << "[GePlugin] get device satMode failed, ret: " << ToString(ret); + return; + } + ADP_LOG(INFO) << "[GePlugin] get device satMode success."; + graph_options_["ge.satuateMode"] = kFloatOverflowMode[floatOverflowMode]; + endTime = InferShapeUtil::GetCurrentTimestap(); ADP_LOG(INFO) << "[GEOP] End GeOp::ComputeAsync, kernel_name:" << geop_name << ", ret_status:" << ToString(run_graph_status) << " ,tf session: " << tf_session_ @@ -1078,6 +1092,7 @@ void GeOp::HandleDpOpAndGetNextNodes(Graph &graph) { } } } + for (Node *node : remove_nodes) { ADP_LOG(INFO) << "[GEOP] Remove node: " << node->name(); graph.RemoveNode(node); -- Gitee