From 71dcdb2a9f57cd6a81e37367db558577610f0a88 Mon Sep 17 00:00:00 2001 From: pipihugh Date: Wed, 9 Mar 2022 14:39:04 +0800 Subject: [PATCH] =?UTF-8?q?1.8=E7=AE=97=E5=AD=90=E7=A7=BB=E6=A4=8D?= =?UTF-8?q?=EF=BC=9Aargmin=E3=80=81atan2=E3=80=81cdist=5Fbackward=E3=80=81?= =?UTF-8?q?cross=E3=80=81diag?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- test/test_network_ops/test_argmin.py | 72 +++++++++++ test/test_network_ops/test_atan2.py | 74 +++++++++++ test/test_network_ops/test_cdist_backward.py | 108 ++++++++++++++++ test/test_network_ops/test_copy_.py | 53 ++++++++ test/test_network_ops/test_cross.py | 71 ++++++++++ test/test_network_ops/test_diag.py | 122 ++++++++++++++++++ test/test_network_ops/test_to.py | 50 +++++++ torch_npu/csrc/aten/ops/ArgminKernelNpu.cpp | 50 +++++++ torch_npu/csrc/aten/ops/Atan2KernelNpu.cpp | 70 ++++++++++ .../csrc/aten/ops/CdistBackwardKernelNpu.cpp | 85 ++++++++++++ torch_npu/csrc/aten/ops/CrossKernelNpu.cpp | 72 +++++++++++ torch_npu/csrc/aten/ops/DiagKernelNpu.cpp | 100 ++++++++++++++ 12 files changed, 927 insertions(+) create mode 100644 test/test_network_ops/test_argmin.py create mode 100644 test/test_network_ops/test_atan2.py create mode 100644 test/test_network_ops/test_cdist_backward.py create mode 100644 test/test_network_ops/test_copy_.py create mode 100644 test/test_network_ops/test_cross.py create mode 100644 test/test_network_ops/test_diag.py create mode 100644 test/test_network_ops/test_to.py create mode 100644 torch_npu/csrc/aten/ops/ArgminKernelNpu.cpp create mode 100644 torch_npu/csrc/aten/ops/Atan2KernelNpu.cpp create mode 100644 torch_npu/csrc/aten/ops/CdistBackwardKernelNpu.cpp create mode 100644 torch_npu/csrc/aten/ops/CrossKernelNpu.cpp create mode 100644 torch_npu/csrc/aten/ops/DiagKernelNpu.cpp diff --git a/test/test_network_ops/test_argmin.py b/test/test_network_ops/test_argmin.py new file mode 100644 index 0000000000..d63c756e0c --- /dev/null +++ b/test/test_network_ops/test_argmin.py @@ -0,0 +1,72 @@ +# Copyright (c) 2020 Huawei Technologies Co., Ltd +# All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# coding: utf-8 + +import torch +import torch_npu +import numpy as np + +from torch_npu.testing.testcase import TestCase, run_tests +from torch_npu.testing.common_utils import create_common_tensor +from torch_npu.testing.decorator import Dtypes, instantiate_tests + +@instantiate_tests +class TestArgmin(TestCase): + @Dtypes(torch.float, torch.half) + def test_argmin(self, device, dtype): + inputValues = [-1000, -1, 0, 0.5, 1, 2, 1000] + expectedOutput = [0.0000, 0.2689, 0.5, 0.6225, 0.7311, 0.8808, 1.000] + precision_4dps = 0.0002 + a = torch.tensor(inputValues, dtype=dtype, device=device) + self.assertRtolEqual(torch.tensor(inputValues, dtype=dtype, device=device).sigmoid().cpu(), + torch.tensor(expectedOutput, dtype=dtype, device=device).cpu(), + precision_4dps) + + def cpu_op_exec(self, input1, dims, keepdim=False): + output = torch.argmin(input1, dim=dims, keepdim=keepdim) + if output.dtype != torch.int32: + output = output.to(torch.int32) + output = output.numpy() + return output + + def npu_op_exec(self, input1, dims, keepdim=False): + output = torch.argmin(input1, dim=dims, keepdim=keepdim) + output = output.to("cpu") + if output.dtype != torch.int32: + output = output.to(torch.int32) + output = output.numpy() + return output + + def test_argmin_shape_format(self): + shape_format = [ + [ [np.float32, 0, (6, 4)], 0, False], + [ [np.float32, 2, (6, 4)], 1, True ], + [ [np.float32, 0, (2, 4, 5)], 2, True ], + [ [np.float32, 0, (1, 2, 3, 3)], 2, False], + [ [np.float32, 0, (1, 2, 3, 3)], 3, True ], + [ [np.float16, 0, (6, 4)], 0, False], + [ [np.float16, 2, (6, 4)], 1, True ], + [ [np.float16, 0, (2, 4, 5)], 2, True ], + [ [np.float16, 3, (1, 2, 3, 3)], 2, False], + [ [np.float16, 0, (1, 2, 3, 3)], 3, True ], + ] + for item in shape_format: + cpu_input, npu_input = create_common_tensor(item[0], 1, 100) + cpu_output = self.cpu_op_exec(cpu_input, item[1], keepdim=item[2]) + npu_output = self.npu_op_exec(npu_input, item[1], keepdim=item[2]) + self.assertRtolEqual(cpu_output, npu_output) + +if __name__ == "__main__": + run_tests() diff --git a/test/test_network_ops/test_atan2.py b/test/test_network_ops/test_atan2.py new file mode 100644 index 0000000000..bded83715f --- /dev/null +++ b/test/test_network_ops/test_atan2.py @@ -0,0 +1,74 @@ +# Copyright (c) 2020, Huawei Technologies.All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch_npu +import numpy as np + +from torch_npu.testing.testcase import TestCase, run_tests +from torch_npu.testing.common_utils import create_common_tensor + +class TestAtan2(TestCase): + def cpu_op_exec(self,input1, input2): + output = torch.atan2(input1, input2) + output = output.numpy() + return output + + def npu_op_exec(self,input1, input2): + output = torch.atan2(input1, input2) + output = output.to("cpu") + output = output.numpy() + return output + + def npu_op_exec_out(self,input1, input2, out): + torch.atan2(input1, input2, out=out) + output = out.to("cpu") + output = output.numpy() + return output + + def test_atan2_common_shape_format(self): + shape_format = [ + [[np.float16, 0, [4, 12, 12, 128]], [np.float16, 0, [4]]], + [[np.float16, 0, [4, 128]], [np.float16, 0, [4, 256, 12]]], + [[np.float32, 0, [4, 12, 12, 128]], [np.float32, 0, [4]]], + [[np.float32, 0, [4, 128]], [np.float32, 0, [4, 256, 12]]], + [[np.float16, 2, [4, 12, 12, 128]], [np.float16, 0, [4]]], + [[np.float16, 3, [4, 128]], [np.float16, 0, [4, 256, 12]]], + [[np.float32, 2, [4, 12, 12, 128]], [np.float32, 0, [4]]], + [[np.float32, 3, [4, 128]], [np.float32, 0, [4, 256, 12]]], + ] + for item in shape_format: + cpu_input1, npu_input1 = create_common_tensor(item[0], -1, 1) + cpu_input2, npu_input2 = create_common_tensor(item[0], -1, 1) + cpu_out, npu_out = create_common_tensor(item[1], -1, 1) + if cpu_input1.dtype == torch.float16: + cpu_input1 = cpu_input1.to(torch.float32) + if cpu_input2.dtype == torch.float16: + cpu_input2 = cpu_input2.to(torch.float32) + cpu_output = self.cpu_op_exec(cpu_input1, cpu_input2) + npu_output = self.npu_op_exec(npu_input1, npu_input2) + npu_output_out = self.npu_op_exec_out(npu_input1, npu_input2, npu_out) + cpu_output = cpu_output.astype(npu_output.dtype) + self.assertRtolEqual(cpu_output, npu_output) + self.assertRtolEqual(cpu_output, npu_output_out) + + def test_atan2_mix_dtype(self): + npu_input1, npu_input2 = create_common_tensor([np.float32, 0, (2, 3)], 1, 100) + npu_input3, npu_input4 = create_common_tensor([np.float16, 0, (2, 3)], 1, 100) + cpu_output = self.cpu_op_exec(npu_input1, npu_input3) + npu_output = self.npu_op_exec(npu_input2, npu_input4) + self.assertRtolEqual(cpu_output, npu_output) + +if __name__ == "__main__": + run_tests() diff --git a/test/test_network_ops/test_cdist_backward.py b/test/test_network_ops/test_cdist_backward.py new file mode 100644 index 0000000000..fb8e5c5d5d --- /dev/null +++ b/test/test_network_ops/test_cdist_backward.py @@ -0,0 +1,108 @@ +# Copyright (c) 2020 Huawei Technologies Co., Ltd +# Copyright (c) 2019, Facebook CORPORATION. +# All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch_npu +import numpy as np + +from torch_npu.testing.testcase import TestCase, run_tests +from torch_npu.testing.common_utils import create_common_tensor + +class Testcdist(TestCase): + def generate_data(self, min_n, max_n, shape_x, shape_y, src_type): + np.random.seed(10086) + x1 = np.random.uniform(min_n, max_n, shape_x).astype(src_type) + x2 = np.random.uniform(min_n, max_n, shape_y).astype(src_type) + return x1, x2 + + def cdist_backward(self, x1, x2, p, grad, cdist): + x1 = torch.unsqueeze(x1, -2) + x2 = torch.unsqueeze(x2, -3) + grad = torch.unsqueeze(grad, -1) + cdist = torch.unsqueeze(cdist, -1) + diff = x1 - x2 + diff_abs = torch.abs(diff) + nz_cdist = torch.where(cdist == 0, torch.ones_like(cdist), cdist) + sign = torch.where(diff > 0, torch.ones_like(diff), torch.full_like(diff, -1)) + sign = torch.where(diff == 0, torch.zeros_like(diff), sign) + + if p == 0.0: + res = torch.zeros_like(diff) + elif p == 1.0: + res = grad * sign + elif p < 2.0: + res = sign * torch.pow(diff_abs, p - 1.0) * grad / torch.pow(nz_cdist, p - 1.0) + res = torch.where(cdist == 0, torch.zeros_like(res), res) + elif p == 2.0: + res = grad * diff / nz_cdist + res = torch.where(cdist == 0, torch.zeros_like(res), res) + elif p == float("inf"): + mask = torch.where(cdist - diff_abs > 0, torch.zeros_like(diff), torch.ones_like(diff)) + res = grad * sign * mask + else: + res = diff * torch.pow(diff_abs, p - 2) * grad / torch.pow(nz_cdist, p - 1.0) + res = torch.where(cdist == 0, torch.zeros_like(res), res) + res = torch.sum(res, -2) + return res + + def op_exec(self, x1, x2, p, device='cpu'): + is_fp16 = x1.dtype == np.float16 + if device == 'cpu' and is_fp16: + x1 = x1.astype(np.float32) + x2 = x2.astype(np.float32) + x1 = torch.tensor(x1, device=device, requires_grad=True) + x2 = torch.tensor(x2, device=device, requires_grad=True) + y = torch.cdist(x1, x2, p) + grad = torch.ones_like(y, requires_grad=True, device=device) + if device == 'cpu' and is_fp16: + y = y.half() + y = y.float() + out = self.cdist_backward(x1, x2, p, grad, y) + return out.detach().numpy().astype('float16') + y.backward(grad, retain_graph=True) + out = x1.grad.detach().cpu().numpy() + return out + + def test_cdis_backward_common_shape(self): + shape_items = [ + [np.float16, (5, 10), (4, 10)], + [np.float16, (20, 5, 10), (20, 4, 10)], + [np.float32, (5, 10), (4, 10)], + [np.float32, (20, 5, 10), (20, 4, 10)], + ] + p_ranges = [0.0, 0.5, 1.0, 1.5, 2.0, 2.5] + for item in shape_items: + for p in p_ranges: + input1, input2 = self.generate_data(-1, 1, + item[1], item[2], item[0]) + cpu_output = self.op_exec(input1, input2, p, device='cpu') + npu_output = self.op_exec(input1, input2, p, device='npu') + self.assertRtolEqual(cpu_output, npu_output) + + def test_cdis_backward_input_range(self): + item = [np.float32, (20, 5, 5), (20, 4, 5)] + p_ranges = [0.0, 0.5, 1.0, 1.5, 2.0, 2.5] + input_ragnes = [(-0.1, 0.1), (-10, 10), (-20, 20)] + for p in p_ranges: + for min_max in input_ragnes: + input1, input2 = self.generate_data(min_max[0], min_max[1], + item[1], item[2], item[0]) + cpu_output = self.op_exec(input1, input2, p, device='cpu') + npu_output = self.op_exec(input1, input2, p, device='npu') + self.assertRtolEqual(cpu_output, npu_output) + +if __name__ == "__main__": + run_tests() diff --git a/test/test_network_ops/test_copy_.py b/test/test_network_ops/test_copy_.py new file mode 100644 index 0000000000..d72ffa1aa5 --- /dev/null +++ b/test/test_network_ops/test_copy_.py @@ -0,0 +1,53 @@ +# Copyright (c) 2020 Huawei Technologies Co., Ltd +# Copyright (c) 2019, Facebook CORPORATION. +# All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch_npu +import numpy as np + +from torch_npu.testing.testcase import TestCase, run_tests +from torch_npu.testing.common_utils import create_common_tensor + +class TestCopy(TestCase): + def cpu_op_exec(self, input1, input2): + output = input1.copy_(input2); + output = output.numpy() + return output + + def npu_op_exec(self, input1, input2): + input1 = input1.to("npu") + input2 = input2.to("npu") + output = input1.copy_(input2); + output = output.to("cpu") + output = output.numpy() + return output + + def test_copy__(self): + format_list = [0] + shape_list = [(4, 1), (4, 3, 1)] + dtype_list = [np.float32, np.int32, np.float16] + shape_format = [ + [i, j, k] for i in dtype_list for j in format_list for k in shape_list + ] + for item in shape_format: + cpu_input1, npu_input1 = create_common_tensor(item, 0, 100) + cpu_input2, npu_input2 = create_common_tensor(item, 0, 100) + cpu_output = self.cpu_op_exec(cpu_input1, cpu_input2) + npu_output = self.npu_op_exec(npu_input1, npu_input2) + self.assertRtolEqual(cpu_output, npu_output) + +if __name__ == "__main__": + run_tests() diff --git a/test/test_network_ops/test_cross.py b/test/test_network_ops/test_cross.py new file mode 100644 index 0000000000..e7e3e08c21 --- /dev/null +++ b/test/test_network_ops/test_cross.py @@ -0,0 +1,71 @@ +# Copyright (c) 2020, Huawei Technologies.All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch_npu +import numpy as np + +from torch_npu.testing.testcase import TestCase, run_tests +from torch_npu.testing.common_utils import create_common_tensor + +class TestCross(TestCase): + def cpu_op_exec(self, input1, input2): + output = torch.cross(input1, input2) + output = output.numpy() + return output + + def cpu_op_exec_dim(self, input1, input2, dim): + output = torch.cross(input1, input2, dim) + output = output.numpy() + return output + + def npu_op_exec(self, input1, input2): + output = torch.cross(input1, input2) + output = output.to("cpu") + output = output.numpy() + return output + + def npu_op_exec_dim(self, input1, input2, dim): + output = torch.cross(input1, input2, dim) + output = output.to("cpu") + output = output.numpy() + return output + + def test_median_shape_format_dim(self): + shape_format_dim = [ + [[np.float32, -1, (4, 3)], 1], + [[np.float32, -1, (4, 3, 3)], 1], + [[np.float32, -1, (3, 2)], 0], + ] + for item in shape_format_dim: + cpu_input_left, npu_input_left = create_common_tensor(item[0], -2, 2) + cpu_input_right, npu_input_right = create_common_tensor(item[0], -2, 2) + cpu_output = self.cpu_op_exec_dim(cpu_input_left, cpu_input_right, item[1]) + npu_output = self.npu_op_exec_dim(npu_input_left, npu_input_right, item[1]) + self.assertRtolEqual(cpu_output, npu_output) + + def test_median_shape_format(self): + shape_format = [ + [[np.float32, -1, (5, 3)]], + [[np.float32, -1, (5, 3, 4)]], + ] + for item in shape_format: + cpu_input_left, npu_input_left = create_common_tensor(item[0], -2, 2) + cpu_input_right, npu_input_right = create_common_tensor(item[0], -2, 2) + cpu_output = self.cpu_op_exec(cpu_input_left, cpu_input_right) + npu_output = self.npu_op_exec(npu_input_left, npu_input_right) + self.assertRtolEqual(cpu_output, npu_output) + +if __name__ == "__main__": + run_tests() diff --git a/test/test_network_ops/test_diag.py b/test/test_network_ops/test_diag.py new file mode 100644 index 0000000000..521d5fc440 --- /dev/null +++ b/test/test_network_ops/test_diag.py @@ -0,0 +1,122 @@ +# Copyright (c) 2020, Huawei Technologies.All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch_npu +import numpy as np + +from torch_npu.testing.testcase import TestCase, run_tests +from torch_npu.testing.common_utils import create_common_tensor + +class TestDiag(TestCase): + def cpu_op_exec(self, input1, diagonal): + output = torch.diag(input1, diagonal=diagonal) + output = output.numpy() + return output + + def npu_op_exec(self, input1, diagonal): + output = torch.diag(input1, diagonal=diagonal) + output = output.to("cpu") + output = output.numpy() + return output + + def cpu_op_exec_out(self, input1, diagonal, out): + torch.diag(input1, diagonal=diagonal, out=out) + output = out.numpy() + return output + + def npu_op_exec_out(self, input1, diagonal, out): + torch.diag(input1, diagonal=diagonal, out=out) + output = out.to("cpu") + output = output.numpy() + return output + + def cpu_op_exec_fp16(self, input1, diagonal): + input1 = input1.to(torch.float32) + output = torch.diag(input1, diagonal) + output = output.numpy() + output = output.astype(np.float16) + return output + + def generate_npu_data(self, min_d, max_d, shape, dtype): + input1 = np.random.uniform(min_d, max_d, shape).astype(dtype) + output1 = np.random.uniform(min_d, max_d, shape).astype(dtype) + npu_input1 = torch.from_numpy(input1) + npu_output1 = torch.from_numpy(output1) + return npu_input1, npu_output1 + + def test_diag_common_shape_format(self): + shape_format = [ + [[np.float32, -1, [16]], 0], # test the condition of 1-dimension + [[np.float32, -1, [1024]], 0], + [[np.float32, -1, [5, 5]], 0], # test the condition of 2-dimension + [[np.float32, -1, [256, 256]], 0], + ] + for item in shape_format: + cpu_input, npu_input = create_common_tensor(item[0], 0, 100) + cpu_output = self.cpu_op_exec(cpu_input, item[1]) + npu_output = self.npu_op_exec(npu_input, item[1]) + self.assertRtolEqual(cpu_output, npu_output) + + def test_diag_float32_out(self): + shape_format = [ + [[np.float32, -1, [16]], [np.float32, -1, [20]], 0], # test the condition of 1-dimension + [[np.float32, -1, [1024]], [np.float32, -1, [20, 20]], 0], + [[np.float32, -1, [5, 5]], [np.float32, -1, [5, 5, 5]], 0], # test the condition of 2-dimension + [[np.float32, -1, [256, 256]], [np.float32, -1, [256]], 0], + ] + for item in shape_format: + cpu_input1, npu_input1 = create_common_tensor(item[0], 0, 100) + cpu_input2, npu_input2 = create_common_tensor(item[1], 0, 100) + cpu_output = self.cpu_op_exec(cpu_input1, item[2]) + npu_output = self.npu_op_exec_out(npu_input1, item[2], npu_input2) + self.assertRtolEqual(cpu_output, npu_output) + + npu_input1, npu_output1 = self.generate_npu_data(0, 100, (5,5), np.float32) + cpu_output = self.cpu_op_exec(npu_input1, 0) + npu_output = self.npu_op_exec_out(npu_input1, 0, npu_output1) + self.assertRtolEqual(cpu_output, npu_output) + + def test_diag_float16_out(self): + shape_format = [ + [[np.float16, -1, [16]], [np.float16, -1, [20]], 0], # test the condition of 1-dimension + [[np.float16, -1, [1024]], [np.float16, -1, [20, 20]], 0], + [[np.float16, -1, [5, 5]], [np.float16, -1, [5, 5, 5]], 0], # test the condition of 2-dimension + [[np.float16, -1, [256, 256]], [np.float16, -1, [256]], 0], + ] + for item in shape_format: + cpu_input1, npu_input1 = create_common_tensor(item[0], 0, 100) + cpu_input2, npu_input2 = create_common_tensor(item[1], 0, 100) + if cpu_input1.dtype == torch.float16: + cpu_input1 = cpu_input1.to(torch.float32) + cpu_output = self.cpu_op_exec(cpu_input1, item[2]) + npu_output = self.npu_op_exec_out(npu_input1, item[2], npu_input2) + cpu_output = cpu_output.astype(npu_output.dtype) + self.assertRtolEqual(cpu_output, npu_output) + + def test_diag_float16_shape_format(self): + shape_format = [ + [[np.float16, -1, [4]], 0], # test the condition of 1-dimension + [[np.float16, -1, [512]], 0], + [[np.float16, -1, [4, 4]], 0], # test the condition of 2-dimension + [[np.float16, -1, [256, 256]], 0], + ] + for item in shape_format: + cpu_input, npu_input = create_common_tensor(item[0], 1, 100) + cpu_output = self.cpu_op_exec_fp16(cpu_input, item[1]) + npu_output = self.npu_op_exec(npu_input, item[1]) + self.assertRtolEqual(cpu_output, npu_output) + +if __name__ == "__main__": + run_tests() diff --git a/test/test_network_ops/test_to.py b/test/test_network_ops/test_to.py new file mode 100644 index 0000000000..3f8c7076be --- /dev/null +++ b/test/test_network_ops/test_to.py @@ -0,0 +1,50 @@ +# Copyright (c) 2020 Huawei Technologies Co., Ltd +# Copyright (c) 2019, Facebook CORPORATION. +# All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch_npu +import numpy as np + +from torch_npu.testing.testcase import TestCase, run_tests +from torch_npu.testing.common_utils import create_common_tensor + +class TestTo(TestCase): + def cpu_op_exec(self, input1, target): + output = input1.to(target) + output = output.cpu().numpy() + return output + + def npu_op_exec(self,input1, target): + output = input1.to(target) + output = output.cpu().numpy() + return output + + def test_to(self, device="npu"): + shape_format = [ + [np.float32, 0, [3, 3]], + [np.float16, 0, [4, 3]], + [np.int32, 0, [3, 5]], + ] + targets = [torch.float16, torch.float32, torch.int32, 'cpu', 'npu'] + for item in shape_format: + cpu_input1, npu_input1 = create_common_tensor(item, -100, 100) + for target in targets: + cpu_output = self.cpu_op_exec(cpu_input1, target) + npu_output = self.npu_op_exec(npu_input1, target) + self.assertRtolEqual(cpu_output, npu_output) + +if __name__ == "__main__": + run_tests() diff --git a/torch_npu/csrc/aten/ops/ArgminKernelNpu.cpp b/torch_npu/csrc/aten/ops/ArgminKernelNpu.cpp new file mode 100644 index 0000000000..c0fd98cc59 --- /dev/null +++ b/torch_npu/csrc/aten/ops/ArgminKernelNpu.cpp @@ -0,0 +1,50 @@ +// Copyright (c) 2020, Huawei Technologies.All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" + +namespace at_npu { +namespace native { + +at::Tensor NPUNativeFunctions::argmin( + const at::Tensor& self, + c10::optional dim, + bool keepdim) { + TORCH_CHECK( + self.numel() > 0, + "cannot perform reduction function argmin on a " + "tensor with no elements because the operation does not have an identity"); + at::Tensor input = dim.has_value() ? self : self.reshape({-1}); + int64_t realDim = dim.has_value() ? dim.value() : 0; + bool realKeepDim = dim.has_value() ? keepdim : false; + auto outputSize = reduce_ops_npu_output_size(input, realDim, realKeepDim); + at::Tensor result = OpPreparation::ApplyTensorWithFormat( + outputSize, + self.options().dtype(at::kInt), + ACL_FORMAT_ND); + c10::SmallVector DimVec = {realDim}; + OpCommand cmd; + cmd.Name("ArgMin") + .Input(input) + .Input(DimVec, at::kInt) + .Output(result) + .Attr("keep_dims", realKeepDim) + .Run(); + result = NPUNativeFunctions::npu_dtype_cast(result, at::ScalarType::Long); + return result; +} + +} // namespace native +} // namespace at_npu diff --git a/torch_npu/csrc/aten/ops/Atan2KernelNpu.cpp b/torch_npu/csrc/aten/ops/Atan2KernelNpu.cpp new file mode 100644 index 0000000000..433cac4f6a --- /dev/null +++ b/torch_npu/csrc/aten/ops/Atan2KernelNpu.cpp @@ -0,0 +1,70 @@ +// Copyright (c) 2020, Huawei Technologies.All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" + +namespace at_npu { +namespace native { + +at::Tensor& atan2_out_npu_nocheck( + const at::Tensor& self, + const at::Tensor& other, + at::Tensor& result) { + auto unified_result = OpPreparation::binary_op_check(result, self, other, true); + OpCommand cmd; + cmd.Name("Atan2") + .Expect(unified_result) + .Input(self) + .Input(other) + .Output(result) + .Run(); + return result; +} + +at::Tensor& NPUNativeFunctions::atan2_out( + const at::Tensor& self, + const at::Tensor& other, + at::Tensor& result) { + auto outputSize = broadcast_ops_npu_output_size(self, other); + OpPreparation::CheckOut( + {self}, + result, + self, + outputSize); + if (!NpuUtils::check_match(&result)) { + at::Tensor contiguousResult = NpuUtils::format_contiguous(result); + at::Tensor result = atan2_out_npu_nocheck(self, other, contiguousResult); + NpuUtils::format_fresh_view(result, contiguousResult); + } else { + atan2_out_npu_nocheck(self, other, result); + } + return result; +} + +at::Tensor NPUNativeFunctions::atan2(const at::Tensor& self, const at::Tensor& other) { + auto outputSize = broadcast_ops_npu_output_size(self, other); + at::Tensor result = OpPreparation::ApplyTensor(self, outputSize); + atan2_out_npu_nocheck(self, other, result); + return result; +} + +at::Tensor& NPUNativeFunctions::atan2_(at::Tensor& self, const at::Tensor& other) { + OpPreparation::CheckMemory({self, other}, {self}); + NPUNativeFunctions::atan2_out(self, other, self); + return self; +} + +} // namespace native +} // namespace at_npu diff --git a/torch_npu/csrc/aten/ops/CdistBackwardKernelNpu.cpp b/torch_npu/csrc/aten/ops/CdistBackwardKernelNpu.cpp new file mode 100644 index 0000000000..085eb8548f --- /dev/null +++ b/torch_npu/csrc/aten/ops/CdistBackwardKernelNpu.cpp @@ -0,0 +1,85 @@ +// Copyright (c) 2020, Huawei Technologies.All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" + +namespace at_npu { +namespace native { + +static void check_cdist_backward_input( + const at::Tensor& grad, + const at::Tensor& x1, + const at::Tensor& x2, + const double p, + const at::Tensor& cdist) { + TORCH_CHECK(x1.is_contiguous(), "_cdist_backward requires X1 to be contiguous"); + TORCH_CHECK(x2.is_contiguous(), "_cdist_backward requires X2 to be contiguous"); + TORCH_CHECK(cdist.is_contiguous(), "_cdist_backward requires dist to be contiguous"); + TORCH_CHECK(grad.is_contiguous(), "_cdist_backward requires grad to be contiguous"); + auto device1 = x1.device().type(); + TORCH_CHECK(device1 == at::kCPU || device1 == at::kCUDA || device1 == at::kNPU, "_cdist_backward only supports CPU, CUDA and NPU devices, X1 got: ", device1); + auto device2 = x2.device().type(); + TORCH_CHECK(device2 == at::kCPU || device2 == at::kCUDA || device2 == at::kNPU, "_cdist_backward only supports CPU, CUDA and NPU devices, X2 got: ", device2); + TORCH_CHECK(p <= std::numeric_limits::max(), "npu dose not support float64" ); +} + +at::Tensor NPUNativeFunctions::_cdist_backward( + const at::Tensor& grad, + const at::Tensor& x1, + const at::Tensor& x2, + const double p, + const at::Tensor& cdist) { + check_cdist_backward_input(grad, x1, x2, p, cdist); + float p_float; + if (std::isinf(p)) { + p_float = std::numeric_limits::infinity(); + } + else { + p_float = static_cast(p); + } + + auto dim1 = x1.dim(); + auto dim2 = x2.dim(); + c10::SmallVector tensor1_expand_size = array_to_small_vector(x1.sizes()); + tensor1_expand_size.insert(tensor1_expand_size.begin() + (dim1 - 1), 1); + c10::SmallVector tensor2_expand_size = array_to_small_vector(x2.sizes()); + tensor2_expand_size.insert(tensor2_expand_size.begin() + (dim2 - 2), 1); + c10::SmallVector grad_expand_size = array_to_small_vector(grad.sizes()); + grad_expand_size.insert(grad_expand_size.end(), 1); + c10::SmallVector cdist_expand_size = array_to_small_vector(cdist.sizes()); + cdist_expand_size.insert(cdist_expand_size.end(), 1); + std::vector tensor_broadcast_size = at::infer_size(tensor1_expand_size, tensor2_expand_size); + + at::Tensor tensor1_broadcast = x1.view(tensor1_expand_size).expand(tensor_broadcast_size).contiguous(); + at::Tensor tensor2_broadcast = x2.view(tensor2_expand_size).expand(tensor_broadcast_size).contiguous(); + at::Tensor grad_broadcast = grad.view(grad_expand_size).expand(tensor_broadcast_size).contiguous(); + at::Tensor cdist_broadcast = cdist.view(cdist_expand_size).expand(tensor_broadcast_size).contiguous(); + auto outputSize = x1.sizes(); + at::Tensor result = OpPreparation::ApplyTensor(tensor1_broadcast, outputSize); + + OpCommand cmd; + cmd.Name("CdistGrad") + .Input(grad_broadcast) + .Input(tensor1_broadcast) + .Input(tensor2_broadcast) + .Input(cdist_broadcast) + .Attr("p", p_float) + .Output(result) + .Run(); + return result; +} + +} // namespace native +} // namespace at_npu diff --git a/torch_npu/csrc/aten/ops/CrossKernelNpu.cpp b/torch_npu/csrc/aten/ops/CrossKernelNpu.cpp new file mode 100644 index 0000000000..6f3b663c86 --- /dev/null +++ b/torch_npu/csrc/aten/ops/CrossKernelNpu.cpp @@ -0,0 +1,72 @@ +// Copyright (c) 2020, Huawei Technologies.All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "torch_npu/csrc/framework/utils/CalcuOpUtil.h" +#include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" + +namespace at_npu { +namespace native { + +at::Tensor cross_dest_output(const at::Tensor& self, const at::Tensor& other) { + bool isSelfWrapped = CalcuOpUtil::is_scalar_wrapped_to_tensor(self); + return isSelfWrapped ? other : self; +} + +at::Tensor& cross_out_npu_nocheck( + const at::Tensor& self, + const at::Tensor& other, + c10::optional dim, + at::Tensor& result) { + int64_t realDim = dim.has_value() ? dim.value() : -65530; + OpCommand cmd; + cmd.Name("Cross") + .Input(self) + .Input(other) + .Output(result) + .Attr("dim", realDim) + .Run(); + return result; +} + +at::Tensor& NPUNativeFunctions::cross_out( + const at::Tensor& self, + const at::Tensor& other, + c10::optional dim, + at::Tensor& result){ + auto outputSize = broadcast_ops_npu_output_size(self, other); + at::Tensor outputTensor = cross_dest_output(self, other); + OpPreparation::CheckOut( + {self}, + result, + CalcuOpUtil::get_tensor_npu_format(outputTensor), + self.scalar_type(), + outputSize); + cross_out_npu_nocheck(self, other, dim, result); + return result; +} + +at::Tensor NPUNativeFunctions::cross( + const at::Tensor& self, + const at::Tensor& other, + c10::optional dim) { + auto outputSize = broadcast_ops_npu_output_size(self, other); + at::Tensor outputTensor = cross_dest_output(self, other); + at::Tensor result = OpPreparation::ApplyTensor(outputSize, self.options(), outputTensor); + cross_out_npu_nocheck(self, other, dim, result); + return result; +} + +} // namespace native +} // namespace at_npu diff --git a/torch_npu/csrc/aten/ops/DiagKernelNpu.cpp b/torch_npu/csrc/aten/ops/DiagKernelNpu.cpp new file mode 100644 index 0000000000..ac13a89763 --- /dev/null +++ b/torch_npu/csrc/aten/ops/DiagKernelNpu.cpp @@ -0,0 +1,100 @@ +// Copyright (c) 2020, Huawei Technologies.All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "torch_npu/csrc/framework/utils/CalcuOpUtil.h" +#include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" + +namespace at_npu { +namespace native { +namespace { +c10::SmallVector diag_npu_output_size( + const at::Tensor& self, + int64_t diagonal) { + c10::SmallVector shape; + if (self.dim() == 1) { + shape.emplace_back(self.size(0) + diagonal); + shape.emplace_back(self.size(0) + diagonal); + return shape; + } + int64_t m = self.size(0); + int64_t n = self.size(1); + if (m == n) { + shape.emplace_back(m - diagonal); + } else if (m < n) { + shape.emplace_back(diagonal <= n - m ? m : n - diagonal); + } else { + shape.emplace_back(n - diagonal); + } + return shape; +} +} // namespace + +at::Tensor& diag_out_npu_nocheck( + const at::Tensor& self, + int64_t diagonal, + at::Tensor& result) { + OpCommand cmd; + if (self.dim() == 1) { + cmd.Name("Diag"); + } else { + cmd.Name("DiagPart"); + } + cmd.Input(self) + .Output(result) + .Attr("diagonal", diagonal) + .Run(); + return result; +} + +at::Tensor& NPUNativeFunctions::diag_out( + const at::Tensor& self, + int64_t diagonal, + at::Tensor& result) { + TORCH_CHECK((self.dim() == 1) || (self.dim() == 2), + "Value should be a 1-dimensional tensor or 2-dimensional tensor, but got ", self.dim()); + diagonal = make_wrap_dim(diagonal, self.dim()); + TORCH_CHECK((self.dim() == 1) || (self.dim() == 2 && diagonal <= self.size(0) && diagonal <= self.size(1)), + "If the value is 2-dimensional tensor, the diagonal shoule less than shape.Diagonal is ", diagonal); + + auto outputSize = diag_npu_output_size(self, diagonal); + OpPreparation::CheckOut( + {self}, + result, + CalcuOpUtil::get_tensor_npu_format(self), + self.scalar_type(), + outputSize); + OpPipeWithDefinedOut pipe; + return pipe.CheckMemory({self}, {result}) + .Func([&self, &diagonal](at::Tensor& result){diag_out_npu_nocheck(self, diagonal, result);}) + .Call(result); +} + +at::Tensor NPUNativeFunctions::diag( + const at::Tensor& self, + int64_t diagonal) { + TORCH_CHECK((self.dim() == 1) || (self.dim() == 2), + "Value should be a 1-dimensional tensor or 2-dimensional tensor, but got ", self.dim()); + diagonal = make_wrap_dim(diagonal, self.dim()); + TORCH_CHECK((self.dim() == 1) || (self.dim() == 2 && diagonal <= self.size(0) && diagonal <= self.size(1)), + "If the value is 2-dimensional tensor, the diagonal shoule less than shape.Diagonal is ", diagonal); + + auto outputSize = diag_npu_output_size(self, diagonal); + at::Tensor result = OpPreparation::ApplyTensor(self, outputSize); + diag_out_npu_nocheck(self, diagonal, result); + return result; +} + +} // namespace native +} // namespace at_npu -- Gitee