From f5363e5bf5f4d80f0f657ff7081d551dc20e38b4 Mon Sep 17 00:00:00 2001 From: shibo19 Date: Wed, 9 Mar 2022 19:39:30 +0800 Subject: [PATCH] =?UTF-8?q?aoe=E8=B0=83=E4=BC=98=E4=B8=ADtransdata?= =?UTF-8?q?=E7=AE=97=E5=AD=90=E4=B8=8D=E5=8F=82=E4=B8=8E=E8=B0=83=E8=B0=83?= =?UTF-8?q?=E4=BC=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- torch_npu/csrc/framework/OpParamMaker.cpp | 6 ++++-- torch_npu/csrc/framework/aoe/AoeUtils.cpp | 8 ++++++++ torch_npu/csrc/framework/aoe/AoeUtils.h | 5 ++++- 3 files changed, 16 insertions(+), 3 deletions(-) diff --git a/torch_npu/csrc/framework/OpParamMaker.cpp b/torch_npu/csrc/framework/OpParamMaker.cpp index 48acf5e9adb..16a8cc63e6f 100644 --- a/torch_npu/csrc/framework/OpParamMaker.cpp +++ b/torch_npu/csrc/framework/OpParamMaker.cpp @@ -186,7 +186,8 @@ namespace at_npu int index = 0; do { - if (at_npu::native::aoe::aoe_manager().IsAoeEnabled()) { + if (at_npu::native::aoe::aoe_manager().IsAoeEnabled() && + !at_npu::native::aoe::aoe_manager().IsInBlacklist(name)) { ret = at_npu::native::AclGenGraphAndDumpForOp( name.c_str(), inputSize, @@ -238,7 +239,8 @@ namespace at_npu AclopSetCompileFlag(aclOpCompileFlag::ACL_OP_COMPILE_DEFAULT); reset_flag = true; } - if (at_npu::native::aoe::aoe_manager().IsAoeEnabled()) { + if (at_npu::native::aoe::aoe_manager().IsAoeEnabled() && + !at_npu::native::aoe::aoe_manager().IsInBlacklist(cur_paras->opType)) { ret = at_npu::native::AclGenGraphAndDumpForOp( (cur_paras->opType).c_str(), cur_paras->paras.input_num, diff --git a/torch_npu/csrc/framework/aoe/AoeUtils.cpp b/torch_npu/csrc/framework/aoe/AoeUtils.cpp index e1fe83ce9b1..fd4e892822a 100644 --- a/torch_npu/csrc/framework/aoe/AoeUtils.cpp +++ b/torch_npu/csrc/framework/aoe/AoeUtils.cpp @@ -47,6 +47,14 @@ bool AoeDumpGraphManager::IsAoeEnabled() const { return aoe_enable; } +bool AoeDumpGraphManager::IsInBlacklist(const std::string &opName) const { + if (black_list_.find(opName) != black_list_.end()) + { + return true; + } + return false; +} + AoeDumpGraphManager& aoe_manager() { static AoeDumpGraphManager instance; return instance; diff --git a/torch_npu/csrc/framework/aoe/AoeUtils.h b/torch_npu/csrc/framework/aoe/AoeUtils.h index 085a2e42302..23770d741e5 100644 --- a/torch_npu/csrc/framework/aoe/AoeUtils.h +++ b/torch_npu/csrc/framework/aoe/AoeUtils.h @@ -17,6 +17,7 @@ #ifndef __NATIVE_NPU_TOOLS_AOEUTILS__ #define __NATIVE_NPU_TOOLS_AOEUTILS__ +#include #include #include @@ -34,11 +35,13 @@ public: void EnableAoe(); bool IsAoeEnabled() const; - + bool IsInBlacklist(const std::string &opName) const; + bool aoe_enable=false; // to save graph for autotune, default path is ./ std::string autotune_graphdumppath="./"; aclGraphDumpOption* AclGraphDumpOption=NULL; + std::unordered_set black_list_ = {"TransData"}; }; -- Gitee