From 40b0c7a2b4a0e1aa63b8000f828354c86c90c5a4 Mon Sep 17 00:00:00 2001 From: zhoufan37 Date: Wed, 9 Mar 2022 18:23:41 +0800 Subject: [PATCH] add upsampleBicubic2d Operator --- .../test_upsample_bicubic2d.py | 92 +++++++++++++ .../test_upsample_bicubic2d_backward.py | 101 ++++++++++++++ torch_npu/csrc/aten/npu_native_functions.yaml | 2 - .../UpSampleBicubic2dBackwardKernelNpu.cpp | 124 +++++++++++++++++ .../aten/ops/UpsampleBicubic2dKernelNpu.cpp | 126 ++++++++++++++++++ 5 files changed, 443 insertions(+), 2 deletions(-) create mode 100644 test/test_network_ops/test_upsample_bicubic2d.py create mode 100644 test/test_network_ops/test_upsample_bicubic2d_backward.py create mode 100644 torch_npu/csrc/aten/ops/UpSampleBicubic2dBackwardKernelNpu.cpp create mode 100644 torch_npu/csrc/aten/ops/UpsampleBicubic2dKernelNpu.cpp diff --git a/test/test_network_ops/test_upsample_bicubic2d.py b/test/test_network_ops/test_upsample_bicubic2d.py new file mode 100644 index 0000000000..38e16e9228 --- /dev/null +++ b/test/test_network_ops/test_upsample_bicubic2d.py @@ -0,0 +1,92 @@ +# Copyright (c) 2020 Huawei Technologies Co., Ltd +# Copyright (c) 2019, Facebook CORPORATION. +# All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch_npu +import numpy as np + +from torch_npu.testing.testcase import TestCase, run_tests +from torch_npu.testing.common_utils import create_common_tensor + +class TestUpsampleBicubic2d(TestCase): + + def cpu_op_exec(self, input1, output_size, align_corners, scale_h, scale_w): + output = torch._C._nn.upsample_bicubic2d(input1, output_size, align_corners, scale_h, scale_w) + output = output.numpy() + return output + + def npu_op_exec(self, input1, output_size, align_corners, scale_h, scale_w): + output = torch._C._nn.upsample_bicubic2d(input1, output_size, align_corners, scale_h, scale_w) + output = output.to("cpu") + output = output.numpy() + return output + + def create_shape_format32(self): + dtype_list = [np.float32] + format_list = [-1] + shape_list = [(1, 1, 1, 1), (1, 31, 149, 2), (32, 32, 32, 32), (65535, 2, 4, 8)] + size_list = [(2, 2), (4, 4)] + align_corners_list = [True, False] + scale_h = [0, 0.5] + scale_w = [0, 0.5] + + shape_format = [[[i, j, k], h, f, e, g] for i in dtype_list + for j in format_list for k in shape_list + for h in size_list for f in align_corners_list + for e in scale_h for g in scale_w] + + return shape_format + + def create_shape_format16(self): + dtype_list1 = [np.float16] + format_list1 = [-1] + shape_list1 = [(1, 1, 1, 1), (1, 31, 149, 2), (32, 32, 32, 32), (65535, 2, 4, 8)] + size_list1 = [(2, 2), (4, 4)] + align_corners_list1 = [True, False] + scale_h1 = [0, 0.5] + scale_w1 = [0, 0.5] + + shape_format1 = [[[i, j, k], h, f, e, g] for i in dtype_list1 + for j in format_list1 for k in shape_list1 + for h in size_list1 for f in align_corners_list1 + for e in scale_h1 for g in scale_w1] + + return shape_format1 + + def test_upsample_bicubic2d_common_shape_format(self): + for item in self.create_shape_format32(): + cpu_input1, npu_input1 = create_common_tensor(item[0], 0, 255) + cpu_output = self.cpu_op_exec(cpu_input1, item[1], item[2], item[3], item[4]) + npu_output = self.npu_op_exec(npu_input1, item[1], item[2], item[3], item[4]) + self.assertRtolEqual(cpu_output, npu_output) + + def test_upsample_bicubic2d_float16_shape_format(self): + def cpu_op_exec_fp16(input1, output_size, align_corners, scale_h, scale_w): + input1 = input1.to(torch.float32) + output = torch._C._nn.upsample_bicubic2d(input1, output_size, align_corners, scale_h, scale_w) + output = output.numpy() + output = output.astype(np.float16) + return output + + for item in self.create_shape_format16(): + cpu_input1, npu_input1 = create_common_tensor(item[0], 0, 255) + cpu_output = cpu_op_exec_fp16(cpu_input1, item[1], item[2], item[3], item[4]) + npu_output = self.npu_op_exec(npu_input1, item[1], item[2], item[3], item[4]) + self.assertRtolEqual(cpu_output, npu_output) + + +if __name__ == "__main__": + run_tests() \ No newline at end of file diff --git a/test/test_network_ops/test_upsample_bicubic2d_backward.py b/test/test_network_ops/test_upsample_bicubic2d_backward.py new file mode 100644 index 0000000000..ea37ece37d --- /dev/null +++ b/test/test_network_ops/test_upsample_bicubic2d_backward.py @@ -0,0 +1,101 @@ +# Copyright (c) 2020 Huawei Technologies Co., Ltd +# Copyright (c) 2019, Facebook CORPORATION. +# All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch_npu +import numpy as np + +from torch_npu.testing.testcase import TestCase, run_tests +from torch_npu.testing.common_utils import create_common_tensor + +class TestUpsampleBicubic2dBackward(TestCase): + + def cpu_op_exec(self, input1, output_size, align_corners, scale_h, scale_w): + input1.requires_grad = True + output = torch._C._nn.upsample_bicubic2d(input1, output_size, align_corners, scale_h, scale_w) + output.backward(torch.ones_like(output)) + output_grad = input1.grad + output_grad = output_grad.detach().numpy() + return output_grad + + def npu_op_exec(self, input1, output_size, align_corners, scale_h, scale_w): + input1.requires_grad = True + output = torch._C._nn.upsample_bicubic2d(input1, output_size, align_corners, scale_h, scale_w) + output.backward(torch.ones_like(output)) + output_grad = input1.grad + output_grad = output_grad.to("cpu").detach().numpy() + return output_grad + + def backward_create_shape_format32(self): + dtype_list2 = [np.float32] + format_list2 = [-1] + shape_list2 = [(1, 1, 1, 1), (1, 31, 149, 2), (32, 32, 32, 32)] + size_list2 = [(2, 2), (4, 4)] + align_corners_list2 = [True, False] + scale_h2 = [0, 0.5] + scale_w2 = [0, 0.5] + + shape_format = [[[i, j, k], h, f, e, g] for i in dtype_list2 + for j in format_list2 for k in shape_list2 + for h in size_list2 for f in align_corners_list2 + for e in scale_h2 for g in scale_w2] + + return shape_format + + def backward_create_shape_format16(self): + dtype_list3 = [np.float16] + format_list3 = [-1] + shape_list3 = [(1, 1, 1, 1), (1, 31, 149, 2), (32, 32, 32, 32)] + size_list3 = [(2, 2), (4, 4)] + align_corners_list3 = [True, False] + scale_h3 = [0, 0.5] + scale_w3 = [0, 0.5] + + shape_format1 = [[[i, j, k], h, f, e, g] for i in dtype_list3 + for j in format_list3 for k in shape_list3 + for h in size_list3 for f in align_corners_list3 + for e in scale_h3 for g in scale_w3] + + return shape_format1 + + def test_upsample_bicubic2d_common_shape_format(self): + for item in self.backward_create_shape_format32(): + cpu_input1, npu_input1 = create_common_tensor(item[0], 0, 255) + cpu_output = self.cpu_op_exec(cpu_input1, item[1], item[2], item[3], item[4]) + npu_output = self.npu_op_exec(npu_input1, item[1], item[2], item[3], item[4]) + self.assertRtolEqual(cpu_output, npu_output) + + + def test_upsample_bicubic2d_float16_shape_format(self): + def cpu_op_exec_fp16(input1, output_size, align_corners, scale_h, scale_w): + input1 = input1.to(torch.float32) + input1.requires_grad = True + output = torch._C._nn.upsample_bicubic2d(input1, output_size, align_corners, scale_h, scale_w) + output.backward(torch.ones_like(output)) + output_grad = input1.grad + output_grad = output_grad.detach().numpy() + output_grad = output_grad.astype(np.float16) + return output_grad + + for item in self.backward_create_shape_format16(): + cpu_input1, npu_input1 = create_common_tensor(item[0], 0, 255) + cpu_output = cpu_op_exec_fp16(cpu_input1, item[1], item[2], item[3], item[4]) + npu_output = self.npu_op_exec(npu_input1, item[1], item[2], item[3], item[4]) + self.assertRtolEqual(cpu_output, npu_output) + + +if __name__ == "__main__": + run_tests() diff --git a/torch_npu/csrc/aten/npu_native_functions.yaml b/torch_npu/csrc/aten/npu_native_functions.yaml index 9144830bc9..c5ba06b765 100644 --- a/torch_npu/csrc/aten/npu_native_functions.yaml +++ b/torch_npu/csrc/aten/npu_native_functions.yaml @@ -1673,8 +1673,6 @@ supported: - replication_pad3d_backward - upsample_trilinear3d.vec - upsample_trilinear3d_backward.vec - - upsample_bicubic2d.vec - - upsample_bicubic2d_backward.vec - upsample_nearest1d.vec - upsample_nearest1d_backward.vec - upsample_nearest2d.vec diff --git a/torch_npu/csrc/aten/ops/UpSampleBicubic2dBackwardKernelNpu.cpp b/torch_npu/csrc/aten/ops/UpSampleBicubic2dBackwardKernelNpu.cpp new file mode 100644 index 0000000000..d16c71cdc8 --- /dev/null +++ b/torch_npu/csrc/aten/ops/UpSampleBicubic2dBackwardKernelNpu.cpp @@ -0,0 +1,124 @@ +// Copyright (c) 2020, Huawei Technologies. +// Copyright (c) 2019, Facebook CORPORATION. +// All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include "torch_npu/csrc/framework/utils/CalcuOpUtil.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" + +namespace at_npu { +namespace native { + +at::Tensor& upsample_bicubic2d_backward_out_nocheck( + const at::Tensor& grad_output, + at::IntArrayRef output_size, + at::IntArrayRef input_size, + bool align_corners, + c10::optional scales_h, + c10::optional scales_w, + at::Tensor& grad_input) { + + TORCH_CHECK( + output_size.size() == 2, + "It is expected output_size equals to 2, but got size ", + output_size.size()); + + TORCH_CHECK( + input_size.size() == 4, + "It is expected input_size equals to 4, but got size ", + input_size.size()); + + float temp_h = 0.0; + float temp_w = 0.0; + if(scales_h.has_value()) { + temp_h = (float)scales_h.value(); + } + if(scales_w.has_value()) { + temp_w = (float)scales_w.value(); + } + c10::SmallVector scales = {temp_h, temp_w}; + c10::SmallVector roi = {}; + string coordinate_transformation_mode = + align_corners ? "align_corners" : "half_pixel"; + + float cu = -0.75; + int64_t ex = 0; + float ext = 0.0; + string mode = "cubic"; + string ne = "round_prefer_floor"; + + OpCommand cmd; + cmd.Name("ResizeGradD") + .Input(grad_output) + .Output(grad_input) + .Attr("scales", scales) + .Attr("roi", roi) + .Attr("original_size", input_size) + .Attr("coordinate_transformation_mode", coordinate_transformation_mode) + .Attr("cubic_coeff_a", cu) + .Attr("exclude_outside", ex) + .Attr("extrapolation_value", ext) + .Attr("mode", mode) + .Attr("nearest_mode", ne) + .Run(); + + return grad_input; +} + +at::Tensor& NPUNativeFunctions::upsample_bicubic2d_backward_out( + const at::Tensor& grad_output, + at::IntArrayRef output_size, + at::IntArrayRef input_size, + bool align_corners, + c10::optional scales_h, + c10::optional scales_w, + at::Tensor& grad_input) { + + auto outputSize = upsample_bicubic2d_backward_npu_output_size(input_size); + + OpPreparation::CheckOut( + {grad_output}, + grad_input, + CalcuOpUtil::get_tensor_npu_format(grad_output), + grad_output.scalar_type(), + outputSize); + + if (!NpuUtils::check_match(&grad_input)) { + at::Tensor contiguousResult = NpuUtils::format_contiguous(grad_input); + upsample_bicubic2d_backward_out_nocheck(grad_output, output_size, input_size, align_corners, scales_h, scales_w, contiguousResult); + NpuUtils::format_fresh_view(grad_input, contiguousResult); + } else { + upsample_bicubic2d_backward_out_nocheck(grad_output, output_size, input_size, align_corners, scales_h, scales_w, grad_input); + } + + return grad_input; +} + +at::Tensor NPUNativeFunctions::upsample_bicubic2d_backward( + const at::Tensor& grad_output, + at::IntArrayRef output_size, + at::IntArrayRef input_size, + bool align_corners, + c10::optional scales_h, + c10::optional scales_w) { + // construct the output tensor of the NPU + auto outputSize = upsample_bicubic2d_backward_npu_output_size(input_size); + at::Tensor result = OpPreparation::ApplyTensorWithSizes(outputSize, grad_output.options()); + // calculate the output result of the NPU + return upsample_bicubic2d_backward_out_nocheck(grad_output, output_size, input_size, align_corners, scales_h, scales_w, result); +} + +} // namespace native +} // namespace at_npu \ No newline at end of file diff --git a/torch_npu/csrc/aten/ops/UpsampleBicubic2dKernelNpu.cpp b/torch_npu/csrc/aten/ops/UpsampleBicubic2dKernelNpu.cpp new file mode 100644 index 0000000000..19e9b552ad --- /dev/null +++ b/torch_npu/csrc/aten/ops/UpsampleBicubic2dKernelNpu.cpp @@ -0,0 +1,126 @@ +// Copyright (c) 2020, Huawei Technologies.All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include "torch_npu/csrc/framework/utils/CalcuOpUtil.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" + +namespace at_npu { +namespace native { + +at::Tensor& upsample_bicubic2d_out_nocheck( + at::Tensor& result, + const at::Tensor& self, + at::IntArrayRef output_size, + bool align_corners, + c10::optional scales_h, + c10::optional scales_w) { + + TORCH_CHECK( + output_size.size() == 2, + "It is expected output_size equals to 2, but got size ", + output_size.size()); + + float temp_h = 0.0; + float temp_w = 0.0; + if (scales_h.has_value()) { + temp_h = (float)scales_h.value(); + } + if (scales_w.has_value()) { + temp_w = (float)scales_w.value(); + } + c10::SmallVector scales = {temp_h, temp_w}; + c10::SmallVector roi = {}; + string coordinate_transformation_mode = "half_pixel"; + if (align_corners == true) { + coordinate_transformation_mode = "align_corners"; + } + + OpCommand cmd; + cmd.Name("ResizeD") + .Input(self) + .Output(result) + .Attr("sizes", output_size) + .Attr("scales", scales) + .Attr("roi", roi) + .Attr("coordinate_transformation_mode", coordinate_transformation_mode) + .Attr("cubic_coeff_a", (float)-0.75) + .Attr("exclude_outside", (int64_t)0) + .Attr("extrapolation_value", (float)0.0) + .Attr("mode", (string)"cubic") + .Attr("nearest_mode", (string)"round_prefer_floor") + .Run(); + + return result; +} + +at::Tensor& NPUNativeFunctions::upsample_bicubic2d_out( + const at::Tensor& self, + at::IntArrayRef output_size, + bool align_corners, + c10::optional scales_h, + c10::optional scales_w, + at::Tensor& result) { + + // calculate the output size + int64_t N = self.size(0); + int64_t C = self.size(1); + int64_t H = output_size[0]; + int64_t W = output_size[1]; + + c10::SmallVector outputSize = {N, C, H, W}; + + OpPreparation::CheckOut( + {self}, + result, + CalcuOpUtil::get_tensor_npu_format(self), + self.scalar_type(), + outputSize); + + if (!NpuUtils::check_match(&result)) { + at::Tensor contiguousResult = NpuUtils::format_contiguous(result); + upsample_bicubic2d_out_nocheck(contiguousResult, self, output_size, align_corners, scales_h, scales_w); + NpuUtils::format_fresh_view(result, contiguousResult); + } else { + upsample_bicubic2d_out_nocheck(result, self, output_size, align_corners, scales_h, scales_w); + } + + return result; +} + +at::Tensor NPUNativeFunctions::upsample_bicubic2d( + const at::Tensor& self, + at::IntArrayRef output_size, + bool align_corners, + c10::optional scales_h, + c10::optional scales_w) { + + // calculate the output size + int64_t N = self.size(0); + int64_t C = self.size(1); + int64_t H = output_size[0]; + int64_t W = output_size[1]; + c10::SmallVector outputSize = {N, C, H, W}; + + // construct the output tensor of the NPU + at::Tensor result = OpPreparation::ApplyTensor(self, outputSize); + + // calculate the output result of the NPU + upsample_bicubic2d_out_nocheck(result, self, output_size, align_corners, scales_h, scales_w); + + return result; +} + +} // namespace native +} // namespace at_npu -- Gitee