From 5ffc06827e2c6e506ae0aba1d76a7eec10158b08 Mon Sep 17 00:00:00 2001 From: shenpengcheng Date: Wed, 26 Jan 2022 14:28:53 +0800 Subject: [PATCH] =?UTF-8?q?BatchNormBackwardElemt=20=E7=AE=97=E5=AD=90?= =?UTF-8?q?=E8=BF=81=E7=A7=BB?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit add adapt file code fix fix code_checking --- .../test_batchnorm_backward_elemt.py | 73 +++++++++++++ .../ops/BatchNormBackwardElemtKernelNpu.cpp | 100 ++++++++++++++++++ 2 files changed, 173 insertions(+) create mode 100644 test/test_network_ops/test_batchnorm_backward_elemt.py create mode 100644 torch_npu/csrc/aten/ops/BatchNormBackwardElemtKernelNpu.cpp diff --git a/test/test_network_ops/test_batchnorm_backward_elemt.py b/test/test_network_ops/test_batchnorm_backward_elemt.py new file mode 100644 index 0000000000..d85aaa5aef --- /dev/null +++ b/test/test_network_ops/test_batchnorm_backward_elemt.py @@ -0,0 +1,73 @@ +# Copyright (c) 2020, Huawei Technologies.All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch_npu +from torch_npu.testing.common_utils import TestCase, run_tests +from torch_npu.testing.common_device_type import instantiate_device_type_tests + +class TestBatchNormBackwardElemt(TestCase): + + def test_batch_norm_backward_elemt_4d(self, device): + grad_output = torch.ones([2, 3, 1, 4]).npu() + input1 = torch.ones([2, 3, 1, 4]).npu() + mean = torch.tensor([8., 5., 9.]).npu() + invstd = torch.tensor([2., 1., 2.]).npu() + weight = torch.tensor([1., 1., 4.]).npu() + mean_dy = torch.tensor([2., 2., 6.]).npu() + mean_dy_xmn = torch.tensor([2., 3., 11.]).npu() + + grad_input = torch.batch_norm_backward_elemt(grad_output, input1, mean, invstd, + weight, mean_dy, mean_dy_xmn) + cuda_expect_out = torch.tensor([[[[110., 110., 110., 110.]], + [[11, 11, 11, 11]], + [[2776., 2776., 2776, 2776.]]], + [[[110., 110., 110., 110.]], + [[11, 11, 11, 11]], + [[2776., 2776., 2776, 2776.]]]]) + self.assertRtolEqual(grad_input.cpu(), cuda_expect_out) + + def test_batch_norm_backward_elemt_2d(self, device): + grad_output = torch.ones([2, 3]).npu() + input1 = torch.ones([2, 3]).npu() + mean = torch.tensor([8., 5., 9.]).npu() + invstd = torch.tensor([2., 1., 2.]).npu() + weight = torch.tensor([1., 1., 4.]).npu() + mean_dy = torch.tensor([2., 2., 6.]).npu() + mean_dy_xmn = torch.tensor([2., 3., 11.]).npu() + + grad_input = torch.batch_norm_backward_elemt(grad_output, input1, mean, invstd, + weight, mean_dy, mean_dy_xmn) + cuda_expect_out = torch.tensor([[110., 11., 2776.], + [110., 11., 2776.]]) + self.assertRtolEqual(grad_input.cpu(), cuda_expect_out) + + def test_batch_norm_backward_elemt_2d_fp(self, device): + grad_output = torch.ones([2, 3]).npu() + input1 = torch.ones([2, 3]).npu() + mean = torch.tensor([8.123456, 5.147125, 9.365778]).npu() + invstd = torch.tensor([2.65485, 1.36541, 2.25879]).npu() + weight = torch.tensor([1.36987, 1.36944, 4.25774]).npu() + mean_dy = torch.tensor([2., 2., 6.]).npu() + mean_dy_xmn = torch.tensor([2., 3., 11.]).npu() + + grad_input = torch.batch_norm_backward_elemt(grad_output, input1, mean, invstd, + weight, mean_dy, mean_dy_xmn) + cuda_expect_out = torch.tensor([[361.5542, 41.5013, 4467.4121], + [361.5542, 41.5013, 4467.4121]]) + self.assertRtolEqual(grad_input.cpu(), cuda_expect_out) + +instantiate_device_type_tests(TestBatchNormBackwardElemt, globals(), except_for='cpu') +if __name__ == "__main__": + run_tests() \ No newline at end of file diff --git a/torch_npu/csrc/aten/ops/BatchNormBackwardElemtKernelNpu.cpp b/torch_npu/csrc/aten/ops/BatchNormBackwardElemtKernelNpu.cpp new file mode 100644 index 0000000000..889587ef29 --- /dev/null +++ b/torch_npu/csrc/aten/ops/BatchNormBackwardElemtKernelNpu.cpp @@ -0,0 +1,100 @@ +// Copyright (c) 2020 Huawei Technologies Co., Ltd +// Copyright (c) 2019, Facebook CORPORATION. +// All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" + +namespace at_npu { +namespace native { + +at::Tensor& batch_norm_backward_elemt_npu_nocheck( + const at::Tensor& grad_out, + const at::Tensor& input, + const at::Tensor& mean, + const at::Tensor& invstd, + const at::Tensor& weight, + const at::Tensor& mean_dy, + const at::Tensor& mean_dy_xmu, + at::Tensor& grad_input) { + OpCommand cmd; + cmd.Name("SyncBatchNormBackwardElemt") + .Input(grad_out) + .Input(input) + .Input(mean) + .Input(invstd) + .Input(weight) + .Input(mean_dy) + .Input(mean_dy_xmu) + .Output(grad_input) + .Run(); + return grad_input; +} + +void batch_norm_backward_elemt_npu_expand_tensor( + at::Tensor& expand_tensor, + size_t dim_c, + int64_t input_ndim, + at::IntArrayRef input_shape) { + if (input_ndim >2) { + expand_tensor = NPUNativeFunctions::npu_broadcast(expand_tensor, {1, dim_c}).t(); + for (int64_t i = 0; i < input_ndim - 3; i++) { + expand_tensor = expand_tensor.unsqueeze(1); + } + } + expand_tensor = NPUNativeFunctions::npu_broadcast(expand_tensor, input_shape); +} + +at::Tensor NPUNativeFunctions::batch_norm_backward_elemt( + const at::Tensor& grad_out, + const at::Tensor& input, + const at::Tensor& mean, + const at::Tensor& invstd, + const c10::optional& weight_opt, + const at::Tensor& mean_dy, + const at::Tensor& mean_dy_xmu) { + const at::Tensor& weight = c10::value_or_else(weight_opt, [] {return at::Tensor();}); + int64_t input_ndim = input.dim(); + TORCH_CHECK(input_ndim > 1, "input.dim() <= 1") + size_t dim_c = input.size(1); + at::IntArrayRef input_shape = input.sizes(); + at::Tensor mean_expanded(mean); + + batch_norm_backward_elemt_npu_expand_tensor(mean_expanded, dim_c, input_ndim, input_shape); + at::Tensor invstd_expanded(invstd); + + batch_norm_backward_elemt_npu_expand_tensor(invstd_expanded, dim_c, input_ndim, input_shape); + at::Tensor weight_expanded(weight); + + batch_norm_backward_elemt_npu_expand_tensor(weight_expanded, dim_c, input_ndim, input_shape); + at::Tensor mean_dy_expanded(mean_dy); + + batch_norm_backward_elemt_npu_expand_tensor(mean_dy_expanded, dim_c, input_ndim, input_shape); + at::Tensor mean_dy_xmu_expanded(mean_dy_xmu); + + batch_norm_backward_elemt_npu_expand_tensor(mean_dy_xmu_expanded, dim_c, input_ndim, input_shape); + at::Tensor grad_input = OpPreparation::ApplyTensor(input); + return batch_norm_backward_elemt_npu_nocheck( + grad_out, + input, + mean_expanded, + invstd_expanded, + weight_expanded, + mean_dy_expanded, + mean_dy_xmu_expanded, + grad_input); +} +} // namespace native +} // namespace at_npu \ No newline at end of file -- Gitee