From c24bd882b8e3f2611d3cca4887191e82bccad58c Mon Sep 17 00:00:00 2001 From: pipihugh Date: Tue, 22 Mar 2022 16:37:26 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E5=A4=8Dgrid=5Fsampler=5F2d=E7=AE=97?= =?UTF-8?q?=E5=AD=90=E7=B2=BE=E5=BA=A6=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- test/test_network_ops/test_grid_sampler_2d.py | 3 +-- torch_npu/csrc/aten/ops/GridSampler2dKernelNpu.cpp | 8 ++------ 2 files changed, 3 insertions(+), 8 deletions(-) diff --git a/test/test_network_ops/test_grid_sampler_2d.py b/test/test_network_ops/test_grid_sampler_2d.py index 9ea285d514..eba4536728 100644 --- a/test/test_network_ops/test_grid_sampler_2d.py +++ b/test/test_network_ops/test_grid_sampler_2d.py @@ -65,12 +65,11 @@ class TestGridSampler2D(TestCase): ] for item in shape_format: cpu_input, npu_input = create_common_tensor(item[0], 1, 100) - cpu_grid, npu_grid = create_common_tensor(item[1], -1, 1) + cpu_grid, npu_grid = create_common_tensor(item[1], -3, 3) cpu_output = self.cpu_op_fp16_exec(cpu_input, cpu_grid) npu_output = self.npu_op_exec(npu_input, npu_grid) self.assertRtolEqual(cpu_output, npu_output) - if __name__ == "__main__": run_tests() \ No newline at end of file diff --git a/torch_npu/csrc/aten/ops/GridSampler2dKernelNpu.cpp b/torch_npu/csrc/aten/ops/GridSampler2dKernelNpu.cpp index 674479dbe2..38e32044c8 100644 --- a/torch_npu/csrc/aten/ops/GridSampler2dKernelNpu.cpp +++ b/torch_npu/csrc/aten/ops/GridSampler2dKernelNpu.cpp @@ -49,16 +49,13 @@ at::Tensor NPUNativeFunctions::grid_sampler_2d( at::Tensor result = OpPreparation::ApplyTensorWithFormat(dtypeCastOfSelf, outputSize, ACL_FORMAT_ND); - c10::SmallVectorinterMode = {"bilinear", "nearest", "bicubic"}; - c10::SmallVectorpaddingMode = {"zeros", "border", "reflection"}; - OpCommand cmd; cmd.Name("GridSampler2D") .Input(dtypeCastOfSelf) .Input(dtypeCastOfGrid) .Output(result) - .Attr("interpolation_mode", interMode[interpolation_mode]) - .Attr("padding_mode", paddingMode[padding_mode]) + .Attr("interpolation_mode", interpolation_mode) + .Attr("padding_mode", padding_mode) .Attr("align_corners", align_corners) .Run(); @@ -68,6 +65,5 @@ at::Tensor NPUNativeFunctions::grid_sampler_2d( } return result; } - } // namespace native } // namespace at_npu -- Gitee