diff --git a/test/test_network_ops/test_reflection_pad2d_backward.py b/test/test_network_ops/test_reflection_pad2d_backward.py new file mode 100644 index 0000000000000000000000000000000000000000..b9a8853c85ea7bfff3eb6abff60daf48bf1ed2e9 --- /dev/null +++ b/test/test_network_ops/test_reflection_pad2d_backward.py @@ -0,0 +1,73 @@ +# Copyright (c) 2020, Huawei Technologies.All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch_npu +import numpy as np + +from torch_npu.testing.common_utils import TestCase, run_tests +from torch_npu.testing.common_device_type import instantiate_device_type_tests +from torch_npu.testing.util_test import create_common_tensor + +class TestReflectionPad2dBackward(TestCase): + def cpu_op_exec(self, input1, pad): + m = torch.nn.ReflectionPad2d(pad) + input1.requires_grad = True + output = m(input1) + output.backward(torch.ones_like(output)) + input_grad = input1.grad + output = output.numpy() + input_grad = input_grad.numpy() + return output, input_grad + + def npu_op_exec(self, input1, pad): + m = torch.nn.ReflectionPad2d(pad).to("npu") + input1.requires_grad = True + output = m(input1) + output.backward(torch.ones_like(output)) + input_grad = input1.grad + output = output.to("cpu") + output = output.detach().numpy() + input_grad = input_grad.cpu().numpy() + return output, input_grad + + def test_reflectionPad2d_backward_shape_format_fp16(self, device): + shape_format = [ + [[np.float16, 0, (1, 1, 37, 37)], [2, 2, 2, 2]], + [[np.float16, 3, (1, 1, 4, 3)], 2], + [[np.float16, 0, (1, 1, 17, 17)], [1, 2, 2, 2]], + ] + + def cpu_op_exec_fp16(input1, pad): + input1 = input1.to(torch.float32) + input1.requires_grad = True + m = torch.nn.ReflectionPad2d(pad) + output = m(input1) + output.backward(torch.ones_like(output)) + output = output.detach().numpy() + input_grad = input1.grad + input_grad = input_grad.numpy().astype(np.float16) + output = output.astype(np.float16) + return output, input_grad + + for item in shape_format: + cpu_input1, npu_input1 = create_common_tensor(item[0], 1, 100) + cpu_output, cpu_grad = cpu_op_exec_fp16(cpu_input1, item[1]) + npu_output, npu_grad = self.npu_op_exec(npu_input1, item[1]) + self.assertRtolEqual(cpu_output, npu_output) + self.assertRtolEqual(cpu_grad, npu_grad) + +instantiate_device_type_tests(TestReflectionPad2dBackward, globals(), except_for="cpu") +if __name__ == "__main__": + run_tests() diff --git a/torch_npu/csrc/aten/ops/ReflectionPad2dBackwardKernelNpu.cpp b/torch_npu/csrc/aten/ops/ReflectionPad2dBackwardKernelNpu.cpp new file mode 100644 index 0000000000000000000000000000000000000000..b0c2c3731cb5da0bac3d292605a123ec1e6eecfd --- /dev/null +++ b/torch_npu/csrc/aten/ops/ReflectionPad2dBackwardKernelNpu.cpp @@ -0,0 +1,81 @@ +// Copyright (c) 2020 Huawei Technologies Co., Ltd +// Copyright (c) 2019, Facebook CORPORATION. +// All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" + +namespace at_npu { +namespace native { + +at::Tensor& reflection_pad2d_backward_out_npu_nocheck( + const at::Tensor& gradOutput, + const at::Tensor& input, + at::IntArrayRef padding, + at::Tensor& gradInput) { + TORCH_CHECK(input.scalar_type() != at::ScalarType::Float, + "PadV3Grad don't supports torch.float!"); + c10::SmallVector vectorInt; + c10::SmallVector paddingsVector = array_to_small_vector(padding); + paddingsVector.resize(2 * input.dim(), 0); + for (int64_t i = paddingsVector.size(); i > 0; i -= 2) { + vectorInt.emplace_back(paddingsVector[i - 2]); + vectorInt.emplace_back(paddingsVector[i - 1]); + } + OpCommand cmd; + cmd.Name("PadV3Grad") + .Input(gradOutput) + .Input(vectorInt, at::kInt) + .Output(gradInput) + .Attr("mode", (string)"reflect") + .Attr("paddings_contiguous", true) + .Run(); + return gradInput; +} + +at::Tensor& NPUNativeFunctions::reflection_pad2d_backward_out( + const at::Tensor& gradOutput, + const at::Tensor& input, + at::IntArrayRef padding, + at::Tensor& gradInput) { + OpPreparation::CheckOut( + {input, gradOutput}, + gradInput, + input); + OpPipeWithDefinedOut pipe; + return pipe.CheckMemory({input, gradOutput}, {gradInput}) + .Func([&gradOutput, &input, &padding](at::Tensor& gradInput) + {reflection_pad2d_backward_out_npu_nocheck( + gradOutput, + input, + padding, + gradInput);}) + .Call(gradInput); +} + +at::Tensor NPUNativeFunctions::reflection_pad2d_backward( + const at::Tensor& gradOutput, + const at::Tensor& input, + at::IntArrayRef padding) { + at::Tensor gradInput = OpPreparation::ApplyTensor(input); + reflection_pad2d_backward_out_npu_nocheck( + gradOutput, + input, + padding, + gradInput); + return gradInput; +} +} // namespace native +} // namespace at_npu