From 3dadda004fba926b9553549f467ad7916cb36a3c Mon Sep 17 00:00:00 2001 From: shenpengcheng Date: Tue, 25 Jan 2022 21:16:20 +0800 Subject: [PATCH 1/4] bnbe --- .../test_batchnorm_backward_elemt.py | 73 +++++++++++++ torch_npu/csrc/aten/npu_native_functions.yaml | 1 + .../ops/BatchNormBackwardElemtKernelNpu.cpp | 100 ++++++++++++++++++ 3 files changed, 174 insertions(+) create mode 100644 test/test_network_ops/test_batchnorm_backward_elemt.py create mode 100644 torch_npu/csrc/aten/ops/BatchNormBackwardElemtKernelNpu.cpp diff --git a/test/test_network_ops/test_batchnorm_backward_elemt.py b/test/test_network_ops/test_batchnorm_backward_elemt.py new file mode 100644 index 00000000000..d85aaa5aefc --- /dev/null +++ b/test/test_network_ops/test_batchnorm_backward_elemt.py @@ -0,0 +1,73 @@ +# Copyright (c) 2020, Huawei Technologies.All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch_npu +from torch_npu.testing.common_utils import TestCase, run_tests +from torch_npu.testing.common_device_type import instantiate_device_type_tests + +class TestBatchNormBackwardElemt(TestCase): + + def test_batch_norm_backward_elemt_4d(self, device): + grad_output = torch.ones([2, 3, 1, 4]).npu() + input1 = torch.ones([2, 3, 1, 4]).npu() + mean = torch.tensor([8., 5., 9.]).npu() + invstd = torch.tensor([2., 1., 2.]).npu() + weight = torch.tensor([1., 1., 4.]).npu() + mean_dy = torch.tensor([2., 2., 6.]).npu() + mean_dy_xmn = torch.tensor([2., 3., 11.]).npu() + + grad_input = torch.batch_norm_backward_elemt(grad_output, input1, mean, invstd, + weight, mean_dy, mean_dy_xmn) + cuda_expect_out = torch.tensor([[[[110., 110., 110., 110.]], + [[11, 11, 11, 11]], + [[2776., 2776., 2776, 2776.]]], + [[[110., 110., 110., 110.]], + [[11, 11, 11, 11]], + [[2776., 2776., 2776, 2776.]]]]) + self.assertRtolEqual(grad_input.cpu(), cuda_expect_out) + + def test_batch_norm_backward_elemt_2d(self, device): + grad_output = torch.ones([2, 3]).npu() + input1 = torch.ones([2, 3]).npu() + mean = torch.tensor([8., 5., 9.]).npu() + invstd = torch.tensor([2., 1., 2.]).npu() + weight = torch.tensor([1., 1., 4.]).npu() + mean_dy = torch.tensor([2., 2., 6.]).npu() + mean_dy_xmn = torch.tensor([2., 3., 11.]).npu() + + grad_input = torch.batch_norm_backward_elemt(grad_output, input1, mean, invstd, + weight, mean_dy, mean_dy_xmn) + cuda_expect_out = torch.tensor([[110., 11., 2776.], + [110., 11., 2776.]]) + self.assertRtolEqual(grad_input.cpu(), cuda_expect_out) + + def test_batch_norm_backward_elemt_2d_fp(self, device): + grad_output = torch.ones([2, 3]).npu() + input1 = torch.ones([2, 3]).npu() + mean = torch.tensor([8.123456, 5.147125, 9.365778]).npu() + invstd = torch.tensor([2.65485, 1.36541, 2.25879]).npu() + weight = torch.tensor([1.36987, 1.36944, 4.25774]).npu() + mean_dy = torch.tensor([2., 2., 6.]).npu() + mean_dy_xmn = torch.tensor([2., 3., 11.]).npu() + + grad_input = torch.batch_norm_backward_elemt(grad_output, input1, mean, invstd, + weight, mean_dy, mean_dy_xmn) + cuda_expect_out = torch.tensor([[361.5542, 41.5013, 4467.4121], + [361.5542, 41.5013, 4467.4121]]) + self.assertRtolEqual(grad_input.cpu(), cuda_expect_out) + +instantiate_device_type_tests(TestBatchNormBackwardElemt, globals(), except_for='cpu') +if __name__ == "__main__": + run_tests() \ No newline at end of file diff --git a/torch_npu/csrc/aten/npu_native_functions.yaml b/torch_npu/csrc/aten/npu_native_functions.yaml index 27b8a88a836..39d5a21d1bd 100644 --- a/torch_npu/csrc/aten/npu_native_functions.yaml +++ b/torch_npu/csrc/aten/npu_native_functions.yaml @@ -241,6 +241,7 @@ supported: - addcdiv_ - addcdiv.out - slogdet + - batch_norm_backward_elemt custom: - func: npu_transpose_to_contiguous(Tensor self) -> Tensor diff --git a/torch_npu/csrc/aten/ops/BatchNormBackwardElemtKernelNpu.cpp b/torch_npu/csrc/aten/ops/BatchNormBackwardElemtKernelNpu.cpp new file mode 100644 index 00000000000..97503e3cb16 --- /dev/null +++ b/torch_npu/csrc/aten/ops/BatchNormBackwardElemtKernelNpu.cpp @@ -0,0 +1,100 @@ +// Copyright (c) 2020 Huawei Technologies Co., Ltd +// Copyright (c) 2019, Facebook CORPORATION. +// All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" + +namespace at_npu { +namespace native { + +at::Tensor batch_norm_backward_elemt_npu_nocheck( + const at::Tensor& grad_out, + const at::Tensor& input, + const at::Tensor& mean, + const at::Tensor& invstd, + const at::Tensor& weight, + const at::Tensor& mean_dy, + const at::Tensor& mean_dy_xmu, + at::Tensor& grad_input) { + OpCommand cmd; + cmd.Name("SyncBatchNormBackwardElemt") + .Input(grad_out) + .Input(input) + .Input(mean) + .Input(invstd) + .Input(weight) + .Input(mean_dy) + .Input(mean_dy_xmu) + .Output(grad_input) + .Run(); + return grad_input; +} + +void batch_norm_backward_elemt_npu_expand_tensor( + at::Tensor& expand_tensor, + size_t dim_c, + int64_t input_ndim, + at::IntArrayRef input_shape) { + if (input_ndim >2) { + expand_tensor = at::npu_broadcast(expand_tensor, {1, dim_c}).t(); + for (int64_t i = 0; i < input_ndim - 3; i++) { + expand_tensor = expand_tensor.unsqueeze(1); + } + } + expand_tensor = at::npu_broadcast(expand_tensor, input_shape); +} + +at::Tensor NPUNativeFunctions::batch_norm_backward_elemt( + const at::Tensor& grad_out, + const at::Tensor& input, + const at::Tensor& mean, + const at::Tensor& invstd, + const at::Tensor& weight, + const at::Tensor& mean_dy, + const at::Tensor& mean_dy_xmu) { + int64_t input_ndim = input.dim(); + + TORCH_CHECK(input_ndim > 1, "input.dim() <= 1") + size_t dim_c = input.size(1); + at::IntArrayRef input_shape = input.sizes(); + at::Tensor mean_expanded(mean); + + batch_norm_backward_elemt_npu_expand_tensor(mean_expanded, dim_c, input_ndim, input_shape); + at::Tensor invstd_expanded(invstd); + + batch_norm_backward_elemt_npu_expand_tensor(invstd_expanded, dim_c, input_ndim, input_shape); + at::Tensor weight_expanded(weight); + + batch_norm_backward_elemt_npu_expand_tensor(weight_expanded, dim_c, input_ndim, input_shape); + at::Tensor mean_dy_expanded(mean_dy); + + batch_norm_backward_elemt_npu_expand_tensor(mean_dy_expanded, dim_c, input_ndim, input_shape); + at::Tensor mean_dy_xmu_expanded(mean_dy_xmu); + + batch_norm_backward_elemt_npu_expand_tensor(mean_dy_xmu_expanded, dim_c, input_ndim, input_shape); + at::Tensor grad_input = OpPreparation::ApplyTensor(input); + return batch_norm_backward_elemt_npu_nocheck( + grad_out, + input, + mean_expanded, + invstd_expanded, + weight_expanded, + mean_dy_expanded, + mean_dy_xmu_expanded, + grad_input); +} +} // namespace native +} // namespace at_npu \ No newline at end of file -- Gitee From d0638881d8c5cf9f0ec9f9e370b0d0d989c6b54e Mon Sep 17 00:00:00 2001 From: shenpengcheng Date: Tue, 25 Jan 2022 21:27:46 +0800 Subject: [PATCH 2/4] bnbe --- torch_npu/csrc/aten/ops/BatchNormBackwardElemtKernelNpu.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torch_npu/csrc/aten/ops/BatchNormBackwardElemtKernelNpu.cpp b/torch_npu/csrc/aten/ops/BatchNormBackwardElemtKernelNpu.cpp index 97503e3cb16..3248e9b6a99 100644 --- a/torch_npu/csrc/aten/ops/BatchNormBackwardElemtKernelNpu.cpp +++ b/torch_npu/csrc/aten/ops/BatchNormBackwardElemtKernelNpu.cpp @@ -49,12 +49,12 @@ void batch_norm_backward_elemt_npu_expand_tensor( int64_t input_ndim, at::IntArrayRef input_shape) { if (input_ndim >2) { - expand_tensor = at::npu_broadcast(expand_tensor, {1, dim_c}).t(); + expand_tensor = NPUNativeFunctions::npu_broadcast(expand_tensor, {1, dim_c}).t(); for (int64_t i = 0; i < input_ndim - 3; i++) { expand_tensor = expand_tensor.unsqueeze(1); } } - expand_tensor = at::npu_broadcast(expand_tensor, input_shape); + expand_tensor = NPUNativeFunctions::npu_broadcast(expand_tensor, input_shape); } at::Tensor NPUNativeFunctions::batch_norm_backward_elemt( -- Gitee From e183f8c6e4afeb7a702ae8074b8982858ffdb83e Mon Sep 17 00:00:00 2001 From: shenpengcheng Date: Wed, 26 Jan 2022 10:22:52 +0800 Subject: [PATCH 3/4] bnbe --- torch_npu/csrc/aten/npu_native_functions.yaml | 1 + torch_npu/csrc/aten/ops/BatchNormBackwardElemtKernelNpu.cpp | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/torch_npu/csrc/aten/npu_native_functions.yaml b/torch_npu/csrc/aten/npu_native_functions.yaml index 39d5a21d1bd..deac45abc45 100644 --- a/torch_npu/csrc/aten/npu_native_functions.yaml +++ b/torch_npu/csrc/aten/npu_native_functions.yaml @@ -242,6 +242,7 @@ supported: - addcdiv.out - slogdet - batch_norm_backward_elemt + - npu_broadcast custom: - func: npu_transpose_to_contiguous(Tensor self) -> Tensor diff --git a/torch_npu/csrc/aten/ops/BatchNormBackwardElemtKernelNpu.cpp b/torch_npu/csrc/aten/ops/BatchNormBackwardElemtKernelNpu.cpp index 3248e9b6a99..23f45436b09 100644 --- a/torch_npu/csrc/aten/ops/BatchNormBackwardElemtKernelNpu.cpp +++ b/torch_npu/csrc/aten/ops/BatchNormBackwardElemtKernelNpu.cpp @@ -62,11 +62,11 @@ at::Tensor NPUNativeFunctions::batch_norm_backward_elemt( const at::Tensor& input, const at::Tensor& mean, const at::Tensor& invstd, - const at::Tensor& weight, + const c10::optional& weight_opt, const at::Tensor& mean_dy, const at::Tensor& mean_dy_xmu) { + const at::Tensor& weight = c10::value_or_else(weight_opt, [] {return at::Tensor();}); int64_t input_ndim = input.dim(); - TORCH_CHECK(input_ndim > 1, "input.dim() <= 1") size_t dim_c = input.size(1); at::IntArrayRef input_shape = input.sizes(); -- Gitee From f725e410685e806cef10462fec5f26d6f50c7bf6 Mon Sep 17 00:00:00 2001 From: shenpengcheng Date: Wed, 26 Jan 2022 11:23:54 +0800 Subject: [PATCH 4/4] bnbe --- torch_npu/csrc/aten/npu_native_functions.yaml | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/torch_npu/csrc/aten/npu_native_functions.yaml b/torch_npu/csrc/aten/npu_native_functions.yaml index deac45abc45..b42e3404119 100644 --- a/torch_npu/csrc/aten/npu_native_functions.yaml +++ b/torch_npu/csrc/aten/npu_native_functions.yaml @@ -242,7 +242,6 @@ supported: - addcdiv.out - slogdet - batch_norm_backward_elemt - - npu_broadcast custom: - func: npu_transpose_to_contiguous(Tensor self) -> Tensor @@ -288,4 +287,4 @@ custom: - func: npu_indexing.out(Tensor self, int[] begin, int[] end, int[] strides, int begin_mask=0, int end_mask=0, int ellipsis_mask=0, int new_axis_mask=0, int shrink_axis_mask=0, *, Tensor(a!) out) -> Tensor(a!) custom_autograd: - func: npu_convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, int groups) -> Tensor - - func: npu_convolution_transpose(Tensor input, Tensor weight, Tensor? bias, int[] padding, int[] output_padding, int[] stride, int[] dilation, int groups) -> Tensor + - func: npu_convolution_transpose(Tensor input, Tensor weight, Tensor? bias, int[] padding, int[] output_padding, int[] stride, int[] dilation, int groups) -> Tensor \ No newline at end of file -- Gitee