diff --git a/test/test_aoe/test_aoe.py b/test/test_aoe/test_aoe.py new file mode 100644 index 0000000000000000000000000000000000000000..c6310ba546882da2058cc3a8c14187db1e80c2da --- /dev/null +++ b/test/test_aoe/test_aoe.py @@ -0,0 +1,79 @@ +# Copyright (c) 2020, Huawei Technologies.All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import shutil +import torch +from torch_npu.testing.common_utils import TestCase, run_tests +import torch_npu + +class SmallModel(torch.nn.Module): + def __init__(self, in_channel, out_channel): + super(SmallModel, self).__init__() + self.conv1 = torch.nn.Conv2d(in_channel, in_channel, 1) + self.relu1 = torch.nn.ReLU() + self.conv2 = torch.nn.Conv2d(in_channel, out_channel, 1) + + def forward(self, input_1): + input_1 = self.conv1(input_1) + input_1 = self.relu1(input_1) + input_1 = self.conv2(input_1) + return input_1.reshape(input_1.shape[0], -1) + +class TestAoe(TestCase): + results_path = os.path.join(os.path.split(os.path.realpath(__file__))[0], "graphs") + + @classmethod + def setUpClass(cls): + if os.path.exists(TestAoe.results_path): + shutil.rmtree(TestAoe.results_path) + os.makedirs(TestAoe.results_path) + TestAoe.enable_aoe() + + @classmethod + def tearDownClass(cls): + if os.path.exists(TestAoe.results_path): + shutil.rmtree(TestAoe.results_path) + + @classmethod + def enable_aoe(cls): + option = {"autotune": "enable", "autotunegraphdumppath": TestAoe.results_path} + torch.npu.set_option(option) + + def test_aoe_dumpgraph(self): + def train(): + for index in range(steps): + x = torch.rand(input_shape).to(device) + y = torch.rand(out_shape).reshape(out_shape[0], -1).to(device) + y_pred = model(x) + loss = criterion(y_pred, y) + optimizer.zero_grad() + loss.backward() + optimizer.step() + input_shape = (4, 3, 24, 24) + out_shape = (4, 12, 24, 24) + steps = 5 + device = "npu:0" if torch.npu.is_available() else "cpu" + model = SmallModel(input_shape[1], out_shape[1]).to(device) + criterion = torch.nn.MSELoss(reduction='sum') + optimizer = torch.optim.SGD(model.parameters(), lr=1e-4) + train() + + file_list = os.listdir(TestAoe.results_path) + if torch.npu.is_available(): + self.assertTrue(len(file_list) > 0) + + +if __name__ == '__main__': + run_tests() diff --git a/third_party/acl/inc/acl/acl_op_compiler.h b/third_party/acl/inc/acl/acl_op_compiler.h index 438de77a2a8d1760be5ee8ffd805ec6f7b09acf4..0807efbd8ed85ed77531ee78904e6ba9515f4103 100644 --- a/third_party/acl/inc/acl/acl_op_compiler.h +++ b/third_party/acl/inc/acl/acl_op_compiler.h @@ -105,6 +105,54 @@ ACL_FUNC_VISIBILITY aclError aclopCompileAndExecute(const char *opType, */ ACL_FUNC_VISIBILITY aclError aclSetCompileopt(aclCompileOpt opt, const char *value); +/** + * @ingroup AscendCL + * @brief set compile option + * + * @param aclCompileOpt [IN] compile option + * @param value [IN] pointer for the option value + * + * @retval ACL_SUCCESS The function is successfully executed. + * @retval OtherValues Failure + */ +ACL_FUNC_VISIBILITY aclError aclSetCompileopt(aclCompileOpt opt, const char *value); + +typedef enum { + ACL_GRAPH_STAGE_ORIGIN = 0, // default + ACL_GRAPH_STAGE_FUZZ = 1, +} aclGraphStage; + +typedef struct aclGraphDumpOption aclGraphDumpOption; + +/** + * @ingroup AscendCL + * @brief dump op graph for AOE + * + * @param opType [IN] op type + * @param numInputs [IN] number of inputs + * @param inputDesc [IN] pointer to array of input tensor descriptions + * @param inputs [IN] pointer to array of input buffers + * @param numOutputs [IN] number of outputs + * @param outputDesc [IN] pointer to array of output tensor descriptions + * @param outputs [IN] pointer to array of outputs buffers + * @param attr [IN] pointer to instance of aclopAttr. + * may pass nullptr if the op has no attribute + * @param engineType [IN] engine type + * @param graphDumpPath [IN] path to save dump graph of op + * @param aclGraphDumpOption [IN] dump graph option + * @retval ACL_ERROR_NONE The function is successfully executed. + * @retval OtherValues Failure + */ +ACL_FUNC_VISIBILITY aclError aclGenGraphAndDumpForOp(const char *opType, + int numInputs, const aclTensorDesc *const inputDesc[], const aclDataBuffer *const inputs[], + int numOutputs, const aclTensorDesc *const outputDesc[], aclDataBuffer *const outputs[], + const aclopAttr *attr, aclopEngineType engineType, const char *graphDumpPath, + aclGraphDumpOption* graphdumpOpt); + +ACL_FUNC_VISIBILITY aclGraphDumpOption* aclCreateGraphDumpOpt(); + +ACL_FUNC_VISIBILITY aclError aclDestroyGraphDumpOpt(aclGraphDumpOption* aclGraphDumpOpt); + #ifdef __cplusplus } #endif diff --git a/third_party/acl/libs/acl_op_compiler.cpp b/third_party/acl/libs/acl_op_compiler.cpp index ec587c867fc1a9bd723d1e9b40f5052ead9dc999..ea5feee96822a19c6482e48e9a57c5a0d05a2385 100644 --- a/third_party/acl/libs/acl_op_compiler.cpp +++ b/third_party/acl/libs/acl_op_compiler.cpp @@ -82,3 +82,25 @@ aclError aclSetCompileopt( return 0; } +aclError aclGenGraphAndDumpForOp( + const char *opType, + int numInputs, + const aclTensorDesc *const inputDesc[], + const aclDataBuffer *const inputs[], + int numOutputs, + const aclTensorDesc *const outputDesc[], + aclDataBuffer *const outputs[], + const aclopAttr *attr, + aclopEngineType engineType, + const char *graphDumpPath, + aclGraphDumpOption* graphdumpOpt) { + return 0; +} + +aclGraphDumpOption* aclCreateGraphDumpOpt() { + return NULL; +} + +aclError aclDestroyGraphDumpOpt(aclGraphDumpOption* aclGraphDumpOpt) { + return 0; +} \ No newline at end of file diff --git a/torch_npu/csrc/framework/OpParamMaker.cpp b/torch_npu/csrc/framework/OpParamMaker.cpp index 8a542cc3b9a2fd7c2ee3b4b003eeda5f7a8eaf34..4512721318e6f94a32ff93f40249d03b024d987a 100644 --- a/torch_npu/csrc/framework/OpParamMaker.cpp +++ b/torch_npu/csrc/framework/OpParamMaker.cpp @@ -18,7 +18,7 @@ #include #include -#include "torch_npu/csrc/framework/aoe/AutoTune.h" +#include "torch_npu/csrc/framework/aoe/AoeUtils.h" #include "torch_npu/csrc/framework/utils/NpuFuzzyBlacklist.h" #include "torch_npu/csrc/framework/utils/CalcuOpUtil.h" #include "torch_npu/csrc/framework/utils/NpuUtils.h" @@ -171,7 +171,6 @@ namespace at_npu aclError OpCommandImpl::InnerRun(string name, AclExecParam ¶ms) { - AutotuneManager::GetInstance()->PushGraph(name, params.graph); auto stream = c10::npu::getCurrentNPUStream(); auto inputSize = params.inBuffer.size(); auto outputSize = params.outBuffer.size(); @@ -185,6 +184,24 @@ namespace at_npu int index = 0; do { + if (at_npu::native::aoe::aoe_manager().IsAoeEnabled()) { + ret = at_npu::native::AclGenGraphAndDumpForOp( + name.c_str(), + inputSize, + params.inDesc.data(), + params.inBuffer.data(), + outputSize, + params.outDesc.data(), + params.outBuffer.data(), + params.attr, + ACL_ENGINE_SYS, + at_npu::native::aoe::aoe_manager().GetDumpGraphPath().c_str(), + nullptr); + if (ret != ACL_ERROR_NONE) { + C10_NPU_SHOW_ERR_MSG(); + TORCH_CHECK(false, "In aoe mode, AclGenGraphAndDumpForOp failed!"); + } + } ret = aclopCompileAndExecute( name.c_str(), inputSize, @@ -222,6 +239,24 @@ namespace at_npu int index = 0; do { + if (at_npu::native::aoe::aoe_manager().IsAoeEnabled()) { + ret = at_npu::native::AclGenGraphAndDumpForOp( + (cur_paras->opType).c_str(), + cur_paras->paras.input_num, + cur_paras->paras.input_desc, + cur_paras->paras.input_data_buf, + cur_paras->paras.output_num, + cur_paras->paras.output_desc, + cur_paras->paras.output_data_buf, + cur_paras->attr, + ACL_ENGINE_SYS, + at_npu::native::aoe::aoe_manager().GetDumpGraphPath().c_str(), + nullptr); + if (ret != ACL_ERROR_NONE) { + C10_NPU_SHOW_ERR_MSG(); + TORCH_CHECK(false, "In aoe mode, AclGenGraphAndDumpForOp failed!"); + } + } ret = aclopCompileAndExecute( (cur_paras->opType).c_str(), cur_paras->paras.input_num, diff --git a/torch_npu/csrc/framework/OpParamMaker.h b/torch_npu/csrc/framework/OpParamMaker.h index 89679fd6d9b4bae2cb49ace84a52e313fa31fe27..92633ab259150c3264b0988b5819cf0d720de926 100644 --- a/torch_npu/csrc/framework/OpParamMaker.h +++ b/torch_npu/csrc/framework/OpParamMaker.h @@ -21,7 +21,6 @@ #include "third_party/acl/inc/acl/acl_base.h" #include "torch_npu/csrc/framework/interface/AclOpCompileInterface.h" #include "torch_npu/csrc/framework/NPUDefine.h" -#include "torch_npu/csrc/framework/interface/Graph.h" namespace at_npu { @@ -220,7 +219,6 @@ namespace at_npu void SetName(string &name) { opName = name; - execParam.graph.Name(name); } void AddInput( @@ -230,7 +228,6 @@ namespace at_npu aclFormat format) { inputCounter += 1; - execParam.graph.Input(desc); execParam.inDesc.emplace_back(std::move(desc)); execParam.inBuffer.emplace_back(std::move(buffer)); execParam.inDims.emplace_back(dim); @@ -246,7 +243,6 @@ namespace at_npu { AddInput(desc, buffer, dim, format); execParam.hostMem.emplace_back(hostTensor); - execParam.graph.SetConst(hostTensor.data_ptr(), hostTensor.nbytes()); } void AddConst(c10::SmallVector dimList) @@ -268,7 +264,6 @@ namespace at_npu int64_t dim, aclFormat format) { - execParam.graph.Output(desc); execParam.outDesc.emplace_back(std::move(desc)); execParam.outBuffer.emplace_back(std::move(buffer)); execParam.outDims.emplace_back(dim); @@ -281,7 +276,6 @@ namespace at_npu InitAttr(); AttrInfoMaker::Add(value, attrInfo); OpAttrMaker::Set(execParam.attr, attrName, value); - execParam.graph.AddAttr(attrName, value); execParam.hasAttr = true; } @@ -447,7 +441,6 @@ namespace at_npu aclopAttr *attr = nullptr; bool hasAttr = false; - Graph graph; }; void InitAttr() @@ -455,7 +448,6 @@ namespace at_npu if (execParam.attr == nullptr) { execParam.attr = aclopCreateAttr(); - execParam.graph.Make(); } } diff --git a/torch_npu/csrc/framework/interface/GeHelper.cpp b/torch_npu/csrc/framework/aoe/AoeUtils.cpp similarity index 38% rename from torch_npu/csrc/framework/interface/GeHelper.cpp rename to torch_npu/csrc/framework/aoe/AoeUtils.cpp index 1aa6e32e8dded78698ae5f541697af4845817433..e1fe83ce9b17342fe2a1a529e9906a247219e0bd 100644 --- a/torch_npu/csrc/framework/interface/GeHelper.cpp +++ b/torch_npu/csrc/framework/aoe/AoeUtils.cpp @@ -1,4 +1,5 @@ // Copyright (c) 2020 Huawei Technologies Co., Ltd +// Copyright (c) 2019, Facebook CORPORATION. // All rights reserved. // // Licensed under the BSD 3-Clause License (the "License"); @@ -13,43 +14,44 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "torch_npu/csrc/framework/interface/GeHelper.h" - -namespace at_npu -{ - namespace native - { - - ge::TensorDesc GeHelper::Convert(const aclTensorDesc *desc) - { - ge::Shape shape(GetTensorDescDims(desc)); - auto format = Convert(aclGetTensorDescFormat(desc)); - auto dataType = Convert(aclGetTensorDescType(desc)); - ge::TensorDesc tensorDesc(shape, format, dataType); - return tensorDesc; - } - - std::vector GeHelper::GetTensorDescDims(const aclTensorDesc *desc) - { - auto size = aclGetTensorDescNumDims(desc); - std::vector dims; - dims.resize(size); - for (int i = 0; i < size; i++) - { - dims[i] = aclGetTensorDescDim(desc, i); - } - return dims; - } - - ge::DataType GeHelper::Convert(aclDataType dataType) - { - return (ge::DataType)dataType; - } - - ge::Format GeHelper::Convert(aclFormat format) - { - return (ge::Format)format; - } - - } // namespace native -} // namespace at_npu +#include "torch_npu/csrc/framework/interface/AclOpCompileInterface.h" +#include "torch_npu/csrc/framework/aoe/AoeUtils.h" + +namespace at_npu { +namespace native { +namespace aoe { + +void AoeDumpGraphManager::SetDumpGraphPath(const std::string& dump_path) { + autotune_graphdumppath = dump_path; +} + +std::string AoeDumpGraphManager::GetDumpGraphPath() const { + return autotune_graphdumppath; +} + +aclGraphDumpOption* AoeDumpGraphManager::CreateGraphDumpOption() { + AclGraphDumpOption = AclCreateGraphDumpOpt(); + return AclGraphDumpOption; +} + +void AoeDumpGraphManager::DestropyGraphDumpOption() { + AclDestroyGraphDumpOpt(AclGraphDumpOption); + AclGraphDumpOption = NULL; +} + +void AoeDumpGraphManager::EnableAoe() { + aoe_enable = true; +} + +bool AoeDumpGraphManager::IsAoeEnabled() const { + return aoe_enable; +} + +AoeDumpGraphManager& aoe_manager() { + static AoeDumpGraphManager instance; + return instance; +} + +} // namespace aoe +} // namespace native +} // namespace at_npu \ No newline at end of file diff --git a/torch_npu/csrc/framework/interface/GeHelper.h b/torch_npu/csrc/framework/aoe/AoeUtils.h similarity index 48% rename from torch_npu/csrc/framework/interface/GeHelper.h rename to torch_npu/csrc/framework/aoe/AoeUtils.h index 502cdfd2f535699e6f784feeea580d65b5a94bf6..085a2e42302ad6a0bb9fddbf41ed3e3259dc94aa 100644 --- a/torch_npu/csrc/framework/interface/GeHelper.h +++ b/torch_npu/csrc/framework/aoe/AoeUtils.h @@ -1,4 +1,5 @@ // Copyright (c) 2020 Huawei Technologies Co., Ltd +// Copyright (c) 2019, Facebook CORPORATION. // All rights reserved. // // Licensed under the BSD 3-Clause License (the "License"); @@ -13,35 +14,38 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef __PLUGIN_NATIVE_NPU_INTERFACE_GE_HELPER__ -#define __PLUGIN_NATIVE_NPU_INTERFACE_GE_HELPER__ +#ifndef __NATIVE_NPU_TOOLS_AOEUTILS__ +#define __NATIVE_NPU_TOOLS_AOEUTILS__ -#include "third_party/acl/inc/graph/tensor.h" // TensorDesc -#include "third_party/acl/inc/graph/types.h" // Format -#include "third_party/acl/inc/acl/acl_base.h" // aclTensorDesc +#include +#include namespace at_npu { namespace native { +namespace aoe { -/** - This class is used to transform acl interface to ge interface - e.g aclTensorDesc vs ge::TensorDesc - */ -class GeHelper { +class AoeDumpGraphManager { public: - /** - This API is used to transform aclTensorDesc to ge::TensorDesc - */ - static ge::TensorDesc Convert(const aclTensorDesc* desc); -private: - static ge::DataType Convert(aclDataType dataType); - static ge::Format Convert(aclFormat format); -private: - static std::vector GetTensorDescDims(const aclTensorDesc* desc); + void SetDumpGraphPath(const std::string& dump_path); + std::string GetDumpGraphPath() const; + + aclGraphDumpOption* CreateGraphDumpOption(); + void DestropyGraphDumpOption(); + + void EnableAoe(); + bool IsAoeEnabled() const; + + bool aoe_enable=false; + // to save graph for autotune, default path is ./ + std::string autotune_graphdumppath="./"; + aclGraphDumpOption* AclGraphDumpOption=NULL; + }; +AoeDumpGraphManager& aoe_manager(); + +} // namespace aoe } // namespace native } // namespace at_npu - -#endif // __NATIVE_NPU_INTERFACE_GE_HELPER__ \ No newline at end of file +#endif // __NATIVE_NPU_TOOLS_AOEUTILS__ \ No newline at end of file diff --git a/torch_npu/csrc/framework/aoe/AutoTune.cpp b/torch_npu/csrc/framework/aoe/AutoTune.cpp deleted file mode 100644 index ae4fd4c9e4c3b053bfece9d6552586b02e1017b7..0000000000000000000000000000000000000000 --- a/torch_npu/csrc/framework/aoe/AutoTune.cpp +++ /dev/null @@ -1,195 +0,0 @@ -// Copyright (c) 2020 Huawei Technologies Co., Ltd -// All rights reserved. -// -// Licensed under the BSD 3-Clause License (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://opensource.org/licenses/BSD-3-Clause -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include - -#include "torch_npu/csrc/framework/aoe/AutoTune.h" -#include "torch_npu/csrc/framework/interface/AoeInterface.h" -#include "torch_npu/csrc/framework/interface/EnvVariables.h" -#include "torch_npu/csrc/framework/interface/GeHelper.h" - -namespace at_npu -{ - namespace native - { - // It is better to set MAX_TUNE_THREADS = 8~64, support set to be 1~64 - constexpr std::size_t MAX_TUNE_THREADS = 8; - // tune mode: 0 -> model mode, 2 -> op mode, 4 -> grad mode - constexpr std::size_t AOE_TUNE_MODE = 2; - AutotuneManager *AutotuneManager::GetInstance() - { - static AutotuneManager instance; - return &instance; - } - - AutotuneManager::AutotuneManager() - : isInited(false), - thread_pool_(std::make_shared(MAX_TUNE_THREADS)) - { - sessionOptions["job_type"] = std::to_string(AOE_TUNE_MODE).c_str(); - } - - AutotuneManager::~AutotuneManager() - { - DeInit(); - } - - void AutotuneManager::Init() - { - if (isInited) - { - return; - } - std::map globalOptions; - // graph tune parallel num, only support to be 1~64, default=8 - globalOptions["tuning_parallel_num"] = std::to_string(MAX_TUNE_THREADS).c_str(); - auto ret = aoe::initialize(globalOptions); - if (ret) - { - TORCH_CHECK(ret, "aoe::initialize failed. error code:", ret); - return; - } - isInited = true; - } - - void AutotuneManager::DeInit() - { - if (isInited) - { - this->ge_graphs.clear(); - aoe::finalize(); - isInited = false; - } - } - - void AutotuneManager::CreatSessions() - { - for (int i = 0; i < MAX_TUNE_THREADS; i++) - { - aoe::SessionId sessionId; - auto ret = aoe::create_session(sessionOptions, sessionId); - if (ret) - { - TORCH_CHECK(ret, "aoe::create_session failed. error code:", ret); - return; - } - this->sessionIDs.emplace_back(sessionId); - } - } - - void AutotuneManager::DestroySessions() - { - TORCH_CHECK(this->sessionIDs.size() == MAX_TUNE_THREADS, "The AOE sessionID nums should be same to MAX_TUNE_THREADS!"); - for (auto it = this->sessionIDs.begin(); it < this->sessionIDs.end(); it++) - { - aoe::destroy_session(*it); - } - this->sessionIDs.clear(); - } - - void AutotuneManager::PushGraph(const std::string &name, Graph &tuningGraph) - { - if (not env::AutoTuneEnabled()) - { - return; - } - // TransData does not support tuning. - if (name == "TransData") - { - return; - } - - if (!isInited) - { - Init(); - } - - ge::Graph ge_graph; - tuningGraph.GeGraph(ge_graph); - - ge_graphs.emplace_back(std::move(ge_graph)); - if (this->ge_graphs.size() < MAX_TUNE_THREADS) - { - return; - } - this->TuningGraphs(); - } - - void AutotuneManager::DoGraphsTune() - { - auto alive_thread_nums = thread_pool_->numAvailable(); - TORCH_CHECK(alive_thread_nums >= this->ge_graphs.size(), "The ge_graph size is greater than thread_pool_ size!"); - int tune_ops = -1; - if (alive_thread_nums > 0) - { - for (auto it = this->ge_graphs.begin(); it < this->ge_graphs.begin() + alive_thread_nums && it < this->ge_graphs.end(); it++) - { - tune_ops += 1; - thread_pool_->run(std::bind( - &AutotuneManager::DoGraphTune, - this, - *it, - this->sessionIDs[tune_ops])); - // need to sleep some seconds in master thread to ensure all threads are up! - std::this_thread::sleep_for(std::chrono::milliseconds(100)); - } - } - } - - void AutotuneManager::DoGraphTune(ge::Graph &ge_graph, aoe::SessionId sessionId) - { - auto ret = aoe::set_tuning_graph(sessionId, ge_graph); - if (ret) - { - TORCH_CHECK(ret, "aoe::set_tuning_graph failed. error code:", ret); - return; - } - std::map tuingOptions; - ret = aoe::tuning_graph(sessionId, tuingOptions); - if (ret) - { - TORCH_CHECK(ret, "aoe::tuning_graph failed. error code:", ret); - return; - } - } - - void AutotuneManager::TuningGraphs() - { - if (this->ge_graphs.size() == MAX_TUNE_THREADS && - thread_pool_->numAvailable() == MAX_TUNE_THREADS) - { - this->CreatSessions(); - this->DoGraphsTune(); - this->WaitThreadsFinished(); - this->ge_graphs.clear(); - this->DestroySessions(); - } - else - { - TORCH_CHECK(false, "TuningGraphs failed, the size of ge_graphs" - " and thread_pool's numAvailable should be same to MAX_TUNE_THREADS"); - } - } - - void AutotuneManager::WaitThreadsFinished() - { - if (thread_pool_ != nullptr || thread_pool_->numAvailable() < MAX_TUNE_THREADS) - { - thread_pool_->waitWorkComplete(); - } - } - - } // namespace native -} // namespace at_npu \ No newline at end of file diff --git a/torch_npu/csrc/framework/aoe/AutoTune.h b/torch_npu/csrc/framework/aoe/AutoTune.h deleted file mode 100644 index d0055f064092e9d664885b1a039ab878aa6a78d9..0000000000000000000000000000000000000000 --- a/torch_npu/csrc/framework/aoe/AutoTune.h +++ /dev/null @@ -1,64 +0,0 @@ -// Copyright (c) 2020 Huawei Technologies Co., Ltd -// All rights reserved. -// -// Licensed under the BSD 3-Clause License (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://opensource.org/licenses/BSD-3-Clause -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef __PLUGIN_NATIVE_NPU_AOE_AUTOTUNE__ -#define __PLUGIN_NATIVE_NPU_AOE_AUTOTUNE__ - -#include -#include - -#include "torch_npu/csrc/framework/interface/AoeInterface.h" -#include "torch_npu/csrc/framework/interface/Graph.h" - -namespace at_npu -{ - namespace native - { - /** - class Autotune provide API for tuning. - */ - class AutotuneManager - { - - public: - static AutotuneManager *GetInstance(); - void PushGraph(const std::string &name, Graph &tuningGraph); - - private: - std::atomic isInited; - std::shared_ptr thread_pool_; - std::vector sessionIDs; - std::vector ge_graphs; - std::map sessionOptions; - - AutotuneManager(); - ~AutotuneManager(); - - void Init(); - void DeInit(); - - void DoGraphsTune(); - void DoGraphTune(ge::Graph &ge_graph, aoe::SessionId sessionId); - void TuningGraphs(); - void WaitThreadsFinished(); - - void CreatSessions(); - void DestroySessions(); - }; - - } // namespace native -} // namespace at_npu - -#endif // __NATIVE_NPU_AOE_AUTOTUNE__ \ No newline at end of file diff --git a/torch_npu/csrc/framework/interface/AclOpCompileInterface.cpp b/torch_npu/csrc/framework/interface/AclOpCompileInterface.cpp index 05588c1b2f512df17b25dc62bb47801fa5e42fd2..94a7d08b1b0b0c79635515f27c405b1323653608 100644 --- a/torch_npu/csrc/framework/interface/AclOpCompileInterface.cpp +++ b/torch_npu/csrc/framework/interface/AclOpCompileInterface.cpp @@ -32,19 +32,60 @@ namespace at_npu REGISTER_LIBRARY(libacl_op_compiler) LOAD_FUNCTION(aclopSetCompileFlag) + LOAD_FUNCTION(aclGenGraphAndDumpForOp) + LOAD_FUNCTION(aclCreateGraphDumpOpt) + LOAD_FUNCTION(aclDestroyGraphDumpOpt) - aclError AclopSetCompileFlag(aclOpCompileFlag flag) - { - typedef aclError (*aclopSetCompileFlagFunc)(aclOpCompileFlag); - static aclopSetCompileFlagFunc func = nullptr; - if (func == nullptr) - { - func = (aclopSetCompileFlagFunc)GET_FUNC(aclopSetCompileFlag); - } - TORCH_CHECK(func, "Failed to find function ", "aclopSetCompileFlag"); - auto ret = func(flag); - return ret; - } +aclError AclopSetCompileFlag(aclOpCompileFlag flag) { + typedef aclError (*aclopSetCompileFlagFunc)(aclOpCompileFlag); + static aclopSetCompileFlagFunc func = nullptr; + if (func == nullptr) + { + func = (aclopSetCompileFlagFunc)GET_FUNC(aclopSetCompileFlag); + } + TORCH_CHECK(func, "Failed to find function ", "aclopSetCompileFlag"); + auto ret = func(flag); + return ret; +} + +aclError AclGenGraphAndDumpForOp(const char *opType, + int numInputs, const aclTensorDesc *const inputDesc[], const aclDataBuffer *const inputs[], + int numOutputs, const aclTensorDesc *const outputDesc[], aclDataBuffer *const outputs[], + const aclopAttr *attr, aclopEngineType engineType, const char *graphDumpPath, + aclGraphDumpOption* graphdumpOpt) { + typedef aclError(*AclGenGraphAndDumpForOpFunc)(const char *,int, + const aclTensorDesc *const [], const aclDataBuffer *const [], + int, const aclTensorDesc *const [], aclDataBuffer *const [], + const aclopAttr *, aclopEngineType, const char *, aclGraphDumpOption*); + static AclGenGraphAndDumpForOpFunc func = nullptr; + if (func == nullptr) { + func = (AclGenGraphAndDumpForOpFunc)GET_FUNC(aclGenGraphAndDumpForOp); + } + TORCH_CHECK(func, "Failed to find function ", "aclGenGraphAndDumpForOp"); + auto ret = func(opType, numInputs, inputDesc, inputs, numOutputs, + outputDesc, outputs, attr, engineType, graphDumpPath, graphdumpOpt); + return ret; +} + +aclGraphDumpOption* AclCreateGraphDumpOpt() { + typedef aclGraphDumpOption*(*AclCreateGraphDumpOptFunc)(); + static AclCreateGraphDumpOptFunc func = nullptr; + if (func == nullptr) { + func = (AclCreateGraphDumpOptFunc)GET_FUNC(aclCreateGraphDumpOpt); + } + TORCH_CHECK(func, "Failed to find function ", "aclCreateGraphDumpOpt"); + return func(); +} + +aclError AclDestroyGraphDumpOpt(aclGraphDumpOption* aclGraphDumpOpt) { + typedef aclError(*AclDestroyGraphDumpOptFunc)(aclGraphDumpOption*); + static AclDestroyGraphDumpOptFunc func = nullptr; + if (func == nullptr) { + func = (AclDestroyGraphDumpOptFunc)GET_FUNC(aclDestroyGraphDumpOpt); + } + TORCH_CHECK(func, "Failed to find function ", "aclDestroyGraphDumpOpt"); + return func(aclGraphDumpOpt); +} } // namespace native } // namespace at_npu \ No newline at end of file diff --git a/torch_npu/csrc/framework/interface/AclOpCompileInterface.h b/torch_npu/csrc/framework/interface/AclOpCompileInterface.h index 85d31e596149f85674fd5bb013e61bb8fd0e9e19..f59f7fe841cf652a82b50755a34378e2b7c730e4 100644 --- a/torch_npu/csrc/framework/interface/AclOpCompileInterface.h +++ b/torch_npu/csrc/framework/interface/AclOpCompileInterface.h @@ -32,6 +32,46 @@ namespace native { */ aclError AclopSetCompileFlag(aclOpCompileFlag flag); +/** + * @ingroup AscendCL + * @brief dump op graph for AOE + * + * @param opType [IN] op type + * @param numInputs [IN] number of inputs + * @param inputDesc [IN] pointer to array of input tensor descriptions + * @param inputs [IN] pointer to array of input buffers + * @param numOutputs [IN] number of outputs + * @param outputDesc [IN] pointer to array of output tensor descriptions + * @param outputs [IN] pointer to array of outputs buffers + * @param attr [IN] pointer to instance of aclopAttr. + * may pass nullptr if the op has no attribute + * @param engineType [IN] engine type + * @param compileFlag [IN] compile flag + * @param graphDumpPath [IN] path to save dump graph of op + * @param aclGraphDumpOption [IN] dump graph option + * @retval ACL_ERROR_NONE The function is successfully executed. + * @retval OtherValues Failure + */ +aclError AclGenGraphAndDumpForOp(const char *opType, + int numInputs, const aclTensorDesc *const inputDesc[], const aclDataBuffer *const inputs[], + int numOutputs, const aclTensorDesc *const outputDesc[], aclDataBuffer *const outputs[], + const aclopAttr *attr, aclopEngineType engineType, const char *graphDumpPath, + aclGraphDumpOption* graphdumpOpt); + +/** + * @brief create the dump option for AclGenGraphAndDumpForOp API, used for AOE + * @retval created aclGraphDumpOption + */ +aclGraphDumpOption* AclCreateGraphDumpOpt(); + +/** + * @brief destroy the dump option created by aclCreateGraphDumpOpt + * @param aclGraphDumpOpt [IN] dump option created by aclCreateGraphDumpOpt + * @retval ACL_ERROR_NONE The function is successfully executed. + * @retval OtherValues Failure + */ +aclError AclDestroyGraphDumpOpt(aclGraphDumpOption* aclGraphDumpOpt); + } // namespace native } // namespace at_npu diff --git a/torch_npu/csrc/framework/interface/AoeInterface.cpp b/torch_npu/csrc/framework/interface/AoeInterface.cpp deleted file mode 100644 index 71219367401ddcede76e9f81955e325a79e20513..0000000000000000000000000000000000000000 --- a/torch_npu/csrc/framework/interface/AoeInterface.cpp +++ /dev/null @@ -1,124 +0,0 @@ -// Copyright (c) 2020 Huawei Technologies Co., Ltd -// All rights reserved. -// -// Licensed under the BSD 3-Clause License (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://opensource.org/licenses/BSD-3-Clause -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include -#include -#include - -#include "torch_npu/csrc/framework/interface/AoeInterface.h" - -namespace at_npu -{ - namespace native - { - namespace aoe - { - -#undef LOAD_FUNCTION -#define LOAD_FUNCTION(funcName) \ - REGISTER_FUNCTION(libaoe_tuning, funcName) -#undef GET_FUNC -#define GET_FUNC(funcName) \ - GET_FUNCTION(libaoe_tuning, funcName) - - REGISTER_LIBRARY(libaoe_tuning) - LOAD_FUNCTION(AoeInitialize) - LOAD_FUNCTION(AoeFinalize) - LOAD_FUNCTION(AoeCreateSession) - LOAD_FUNCTION(AoeDestroySession) - LOAD_FUNCTION(AoeSetTuningGraph) - LOAD_FUNCTION(AoeTuningGraph) - - AoeStatus initialize(const std::map &globalOptions) - { - typedef AoeStatus (*aoeInitFunc)(const std::map &); - static aoeInitFunc func = nullptr; - if (func == nullptr) - { - func = (aoeInitFunc)GET_FUNC(AoeInitialize); - } - TORCH_CHECK(func, "Failed to find function ", "AoeInitialize"); - auto ret = func(globalOptions); - return ret; - } - - AoeStatus finalize() - { - typedef AoeStatus (*aoeFinalizeFunc)(); - static aoeFinalizeFunc func = nullptr; - if (func == nullptr) - { - func = (aoeFinalizeFunc)GET_FUNC(AoeFinalize); - } - TORCH_CHECK(func, "Failed to find function ", "AoeFinalize"); - auto ret = func(); - return ret; - } - - AoeStatus create_session(const std::map &sessionOptions, SessionId &sessionId) - { - typedef AoeStatus (*aoeCreateSession)(const std::map &, SessionId &); - aoeCreateSession func = nullptr; - if (func == nullptr) - { - func = (aoeCreateSession)GET_FUNC(AoeCreateSession); - } - TORCH_CHECK(func, "Failed to find function ", "AoeCreateSession"); - auto ret = func(sessionOptions, sessionId); - return ret; - } - - AoeStatus destroy_session(SessionId sessionId) - { - typedef AoeStatus (*aoeDestroySession)(SessionId); - aoeDestroySession func = nullptr; - if (func == nullptr) - { - func = (aoeDestroySession)GET_FUNC(AoeDestroySession); - } - TORCH_CHECK(func, "Failed to find function ", "AoeDestroySession"); - auto ret = func(sessionId); - return ret; - } - - AoeStatus set_tuning_graph(SessionId sessionId, ge::Graph &tuningGraph) - { - typedef AoeStatus (*aoeSetTuningGraph)(SessionId, ge::Graph &); - aoeSetTuningGraph func = nullptr; - if (func == nullptr) - { - func = (aoeSetTuningGraph)GET_FUNC(AoeSetTuningGraph); - } - TORCH_CHECK(func, "Failed to find function ", "AoeSetTuningGraph"); - auto ret = func(sessionId, tuningGraph); - return ret; - } - - AoeStatus tuning_graph(SessionId sessionId, const std::map &tuingOptions) - { - typedef AoeStatus (*aoeTuningGraph)(SessionId, const std::map &); - aoeTuningGraph func = nullptr; - if (func == nullptr) - { - func = (aoeTuningGraph)GET_FUNC(AoeTuningGraph); - } - TORCH_CHECK(func, "Failed to find function ", "AoeTuningGraph"); - auto ret = func(sessionId, tuingOptions); - return ret; - } - - } // namespace aoe - } // namespace native -} // namespace at_npu diff --git a/torch_npu/csrc/framework/interface/AoeInterface.h b/torch_npu/csrc/framework/interface/AoeInterface.h deleted file mode 100644 index ce8b7e41486f9f6c52b1a449a4b0acd2fd83231f..0000000000000000000000000000000000000000 --- a/torch_npu/csrc/framework/interface/AoeInterface.h +++ /dev/null @@ -1,64 +0,0 @@ -// Copyright (c) 2020 Huawei Technologies Co., Ltd -// All rights reserved. -// -// Licensed under the BSD 3-Clause License (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://opensource.org/licenses/BSD-3-Clause -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef __PLUGIN_NATIVE_NPU_INTERFACE_AOEINTERFACE__ -#define __PLUGIN_NATIVE_NPU_INTERFACE_AOEINTERFACE__ - -#include "third_party/acl/inc/graph/ascend_string.h" -#include "third_party/acl/inc/graph/graph.h" - -namespace at_npu { -namespace native { -namespace aoe { - -/** - SessionId is provide by aoe, it used to store session id. - */ -using SessionId = uint64_t; -/** - AoeStatues is provide by aoe, it used to store the return value. - */ -using AoeStatus = int32_t; - -/** - This API is used to init aoe, it need be called once at process. - */ -AoeStatus initialize(const std::map &globalOptions); -/** - This API is used to finalize aoe, it need be called once at process. - */ -AoeStatus finalize(); -/** - This API is used to create session, this operation should be called after init. - */ -AoeStatus create_session(const std::map &sessionOptions, SessionId &sessionId); -/** - This API is used to destroy session - */ -AoeStatus destroy_session(SessionId sessionId); -/** - This API is used to associate session and graph - */ -AoeStatus set_tuning_graph(SessionId sessionId, ge::Graph &tuningGraph); -/** - This API is used to tuning graphs at session - */ -AoeStatus tuning_graph(SessionId sessionId, const std::map &tuingOptions); - -} // namespace aoe -} // namespace native -} // namespace at_npu - -#endif // __NATIVE_NPU_INTERFACE_AOEINTERFACE__ \ No newline at end of file diff --git a/torch_npu/csrc/framework/interface/EnvVariables.cpp b/torch_npu/csrc/framework/interface/EnvVariables.cpp index c902f6cae18994e4e5de0724bafb2c24f6b6ed98..b37c0952c377bccbfc659f9fbfa8f425a71e7ff0 100644 --- a/torch_npu/csrc/framework/interface/EnvVariables.cpp +++ b/torch_npu/csrc/framework/interface/EnvVariables.cpp @@ -14,6 +14,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include #include #include #include @@ -22,115 +23,111 @@ #include "torch_npu/csrc/framework/utils/NpuProfilingDispatch.h" #include "third_party/acl/inc/acl/acl_mdl.h" #include "torch_npu/csrc/framework/interface/AclOpCompileInterface.h" +#include "torch_npu/csrc/framework/aoe/AoeUtils.h" +namespace at_npu { +namespace native { +namespace env { -namespace at_npu -{ - namespace native - { - namespace env - { - - REGISTER_OPTION(autotune) - REGISTER_OPTION_BOOL_FUNCTION(AutoTuneEnabled, autotune, "disable", "enable") - - REGISTER_OPTION_INIT_BY_ENV(bmmv2_enable) - REGISTER_OPTION_BOOL_FUNCTION(CheckBmmV2Enable, bmmv2_enable, "0", "1") - - REGISTER_OPTION_HOOK(mdldumpswitch, [](const std::string &val) - { - if (val == "enable") - { - aclmdlInitDump(); - } - else - { - aclmdlFinalizeDump(); - } - }) - REGISTER_OPTION_HOOK(mdldumpconfigpath, [](const std::string &val) - { aclmdlSetDump(val.c_str()); }) - - REGISTER_OPTION_HOOK(fuzzycompileswitch, [](const std::string &val) - { - if (val == "enable") - { - AclopSetCompileFlag(aclOpCompileFlag::ACL_OP_COMPILE_FUZZ); - } - else - { - AclopSetCompileFlag(aclOpCompileFlag::ACL_OP_COMPILE_DEFAULT); - } - }) - REGISTER_OPTION_BOOL_FUNCTION(CheckFuzzyEnable, fuzzycompileswitch, "disable", "enable") - - REGISTER_OPTION_HOOK(ACL_OP_DEBUG_LEVEL, [](const std::string &val) - { aclSetCompileopt(aclCompileOpt::ACL_OP_DEBUG_LEVEL, val.c_str()); }) - REGISTER_OPTION_HOOK(ACL_DEBUG_DIR, [](const std::string &val) - { aclSetCompileopt(aclCompileOpt::ACL_DEBUG_DIR, val.c_str()); }) - REGISTER_OPTION_HOOK(ACL_OP_COMPILER_CACHE_MODE, [](const std::string &val) - { aclSetCompileopt(aclCompileOpt::ACL_OP_COMPILER_CACHE_MODE, val.c_str()); }) - REGISTER_OPTION_HOOK(ACL_OP_COMPILER_CACHE_DIR, [](const std::string &val) - { aclSetCompileopt(aclCompileOpt::ACL_OP_COMPILER_CACHE_DIR, val.c_str()); }) - REGISTER_OPTION_HOOK(ACL_OP_SELECT_IMPL_MODE, [](const std::string &val) - { - if (val == "high_precision" || val == "high_performance") - { - aclSetCompileopt(aclCompileOpt::ACL_OP_SELECT_IMPL_MODE, val.c_str()); - } - else - { - TORCH_CHECK(0, "ACL_OP_SELECT_IMPL_MODE only support `high_precision` or " - " `high_performance`, but got ", - val); - } - }) - REGISTER_OPTION_HOOK(ACL_OPTYPELIST_FOR_IMPLMODE, [](const std::string &val) - { aclSetCompileopt(aclCompileOpt::ACL_OPTYPELIST_FOR_IMPLMODE, val.c_str()); }) - REGISTER_OPTION_HOOK(NPU_FUZZY_COMPILE_BLACKLIST, [](const std::string &val) - { FuzzyCompileBlacklist::GetInstance().RegisterBlacklist(val); }) - - REGISTER_OPTION_INIT_BY_ENV(PROFILING_MODE) - REGISTER_OPTION_BOOL_FUNCTION(CheckProfilingEnable, PROFILING_MODE, "false", "true"); - - REGISTER_OPTION_HOOK(deliverswitch, [](const std::string &val) - { - TORCH_CHECK( - CheckProfilingEnable(), - "before you prepare to deliver op, ", - "you should be enture profiling mode is on correctly!"); - if (val == "enable") - { - NpuProfilingDispatch::Instance().start(); - } - else - { - NpuProfilingDispatch::Instance().stop(); - } - }) - - REGISTER_OPTION_HOOK(profilerResultPath, [](const std::string &val) - { at::native::npu::NpuProfiling::Instance().Init(val); }) - - REGISTER_OPTION_HOOK(profiling, [](const std::string &val) - { - if (val.compare("start") == 0) - { - at::native::npu::NpuProfiling::Instance().Start(); - } - else if (val.compare("stop") == 0) - { - at::native::npu::NpuProfiling::Instance().Stop(); - } - else if (val.compare("finalize") == 0) - { - at::native::npu::NpuProfiling::Instance().Finalize(); - } - else - { - TORCH_CHECK(false, "profiling input: (", val, " ) error!") - } - }) - - } // namespace env - } // namespace native +void ValidPathCheck(const std::string& file_path) { + char abs_path[PATH_MAX] = {'\0'}; + if (realpath(file_path.c_str(), abs_path) == nullptr) { + TORCH_CHECK(0, "configPath path Fails, path ", (char*)file_path.c_str()); + } +} + +REGISTER_OPTION_HOOK(autotune, [](const std::string& val) { + if (val == "enable") { + at_npu::native::aoe::aoe_manager().EnableAoe(); + } +}) + +REGISTER_OPTION_HOOK(autotunegraphdumppath, [](const std::string& val) { + ValidPathCheck(val); + at_npu::native::aoe::aoe_manager().SetDumpGraphPath(val); +}) + +REGISTER_OPTION_INIT_BY_ENV(bmmv2_enable) +REGISTER_OPTION_BOOL_FUNCTION(CheckBmmV2Enable, bmmv2_enable, "0", "1") + +REGISTER_OPTION_HOOK(mdldumpswitch, [](const std::string &val){ + if (val == "enable") { + aclmdlInitDump(); + } else { + aclmdlFinalizeDump(); + } +}) +REGISTER_OPTION_HOOK(mdldumpconfigpath, [](const std::string &val) { + aclmdlSetDump(val.c_str()); +}) + +REGISTER_OPTION_HOOK(fuzzycompileswitch, [](const std::string &val) { + if (val == "enable") { + AclopSetCompileFlag(aclOpCompileFlag::ACL_OP_COMPILE_FUZZ); + } else { + AclopSetCompileFlag(aclOpCompileFlag::ACL_OP_COMPILE_DEFAULT); + } +}) +REGISTER_OPTION_BOOL_FUNCTION(CheckFuzzyEnable, fuzzycompileswitch, "disable", "enable") + +REGISTER_OPTION_HOOK(ACL_OP_DEBUG_LEVEL, [](const std::string &val) { + aclSetCompileopt(aclCompileOpt::ACL_OP_DEBUG_LEVEL, val.c_str()); +}) +REGISTER_OPTION_HOOK(ACL_DEBUG_DIR, [](const std::string &val) { + aclSetCompileopt(aclCompileOpt::ACL_DEBUG_DIR, val.c_str()); +}) + +REGISTER_OPTION_HOOK(ACL_OP_COMPILER_CACHE_MODE, [](const std::string &val) { + aclSetCompileopt(aclCompileOpt::ACL_OP_COMPILER_CACHE_MODE, val.c_str()); +}) + +REGISTER_OPTION_HOOK(ACL_OP_COMPILER_CACHE_DIR, [](const std::string &val) { + aclSetCompileopt(aclCompileOpt::ACL_OP_COMPILER_CACHE_DIR, val.c_str()); +}) + +REGISTER_OPTION_HOOK(ACL_OP_SELECT_IMPL_MODE, [](const std::string &val) { + if (val == "high_precision" || val == "high_performance") { + aclSetCompileopt(aclCompileOpt::ACL_OP_SELECT_IMPL_MODE, val.c_str()); + } else { + TORCH_CHECK(0, "ACL_OP_SELECT_IMPL_MODE only support `high_precision` or " + " `high_performance`, but got ", val); + } +}) + +REGISTER_OPTION_HOOK(ACL_OPTYPELIST_FOR_IMPLMODE, [](const std::string &val) + { aclSetCompileopt(aclCompileOpt::ACL_OPTYPELIST_FOR_IMPLMODE, val.c_str()); }) +REGISTER_OPTION_HOOK(NPU_FUZZY_COMPILE_BLACKLIST, [](const std::string &val) + { FuzzyCompileBlacklist::GetInstance().RegisterBlacklist(val); }) + +REGISTER_OPTION_INIT_BY_ENV(PROFILING_MODE) +REGISTER_OPTION_BOOL_FUNCTION(CheckProfilingEnable, PROFILING_MODE, "false", "true"); + +REGISTER_OPTION_HOOK(deliverswitch, [](const std::string &val) { + TORCH_CHECK(CheckProfilingEnable(), + "before you prepare to deliver op, ", + "you should be enture profiling mode is on correctly!"); + if (val == "enable") { + NpuProfilingDispatch::Instance().start(); + } else { + NpuProfilingDispatch::Instance().stop(); + } +}) + +REGISTER_OPTION_HOOK(profilerResultPath, [](const std::string &val) { + at::native::npu::NpuProfiling::Instance().Init(val); +}) + +REGISTER_OPTION_HOOK(profiling, [](const std::string &val) { + if (val.compare("start") == 0) { + at::native::npu::NpuProfiling::Instance().Start(); + } else if (val.compare("stop") == 0) { + at::native::npu::NpuProfiling::Instance().Stop(); + } else if (val.compare("finalize") == 0) { + at::native::npu::NpuProfiling::Instance().Finalize(); + } else { + TORCH_CHECK(false, "profiling input: (", val, " ) error!") + } +}) + +} // namespace env +} // namespace native } // namespace at_npu diff --git a/torch_npu/csrc/framework/interface/Graph.cpp b/torch_npu/csrc/framework/interface/Graph.cpp deleted file mode 100644 index eea8ec2c1fe1c673944738af80ee892771bf9be9..0000000000000000000000000000000000000000 --- a/torch_npu/csrc/framework/interface/Graph.cpp +++ /dev/null @@ -1,138 +0,0 @@ -// Copyright (c) 2020 Huawei Technologies Co., Ltd -// All rights reserved. -// -// Licensed under the BSD 3-Clause License (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://opensource.org/licenses/BSD-3-Clause -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "third_party/acl/inc/ge/ge_ir_build.h" // aclgrphGenerateForOp -#include "torch_npu/csrc/framework/interface/Graph.h" -#include "torch_npu/csrc/framework/interface/GeHelper.h" -#include "torch_npu/csrc/framework/utils/CalcuOpUtil.h" - -namespace at_npu -{ - namespace native - { - - void GNodeAttrMaker::Set(ge::GNode &op, const ge::AscendString &name, bool value) - { - op.SetAttr(name, value); - } - - void GNodeAttrMaker::Set(ge::GNode &op, const ge::AscendString &name, int64_t value) - { - op.SetAttr(name, value); - } - - void GNodeAttrMaker::Set(ge::GNode &op, const ge::AscendString &name, float value) - { - op.SetAttr(name, value); - } - - void GNodeAttrMaker::Set(ge::GNode &op, const ge::AscendString &name, std::string value) - { - ge::AscendString val(value.c_str()); - op.SetAttr(name, val); - } - - void GNodeAttrMaker::Set(ge::GNode &op, const ge::AscendString &name, c10::IntArrayRef value) - { - auto vec = value.vec(); - op.SetAttr(name, vec); - } - - void GNodeAttrMaker::Set(ge::GNode &op, const ge::AscendString &name, at::ArrayRef value) - { - auto vec = value.vec(); - op.SetAttr(name, vec); - } - - void GNodeAttrMaker::Set(ge::GNode &op, const ge::AscendString &name, c10::Scalar value) - { - float val = CalcuOpUtil::get_scalar_float_value(value); - op.SetAttr(name, val); - } - - void GNodeAttrMaker::Set(ge::GNode &op, const ge::AscendString &name, at::ArrayRef value) - { - std::vector> vals; - for (int i = 0; i < value.size(); i++) - { - std::vector val; - val.resize(value[i].size()); - std::copy(value[i].begin(), value[i].end(), val.begin()); - vals.emplace_back(val); - } - op.SetAttr(name, vals); - } - - Graph &Graph::Name(std::string name) - { - this->name = name; - this->inputs.clear(); - this->outputs.clear(); - return *this; - } - - Graph &Graph::Input(const aclTensorDesc *inDesc) - { - inputs.emplace_back(GeHelper::Convert(inDesc)); - return *this; - } - - Graph &Graph::Output(const aclTensorDesc *outDesc) - { - outputs.emplace_back(GeHelper::Convert(outDesc)); - return *this; - } - - Graph &Graph::SetConst(void *const_data_buffer, const size_t &const_data_len) - { - TORCH_CHECK(inputs.size() > 0, "The input vector can not be null!"); - // SetConstData function only support in CANN 5.0.3 (after 2021/08/15) - return *this; - } - - void Graph::Make() - { - if (not env::AutoTuneEnabled()) - { - return; - } - - ge::AscendString opType(name.c_str()); - auto ret = ge::aclgrphGenerateForOp(opType, this->inputs, this->outputs, this->graph); - if (ret != ge::GRAPH_SUCCESS) - { - AT_ERROR("aclgrphGenerateForOp failed. error code:", ret); - return; - } - - auto nodes = this->graph.GetDirectNode(); - ge::AscendString type; - for (auto tmpNode : nodes) - { - tmpNode.GetType(type); - if (type == name.c_str()) - { - node = tmpNode; - break; - } - } - } - - void Graph::GeGraph(ge::Graph &g) - { - g = this->graph; - } - } // namespace native -} // namespace at_npu diff --git a/torch_npu/csrc/framework/interface/Graph.h b/torch_npu/csrc/framework/interface/Graph.h deleted file mode 100644 index 297a25882313f76837b65ddc24ee42ca1e7ba81f..0000000000000000000000000000000000000000 --- a/torch_npu/csrc/framework/interface/Graph.h +++ /dev/null @@ -1,107 +0,0 @@ -// Copyright (c) 2020 Huawei Technologies Co., Ltd -// All rights reserved. -// -// Licensed under the BSD 3-Clause License (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://opensource.org/licenses/BSD-3-Clause -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef __PLUGIN_NATIVE_INTERFACE_GRAPH__ -#define __PLUGIN_NATIVE_INTERFACE_GRAPH__ - -#include -#include -#include - -#include "third_party/acl/inc/graph/graph.h" // ge::Graph -#include "third_party/acl/inc/graph/tensor.h" // TensorDesc -#include "third_party/acl/inc/graph/types.h" // Format -#include "third_party/acl/inc/acl/acl_base.h" -#include "torch_npu/csrc/framework/interface/EnvVariables.h" - -namespace at_npu -{ - namespace native - { - /** - This class is used to set GNode's attribute. - */ - class GNodeAttrMaker - { - public: - static void Set(ge::GNode &op, const ge::AscendString &name, bool value); - static void Set(ge::GNode &op, const ge::AscendString &name, int64_t value); - static void Set(ge::GNode &op, const ge::AscendString &name, float value); - static void Set(ge::GNode &op, const ge::AscendString &name, std::string value); - static void Set(ge::GNode &op, const ge::AscendString &name, c10::IntArrayRef value); - static void Set(ge::GNode &op, const ge::AscendString &name, at::ArrayRef value); - static void Set(ge::GNode &op, const ge::AscendString &name, c10::Scalar value); - static void Set(ge::GNode &op, const ge::AscendString &name, at::ArrayRef value); - }; // class GNodeAttrMaker - - /** - Class Graph is the wrapper of ge::Graph andd support to use ACL interface to construct. - */ - class Graph - { - public: - /** - This api is used to set graph's name. - */ - Graph &Name(std::string name); - /** - This api is used to set graph's input desc - */ - Graph &Input(const aclTensorDesc *inDesc); - /** - This api is used to set graph's output desc - */ - Graph &Output(const aclTensorDesc *outDesc); - /** - This api is used to set graph's last input desc to be const. - */ - Graph &SetConst(void *const_data_buffer, const size_t &const_data_len); - /** - This api is used to make graph, which is depend on the TensorDesc of inputs and outputs - */ - void Make(); - /** - This api should be called after Make(). - */ - template - void AddAttr(std::string &attrName, dataType value); - /** - This API is used to get the private member: ge::Graph. - */ - void GeGraph(ge::Graph &g); - - private: - std::string name; - std::vector inputs; - std::vector outputs; - ge::Graph graph; - ge::GNode node; - }; - - template - void Graph::AddAttr(std::string &attrName, dataType value) - { - if (not env::AutoTuneEnabled()) - { - return; - } - ge::AscendString attrName_(attrName.c_str()); - GNodeAttrMaker::Set(node, attrName_, value); - } - - } // namespace native -} // namespace at_npu - -#endif // __NATIVE_NPU_INTERFACE_GRAPH__ \ No newline at end of file diff --git a/torch_npu/csrc/npu/Module.cpp b/torch_npu/csrc/npu/Module.cpp index 0dc927d859c288ca3c559a075352bc565739f331..73f8f06d36285a12325201c3e7671796b52ff470 100644 --- a/torch_npu/csrc/npu/Module.cpp +++ b/torch_npu/csrc/npu/Module.cpp @@ -24,12 +24,14 @@ #include #include #include +#include #include #include #include #include #include #include +#include #include #include #include @@ -376,6 +378,41 @@ PyObject* THNPModule_finalizeDump(PyObject* _unused, PyObject* noargs) { END_HANDLE_TH_ERRORS } +PyObject* THNPModule_setOption_wrap(PyObject* self, PyObject* arg) { + HANDLE_TH_ERRORS + + if (!PyDict_Check(arg)) { + throw torch::TypeError("npu option must be a dict."); + } + + PyObject *key = nullptr; + PyObject *value = nullptr; + Py_ssize_t pos = 0; + std::map option; + + while (PyDict_Next(arg, &pos, &key, &value)) { + if (key == nullptr || !PyUnicode_Check(key)) { + throw torch::TypeError("option name is nullptr or is not string."); + } + + if (value == nullptr || !PyUnicode_Check(value)) { + throw torch::TypeError("option value is nullptr or is not string."); + } + + const char *pKey = PyUnicode_AsUTF8(key); + const char *pValue = PyUnicode_AsUTF8(value); + option[pKey] = pValue; + } + + torch::utils::npu_lazy_init(); + { + pybind11::gil_scoped_release no_gil; + c10::npu::SetOption(option); + } + Py_RETURN_NONE; + END_HANDLE_TH_ERRORS +} + static struct PyMethodDef THNPModule_methods[] = { {"_npu_init", (PyCFunction)THNPModule_initExtension, METH_NOARGS, nullptr}, {"_npu_set_run_yet_variable_to_false", (PyCFunction)THNPModule_set_run_yet_variable_to_false_wrap, METH_NOARGS, nullptr}, @@ -399,6 +436,7 @@ static struct PyMethodDef THNPModule_methods[] = { {"_npu_initDump", (PyCFunction)THNPModule_initDump, METH_NOARGS, nullptr}, {"_npu_setDump", (PyCFunction)THNPModule_setDump, METH_O, nullptr}, {"_npu_finalizeDump", (PyCFunction)THNPModule_finalizeDump, METH_NOARGS, nullptr}, + {"_npu_setOption", (PyCFunction)THNPModule_setOption_wrap, METH_O, nullptr}, {nullptr}}; PyMethodDef* THNPModule_get_methods() { diff --git a/torch_npu/npu/__init__.py b/torch_npu/npu/__init__.py index b97832386ed32d8b6d46e3a260c5c386d5793391..758c2abd34e4ff9e7678df8e943087d3734bd542 100644 --- a/torch_npu/npu/__init__.py +++ b/torch_npu/npu/__init__.py @@ -26,7 +26,7 @@ __all__ = [ "reset_peak_memory_stats", "reset_max_memory_allocated", "reset_max_memory_cached", "memory_allocated", "max_memory_allocated", "memory_reserved", "max_memory_reserved", "memory_cached", "max_memory_cached", "memory_snapshot", "memory_summary", - "Stream", "Event", "profiler" + "Stream", "Event", "profiler", "set_option" ] @@ -43,4 +43,5 @@ from .memory import (_free_mutex, caching_allocator_alloc, caching_allocator_del max_memory_allocated, memory_reserved, max_memory_reserved, memory_cached, max_memory_cached, memory_snapshot, memory_summary) from .streams import Stream, Event -from . import profiler \ No newline at end of file +from . import profiler +from .npu_frontend_enhance import set_option \ No newline at end of file diff --git a/torch_npu/npu/npu_frontend_enhance.py b/torch_npu/npu/npu_frontend_enhance.py new file mode 100644 index 0000000000000000000000000000000000000000..af120c713a49b5f3b1f3adeeb041de637d4df979 --- /dev/null +++ b/torch_npu/npu/npu_frontend_enhance.py @@ -0,0 +1,64 @@ +# Copyright (c) 2020 Huawei Technologies Co., Ltd +# Copyright (c) 2019, Facebook CORPORATION. +# All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import torch_npu._C +# this file is used to enhance the npu frontend API by set_option or other. + +__all__ = ["set_option", "global_step_inc", "set_start_fuzz_compile_step"] + +def set_option(option): + if not isinstance(option, dict): + raise TypeError("npu option must be a dict.") + + for option_name, option_value in option.items(): + option[option_name] = str(option_value) + torch_npu._C._npu_setOption(option) + +def init_dump(): + option = {"mdldumpswitch":"enable"} + torch_npu._C._npu_setOption(option) + +def set_dump(cfg_file): + if not os.path.exists(cfg_file): + raise AssertionError("cfg_file %s path not exists."%(cfg_file)) + cfg_file = os.path.abspath(cfg_file) + option = {"mdldumpconfigpath": cfg_file} + torch_npu._C._npu_setOption(option) + +def finalize_dump(): + option = {"mdldumpswitch": "disable"} + torch_npu._C._npu_setOption(option) + +_GLOBAL_STEP = 0 +_START_FUZZ_COMPILE_STEP = 1 +def global_step_inc(): + global _GLOBAL_STEP + _GLOBAL_STEP += 1 + + option = {"fuzzycompileswitch": "enable" if _GLOBAL_STEP >= _START_FUZZ_COMPILE_STEP \ + else "disable"} + torch_npu._C._npu_setOption(option) + +def set_start_fuzz_compile_step(step): + if not isinstance(step, int): + raise TypeError("step must be a int, but got ", type(step)) + + global _START_FUZZ_COMPILE_STEP + _START_FUZZ_COMPILE_STEP = step + option = {"fuzzycompileswitch": "disable"} + torch_npu._C._npu_setOption(option) +