From 9b247710ee6b960fc99a9171bdc1211353d7ccd8 Mon Sep 17 00:00:00 2001 From: zhoufan37 Date: Thu, 17 Feb 2022 20:45:08 +0800 Subject: [PATCH] Add UpsampleLinear1d Operator --- .../test_upsample_linear1d.py | 108 ++++++++++++++ .../test_upsample_linear1d_backward.py | 130 +++++++++++++++++ torch_npu/csrc/aten/npu_native_functions.yaml | 2 - .../ops/UpsampleLinear1dBackwardKernelNpu.cpp | 123 ++++++++++++++++ .../aten/ops/UpsampleLinear1dKernelNpu.cpp | 135 ++++++++++++++++++ 5 files changed, 496 insertions(+), 2 deletions(-) create mode 100644 test/test_network_ops/test_upsample_linear1d.py create mode 100644 test/test_network_ops/test_upsample_linear1d_backward.py create mode 100644 torch_npu/csrc/aten/ops/UpsampleLinear1dBackwardKernelNpu.cpp create mode 100644 torch_npu/csrc/aten/ops/UpsampleLinear1dKernelNpu.cpp diff --git a/test/test_network_ops/test_upsample_linear1d.py b/test/test_network_ops/test_upsample_linear1d.py new file mode 100644 index 00000000000..4e1c9e80d9f --- /dev/null +++ b/test/test_network_ops/test_upsample_linear1d.py @@ -0,0 +1,108 @@ +# Copyright (c) 2020, Huawei Technologies.All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch_npu +import numpy as np + +from torch_npu.testing.common_utils import TestCase, run_tests +from torch_npu.testing.common_device_type import instantiate_device_type_tests +from torch_npu.testing.util_test import create_common_tensor + +class TestUpsampleLinear1D(TestCase): + def cpu_op_exec(self, input1, size, align_corners): + out_result = torch.ones(input1.shape[0], input1.shape[1], size[0], dtype=input1.dtype) + output = torch._C._nn.upsample_linear1d(input=input1, output_size=size, align_corners=align_corners) + torch._C._nn.upsample_linear1d(input=input1, output_size=size, align_corners=align_corners, out=out_result) + return output.numpy(), out_result.numpy() + + def npu_op_exec(self, input1, size, align_corners): + out_result = torch.ones(input1.shape[0], input1.shape[1], size[0], dtype=input1.dtype) + out_result = out_result.to("npu") + output = torch._C._nn.upsample_linear1d(input=input1, output_size=size, align_corners=align_corners) + torch._C._nn.upsample_linear1d(input=input1, output_size=size, align_corners=align_corners, out=out_result) + output = output.to("cpu") + out_result = out_result.to("cpu") + return output.numpy(), out_result.numpy() + + def test_upsample_linear1d_shapeformat(self, dev): + test_cases = [ + [[np.float16, 0, (1, 1, 1, 2)], [4, ], True], + [[np.float16, 0, (2, 1, 1, 4)], [8, ], True], + [[np.float16, 0, (2, 2, 1, 3)], [1, ], True], + [[np.float16, 0, (2, 1, 1, 1)], [4, ], False], + [[np.float16, 0, (4, 1, 1, 2)], [4, ], False], + [[np.float16, 0, (1, 1, 1, 1)], [1, ], False], + + [[np.float32, 0, (1, 1, 1, 2)], [4, ], True], + [[np.float32, 0, (2, 1, 1, 2)], [4, ], True], + [[np.float32, 0, (2, 2, 1, 3)], [1, ], True], + [[np.float32, 0, (3, 1, 1, 1)], [2, ], False], + [[np.float32, 0, (4, 1, 1, 1)], [2, ], False], + [[np.float32, 0, (1, 1, 1, 1)], [1, ], False], + + [[np.float16, 0, (9, 7, 1, 2)], [15, ], True], + [[np.float16, 0, (8, 7, 1, 1)], [2, ], True], + [[np.float16, 0, (17, 2, 1, 3)], [1, ], True], + [[np.float16, 0, (6, 4, 1, 1)], [3, ], False], + [[np.float16, 0, (8, 7, 1, 2)], [4, ], False], + [[np.float16, 0, (2, 7, 1, 7)], [1, ], False], + + [[np.float32, 0, (9, 7, 1, 2)], [7, ], True], + [[np.float32, 0, (8, 3, 1, 1)], [2, ], True], + [[np.float32, 0, (8, 3, 1, 1)], [2, ], True], + [[np.float32, 0, (17, 2, 1, 3)], [1, ], True], + [[np.float32, 0, (9, 7, 1, 2)], [7, ], False], + [[np.float32, 0, (8, 3, 1, 3)], [2, ], False], + [[np.float32, 0, (2, 7, 1, 7)], [1, ], False], + + [[np.float16, 0, (9, 7, 1, 2)], [17, ], True], + [[np.float16, 0, (17, 13, 1, 15)], [16, ], True], + [[np.float16, 0, (61, 41, 1, 1)], [7, ], False], + [[np.float16, 0, (38, 7, 1, 7)], [16, ], False], + [[np.float32, 0, (997, 3, 1, 1)], [32, ], True], + [[np.float32, 0, (627, 2, 1, 3)], [17, ], False], + [[np.float32, 0, (78, 73, 1, 1)], [48, ], False], + [[np.float32, 0, (65535, 2, 1, 4)], [8, ], False], + [[np.float16, 0, (65535, 2, 1, 4)], [8, ], False], + [[np.float32, 0, (10086, 3, 1, 17)], [57, ], False], + [[np.float16, 0, (10086, 3, 1, 17)], [57, ], False] + ] + for item in test_cases: + cpu_input, npu_input = create_common_tensor(item[0], 0, 100) + + if cpu_input.dtype == torch.float16: + cpu_input = cpu_input.to(torch.float32) + + if cpu_input.dim() == 4: + cpu_input = cpu_input.squeeze(2) + + if npu_input.dim() == 4: + npu_input = npu_input.squeeze(2) + + size = item[1] + align_corners = item[2] + + npu_output ,npu_out_result = self.npu_op_exec(npu_input, size, align_corners) + cpu_output ,cpu_out_result = self.cpu_op_exec(cpu_input, size, align_corners) + + cpu_output = cpu_output.astype(npu_output.dtype) + cpu_out_result = cpu_out_result.astype(npu_out_result.dtype) + + self.assertRtolEqual(cpu_output, npu_output) + self.assertRtolEqual(cpu_out_result, npu_out_result) + +instantiate_device_type_tests(TestUpsampleLinear1D, globals(), except_for="cpu") +if __name__ == "__main__": + run_tests() \ No newline at end of file diff --git a/test/test_network_ops/test_upsample_linear1d_backward.py b/test/test_network_ops/test_upsample_linear1d_backward.py new file mode 100644 index 00000000000..6ef6f081dca --- /dev/null +++ b/test/test_network_ops/test_upsample_linear1d_backward.py @@ -0,0 +1,130 @@ +# Copyright (c) 2020, Huawei Technologies.All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch_npu +import numpy as np + +from torch_npu.testing.common_utils import TestCase, run_tests +from torch_npu.testing.common_device_type import instantiate_device_type_tests +from torch_npu.testing.util_test import create_common_tensor + +class TestUpsampleLinear1DBackward(TestCase): + def cpu_op_exec(self, input1, grads, size, align_corners): + input1.requires_grad_(True) + output = torch._C._nn.upsample_linear1d(input1, size, align_corners=align_corners) + output.backward(grads) + gradcpu = input1.grad + return output.detach().numpy(), gradcpu.detach().numpy() + + def npu_op_exec(self, input1, grads, size, align_corners): + input1.requires_grad_(True) + output = torch._C._nn.upsample_linear1d(input1, size, align_corners=align_corners) + output = output.to("npu") + output.backward(grads) + gradnpu = input1.grad + gradnpu = gradnpu.to("cpu") + output = output.to("cpu") + return output.detach().numpy(), gradnpu.detach().numpy() + + def test_upsample_linear1d_backward_shapeformat(self, dev): + test_cases = [ + [[np.float16, 0, (1, 1, 1, 2)], [4, ], True], + [[np.float16, 0, (2, 1, 1, 4)], [8, ], True], + [[np.float16, 0, (2, 2, 1, 3)], [1, ], True], + [[np.float16, 0, (2, 1, 1, 1)], [4, ], False], + [[np.float16, 0, (4, 1, 1, 2)], [4, ], False], + [[np.float16, 0, (1, 1, 1, 1)], [1, ], False], + + [[np.float32, 0, (1, 1, 1, 2)], [4, ], True], + [[np.float32, 0, (2, 1, 1, 2)], [4, ], True], + [[np.float32, 0, (2, 2, 1, 3)], [1, ], True], + [[np.float32, 0, (3, 1, 1, 1)], [2, ], False], + [[np.float32, 0, (4, 1, 1, 1)], [2, ], False], + [[np.float32, 0, (1, 1, 1, 1)], [1, ], False], + + [[np.float32, 0, (9, 7, 1, 2)], [15, ], True], + [[np.float16, 0, (8, 7, 1, 1)], [2, ], True], + [[np.float16, 0, (17, 2, 1, 3)], [1, ], True], + [[np.float16, 0, (6, 4, 1, 1)], [3, ], False], + [[np.float16, 0, (8, 7, 1, 2)], [4, ], False], + [[np.float16, 0, (2, 7, 1, 7)], [1, ], False], + + [[np.float32, 0, (9, 7, 1, 2)], [7, ], True], + [[np.float32, 0, (8, 3, 1, 1)], [2, ], True], + [[np.float32, 0, (8, 3, 1, 1)], [2, ], True], + [[np.float32, 0, (17, 2, 1, 3)], [1, ], True], + [[np.float32, 0, (9, 7, 1, 2)], [7, ], False], + [[np.float32, 0, (8, 3, 1, 3)], [2, ], False], + [[np.float32, 0, (2, 7, 1, 7)], [1, ], False], + + [[np.float16, 0, (9, 7, 1, 2)], [17, ], True], + [[np.float16, 0, (17, 13, 1, 15)], [16, ], True], + [[np.float16, 0, (61, 41, 1, 1)], [7, ], False], + [[np.float16, 0, (38, 7, 1, 7)], [16, ], False], + [[np.float32, 0, (997, 3, 1, 1)], [32, ], True], + [[np.float32, 0, (627, 2, 1, 3)], [17, ], False], + [[np.float32, 0, (78, 73, 1, 1)], [48, ], False], + [[np.float32, 0, (6553, 2, 1, 2)], [4, ], False], + [[np.float16, 0, (6553, 2, 1, 2)], [4, ], False], + [[np.float32, 0, (1008, 3, 1, 2)], [4, ], False], + [[np.float16, 0, (1008, 3, 1, 2)], [4, ], False] + ] + for item in test_cases: + cpu_input, npu_input = create_common_tensor(item[0], 0, 100) + + size = list(item[0][2]) + size[3] = item[1][0] + + grad_item = [] + grad_item.append(item[0][0]) + grad_item.append(item[0][1]) + grad_item.append(size) + cpu_grads, npu_grads = create_common_tensor(grad_item, 0, 100) + + if cpu_input.dtype == torch.float16: + cpu_input = cpu_input.to(torch.float32) + + if cpu_grads.dtype == torch.float16: + cpu_grads = cpu_grads.to(torch.float32) + + if cpu_input.dim() == 4: + cpu_input = cpu_input.squeeze(2) + + if npu_input.dim() == 4: + npu_input = npu_input.squeeze(2) + + if cpu_grads.dim() == 4: + cpu_grads = cpu_grads.squeeze(2) + + if npu_grads.dim() == 4: + npu_grads = npu_grads.squeeze(2) + + size = item[1] + align_corners = item[2] + + cpu_output, cpu_grad = self.cpu_op_exec(cpu_input, cpu_grads, size, align_corners) + npu_output, npu_grad = self.npu_op_exec(npu_input, npu_grads, size, align_corners) + + cpu_output = cpu_output.astype(npu_output.dtype) + cpu_grad = cpu_grad.astype(npu_grad.dtype) + + self.assertRtolEqual(cpu_output, npu_output) + self.assertRtolEqual(cpu_grad, npu_grad) + + + +instantiate_device_type_tests(TestUpsampleLinear1DBackward, globals(), except_for="cpu") +if __name__ == "__main__": + run_tests() \ No newline at end of file diff --git a/torch_npu/csrc/aten/npu_native_functions.yaml b/torch_npu/csrc/aten/npu_native_functions.yaml index 0ab0fa4062e..648f1f10d7d 100644 --- a/torch_npu/csrc/aten/npu_native_functions.yaml +++ b/torch_npu/csrc/aten/npu_native_functions.yaml @@ -1679,8 +1679,6 @@ supported: - replication_pad3d - replication_pad3d_backward.grad_input - replication_pad3d_backward - - upsample_linear1d.vec - - upsample_linear1d_backward.vec - upsample_bilinear2d.vec - upsample_bilinear2d_backward.vec - upsample_trilinear3d.vec diff --git a/torch_npu/csrc/aten/ops/UpsampleLinear1dBackwardKernelNpu.cpp b/torch_npu/csrc/aten/ops/UpsampleLinear1dBackwardKernelNpu.cpp new file mode 100644 index 00000000000..3a6686dc564 --- /dev/null +++ b/torch_npu/csrc/aten/ops/UpsampleLinear1dBackwardKernelNpu.cpp @@ -0,0 +1,123 @@ +// Copyright (c) 2020 Huawei Technologies Co., Ltd +// Copyright (c) 2019, Facebook CORPORATION. +// All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" + +namespace at_npu { +namespace native { + +static inline void upsample_linear1d_backward_check( + const at::Tensor& grad_output, + at::IntArrayRef output_size, + at::IntArrayRef input_size) { + + TORCH_CHECK( + output_size.size() == 1, + "It is expected output_size equals to 1, but got size ", + output_size.size()); + + TORCH_CHECK( + input_size.size() == 3, + "It is expected input_size equals to 3, but got size ", + input_size.size()); + + TORCH_CHECK( + (grad_output.size(1) != 0 && grad_output.size(2) != 0) && grad_output.dim() == 3, + "Non-empty 3D data tensor expected but got a tensor with sizes ", + grad_output.sizes()); + + int64_t output_width = grad_output.size(2); + int64_t input_width = input_size[2]; + + TORCH_CHECK( + output_width > 0 && input_width > 0, + "Input and output sizes should be greater than 0, but got input (W: ", + input_width, + ") and output (W: ", + output_width, + ")"); +} + +at::Tensor& upsample_linear1d_backward_out( + at::Tensor& result, + const at::Tensor& grad_output, + at::IntArrayRef input_size, + bool align_corners, + c10::optional scales) { + + c10::SmallVector sc = {}; + if (scales.has_value()) { + sc.push_back(scales.value()); + } else { + float temp = float(grad_output.size(3)) / float(input_size[2]); + sc.push_back(temp); + } + string coordinate_transformation_mode = + align_corners ? "align_corners" : "half_pixel"; + + // executing the NPU operator + OpCommand cmd; + cmd.Name("ResizeGradD") + .Input(grad_output) + .Output(result) + .Attr("original_size", input_size) + .Attr("scales", sc) + .Attr("coordinate_transformation_mode", coordinate_transformation_mode) + .Attr("mode", (string)"linear") + .Run(); + return result; +} + +at::Tensor NPUNativeFunctions::upsample_linear1d_backward( + const at::Tensor& grad_output, + at::IntArrayRef output_size, + at::IntArrayRef input_size, + bool align_corners, + c10::optional scales) { + upsample_linear1d_backward_check(grad_output, output_size, input_size); + at::Tensor _grad_output = grad_output; + if(grad_output.scalar_type() != at::ScalarType::Float) + { + _grad_output = NPUNativeFunctions::npu_dtype_cast(_grad_output, at::ScalarType::Float); + } + + // calculate the output size + int64_t N = _grad_output.size(0); + int64_t C = _grad_output.size(1); + int64_t W = input_size[2]; + + c10::SmallVector outputSize = {N, C, W}; + + // Since only NCHW format input is currently supported, first convert the + // input grad_output (3 dimensions) to 4 dimensions as the input of npu + auto grad_output_4dim = _grad_output.unsqueeze(2); + + // construct the output tensor of the NPU + at::Tensor result = OpPreparation::ApplyTensor(_grad_output, outputSize); + + // calculate the output result of the NPU + upsample_linear1d_backward_out( + result, grad_output_4dim, input_size, align_corners, scales); + + if (result.dtype() != grad_output.dtype()) { + result = result.to(grad_output.dtype()); + } + + return result; +} +} // namespace native +} // namespace at \ No newline at end of file diff --git a/torch_npu/csrc/aten/ops/UpsampleLinear1dKernelNpu.cpp b/torch_npu/csrc/aten/ops/UpsampleLinear1dKernelNpu.cpp new file mode 100644 index 00000000000..7da1537463d --- /dev/null +++ b/torch_npu/csrc/aten/ops/UpsampleLinear1dKernelNpu.cpp @@ -0,0 +1,135 @@ +// Copyright (c) 2020 Huawei Technologies Co., Ltd +// Copyright (c) 2019, Facebook CORPORATION. +// All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include "torch_npu/csrc/framework/utils/CalcuOpUtil.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" + +namespace at_npu { +namespace native { + +static inline void upsample_linear1d_check( + const at::Tensor& self, + at::IntArrayRef output_size) { + TORCH_CHECK( + output_size.size() == 1, + "It is expected output_size equals to 1, but got size ", + output_size.size()); + + TORCH_CHECK( + (self.size(1) != 0 && self.size(2) != 0) && self.dim() == 3, + "Non-empty 3D data tensor expected but got a tensor with sizes ", + self.sizes()); + + int64_t input_width = self.size(2); + int64_t output_width = output_size[0]; + + TORCH_CHECK( + input_width > 0 && output_width > 0, + "Input and output sizes should be greater than 0, but got input (W: ", + input_width, + ") and output (W: ", + output_width, + ")"); +} + +c10::SmallVector upsample_linear1d_npu_input( + const c10::SmallVector& inputTensor) { + return CalcuOpUtil::create_npu_input_tensor_desc(inputTensor); +} + +c10::SmallVector upsample_linear1d_npu_output( + const c10::SmallVector& outputTensor) { + return CalcuOpUtil::create_npu_output_tensor_desc(outputTensor); +} + +c10::SmallVector upsample_linear1d_npu_attr( + const at::Tensor& self, + at::IntArrayRef output_size, + bool align_corners, + c10::optional scales) { + NPUAttrDesc npuAttrSizes = NPUAttrDesc("sizes", output_size); + // If the value of scales is not passed in, then according to the formula: + // scale = output_size / input.dim3 + // to calculate the value of scale + c10::SmallVector sc = {}; + if (scales.has_value()) { + sc.push_back(scales.value()); + } else { + float temp = float(output_size[0]) / float(self.size(3)); + sc.push_back(temp); + } + NPUAttrDesc npuAttrScales = NPUAttrDesc("scales", sc); + + string coordinate_transformation_mode = + align_corners ? "align_corners" : "half_pixel"; + NPUAttrDesc npuAttrCoordinateTransformationMode = NPUAttrDesc( + "coordinate_transformation_mode", coordinate_transformation_mode); + + string mode = "linear"; + NPUAttrDesc npuAttrMode = NPUAttrDesc("mode", mode); + + // required attr + c10::SmallVector attrs = { + npuAttrSizes, + npuAttrCoordinateTransformationMode, + npuAttrMode, + npuAttrScales}; + return attrs; +} + +at::Tensor& NPUNativeFunctions::upsample_linear1d_out( + const at::Tensor& self, + at::IntArrayRef output_size, + bool align_corners, + c10::optional scales, + at::Tensor& result) { + upsample_linear1d_check(self,output_size); + // Since only NCHW format input is currently supported, first convert the + // input self (3 dimensions) to 4 dimensions as the input of npu + auto input_4dim = self.unsqueeze(2); + // constructs the input and output NPUTensorDesc + auto inputs = upsample_linear1d_npu_input({input_4dim}); + auto outputs = upsample_linear1d_npu_output({result}); + + // constructs the attr of the NPUAttrDesc + auto attrs = upsample_linear1d_npu_attr(input_4dim, output_size, align_corners, scales); + + // executing the NPU operator + CalcuOpUtil::execute_npu_operate("ResizeD", inputs, outputs, attrs); + return result; +} + +at::Tensor NPUNativeFunctions::upsample_linear1d( + const at::Tensor& self, + at::IntArrayRef output_size, + bool align_corners, + c10::optional scales) { + // calculate the output size + auto outputSize = upsample_linear1d_npu_output_size( + self, output_size, align_corners, scales); + + // construct the output tensor of the NPU + at::Tensor result = OpPreparation::ApplyTensorWithSizes(outputSize, self.options()); + + // calculate the output result of the NPU + NPUNativeFunctions::upsample_linear1d_out( + self, output_size, align_corners, scales, result); + + return result; +} +} // namespace native +} // namespace at_npu \ No newline at end of file -- Gitee