diff --git a/test/test_network_ops/test_grid_sampler.py b/test/test_network_ops/test_grid_sampler.py index 3ec50d5188210491daeed653cb56445173326cd3..4296cf46a8006c8d491065c5bb327f8b35f80603 100644 --- a/test/test_network_ops/test_grid_sampler.py +++ b/test/test_network_ops/test_grid_sampler.py @@ -22,7 +22,7 @@ from torch_npu.testing.testcase import TestCase, run_tests from torch_npu.testing.common_utils import create_common_tensor class TestGridSampler(TestCase): - def test_grid_sampler_fp32(self, device="npu"): + def test_grid_sampler_fp32(self): format_list = [0] shape_list = [[100, 1, 28, 28], [100, 64, 32, 28]] shape_format = [ @@ -36,13 +36,13 @@ class TestGridSampler(TestCase): npu_output = self.npu_op_exec(npu_input, npu_sample) self.assertRtolEqual(cpu_output, npu_output) - def test_grid_sampler_fp16(self, device="npu"): + def test_grid_sampler_fp16(self): format_list = [0] shape_list = [[1, 1, 3, 3], [1, 2, 3, 4]] shape_format = [ [np.float16, j, k] for j in format_list for k in shape_list ] - sample_format = [np.float32, 0, [1, 2, 2, 2]] + sample_format = [np.float16, 0, [1, 2, 2, 2]] for item in shape_format: cpu_input, npu_input = create_common_tensor(item, 0, 10) cpu_sample, npu_sample = create_common_tensor(sample_format, -1, 1) diff --git a/test/test_network_ops/test_grid_sampler_backwawrd.py b/test/test_network_ops/test_grid_sampler_backwawrd.py new file mode 100644 index 0000000000000000000000000000000000000000..a5b9b012735223cab5be77fb73ff5bc1c6210bd7 --- /dev/null +++ b/test/test_network_ops/test_grid_sampler_backwawrd.py @@ -0,0 +1,77 @@ +# Copyright (c) 2020 Huawei Technologies Co., Ltd +# Copyright (c) 2019, Facebook CORPORATION. +# All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch_npu +import numpy as np + +from torch_npu.testing.testcase import TestCase, run_tests +from torch_npu.testing.common_utils import create_common_tensor + +class TestGridSampler(TestCase): + def cpu_op_exec(self, input1, sample): + input1.requires_grad = True + sample.requires_grad = True + output = torch.grid_sampler(input1, sample, 0, 0, True) + output.backward(torch.ones_like(output)) + input_grad = input1.grad.numpy() + sample_grad = sample.grad.numpy() + return input_grad, sample_grad + + def npu_op_exec(self, input1, sample): + input1.requires_grad = True + sample.requires_grad = True + output = torch.grid_sampler(input1, sample, 0, 0, True) + output.backward(torch.ones_like(output)) + input_grad = input1.grad.to("cpu").numpy() + sample_grad = sample.grad.to("cpu").numpy() + return input_grad, sample_grad + + def result_grid_sampler(self, shape_format, sample_format): + for item in shape_format: + cpu_input, npu_input = create_common_tensor(item, 0, 100) + cpu_sample, npu_sample = create_common_tensor(sample_format, -1, 1) + if cpu_input.dtype == torch.float16: + cpu_input = cpu_input.to(torch.float32) + cpu_sample = cpu_sample.to(torch.float32) + cpu_grad1, cpu_grad2 = self.cpu_op_exec(cpu_input, cpu_sample) + npu_grad1, npu_grad2 = self.npu_op_exec(npu_input, npu_sample) + cpu_grad1 = cpu_grad1.astype(npu_grad1.dtype) + cpu_grad2 = npu_grad2.astype(npu_grad1.dtype) + self.assertRtolEqual(cpu_grad1, npu_grad1) + self.assertRtolEqual(cpu_grad2, npu_grad2) + + def test_grid_sampler_fp32(self): + format_list = [0] + shape_list = [[100, 1, 28, 28], [100, 64, 32, 28]] + shape_format = [ + [np.float32, j, k] for j in format_list for k in shape_list + ] + sample_format = [np.float32, 0, [100, 1, 1, 2]] + self.result_grid_sampler(shape_format, sample_format) + + def test_grid_sampler_fp16(self): + format_list = [0] + shape_list = [[1, 1, 3, 3], [1, 2, 3, 4]] + shape_format = [ + [np.float16, j, k] for j in format_list for k in shape_list + ] + sample_format = [np.float16, 0, [1, 2, 2, 2]] + self.result_grid_sampler(shape_format, sample_format) + + +if __name__ == "__main__": + run_tests() \ No newline at end of file diff --git a/torch_npu/csrc/aten/ops/GridSamplerKernelNpu.cpp b/torch_npu/csrc/aten/ops/GridSamplerKernelNpu.cpp deleted file mode 100644 index 1f1de5e316a00fb0b6f85ae2c6118aff11aab2f2..0000000000000000000000000000000000000000 --- a/torch_npu/csrc/aten/ops/GridSamplerKernelNpu.cpp +++ /dev/null @@ -1,75 +0,0 @@ -// Copyright (c) 2020 Huawei Technologies Co., Ltd -// All rights reserved. -// -// Licensed under the BSD 3-Clause License (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://opensource.org/licenses/BSD-3-Clause -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "torch_npu/csrc/framework/utils/OpAdapter.h" -#include "torch_npu/csrc/aten/NPUNativeFunctions.h" - -namespace at_npu { -namespace native { - -at::Tensor& grid_sampler_npu_nocheck( - const at::Tensor& self, - const at::Tensor& grid, - int64_t interpolation_mode, - int64_t padding_mode, - bool align_corners, - at::Tensor& result) { - OpCommand cmd; - cmd.Name("GridSampler2D") - .Input(self) - .Input(grid) - .Output(result) - .Attr("interpolation_mode", interpolation_mode) - .Attr("padding_mode", padding_mode) - .Attr("align_corners", align_corners) - .Run(); - return result; -} - -at::Tensor NPUNativeFunctions::grid_sampler( - const at::Tensor& self, - const at::Tensor& grid, - int64_t interpolation_mode, - int64_t padding_mode, - bool align_corners) { - at::Tensor formatCastOfSelf = NPUNativeFunctions::npu_format_cast(self, ACL_FORMAT_ND); - at::Tensor formatCastOfGrid = NPUNativeFunctions::npu_format_cast(grid, ACL_FORMAT_ND); - if (formatCastOfSelf.scalar_type() == at::ScalarType::Half) { - formatCastOfSelf = NPUNativeFunctions::npu_dtype_cast(formatCastOfSelf, at::ScalarType::Float); - } - if (formatCastOfGrid.scalar_type() == at::ScalarType::Half) { - formatCastOfGrid = NPUNativeFunctions::npu_dtype_cast(formatCastOfGrid, at::ScalarType::Float); - } - c10::SmallVector outputSize = {formatCastOfSelf.size(0), - formatCastOfSelf.size(1), - formatCastOfGrid.size(1), - formatCastOfGrid.size(2)}; - auto result = OpPreparation::ApplyTensorWithFormat( - outputSize, formatCastOfSelf.options(), ACL_FORMAT_ND); - grid_sampler_npu_nocheck( - formatCastOfSelf, - formatCastOfGrid, - interpolation_mode, - padding_mode, - align_corners, - result); - if (result.scalar_type() != self.scalar_type()) { - result = NPUNativeFunctions::npu_dtype_cast(result, at::ScalarType::Half); - } - return result; -} - -} // namespace native -} // namespace at_npu \ No newline at end of file