diff --git a/test/test_cann_profiler.py b/test/test_cann_profiler.py new file mode 100644 index 0000000000000000000000000000000000000000..a906ad3e54caecb304d3a83e83e7ceb6f49b3b9d --- /dev/null +++ b/test/test_cann_profiler.py @@ -0,0 +1,112 @@ +# Copyright (c) 2020 Huawei Technologies Co., Ltd +# Copyright (c) 2019, Facebook CORPORATION. +# All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import shutil +from itertools import combinations +import torch +import torch_npu +from torch.testing._internal.common_utils import TestCase, run_tests + +class SmallModel(torch.nn.Module): + def __init__(self, in_channel=3, out_channel=12): + super(SmallModel, self).__init__() + self.conv1 = torch.nn.Conv2d(in_channel, in_channel, 3, padding=1) + self.relu1 = torch.nn.ReLU() + self.conv2 = torch.nn.Conv2d(in_channel, out_channel, 3, padding=1) + + def forward(self, input_1): + input_1 = self.conv1(input_1) + input_1 = self.relu1(input_1) + input_1 = self.conv2(input_1) + return input_1.reshape(input_1.shape[0], -1) + +class TestCannProfiler(TestCase): + enevtTypeResults = [] + results_path = "./results" + + @classmethod + def setUpClass(cls): + if not os.path.exists(TestCannProfiler.results_path): + os.makedirs(TestCannProfiler.results_path) + torch.npu.prof_init(TestCannProfiler.results_path) + tensor = torch.rand(2,3).npu() + + enevtTypes = [{"ACL_PROF_ACL_API":False}, {"ACL_PROF_TASK_TIME":False}, + {"ACL_PROF_AICORE_METRICS":False}, {"ACL_PROF_AICPU":False}, + {"ACL_PROF_L2CACHE":False}, {"ACL_PROF_HCCL_TRACE":False}, + {"ACL_PROF_TRAINING_TRACE":False}] + + enevtTypeCombinations = list(combinations(enevtTypes, 1)) + list(combinations(enevtTypes, 2)) + \ + list(combinations(enevtTypes, 3)) + list(combinations(enevtTypes, 4)) + \ + list(combinations(enevtTypes, 5)) + list(combinations(enevtTypes, 6)) + for events in enevtTypeCombinations: + temp_events = {} + for event in events: + temp_events.update(event) + TestCannProfiler.enevtTypeResults.append(temp_events) + + @classmethod + def tearDownClass(cls): + if os.path.exists(TestCannProfiler.results_path): + shutil.rmtree(TestCannProfiler.results_path) + torch.npu.prof_finalize() + + def _run_ops(self): + input_1 = torch.rand(10, 10).npu() + input_2 = torch.rand(10, 10).npu() + out = input_1*input_2 + + def _run_small_model(self): + input_shape = (4, 3, 24, 24) + out_shape = (4, 12, 24, 24) + device = "npu" + model = SmallModel(input_shape[1], out_shape[1]).to(device) + criterion = torch.nn.MSELoss() + optimizer = torch.optim.SGD(model.parameters(), lr=0.0001) + for i in range(10): + inputs = torch.rand(input_shape).to(device) + target = torch.rand(out_shape).reshape(out_shape[0], -1).to(device) + output = model(inputs) + loss = criterion(output, target) + loss.backward() + optimizer.zero_grad() + optimizer.step() + + def _test_cann_ops(self, *args, **kwargs): + config = torch.npu.profileConfig(**kwargs) + torch.npu.prof_start(config.NpuEventConfig, config.AiCoreMetricsConfig) + self._run_ops() + torch.npu.prof_stop() + + def _test_cann_model(self, *args, **kwargs): + config = torch.npu.profileConfig(**kwargs) + torch.npu.prof_start(config.NpuEventConfig, config.AiCoreMetricsConfig) + self._run_small_model() + torch.npu.prof_stop() + + def test_with_ops(self): + for events in TestCannProfiler.enevtTypeResults: + for i in range(5): + self._test_cann_ops(**events, aiCoreMetricsType=i) + + def test_with_small_model(self): + for events in TestCannProfiler.enevtTypeResults: + for i in range(5): + self._test_cann_model(**events, aiCoreMetricsType=i) + +if __name__ == "__main__": + run_tests() diff --git a/torch_npu/csrc/aten/common/CopyKernel.cpp b/torch_npu/csrc/aten/common/CopyKernel.cpp index 5025e3accd61aedca61b9d118681ed7c7cbe7292..6bc360ae5bc1360064c35195fda6f612c1116470 100644 --- a/torch_npu/csrc/aten/common/CopyKernel.cpp +++ b/torch_npu/csrc/aten/common/CopyKernel.cpp @@ -17,7 +17,7 @@ #include #include #include -#include +#include "torch_npu/csrc/register/OptionsManager.h" #include "torch_npu/csrc/framework/contiguous/ContiguousOpt.h" #include "torch_npu/csrc/framework/FormatHelper.h" @@ -132,7 +132,7 @@ void copy_d2d_last_method( bool same_type, bool non_blocking) { // general copy method but Low performance - if (c10::npu::OptionsManager::CheckPTcopy_Enable()) { + if (torch_npu::option::OptionsManager::CheckPTcopy_Enable()) { RECORD_FUNCTION("d2dCopyWithPTCopy", std::vector({src})); copy_kernel_npu(self, src, non_blocking); } else { diff --git a/torch_npu/csrc/aten/ops/AddKernelNpu.cpp b/torch_npu/csrc/aten/ops/AddKernelNpu.cpp index 363599e38aef94966cebf3d8e9fbfc1758fbbaf6..19db800366429c9bedaad43374429a142de6d743 100644 --- a/torch_npu/csrc/aten/ops/AddKernelNpu.cpp +++ b/torch_npu/csrc/aten/ops/AddKernelNpu.cpp @@ -15,7 +15,7 @@ // limitations under the License. #include -#include +#include "torch_npu/csrc/register/OptionsManager.h" #include #include "torch_npu/csrc/framework/utils/CalcuOpUtil.h" @@ -121,7 +121,7 @@ namespace at_npu } else { - if (c10::npu::OptionsManager::CheckDynamicOptimizer("ADD")) + if (torch_npu::option::OptionsManager::CheckDynamicOptimizer("ADD")) { cmd.Name("AxpyV2") .Input(self) diff --git a/torch_npu/csrc/aten/ops/BmmKernelNpu.cpp b/torch_npu/csrc/aten/ops/BmmKernelNpu.cpp index 38aceb87fd3faf9805639c1659673666e757ab15..4042b6f678cfd6ed37bb6da08fb845a4e793015c 100644 --- a/torch_npu/csrc/aten/ops/BmmKernelNpu.cpp +++ b/torch_npu/csrc/aten/ops/BmmKernelNpu.cpp @@ -72,7 +72,7 @@ at::Tensor NPUNativeFunctions::bmm(const at::Tensor& self, const at::Tensor& mat // 检查是否指定mm输出为NCHW。待NLP模型总体策略制定后删去 if ((self.scalar_type() == at::ScalarType::Float || self.scalar_type() == at::ScalarType::Half) && - !c10::npu::OptionsManager::CheckSwitchMMOutputEnable()) { + !torch_npu::option::OptionsManager::CheckSwitchMMOutputEnable()) { result = at::empty_with_format(outputSize, self.options(), ACL_FORMAT_FRACTAL_NZ); } else { result = at::empty_with_format(outputSize, self.options(), ACL_FORMAT_ND); diff --git a/torch_npu/csrc/aten/ops/MmKernelNpu.cpp b/torch_npu/csrc/aten/ops/MmKernelNpu.cpp index e836bee17c78ff00769fbce295c28527c8ce0db2..ceba5c8a22cb635ddaf98f7495126365b1053941 100644 --- a/torch_npu/csrc/aten/ops/MmKernelNpu.cpp +++ b/torch_npu/csrc/aten/ops/MmKernelNpu.cpp @@ -181,7 +181,7 @@ Return: // construct the output tensor of the NPU at::Tensor result; - if ((self.scalar_type() == at::ScalarType::Half) && !c10::npu::OptionsManager::CheckSwitchMMOutputEnable()) + if ((self.scalar_type() == at::ScalarType::Half) && !torch_npu::option::OptionsManager::CheckSwitchMMOutputEnable()) { result = at::empty_with_format( outputSize, self.options(), ACL_FORMAT_FRACTAL_NZ); diff --git a/torch_npu/csrc/aten/ops/MulKernelNpu.cpp b/torch_npu/csrc/aten/ops/MulKernelNpu.cpp index 7e6403e1fcb3a130b0ab7520f0487ccf0d6b7c71..1fa1b7afc88814b1ab3962ab716b90c39286f11b 100644 --- a/torch_npu/csrc/aten/ops/MulKernelNpu.cpp +++ b/torch_npu/csrc/aten/ops/MulKernelNpu.cpp @@ -15,8 +15,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include - +#include "torch_npu/csrc/register/OptionsManager.h" #include "torch_npu/csrc/framework/utils/OpAdapter.h" #include "torch_npu/csrc/framework/utils/CalcuOpUtil.h" #include "torch_npu/csrc/aten/NPUNativeFunctions.h" @@ -36,7 +35,7 @@ namespace at_npu { auto unified_result = OpPreparation::binary_op_check(result, self, other, true); OpCommand cmd; - if (c10::npu::OptionsManager::CheckDynamicOptimizer("MUL")) + if (torch_npu::option::OptionsManager::CheckDynamicOptimizer("MUL")) { cmd.Name("Mul") .Expect(unified_result) diff --git a/torch_npu/csrc/framework/InferFormat.cpp b/torch_npu/csrc/framework/InferFormat.cpp index 6ef3c77742d74485b2e669ed9a91328096b75cc5..7a060e6cf2a4c4e831d6a6400d83c4e6b7ca8475 100644 --- a/torch_npu/csrc/framework/InferFormat.cpp +++ b/torch_npu/csrc/framework/InferFormat.cpp @@ -13,7 +13,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include +#include "torch_npu/csrc/register/OptionsManager.h" #include "torch_npu/csrc/framework/InferFormat.h" #include "torch_npu/csrc/framework/FormatHelper.h" diff --git a/torch_npu/csrc/framework/OpCommandBase.h b/torch_npu/csrc/framework/OpCommandBase.h index 748cd17bebd6cdfcd8e0fb61a31123d1dc194f47..069c2e4e8211226a682f16b0f703cd73e1fd4c4e 100644 --- a/torch_npu/csrc/framework/OpCommandBase.h +++ b/torch_npu/csrc/framework/OpCommandBase.h @@ -15,7 +15,7 @@ #ifndef __PULGIN_NATIVE_UTILS_COMMAND_BASE__ #define __PULGIN_NATIVE_UTILS_COMMAND_BASE__ -#include +#include "torch_npu/csrc/register/OptionsManager.h" #include #include "torch_npu/csrc/aten/mirror/NPUTensorIterator.h" @@ -137,7 +137,7 @@ namespace at_npu void Run() { - if (c10::npu::OptionsManager::CheckQueueEnable()) + if (torch_npu::option::OptionsManager::CheckQueueEnable()) { ExecuteParas params; aclCmd->ExportParams(params); diff --git a/torch_npu/csrc/framework/OpParamMaker.cpp b/torch_npu/csrc/framework/OpParamMaker.cpp index 4512721318e6f94a32ff93f40249d03b024d987a..3983af329461f83345b8afacc3e246019bc1c1b3 100644 --- a/torch_npu/csrc/framework/OpParamMaker.cpp +++ b/torch_npu/csrc/framework/OpParamMaker.cpp @@ -13,7 +13,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include +#include "torch_npu/csrc/register/OptionsManager.h" #include #include #include diff --git a/torch_npu/csrc/framework/contiguous/ContiguousOpt.cpp b/torch_npu/csrc/framework/contiguous/ContiguousOpt.cpp index 0854e27c09179e30091d5c7b853aa333be3e8059..246adc7ec10c2bd5c7dd00fe8b8f307483988780 100644 --- a/torch_npu/csrc/framework/contiguous/ContiguousOpt.cpp +++ b/torch_npu/csrc/framework/contiguous/ContiguousOpt.cpp @@ -138,7 +138,7 @@ namespace at_npu } if (OpenCombined && - c10::npu::OptionsManager::CheckCombinedOptimizerEnable()) + torch_npu::option::OptionsManager::CheckCombinedOptimizerEnable()) { optimizations.emplace_back("combined"); } diff --git a/torch_npu/csrc/framework/contiguous/ContiguousOpt.h b/torch_npu/csrc/framework/contiguous/ContiguousOpt.h index cdb230047aa9d067deabc98dea1c258de5c5c7ee..c96073f240a699b6bc45255b839820b9eecaa07d 100644 --- a/torch_npu/csrc/framework/contiguous/ContiguousOpt.h +++ b/torch_npu/csrc/framework/contiguous/ContiguousOpt.h @@ -16,7 +16,7 @@ #ifndef __PULGIN_NATIVE_CONTIGUOUS_CONTIGUOUS_OPTIMIZE__ #define __PULGIN_NATIVE_CONTIGUOUS_CONTIGUOUS_OPTIMIZE__ -#include +#include "torch_npu/csrc/register/OptionsManager.h" #include #include "torch_npu/csrc/framework/utils/NpuUtils.h" diff --git a/torch_npu/csrc/framework/contiguous/combined_opt.cpp b/torch_npu/csrc/framework/contiguous/combined_opt.cpp index cf270817f8015219da9f0bf36158d8a0899f09a9..a0fa15f3cb95fbc3efc8e710dd42a8208371c2da 100644 --- a/torch_npu/csrc/framework/contiguous/combined_opt.cpp +++ b/torch_npu/csrc/framework/contiguous/combined_opt.cpp @@ -35,7 +35,7 @@ namespace at_npu int64_t maxLen = 2; // Setting for 3-combined cases: "TRI_COMBINED_ENABLE=1". - if (c10::npu::OptionsManager::CheckTriCombinedOptimizerEnable()) + if (torch_npu::option::OptionsManager::CheckTriCombinedOptimizerEnable()) { maxLen = 3; } diff --git a/torch_npu/csrc/framework/interface/AclInterface.cpp b/torch_npu/csrc/framework/interface/AclInterface.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c97c99df4419c966c405c342b99ed6f665e90150 --- /dev/null +++ b/torch_npu/csrc/framework/interface/AclInterface.cpp @@ -0,0 +1,192 @@ +// Copyright (c) 2020 Huawei Technologies Co., Ltd +// All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "torch_npu/csrc/register/FunctionLoader.h" +#include "torch_npu/csrc/framework/interface/AclInterface.h" + +namespace at_npu { +namespace native { + +#undef LOAD_FUNCTION +#define LOAD_FUNCTION(funcName) \ + REGISTER_FUNCTION(libascendcl, funcName) +#undef GET_FUNC +#define GET_FUNC(funcName) \ + GET_FUNCTION(libascendcl, funcName) + +REGISTER_LIBRARY(libascendcl) +LOAD_FUNCTION(aclGetRecentErrMsg) +LOAD_FUNCTION(aclrtCreateEventWithFlag) +LOAD_FUNCTION(aclrtQueryEventWaitStatus) +LOAD_FUNCTION(aclprofCreateStepInfo) +LOAD_FUNCTION(aclprofGetStepTimestamp) +LOAD_FUNCTION(aclprofDestroyStepInfo) +LOAD_FUNCTION(aclprofInit) +LOAD_FUNCTION(aclprofStart) +LOAD_FUNCTION(aclprofStop) +LOAD_FUNCTION(aclprofFinalize) +LOAD_FUNCTION(aclprofCreateConfig) +LOAD_FUNCTION(aclprofDestroyConfig) + +aclprofStepInfoPtr init_stepinfo(){ + typedef aclprofStepInfoPtr(*npdInitFunc)(); + static npdInitFunc func = nullptr; + if(func == nullptr){ + func = (npdInitFunc)GET_FUNC(aclprofCreateStepInfo); + } + TORCH_CHECK(func, "Failed to find function ", "aclprofCreateStepInfo"); + auto ret = func(); + return ret; +} + +NpdStatus destroy_stepinfo(aclprofStepInfoPtr stepInfo){ + typedef NpdStatus(*npdDestroyFunc)(aclprofStepInfoPtr); + static npdDestroyFunc func = nullptr; + if(func == nullptr){ + func = (npdDestroyFunc)GET_FUNC(aclprofDestroyStepInfo); + } + TORCH_CHECK(func, "Failed to find function ", "aclprofDestroyStepInfo"); + auto ret = func(stepInfo); + return ret; +} + +NpdStatus start_deliver_op(aclprofStepInfoPtr stepInfo, aclprofStepTag stepTag, aclrtStream stream){ + typedef NpdStatus(*npdStartProfiling)(aclprofStepInfoPtr, aclprofStepTag, aclrtStream); + static npdStartProfiling func = nullptr; + if(func == nullptr){ + func = (npdStartProfiling)GET_FUNC(aclprofGetStepTimestamp); + } + TORCH_CHECK(func, "Failed to find function ", "aclprofGetStepTimestamp"); + auto ret = func(stepInfo, stepTag, stream); + return ret; +} + +NpdStatus stop_deliver_op(aclprofStepInfoPtr stepInfo, aclprofStepTag stepTag, aclrtStream stream){ + typedef NpdStatus(*npdStopProfiling)(aclprofStepInfoPtr, aclprofStepTag, aclrtStream); + static npdStopProfiling func = nullptr; + if(func == nullptr){ + func = (npdStopProfiling)GET_FUNC(aclprofGetStepTimestamp); + } + TORCH_CHECK(func, "Failed to find function ", "aclprofGetStepTimestamp"); + auto ret = func(stepInfo, stepTag, stream); + return ret; +} + +const char *AclGetErrMsg() +{ + typedef const char *(*aclGetErrMsg)(); + static aclGetErrMsg func = nullptr; + if (func == nullptr) { + func = (aclGetErrMsg)GET_FUNC(aclGetRecentErrMsg); + } + if (func != nullptr) { + return func(); + } + return ""; +} + +aclError AclrtCreateEventWithFlag(aclrtEvent *event, uint32_t flag) { + typedef aclError(*AclrtCreateEventWithFlagFunc)(aclrtEvent*, uint32_t); + static AclrtCreateEventWithFlagFunc func = nullptr; + if (func == nullptr) { + func = (AclrtCreateEventWithFlagFunc)GET_FUNC(aclrtCreateEventWithFlag); + } + TORCH_CHECK(func, "Failed to find function ", "aclrtCreateEventWithFlag"); + return func(event, flag); +} + +aclError AclQueryEventStatus(aclrtEvent event, aclrtEventWaitStatus *waitStatus, aclrtEventStatus *recordStatus) +{ + typedef aclError (*aclQueryEventWaitStatus)(aclrtEvent event, aclrtEventWaitStatus *status); + static aclQueryEventWaitStatus func = nullptr; + if (func == nullptr) { + func = (aclQueryEventWaitStatus)GET_FUNC(aclrtQueryEventWaitStatus); + } + if (func != nullptr) { + return func(event, waitStatus); + } else { + return aclrtQueryEvent(event, recordStatus); + } +} + +aclError AclProfilingInit(const char *profilerResultPath, size_t length) { + typedef aclError (*AclProfInitFunc) (const char *, size_t); + static AclProfInitFunc func = nullptr; + if (func == nullptr) { + func = (AclProfInitFunc)GET_FUNC(aclprofInit); + } + TORCH_CHECK(func, "Failed to find function ", "aclprofInit"); + return func(profilerResultPath, length); +} + +aclError AclProfilingStart(const aclprofConfig *profilerConfig) { + typedef aclError (*AclProfStartFunc) (const aclprofConfig *); + static AclProfStartFunc func = nullptr; + if (func == nullptr) { + func = (AclProfStartFunc)GET_FUNC(aclprofStart); + } + TORCH_CHECK(func, "Failed to find function ", "aclprofStart"); + return func(profilerConfig); +} + +aclError AclProfilingStop(const aclprofConfig *profilerConfig) { + typedef aclError (*AclProfStopFunc) (const aclprofConfig*); + static AclProfStopFunc func = nullptr; + if (func == nullptr) { + func = (AclProfStopFunc)GET_FUNC(aclprofStop); + } + TORCH_CHECK(func, "Failed to find function ", "aclprofStop"); + return func(profilerConfig); +} + +aclError AclProfilingFinalize() { + typedef aclError (*AclProfFinalizeFunc) (); + static AclProfFinalizeFunc func = nullptr; + if (func == nullptr) { + func = (AclProfFinalizeFunc)GET_FUNC(aclprofFinalize); + } + TORCH_CHECK(func, "Failed to find function ", "aclprofFinalize"); + return func(); +} + +aclprofConfig *AclProfilingCreateConfig( + uint32_t *deviceIdList, + uint32_t deviceNums, + aclprofAicoreMetrics aicoreMetrics, + aclprofAicoreEvents *aicoreEvents, + uint64_t dataTypeConfig) { + typedef aclprofConfig *(*AclProfCreateConfigFunc) \ + (uint32_t *, uint32_t, aclprofAicoreMetrics, aclprofAicoreEvents *, uint64_t); + static AclProfCreateConfigFunc func = nullptr; + if (func == nullptr) { + func = (AclProfCreateConfigFunc)GET_FUNC(aclprofCreateConfig); + } + TORCH_CHECK(func, "Failed to find function ", "aclprofCreateConfig"); + return func(deviceIdList, deviceNums, aicoreMetrics, aicoreEvents, dataTypeConfig); +} + +aclError AclProfilingDestroyConfig(const aclprofConfig *profilerConfig) { + typedef aclError (*AclProfDestroyConfigFunc) (const aclprofConfig *); + static AclProfDestroyConfigFunc func = nullptr; + if (func == nullptr) { + func = (AclProfDestroyConfigFunc)GET_FUNC(aclprofDestroyConfig); + } + TORCH_CHECK(func, "Failed to find function ", "aclprofDestroyConfig"); + return func(profilerConfig); +} + +} // namespace native +} // namespace at_npu diff --git a/torch_npu/csrc/framework/interface/AclInterface.h b/torch_npu/csrc/framework/interface/AclInterface.h new file mode 100644 index 0000000000000000000000000000000000000000..3403bf955fc058a1058f77cf2d3d9366acd34689 --- /dev/null +++ b/torch_npu/csrc/framework/interface/AclInterface.h @@ -0,0 +1,94 @@ +// Copyright (c) 2020 Huawei Technologies Co., Ltd +// All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef __TORCH_NPU_INTERFACE_ACLINTERFACE__ +#define __TORCH_NPU_INTERFACE_ACLINTERFACE__ + +#include "third_party/acl/inc/acl/acl_rt.h" +#include +#include + +namespace at_npu { +namespace native { + +typedef enum aclrtEventWaitStatus { + ACL_EVENT_WAIT_STATUS_COMPLETE = 0, + ACL_EVENT_WAIT_STATUS_NOT_READY = 1, + ACL_EVENT_WAIT_STATUS_RESERVED = 0xffff, +} aclrtEventWaitStatus; + +/** + aclprofStepInfo is provide by acl, it used to be store dispatch op info. + */ +using aclprofStepInfoPtr = aclprofStepInfo *; +/** + NpdStatus is provide by acl, it used to store the return value. + */ +using NpdStatus = int; + +/** + This Api is used to init npd, it need to be called once at process. + */ +aclprofStepInfoPtr init_stepinfo(); +/** + This Api is used to destroy npd, it need to be called once at process. + */ +NpdStatus destroy_stepinfo(aclprofStepInfoPtr stepInfo); +/** + This Api is used to start dispatch op, this operation should be called after init. + */ +NpdStatus start_deliver_op(aclprofStepInfoPtr stepInfo, aclprofStepTag stepTag, aclrtStream stream); +/** + This Api is used to stop dispatch op, this operation should be called after start dispatch op. + */ +NpdStatus stop_deliver_op(aclprofStepInfoPtr stepInfo, aclprofStepTag stepTag, aclrtStream stream); + +/** + This API is used to get error msg + */ +const char *AclGetErrMsg(); + +/** + * @ingroup AscendCL + * @brief create event instance + * + * @param event [OUT] created event + * @param flag [IN] event flag + * @retval ACL_ERROR_NONE The function is successfully executed. + * @retval OtherValues Failure + */ +aclError AclrtCreateEventWithFlag(aclrtEvent *event, uint32_t flag); + +/** + This API is used to query status of event task + */ +aclError AclQueryEventStatus(aclrtEvent event, aclrtEventWaitStatus *waitStatus, aclrtEventStatus *recordStatus); + +aclError AclProfilingInit(const char *profilerResultPath, size_t length); +aclError AclProfilingStart(const aclprofConfig *profilerConfig); +aclError AclProfilingStop(const aclprofConfig *profilerConfig); +aclError AclProfilingFinalize(); +aclprofConfig * AclProfilingCreateConfig( + uint32_t *deviceIdList, + uint32_t deviceNums, + aclprofAicoreMetrics aicoreMetrics, + aclprofAicoreEvents *aicoreEvents, + uint64_t dataTypeConfig); +aclError AclProfilingDestroyConfig(const aclprofConfig *profilerConfig); + +} // namespace native +} // namespace at_npu + +#endif // __TORCH_NPU_INTERFACE_ACLINTERFACE__ \ No newline at end of file diff --git a/torch_npu/csrc/framework/interface/EnvVariables.cpp b/torch_npu/csrc/framework/interface/EnvVariables.cpp index b37c0952c377bccbfc659f9fbfa8f425a71e7ff0..728fe6d4af49105ae6e74a6f225b3892f6f57ab3 100644 --- a/torch_npu/csrc/framework/interface/EnvVariables.cpp +++ b/torch_npu/csrc/framework/interface/EnvVariables.cpp @@ -15,15 +15,14 @@ // limitations under the License. #include -#include #include -#include -#include "torch_npu/csrc/framework/utils/NpuFuzzyBlacklist.h" -#include "torch_npu/csrc/framework/utils/NpuProfilingDispatch.h" #include "third_party/acl/inc/acl/acl_mdl.h" +#include "torch_npu/csrc/framework/utils/NpuFuzzyBlacklist.h" #include "torch_npu/csrc/framework/interface/AclOpCompileInterface.h" #include "torch_npu/csrc/framework/aoe/AoeUtils.h" +#include "torch_npu/csrc/register/OptionRegister.h" +#include "torch_npu/csrc/profiler/cann_profiling.h" namespace at_npu { namespace native { namespace env { @@ -106,23 +105,21 @@ REGISTER_OPTION_HOOK(deliverswitch, [](const std::string &val) { "before you prepare to deliver op, ", "you should be enture profiling mode is on correctly!"); if (val == "enable") { - NpuProfilingDispatch::Instance().start(); + torch_npu::profiler::NpuProfilingDispatch::Instance().start(); } else { - NpuProfilingDispatch::Instance().stop(); + torch_npu::profiler::NpuProfilingDispatch::Instance().stop(); } }) REGISTER_OPTION_HOOK(profilerResultPath, [](const std::string &val) { - at::native::npu::NpuProfiling::Instance().Init(val); + torch_npu::profiler::NpuProfiling::Instance().Init(val); }) REGISTER_OPTION_HOOK(profiling, [](const std::string &val) { - if (val.compare("start") == 0) { - at::native::npu::NpuProfiling::Instance().Start(); - } else if (val.compare("stop") == 0) { - at::native::npu::NpuProfiling::Instance().Stop(); + if (val.compare("stop") == 0) { + torch_npu::profiler::NpuProfiling::NpuProfiling::Instance().Stop(); } else if (val.compare("finalize") == 0) { - at::native::npu::NpuProfiling::Instance().Finalize(); + torch_npu::profiler::NpuProfiling::NpuProfiling::Instance().Finalize(); } else { TORCH_CHECK(false, "profiling input: (", val, " ) error!") } diff --git a/torch_npu/csrc/framework/utils/CalcuOpUtil.cpp b/torch_npu/csrc/framework/utils/CalcuOpUtil.cpp index a847435adcc1441967c66cb87af68dd95f00251f..71bd52d3469d08c8d64bd9472edc7d18a34263df 100644 --- a/torch_npu/csrc/framework/utils/CalcuOpUtil.cpp +++ b/torch_npu/csrc/framework/utils/CalcuOpUtil.cpp @@ -17,7 +17,7 @@ #include #include #include -#include +#include "torch_npu/csrc/register/OptionsManager.h" #include "torch_npu/csrc/framework/utils/CalcuOpUtil.h" #include "torch_npu/csrc/framework/interface/AclOpCompileInterface.h" @@ -686,7 +686,7 @@ namespace at_npu c10::SmallVector &outputs, const c10::SmallVector &attrs) { - if (c10::npu::OptionsManager::CheckQueueEnable()) + if (torch_npu::option::OptionsManager::CheckQueueEnable()) { ExecuteParas cur_paras; cur_paras.opType = opName; diff --git a/torch_npu/csrc/npu/Module.cpp b/torch_npu/csrc/npu/Module.cpp index 73f8f06d36285a12325201c3e7671796b52ff470..2031ffe2884919532d3d3ebaf238769d1ebb7315 100644 --- a/torch_npu/csrc/npu/Module.cpp +++ b/torch_npu/csrc/npu/Module.cpp @@ -24,7 +24,6 @@ #include #include #include -#include #include #include #include @@ -40,7 +39,8 @@ #include #include "third_party/acl/inc/acl/acl.h" - +#include "torch_npu/csrc/register/OptionRegister.h" +#include "torch_npu/csrc/profiler/cann_profiling.h" static PyObject* THNPModule_initExtension(PyObject* self, PyObject* noargs) { HANDLE_TH_ERRORS @@ -407,12 +407,28 @@ PyObject* THNPModule_setOption_wrap(PyObject* self, PyObject* arg) { torch::utils::npu_lazy_init(); { pybind11::gil_scoped_release no_gil; - c10::npu::SetOption(option); + torch_npu::option::SetOption(option); } Py_RETURN_NONE; END_HANDLE_TH_ERRORS } +PyObject* THNPModule_prof_start(PyObject* self, PyObject* args) { + HANDLE_TH_ERRORS + + PyObject *value_1 = nullptr; + PyObject *value_2 = nullptr; + if(!PyArg_ParseTuple(args, "OO", &value_1, &value_2)) { + throw torch::TypeError("prof_start npu_event type or aicore_metrics set error."); + } + uint64_t npu_event = THPUtils_unpackLong(value_1); + uint64_t aicore_metrics = THPUtils_unpackLong(value_2); + pybind11::gil_scoped_release no_gil; + torch_npu::profiler::NpuProfiling::Instance().Start(npu_event, aicore_metrics); + Py_RETURN_NONE; + END_HANDLE_TH_ERRORS +} + static struct PyMethodDef THNPModule_methods[] = { {"_npu_init", (PyCFunction)THNPModule_initExtension, METH_NOARGS, nullptr}, {"_npu_set_run_yet_variable_to_false", (PyCFunction)THNPModule_set_run_yet_variable_to_false_wrap, METH_NOARGS, nullptr}, @@ -437,6 +453,7 @@ static struct PyMethodDef THNPModule_methods[] = { {"_npu_setDump", (PyCFunction)THNPModule_setDump, METH_O, nullptr}, {"_npu_finalizeDump", (PyCFunction)THNPModule_finalizeDump, METH_NOARGS, nullptr}, {"_npu_setOption", (PyCFunction)THNPModule_setOption_wrap, METH_O, nullptr}, + {"_prof_start", (PyCFunction)THNPModule_prof_start, METH_VARARGS, nullptr}, {nullptr}}; PyMethodDef* THNPModule_get_methods() { diff --git a/torch_npu/csrc/profiler/cann_profiling.cpp b/torch_npu/csrc/profiler/cann_profiling.cpp new file mode 100644 index 0000000000000000000000000000000000000000..41c9a99c56ddef108843f99a0807a631f3933137 --- /dev/null +++ b/torch_npu/csrc/profiler/cann_profiling.cpp @@ -0,0 +1,154 @@ +// Copyright (c) 2020, Huawei Technologies.All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "torch_npu/csrc/profiler/cann_profiling.h" +#include +#include + +namespace torch_npu { +namespace profiler { + +NpuProfiling& NpuProfiling::Instance() { + static NpuProfiling npuProfiling; + return npuProfiling; +} + +void NpuProfiling::Init(const std::string &path) { + TORCH_CHECK(status == PROFILING_FINALIZE, "init current profile status is: ", status, " error!"); + c10::npu::npuSynchronizeDevice(); + auto ret = c10::npu::acl::AclProfilingInit(path.c_str(), path.length()); + if (ret && (ret != ACL_ERROR_PROF_ALREADY_RUN)) { + NPU_LOGE("npu AclProfInit fail, error code: %d", ret); + C10_NPU_SHOW_ERR_MSG(); + return; + } + status = PROFILING_INIT; +} + +void NpuProfiling::Start(uint64_t npu_event, uint64_t aicore_metrics) { + TORCH_CHECK(status == PROFILING_INIT || status == PROFILING_STOP, + "start current profile status is: ", status, " error!"); + int deviceIndex = 0; + aclError ret = aclrtGetDevice(&deviceIndex); + if(ret){ + NPU_LOGE("npu profiling aclrtGetDevice fail, error code: %d", ret); + C10_NPU_SHOW_ERR_MSG(); + (void)c10::npu::acl::AclProfilingFinalize(); + status = PROFILING_FINALIZE; + return; + } + const uint32_t deviceNum = 1; + uint32_t deviceIdList[deviceNum] = {deviceIndex}; + profCfg = c10::npu::acl::AclProfilingCreateConfig( + deviceIdList, + deviceNum, + (aclprofAicoreMetrics)aicore_metrics, + nullptr, + npu_event); + if (profCfg == nullptr) { + NPU_LOGE("npu profiling profiling_create_config fail, error profCfg is null."); + C10_NPU_SHOW_ERR_MSG(); + (void)c10::npu::acl::AclProfilingFinalize(); + status = PROFILING_FINALIZE; + return; + } + c10::npu::npuSynchronizeDevice(); + ret = c10::npu::acl::AclProfilingStart(profCfg); + if(ret && (ret != ACL_ERROR_PROF_ALREADY_RUN)){ + NPU_LOGE("npu profiling AclProfStart fail, error code: %d", ret); + C10_NPU_SHOW_ERR_MSG(); + } + status = PROFILING_START; +} + +void NpuProfiling::Stop() { + TORCH_CHECK(status == PROFILING_START, "stop current profile status is: ", status, " error!"); + c10::npu::npuSynchronizeDevice(); + auto ret = c10::npu::acl::AclProfilingStop(profCfg); + if (ret && (ret != ACL_ERROR_PROF_ALREADY_RUN)) { + NPU_LOGE("npu AclProfStop fail, error code: %d", ret); + C10_NPU_SHOW_ERR_MSG(); + } + status = PROFILING_STOP; +} + +void NpuProfiling::Finalize() { + if (profCfg != nullptr) { + if (status != PROFILING_STOP) { + NPU_LOGW("finalize current profile status ( %u ) is not stopped, and call stop now.", status); + auto ret = c10::npu::acl::AclProfilingStop(profCfg); + if (ret && (ret != ACL_ERROR_PROF_ALREADY_RUN)) { + NPU_LOGE("npu AclProfStop fail, error code: %d", ret); + C10_NPU_SHOW_ERR_MSG(); + } + } + auto ret = c10::npu::acl::AclProfilingDestroyConfig(profCfg); + if (ret && (ret != ACL_ERROR_PROF_ALREADY_RUN)) { + NPU_LOGE("npu AclProfDestoryConfig fail, error code: %d", ret); + C10_NPU_SHOW_ERR_MSG(); + } + profCfg = nullptr; + } + auto ret = c10::npu::acl::AclProfilingFinalize(); + if (ret && (ret != ACL_ERROR_PROF_ALREADY_RUN)) { + NPU_LOGE("npu AclProfFinalize fail, error code: %d", ret); + C10_NPU_SHOW_ERR_MSG(); + } + status = PROFILING_FINALIZE; +} + + +NpuProfilingDispatch& NpuProfilingDispatch::Instance(){ + static NpuProfilingDispatch npuProfilingDispatch; + return npuProfilingDispatch; +} + +void NpuProfilingDispatch::init(){ + profStepInfo = c10::npu::acl::init_stepinfo(); +} + +void NpuProfilingDispatch::start(){ + this->init(); + auto stream = c10::npu::getCurrentNPUStream(); + auto ret = c10::npu::acl::start_deliver_op( + profStepInfo, + aclprofStepTag::ACL_STEP_START, + stream); + if(ret != ACL_ERROR_NONE){ + NPU_LOGE("npu profiling start fail, error code: %d", ret); + C10_NPU_SHOW_ERR_MSG(); + } +} + +void NpuProfilingDispatch::stop(){ + auto stream = c10::npu::getCurrentNPUStream(); + auto ret = c10::npu::acl::stop_deliver_op( + profStepInfo, + aclprofStepTag::ACL_STEP_END, + stream); + if(ret != ACL_ERROR_NONE){ + NPU_LOGE("npu profiling stop fail, error code: %d", ret); + C10_NPU_SHOW_ERR_MSG(); + } + this->destroy(); +} + +void NpuProfilingDispatch::destroy(){ + if(profStepInfo != nullptr){ + c10::npu::acl::destroy_stepinfo(profStepInfo); + } +} + +} +} diff --git a/torch_npu/csrc/framework/utils/NpuProfilingDispatch.cpp b/torch_npu/csrc/profiler/cann_profiling.h similarity index 30% rename from torch_npu/csrc/framework/utils/NpuProfilingDispatch.cpp rename to torch_npu/csrc/profiler/cann_profiling.h index ae2863fcd94ee1a7c0cac81c212238b1880909f5..8e80afbc0bc7691b50fc86b9d10a35060f1be01a 100644 --- a/torch_npu/csrc/framework/utils/NpuProfilingDispatch.cpp +++ b/torch_npu/csrc/profiler/cann_profiling.h @@ -4,7 +4,7 @@ // // Licensed under the BSD 3-Clause License (the "License"); // you may not use this file except in compliance with the License. -// You may obtain a copy of the License at_npu +// You may obtain a copy of the License at // // https://opensource.org/licenses/BSD-3-Clause // @@ -14,64 +14,53 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include -#include +#ifndef __CANN_PROFILING__ +#define __CANN_PROFILING__ -#include "torch_npu/csrc/framework/utils/NpuProfilingDispatch.h" +#include +#include +#include -namespace at_npu -{ - namespace native - { - - NpuProfilingDispatch &NpuProfilingDispatch::Instance() - { - static NpuProfilingDispatch npuProfilingDispatch; - return npuProfilingDispatch; - } +typedef enum { + PROFILING_FINALIZE, + PROFILING_INIT, + PROFILING_START, + PROFILING_STOP +} PROFILING_STATUS; - void NpuProfilingDispatch::init() - { - profStepInfo = c10::npu::acl::init_stepinfo(); - } +namespace torch_npu { +namespace profiler { - void NpuProfilingDispatch::start() - { - this->init(); - auto stream = c10::npu::getCurrentNPUStream(); - auto ret = c10::npu::acl::start_deliver_op( - profStepInfo, - aclprofStepTag::ACL_STEP_START, - stream); - if (ret != ACL_ERROR_NONE) - { - NPU_LOGE("npu profiling start fail, error code: %d", ret); - C10_NPU_SHOW_ERR_MSG(); - } - } - - void NpuProfilingDispatch::stop() - { - auto stream = c10::npu::getCurrentNPUStream(); - auto ret = c10::npu::acl::stop_deliver_op( - profStepInfo, - aclprofStepTag::ACL_STEP_END, - stream); - if (ret != ACL_ERROR_NONE) - { - NPU_LOGE("npu profiling stop fail, error code: %d", ret); - C10_NPU_SHOW_ERR_MSG(); - } - this->destroy(); - } +class TORCH_NPU_API NpuProfiling +{ +public: + static NpuProfiling& Instance(); + void Init(const std::string &profilerResultPath); + void Start(uint64_t npu_event, uint64_t aicore_metrics); + void Stop(); + void Finalize(); +private: + aclprofConfig* profCfg = nullptr; + PROFILING_STATUS status = PROFILING_FINALIZE; + NpuProfiling() = default; + ~NpuProfiling() = default; +}; - void NpuProfilingDispatch::destroy() - { - if (profStepInfo != nullptr) - { - c10::npu::acl::destroy_stepinfo(profStepInfo); - } - } +class NpuProfilingDispatch +{ +public: + static NpuProfilingDispatch& Instance(); + void start(); + void stop(); +private: + aclprofStepInfo* profStepInfo = nullptr; + NpuProfilingDispatch() = default; + ~NpuProfilingDispatch() = default; + void init(); + void destroy(); +}; - } } +} + +#endif // __CANN_PROFILING__ \ No newline at end of file diff --git a/torch_npu/csrc/register/FunctionLoader.cpp b/torch_npu/csrc/register/FunctionLoader.cpp new file mode 100644 index 0000000000000000000000000000000000000000..f076791aa26d83a3e61d0bf93ab77354d1a33916 --- /dev/null +++ b/torch_npu/csrc/register/FunctionLoader.cpp @@ -0,0 +1,102 @@ +// Copyright (c) 2020 Huawei Technologies Co., Ltd +// All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include "torch_npu/csrc/register/FunctionLoader.h" + +namespace torch_npu { +namespace option { + +FunctionLoader::FunctionLoader(const std::string& name) { + this->fileName = name + ".so"; +} + +FunctionLoader::~FunctionLoader() { + if (this->handle != nullptr) { + dlclose(this->handle); + } +} + +void FunctionLoader::Set(const std::string& name) { + this->registry[name] = nullptr; +} + +void* FunctionLoader::Get(const std::string& name) { + if (this->handle == nullptr) { + auto handle = dlopen(this->fileName.c_str(), RTLD_LAZY); + if (handle == nullptr) { + AT_ERROR(dlerror()); + return nullptr; + } + this->handle = handle; + } + + auto itr = registry.find(name); + if (itr == registry.end()) { + AT_ERROR("function(", name, ") is not registered."); + return nullptr; + } + + if (itr->second != nullptr) { + return itr->second; + } + + auto func = dlsym(this->handle, name.c_str()); + if (func == nullptr) { + return nullptr; + } + this->registry[name] = func; + return func; +} + +namespace register_function { + FunctionRegister* FunctionRegister::GetInstance() { + static FunctionRegister instance; + return &instance; + } + void FunctionRegister::Register(const std::string& name, ::std::unique_ptr& ptr) { + std::lock_guard lock(mu_); + registry.emplace(name, std::move(ptr)); + } + + void FunctionRegister::Register(const std::string& name, const std::string& funcName) { + auto itr = registry.find(name); + if (itr == registry.end()) { + AT_ERROR(name, " library should register first."); + return; + } + itr->second->Set(funcName); + } + + void* FunctionRegister::Get(const std::string& soName, const std::string& funcName) { + auto itr = registry.find(soName); + if (itr != registry.end()) { + return itr->second->Get(funcName); + } + return nullptr; + } + + FunctionRegisterBuilder::FunctionRegisterBuilder(const std::string& name, ::std::unique_ptr& ptr) { + FunctionRegister::GetInstance()->Register(name, ptr); + } + FunctionRegisterBuilder::FunctionRegisterBuilder(const std::string& soName, const std::string& funcName) { + FunctionRegister::GetInstance()->Register(soName, funcName); + } +} // namespace register_function + + +} // namespace option +} // namespace torch_npu diff --git a/torch_npu/csrc/register/FunctionLoader.h b/torch_npu/csrc/register/FunctionLoader.h new file mode 100644 index 0000000000000000000000000000000000000000..d56e2ef5b04e66e45e4e89a2c587f38e0e8d8c4f --- /dev/null +++ b/torch_npu/csrc/register/FunctionLoader.h @@ -0,0 +1,113 @@ +// Copyright (c) 2020 Huawei Technologies Co., Ltd +// All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include + +namespace torch_npu { +namespace option { + +/** + FunctionLoader is used to store function address in the process. + */ +class FunctionLoader { +public: + /** + ctr + */ + explicit FunctionLoader(const std::string& filename); + /** + dectr + */ + ~FunctionLoader(); + /** + set function name + */ + void Set(const std::string& name); + /** + get function address by function name. + */ + void* Get(const std::string& name); +private: + mutable std::mutex mu_; + std::string fileName; + void* handle = nullptr; + mutable std::unordered_map registry; +}; // class FunctionLoader + + +namespace register_function { +/** + this class is used to register + */ +class FunctionRegister { +public: + /** + Singleton + */ + static FunctionRegister* GetInstance(); + /** + this API is used to store FunctionLoader class + */ + void Register(const std::string& name, ::std::unique_ptr& ptr); + /** + this API is used to associate library name and function name. + */ + void Register(const std::string& name, const std::string& funcName); + /** + this API is used to get the function address by library and function name. + */ + void* Get(const std::string& soName, const std::string& funcName); + +private: + FunctionRegister() = default; + mutable std::mutex mu_; + mutable std::unordered_map> registry; +}; // class FunctionRegister + +/** + FunctionRegisterBuilder is the helper of FunctionRegister. + */ +class FunctionRegisterBuilder { +public: + /** + ctr + */ + FunctionRegisterBuilder(const std::string& name, ::std::unique_ptr& ptr); + /** + ctr + */ + FunctionRegisterBuilder(const std::string& soName, const std::string& funcName); +}; // class FunctionRegisterBuilder + +} // namespace register_function + +#define REGISTER_LIBRARY(soName) \ + auto library_##soName = \ + ::std::unique_ptr(new torch_npu::option::FunctionLoader(#soName)); \ + static torch_npu::option::register_function::FunctionRegisterBuilder \ + register_library_##soName(#soName, library_##soName); + +#define REGISTER_FUNCTION(soName, funcName) \ + static torch_npu::option::register_function::FunctionRegisterBuilder \ + register_function_##funcName(#soName, #funcName); + +#define GET_FUNCTION(soName, funcName) \ + torch_npu::option::register_function::FunctionRegister::GetInstance()->Get(#soName, #funcName); + +} // namespace option +} // namespace torch_npu diff --git a/torch_npu/csrc/register/OptionRegister.cpp b/torch_npu/csrc/register/OptionRegister.cpp new file mode 100644 index 0000000000000000000000000000000000000000..30090f0a52997abad280f42b3b651ae0c57df568 --- /dev/null +++ b/torch_npu/csrc/register/OptionRegister.cpp @@ -0,0 +1,102 @@ +// Copyright (c) 2020 Huawei Technologies Co., Ltd +// All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include "torch_npu/csrc/register/OptionRegister.h" + +namespace torch_npu { +namespace option { + +OptionInterface::OptionInterface(OptionCallBack callback) { + this->callback = callback; +} + +void OptionInterface::Set(const std::string& in) { + this->val = in; + if (this->callback != nullptr) { + this->callback(in); + } +} + +std::string OptionInterface::Get() { + return val; +} + + +namespace register_options { +OptionRegister* OptionRegister::GetInstance() { + static OptionRegister instance; + return &instance; +} + +void OptionRegister::Register(const std::string& name, + ::std::unique_ptr& ptr) { + std::lock_guard lock(mu_); + registry.emplace(name, std::move(ptr)); +} + +void OptionRegister::Set(const std::string& name, const std::string& val) { + auto itr = registry.find(name); + if (itr != registry.end()) { + itr->second->Set(val); + } else { + AT_ERROR("invalid npu option name:", name); + } +} + +c10::optional OptionRegister::Get(const std::string& name) { + auto itr = registry.find(name); + if (itr != registry.end()) { + return itr->second->Get(); + } + return c10::nullopt; // default value +} + +OptionInterfaceBuilder::OptionInterfaceBuilder( + const std::string& name, + ::std::unique_ptr& ptr, + const std::string& type) { + OptionRegister::GetInstance()->Register(name, ptr); + + // init the value if env variable. + if (type == "env") { + std::string env_name = name; + std::transform(env_name.begin(), env_name.end(), env_name.begin(), ::toupper); + char* env_val = std::getenv(env_name.c_str()); + if (env_val != nullptr) { + std::string val(env_val); + OptionRegister::GetInstance()->Set(name, val); + } + } +} +} // namespace register_options + +void SetOption(const std::string& key, const std::string& val) { + register_options::OptionRegister::GetInstance()->Set(key, val); +} + +void SetOption(const std::map& options) { + for (auto item : options) { + SetOption(item.first, item.second); + } +} + +c10::optional GetOption(const std::string& key) { + return register_options::OptionRegister::GetInstance()->Get(key); +} + +} // namespace option +} // namespace torch_npu diff --git a/torch_npu/csrc/register/OptionRegister.h b/torch_npu/csrc/register/OptionRegister.h new file mode 100644 index 0000000000000000000000000000000000000000..da51c188b1d29d299fb5ccc7d53fd40eb86eede7 --- /dev/null +++ b/torch_npu/csrc/register/OptionRegister.h @@ -0,0 +1,145 @@ +// Copyright (c) 2020 Huawei Technologies Co., Ltd +// All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef __TORCH_NPU_OPTION_REGISTER_H__ +#define __TORCH_NPU_OPTION_REGISTER_H__ + +#include +#include +#include +#include +#include +#include + +namespace torch_npu { +namespace option { + +typedef void(*OptionCallBack) (const std::string&); +/** + This class is used to storage env value, and provide Set and Get to + */ +class OptionInterface { +public: + /** + dctr + */ + OptionInterface(OptionCallBack callback=nullptr); + /** + This API is used to store value. + */ + void Set(const std::string& in); + /** + This API is used to load value. + */ + std::string Get(); +private: +/** + Its used to store hook. + */ + OptionCallBack callback = nullptr; + std::string val; +}; + +namespace register_options { + +/** + This class is used to register OptionInterface + */ +class OptionRegister { +public: + /** + dctr + */ + ~OptionRegister() = default; + /** + singleton + */ + static OptionRegister* GetInstance(); + /** + register + */ + void Register(const std::string& name, ::std::unique_ptr& ptr); + /** + This API is used to store value to special key. + */ + void Set(const std::string& name, const std::string& val); + /** + This API is used to load value from special key. + */ + c10::optional Get(const std::string& name); +private: + OptionRegister() {} + mutable std::mutex mu_; + mutable std::unordered_map> registry; +}; + +/** + This class is the helper to construct class OptionRegister + */ +class OptionInterfaceBuilder { +public: + OptionInterfaceBuilder(const std::string& name, ::std::unique_ptr& ptr, const std::string& type = "cli"); +}; + +} // namespace register_options + +/** + This API is used to store key-value pairs + */ +void SetOption(const std::map& options); +/** + This API is used to store key-value pair + */ +void SetOption(const std::string& key, const std::string& val); +/** + This API is used to load value by key + */ +c10::optional GetOption(const std::string& key); + +#define REGISTER_OPTION(name) \ + REGISTER_OPTION_UNIQ(name, name, cli) + +#define REGISTER_OPTION_INIT_BY_ENV(name) \ + REGISTER_OPTION_UNIQ(name, name, env) + +#define REGISTER_OPTION_UNIQ(id, name, type) \ + auto options_interface_##id = \ + ::std::unique_ptr(new torch_npu::option::OptionInterface()); \ + static torch_npu::option::register_options::OptionInterfaceBuilder \ + register_options_interface_##id(#name, options_interface_##id, #type); + +#define REGISTER_OPTION_HOOK(name, ...) \ + REGISTER_OPTION_HOOK_UNIQ(name, name, __VA_ARGS__) + +#define REGISTER_OPTION_HOOK_UNIQ(id, name, ...) \ + auto options_interface_##id = \ + ::std::unique_ptr( \ + new torch_npu::option::OptionInterface(torch_npu::option::OptionCallBack(__VA_ARGS__))); \ + static torch_npu::option::register_options::OptionInterfaceBuilder \ + register_options_interface_##id(#name, options_interface_##id); + +#define REGISTER_OPTION_BOOL_FUNCTION(func, key, defaultVal, trueVal) \ + bool func() { \ + auto val = torch_npu::option::GetOption(#key); \ + if (val.value_or(defaultVal) == (trueVal)) { \ + return true; \ + } \ + return false; \ + } + +} // namespace option +} // namespace torch_npu + +#endif // __TORCH_NPU_OPTION_REGISTER_H__ \ No newline at end of file diff --git a/torch_npu/csrc/register/OptionsManager.cpp b/torch_npu/csrc/register/OptionsManager.cpp new file mode 100644 index 0000000000000000000000000000000000000000..0afa6ce4d49bdd48ee26026256829ea18559574f --- /dev/null +++ b/torch_npu/csrc/register/OptionsManager.cpp @@ -0,0 +1,123 @@ +// Copyright (c) 2020 Huawei Technologies Co., Ltd +// All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "torch_npu/csrc/register/OptionRegister.h" +#include "torch_npu/csrc/register/OptionsManager.h" + +namespace torch_npu { +namespace option { + +using namespace std; + +bool OptionsManager::CheckQueueEnable() { + static int32_t queue_enable = -1; + if (queue_enable == -1) { + queue_enable = GetBoolTypeOption("TASK_QUEUE_ENABLE"); + } + return (queue_enable == 1); +} + +bool OptionsManager::CheckPTcopy_Enable() { + static int32_t PTcopy__enable = -1; + if (PTcopy__enable == -1) { + PTcopy__enable = GetBoolTypeOption("PTCOPY_ENABLE"); + } + return (PTcopy__enable == 1); +} + +bool OptionsManager::CheckCombinedOptimizerEnable() { + static int32_t combined_optimize = -1; + if (combined_optimize == -1) { + combined_optimize = GetBoolTypeOption("COMBINED_ENABLE"); + } + return (combined_optimize == 1); +} + +bool OptionsManager::CheckTriCombinedOptimizerEnable() { + static int32_t tri_combined_optimize = -1; + if (tri_combined_optimize == -1) { + tri_combined_optimize = GetBoolTypeOption("TRI_COMBINED_ENABLE"); + } + return (tri_combined_optimize == 1); +} + + +bool OptionsManager::CheckAclDumpDateEnable() { + static int aclDumpDataEnable = -1; + if (aclDumpDataEnable == -1) { + aclDumpDataEnable = GetBoolTypeOption("ACL_DUMP_DATA"); + } + return (aclDumpDataEnable == 1); +} + +bool OptionsManager::CheckSwitchMMOutputEnable() { + static int switchMMOutputEnable = -1; + if (switchMMOutputEnable == -1) { + switchMMOutputEnable = GetBoolTypeOption("SWITCH_MM_OUTPUT_ENABLE"); + } + return (switchMMOutputEnable == 1); +} + +int OptionsManager::GetBoolTypeOption(const char* env_str) { + char* env_val = std::getenv(env_str); + int64_t envFlag = (env_val != nullptr) ? strtol(env_val, nullptr, 10) : 0; + return (envFlag != 0) ? 1 : 0; +} + +bool OptionsManager::CheckUseNpuLogEnable() { + static int useNpuLog = -1; + if (useNpuLog == -1) { + useNpuLog = GetBoolTypeOption("NPU_LOG_ENABLE"); + } + + return (useNpuLog == 1); +} + +bool OptionsManager::CheckDynamicOptimizer(const char* op) { + static int isGetOps = 0; + static std::map op_map = {{"ADD", false}, {"MUL", false}}; + if (isGetOps == 0) { + char* dynamicOptimizerEnv = std::getenv("DYNAMIC_OP"); + if (dynamicOptimizerEnv != nullptr) { + std::string dynamicOptimizerEnvStr = dynamicOptimizerEnv; + const std::string separator = "#"; + std::string::size_type pos1 = 0; + std::string::size_type pos2 = dynamicOptimizerEnvStr.find(separator); + std::string substr; + while (pos2 != std::string::npos) { + substr = dynamicOptimizerEnvStr.substr(pos1, pos2 - pos1); + if (op_map.find(substr) != op_map.end()) { + op_map[substr] = true; + } + pos1 = pos2 + separator.size(); + pos2 = dynamicOptimizerEnvStr.find(separator, pos1); + } + if (pos1 != dynamicOptimizerEnvStr.size()) { + substr = dynamicOptimizerEnvStr.substr(pos1); + if (op_map.find(substr) != op_map.end()) { + op_map[substr] = true; + } + } + } + isGetOps = 1; + } + TORCH_CHECK( + op_map.find(op) != op_map.end(), "This op is not currently optimized."); + return op_map[op]; +} + +} // namespace option +} // namespace torch_npu \ No newline at end of file diff --git a/torch_npu/csrc/framework/utils/NpuProfilingDispatch.h b/torch_npu/csrc/register/OptionsManager.h similarity index 43% rename from torch_npu/csrc/framework/utils/NpuProfilingDispatch.h rename to torch_npu/csrc/register/OptionsManager.h index ec8e7c90b556d93c2ac3e999f1524fdfef19e058..dd59ea6513cb6d4da2518f3c728e6ea30a39e598 100644 --- a/torch_npu/csrc/framework/utils/NpuProfilingDispatch.h +++ b/torch_npu/csrc/register/OptionsManager.h @@ -1,10 +1,9 @@ // Copyright (c) 2020 Huawei Technologies Co., Ltd -// Copyright (c) 2019, Facebook CORPORATION. // All rights reserved. // // Licensed under the BSD 3-Clause License (the "License"); // you may not use this file except in compliance with the License. -// You may obtain a copy of the License at_npu +// You may obtain a copy of the License at // // https://opensource.org/licenses/BSD-3-Clause // @@ -14,28 +13,33 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef __PLUGIN_NPU_PROFILING_DISPATCH__ -#define __PLUGIN_NPU_PROFILING_DISPATCH__ +#ifndef __TORCH_NPU_OPTIONSMANAGER_H__ +#define __TORCH_NPU_OPTIONSMANAGER_H__ -#include +#include +#include +#include +#include -namespace at_npu { -namespace native { +namespace torch_npu { +namespace option { -class NpuProfilingDispatch -{ +class OptionsManager { public: - static NpuProfilingDispatch& Instance(); - void start(); - void stop(); + static bool CheckQueueEnable(); + static bool CheckPTcopy_Enable(); + static bool CheckCombinedOptimizerEnable(); + static bool CheckTriCombinedOptimizerEnable(); + static bool CheckAclDumpDateEnable(); + static bool CheckSwitchMMOutputEnable(); + static bool CheckDynamicOptimizer(const char* op); + static bool CheckUseNpuLogEnable(); + static std::string CheckDisableDynamicPath(); private: - aclprofStepInfo* profStepInfo = nullptr; - NpuProfilingDispatch() = default; - ~NpuProfilingDispatch() = default; - void init(); - void destroy(); + static int GetBoolTypeOption(const char* env_str); }; -} -} -#endif // __PLUGIN_NPU_PROFILING_DISPATCH__ \ No newline at end of file +} // namespace option +} // namespace torch_npu + +#endif // __TORCH_NPU_OPTIONSMANAGER_H__ \ No newline at end of file diff --git a/torch_npu/npu/__init__.py b/torch_npu/npu/__init__.py index d2f0179f38f2152a140d67b5a79c509e2385392b..c370fd20ea8bd8df2d507c6400d4fb9d99d0cc74 100644 --- a/torch_npu/npu/__init__.py +++ b/torch_npu/npu/__init__.py @@ -26,7 +26,8 @@ __all__ = [ "reset_peak_memory_stats", "reset_max_memory_allocated", "reset_max_memory_cached", "memory_allocated", "max_memory_allocated", "memory_reserved", "max_memory_reserved", "memory_cached", "max_memory_cached", "memory_snapshot", "memory_summary", - "Stream", "Event", "profiler", "set_option", "set_aoe" + "Stream", "Event", "profiler", "set_option", "set_aoe", "profile", "prof_init", + "prof_start", "prof_stop", "prof_finalize", "profileConfig" ] @@ -44,4 +45,5 @@ from .memory import (_free_mutex, caching_allocator_alloc, caching_allocator_del memory_cached, max_memory_cached, memory_snapshot, memory_summary) from .streams import Stream, Event from . import profiler -from .npu_frontend_enhance import set_option, set_aoe \ No newline at end of file +from .npu_frontend_enhance import (set_option, set_aoe, profile, prof_init, + prof_start, prof_stop, prof_finalize, profileConfig) \ No newline at end of file diff --git a/torch_npu/npu/npu_frontend_enhance.py b/torch_npu/npu/npu_frontend_enhance.py index bb3b14bfa78e5616aaffd2f17a67adab849eacae..70730860c5fc7cd27b1e15bd75bb68d3ab71ce68 100644 --- a/torch_npu/npu/npu_frontend_enhance.py +++ b/torch_npu/npu/npu_frontend_enhance.py @@ -19,7 +19,8 @@ import os import torch_npu._C # this file is used to enhance the npu frontend API by set_option or other. -__all__ = ["set_option", "global_step_inc", "set_start_fuzz_compile_step", "set_aoe"] +__all__ = ["set_option", "global_step_inc", "set_start_fuzz_compile_step", "set_aoe", "profile", "prof_init", + "prof_start", "prof_stop", "prof_finalize", "profileConfig"] def set_option(option): if not isinstance(option, dict): @@ -72,3 +73,101 @@ def set_aoe(dump_path): os.makedirs(dump_path) except Exception: raise ValueError("the path of '%s' is invaild."%(dump_path)) + +def prof_init(path): + if not os.path.exists(path): + raise AssertionError("profiler_result_path: %s not exists."%(path)) + profiler_result_path = os.path.abspath(path) + option = {"profilerResultPath": profiler_result_path} + torch_npu._C._npu_setOption(option) + +def prof_start(npu_event, aicore_metrics): + torch_npu._C._prof_start(npu_event, aicore_metrics) + +def prof_stop(): + option = {"profiling": "stop"} + torch_npu._C._npu_setOption(option) + +def prof_finalize(): + option = {"profiling": "finalize"} + torch_npu._C._npu_setOption(option) + +class npuEvent(object): + def __init__(self): + self.ACL_PROF_ACL_API = 0x0001 + self.ACL_PROF_TASK_TIME = 0x0002 + self.ACL_PROF_AICORE_METRICS = 0x0004 + self.ACL_PROF_AICPU = 0x0008 + self.ACL_PROF_L2CACHE = 0x0010 + self.ACL_PROF_HCCL_TRACE = 0x0020 + self.ACL_PROF_TRAINING_TRACE = 0x0040 + self.ACL_PROF_MSPROFTX = 0x0080 + self.ACL_PROF_RUNTIME_API = 0x0100 + + def update(self, ACL_PROF_ACL_API=True, ACL_PROF_TASK_TIME=True, + ACL_PROF_AICORE_METRICS=True, ACL_PROF_AICPU=True, + ACL_PROF_L2CACHE=False, ACL_PROF_HCCL_TRACE=True, + ACL_PROF_TRAINING_TRACE=True): + if not ACL_PROF_ACL_API: + self.ACL_PROF_ACL_API = 0x00 + if not ACL_PROF_TASK_TIME: + self.ACL_PROF_TASK_TIME = 0x00 + if not ACL_PROF_AICORE_METRICS: + self.ACL_PROF_AICORE_METRICS = 0x00 + if not ACL_PROF_AICPU: + self.ACL_PROF_AICPU = 0x00 + if not ACL_PROF_L2CACHE: + self.ACL_PROF_L2CACHE = 0x00 + if not ACL_PROF_HCCL_TRACE: + self.ACL_PROF_HCCL_TRACE = 0x00 + if not ACL_PROF_TRAINING_TRACE: + self.ACL_PROF_TRAINING_TRACE = 0x00 + return self.getConfig() + + def getConfig(self): + return self.ACL_PROF_ACL_API | self.ACL_PROF_TASK_TIME | self.ACL_PROF_AICORE_METRICS \ + | self.ACL_PROF_AICPU | self.ACL_PROF_L2CACHE | self.ACL_PROF_HCCL_TRACE \ + | self.ACL_PROF_TRAINING_TRACE | self.ACL_PROF_RUNTIME_API + +class aiCoreMetrics(object): + ACL_AICORE_ARITHMETIC_UTILIZATION = 0 + ACL_AICORE_PIPE_UTILIZATION = 1 + ACL_AICORE_MEMORY_BANDWIDTH = 2 + ACL_AICORE_L0B_AND_WIDTH = 3 + ACL_AICORE_RESOURCE_CONFLICT_RATIO = 4 + ACL_AICORE_NONE = 0xFF + +class profileConfig(object): + def __init__(self, ACL_PROF_ACL_API=True, ACL_PROF_TASK_TIME=True, ACL_PROF_AICORE_METRICS=True, + ACL_PROF_AICPU=True, ACL_PROF_L2CACHE=False, ACL_PROF_HCCL_TRACE=True, + ACL_PROF_TRAINING_TRACE=True, aiCoreMetricsType=0): + self.NpuEventConfig = npuEvent().update(ACL_PROF_ACL_API, ACL_PROF_TASK_TIME, ACL_PROF_AICORE_METRICS, + ACL_PROF_AICPU, ACL_PROF_L2CACHE, ACL_PROF_HCCL_TRACE, + ACL_PROF_TRAINING_TRACE) + self.AiCoreMetricsConfig = aiCoreMetricsType + + +class profile(object): + def __init__(self, profiler_result_path="./", use_e2e_profiler=False, + config=profileConfig()): + self.result_path = profiler_result_path + self.use_e2e_profiler = use_e2e_profiler + self.npu_event = config.NpuEventConfig + self.aicore_metrics = config.AiCoreMetricsConfig + self.entered = False + if self.use_e2e_profiler: + raise ValueError("This version dose not support E2E profiling now!") + + def __enter__(self): + if self.entered: + raise RuntimeError("npu profiler traces are not reentrant") + self.entered = True + prof_init(self.result_path) + prof_start(self.npu_event, self.aicore_metrics) + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + prof_stop() + prof_finalize() + return False + \ No newline at end of file