diff --git a/test/test_network_ops/test_grid_sampler_2d.py b/test/test_network_ops/test_grid_sampler_2d.py new file mode 100644 index 0000000000000000000000000000000000000000..af67846b8712f369613c2fd632713103e3395e4f --- /dev/null +++ b/test/test_network_ops/test_grid_sampler_2d.py @@ -0,0 +1,76 @@ +# Copyright (c) 2020, Huawei Technologies.All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch_npu +import numpy as np + +from torch_npu.testing.common_utils import TestCase, run_tests +from torch_npu.testing.common_device_type import instantiate_device_type_tests +from torch_npu.testing.util_test import create_common_tensor + +class TestGridSampler2D(TestCase): + def cpu_op_exec(self, input1, grid): + output = torch.grid_sampler_2d(input1, grid, 0, 0, True) + output = output.numpy() + return output + + def npu_op_exec(self, input1, grid): + output = torch.grid_sampler_2d(input1, grid, 0, 0, True) + output = output.to("cpu") + output = output.numpy() + return output + + def cpu_op_fp16_exec(self, input1, grid): + input1 = input1.to(torch.float32) + grid = grid.to(torch.float32) + output = torch.grid_sampler_2d(input1, grid, 0, 0, True) + output = output.numpy() + output = output.astype(np.float16) + return output + + def test_grid_sampler_2d_shape_format(self, device): + shape_format = [ + [[np.float32, 0, (1,2,4,20)],[np.float32, 0, (1,10,8,2)]], + [[np.float32, 0, (1,4,64, 10)],[np.float32, 0, (1,2,32,2)]], + [[np.float32, 0, (2, 2048, 7, 7)],[np.float32, 0, (2, 2048, 14, 2)]], + [[np.float32, 4, (32, 1, 3, 3)],[np.float32, 4, (32, 20, 30, 2)]], + [[np.float32, 29, (1,2,10, 128)],[np.float32, 4, (1, 10, 5, 2)]] + ] + for item in shape_format: + cpu_input, npu_input = create_common_tensor(item[0], 1, 100) + cpu_grid, npu_grid = create_common_tensor(item[1], -1, 1) + cpu_output = self.cpu_op_exec(cpu_input, cpu_grid) + npu_output = self.npu_op_exec(npu_input, npu_grid) + self.assertRtolEqual(cpu_output, npu_output) + + def test_grid_sampler_2d_fp16_shape_format(self, device): + shape_format = [ + [[np.float16, 0, (1,2,4,20)],[np.float16, 0, (1,10,8,2)]], + [[np.float16, 0, (1,4,64, 10)],[np.float16, 0, (1,2,32,2)]], + [[np.float16, 0, (2, 2048, 7, 7)],[np.float16, 0, (2, 2048, 14, 2)]], + [[np.float16, 4, (32, 1, 3, 3)],[np.float16, 4, (32, 20, 30, 2)]], + [[np.float16, 29, (1,2,10, 128)],[np.float16, 4, (1, 10, 5, 2)]] + ] + for item in shape_format: + cpu_input, npu_input = create_common_tensor(item[0], 1, 100) + cpu_grid, npu_grid = create_common_tensor(item[1], -1, 1) + cpu_output = self.cpu_op_fp16_exec(cpu_input, cpu_grid) + npu_output = self.npu_op_exec(npu_input, npu_grid) + self.assertRtolEqual(cpu_output, npu_output) + +instantiate_device_type_tests(TestGridSampler2D, globals(), except_for="cpu") +if __name__ == "__main__": + run_tests() + \ No newline at end of file diff --git a/torch_npu/csrc/aten/ops/GridSampler2dKernelNpu.cpp b/torch_npu/csrc/aten/ops/GridSampler2dKernelNpu.cpp new file mode 100644 index 0000000000000000000000000000000000000000..674479dbe23b3eb768c83d7018e2c789e3a8118f --- /dev/null +++ b/torch_npu/csrc/aten/ops/GridSampler2dKernelNpu.cpp @@ -0,0 +1,73 @@ +// Copyright (c) 2020 Huawei Technologies Co., Ltd +// Copyright (c) 2019, Facebook CORPORATION. +// All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" + +namespace at_npu { +namespace native { + +at::Tensor NPUNativeFunctions::grid_sampler_2d( + const at::Tensor& self, + const at::Tensor& grid, + int64_t interpolation_mode, + int64_t padding_mode, + bool align_corners) { + TORCH_CHECK( + (0 <= interpolation_mode && interpolation_mode <= 2), + "interpolation_mode must be in range [0~2].") + TORCH_CHECK( + (0 <= padding_mode && padding_mode <= 2), + "padding_mode must be in range [0~2].") + + at::Tensor dtypeCastOfSelf = self; + at::Tensor dtypeCastOfGrid = grid; + if (dtypeCastOfSelf.scalar_type() == c10::ScalarType::Half) { + dtypeCastOfSelf = NPUNativeFunctions::npu_dtype_cast(dtypeCastOfSelf, c10::ScalarType::Float); + } + if (dtypeCastOfGrid.scalar_type() == c10::ScalarType::Half) { + dtypeCastOfGrid = NPUNativeFunctions::npu_dtype_cast(dtypeCastOfGrid, c10::ScalarType::Float); + } + + c10::SmallVector outputSize = {dtypeCastOfSelf.size(0), + dtypeCastOfSelf.size(1), + dtypeCastOfGrid.size(1), + dtypeCastOfGrid.size(2)}; + + at::Tensor result = OpPreparation::ApplyTensorWithFormat(dtypeCastOfSelf, outputSize, ACL_FORMAT_ND); + + c10::SmallVectorinterMode = {"bilinear", "nearest", "bicubic"}; + c10::SmallVectorpaddingMode = {"zeros", "border", "reflection"}; + + OpCommand cmd; + cmd.Name("GridSampler2D") + .Input(dtypeCastOfSelf) + .Input(dtypeCastOfGrid) + .Output(result) + .Attr("interpolation_mode", interMode[interpolation_mode]) + .Attr("padding_mode", paddingMode[padding_mode]) + .Attr("align_corners", align_corners) + .Run(); + + c10::ScalarType selfScalarType(self.scalar_type()); + if (result.scalar_type() != selfScalarType) { + result = NPUNativeFunctions::npu_dtype_cast(result,selfScalarType); + } + return result; +} + +} // namespace native +} // namespace at_npu