From ace31df16d5852a17e21653121f4f18ed3d1afeb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=AD=99=E5=B9=B8?= <346736790@qq.com> Date: Fri, 18 Feb 2022 07:25:46 +0000 Subject: [PATCH 1/7] test npu_dropout --- test/test_network_ops/test_dropout.py | 73 +++++++ .../test_network_ops/test_dropout_backward.py | 89 +++++++++ torch_npu/csrc/aten/npu_native_functions.yaml | 8 +- torch_npu/csrc/aten/ops/dropoutKernelNpu.cpp | 188 ++++++++++++++++++ 4 files changed, 357 insertions(+), 1 deletion(-) create mode 100644 test/test_network_ops/test_dropout.py create mode 100644 test/test_network_ops/test_dropout_backward.py create mode 100644 torch_npu/csrc/aten/ops/dropoutKernelNpu.cpp diff --git a/test/test_network_ops/test_dropout.py b/test/test_network_ops/test_dropout.py new file mode 100644 index 0000000000..324d52e0c7 --- /dev/null +++ b/test/test_network_ops/test_dropout.py @@ -0,0 +1,73 @@ +# Copyright (c) 2020, Huawei Technologies.All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +sys.path.append('..') +import torch +import torch_npu +import numpy as np +from torch.nn import functional as F + +from torch_npu.testing.common_utils import TestCase, run_tests +from torch_npu.testing.common_device_type import Dtypes, instantiate_device_type_tests +from torch_npu.testing.util_test import create_common_tensor, test_2args_broadcast, create_dtype_tensor, UT_FAST_MODE + +class TestDropOutDoMask(TestCase): + def cpu_op_exec(self, input): + out = torch.nn.Dropout(0.5)(input) + out = out.numpy() + return out + + def npu_op_exec(self, input): + out = torch.nn.Dropout(0.5)(input) + out = out.to("cpu") + out = out.numpy() + return out + + def dropout_list_exec(self, list): + epsilon = 1e-3 + for item in list: + cpu_input1, npu_input1 = create_common_tensor(item, 0, 100) + if cpu_input1.dtype == torch.float16: + cpu_input1 = cpu_input1.to(torch.float32) + cpu_output = self.cpu_op_exec(cpu_input1) + npu_output = self.npu_op_exec(npu_input1) + cpu_output = cpu_output.astype(npu_output.dtype) + # 该算子随机结果的比较方式 + for a, b in zip(cpu_output.flatten(), npu_output.flatten()): + if abs(a) > 0 and abs(b) > 0 and abs(a - b) > epsilon: + print(f'input = {item}, ERROR!') + break + else: + print(f'input = {item}, Successfully!') + + def test_op_shape_format_fp16(self, device): + format_list = [-1] + shape_list = [1, (256, 1280), (32, 3, 3), (256, 2048, 7, 7)] + shape_format = [ + [np.float16, i, j] for i in format_list for j in shape_list + ] + self.dropout_list_exec(shape_format) + + def test_op_shape_format_fp32(self, device): + format_list = [-1] + shape_list = [1, (256, 1280), (32, 3, 3), (256, 2048, 7, 7)] + shape_format = [ + [np.float32, i, j] for i in format_list for j in shape_list + ] + self.dropout_list_exec(shape_format) + +instantiate_device_type_tests(TestDropOutDoMask, globals(), except_for="cpu") +if __name__ == "__main__": + run_tests() \ No newline at end of file diff --git a/test/test_network_ops/test_dropout_backward.py b/test/test_network_ops/test_dropout_backward.py new file mode 100644 index 0000000000..281524924d --- /dev/null +++ b/test/test_network_ops/test_dropout_backward.py @@ -0,0 +1,89 @@ +# Copyright (c) 2020, Huawei Technologies.All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +sys.path.append('..') +import torch +import torch_npu +import numpy as np +from torch.nn import functional as F + +from torch_npu.testing.common_utils import TestCase, run_tests +from torch_npu.testing.common_device_type import Dtypes, instantiate_device_type_tests +from torch_npu.testing.util_test import create_common_tensor, test_2args_broadcast, create_dtype_tensor, UT_FAST_MODE + +class TestDropOutBackward(TestCase): + def cpu_op_exec(self, input1): + input1.requires_grad = True + out = torch.nn.Dropout(0.5)(input1) + out.backward(torch.ones_like(out)) + out_grad = input1.grad + out_grad = out_grad.detach().numpy() + out = out.detach().numpy() + return out_grad, out + + def npu_op_exec(self, input1): + input1.requires_grad = True + out = torch.nn.Dropout(0.5)(input1) + out.backward(torch.ones_like(out)) + out_grad = input1.grad + out_grad = out_grad.to("cpu") + out_grad = out_grad.detach().numpy() + out = out.to("cpu") + out = out.detach().numpy() + return out_grad, out + + def dropout_list_exec(self, list): + epsilon = 1e-3 + for item in list: + cpu_input1, npu_input1 = create_common_tensor(item, 0, 100) + if cpu_input1.dtype == torch.float16: + cpu_input1 = cpu_input1.to(torch.float32) + cpu_output_grad, cpu_output = self.cpu_op_exec(cpu_input1) + npu_output_grad, npu_output = self.npu_op_exec(npu_input1) + cpu_output = cpu_output.astype(npu_output.dtype) + # 该算子随机结果的比较方式 + for a, b in zip(cpu_output.flatten(), npu_output.flatten()): + if abs(a) > 0 and abs(b) > 0 and abs(a - b) > epsilon: + print(f'input = {item}, ERROR!') + break + else: + print(f'input = {item}, Successfully!') + + for a, b in zip(cpu_output_grad.flatten(), npu_output_grad.flatten()): + if abs(a) > 0 and abs(b) > 0 and abs(a - b) > epsilon: + print(f'input = {item}, ERROR!') + break + else: + print(f'input = {item}, Successfully!') + + def test_op_shape_format_fp16(self, device): + format_list = [-1] + shape_list = [1, (32, 3, 3)] + shape_format = [ + [np.float16, i, j] for i in format_list for j in shape_list + ] + self.dropout_list_exec(shape_format) + + def test_op_shape_format_fp32(self, device): + format_list = [-1] + shape_list = [1, (32, 3, 3)] + shape_format = [ + [np.float32, i, j] for i in format_list for j in shape_list + ] + self.dropout_list_exec(shape_format) + +instantiate_device_type_tests(TestDropOutBackward, globals(), except_for="cpu") +if __name__ == "__main__": + run_tests() \ No newline at end of file diff --git a/torch_npu/csrc/aten/npu_native_functions.yaml b/torch_npu/csrc/aten/npu_native_functions.yaml index 1aaf07af26..b86f394d2e 100644 --- a/torch_npu/csrc/aten/npu_native_functions.yaml +++ b/torch_npu/csrc/aten/npu_native_functions.yaml @@ -1926,10 +1926,16 @@ custom: - func: npu_confusion_transpose_backward(Tensor grad, int[] perm, int[] shape, bool transpose_first) -> Tensor - func: npu_one_hot(Tensor self, int num_classes=-1, int depth=1, Scalar on_value=1, Scalar off_value=0) -> Tensor variants: function, method + - func: npu_linear_backward(Tensor grad, Tensor input, Tensor weight) -> (Tensor, Tensor) + - func: npu_anchor_response_flags(Tensor self, int[2] featmap_size, int[2] stride, int num_base_anchors) -> Tensor + variants: function, method + - func: npu_dropout_backward(Tensor grad_output, Tensor mask, float p) -> Tensor custom_autograd: - func: npu_convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, int groups) -> Tensor - func: npu_convolution_transpose(Tensor input, Tensor weight, Tensor? bias, int[] padding, int[] output_padding, int[] stride, int[] dilation, int groups) -> Tensor - func: fast_gelu(Tensor self) -> Tensor - func: npu_confusion_transpose(Tensor self, int[] perm, int[] shape, bool transpose_first) -> Tensor variants: function, method - - func: npu_ps_roi_pooling(Tensor self, Tensor rois, float spatial_scale, int group_size, int output_dim) -> Tensor \ No newline at end of file + - func: npu_ps_roi_pooling(Tensor self, Tensor rois, float spatial_scale, int group_size, int output_dim) -> Tensor + - func: npu_linear(Tensor input, Tensor weight, Tensor? bias=None) -> Tensor + - func: _npu_dropout(Tensor self, float p) -> (Tensor, Tensor) \ No newline at end of file diff --git a/torch_npu/csrc/aten/ops/dropoutKernelNpu.cpp b/torch_npu/csrc/aten/ops/dropoutKernelNpu.cpp new file mode 100644 index 0000000000..287e3537d7 --- /dev/null +++ b/torch_npu/csrc/aten/ops/dropoutKernelNpu.cpp @@ -0,0 +1,188 @@ +// Copyright (c) 2020 Huawei Technologies Co., Ltd +// All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include "torch_npu/csrc/core/npu/NPUCachingAllocator.h" +#include "torch_npu/csrc/framework/utils/NpuUtils.h" +#include "torch_npu/csrc/framework/utils/CalcuOpUtil.h" +#include "torch_npu/csrc/framework/interface/EnvVariables.h" +#include "torch_npu/csrc/framework/utils/KernelNpuOutputSize.h" +#include "torch_npu/csrc/framework/utils/OpTemplate.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" + +namespace at_npu { +namespace native { +using namespace torch::autograd; +// using torch::autograd::AutogradContext; +// using tensor_list = std::vector; +//using tensor_list1 = std::vector; +at::Tensor dropout_do_mask( + at::Tensor& result, + const at::Tensor& self, + const at::Tensor& mask, + at::Scalar prob) { + OpCommand cmd; + cmd.Name("DropOutDoMask") + .Input(self) + .Input(mask) + .Input(prob, self.scalar_type(), CompileType::MEMORY_HOST_COMPILE_DEPENDENT) + .Output(result) + .Run(); + + return result; +} + +at::Tensor dropout_gen_mask(const at::Tensor& self, at::Scalar prob) { + bool isFuzzyCompile = env::CheckFuzzyEnable(); + int64_t numels; + auto desc_ = self.storage().get_npu_desc(); + numels = isFuzzyCompile ? at::prod_intlist(desc_.storage_sizes_) : self.numel(); + + uint32_t length = (numels + 128 - 1) / 128 * 128; + at::Tensor mask = OpPreparation::ApplyTensorWithFormat( + {length / 8}, + self.options().dtype(at::kByte), + ACL_FORMAT_ND); + + at::IntArrayRef selfShape = isFuzzyCompile ? desc_.storage_sizes_ : self.sizes(); + + OpCommand cmd; + // If either seed or seed2 are set to be non-zero, the random number generator + // is seeded by the given seed. Otherwise, it is seeded by a random seed. + int64_t seed = 0; + int64_t seed2 = 0; + cmd.Name("DropOutGenMask") + .Input(selfShape) + .Input(prob, self.scalar_type(), CompileType::MEMORY_HOST_COMPILE_DEPENDENT) + .Output(mask) + .Attr("seed", seed) + .Attr("seed2", seed2) + .Run(); + return mask; +} + +std::tuple dropout_v1_npu_impl( + at::Tensor result, + const at::Tensor& self, + double p) { + at::Tensor selfCp = NpuUtils::format_contiguous(self); + TORCH_CHECK( + p >= 0 && p <= 1, + "dropout probability has to be between 0 and 1, but got ", + p); + TORCH_CHECK( + at::isFloatingType(selfCp.scalar_type()), + "dropout only supports floating-point dtypes"); + + double retain = 1. - p; + at::Scalar prob = at::Scalar(retain); + at::Tensor mask; + auto original_stream = c10::npu::getCurrentNPUStream(); + { + // During the life cycle of this raii instance, the calcu stream is set as the + // secondary stream, and tasks are distributed to the secondary stream. At the + // same time, according to the one-stream-one-pool principle, memory is also + // alloced from the pool of the secondary stream. + c10::npu::SecondaryStreamGuard guard(c10::npu::getCurrentSecondaryStream()); + mask = dropout_gen_mask(selfCp, prob); + } + // When tasks on multiple streams read and write the same block of memory, + // recordStream needs to be called to ensure the correctness of memory reuse. + c10_npu::NPUCachingAllocator::recordStream(mask.storage().data_ptr(), original_stream); + dropout_do_mask(result, selfCp, mask, prob); + + return std::tie(result, mask); +} + +at::Tensor NPUNativeFunctions::npu_dropout_backward( + const at::Tensor& grad_output, + const at::Tensor& mask, + double scale) { + TORCH_CHECK( + at::isFloatingType(grad_output.scalar_type()), + "dropoutbackward only supports floating-point dtypes"); + TORCH_CHECK( + mask.scalar_type() == at::ScalarType::Byte, + "mask should be torch.uint8 dtype"); + double retain = 1. - scale; + at::Tensor result = OpPreparation::ApplyTensor(grad_output); + + OpCommand cmd; + cmd.Name("DropOutDoMask") + .Input(grad_output) + .Input(mask) + .Input(retain, grad_output.scalar_type(), CompileType::MEMORY_HOST_COMPILE_DEPENDENT) + .Output(result) + .Run(); + + return result; +} + +std::tuple _dropout_npu_com( + const at::Tensor& self, + double p) { + at::Tensor result = OpPreparation::ApplyTensor(self); + return dropout_v1_npu_impl(result, self, p); +} + +class NPUdropoutFunction: public torch::autograd::Function { +public: + static tensor_list forward(AutogradContext *ctx, + const at::Tensor& self, + double p) { + ctx->saved_data["p"] = p; + at::AutoNonVariableTypeMode g; + ctx->save_for_backward({self}); + auto result = _dropout_npu_com(self, p); + auto result1 = std::get<1>(result); + ctx->saved_data["output"] = result1; + tensor_list result_list = {std::get<0>(result), result1}; + return result_list; + } + + static tensor_list backward(AutogradContext *ctx, + tensor_list grad_outputs) { + auto p = ctx->saved_data["p"].toDouble(); + auto mask = ctx->saved_data["output"].toTensor(); + auto saved = ctx->get_saved_variables(); + + at::Tensor result = NPUNativeFunctions::npu_dropout_backward(grad_outputs[0], mask, p); + tensor_list output = {result, at::Tensor()}; + return output; + } +}; + +std::tuple NPUNativeFunctions::_npu_dropout( + const at::Tensor& self, + double p) { + auto result = NPUdropoutFunction::apply(self, p); + std::tuple output(result[0], result[1]); + return output; +} + +at::Tensor NPUNativeFunctions::dropout(const at::Tensor& self, double p, bool train) { + if (p == 0 || !train || self.numel() == 0) { + return self; + } + if (p == 1) { + return self.mul(at::zeros(self.sizes(), self.options())); + } + at::Tensor result = std::get<0>(NPUNativeFunctions::_npu_dropout(self, p)); + return result; +} + +} // namespace native +} // namespace at_npu \ No newline at end of file -- Gitee From dd96fe0c7741861e4dbdd4d00337ed8819c85ce8 Mon Sep 17 00:00:00 2001 From: sunxing <346736790@qq.com> Date: Fri, 18 Feb 2022 15:31:44 +0800 Subject: [PATCH 2/7] i --- torch_npu/csrc/aten/npu_native_functions.yaml | 1 - 1 file changed, 1 deletion(-) diff --git a/torch_npu/csrc/aten/npu_native_functions.yaml b/torch_npu/csrc/aten/npu_native_functions.yaml index 788d93bc56..d185d2ebdb 100644 --- a/torch_npu/csrc/aten/npu_native_functions.yaml +++ b/torch_npu/csrc/aten/npu_native_functions.yaml @@ -1939,4 +1939,3 @@ custom_autograd: - func: npu_ps_roi_pooling(Tensor self, Tensor rois, float spatial_scale, int group_size, int output_dim) -> Tensor - func: npu_linear(Tensor input, Tensor weight, Tensor? bias=None) -> Tensor - func: _npu_dropout(Tensor self, float p) -> (Tensor, Tensor) - - func: npu_linear(Tensor input, Tensor weight, Tensor? bias=None) -> Tensor -- Gitee From 2f463bb24350b8cc2ce50f8d3753fd89493237ca Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=AD=99=E5=B9=B8?= <346736790@qq.com> Date: Fri, 18 Feb 2022 07:41:05 +0000 Subject: [PATCH 3/7] f --- torch_npu/csrc/aten/ops/dropoutKernelNpu.cpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/torch_npu/csrc/aten/ops/dropoutKernelNpu.cpp b/torch_npu/csrc/aten/ops/dropoutKernelNpu.cpp index 287e3537d7..fb0023419b 100644 --- a/torch_npu/csrc/aten/ops/dropoutKernelNpu.cpp +++ b/torch_npu/csrc/aten/ops/dropoutKernelNpu.cpp @@ -12,9 +12,8 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. - -#include #include +#include "torch_npu/csrc/core/npu/SecondaryStreamGuard.h" #include "torch_npu/csrc/core/npu/NPUCachingAllocator.h" #include "torch_npu/csrc/framework/utils/NpuUtils.h" #include "torch_npu/csrc/framework/utils/CalcuOpUtil.h" -- Gitee From 19787c7afede0d79e3eca0b27b97db9af5a99982 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=AD=99=E5=B9=B8?= <346736790@qq.com> Date: Fri, 18 Feb 2022 08:00:26 +0000 Subject: [PATCH 4/7] i --- torch_npu/csrc/aten/ops/dropoutKernelNpu.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torch_npu/csrc/aten/ops/dropoutKernelNpu.cpp b/torch_npu/csrc/aten/ops/dropoutKernelNpu.cpp index fb0023419b..03bd2dc5ea 100644 --- a/torch_npu/csrc/aten/ops/dropoutKernelNpu.cpp +++ b/torch_npu/csrc/aten/ops/dropoutKernelNpu.cpp @@ -89,13 +89,13 @@ std::tuple dropout_v1_npu_impl( double retain = 1. - p; at::Scalar prob = at::Scalar(retain); at::Tensor mask; - auto original_stream = c10::npu::getCurrentNPUStream(); + auto original_stream = c10_npu::getCurrentNPUStream(); { // During the life cycle of this raii instance, the calcu stream is set as the // secondary stream, and tasks are distributed to the secondary stream. At the // same time, according to the one-stream-one-pool principle, memory is also // alloced from the pool of the secondary stream. - c10::npu::SecondaryStreamGuard guard(c10::npu::getCurrentSecondaryStream()); + c10_npu::SecondaryStreamGuard guard(c10_npu::getCurrentSecondaryStream()); mask = dropout_gen_mask(selfCp, prob); } // When tasks on multiple streams read and write the same block of memory, -- Gitee From 67f78e29488df73d293f79b9a90bcf73aee31552 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=AD=99=E5=B9=B8?= <346736790@qq.com> Date: Fri, 18 Feb 2022 08:20:00 +0000 Subject: [PATCH 5/7] i --- torch_npu/csrc/aten/ops/dropoutKernelNpu.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torch_npu/csrc/aten/ops/dropoutKernelNpu.cpp b/torch_npu/csrc/aten/ops/dropoutKernelNpu.cpp index 03bd2dc5ea..5d56bebccd 100644 --- a/torch_npu/csrc/aten/ops/dropoutKernelNpu.cpp +++ b/torch_npu/csrc/aten/ops/dropoutKernelNpu.cpp @@ -89,13 +89,13 @@ std::tuple dropout_v1_npu_impl( double retain = 1. - p; at::Scalar prob = at::Scalar(retain); at::Tensor mask; - auto original_stream = c10_npu::getCurrentNPUStream(); + auto original_stream = torch_npu::getCurrentNPUStream(); { // During the life cycle of this raii instance, the calcu stream is set as the // secondary stream, and tasks are distributed to the secondary stream. At the // same time, according to the one-stream-one-pool principle, memory is also // alloced from the pool of the secondary stream. - c10_npu::SecondaryStreamGuard guard(c10_npu::getCurrentSecondaryStream()); + torch_npu::SecondaryStreamGuard guard(torch_npu::getCurrentSecondaryStream()); mask = dropout_gen_mask(selfCp, prob); } // When tasks on multiple streams read and write the same block of memory, -- Gitee From 53850b80caee868086ece69796b9bc3b870c29ef Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=AD=99=E5=B9=B8?= <346736790@qq.com> Date: Fri, 18 Feb 2022 08:42:37 +0000 Subject: [PATCH 6/7] i --- torch_npu/csrc/aten/ops/dropoutKernelNpu.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torch_npu/csrc/aten/ops/dropoutKernelNpu.cpp b/torch_npu/csrc/aten/ops/dropoutKernelNpu.cpp index 5d56bebccd..d7889a4d59 100644 --- a/torch_npu/csrc/aten/ops/dropoutKernelNpu.cpp +++ b/torch_npu/csrc/aten/ops/dropoutKernelNpu.cpp @@ -89,13 +89,13 @@ std::tuple dropout_v1_npu_impl( double retain = 1. - p; at::Scalar prob = at::Scalar(retain); at::Tensor mask; - auto original_stream = torch_npu::getCurrentNPUStream(); + auto original_stream = c10::npu::getCurrentNPUStream(); { // During the life cycle of this raii instance, the calcu stream is set as the // secondary stream, and tasks are distributed to the secondary stream. At the // same time, according to the one-stream-one-pool principle, memory is also // alloced from the pool of the secondary stream. - torch_npu::SecondaryStreamGuard guard(torch_npu::getCurrentSecondaryStream()); + torch_npu::SecondaryStreamGuard guard(c10::npu::getCurrentSecondaryStream()); mask = dropout_gen_mask(selfCp, prob); } // When tasks on multiple streams read and write the same block of memory, -- Gitee From 9bbcfd05ff9b962467315d3055663198ecefad55 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=AD=99=E5=B9=B8?= <346736790@qq.com> Date: Fri, 18 Feb 2022 08:55:01 +0000 Subject: [PATCH 7/7] i --- torch_npu/csrc/core/npu/SecondaryStreamGuard.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torch_npu/csrc/core/npu/SecondaryStreamGuard.h b/torch_npu/csrc/core/npu/SecondaryStreamGuard.h index 3a5160949b..5994b7fe65 100644 --- a/torch_npu/csrc/core/npu/SecondaryStreamGuard.h +++ b/torch_npu/csrc/core/npu/SecondaryStreamGuard.h @@ -25,11 +25,11 @@ struct SecondaryStreamGuard{ explicit SecondaryStreamGuard(c10::Stream stream) : guard_(stream) {}; ~SecondaryStreamGuard() { - c10::NPUEvent npu_event; + c10::npu::NPUEvent npu_event; npu_event.record(guard_.current_stream()); npu_event.block(guard_.original_stream()); } private: - c10::NPUStreamGuard guard_; + c10::npu::NPUStreamGuard guard_; }; } // namespace torch_npu \ No newline at end of file -- Gitee