From 8075eaba89ecd2e64bb7765f04b455f1f5daf60f Mon Sep 17 00:00:00 2001 From: pipihugh Date: Mon, 14 Feb 2022 09:58:43 +0800 Subject: [PATCH] =?UTF-8?q?max=5Funpool2d=E3=80=81max=5Funpool2d=5Fbackwar?= =?UTF-8?q?d=E7=AE=97=E5=AD=90=E7=A7=BB=E6=A4=8D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- test/test_network_ops/test_max_unpool2d.py | 33 +++++++ .../test_max_unpool2d_backward.py | 50 +++++++++++ .../pooling/MaxUnpool2dBackwardKernelNpu.cpp | 85 ++++++++++++++++++ .../aten/ops/pooling/MaxUnpool2dKernelNpu.cpp | 87 +++++++++++++++++++ 4 files changed, 255 insertions(+) create mode 100644 test/test_network_ops/test_max_unpool2d.py create mode 100644 test/test_network_ops/test_max_unpool2d_backward.py create mode 100644 torch_npu/csrc/aten/ops/pooling/MaxUnpool2dBackwardKernelNpu.cpp create mode 100644 torch_npu/csrc/aten/ops/pooling/MaxUnpool2dKernelNpu.cpp diff --git a/test/test_network_ops/test_max_unpool2d.py b/test/test_network_ops/test_max_unpool2d.py new file mode 100644 index 00000000000..ba14a812163 --- /dev/null +++ b/test/test_network_ops/test_max_unpool2d.py @@ -0,0 +1,33 @@ +# Copyright (c) 2020, Huawei Technologies.All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch_npu + +from torch_npu.testing.common_utils import TestCase, run_tests +from torch_npu.testing.common_device_type import instantiate_device_type_tests + +class TestMaxunpool2d(TestCase): + def test_max_unpool2d(self, device): + input1 = torch.tensor([[[[1., 2., 3., 4.], [5., 6., 7., 8.], [9., 10. , 11. , 12.], [13., 14., 15., 16.]]]]) + pool2d = torch.nn.MaxPool2d(2, stride=2, return_indices=True) + out, ind = pool2d(input1) + unpool2d = torch.nn.MaxUnpool2d(2, stride=2) + npu_out = unpool2d(out.npu(), ind.npu()) + cpu_out = unpool2d(out, ind) + self.assertRtolEqual(cpu_out, npu_out.cpu()) + +instantiate_device_type_tests(TestMaxunpool2d, globals(), except_for="cpu") +if __name__ == "__main__": + run_tests() diff --git a/test/test_network_ops/test_max_unpool2d_backward.py b/test/test_network_ops/test_max_unpool2d_backward.py new file mode 100644 index 00000000000..b01008ce7a0 --- /dev/null +++ b/test/test_network_ops/test_max_unpool2d_backward.py @@ -0,0 +1,50 @@ +# Copyright (c) 2020, Huawei Technologies.All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch_npu + +from torch_npu.testing.common_utils import TestCase, run_tests +from torch_npu.testing.common_device_type import instantiate_device_type_tests + +class TestMaxunpool2dBackward(TestCase): + def test_maxunpool2d_backward(self, device): + input1 = torch.tensor([[[[1., 2, 3, 4], [5, 6, 7, 8], [9, 10 , 11 , 12], [13, 14, 15, 16]]]]) + pool2d = torch.nn.MaxPool2d(2, stride = 2, return_indices = True) + out, ind = pool2d(input1) + unpool2d = torch. nn.MaxUnpool2d(2, stride = 2) + npu_upinput = out.npu() + npu_ind = ind.npu() + npu_upinput.requires_grad = True + out.requires_grad = True + npu_out = unpool2d(npu_upinput, npu_ind) + npu_out.backward(torch.ones_like(npu_out)) + npu_grad = npu_upinput.grad + cpu_out = unpool2d(out, ind) + cpu_out.backward(torch.ones_like(cpu_out)) + cpu_grad = out.grad + self.assertRtolEqual(cpu_grad, npu_grad.cpu()) + + cpu_out = unpool2d(out, ind) + grad_input = torch.randn(cpu_out.shape) + cpu_out.backward(grad_input) + cpu_grad = out.grad + npu_out = unpool2d(npu_upinput, npu_ind) + npu_out.backward(grad_input.npu()) + npu_grad = npu_upinput.grad + self.assertRtolEqual(cpu_grad, npu_grad.cpu()) + +instantiate_device_type_tests(TestMaxunpool2dBackward, globals(), except_for="cpu") +if __name__ == "__main__": + run_tests() diff --git a/torch_npu/csrc/aten/ops/pooling/MaxUnpool2dBackwardKernelNpu.cpp b/torch_npu/csrc/aten/ops/pooling/MaxUnpool2dBackwardKernelNpu.cpp new file mode 100644 index 00000000000..91bab6bbb5e --- /dev/null +++ b/torch_npu/csrc/aten/ops/pooling/MaxUnpool2dBackwardKernelNpu.cpp @@ -0,0 +1,85 @@ +// Copyright (c) 2020, Huawei Technologies.All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" + +namespace at_npu { +namespace native { + +at::Tensor& NPUNativeFunctions::max_unpool2d_backward_out( + const at::Tensor& gradOutput, + const at::Tensor& self, + const at::Tensor& indices, + at::IntArrayRef outputSize, + at::Tensor& gradInput) { + OpPreparation::CheckOut( + {self, gradOutput}, + gradInput, + self); + TORCH_CHECK( + outputSize.size() == 2, + "There should be exactly two elements (height, width) in outputSize"); + TORCH_CHECK( + (self.ndimension() == 3 || self.ndimension() == 4), + "Input to max_unpooling2d should be a 3d or 4d Tensor"); + TORCH_CHECK( + self.sizes() == indices.sizes(), + "Shape of indices should match shape of input"); + TORCH_CHECK(self.numel() > 0, "Input must be non-empty"); + + auto oheight = outputSize[0]; + auto owidth = outputSize[1]; + int64_t n = 1; + int64_t c = self.size(0); + int64_t h = self.size(1); + int64_t w = self.size(2); + int64_t selfDim = self.ndimension(); + + if (selfDim == 4) { + n = self.size(0); + c = self.size(1); + h = self.size(2); + w = self.size(3); + } + + auto gradOutputContiguous = gradOutput.contiguous(); + auto indicesContiguous = indices.contiguous(); + gradOutputContiguous = gradOutputContiguous.reshape({n, c, oheight * owidth}); + indicesContiguous = indicesContiguous.reshape({n, c, h * w}); + gradInput.resize_as_(self); + gradInput.zero_(); + gradInput = gradInput.reshape({n, c, h * w}); + const int dim = 2; + + gradInput = gather_out(gradOutputContiguous, dim, indicesContiguous, false, gradInput); + if (selfDim == 3) { + gradInput = gradInput.reshape({c, h, w}); + } else { + gradInput = gradInput.reshape({n, c, h, w}); + } + return gradInput; +} + +at::Tensor NPUNativeFunctions::max_unpool2d_backward( + const at::Tensor& gradOutput, + const at::Tensor& self, + const at::Tensor& indices, + at::IntArrayRef outputSize) { + auto gradInput = at::empty_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT); + max_unpool2d_backward_out(gradOutput, self, indices, outputSize, gradInput); + return gradInput; +} +} // namespace native +} // namespace at_npu diff --git a/torch_npu/csrc/aten/ops/pooling/MaxUnpool2dKernelNpu.cpp b/torch_npu/csrc/aten/ops/pooling/MaxUnpool2dKernelNpu.cpp new file mode 100644 index 00000000000..4ab035a4e29 --- /dev/null +++ b/torch_npu/csrc/aten/ops/pooling/MaxUnpool2dKernelNpu.cpp @@ -0,0 +1,87 @@ +// Copyright (c) 2020, Huawei Technologies.All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" + +namespace at_npu { +namespace native { + +at::Tensor& NPUNativeFunctions::max_unpool2d_out( + const at::Tensor& self, + const at::Tensor& indices, + at::IntArrayRef outputSize, + at::Tensor& output) { + OpPreparation::CheckOut( + {self, indices}, + output, + self); + TORCH_CHECK( + outputSize.size() == 2, + "There should be exactly two elements (height, width) in outputSize"); + TORCH_CHECK( + (self.ndimension() == 3 || self.ndimension() == 4), + "Input to max_unpooling2d should be a 3d or 4d Tensor"); + TORCH_CHECK( + self.sizes() == indices.sizes(), + "Shape of indices should match shape of input"); + TORCH_CHECK(self.numel() > 0, "Input must be non-empty"); + + auto oheight = outputSize[0]; + auto owidth = outputSize[1]; + auto selfContiguous = self.contiguous(); + auto indicesContiguous = indices.contiguous(); + int64_t h = -1; + int64_t w = -1; + int64_t selfDim = self.ndimension(); + int64_t numBatch = -1; + int64_t numChannels = -1; + if (selfDim == 3) { + numChannels = self.size(0); + h = self.size(1); + w = self.size(2); + output.resize_({numChannels, oheight * owidth}); + selfContiguous = selfContiguous.reshape({numChannels, h * w}); + indicesContiguous = indicesContiguous.reshape({numChannels, h * w}); + } else { + numBatch = self.size(0); + numChannels = self.size(1); + h = self.size(2); + w = self.size(3); + output.resize_({numBatch, numChannels, oheight * owidth}); + selfContiguous = selfContiguous.reshape({numBatch, numChannels, h * w}); + indicesContiguous = indicesContiguous.reshape({numBatch, numChannels, h * w}); + } + + output.zero_(); + int64_t dim = 2; + output = output.scatter(dim, indicesContiguous, selfContiguous); + if (selfDim == 3) { + output = output.reshape({numChannels, oheight, owidth}); + } else { + output = output.reshape({numBatch, numChannels, oheight, owidth}); + } + return output; +}; + +at::Tensor NPUNativeFunctions::max_unpool2d( + const at::Tensor& self, + const at::Tensor& indices, + at::IntArrayRef output_size) { + auto output = OpPreparation::ApplyTensor(self, {0}); + max_unpool2d_out(self, indices, output_size, output); + return output; +} +} // namespace native +} // namespace at_npu -- Gitee