From 1a82fae6e8c10cac2133ca448e90a5bd6001c21a Mon Sep 17 00:00:00 2001 From: shenpengcheng Date: Sun, 20 Feb 2022 15:12:45 +0800 Subject: [PATCH] =?UTF-8?q?GridSampler2dBackward=E7=AE=97=E5=AD=90?= =?UTF-8?q?=E8=BF=81=E7=A7=BB?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit code check code fix --- .../test_grid_sampler_2d_backward.py | 100 ++++++++++++++++++ .../ops/GridSampler2dBackwardKernelNpu.cpp | 78 ++++++++++++++ 2 files changed, 178 insertions(+) create mode 100644 test/test_network_ops/test_grid_sampler_2d_backward.py create mode 100644 torch_npu/csrc/aten/ops/GridSampler2dBackwardKernelNpu.cpp diff --git a/test/test_network_ops/test_grid_sampler_2d_backward.py b/test/test_network_ops/test_grid_sampler_2d_backward.py new file mode 100644 index 0000000000..e18f97c0cd --- /dev/null +++ b/test/test_network_ops/test_grid_sampler_2d_backward.py @@ -0,0 +1,100 @@ +# Copyright (c) 2020, Huawei Technologies.All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch_npu +import numpy as np + +from torch_npu.testing.common_utils import TestCase, run_tests +from torch_npu.testing.common_device_type import instantiate_device_type_tests +from torch_npu.testing.util_test import create_common_tensor + +class TestGridSampler2dBackward(TestCase): + def get_attrs(self): + attrs = [ + [0, True], + [1, True], + [0, False], + [1, False] + ] + return attrs + + def cpu_op_exec(self, input1, sample, pad_mode, align): + input1.requires_grad = True + sample.requires_grad = True + out = torch.grid_sampler_2d(input1, sample, 0, pad_mode, align) + out.backward(torch.ones_like(out)) + dx = input1.grad.numpy() + dgrid = sample.grad.numpy() + return dx, dgrid + + def npu_op_exec(self, input1, sample, pad_mode, align): + input1.requires_grad = True + sample.requires_grad = True + out = torch.grid_sampler_2d(input1, sample, 0, pad_mode, align) + out.backward(torch.ones_like(out)) + dx = input1.grad + dgrid = sample.grad + dx = dx.to("cpu").numpy() + dgrid = dgrid.to("cpu").numpy() + return dx, dgrid + + def test_grid_sampler_2d_backward_fp32(self, device): + shape_list = [[100, 1, 28, 28], [100, 64, 32, 28]] + shape_format = [ + [np.float32, -1, j] for j in shape_list + ] + sample_format = [np.float32, -1, [100, 1, 1, 2]] + attrs = self.get_attrs() + for item in shape_format: + for attr in attrs: + cpu_input, npu_input = create_common_tensor(item, 0, 100) + cpu_sample, npu_sample = create_common_tensor(sample_format, -1, 1) + cpu_output_dx, cpu_output_dgrid = self.cpu_op_exec(cpu_input, cpu_sample, *attr) + npu_output_dx, npu_output_dgrid = self.npu_op_exec(npu_input, npu_sample, *attr) + self.assertRtolEqual(cpu_output_dx, npu_output_dx) + self.assertRtolEqual(cpu_output_dgrid, npu_output_dgrid) + + def test_grid_sampler_2d_backward_fp16(self, device): + def cpu_op_fp16_exec(input1, sample, pad_mode, align): + input1 = input1.to(torch.float32) + sample = sample.to(torch.float32) + input1.requires_grad = True + sample.requires_grad = True + out = torch.grid_sampler(input1, sample, 0, pad_mode, align) + out.backward(torch.ones_like(out)) + dx = input1.grad + dgrid = sample.grad + dx = dx.numpy().astype(np.float16) + dgrid = dgrid.numpy().astype(np.float16) + return dx, dgrid + + shape_list = [[100, 1, 28, 28], [100, 64, 32, 28]] + shape_format = [ + [np.float16, -1, j] for j in shape_list + ] + sample_format = [np.float16, -1, [100, 1, 1, 2]] + attrs = self.get_attrs() + for item in shape_format: + for attr in attrs: + cpu_input, npu_input = create_common_tensor(item, 0, 100) + cpu_sample, npu_sample = create_common_tensor(sample_format, -1, 1) + cpu_output_dx, cpu_output_dgrid = cpu_op_fp16_exec(cpu_input, cpu_sample, *attr) + npu_output_dx, npu_output_dgrid = self.npu_op_exec(npu_input, npu_sample, *attr) + self.assertRtolEqual(cpu_output_dx, npu_output_dx) + self.assertRtolEqual(cpu_output_dgrid, npu_output_dgrid) + +instantiate_device_type_tests(TestGridSampler2dBackward, globals(), except_for="cpu") +if __name__ == "__main__": + run_tests() \ No newline at end of file diff --git a/torch_npu/csrc/aten/ops/GridSampler2dBackwardKernelNpu.cpp b/torch_npu/csrc/aten/ops/GridSampler2dBackwardKernelNpu.cpp new file mode 100644 index 0000000000..69090441c4 --- /dev/null +++ b/torch_npu/csrc/aten/ops/GridSampler2dBackwardKernelNpu.cpp @@ -0,0 +1,78 @@ +// Copyright (c) 2020 Huawei Technologies Co., Ltd +// Copyright (c) 2019, Facebook CORPORATION. +// All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" + +namespace at_npu { +namespace native { + +std::tuple NPUNativeFunctions::grid_sampler_2d_backward( + const at::Tensor& grad, + const at::Tensor& input, + const at::Tensor& grid, + int64_t interpolation_mode, + int64_t padding_mode, + bool align_corners) { + TORCH_CHECK( + (0 <= interpolation_mode && interpolation_mode <= 2), + "interpolation_mode must be in range [0~2].") + TORCH_CHECK( + (0 <= padding_mode && padding_mode <= 2), + "padding_mode must be in range [0~2].") + + at::Tensor formatCastOfGrad = grad; + at::Tensor formatCastOfInput = input; + at::Tensor formatCastOfGrid = grid; + if (formatCastOfGrad.scalar_type() == at::ScalarType::Half) { + formatCastOfGrad = NPUNativeFunctions::npu_dtype_cast(formatCastOfGrad, at::ScalarType::Float); + } + if (formatCastOfInput.scalar_type() == at::ScalarType::Half) { + formatCastOfInput = NPUNativeFunctions::npu_dtype_cast(formatCastOfInput, at::ScalarType::Float); + } + if (formatCastOfGrid.scalar_type() == at::ScalarType::Half) { + formatCastOfGrid = NPUNativeFunctions::npu_dtype_cast(formatCastOfGrid, at::ScalarType::Float); + } + + at::Tensor dx = OpPreparation::ApplyTensor(formatCastOfInput); + at::Tensor dgrid = OpPreparation::ApplyTensor(formatCastOfGrid); + + c10::SmallVectorinterMode = {"bilinear", "nearest", "bicubic"}; + c10::SmallVectorpaddingMode = {"zeros", "border", "reflection"}; + + OpCommand cmd; + cmd.Name("GridSampler2DGrad") + .Input(formatCastOfGrad) + .Input(formatCastOfInput) + .Input(formatCastOfGrid) + .Output(dx) + .Output(dgrid) + .Attr("interpolation_mode", interMode[interpolation_mode]) + .Attr("padding_mode", paddingMode[padding_mode]) + .Attr("align_corners", align_corners) + .Run(); + + at::ScalarType inputScalarType(input.scalar_type()); + if (dx.scalar_type() != inputScalarType) { + dx = NPUNativeFunctions::npu_dtype_cast(dx, inputScalarType); + dgrid = NPUNativeFunctions::npu_dtype_cast(dgrid, inputScalarType); + } + + return std::tuple(dx, dgrid); +} +} // namespace native +} // namespace at_npu -- Gitee