diff --git a/torch_npu/csrc/framework/OpParamMaker.cpp b/torch_npu/csrc/framework/OpParamMaker.cpp index 48acf5e9adb937ec6d9ee8287d0b367b63e45e45..16a8cc63e6f6118c28def000bbd69a5464a2b2f8 100644 --- a/torch_npu/csrc/framework/OpParamMaker.cpp +++ b/torch_npu/csrc/framework/OpParamMaker.cpp @@ -186,7 +186,8 @@ namespace at_npu int index = 0; do { - if (at_npu::native::aoe::aoe_manager().IsAoeEnabled()) { + if (at_npu::native::aoe::aoe_manager().IsAoeEnabled() && + !at_npu::native::aoe::aoe_manager().IsInBlacklist(name)) { ret = at_npu::native::AclGenGraphAndDumpForOp( name.c_str(), inputSize, @@ -238,7 +239,8 @@ namespace at_npu AclopSetCompileFlag(aclOpCompileFlag::ACL_OP_COMPILE_DEFAULT); reset_flag = true; } - if (at_npu::native::aoe::aoe_manager().IsAoeEnabled()) { + if (at_npu::native::aoe::aoe_manager().IsAoeEnabled() && + !at_npu::native::aoe::aoe_manager().IsInBlacklist(cur_paras->opType)) { ret = at_npu::native::AclGenGraphAndDumpForOp( (cur_paras->opType).c_str(), cur_paras->paras.input_num, diff --git a/torch_npu/csrc/framework/aoe/AoeUtils.cpp b/torch_npu/csrc/framework/aoe/AoeUtils.cpp index e1fe83ce9b17342fe2a1a529e9906a247219e0bd..fd4e892822a9b2943f2d9e83f52f612789926697 100644 --- a/torch_npu/csrc/framework/aoe/AoeUtils.cpp +++ b/torch_npu/csrc/framework/aoe/AoeUtils.cpp @@ -47,6 +47,14 @@ bool AoeDumpGraphManager::IsAoeEnabled() const { return aoe_enable; } +bool AoeDumpGraphManager::IsInBlacklist(const std::string &opName) const { + if (black_list_.find(opName) != black_list_.end()) + { + return true; + } + return false; +} + AoeDumpGraphManager& aoe_manager() { static AoeDumpGraphManager instance; return instance; diff --git a/torch_npu/csrc/framework/aoe/AoeUtils.h b/torch_npu/csrc/framework/aoe/AoeUtils.h index 085a2e42302ad6a0bb9fddbf41ed3e3259dc94aa..23770d741e5dcb2127c175862b3ae7fbb27609eb 100644 --- a/torch_npu/csrc/framework/aoe/AoeUtils.h +++ b/torch_npu/csrc/framework/aoe/AoeUtils.h @@ -17,6 +17,7 @@ #ifndef __NATIVE_NPU_TOOLS_AOEUTILS__ #define __NATIVE_NPU_TOOLS_AOEUTILS__ +#include #include #include @@ -34,11 +35,13 @@ public: void EnableAoe(); bool IsAoeEnabled() const; - + bool IsInBlacklist(const std::string &opName) const; + bool aoe_enable=false; // to save graph for autotune, default path is ./ std::string autotune_graphdumppath="./"; aclGraphDumpOption* AclGraphDumpOption=NULL; + std::unordered_set black_list_ = {"TransData"}; };