From 3b42f2f4000d3d37e1f12c9c71b1f5f5a302637a Mon Sep 17 00:00:00 2001 From: wangxiao Date: Tue, 1 Mar 2022 17:02:50 +0800 Subject: [PATCH 1/2] floor, remainder, sort, upsample_bilinear2d, upsample_bilinear2d_backward, upsample_nearest2d, upsample_nearest2d_backward --- test/test_network_ops/test_floor.py | 155 ++++++++++++ test/test_network_ops/test_remainder.py | 224 ++++++++++++++++++ test/test_network_ops/test_sort.py | 92 +++++++ .../test_upsample_bilinear2d.py | 54 +++++ .../test_upsample_bilinear2d_backward.py | 61 +++++ .../test_upsample_nearest2d.py | 54 +++++ .../test_upsample_nearest2d_backward.py | 63 +++++ torch_npu/csrc/aten/ops/FloorKernelNpu.cpp | 57 +++++ .../csrc/aten/ops/RemainderKernelNpu.cpp | 141 +++++++++++ torch_npu/csrc/aten/ops/SortKernelNpu.cpp | 110 +++++++++ .../UpsampleBilinear2dBackwardKernelNpu.cpp | 145 ++++++++++++ .../aten/ops/UpsampleBilinear2dKernelNpu.cpp | 99 ++++++++ .../UpsampleNearest2dBackwardKernelNpu.cpp | 74 ++++++ .../aten/ops/UpsampleNearest2dKernelNpu.cpp | 84 +++++++ 14 files changed, 1413 insertions(+) create mode 100644 test/test_network_ops/test_floor.py create mode 100644 test/test_network_ops/test_remainder.py create mode 100644 test/test_network_ops/test_sort.py create mode 100644 test/test_network_ops/test_upsample_bilinear2d.py create mode 100644 test/test_network_ops/test_upsample_bilinear2d_backward.py create mode 100644 test/test_network_ops/test_upsample_nearest2d.py create mode 100644 test/test_network_ops/test_upsample_nearest2d_backward.py create mode 100644 torch_npu/csrc/aten/ops/FloorKernelNpu.cpp create mode 100644 torch_npu/csrc/aten/ops/RemainderKernelNpu.cpp create mode 100644 torch_npu/csrc/aten/ops/SortKernelNpu.cpp create mode 100644 torch_npu/csrc/aten/ops/UpsampleBilinear2dBackwardKernelNpu.cpp create mode 100644 torch_npu/csrc/aten/ops/UpsampleBilinear2dKernelNpu.cpp create mode 100644 torch_npu/csrc/aten/ops/UpsampleNearest2dBackwardKernelNpu.cpp create mode 100644 torch_npu/csrc/aten/ops/UpsampleNearest2dKernelNpu.cpp diff --git a/test/test_network_ops/test_floor.py b/test/test_network_ops/test_floor.py new file mode 100644 index 0000000000..800e5fca2d --- /dev/null +++ b/test/test_network_ops/test_floor.py @@ -0,0 +1,155 @@ +# Copyright (c) 2020 Huawei Technologies Co., Ltd +# Copyright (c) 2019, Facebook CORPORATION. +# All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +import torch_npu +import numpy as np + +from torch_npu.testing.testcase import TestCase, run_tests +from torch_npu.testing.common_utils import create_common_tensor + +class TestFloor(TestCase): + def cpu_op_exec(self, input1): + output = torch.floor(input1) + output = output.numpy() + return output + + def npu_op_exec(self, input1): + output = torch.floor(input1) + output = output.to("cpu") + output = output.numpy() + return output + + def cpu_op_inter_exec(self, input1): + torch.floor_(input1) + output = input1.numpy() + return output + + def npu_op_inter_exec(self, input1): + torch.floor_(input1) + output = input1.to("cpu") + output = output.numpy() + return output + + def cpu_op_out_exec(self, input1, output): + torch.floor(input1, out = output) + output = output.numpy() + return output + + def npu_op_out_exec(self, input1, output): + torch.floor(input1, out = output) + output = output.to("cpu") + output = output.numpy() + return output + + def test_floor_float32_shape_format(self, device="npu"): + format_list = [0, 3] + shape_list = [[256, 1, 1, 1], [1024, 32, 7, 7], [1024, 32, 7], [1024, 32], [1024]] + shape_format = [ + [np.float32, i, j] for i in format_list for j in shape_list + ] + for item in shape_format: + cpu_input, npu_input = create_common_tensor(item, 1, 100) + cpu_output = self.cpu_op_exec(cpu_input) + npu_output = self.npu_op_exec(npu_input) + self.assertRtolEqual(cpu_output, npu_output) + + def test_floor_inter_float32_shape_format(self, device="npu"): + format_list = [0, 3] + shape_list = [[256, 1, 1, 1], [1024, 32, 7, 7], [1024, 32, 7], [1024, 32], [1024]] + shape_format = [ + [np.float32, i, j] for i in format_list for j in shape_list + ] + for item in shape_format: + cpu_input, npu_input = create_common_tensor(item, 1, 100) + cpu_output = self.cpu_op_inter_exec(cpu_input) + npu_output = self.npu_op_inter_exec(npu_input) + self.assertRtolEqual(cpu_output, npu_output) + + def test_floor_out_float32_shape_format(self, device="npu"): + shape_format = [ + [[np.float32, 0, [1024, 32, 7, 7]], [np.float32, 0, [1024, 32, 7, 7]]], + [[np.float32, 0, [1024, 32, 7]], [np.float32, 0, [1024, 32]]], + [[np.float32, 0, [1024, 32]], [np.float32, 0, [1024, 32]]], + [[np.float32, 0, [1024]], [np.float32, 0, [1024, 1]]], + [[np.float32, 3, [1024, 32, 7, 7]], [np.float32, 3, [1024, 32, 7, 7]]], + [[np.float32, 3, [1024, 32, 7]], [np.float32, 3, [1024, 32]]], + [[np.float32, 3, [1024, 32]], [np.float32, 3, [1024, 20]]], + [[np.float32, 3, [1024]], [np.float32, 3, [1024]]], + ] + for item in shape_format: + cpu_input, npu_input = create_common_tensor(item[0], 1, 100) + cpu_output, npu_output = create_common_tensor(item[1], 1, 100) + cpu_output = self.cpu_op_exec(cpu_input) + npu_output = self.npu_op_out_exec(npu_input, npu_output) + self.assertRtolEqual(cpu_output, npu_output) + + def test_floor_float16_shape_format(self, device="npu"): + format_list = [0, 3] + shape_list = [[256, 1, 1, 1], [1024, 32, 7, 7], [1024, 32, 7], [1024, 32], [1024]] + shape_format = [ + [np.float16, i, j] for i in format_list for j in shape_list + ] + for item in shape_format: + cpu_input, npu_input = create_common_tensor(item, 1, 100) + if item[0] == np.float16: + cpu_input = cpu_input.to(torch.float32) + cpu_output = self.cpu_op_exec(cpu_input) + npu_output = self.npu_op_exec(npu_input) + if item[0] == np.float16: + cpu_output = cpu_output.astype(np.float16) + self.assertRtolEqual(cpu_output, npu_output) + + def test_floor_inter_float16_shape_format(self, device="npu"): + format_list = [0, 3] + shape_list = [[256, 1, 1, 1], [1024, 32, 7, 7], [1024, 32, 7], [1024, 32], [1024]] + shape_format = [ + [np.float16, i, j] for i in format_list for j in shape_list + ] + for item in shape_format: + cpu_input, npu_input = create_common_tensor(item, 1, 100) + if item[0] == np.float16: + cpu_input = cpu_input.to(torch.float32) + cpu_output = self.cpu_op_inter_exec(cpu_input) + npu_output = self.npu_op_inter_exec(npu_input) + if item[0] == np.float16: + cpu_output = cpu_output.astype(np.float16) + self.assertRtolEqual(cpu_output, npu_output) + + def test_floor_out_float16_shape_format(self, device="npu"): + shape_format = [ + [[np.float16, 0, [1024, 32, 7, 7]], [np.float16, 0, [1024, 32, 7, 7]]], + [[np.float16, 0, [1024, 32, 7]], [np.float16, 0, [1024, 32]]], + [[np.float16, 0, [1024, 32]], [np.float16, 0, [1024, 32]]], + [[np.float16, 0, [1024]], [np.float16, 0, [1024, 1]]], + [[np.float16, 3, [1024, 32, 7, 7]], [np.float16, 3, [1024, 32, 7, 7]]], + [[np.float16, 3, [1024, 32, 7]], [np.float16, 3, [1024, 32]]], + [[np.float16, 3, [1024, 32]], [np.float16, 3, [1024, 20]]], + [[np.float16, 3, [1024]], [np.float16, 3, [1024]]], + ] + for item in shape_format: + cpu_input, npu_input = create_common_tensor(item[0], 1, 100) + cpu_output, npu_output = create_common_tensor(item[1], 1, 100) + if item[0][0] == np.float16: + cpu_input = cpu_input.to(torch.float32) + cpu_output = cpu_output.to(torch.float32) + cpu_output = self.cpu_op_exec(cpu_input) + npu_output = self.npu_op_out_exec(npu_input, npu_output) + if item[0][0] == np.float16: + cpu_output = cpu_output.astype(np.float16) + self.assertRtolEqual(cpu_output, npu_output) + +if __name__ == "__main__": + run_tests() \ No newline at end of file diff --git a/test/test_network_ops/test_remainder.py b/test/test_network_ops/test_remainder.py new file mode 100644 index 0000000000..b28251efd0 --- /dev/null +++ b/test/test_network_ops/test_remainder.py @@ -0,0 +1,224 @@ +# Copyright (c) 2020, Huawei Technologies.All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import copy +import torch +import torch_npu +import numpy as np + +from torch_npu.testing.testcase import TestCase, run_tests +from torch_npu.testing.common_utils import create_common_tensor + +class TestRemainder(TestCase): + def cpu_op_exec(self, input1, input2): + output = torch.remainder(input1, input2) + output = output.numpy() + return output + + def npu_op_exec(self, input1, input2): + output = torch.remainder(input1, input2) + output = output.to("cpu") + output = output.numpy() + return output + + def npu_op_exec_out(self, input1, input2, out): + output = torch.remainder(input1, input2, out=out) + output = out.to("cpu") + output = output.numpy() + return output + + def cpu_op_inplace_exec(self, input1, input2): + output = input1.remainder_(input2) + output = output.numpy() + return output + + def npu_op_inplace_exec(self, input1, input2): + output = input1.remainder_(input2) + output = input1.to("cpu") + output = output.numpy() + return output + + def npu_op_exec_scalar(self, input1, input2): + output = torch.remainder(input1, input2) + output = output.to("cpu") + output = output.numpy() + return output + + def remainder_out_result(self, shape_format): + for item in shape_format: + cpu_input1, npu_input1 = create_common_tensor(item, 0, 100) + cpu_input2, npu_input2 = create_common_tensor(item, 0, 100) + npu_input3 = torch.randn(6).to("npu") + if cpu_input1.dtype == torch.float16: + cpu_input1 = cpu_input1.to(torch.float32) + cpu_input2 = cpu_input2.to(torch.float32) + + cpu_output = self.cpu_op_exec(cpu_input1, cpu_input2) + npu_output_out = self.npu_op_exec_out(npu_input1, npu_input2, npu_input3) + self.assertRtolEqual(cpu_output, npu_output_out) + + def remainder_result(self, shape_format): + for item in shape_format: + cpu_input1, npu_input1 = create_common_tensor(item, 0, 100) + cpu_input2, npu_input2 = create_common_tensor(item, 0, 100) + npu_input3 = copy.deepcopy(cpu_input1).to("npu") + if cpu_input1.dtype == torch.float16: + cpu_input1 = cpu_input1.to(torch.float32) + cpu_input2 = cpu_input2.to(torch.float32) + cpu_output = self.cpu_op_exec(cpu_input1, cpu_input2) + npu_output = self.npu_op_exec(npu_input1, npu_input2) + npu_output_out = self.npu_op_exec_out(npu_input1, npu_input2, npu_input3) + cpu_output_inplace = self.cpu_op_inplace_exec(cpu_input1, cpu_input2) + npu_output_inplace = self.npu_op_inplace_exec(npu_input1, npu_input2) + + cpu_output = cpu_output.astype(npu_output.dtype) + cpu_output_inplace = cpu_output_inplace.astype(npu_output_inplace.dtype) + + self.assertRtolEqual(cpu_output, npu_output) + self.assertRtolEqual(cpu_output_inplace, npu_output_inplace) + self.assertRtolEqual(cpu_output, npu_output_out) + + def remainder_scalar_result(self, shape_format): + for item in shape_format: + scalar = np.random.uniform(0, 100) + cpu_input1, npu_input1 = create_common_tensor(item, 0, 2) + npu_input3 = copy.deepcopy(cpu_input1).to("npu") + if cpu_input1.dtype == torch.float16: + cpu_input1 = cpu_input1.to(torch.float32) + cpu_output = self.cpu_op_exec(cpu_input1, scalar) + npu_output_scalar = self.npu_op_exec_scalar(npu_input1, scalar) + npu_output_out = self.npu_op_exec_out(npu_input1, scalar, npu_input3) + + cpu_output = cpu_output.astype(npu_output_scalar.dtype) + self.assertRtolEqual(cpu_output, npu_output_scalar) + self.assertRtolEqual(cpu_output, npu_output_out) + + def test_remainder_shape_format_fp16_1d(self, device="npu"): + format_list = [0, 3] + shape_format = [[np.float16, i, [4]] for i in format_list + ] + self.remainder_result(shape_format) + + def test_remainder_shape_format_fp32_1d(self, device="npu"): + format_list = [0, 3] + shape_format = [[np.float32, i, [4]] for i in format_list + ] + self.remainder_result(shape_format) + + def test_remainder_shape_format_fp16_2d(self, device="npu"): + format_list = [0, 3, 29] + shape_format = [[np.float16, i, [4, 18]] for i in format_list + ] + self.remainder_result(shape_format) + + def test_remainder_shape_format_fp32_2d(self, device="npu"): + format_list = [0, 3, 29] + shape_format = [[np.float32, i, [4, 18]] for i in format_list + ] + self.remainder_result(shape_format) + + def test_remainder_shape_format_fp16_3d(self, device="npu"): + format_list = [0, 3, 29] + shape_format = [[np.float16, i, [4, 18, 32]] for i in format_list + ] + self.remainder_result(shape_format) + + def test_remainder_shape_format_fp32_3d(self, device="npu"): + format_list = [0, 3, 29] + shape_format = [[np.float32, i, [4, 18, 32]] for i in format_list + ] + self.remainder_result(shape_format) + + def test_remainder_shape_format_fp16_4d(self, device="npu"): + format_list = [0, 3, 29] + shape_format = [[np.float16, i, [4, 18, 32, 128]] for i in format_list + ] + self.remainder_result(shape_format) + + def test_remainder_shape_format_fp32_4d(self, device="npu"): + format_list = [0, 3, 29] + shape_format = [[np.float32, i, [4, 18, 32, 128]] for i in format_list + ] + self.remainder_result(shape_format) + + # scalar---------------------------------------------------------- + def test_remainder_scalar_shape_format_fp16_1d(self, device="npu"): + format_list = [0, 3] + shape_format = [[np.float16, i, [4]] for i in format_list + ] + self.remainder_scalar_result(shape_format) + + def test_remainder_scalar_shape_format_fp32_1d(self, device="npu"): + format_list = [0, 3] + shape_format = [[np.float32, i, [4]] for i in format_list + ] + self.remainder_scalar_result(shape_format) + + def test_remainder_scalar_shape_format_fp16_2d(self, device="npu"): + format_list = [0, 3, 29] + shape_format = [[np.float16, i, [4, 18]] for i in format_list + ] + self.remainder_scalar_result(shape_format) + + def test_remainder_scalar_shape_format_fp32_2d(self, device="npu"): + format_list = [0, 3, 29] + shape_format = [[np.float32, i, [4, 18]] for i in format_list + ] + self.remainder_scalar_result(shape_format) + + def test_remainder_scalar_shape_format_fp16_3d(self, device="npu"): + format_list = [0, 3, 29] + shape_format = [[np.float16, i, [4, 18, 32]] for i in format_list + ] + self.remainder_scalar_result(shape_format) + + def test_remainder_scalar_shape_format_fp32_3d(self, device="npu"): + format_list = [0, 3, 29] + shape_format = [[np.float32, i, [4, 18, 32]] for i in format_list + ] + self.remainder_scalar_result(shape_format) + + def test_remainder_scalar_shape_format_fp16_4d(self, device="npu"): + format_list = [0, 3, 29] + shape_format = [[np.float16, i, [4, 18, 32, 128]] for i in format_list + ] + self.remainder_scalar_result(shape_format) + + def test_remainder_scalar_shape_format_fp32_4d(self, device="npu"): + format_list = [0, 3, 29] + shape_format = [[np.float32, i, [4, 18, 32, 128]] for i in format_list + ] + self.remainder_scalar_result(shape_format) + + def test_remainder_mix_dtype_1(self, device="npu"): + npu_input1, npu_input2 = create_common_tensor([np.int32, 0, (2, 3)], 1, 100) + npu_input3, npu_input4 = create_common_tensor([np.float32, 0, (2, 3)], 1, 100) + cpu_output = self.cpu_op_exec(npu_input1, npu_input3) + npu_output = self.npu_op_exec(npu_input1, npu_input3) + self.assertRtolEqual(cpu_output, npu_output) + + def test_remainder_mix_dtype_2(self, device="npu"): + npu_input1, npu_input2 = create_common_tensor([np.float32, 0, (2, 3)], 1, 100) + npu_input3 = torch.tensor(3).int() + cpu_output = self.cpu_op_exec(npu_input1, npu_input3) + npu_output = self.npu_op_exec(npu_input1, npu_input3) + self.assertRtolEqual(cpu_output, npu_output) + + def test_remainder_scalar_shape_format_fp32_out_4d(self, device="npu"): + format_list = [0] + shape_format = [[np.float32, i, [4, 18, 32, 128]] for i in format_list + ] + self.remainder_out_result(shape_format) + +if __name__ == "__main__": + run_tests() diff --git a/test/test_network_ops/test_sort.py b/test/test_network_ops/test_sort.py new file mode 100644 index 0000000000..27cc84ef1b --- /dev/null +++ b/test/test_network_ops/test_sort.py @@ -0,0 +1,92 @@ +# Copyright (c) 2020 Huawei Technologies Co., Ltd +# Copyright (c) 2019, Facebook CORPORATION. +# All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +import torch_npu +import numpy as np + +from torch_npu.testing.testcase import TestCase, run_tests +from torch_npu.testing.common_utils import create_common_tensor + +class TestSort(TestCase): + def cpu_op_exec(self, input1, dim): + output, indices = torch.sort(input1, dim=dim) + output = output.numpy() + indices = indices.numpy() + return output, indices + + def npu_op_exec(self, input1, dim): + output, indices = torch.sort(input1, dim=dim) + output = output.cpu() + indices = indices.cpu() + output = output.numpy() + indices = indices.numpy() + return output, indices + + def cpu_default_op_exec(self, input1): + output, indices = torch.sort(input1) + output = output.numpy() + indices = indices.numpy() + return output, indices + + def npu_default_op_exec(self, input1): + output, indices = torch.sort(input1) + output = output.cpu() + indices = indices.cpu() + output = output.numpy() + indices = indices.numpy() + return output, indices + + # at present accuracy under FP32 is inadequate + def _test_sort_shape_format_fp32(self, device="npu"): + shape_format = [ + [[np.float32, 0, (8, 4, 3, 9)], 2], + [[np.float32, 0, (2, 3)]], + [[np.float32, 0, (1, 7)], 0], + [[np.float32, 0, (1, 5, 6)], 1], + ] + + for item in shape_format: + cpu_input1, npu_input1 = create_common_tensor(item[0], -100, 100) + if len(item) > 1: + cpu_output, cpu_indices = self.cpu_op_exec(cpu_input1, item[1]) + npu_output, npu_indices = self.npu_op_exec(npu_input1, item[1]) + else: + cpu_output, cpu_indices = self.cpu_default_op_exec(cpu_input1) + npu_output, npu_indices = self.npu_default_op_exec(npu_input1) + self.assertRtolEqual(cpu_output, npu_output) + self.assertRtolEqual(cpu_indices, npu_indices) + + def test_sort_shape_format_fp16(self, device="npu"): + shape_format = [ + [[np.float16, 0, (8, 4, 3, 9)], 2], + [[np.float16, 0, (2, 3)]], + [[np.float16, 0, (1, 7)], 0], + [[np.float16, 0, (1, 5, 6)], 1], + ] + + for item in shape_format: + cpu_input1, npu_input1 = create_common_tensor(item[0], -100, 100) + if len(item) > 1: + cpu_output, cpu_indices = self.cpu_op_exec(cpu_input1.to(torch.float32), item[1]) + npu_output, npu_indices = self.npu_op_exec(npu_input1, item[1]) + else: + cpu_output, cpu_indices = self.cpu_default_op_exec(cpu_input1.to(torch.float32)) + npu_output, npu_indices = self.npu_default_op_exec(npu_input1) + self.assertRtolEqual(cpu_output.astype(np.float16), npu_output) + self.assertRtolEqual(cpu_indices, npu_indices) + +if __name__ == "__main__": + run_tests() \ No newline at end of file diff --git a/test/test_network_ops/test_upsample_bilinear2d.py b/test/test_network_ops/test_upsample_bilinear2d.py new file mode 100644 index 0000000000..d610e9dcbe --- /dev/null +++ b/test/test_network_ops/test_upsample_bilinear2d.py @@ -0,0 +1,54 @@ +# Copyright (c) 2020, Huawei Technologies.All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +import torch_npu +import numpy as np + +from torch_npu.testing.testcase import TestCase, run_tests +from torch_npu.testing.common_utils import create_common_tensor + +class TestUpsampleBilinear2d(TestCase): + + def cpu_op_exec(self, inputs, shapes): + output = torch._C._nn.upsample_bilinear2d(inputs, shapes, True, 0, 0) + output = output.numpy() + return output + + def npu_op_exec(self, inputs, shapes): + output = torch._C._nn.upsample_bilinear2d(inputs, shapes, True, 0, 0) + output = output.to("cpu") + output = output.numpy() + return output + + def test_UpsampleBilinear2d_common_shape_format(self, device="npu"): + shape_format = [ + [[np.float32, -1, (1, 1, 3000, 3000)], (2500, 2500)], + [[np.float32, -1, (4, 3, 1, 5)], (2, 2)], + [[np.float32, -1, (2, 3, 2, 1)], (3, 3)], + [[np.float32, -1, (1, 4, 2, 2)], (4, 4)], + [[np.float16, -1, (4, 10, 16, 14)], (5, 5)], + [[np.float16, -1, (8, 8, 8, 8)], (1, 2)], + [[np.float16, -1, (10, 4, 3, 2)], (2, 4)] + ] + for item in shape_format: + cpu_inputs, npu_inputs = create_common_tensor(item[0], 1, 100) + if cpu_inputs.dtype == torch.float16: + cpu_inputs = cpu_inputs.to(torch.float32) + cpu_output = self.cpu_op_exec(cpu_inputs, item[1]) + npu_output = self.npu_op_exec(npu_inputs, item[1]) + cpu_output = cpu_output.astype(npu_output.dtype) + self.assertRtolEqual(cpu_output, npu_output) + +if __name__ == "__main__": + run_tests() \ No newline at end of file diff --git a/test/test_network_ops/test_upsample_bilinear2d_backward.py b/test/test_network_ops/test_upsample_bilinear2d_backward.py new file mode 100644 index 0000000000..92688d3ae7 --- /dev/null +++ b/test/test_network_ops/test_upsample_bilinear2d_backward.py @@ -0,0 +1,61 @@ +# Copyright (c) 2020, Huawei Technologies.All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +import torch_npu +import numpy as np + +from torch_npu.testing.testcase import TestCase, run_tests +from torch_npu.testing.common_utils import create_common_tensor + +class TestUpsampleBilinear2dBackward(TestCase): + def cpu_op_exec(self, inputs, shapes): + inputs.requires_grad_(True) + output = torch._C._nn.upsample_bilinear2d(inputs, shapes, True, 0, 0) + output.backward(torch.ones_like(output)) + gradcpu = inputs.grad + return output.detach().numpy(), gradcpu.detach().numpy() + + def npu_op_exec(self, inputs, shapes): + inputs.requires_grad_(True) + output = torch._C._nn.upsample_bilinear2d(inputs, shapes, True, 0, 0) + inputback = torch.ones_like(output) + output.backward(inputback) + out = output.to("cpu") + grad = inputs.grad + grad = grad.to("cpu") + return out.detach().numpy(), grad.detach().numpy() + + def test_UpsampleBilinear2d_common_shape_format(self, device="npu"): + shape_format = [ + [[np.float32, -1, (4, 3, 1, 5)], (2, 2)], + [[np.float32, -1, (2, 3, 2, 1)], (3, 3)], + [[np.float32, -1, (1, 4, 2, 2)], (4, 4)], + [[np.float16, -1, (4, 10, 16, 14)], (5, 5)], + [[np.float16, -1, (8, 8, 8, 8)], (1, 2)], + [[np.float16, -1, (10, 4, 3, 2)], (2, 4)] + ] + for item in shape_format: + cpu_inputs, npu_inputs = create_common_tensor(item[0], 1, 100) + if cpu_inputs.dtype == torch.float16: + cpu_inputs = cpu_inputs.to(torch.float32) + cpu_output, cpu_grad = self.cpu_op_exec(cpu_inputs, item[1]) + npu_output, npu_grad = self.npu_op_exec(npu_inputs, item[1]) + cpu_output = cpu_output.astype(npu_output.dtype) + cpu_grad = cpu_grad.astype(npu_grad.dtype) + + self.assertRtolEqual(cpu_output, npu_output) + self.assertRtolEqual(cpu_grad, npu_grad) + +if __name__ == "__main__": + run_tests() diff --git a/test/test_network_ops/test_upsample_nearest2d.py b/test/test_network_ops/test_upsample_nearest2d.py new file mode 100644 index 0000000000..5e974f8fe3 --- /dev/null +++ b/test/test_network_ops/test_upsample_nearest2d.py @@ -0,0 +1,54 @@ +# Copyright (c) 2020, Huawei Technologies.All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +import torch_npu +import torch.nn.functional as F +import numpy as np + +from torch_npu.testing.testcase import TestCase, run_tests +from torch_npu.testing.common_utils import create_common_tensor + +class TestUpsamleNearest2D(TestCase): + def cpu_op_exec(self, input1, size): + output = torch.nn.functional.interpolate(input1, size, mode="nearest") + output = output.numpy() + return output + + def npu_op_exec(self, input1, size): + output = torch.nn.functional.interpolate(input1, size, mode="nearest") + output = output.to("cpu") + output = output.numpy() + return output + + def test_upsample_nearest2d_shape_format(self, device="npu"): + shape_format = [ + [[np.float32, 0, (5, 3, 6, 4)], [10, 10]], + [[np.float16, 0, (5, 3, 6, 4)], [10, 10]], + [[np.float32, 0, (2, 3, 2, 4)], [10, 10]], + [[np.float16, -1, (2, 3, 2, 3)], [10, 10]] + ] + + for item in shape_format: + cpu_input, npu_input = create_common_tensor(item[0], 0, 50) + if cpu_input == torch.float16: + cpu_input = cpu_input.to(torch.float32) + + size = item[1] + cpu_output = self.cpu_op_exec(cpu_input, size) + npu_output = self.npu_op_exec(npu_input, size) + cpu_output = cpu_output.astype(npu_output.dtype) + self.assertRtolEqual(cpu_output, npu_output) + +if __name__ == "__main__": + run_tests() diff --git a/test/test_network_ops/test_upsample_nearest2d_backward.py b/test/test_network_ops/test_upsample_nearest2d_backward.py new file mode 100644 index 0000000000..0ea9acc84d --- /dev/null +++ b/test/test_network_ops/test_upsample_nearest2d_backward.py @@ -0,0 +1,63 @@ +# Copyright (c) 2020, Huawei Technologies.All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +import torch_npu +import torch.nn.functional as F +import numpy as np + +from torch_npu.testing.testcase import TestCase, run_tests +from torch_npu.testing.common_utils import create_common_tensor + +class TestUpsampleNearest2DBackward(TestCase): + def cpu_op_exec(self, input1, size): + input1.requires_grad_(True) + output = F.interpolate(input1, size, mode = "nearest") + output.backward(torch.ones_like(output)) + return output.detach().numpy(), input1.grad.numpy() + + def npu_op_exec(self, input1, size): + input1.requires_grad_(True) + output = F.interpolate(input1, size, mode = "nearest") + inputback = torch.ones_like(output) + output.backward(inputback) + out = output.to("cpu") + grad = input1.grad + grad = grad.to("cpu") + return out.detach().numpy(), grad.detach().numpy() + + def test_upsample_bilinear2d_shape_format(self, device="npu"): + shape_format = [ + [[np.float32, 0, (2, 3, 4, 4)], [2, 2]], + [[np.float16, 0, (2, 3, 4, 4)], [2, 2]], + [[np.float32, 0, (5, 3, 6, 4)], [10, 10]], + [[np.float16, 0, (5, 3, 6, 4)], [10, 10]], + [[np.float32, 0, (2, 3, 2, 4)], [10, 10]], + [[np.float16, -1, (2, 3, 2, 3)], [10, 10]] + ] + + for item in shape_format: + cpu_input, npu_input = create_common_tensor(item[0], 0, 100) + if cpu_input == torch.float16: + cpu_input = cpu_input.to(torch.float32) + + cpu_output, cpu_grad = self.cpu_op_exec(cpu_input, item[1]) + npu_output, npu_grad = self.npu_op_exec(npu_input, item[1]) + + cpu_grad = cpu_grad.astype(npu_grad.dtype) + cpu_output = cpu_output.astype(npu_output.dtype) + self.assertRtolEqual(cpu_output, npu_output) + self.assertRtolEqual(cpu_grad, npu_grad) + +if __name__ == "__main__": + run_tests() \ No newline at end of file diff --git a/torch_npu/csrc/aten/ops/FloorKernelNpu.cpp b/torch_npu/csrc/aten/ops/FloorKernelNpu.cpp new file mode 100644 index 0000000000..5784de81de --- /dev/null +++ b/torch_npu/csrc/aten/ops/FloorKernelNpu.cpp @@ -0,0 +1,57 @@ +// Copyright (c) 2020 Huawei Technologies Co., Ltd +// Copyright (c) 2019, Facebook CORPORATION. +// All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" + +namespace at_npu { +namespace native { + +at::Tensor& floor_out_npu_nocheck(const at::Tensor& self, at::Tensor& result) { + OpCommand cmd; + cmd.Name("Floor") + .Input(self) + .Output(result) + .Run(); + + return result; +} + +at::Tensor& NPUNativeFunctions::floor_out(const at::Tensor& self, at::Tensor& result) { + OpPreparation::CheckOut( + {self}, + result, + self); + + OpPipeWithDefinedOut pipe; + return pipe.CheckMemory({self}, {result}) + .Func([&self](at::Tensor& result){floor_out_npu_nocheck(self, result);}) + .Call(result); +} + +at::Tensor& NPUNativeFunctions::floor_(at::Tensor& self) { + NPUNativeFunctions::floor_out(self, self); + + return self; +} + +at::Tensor NPUNativeFunctions::floor(const at::Tensor& self) { + Tensor result = OpPreparation::ApplyTensor(self); + floor_out_npu_nocheck(self, result); + return result; +} + +} // namespace native +} // namespace at_npu diff --git a/torch_npu/csrc/aten/ops/RemainderKernelNpu.cpp b/torch_npu/csrc/aten/ops/RemainderKernelNpu.cpp new file mode 100644 index 0000000000..257103b57e --- /dev/null +++ b/torch_npu/csrc/aten/ops/RemainderKernelNpu.cpp @@ -0,0 +1,141 @@ +// Copyright (c) 2022 Huawei Technologies Co., Ltd +// Copyright (c) 2019, Facebook CORPORATION. +// All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "torch_npu/csrc/framework/utils/CalcuOpUtil.h" +#include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" + +namespace at_npu { +namespace native { + +at::Tensor& remainder_out_scalar_npu_nocheck( + at::Tensor& result, + const at::Tensor& self, + const at::Scalar other) { + OpCommand cmd; + cmd.Name("FloorMod") + .Input(self) + .Input(other, self.scalar_type()) + .Output(result) + .Run(); + + return result; +} + +at::Tensor& NPUNativeFunctions::remainder_out( + const at::Tensor& self, + const at::Scalar other, + at::Tensor& result) { + OpPreparation::CheckOut({self}, result, self); + remainder_out_scalar_npu_nocheck(result, self, other); + return result; +} + +at::Tensor& remainder_out_tensor_npu_nocheck( + at::Tensor& result, + const at::Tensor& self, + const at::Tensor& other) { + auto unified_result = OpPreparation::binary_op_check(result, self, other, true); + if (other.dim() == 0) { + NPUNativeFunctions::remainder_out(self, other.item(), result); + } else { + OpCommand cmd; + cmd.Name("FloorMod") + .Expect(unified_result) + .Input(self) + .Input(other) + .Output(result) + .Run(); + } + + return result; +} + +at::Tensor& NPUNativeFunctions::remainder_out( + const at::Tensor& self, + const at::Tensor& other, + at::Tensor& result) { + at::Tensor outputTensor = CalcuOpUtil::is_scalar_wrapped_to_tensor(self) ? other : self; + auto outputSize = broadcast_ops_npu_output_size(self, other); + OpPreparation::CheckOut( + {self}, + result, + CalcuOpUtil::get_tensor_npu_format(outputTensor), + self.scalar_type(), + outputSize); + remainder_out_tensor_npu_nocheck(result, self, other); + + return result; +} + +at::Tensor NPUNativeFunctions::remainder(const at::Tensor& self, const at::Tensor& other) { + // calculate the output size + bool isSelfWrapped = CalcuOpUtil::is_scalar_wrapped_to_tensor(self); + at::Tensor outputTensor = isSelfWrapped ? other : self; + + auto outputSize = broadcast_ops_npu_output_size(self, other); + + // construct the output tensor of the NPU + at::Tensor result = OpPreparation::ApplyTensorWithFormat( + outputSize, + outputTensor.options(), + CalcuOpUtil::get_tensor_npu_format(outputTensor)); + + // calculate the output result of the NPU + remainder_out_tensor_npu_nocheck(result, self, other); + + return result; +} + +at::Tensor NPUNativeFunctions::remainder(const at::Tensor& self, at::Scalar other) { + // construct the output tensor of the NPU + at::Tensor result = OpPreparation::ApplyTensor(self); + + // calculate the output result of the NPU + remainder_out_scalar_npu_nocheck(result, self, other); + + return result; +} + +at::Tensor& NPUNativeFunctions::remainder_(at::Tensor& self, const at::Tensor& other) { + at::SmallVector inputs = {other}; + at::SmallVector outputs = {self}; + CalcuOpUtil::check_memory_over_laps(inputs, outputs); + + if (!NpuUtils::check_match(&self)) { + at::Tensor contiguousSelf = NpuUtils::format_contiguous(self); + at::Tensor result = remainder_out_tensor_npu_nocheck(contiguousSelf, contiguousSelf, other); + NpuUtils::format_fresh_view(self, result); + } else { + remainder_out_tensor_npu_nocheck(self, self, other); + } + + return self; +} + + +at::Tensor& NPUNativeFunctions::remainder_(at::Tensor& self, at::Scalar other) { + if (!NpuUtils::check_match(&self)) { + at::Tensor contiguousSelf = NpuUtils::format_contiguous(self); + at::Tensor result = remainder_out_scalar_npu_nocheck(contiguousSelf, contiguousSelf, other); + NpuUtils::format_fresh_view(self, result); + } else { + remainder_out_scalar_npu_nocheck(self, self, other); + } + return self; +} + +} // namespace native +} // namespace at_npu \ No newline at end of file diff --git a/torch_npu/csrc/aten/ops/SortKernelNpu.cpp b/torch_npu/csrc/aten/ops/SortKernelNpu.cpp new file mode 100644 index 0000000000..1a1fb161fb --- /dev/null +++ b/torch_npu/csrc/aten/ops/SortKernelNpu.cpp @@ -0,0 +1,110 @@ +// Copyright (c) 2020 Huawei Technologies Co., Ltd +// Copyright (c) 2019, Facebook CORPORATION. +// All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "torch_npu/csrc/framework/utils/CalcuOpUtil.h" +#include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" + +namespace at_npu { +namespace native { + +tuple sort_out_npu_no_transpose( + const at::Tensor& self, + int64_t dim, + bool descending, + at::Tensor& values, + at::Tensor& indices) { + OpCommand cmd; + cmd.Name("Sort") + .Input(self) + .Output(values) + .Output(indices) + .Attr("axis", dim) + .Attr("descending", descending) + .Run(); + + return std::tie(values, indices); +} + +tuple NPUNativeFunctions::sort_out( + const at::Tensor& self, + int64_t dim, + bool descending, + at::Tensor& values, + at::Tensor& indices) { + dim = CalcuOpUtil::make_wrap_dim(dim, self.dim()); + int64_t lastDim = CalcuOpUtil::make_wrap_dim(-1, self.dim()); + + if (dim != lastDim) { + at::SmallVector perm; + for (int64_t i = 0; i < self.dim(); i++) { + perm.emplace_back(i); + } + std::swap(perm[dim], perm[lastDim]); + + at::Tensor transposeSelf = at::npu_transpose(self, perm); + auto outputSize = transpose_npu_output_size(values, perm); + at::Tensor transposeValues = OpPreparation::ApplyTensor(values, outputSize); + at::Tensor transposeIndices =OpPreparation::ApplyTensor(indices, outputSize); + + sort_out_npu_no_transpose( + transposeSelf, lastDim, descending, transposeValues, transposeIndices); + + at::npu_transpose_out(values, transposeValues, perm); + at::npu_transpose_out(indices, transposeIndices, perm); + } else { + sort_out_npu_no_transpose( + self, lastDim, descending, values, indices); + } + + // indices dtype transform Int64 + indices = indices.to(at::kLong); + + return std::tie(values, indices); +} + +tuple NPUNativeFunctions::sort_out( + const at::Tensor& self, + at::Dimname dim, + bool descending, + at::Tensor& values, + at::Tensor& indices) { + return NPUNativeFunctions::sort_out(self, dimname_to_position(self, dim), descending, values, indices); +} + +tuple NPUNativeFunctions::sort( + const at::Tensor& self, + int64_t dim, + bool descending) { + auto outputSize = input_same_output_size(self); + + at::Tensor values = OpPreparation::ApplyTensor(self); + at::Tensor indices = OpPreparation::ApplyTensorWithFormat( + outputSize, self.options().dtype(at::kInt), ACL_FORMAT_NCHW); + + NPUNativeFunctions::sort_out(self, dim, descending, values, indices); + + return std::tie(values, indices); +} + +tuple NPUNativeFunctions::sort( + const at::Tensor& self, + at::Dimname dim, + bool descending) { + return NPUNativeFunctions::sort(self, dimname_to_position(self, dim), descending); +} + +} // namespace native +} // namespace at_npu diff --git a/torch_npu/csrc/aten/ops/UpsampleBilinear2dBackwardKernelNpu.cpp b/torch_npu/csrc/aten/ops/UpsampleBilinear2dBackwardKernelNpu.cpp new file mode 100644 index 0000000000..b9acc93baf --- /dev/null +++ b/torch_npu/csrc/aten/ops/UpsampleBilinear2dBackwardKernelNpu.cpp @@ -0,0 +1,145 @@ +// Copyright (c) 2020 Huawei Technologies Co., Ltd +// Copyright (c) 2019, Facebook CORPORATION. +// All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "torch_npu/csrc/framework/utils/CalcuOpUtil.h" +#include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" + +namespace at_npu { +namespace native { + +bool upsample_bilinear2d_backward_check_is_aicore( + const at::Tensor& grad_output) { + int64_t H = grad_output.size(2); + int64_t W = grad_output.size(3); + // 判断H或W大于10000走ai_cpu算子 + if (H > 10000 || W > 10000) { + return false; + } + return true; +} + +at::Tensor& upsample_bilinear2d_backward_out_npu_nocheck( + at::Tensor& grad_input, + const at::Tensor& grad_output, + at::IntArrayRef output_size, + at::IntArrayRef input_size, + bool align_corners, + c10::optional scales_h, + c10::optional scales_w) { + bool isAicore = upsample_bilinear2d_backward_check_is_aicore(grad_output); + OpCommand cmd; + if (isAicore) { + at::Tensor original_image = OpPreparation::ApplyTensor(grad_output, input_size); + bool half_pixel_centers = !align_corners; + cmd.Name("ResizeBilinearV2Grad") + .Input(grad_output) + .Input(original_image) + .Output(grad_input) + .Attr("align_corners", align_corners) + .Attr("half_pixel_centers", half_pixel_centers) + .Run(); + } else { + cmd.Name("PTUpsampleBilinear2dGrad") + .Input(grad_output) + .Output(grad_input) + .Attr("output_size", output_size) + .Attr("input_size", input_size) + .Attr("align_corners", align_corners); + if (scales_h.has_value()) { + cmd.Attr("scales_h", static_cast(scales_h.value())); + } + if (scales_w.has_value()) { + cmd.Attr("scales_w", static_cast(scales_w.value())); + } + cmd.Run(); + } + return grad_input; +} + +at::Tensor& NPUNativeFunctions::upsample_bilinear2d_backward_out( + const at::Tensor& grad_output, + at::IntArrayRef output_size, + at::IntArrayRef input_size, + bool align_corners, + c10::optional scales_h, + c10::optional scales_w, + Tensor& grad_input) { + + OpPreparation::CheckOut( + {grad_output}, + grad_input, + grad_output, + input_size); + if (!NpuUtils::check_match(&grad_input)) { + at::Tensor result_contiguous = NpuUtils::format_contiguous(grad_input); + + upsample_bilinear2d_backward_out_npu_nocheck( + result_contiguous, + grad_output, + output_size, + input_size, + align_corners, + scales_h, + scales_w); + NpuUtils::format_fresh_view(grad_input, result_contiguous); + } else { + upsample_bilinear2d_backward_out_npu_nocheck( + grad_input, + grad_output, + output_size, + input_size, + align_corners, + scales_h, + scales_w); + } + return grad_input; +} + +at::Tensor NPUNativeFunctions::upsample_bilinear2d_backward( + const at::Tensor& grad_output_ex, + at::IntArrayRef output_size, + at::IntArrayRef input_size, + bool align_corners, + c10::optional scales_h, + c10::optional scales_w) { + at::Tensor grad_output = grad_output_ex; + bool isAicore = upsample_bilinear2d_backward_check_is_aicore(grad_output); + if (!isAicore) { + if (grad_output.scalar_type() != at::ScalarType::Float) { + grad_output = grad_output.npu_dtype_cast(at::ScalarType::Float); + } + } + auto outputSize = upsample_bilinear2d_backward_npu_output_size( + grad_output, output_size, input_size, align_corners, scales_h, scales_w); + aclFormat format = isAicore ? ACL_FORMAT_NC1HWC0 : (aclFormat)CalcuOpUtil::get_tensor_npu_format(grad_output); + at::Tensor grad_input = OpPreparation::ApplyTensorWithFormat(grad_output, outputSize, format); + + upsample_bilinear2d_backward_out_npu_nocheck( + grad_input, + grad_output, + output_size, + input_size, + align_corners, + scales_h, + scales_w); + if (grad_input.dtype() != grad_output_ex.dtype()) { + grad_input = grad_input.to(grad_output_ex.dtype()); + } + return grad_input; +} + +} // namespace native +} // namespace at_npu diff --git a/torch_npu/csrc/aten/ops/UpsampleBilinear2dKernelNpu.cpp b/torch_npu/csrc/aten/ops/UpsampleBilinear2dKernelNpu.cpp new file mode 100644 index 0000000000..689e94a399 --- /dev/null +++ b/torch_npu/csrc/aten/ops/UpsampleBilinear2dKernelNpu.cpp @@ -0,0 +1,99 @@ +// Copyright (c) 2020 Huawei Technologies Co., Ltd +// Copyright (c) 2019, Facebook CORPORATION. +// All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" + +namespace at_npu { +namespace native { + +at::Tensor& upsample_bilinear2d_out_npu_nocheck( + at::Tensor& result, + const at::Tensor& self, + at::IntArrayRef output_size, + bool align_corners, + c10::optional scales_h, + c10::optional scales_w) { + OpCommand cmd; + bool half_pixel_centers = !align_corners; + int64_t H = output_size[0]; + int64_t W = output_size[1]; + SmallVector attr_size = {H, W}; + cmd.Name("ResizeBilinearV2") + .Input(self) + .Input(attr_size, at::kInt) + .Output(result) + .Attr("align_corners", align_corners) + .Attr("half_pixel_centers", half_pixel_centers) + .Run(); + return result; +} + +at::Tensor& NPUNativeFunctions::upsample_bilinear2d_out( + const at::Tensor& self_ex, + at::IntArrayRef output_size, + bool align_corners, + c10::optional scales_h, + c10::optional scales_w, + at::Tensor& result){ + at::Tensor self = self_ex; + if (self.scalar_type() != at::ScalarType::Float) { + self = self.npu_dtype_cast(at::ScalarType::Float); + } + auto outputSize = upsample_bilinear2d_npu_output_size( + self, output_size, align_corners, scales_h, scales_w); + + OpPreparation::CheckOut( + {self}, + result, + self, + outputSize); + if (!NpuUtils::check_match(&result)) { + at::Tensor result_contiguous = NpuUtils::format_contiguous(result); + + upsample_bilinear2d_out_npu_nocheck( + result_contiguous, self, output_size, align_corners, scales_h, scales_w); + NpuUtils::format_fresh_view(result, result_contiguous); + } else { + upsample_bilinear2d_out_npu_nocheck( + result, self, output_size, align_corners, scales_h, scales_w); + } + return result; +} + +at::Tensor NPUNativeFunctions::upsample_bilinear2d( + const at::Tensor& self_ex, + at::IntArrayRef output_size, + bool align_corners, + c10::optional scales_h, + c10::optional scales_w) { + at::Tensor self = self_ex; + if (self.scalar_type() != at::ScalarType::Float) { + self = self.npu_dtype_cast(at::ScalarType::Float); + } + auto outputSize = upsample_bilinear2d_npu_output_size( + self, output_size, align_corners, scales_h, scales_w); + at::Tensor result = OpPreparation::ApplyTensor(outputSize, self.options(), self); + + upsample_bilinear2d_out_npu_nocheck( + result, self, output_size, align_corners, scales_h, scales_w); + if (result.dtype() != self_ex.dtype()) { + result = result.npu_dtype_cast(self_ex.scalar_type()); + } + return result; +} + +} // namespace native +} // namespace at_npu diff --git a/torch_npu/csrc/aten/ops/UpsampleNearest2dBackwardKernelNpu.cpp b/torch_npu/csrc/aten/ops/UpsampleNearest2dBackwardKernelNpu.cpp new file mode 100644 index 0000000000..0ffab86967 --- /dev/null +++ b/torch_npu/csrc/aten/ops/UpsampleNearest2dBackwardKernelNpu.cpp @@ -0,0 +1,74 @@ +// Copyright (c) 2020 Huawei Technologies Co., Ltd +// Copyright (c) 2019, Facebook CORPORATION. +// All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "torch_npu/csrc/framework/utils/CalcuOpUtil.h" +#include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" + +namespace at_npu { +namespace native { + +at::Tensor& NPUNativeFunctions::upsample_nearest2d_backward_out( + const at::Tensor& grads, + at::IntArrayRef output_size, + at::IntArrayRef input_size, + c10::optional scales_h, + c10::optional scales_w, + at::Tensor& y) { + at::SmallVector outputSize = {input_size[2], input_size[3]}; + OpCommand cmd; + cmd.Name("ResizeNearestNeighborV2Grad") + .Input(grads) + .Input(outputSize, at::kInt) + .Output(y) + .Attr("align_corners", false) + .Attr("half_pixel_centers", false) + .Run(); + + return y; +} + +at::Tensor NPUNativeFunctions::upsample_nearest2d_backward( + const at::Tensor& grad_output, + at::IntArrayRef output_size, + at::IntArrayRef input_size, + c10::optional scales_h, + c10::optional scales_w) { + at::Tensor grads = grad_output; + if (grad_output.scalar_type() != at::ScalarType::Float) { + grads = grad_output.to(at::kFloat); + } + at::Tensor grad_input = OpPreparation::ApplyTensor( + input_size, grads.options(), grad_output); + NPUNativeFunctions::upsample_nearest2d_backward_out( + grads, output_size, input_size, scales_h, scales_w, grad_input); + return grad_input; +} + +at::Tensor NPUNativeFunctions::upsample_nearest2d_backward( + const at::Tensor& grad_output, + c10::optional output_size, + at::IntArrayRef input_size, + c10::optional> scale_factors) { + auto osize = CalcuOpUtil::compute_output_size(input_size, output_size, scale_factors); + auto scales_h = CalcuOpUtil::get_scale_value(scale_factors, 0); + auto scales_w = CalcuOpUtil::get_scale_value(scale_factors, 1); + at::Tensor grad_input = OpPreparation::ApplyTensor(grad_output, input_size); + NPUNativeFunctions::upsample_nearest2d_backward_out(grad_output, osize, input_size, scales_h, scales_w, grad_input); + return grad_input; +} + +} // namespace native +} // namespace at_npu \ No newline at end of file diff --git a/torch_npu/csrc/aten/ops/UpsampleNearest2dKernelNpu.cpp b/torch_npu/csrc/aten/ops/UpsampleNearest2dKernelNpu.cpp new file mode 100644 index 0000000000..cb1fcfbbdb --- /dev/null +++ b/torch_npu/csrc/aten/ops/UpsampleNearest2dKernelNpu.cpp @@ -0,0 +1,84 @@ +// Copyright (c) 2020 Huawei Technologies Co., Ltd +// Copyright (c) 2019, Facebook CORPORATION. +// All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "torch_npu/csrc/framework/utils/CalcuOpUtil.h" +#include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" + +namespace at_npu { +namespace native { + +at::SmallVector upsample_nearest2d_npu_output_size( + const at::Tensor& input, + at::IntArrayRef output_size){ + int64_t N = input.size(0); + int64_t C = input.size(1); + int64_t H = output_size[0]; + int64_t W = output_size[1]; + at::SmallVector outputSize = {N, C, H, W}; + + return outputSize; +} + +at::Tensor& NPUNativeFunctions::upsample_nearest2d_out( + const at::Tensor& self, + at::IntArrayRef output_size, + c10::optional scales_h, + c10::optional scales_w, + at::Tensor& result) { + at::SmallVector outputSize = upsample_nearest2d_npu_output_size(self, output_size); + if (!result.sizes().equals(outputSize)){ + result.resize_(outputSize); + } + at::SmallVector outputSizeVec = array_to_small_vector(output_size); + OpCommand cmd; + cmd.Name("ResizeNearestNeighborV2") + .Input(self) + .Input(outputSizeVec, at::kInt) + .Output(result) + .Attr("align_corners", false) + .Attr("half_pixel_centers", false) + .Run(); + return result; +} + +at::Tensor NPUNativeFunctions::upsample_nearest2d( + const at::Tensor& self, + at::IntArrayRef output_size, + c10::optional scales_h, + c10::optional scales_w) { + at::SmallVector outputSize = upsample_nearest2d_npu_output_size(self, output_size); + + at::Tensor result = OpPreparation::ApplyTensor(self, outputSize); + + NPUNativeFunctions::upsample_nearest2d_out(self, output_size, scales_h, scales_w, result); + + return result; +} + +at::Tensor NPUNativeFunctions::upsample_nearest2d( + const at::Tensor& input, + c10::optional output_size, + c10::optional> scale_factors) { + auto osize = CalcuOpUtil::compute_output_size(input.sizes(), output_size, scale_factors); + auto scale_h = CalcuOpUtil::get_scale_value(scale_factors, 0); + auto scale_w = CalcuOpUtil::get_scale_value(scale_factors, 1); + at::Tensor result = OpPreparation::ApplyTensor(input, osize); + NPUNativeFunctions::upsample_nearest2d_out(input, osize, scale_h, scale_w, result); + return result; +} + +} // namespace native +} // namespace at_npu \ No newline at end of file -- Gitee From 9910354a6c942102b71ce2f9bd2d1b195da0ddc5 Mon Sep 17 00:00:00 2001 From: wangxiao Date: Tue, 1 Mar 2022 20:52:58 +0800 Subject: [PATCH 2/2] fix bugs of floor, sort, upsample_bilinear2d, upsample_bilinear2d_backward, upsample_nearest2d, upsample_nearest2d_backward --- test/test_network_ops/test_upsample_nearest2d.py | 2 +- test/test_network_ops/test_upsample_nearest2d_backward.py | 2 +- torch_npu/csrc/aten/npu_native_functions.yaml | 2 -- torch_npu/csrc/aten/ops/FloorKernelNpu.cpp | 2 +- torch_npu/csrc/aten/ops/SortKernelNpu.cpp | 6 +++--- .../csrc/aten/ops/UpsampleBilinear2dBackwardKernelNpu.cpp | 4 ++-- torch_npu/csrc/aten/ops/UpsampleBilinear2dKernelNpu.cpp | 8 ++++---- 7 files changed, 12 insertions(+), 14 deletions(-) diff --git a/test/test_network_ops/test_upsample_nearest2d.py b/test/test_network_ops/test_upsample_nearest2d.py index 5e974f8fe3..a73cbc8bb8 100644 --- a/test/test_network_ops/test_upsample_nearest2d.py +++ b/test/test_network_ops/test_upsample_nearest2d.py @@ -41,7 +41,7 @@ class TestUpsamleNearest2D(TestCase): for item in shape_format: cpu_input, npu_input = create_common_tensor(item[0], 0, 50) - if cpu_input == torch.float16: + if cpu_input.dtype == torch.float16: cpu_input = cpu_input.to(torch.float32) size = item[1] diff --git a/test/test_network_ops/test_upsample_nearest2d_backward.py b/test/test_network_ops/test_upsample_nearest2d_backward.py index 0ea9acc84d..6e74d03d57 100644 --- a/test/test_network_ops/test_upsample_nearest2d_backward.py +++ b/test/test_network_ops/test_upsample_nearest2d_backward.py @@ -48,7 +48,7 @@ class TestUpsampleNearest2DBackward(TestCase): for item in shape_format: cpu_input, npu_input = create_common_tensor(item[0], 0, 100) - if cpu_input == torch.float16: + if cpu_input.dtype == torch.float16: cpu_input = cpu_input.to(torch.float32) cpu_output, cpu_grad = self.cpu_op_exec(cpu_input, item[1]) diff --git a/torch_npu/csrc/aten/npu_native_functions.yaml b/torch_npu/csrc/aten/npu_native_functions.yaml index b4c04c619f..9210637081 100644 --- a/torch_npu/csrc/aten/npu_native_functions.yaml +++ b/torch_npu/csrc/aten/npu_native_functions.yaml @@ -1676,8 +1676,6 @@ supported: - replication_pad3d - replication_pad3d_backward.grad_input - replication_pad3d_backward - - upsample_bilinear2d.vec - - upsample_bilinear2d_backward.vec - upsample_trilinear3d.vec - upsample_trilinear3d_backward.vec - upsample_bicubic2d.vec diff --git a/torch_npu/csrc/aten/ops/FloorKernelNpu.cpp b/torch_npu/csrc/aten/ops/FloorKernelNpu.cpp index 5784de81de..ffecfd9a82 100644 --- a/torch_npu/csrc/aten/ops/FloorKernelNpu.cpp +++ b/torch_npu/csrc/aten/ops/FloorKernelNpu.cpp @@ -48,7 +48,7 @@ at::Tensor& NPUNativeFunctions::floor_(at::Tensor& self) { } at::Tensor NPUNativeFunctions::floor(const at::Tensor& self) { - Tensor result = OpPreparation::ApplyTensor(self); + at::Tensor result = OpPreparation::ApplyTensor(self); floor_out_npu_nocheck(self, result); return result; } diff --git a/torch_npu/csrc/aten/ops/SortKernelNpu.cpp b/torch_npu/csrc/aten/ops/SortKernelNpu.cpp index 1a1fb161fb..2c27cf1144 100644 --- a/torch_npu/csrc/aten/ops/SortKernelNpu.cpp +++ b/torch_npu/csrc/aten/ops/SortKernelNpu.cpp @@ -54,7 +54,7 @@ tuple NPUNativeFunctions::sort_out( } std::swap(perm[dim], perm[lastDim]); - at::Tensor transposeSelf = at::npu_transpose(self, perm); + at::Tensor transposeSelf = NPUNativeFunctions::npu_transpose(self, perm); auto outputSize = transpose_npu_output_size(values, perm); at::Tensor transposeValues = OpPreparation::ApplyTensor(values, outputSize); at::Tensor transposeIndices =OpPreparation::ApplyTensor(indices, outputSize); @@ -62,8 +62,8 @@ tuple NPUNativeFunctions::sort_out( sort_out_npu_no_transpose( transposeSelf, lastDim, descending, transposeValues, transposeIndices); - at::npu_transpose_out(values, transposeValues, perm); - at::npu_transpose_out(indices, transposeIndices, perm); + NPUNativeFunctions::npu_transpose_out(transposeValues, perm, values); + NPUNativeFunctions::npu_transpose_out(transposeIndices, perm, indices); } else { sort_out_npu_no_transpose( self, lastDim, descending, values, indices); diff --git a/torch_npu/csrc/aten/ops/UpsampleBilinear2dBackwardKernelNpu.cpp b/torch_npu/csrc/aten/ops/UpsampleBilinear2dBackwardKernelNpu.cpp index b9acc93baf..3db040139d 100644 --- a/torch_npu/csrc/aten/ops/UpsampleBilinear2dBackwardKernelNpu.cpp +++ b/torch_npu/csrc/aten/ops/UpsampleBilinear2dBackwardKernelNpu.cpp @@ -76,7 +76,7 @@ at::Tensor& NPUNativeFunctions::upsample_bilinear2d_backward_out( bool align_corners, c10::optional scales_h, c10::optional scales_w, - Tensor& grad_input) { + at::Tensor& grad_input) { OpPreparation::CheckOut( {grad_output}, @@ -119,7 +119,7 @@ at::Tensor NPUNativeFunctions::upsample_bilinear2d_backward( bool isAicore = upsample_bilinear2d_backward_check_is_aicore(grad_output); if (!isAicore) { if (grad_output.scalar_type() != at::ScalarType::Float) { - grad_output = grad_output.npu_dtype_cast(at::ScalarType::Float); + grad_output = NPUNativeFunctions::npu_dtype_cast(grad_output, at::ScalarType::Float); } } auto outputSize = upsample_bilinear2d_backward_npu_output_size( diff --git a/torch_npu/csrc/aten/ops/UpsampleBilinear2dKernelNpu.cpp b/torch_npu/csrc/aten/ops/UpsampleBilinear2dKernelNpu.cpp index 689e94a399..15b3318e69 100644 --- a/torch_npu/csrc/aten/ops/UpsampleBilinear2dKernelNpu.cpp +++ b/torch_npu/csrc/aten/ops/UpsampleBilinear2dKernelNpu.cpp @@ -30,7 +30,7 @@ at::Tensor& upsample_bilinear2d_out_npu_nocheck( bool half_pixel_centers = !align_corners; int64_t H = output_size[0]; int64_t W = output_size[1]; - SmallVector attr_size = {H, W}; + at::SmallVector attr_size = {H, W}; cmd.Name("ResizeBilinearV2") .Input(self) .Input(attr_size, at::kInt) @@ -50,7 +50,7 @@ at::Tensor& NPUNativeFunctions::upsample_bilinear2d_out( at::Tensor& result){ at::Tensor self = self_ex; if (self.scalar_type() != at::ScalarType::Float) { - self = self.npu_dtype_cast(at::ScalarType::Float); + self = NPUNativeFunctions::npu_dtype_cast(self, at::ScalarType::Float); } auto outputSize = upsample_bilinear2d_npu_output_size( self, output_size, align_corners, scales_h, scales_w); @@ -81,7 +81,7 @@ at::Tensor NPUNativeFunctions::upsample_bilinear2d( c10::optional scales_w) { at::Tensor self = self_ex; if (self.scalar_type() != at::ScalarType::Float) { - self = self.npu_dtype_cast(at::ScalarType::Float); + self = NPUNativeFunctions::npu_dtype_cast(self, at::ScalarType::Float); } auto outputSize = upsample_bilinear2d_npu_output_size( self, output_size, align_corners, scales_h, scales_w); @@ -90,7 +90,7 @@ at::Tensor NPUNativeFunctions::upsample_bilinear2d( upsample_bilinear2d_out_npu_nocheck( result, self, output_size, align_corners, scales_h, scales_w); if (result.dtype() != self_ex.dtype()) { - result = result.npu_dtype_cast(self_ex.scalar_type()); + result = NPUNativeFunctions::npu_dtype_cast(result, self_ex.scalar_type()); } return result; } -- Gitee