From 7293c5367f722ed63216a71d6794a900e336c92b Mon Sep 17 00:00:00 2001 From: wangxiao Date: Thu, 27 Jan 2022 17:45:09 +0800 Subject: [PATCH 1/4] __and__, __rshift__, addmv, addr, affine_grid_generator, all, any, arange, argmax --- test/test_network_ops/test___and__.py | 64 +++++++ test/test_network_ops/test__rshift__.py | 70 +++++++ test/test_network_ops/test_addmv.py | 107 +++++++++++ test/test_network_ops/test_addr.py | 74 ++++++++ .../test_affine_grid_generator.py | 84 +++++++++ test/test_network_ops/test_all.py | 90 +++++++++ test/test_network_ops/test_any.py | 95 ++++++++++ test/test_network_ops/test_arange.py | 55 ++++++ test/test_network_ops/test_argmax.py | 102 ++++++++++ torch_npu/csrc/aten/ops/AddmvKernelNpu.cpp | 85 +++++++++ torch_npu/csrc/aten/ops/AddrKernelNpu.cpp | 78 ++++++++ .../aten/ops/AffineGridGeneratorKernelNpu.cpp | 79 ++++++++ torch_npu/csrc/aten/ops/AllKernelNpu.cpp | 112 +++++++++++ torch_npu/csrc/aten/ops/AnyKernelNpu.cpp | 115 ++++++++++++ torch_npu/csrc/aten/ops/ArangeKernelNpu.cpp | 177 ++++++++++++++++++ torch_npu/csrc/aten/ops/ArgmaxKernelNpu.cpp | 49 +++++ torch_npu/csrc/aten/ops/__And__KernelNpu.cpp | 88 +++++++++ .../csrc/aten/ops/__Rshift__KernelNpu.cpp | 69 +++++++ 18 files changed, 1593 insertions(+) create mode 100644 test/test_network_ops/test___and__.py create mode 100644 test/test_network_ops/test__rshift__.py create mode 100644 test/test_network_ops/test_addmv.py create mode 100644 test/test_network_ops/test_addr.py create mode 100644 test/test_network_ops/test_affine_grid_generator.py create mode 100644 test/test_network_ops/test_all.py create mode 100644 test/test_network_ops/test_any.py create mode 100644 test/test_network_ops/test_arange.py create mode 100644 test/test_network_ops/test_argmax.py create mode 100644 torch_npu/csrc/aten/ops/AddmvKernelNpu.cpp create mode 100644 torch_npu/csrc/aten/ops/AddrKernelNpu.cpp create mode 100644 torch_npu/csrc/aten/ops/AffineGridGeneratorKernelNpu.cpp create mode 100644 torch_npu/csrc/aten/ops/AllKernelNpu.cpp create mode 100644 torch_npu/csrc/aten/ops/AnyKernelNpu.cpp create mode 100644 torch_npu/csrc/aten/ops/ArangeKernelNpu.cpp create mode 100644 torch_npu/csrc/aten/ops/ArgmaxKernelNpu.cpp create mode 100644 torch_npu/csrc/aten/ops/__And__KernelNpu.cpp create mode 100644 torch_npu/csrc/aten/ops/__Rshift__KernelNpu.cpp diff --git a/test/test_network_ops/test___and__.py b/test/test_network_ops/test___and__.py new file mode 100644 index 00000000000..e9a17293140 --- /dev/null +++ b/test/test_network_ops/test___and__.py @@ -0,0 +1,64 @@ +# Copyright (c) 2020 Huawei Technologies Co., Ltd +# All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch_npu +import numpy as np + +from torch_npu.testing.common_utils import TestCase, run_tests +from torch_npu.testing.common_device_type import instantiate_device_type_tests +from torch_npu.testing.util_test import create_common_tensor + +class Test__And__(TestCase): + def cpu_op_exec(self, input1, input2): + output = input1.__and__(input2) + if output.dtype != torch.int32: + output = output.to(torch.int32) + return output.numpy() + + def npu_op_exec(self, input1, input2): + output = input1.__and__(input2) + output = output.to("cpu") + if output.dtype != torch.int32: + output = output.to(torch.int32) + return output.numpy() + + def test___And___shape_format(self, device): + shape_format = [ + [[np.int32, 0, [256, 1000]], [1]], + [[np.int32, 0, [256, 1000]], [np.int32, 0, [256, 1000]]], + [[np.int16, 0, [256, 1000]], [2]], + [[np.int16, 0, [256, 1000]], [np.int16, 0, [256, 1000]]], + [[np.int8, 0, [256, 1000]], [3]], + [[np.int8, 0, [256, 1000]], [np.int8, 0, [256, 1000]]], + [[np.bool, 0, [256, 1000]], [np.bool, 0, [256, 1000]]], + ] + + for item in shape_format: + if len(item[1]) > 1: + cpu_input1, npu_input1 = create_common_tensor(item[0], 0, 100) + cpu_input2, npu_input2 = create_common_tensor(item[1], 0, 100) + cpu_output = self.cpu_op_exec(cpu_input1, cpu_input2) + npu_output = self.npu_op_exec(npu_input1, npu_input2) + self.assertRtolEqual(cpu_output, npu_output) + else: + cpu_input1, npu_input1 = create_common_tensor(item[0], 0, 100) + cpu_output = self.cpu_op_exec(cpu_input1, item[1][0]) + npu_output = self.npu_op_exec(npu_input1, item[1][0]) + self.assertRtolEqual(cpu_output, npu_output) + +instantiate_device_type_tests(Test__And__, globals(), except_for="cpu") +if __name__ == "__main__": + run_tests() \ No newline at end of file diff --git a/test/test_network_ops/test__rshift__.py b/test/test_network_ops/test__rshift__.py new file mode 100644 index 00000000000..9d421865def --- /dev/null +++ b/test/test_network_ops/test__rshift__.py @@ -0,0 +1,70 @@ +# Copyright (c) 2020 Huawei Technologies Co., Ltd +# Copyright (c) 2019, Facebook CORPORATION. +# All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch_npu +import numpy as np + +from torch_npu.testing.common_utils import TestCase, run_tests +from torch_npu.testing.common_device_type import instantiate_device_type_tests +from torch_npu.testing.util_test import create_common_tensor + + +class TestRshift(TestCase): + def cpu_op_exec(self, input1, input2): + output = input1.__rshift__(input2) + output = output.numpy() + return output + + def npu_op_exec(self, input1, input2): + output = input1.__rshift__(input2) + output = output.to("cpu") + output = output.numpy() + return output + + def test_rshift_tensor(self, device): + format_list = [0] + shape_list = [(256, 32, 56)] + shape_format = [ + [np.int32, i, j] for i in format_list for j in shape_list + ] + for item in shape_format: + cpu_input1, npu_input1 = create_common_tensor(item, 0, 100) + cpu_input2 = torch.tensor([1]).to(torch.int32) + npu_input2 = cpu_input2.npu() + cpu_output = self.cpu_op_exec(cpu_input1, cpu_input2) + npu_output = self.npu_op_exec(npu_input1, npu_input2) + cpu_output = cpu_output.astype(npu_output.dtype) + self.assertRtolEqual(cpu_output, npu_output) + + def test_rshift_scalar(self, device): + format_list = [0] + shape_list = [(256, 32, 56)] + shape_format = [ + [np.int32, i, j] for i in format_list for j in shape_list + ] + for item in shape_format: + cpu_input1, npu_input1 = create_common_tensor(item, 0, 100) + cpu_input2 = torch.tensor(1).to(torch.int32) + npu_input2 = cpu_input2.npu() + cpu_output = self.cpu_op_exec(cpu_input1, cpu_input2) + npu_output = self.npu_op_exec(npu_input1, npu_input2) + cpu_output = cpu_output.astype(npu_output.dtype) + self.assertRtolEqual(cpu_output, npu_output) + +instantiate_device_type_tests(TestRshift, globals(), except_for='cpu') +if __name__ == "__main__": + run_tests() diff --git a/test/test_network_ops/test_addmv.py b/test/test_network_ops/test_addmv.py new file mode 100644 index 00000000000..a71f786a59a --- /dev/null +++ b/test/test_network_ops/test_addmv.py @@ -0,0 +1,107 @@ +# Copyright (c) 2020, Huawei Technologies.All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch_npu +import numpy as np + +from torch_npu.testing.common_utils import TestCase, run_tests +from torch_npu.testing.common_device_type import instantiate_device_type_tests +from torch_npu.testing.util_test import create_common_tensor + +class TestAddmv(TestCase): + def cpu_op_exec(self, a, b, c, alpha, beta): + '''output = alpha * a @ b + beta * c''' + output = torch.addmv(c, a, b, alpha=alpha, beta=beta) + output = output.numpy() + return output + + def npu_op_exec(self, a, b, c, alpha, beta): + output = torch.addmv(c, a, b, alpha=alpha, beta=beta) + output = output.to('cpu') + output = output.numpy() + return output + + def npu_op_exec_out(self,a, b, c, beta, alpha, input1): + torch.addmv(c, a, b, alpha=alpha, beta=beta, out=input1) + output = input1.to("cpu") + output = output.numpy() + return output + + def test_addmv_fp16(self, device): + shape_format = [ + [[np.float16, 3, (2, 3)], [np.float16, 3, (3,)], [np.float16, 3, (2, )]] + + ] + for item in shape_format: + + input_a, npu_input_a = create_common_tensor(item[0], -2, 2) + input_b, npu_input_b= create_common_tensor(item[1], -2, 2) + input_c, npu_input_c= create_common_tensor(item[2], -2, 2) + + input_a = input_a.to(torch.float32) + input_b = input_b.to(torch.float32) + input_c = input_c.to(torch.float32) + + cpu_output = self.cpu_op_exec(input_a, input_b, input_c, 1, 1) + npu_output = self.npu_op_exec(npu_input_a, npu_input_b, npu_input_c, 1, 1) + + cpu_output = cpu_output.astype(np.float16) + self.assertRtolEqual(cpu_output, npu_output) + + def test_addmv_out_fp16(self, device): + shape_format = [ + [[np.float16, 3, (2, 3)], [np.float16, 3, (3,)], [np.float16, 3, (2, )], [np.float16, 3, (10,)]] + + ] + for item in shape_format: + + input_a, npu_input_a = create_common_tensor(item[0], -2, 2) + input_b, npu_input_b= create_common_tensor(item[1], -2, 2) + input_c, npu_input_c= create_common_tensor(item[2], -2, 2) + input, npu_input= create_common_tensor(item[3], -2, 2) + + input_a = input_a.to(torch.float32) + input_b = input_b.to(torch.float32) + input_c = input_c.to(torch.float32) + + cpu_output = self.cpu_op_exec(input_a, input_b, input_c, 1, 1) + npu_output= self.npu_op_exec_out(npu_input_a, npu_input_b, npu_input_c, 1, 1, npu_input) + cpu_output = cpu_output.astype(np.float16) + + self.assertRtolEqual(cpu_output, npu_output) + + def test_addmv_fp32(self, device): + shape_format = [ + [[np.float32, 0, (2, 3)], [np.float32, 0, (3,)], [np.float32, 0, (2, )]], + [[np.float32, 0, (3168, 320)], [np.float32, 0, (320,)], [np.float32, 0, (3168, )]], + ] + for item in shape_format: + + input_a, npu_input_a = create_common_tensor(item[0], -2, 2) + input_b, npu_input_b= create_common_tensor(item[1], -2, 2) + input_c, npu_input_c= create_common_tensor(item[2], -2, 2) + + + cpu_output = self.cpu_op_exec(input_a, input_b, input_c, 1, 1) + npu_output = self.npu_op_exec(npu_input_a, npu_input_b, npu_input_c, 1, 1) + + self.assertRtolEqual(cpu_output, npu_output) + + +instantiate_device_type_tests(TestAddmv, globals(), except_for="cpu") +if __name__ == "__main__": + run_tests() + + diff --git a/test/test_network_ops/test_addr.py b/test/test_network_ops/test_addr.py new file mode 100644 index 00000000000..afec4c3a7e6 --- /dev/null +++ b/test/test_network_ops/test_addr.py @@ -0,0 +1,74 @@ +# Copyright (c) 2020, Huawei Technologies.All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch_npu +import numpy as np + +from torch_npu.testing.common_utils import TestCase, run_tests +from torch_npu.testing.common_device_type import instantiate_device_type_tests +from torch_npu.testing.util_test import create_common_tensor + +class TestAddr(TestCase): + def cpu_op_exec(self,input1, vec1, vec2, beta, alpha): + output = torch.addr(beta, input1, alpha, vec1, vec2) + output = output.numpy() + return output + + def npu_op_exec(self,input1, vec1, vec2, beta, alpha): + output = torch.addr(beta, input1, alpha, vec1, vec2) + output = output.to("cpu") + output = output.numpy() + return output + + def npu_op_exec_out(self,input1, input2, vec1, vec2, beta, alpha): + torch.addr(beta, input1, alpha, vec1, vec2, out=input2) + output = input2.to("cpu") + output = output.numpy() + return output + + def test_addr_common_shape_format(self, device): + shape_format = [ + [[np.float32, 0, (5,3)], [np.float32, 0, (5)], [np.float32, 0, (3)]], + [[np.int32, 0, (5,3)], [np.int32, 0, (5)], [np.int32, 0, (3)]], + ] + for item in shape_format: + cpu_input1, npu_input1 = create_common_tensor(item[0], 1, 100) + cpu_vec1, npu_vec1 = create_common_tensor(item[1], 1, 100) + cpu_vec2, npu_vec2 = create_common_tensor(item[2], 1, 100) + beta = 1 + alpha = 1 + cpu_output = self.cpu_op_exec(cpu_input1, cpu_vec1, cpu_vec2, beta, alpha) + npu_output = self.npu_op_exec(npu_input1, npu_vec1, npu_vec2, beta, alpha) + self.assertRtolEqual(cpu_output, npu_output) + + def test_addr_out_common_shape_format(self, device): + shape_format = [ + [[np.float32, 0, (5,3)], [np.float32, 0, (5,3)], [np.float32, 0, (5)], [np.float32, 0, (3)]], + [[np.int32, 0, (5,3)], [np.int32, 0, (5,3)], [np.int32, 0, (5)], [np.int32, 0, (3)]], + ] + for item in shape_format: + cpu_input1, npu_input1 = create_common_tensor(item[0], 1, 100) + cpu_input2, npu_input2 = create_common_tensor(item[1], 1, 100) + cpu_vec1, npu_vec1 = create_common_tensor(item[2], 1, 100) + cpu_vec2, npu_vec2 = create_common_tensor(item[3], 1, 100) + beta = 1 + alpha = 1 + cpu_output = self.cpu_op_exec(cpu_input1, cpu_vec1, cpu_vec2, beta, alpha) + npu_output = self.npu_op_exec_out(npu_input1, npu_input2, npu_vec1, npu_vec2, beta, alpha) + self.assertRtolEqual(cpu_output, npu_output) + +instantiate_device_type_tests(TestAddr, globals(), except_for='cpu') +if __name__ == "__main__": + run_tests() diff --git a/test/test_network_ops/test_affine_grid_generator.py b/test/test_network_ops/test_affine_grid_generator.py new file mode 100644 index 00000000000..eb19f60e0a3 --- /dev/null +++ b/test/test_network_ops/test_affine_grid_generator.py @@ -0,0 +1,84 @@ +# Copyright (c) 2020 Huawei Technologies Co., Ltd +# Copyright (c) 2019, Facebook CORPORATION. +# All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch_npu +import numpy as np +from torch.nn import functional as F + +from torch_npu.testing.common_utils import TestCase, run_tests +from torch_npu.testing.common_device_type import instantiate_device_type_tests +from torch_npu.testing.util_test import create_common_tensor + + +class TestAffineGridGenerator(TestCase): + def cpu_op_exec(self, theta, size, align_corners): + output = F.affine_grid(theta, torch.Size(size), align_corners) + output = output.numpy() + return output + + def npu_op_exec(self, theta, size, align_corners): + theta = theta.npu() + output = torch.affine_grid_generator(theta, size, align_corners) + output = output.cpu().numpy() + return output + + def test_affine_grid_generator_2D(self, device): + theta_list = [[1, 0, 0], + [0, 1, 0], + ] + size = (1, 3, 10, 10) + align_corners_list = [True, False] + dtype_list = [torch.float32, torch.float16] + shape_format = [ + [theta_list, size, i, j] for i in align_corners_list for j in dtype_list + ] + for item in shape_format: + theta = torch.tensor([item[0]], dtype = item[3]) + cpu_input = theta + npu_input = theta + if cpu_input.dtype == torch.float16: + cpu_input = cpu_input.to(torch.float32) + cpu_output = self.cpu_op_exec(cpu_input, item[1], item[2]) + npu_output = self.npu_op_exec(npu_input, item[1], item[2]) + cpu_output = cpu_output.astype(npu_output.dtype) + self.assertRtolEqual(cpu_output, npu_output, 0.001) + + def test_affine_grid_generator_3D(self, device): + theta_list = [[1, 0, 0, 0], + [0, 1, 0, 0], + [0, 0, 1, 0], + ] + size = (1, 3, 10, 10, 10) + align_corners_list = [True, False] + dtype_list = [torch.float16, torch.float32] + shape_format = [ + [theta_list, size, i, j] for i in align_corners_list for j in dtype_list + ] + for item in shape_format: + theta = torch.tensor([item[0]], dtype = item[3]) + cpu_input = theta + npu_input = theta + if cpu_input.dtype == torch.float16: + cpu_input = cpu_input.to(torch.float32) + cpu_output = self.cpu_op_exec(cpu_input, item[1], item[2]) + npu_output = self.npu_op_exec(npu_input, item[1], item[2]) + cpu_output = cpu_output.astype(npu_output.dtype) + self.assertRtolEqual(cpu_output, npu_output, 0.001) + +instantiate_device_type_tests(TestAffineGridGenerator, globals(), except_for='cpu') +if __name__ == "__main__": + run_tests() diff --git a/test/test_network_ops/test_all.py b/test/test_network_ops/test_all.py new file mode 100644 index 00000000000..9d096ee453a --- /dev/null +++ b/test/test_network_ops/test_all.py @@ -0,0 +1,90 @@ +# Copyright (c) 2020, Huawei Technologies.All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +import torch_npu +import numpy as np + +from torch_npu.testing.common_utils import TestCase, run_tests +from torch_npu.testing.common_device_type import instantiate_device_type_tests +from torch_npu.testing.util_test import create_common_tensor + + +class TestAll(TestCase): + def create_bool_tensor(self, shape, minValue, maxValue): + input1 = np.random.uniform(minValue, maxValue, shape) + input1 = input1 > 0.5 + cpu_input = torch.from_numpy(input1) + npu_input = torch.from_numpy(input1).to("npu") + return cpu_input, npu_input + + def cpu_op_exec(self, input1): + output = input1.all() + output = output.numpy() + return output + + def npu_op_exec(self, input1): + output = input1.all() + output = output.to("cpu") + output = output.numpy() + return output + + def test_all_shape_format(self, device): + shape_list = [[1024], [32, 1024], [32, 8, 1024], [128, 32, 8, 1024], [2, 0, 2]] + for item in shape_list: + cpu_input, npu_input = self.create_bool_tensor(item, 0, 1) + cpu_output = self.cpu_op_exec(cpu_input) + npu_output = self.npu_op_exec(npu_input) + self.assertRtolEqual( + cpu_output.astype( + np.int32), npu_output.astype( + np.int32)) + + def cpu_op_exec1(self, input1, dim): + output = input1.all(dim=dim) + output = output.numpy() + return output + + def npu_op_exec1(self, input1, dim): + output = input1.all(dim=dim) + output = output.to("cpu") + output = output.numpy() + return output + + def npu_op_out_exec1(self, input1, dim): + shape = list(input1.shape) + output0 = torch.randn(shape) > 0 + output1 = torch.randn(shape.pop()) > 0 + output0 = output0.npu() + output1 = output1.npu() + torch.all(input1, dim=dim, keepdim = False, out = output0) + torch.all(input1, dim=dim, keepdim = False, out = output1) + output0 = output0.to("cpu").numpy() + output1 = output1.to("cpu").numpy() + return output0, output1 + + def test_alld_shape_format(self, device): + shape_list = [[1024], [32, 1024], [32, 8, 1024], [128, 32, 8, 1024]] + for item in shape_list: + cpu_input, npu_input = self.create_bool_tensor(item, 0, 1) + cpu_output = self.cpu_op_exec1(cpu_input, 0) + npu_output = self.npu_op_exec1(npu_input, 0) + npu_out0, npu_out1 = self.npu_op_out_exec1(npu_input, 0) + self.assertRtolEqual(cpu_output.astype(np.int32), npu_output.astype(np.int32)) + self.assertRtolEqual(cpu_output.astype(np.int32), npu_out0.astype(np.int32)) + self.assertRtolEqual(cpu_output.astype(np.int32), npu_out1.astype(np.int32)) + + +instantiate_device_type_tests(TestAll, globals(), except_for="cpu") +if __name__ == "__main__": + run_tests() \ No newline at end of file diff --git a/test/test_network_ops/test_any.py b/test/test_network_ops/test_any.py new file mode 100644 index 00000000000..0241c17349e --- /dev/null +++ b/test/test_network_ops/test_any.py @@ -0,0 +1,95 @@ +# Copyright (c) 2020, Huawei Technologies.All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +import torch_npu +import numpy as np + +from torch_npu.testing.common_utils import TestCase, run_tests +from torch_npu.testing.common_device_type import instantiate_device_type_tests +from torch_npu.testing.util_test import create_common_tensor + +class TestAny(TestCase): + def create_bool_tensor(self, shape, minValue, maxValue): + input1 = np.random.uniform(minValue, maxValue, shape) + cpu_input = torch.from_numpy(input1) > 0.5 + npu_input = (torch.from_numpy(input1) > 0.5).to("npu") + return cpu_input, npu_input + + def cpu_op_exec(self, input1): + output = input1.any() + output = output.numpy() + return output + + def npu_op_exec(self, input1): + output = input1.any() + output = output.to("cpu") + output = output.numpy() + return output + + def test_any_shape_format(self, device): + shape_list = [[], + [1024], + [32, 1024], + [32, 8, 1024], + [128, 32, 8, 1024]] + + for item in shape_list: + cpu_input, npu_input = self.create_bool_tensor(item, 0, 1) + cpu_output = self.cpu_op_exec(cpu_input) + npu_output = self.npu_op_exec(npu_input) + self.assertRtolEqual( + cpu_output.astype(np.int32), + npu_output.astype(np.int32)) + + def cpu_op_exec1(self, input1, dim, keepdim): + output = input1.any(dim=dim, keepdim=keepdim) + output = output.numpy() + return output + + def npu_op_exec1(self, input1, dim, keepdim): + output = input1.any(dim=dim, keepdim=keepdim) + output = output.to("cpu") + output = output.numpy() + return output + + def npu_op_out_exec1(self, input1, dim, keepdim): + shape = list(input1.shape) + output0 = torch.randn(shape)>0 + output1 = torch.randn(shape.pop())>0 + output0 = output0.npu() + output1 = output1.npu() + torch.any(input1, dim=dim, keepdim=keepdim, out=output0) + torch.any(input1, dim=dim, keepdim=keepdim, out=output1) + output0 = output0.to("cpu").numpy() + output1 = output1.to("cpu").numpy() + return output0, output1 + + def test_anyd_shape_format(self, device): + shape_list = [[ [1024], 0, False], + [ [32, 1024], 1, False], + [ [32, 8, 1024], 2, True ], + [ [128, 32, 8, 1024], 3, True ]] + + for item in shape_list: + cpu_input, npu_input = self.create_bool_tensor(item[0], 0, 1) + cpu_output = self.cpu_op_exec1(cpu_input, item[1], item[2]) + npu_output = self.npu_op_exec1(npu_input, item[1], item[2]) + npu_out0, npu_out1 = self.npu_op_out_exec1(npu_input, item[1], item[2]) + self.assertRtolEqual(cpu_output.astype(np.int32),npu_output.astype(np.int32)) + self.assertRtolEqual(cpu_output.astype(np.int32),npu_out0.astype(np.int32)) + self.assertRtolEqual(cpu_output.astype(np.int32),npu_out1.astype(np.int32)) + +instantiate_device_type_tests(TestAny, globals(), except_for="cpu") +if __name__ == "__main__": + run_tests() \ No newline at end of file diff --git a/test/test_network_ops/test_arange.py b/test/test_network_ops/test_arange.py new file mode 100644 index 00000000000..b2b0ca16b20 --- /dev/null +++ b/test/test_network_ops/test_arange.py @@ -0,0 +1,55 @@ +# Copyright (c) 2020, Huawei Technologies.All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +import torch_npu +import numpy as np + +from torch_npu.testing.common_utils import TestCase, run_tests +from torch_npu.testing.common_device_type import instantiate_device_type_tests +from torch_npu.testing.util_test import create_common_tensor + +class TestArange(TestCase): + def test_arange(self, device): + shape_format = [ + [0, 100, 2, torch.float32], + [1, 100, 1, torch.int32], + [5, 100, 3, torch.int64], + ] + + for item in shape_format: + cpu_output = torch.arange(item[0], item[1], item[2], dtype=item[3], + device="cpu").numpy() + npu_output = torch.arange(item[0], item[1], item[2], dtype=item[3], + device="npu").cpu().numpy() + self.assertRtolEqual(cpu_output, npu_output) + + def test_arange_out(self, device): + shape_format = [ + [0, 100, 1, torch.float32, [np.float32, 0, [10]]], + [1, 100, 2, torch.int32, [np.int32, 0, [20]]], + [5, 100, 3, torch.int64, [np.int64, 0, [30]]], + ] + + for item in shape_format: + cpu_input1, npu_input1 = create_common_tensor(item[4], 0, 10) + cpu_output = torch.arange(item[0], item[1], item[2], + dtype=item[3], device="cpu").numpy() + npu_output = torch.arange(item[0], item[1], item[2], out = npu_input1, + dtype=item[3], device="npu").cpu().numpy() + self.assertRtolEqual(cpu_output, npu_output) + + +instantiate_device_type_tests(TestArange, globals(), except_for='cpu') +if __name__ == "__main__": + run_tests() diff --git a/test/test_network_ops/test_argmax.py b/test/test_network_ops/test_argmax.py new file mode 100644 index 00000000000..36707cf0e94 --- /dev/null +++ b/test/test_network_ops/test_argmax.py @@ -0,0 +1,102 @@ +# Copyright (c) 2020, Huawei Technologies.All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +import torch_npu +import numpy as np + +from torch_npu.testing.common_utils import TestCase, run_tests +from torch_npu.testing.common_device_type import instantiate_device_type_tests +from torch_npu.testing.util_test import create_common_tensor + + +class TestArgmax(TestCase): + def cpu_op_exec(self, input1): + output = torch.argmax(input1) + output = output.numpy() + return output + + def npu_op_exec(self, input1): + output = torch.argmax(input1) + output = output.to("cpu") + output = output.numpy() + return output + + def test_argmax_shape_format_fp16(self, device): + format_list = [0] + shape_list = [[5], [2, 4], [2, 2, 4], [2, 3, 3, 4]] + shape_format = [ + [np.float16, i, j] for i in format_list for j in shape_list + ] + for item in shape_format: + cpu_input, npu_input = create_common_tensor(item, -10, 10) + cpu_input = cpu_input.to(torch.float32) + cpu_output = self.cpu_op_exec(cpu_input) + npu_output = self.npu_op_exec(npu_input) + cpu_output = cpu_output.astype(npu_output.dtype) + self.assertRtolEqual(cpu_output, npu_output) + + def test_argmax_shape_format_fp32(self, device): + format_list = [0] + shape_list = [[5], [2, 4], [2, 2, 4], [2, 3, 3, 4]] + shape_format = [ + [np.float32, i, j] for i in format_list for j in shape_list + ] + for item in shape_format: + cpu_input, npu_input = create_common_tensor(item, -10, 10) + cpu_output = self.cpu_op_exec(cpu_input) + npu_output = self.npu_op_exec(npu_input) + cpu_output = cpu_output.astype(npu_output.dtype) + self.assertRtolEqual(cpu_output, npu_output) + + def cpu_op_exec1(self, input1, dim): + output = torch.argmax(input1, dim) + output = output.numpy() + return output + + def npu_op_exec1(self, input1, dim): + output = torch.argmax(input1, dim) + output = output.to("cpu") + output = output.numpy() + return output + + def test_argmaxd_shape_format_fp16(self, device): + format_list = [0] + shape_list = [[5], [2, 4], [2, 2, 4], [2, 3, 3, 4]] + shape_format = [ + [np.float16, i, j] for i in format_list for j in shape_list + ] + for item in shape_format: + cpu_input, npu_input = create_common_tensor(item, -10, 10) + cpu_input = cpu_input.to(torch.float32) + cpu_output = self.cpu_op_exec1(cpu_input, -1) + npu_output = self.npu_op_exec1(npu_input, -1) + cpu_output = cpu_output.astype(npu_output.dtype) + self.assertRtolEqual(cpu_output, npu_output) + + def test_argmaxd_shape_format_fp32(self, device): + format_list = [0] + shape_list = [[5], [2, 4], [2, 2, 4], [2, 3, 3, 4]] + shape_format = [ + [np.float32, i, j] for i in format_list for j in shape_list + ] + for item in shape_format: + cpu_input, npu_input = create_common_tensor(item, -10, 10) + cpu_output = self.cpu_op_exec1(cpu_input, -1) + npu_output = self.npu_op_exec1(npu_input, -1) + cpu_output = cpu_output.astype(npu_output.dtype) + self.assertRtolEqual(cpu_output, npu_output) + +instantiate_device_type_tests(TestArgmax, globals(), except_for="cpu") +if __name__ == "__main__": + run_tests() diff --git a/torch_npu/csrc/aten/ops/AddmvKernelNpu.cpp b/torch_npu/csrc/aten/ops/AddmvKernelNpu.cpp new file mode 100644 index 00000000000..d56497a1a40 --- /dev/null +++ b/torch_npu/csrc/aten/ops/AddmvKernelNpu.cpp @@ -0,0 +1,85 @@ +// Copyright (c) 2020, Huawei Technologies.All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" + +namespace at_npu { +namespace native { + +at::Tensor& NPUNativeFunctions::addmv_out( + const at::Tensor& self, + const at::Tensor& mat, + const at::Tensor& vec, + at::Scalar beta, + at::Scalar alpha, + at::Tensor& result) { + NpuUtils::check_1d(vec, "vec", "addmv"); + + at::Tensor mat1 = vec.unsqueeze(1); + + // matmul mat*alpha + at::Tensor mat_alpha = at::mul(mat, alpha); + + // matmul*alpha + at::Tensor mmMulResult = at::mm(mat_alpha, mat1); + + at::Tensor mmMulResult1 = mmMulResult.squeeze(); + + // calculate the output size + auto outputSize = addmv_npu_output_size(self, mat, vec, beta, alpha); + + if (!result.sizes().equals(outputSize)) { + result.resize_(outputSize); + } + // matmul*alpha+self*beta + at::add_out(result, mmMulResult1, self, beta); + + return result; +} + +at::Tensor NPUNativeFunctions::addmv( + const at::Tensor& self, + const at::Tensor& mat, + const at::Tensor& vec, + at::Scalar beta, + at::Scalar alpha) { + auto outputSize = addmv_npu_output_size(self, mat, vec, beta, alpha); + at::Tensor result = OpPreparation::ApplyTensor(self, outputSize); + addmv_out(self, mat, vec, beta, alpha, result); + + return result; +} + +at::Tensor& NPUNativeFunctions::addmv_( + at::Tensor& self, + const at::Tensor& mat, + const at::Tensor& vec, + at::Scalar beta, + at::Scalar alpha) { + OpPreparation::CheckMemory({self, mat, vec}, {self}); + if (!NpuUtils::check_match(&self)) { + at::Tensor contiguousSelf = NpuUtils::format_contiguous(self); + at::Tensor result = + addmv_out(contiguousSelf, mat, vec, beta, alpha, contiguousSelf); + NpuUtils::format_fresh_view(self, result); + } else { + addmv_out(self, mat, vec, beta, alpha, self); + } + return self; +} + +} // namespace native +} // namespace at_npu + + diff --git a/torch_npu/csrc/aten/ops/AddrKernelNpu.cpp b/torch_npu/csrc/aten/ops/AddrKernelNpu.cpp new file mode 100644 index 00000000000..00c47b2e9f2 --- /dev/null +++ b/torch_npu/csrc/aten/ops/AddrKernelNpu.cpp @@ -0,0 +1,78 @@ +// Copyright (c) 2020, Huawei Technologies.All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" + +namespace at_npu { +namespace native { + +at::Tensor& NPUNativeFunctions::addr_out( + const at::Tensor& self, + const at::Tensor& vec1, + const at::Tensor& vec2, + at::Scalar beta, + at::Scalar alpha, + at::Tensor& result) { + NpuUtils::check_1d(vec1, "vec1", "addr"); + NpuUtils::check_1d(vec2, "vec2", "addr"); + + at::Tensor mat1 = vec1.unsqueeze(1); + at::Tensor mat2 = vec2.unsqueeze(0); + + // vecmul vec1&vec2 + at::Tensor mmResult = at::mm(mat1, mat2); + + // matmul*alpha + at::Tensor mmMulResult = at::mul(mmResult, alpha); + + // matmul*alpha+self*beta + at::add_out(result, mmMulResult, self, beta); + + return result; +} + +at::Tensor NPUNativeFunctions::addr( + const at::Tensor& self, + const at::Tensor& vec1, + const at::Tensor& vec2, + at::Scalar beta, + at::Scalar alpha) { + auto outputSize = addr_npu_output_size(self, vec1, vec2, beta, alpha); + at::Tensor result = OpPreparation::ApplyTensor(self, outputSize); + addr_out(self, vec1, vec2, beta, alpha, result); + + return result; +} + +at::Tensor& NPUNativeFunctions::addr_( + at::Tensor& self, + const at::Tensor& vec1, + const at::Tensor& vec2, + at::Scalar beta, + at::Scalar alpha) { + OpPreparation::CheckMemory({self, vec1, vec2}, {self}); + if (!NpuUtils::check_match(&self)) { + at::Tensor contiguousSelf = NpuUtils::format_contiguous(self); + at::Tensor result = + addr_out(contiguousSelf, vec1, vec2, beta, alpha, contiguousSelf); + NpuUtils::format_fresh_view(self, result); + } else { + addr_out(self, vec1, vec2, beta, alpha, self); + } + + return self; +} + +} // namespace native +} // namespace at_npu diff --git a/torch_npu/csrc/aten/ops/AffineGridGeneratorKernelNpu.cpp b/torch_npu/csrc/aten/ops/AffineGridGeneratorKernelNpu.cpp new file mode 100644 index 00000000000..d21c6256c61 --- /dev/null +++ b/torch_npu/csrc/aten/ops/AffineGridGeneratorKernelNpu.cpp @@ -0,0 +1,79 @@ +// Copyright (c) 2020 Huawei Technologies Co., Ltd +// Copyright (c) 2019, Facebook CORPORATION. +// All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" + +namespace at_npu { +namespace native { + +at::Tensor& affine_grid_generator_npu_nocheck( + at::Tensor& result, + const at::Tensor& theta, + at::IntArrayRef size, + bool align_corners) { + int value = size.size(); + vector tmp_vector = {}; + for (int i = 0; i < value; i++) { + tmp_vector.emplace_back(size[i]); + } + at::Tensor sizeTensor_cpu = from_blob((void*)tmp_vector.data(), {value}, at::kInt); + at::Tensor sizeTensor_npu = CalcuOpUtil::copy_tensor_host_to_device(sizeTensor_cpu); + + OpCommand cmd; + cmd.Name("AffineGrid") + .Input(theta) + .InputPair(sizeTensor_npu, sizeTensor_cpu) + .Output(result) + .Attr("align_corners", align_corners) + .Run(); + + return result; +} + +at::Tensor NPUNativeFunctions::affine_grid_generator( + const at::Tensor& theta, + at::IntArrayRef size, + bool align_corners) { + TORCH_CHECK(size.size() == 4 || size.size() == 5, + "AffineGridGenerator needs 4d or 5d size(input)."); + // calculate the output size + SmallVector outputSize = { }; + if(size.size() == 4) { + outputSize = {size[0], size[2] * size[3], 2}; + } else { + outputSize = {size[0], size[2] * size[3] * size[4], 3}; + } + + // construct the output tensor of the NPU + at::Tensor result = OpPreparation::ApplyTensor(theta, outputSize); + // calculate the output result of the NPU + affine_grid_generator_npu_nocheck( + result, + theta, + size, + align_corners); + + if(size.size() == 4) { + result = result.view({size[0], size[2], size[3], 2}); + } else { + result = result.view({size[0], size[2], size[3], size[4], 3}); + } + return result; +} + +} // namespace native +} // namespace at_npu \ No newline at end of file diff --git a/torch_npu/csrc/aten/ops/AllKernelNpu.cpp b/torch_npu/csrc/aten/ops/AllKernelNpu.cpp new file mode 100644 index 00000000000..0d12ca06f37 --- /dev/null +++ b/torch_npu/csrc/aten/ops/AllKernelNpu.cpp @@ -0,0 +1,112 @@ +// Copyright (c) 2020 Huawei Technologies Co., Ltd +// Copyright (c) 2019, Facebook CORPORATION. +// All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "c10/npu/OptionsManager.h" + +#include "torch_npu/csrc/framework/utils/CalcuOpUtil.h" +#include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" + +namespace at_npu { +namespace native { +using namespace at::native::npu; + +inline at::Tensor all_out_npu_nocheck( + at::Tensor& result, + const at::Tensor& self, + c10::SmallVector dimList, + bool keepdim) { + OpCommand cmd; + cmd.Name("ReduceAll") + .Input(self) + .Input(dimList, at::kLong) + .Output(result) + .Attr("keep_dims", keepdim) + .Run(); + + return result; +} + +at::Tensor& NPUNativeFunctions::all_out( + const at::Tensor& self, + int64_t dim, + bool keepdim, + at::Tensor& result) { + c10::SmallVector dimList = {dim}; + + // check result for return + auto outputSize = reduce_ops_npu_output_size(self, dimList, keepdim); + OpPreparation::CheckOut( + {self}, + result, + CalcuOpUtil::get_tensor_npu_format(self), + self.scalar_type(), + outputSize); + + // calculate the output result of the NPU + all_out_npu_nocheck( + result, self, dimList, keepdim); + + return result; +} + +at::Tensor NPUNativeFunctions::all(const at::Tensor& self, int64_t dim, bool keepdim) { + TORCH_CHECK(self.scalar_type() == at::ScalarType::Bool || self.scalar_type() == at::ScalarType::Byte, + "all only supports torch.uint8 and torch.bool dtypes"); + if (self.numel() == 0) { + at::Tensor res = OpPreparation::ApplyTensor({}, self.options().dtype(kInt), self).fill_(1).to(at::ScalarType::Bool); + return res; + } + + // calculate the output size + at::IntArrayRef dims(dim); + auto outputSize = reduce_ops_npu_output_size(self, dims, keepdim); + + // construct the output tensor of the NPU + at::Tensor result = OpPreparation::ApplyTensor(self, outputSize); + + // calculate the output result of the NPU + all_out_npu_nocheck(result, self, {dim}, keepdim); + + return result; +} + +at::Tensor NPUNativeFunctions::all(const at::Tensor& self) { + TORCH_CHECK(self.scalar_type() == at::ScalarType::Bool || self.scalar_type() == at::ScalarType::Byte, + "all only supports torch.uint8 and torch.bool dtypes"); + if (self.numel() == 0) { + at::Tensor res = OpPreparation::ApplyTensor({}, self.options().dtype(kInt), self).fill_(1).to(at::ScalarType::Bool); + return res; + } + + // calculate the output size + at::IntArrayRef dims; + auto outputSize = reduce_ops_npu_output_size(self, dims, false); + + // construct the output tensor of the NPU + at::Tensor result = OpPreparation::ApplyTensor(self, outputSize); + + // calculate the output result of the NPU + all_out_npu_nocheck( + result, + self, + CalcuOpUtil::get_dimlist_for_tensor(self), + false); + + return result; +} + +} // namespace native +} // namespace at_npu \ No newline at end of file diff --git a/torch_npu/csrc/aten/ops/AnyKernelNpu.cpp b/torch_npu/csrc/aten/ops/AnyKernelNpu.cpp new file mode 100644 index 00000000000..1353fd53ee6 --- /dev/null +++ b/torch_npu/csrc/aten/ops/AnyKernelNpu.cpp @@ -0,0 +1,115 @@ +// Copyright (c) 2020 Huawei Technologies Co., Ltd +// Copyright (c) 2019, Facebook CORPORATION. +// All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "torch_npu/csrc/framework/utils/CalcuOpUtil.h" +#include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" + +namespace at_npu { +namespace native { + +inline at::Tensor& any_out_npu_nocheck( + at::Tensor& result, + const at::Tensor& self, + at::SmallVector dimList, + bool keepdim) { + OpCommand cmd; + cmd.Name("ReduceAny") + .Input(self) + .Input(dimList) + .Output(result) + .Attr("keep_dims", keepdim) + .Run(); + + return result; +} + +at::Tensor& NPUNativeFunctions::any_out( + const at::Tensor& self, + int64_t dim, + bool keepdim, + at::Tensor& result) { + at::SmallVector dimList; + if (dim == LLONG_MIN) { + dimList = CalcuOpUtil::get_dimlist_for_tensor(self); + } else { + dimList = {dim}; + } + + // check result for return + auto outputSize = reduce_ops_npu_output_size(self, dimList, keepdim); + OpPreparation::CheckOut( + {self}, + result, + CalcuOpUtil::get_tensor_npu_format(self), + self.scalar_type(), + outputSize); + + // calculate the output result of the NPU + any_out_npu_nocheck(result, self, dimList, keepdim); + + return result; +} + +at::Tensor NPUNativeFunctions::any(const at::Tensor& self, int64_t dim, bool keepdim) { + // calculate the output size + at::IntArrayRef dims(dim); + auto outputSize = reduce_ops_npu_output_size(self, dims, keepdim); + + // construct the output tensor of the NPU + at::Tensor result = at::empty_with_format( + outputSize, self.options(), CalcuOpUtil::get_tensor_npu_format(self)); + + // calculate the output result of the NPU + if (dim == LLONG_MIN) { + any_out_npu_nocheck( + result, self, CalcuOpUtil::get_dimlist_for_tensor(self), keepdim); + } else { + any_out_npu_nocheck(result, self, {dim}, keepdim); + } + + return result; +} + +at::Tensor NPUNativeFunctions::any(const at::Tensor& self) { + // when self's dim = 0, convert [1] tensor and reduce it + if (self.dim() == 0) { + at::Tensor self_tmp = self; + self_tmp = at::empty_with_format( + {1}, + self.options().dtype(at::ScalarType::Float), + CalcuOpUtil::get_tensor_npu_format(self)) + .fill_(self.item()) + .to(at::ScalarType::Bool); + return NPUNativeFunctions::any(self_tmp, 0, false); + } + + // calculate the output size + IntArrayRef dims; + auto outputSize = reduce_ops_npu_output_size(self, dims, false); + + // construct the output tensor of the NPU + at::Tensor result = OpPreparation::ApplyTensor(self, outputSize); + + // calculate the output result of the NPU + any_out_npu_nocheck( + result, self, CalcuOpUtil::get_dimlist_for_tensor(self), false); + + return result; +} + +} // namespace native +} // namespace at_npu \ No newline at end of file diff --git a/torch_npu/csrc/aten/ops/ArangeKernelNpu.cpp b/torch_npu/csrc/aten/ops/ArangeKernelNpu.cpp new file mode 100644 index 00000000000..94f5169863b --- /dev/null +++ b/torch_npu/csrc/aten/ops/ArangeKernelNpu.cpp @@ -0,0 +1,177 @@ +// Copyright (c) 2020 Huawei Technologies Co., Ltd +// Copyright (c) 2019, Facebook CORPORATION. +// All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "torch_npu/csrc/framework/utils/CalcuOpUtil.h" +#include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" + +namespace at_npu { +namespace native { + +// bool inputs are considered integral +static inline bool allIntegral( + std::initializer_list> l) { + for (Scalar& s : l) { + if (!s.isIntegral(true)) { + return false; + } + } + return true; +} + + +at::Tensor& arange_out_npu_nocheck( + at::Tensor& result, + at::Scalar start, + at::Scalar end, + at::Scalar step) { + OpCommand cmd; + cmd.Name("Range") + .Input(start, result.scalar_type()) + .Input(end, result.scalar_type()) + .Input(step, result.scalar_type()) + .Output(result) + .Run(); + + return result; +} + +at::Tensor NPUNativeFunctions::arange( + at::Scalar start, + at::Scalar end, + at::Scalar step, + c10::optional dtype_opt, + c10::optional layout_opt, + c10::optional device_opt, + c10::optional pin_memory_opt) { + + auto device = device_or_default(device_opt); + TensorOptions option; + option = option.dtype(dtype_opt) + .layout(layout_opt) + .device(device) + .pinned_memory(pin_memory_opt); + + float start_value = CalcuOpUtil::get_scalar_float_value(start); + float end_value = CalcuOpUtil::get_scalar_float_value(end); + float step_value = CalcuOpUtil::get_scalar_float_value(step); + + // Check step start end + TORCH_CHECK(step_value != 0, "step must be nonzero"); + TORCH_CHECK(((step_value > 0) && (end_value >= start_value)) || ((step_value < 0) && (end_value <= start_value)), + "upper bound and larger bound inconsistent with step sign"); + + bool set_to_integral_dtype = + !option.has_dtype() && allIntegral({start, end, step}); + + // check start == end + if (set_to_integral_dtype) { + option = option.dtype(at::ScalarType::Int); + } + at::Tensor result_check = OpPreparation::ApplyTensorWithFormat({0}, option, ACL_FORMAT_ND); + + if (start_value == end_value) { + return result_check; + } + + // calculate the output size + double size_arange = std::ceil(static_cast(end.toDouble() - start.toDouble()) + / step.toDouble()); + int64_t size_value = static_cast(size_arange); + at::SmallVector outputSize = {size_value}; + at::Tensor result = OpPreparation::ApplyTensorWithFormat(outputSize, option, ACL_FORMAT_ND); + + if(option.dtype() == at::kHalf) { + result = result.to(at::kFloat); + } + + arange_out_npu_nocheck(result, start, end, step); + + if(option.dtype() == at::kHalf) { + result = result.to(at::kHalf); + } + + return result; +} + +at::Tensor NPUNativeFunctions::arange( + at::Scalar start, + at::Scalar end, + c10::optional dtype_opt, + c10::optional layout_opt, + c10::optional device_opt, + c10::optional pin_memory_opt) { + + return NPUNativeFunctions::arange(start, end, 1, dtype_opt, layout_opt, device_opt, pin_memory_opt); +} + + +at::Tensor NPUNativeFunctions::arange( + at::Scalar end, + c10::optional dtype_opt, + c10::optional layout_opt, + c10::optional device_opt, + c10::optional pin_memory_opt) { + + return NPUNativeFunctions::arange(0, end, dtype_opt, layout_opt, device_opt, pin_memory_opt); // start = 0 +} + +at::Tensor& NPUNativeFunctions::arange_out( + at::Scalar start, + at::Scalar end, + at::Scalar step, + at::Tensor& result) { + float start_value = CalcuOpUtil::get_scalar_float_value(start); + float end_value = CalcuOpUtil::get_scalar_float_value(end); + float step_value = CalcuOpUtil::get_scalar_float_value(step); + + // Check step start end + TORCH_CHECK(step_value != 0, "step must be nonzero"); + TORCH_CHECK(((step_value > 0) && (end_value >= start_value)) || ((step_value < 0) && (end_value <= start_value)), + "upper bound and larger bound inconsistent with step sign"); + + // calculate the output size + double size_arange = std::ceil(static_cast(end.toDouble() - start.toDouble()) + / step.toDouble()); + int64_t size_value = static_cast(size_arange); + SmallVector outputSize = {size_value}; + result.resize_(outputSize); + + arange_out_npu_nocheck(result, start, end, step); + + return result; +} + +at::Tensor& arange_other_out_npu(at::Scalar start, at::Scalar end, at::Tensor& result) { + return arange_out(start, end, 1, result); +} + +at::Tensor& NPUNativeFunctions::arange_out(Scalar end, at::Tensor& result) { + return arange_other_out_npu(0, end, result); +} + +at::Tensor NPUNativeFunctions::_dim_arange(const at::Tensor& self, int64_t dim) { + c10::optional dtype_opt(at::kInt); + c10::optional layout_opt(self.options().layout()); + c10::optional device_opt(self.options().device()); + c10::optional pin_memory_opt(self.options().pinned_memory()); + + at::Tensor result = NPUNativeFunctions::arange(self.size(dim), dtype_opt, layout_opt, device_opt, pin_memory_opt); + return result; +} + +} // namespace native +} // namespace at_npu \ No newline at end of file diff --git a/torch_npu/csrc/aten/ops/ArgmaxKernelNpu.cpp b/torch_npu/csrc/aten/ops/ArgmaxKernelNpu.cpp new file mode 100644 index 00000000000..1e0707fbfa2 --- /dev/null +++ b/torch_npu/csrc/aten/ops/ArgmaxKernelNpu.cpp @@ -0,0 +1,49 @@ +// Copyright (c) 2020 Huawei Technologies Co., Ltd +// Copyright (c) 2019, Facebook CORPORATION. +// All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" + +namespace at_npu { +namespace native { + +at::Tensor NPUNativeFunctions::argmax(const at::Tensor& self, at::optional dim, bool keepdim) { + at::Tensor input = dim.has_value() ? self : self.reshape({-1}); + int64_t realDim = dim.has_value() ? dim.value() : 0; + bool realKeepDim = dim.has_value() ? keepdim : false; + + // calculate the output size + auto outputSize = reduce_ops_npu_output_size(input, realDim, realKeepDim); + + // construct the output tensor of the NPU + at::Tensor result = OpPreparation::ApplyTensorWithSizes(outputSize, self.options().dtype(at::kInt)); + + at::SmallVector DimVec = {realDim}; + // calculate the output result of the NPU + OpCommand cmd; + cmd.Name("ArgMaxV2") + .Input(input) + .Input(DimVec, at::kInt) + .Output(result) + .Attr("keep_dims", realKeepDim) + .Run(); + + result = result.to(at::ScalarType::Long); + return result; +} + +} // namespace native +} // namespace at_npu diff --git a/torch_npu/csrc/aten/ops/__And__KernelNpu.cpp b/torch_npu/csrc/aten/ops/__And__KernelNpu.cpp new file mode 100644 index 00000000000..2e972330cad --- /dev/null +++ b/torch_npu/csrc/aten/ops/__And__KernelNpu.cpp @@ -0,0 +1,88 @@ +// Copyright (c) 2020 Huawei Technologies Co., Ltd +// Copyright (c) 2019, Facebook CORPORATION. +// All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "torch_npu/csrc/framework/utils/CalcuOpUtil.h" +#include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" + +namespace at_npu { +namespace native { + +at::Tensor __and___dest_output(const at::Tensor& self, const at::Tensor& other) { + bool isSelfWrapped = CalcuOpUtil::is_scalar_wrapped_to_tensor(self); + + if (not isSelfWrapped) { + return self; + } else { + return other; + } +} + +at::Tensor& __and___out_npu_nocheck( + const at::Tensor& self, + const at::Scalar other, + at::Tensor& result) { + OpCommand cmd; + cmd.Name((self.scalar_type() == at::ScalarType::Bool) ? "LogicalAnd" : "BitwiseAnd") + .Input(self) + .Input(other,self.scalar_type()) + .Output(result) + .Run(); + + return result; +} + +at::Tensor& __and___out_npu_nocheck( + const at::Tensor& self, + const at::Tensor& other, + at::Tensor& result) { + if (other.dim() == 0 && !other.is_npu()) { + __and___out_npu_nocheck(self, other.item(),result); + } else if (self.dim() == 0 && !self.is_npu()) { + __and___out_npu_nocheck(other, self.item(),result); + } else { + OpCommand cmd; + cmd.Name((self.scalar_type() == at::ScalarType::Bool) ? "LogicalAnd" : "BitwiseAnd") + .Input(self) + .Input(other) + .Output(result) + .Run(); + } + + return result; +} + +at::Tensor NPUNativeFunctions::__and___tensor(const at::Tensor& self, const at::Tensor& other) { + // calculate the output size + at::Tensor outputTensor = __and___dest_output(self, other); + auto outputSize = broadcast_ops_npu_output_size(self, other); + + // construct the output tensor of the NPU + at::Tensor result = OpPreparation::ApplyTensor(outputTensor, outputSize); + // calculate the output result of the NPU + __and___out_npu_nocheck(self, other,result); + return result; +} + +at::Tensor NPUNativeFunctions::__and___scalar(const at::Tensor& self, at::Scalar other) { + at::Tensor result = OpPreparation::ApplyTensor(self); + __and___out_npu_nocheck(self, other,result); + + return result; +} + +} // namespace native +} // namespace at_npu \ No newline at end of file diff --git a/torch_npu/csrc/aten/ops/__Rshift__KernelNpu.cpp b/torch_npu/csrc/aten/ops/__Rshift__KernelNpu.cpp new file mode 100644 index 00000000000..e5e78322642 --- /dev/null +++ b/torch_npu/csrc/aten/ops/__Rshift__KernelNpu.cpp @@ -0,0 +1,69 @@ +// Copyright (c) 2020 Huawei Technologies Co., Ltd +// Copyright (c) 2019, Facebook CORPORATION. +// All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" + +namespace at_npu { +namespace native { + +at::Tensor& __rshift___out_npu_nocheck( + const at::Tensor& self, + at::Scalar other, + at::Tensor& result) { + OpCommand cmd; + cmd.Name("RightShift") + .Input(self) + .Input(other, self.scalar_type()) + .Output(result) + .Run(); + + return result; +} + +at::Tensor& __rshift___out_npu_nocheck( + const at::Tensor& self, + const at::Tensor& other, + at::Tensor& result) { + OpCommand cmd; + cmd.Name("RightShift") + .Input(self) + .Input(other) + .Output(result) + .Run(); + + return result; +} + +at::Tensor NPUNativeFunctions::__rshift___tensor(const at::Tensor& self, const at::Tensor& other) { + // construct the output tensor of the NPU + at::Tensor result = OpPreparation::ApplyTensor(self); + __rshift___out_npu_nocheck( self, other,result); + + return result; +} + +at::Tensor NPUNativeFunctions::__rshift___scalar(const at::Tensor& self, at::Scalar other) { + // construct the output tensor of the NPU + at::Tensor result = OpPreparation::ApplyTensor(self); + + __rshift___out_npu_nocheck(self, other, result); + + return result; +} + +} // namespace native +} // namespace at \ No newline at end of file -- Gitee From 7259e9af10bf702725b5f373d56052ca9360c2f4 Mon Sep 17 00:00:00 2001 From: wangxiao Date: Thu, 27 Jan 2022 19:43:49 +0800 Subject: [PATCH 2/4] __and__, __rshift__, affine_grid_generator, all, any, arange fi bug --- torch_npu/csrc/aten/npu_native_functions.yaml | 4 --- .../aten/ops/AffineGridGeneratorKernelNpu.cpp | 6 ++-- torch_npu/csrc/aten/ops/AllKernelNpu.cpp | 5 ++- torch_npu/csrc/aten/ops/AnyKernelNpu.cpp | 2 +- torch_npu/csrc/aten/ops/ArangeKernelNpu.cpp | 31 ++++++++++--------- torch_npu/csrc/aten/ops/__And__KernelNpu.cpp | 4 +-- .../csrc/aten/ops/__Rshift__KernelNpu.cpp | 4 +-- 7 files changed, 26 insertions(+), 30 deletions(-) diff --git a/torch_npu/csrc/aten/npu_native_functions.yaml b/torch_npu/csrc/aten/npu_native_functions.yaml index 62f461b32c5..1f08a26654f 100644 --- a/torch_npu/csrc/aten/npu_native_functions.yaml +++ b/torch_npu/csrc/aten/npu_native_functions.yaml @@ -97,13 +97,9 @@ supported: - affine_grid_generator_backward - all.dim - all.out - - all.dimname - - all.dimname_out - allclose - any.dim - any.out - - any.dimname - - any.dimname_out - arange - arange.start - arange.start_step diff --git a/torch_npu/csrc/aten/ops/AffineGridGeneratorKernelNpu.cpp b/torch_npu/csrc/aten/ops/AffineGridGeneratorKernelNpu.cpp index d21c6256c61..8b7e0f3bd42 100644 --- a/torch_npu/csrc/aten/ops/AffineGridGeneratorKernelNpu.cpp +++ b/torch_npu/csrc/aten/ops/AffineGridGeneratorKernelNpu.cpp @@ -13,7 +13,7 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. - +#include "torch_npu/csrc/framework/utils/CalcuOpUtil.h" #include "torch_npu/csrc/framework/utils/OpAdapter.h" #include "torch_npu/csrc/aten/NPUNativeFunctions.h" @@ -30,7 +30,7 @@ at::Tensor& affine_grid_generator_npu_nocheck( for (int i = 0; i < value; i++) { tmp_vector.emplace_back(size[i]); } - at::Tensor sizeTensor_cpu = from_blob((void*)tmp_vector.data(), {value}, at::kInt); + at::Tensor sizeTensor_cpu = at::from_blob((void*)tmp_vector.data(), {value}, at::kInt); at::Tensor sizeTensor_npu = CalcuOpUtil::copy_tensor_host_to_device(sizeTensor_cpu); OpCommand cmd; @@ -51,7 +51,7 @@ at::Tensor NPUNativeFunctions::affine_grid_generator( TORCH_CHECK(size.size() == 4 || size.size() == 5, "AffineGridGenerator needs 4d or 5d size(input)."); // calculate the output size - SmallVector outputSize = { }; + at::SmallVector outputSize = { }; if(size.size() == 4) { outputSize = {size[0], size[2] * size[3], 2}; } else { diff --git a/torch_npu/csrc/aten/ops/AllKernelNpu.cpp b/torch_npu/csrc/aten/ops/AllKernelNpu.cpp index 0d12ca06f37..ae1d4a7771a 100644 --- a/torch_npu/csrc/aten/ops/AllKernelNpu.cpp +++ b/torch_npu/csrc/aten/ops/AllKernelNpu.cpp @@ -21,7 +21,6 @@ namespace at_npu { namespace native { -using namespace at::native::npu; inline at::Tensor all_out_npu_nocheck( at::Tensor& result, @@ -66,7 +65,7 @@ at::Tensor NPUNativeFunctions::all(const at::Tensor& self, int64_t dim, bool kee TORCH_CHECK(self.scalar_type() == at::ScalarType::Bool || self.scalar_type() == at::ScalarType::Byte, "all only supports torch.uint8 and torch.bool dtypes"); if (self.numel() == 0) { - at::Tensor res = OpPreparation::ApplyTensor({}, self.options().dtype(kInt), self).fill_(1).to(at::ScalarType::Bool); + at::Tensor res = OpPreparation::ApplyTensor({}, self.options().dtype(at::kInt), self).fill_(1).to(at::ScalarType::Bool); return res; } @@ -87,7 +86,7 @@ at::Tensor NPUNativeFunctions::all(const at::Tensor& self) { TORCH_CHECK(self.scalar_type() == at::ScalarType::Bool || self.scalar_type() == at::ScalarType::Byte, "all only supports torch.uint8 and torch.bool dtypes"); if (self.numel() == 0) { - at::Tensor res = OpPreparation::ApplyTensor({}, self.options().dtype(kInt), self).fill_(1).to(at::ScalarType::Bool); + at::Tensor res = OpPreparation::ApplyTensor({}, self.options().dtype(at::kInt), self).fill_(1).to(at::ScalarType::Bool); return res; } diff --git a/torch_npu/csrc/aten/ops/AnyKernelNpu.cpp b/torch_npu/csrc/aten/ops/AnyKernelNpu.cpp index 1353fd53ee6..7cdac57d78e 100644 --- a/torch_npu/csrc/aten/ops/AnyKernelNpu.cpp +++ b/torch_npu/csrc/aten/ops/AnyKernelNpu.cpp @@ -98,7 +98,7 @@ at::Tensor NPUNativeFunctions::any(const at::Tensor& self) { } // calculate the output size - IntArrayRef dims; + at::IntArrayRef dims; auto outputSize = reduce_ops_npu_output_size(self, dims, false); // construct the output tensor of the NPU diff --git a/torch_npu/csrc/aten/ops/ArangeKernelNpu.cpp b/torch_npu/csrc/aten/ops/ArangeKernelNpu.cpp index 94f5169863b..494594327d4 100644 --- a/torch_npu/csrc/aten/ops/ArangeKernelNpu.cpp +++ b/torch_npu/csrc/aten/ops/ArangeKernelNpu.cpp @@ -23,8 +23,8 @@ namespace native { // bool inputs are considered integral static inline bool allIntegral( - std::initializer_list> l) { - for (Scalar& s : l) { + std::initializer_list> l) { + for (at::Scalar& s : l) { if (!s.isIntegral(true)) { return false; } @@ -54,12 +54,12 @@ at::Tensor NPUNativeFunctions::arange( at::Scalar end, at::Scalar step, c10::optional dtype_opt, - c10::optional layout_opt, - c10::optional device_opt, + c10::optional layout_opt, + c10::optional device_opt, c10::optional pin_memory_opt) { auto device = device_or_default(device_opt); - TensorOptions option; + at::TensorOptions option; option = option.dtype(dtype_opt) .layout(layout_opt) .device(device) @@ -111,8 +111,8 @@ at::Tensor NPUNativeFunctions::arange( at::Scalar start, at::Scalar end, c10::optional dtype_opt, - c10::optional layout_opt, - c10::optional device_opt, + c10::optional layout_opt, + c10::optional device_opt, c10::optional pin_memory_opt) { return NPUNativeFunctions::arange(start, end, 1, dtype_opt, layout_opt, device_opt, pin_memory_opt); @@ -121,9 +121,9 @@ at::Tensor NPUNativeFunctions::arange( at::Tensor NPUNativeFunctions::arange( at::Scalar end, - c10::optional dtype_opt, - c10::optional layout_opt, - c10::optional device_opt, + c10::optional dtype_opt, + c10::optional layout_opt, + c10::optional device_opt, c10::optional pin_memory_opt) { return NPUNativeFunctions::arange(0, end, dtype_opt, layout_opt, device_opt, pin_memory_opt); // start = 0 @@ -147,7 +147,7 @@ at::Tensor& NPUNativeFunctions::arange_out( double size_arange = std::ceil(static_cast(end.toDouble() - start.toDouble()) / step.toDouble()); int64_t size_value = static_cast(size_arange); - SmallVector outputSize = {size_value}; + at::SmallVector outputSize = {size_value}; result.resize_(outputSize); arange_out_npu_nocheck(result, start, end, step); @@ -156,17 +156,18 @@ at::Tensor& NPUNativeFunctions::arange_out( } at::Tensor& arange_other_out_npu(at::Scalar start, at::Scalar end, at::Tensor& result) { - return arange_out(start, end, 1, result); + at::Scalar step = 1; + return NPUNativeFunctions::arange_out(start, end, step, result); } -at::Tensor& NPUNativeFunctions::arange_out(Scalar end, at::Tensor& result) { +at::Tensor& NPUNativeFunctions::arange_out(at::Scalar end, at::Tensor& result) { return arange_other_out_npu(0, end, result); } at::Tensor NPUNativeFunctions::_dim_arange(const at::Tensor& self, int64_t dim) { c10::optional dtype_opt(at::kInt); - c10::optional layout_opt(self.options().layout()); - c10::optional device_opt(self.options().device()); + c10::optional layout_opt(self.options().layout()); + c10::optional device_opt(self.options().device()); c10::optional pin_memory_opt(self.options().pinned_memory()); at::Tensor result = NPUNativeFunctions::arange(self.size(dim), dtype_opt, layout_opt, device_opt, pin_memory_opt); diff --git a/torch_npu/csrc/aten/ops/__And__KernelNpu.cpp b/torch_npu/csrc/aten/ops/__And__KernelNpu.cpp index 2e972330cad..7c463729bb4 100644 --- a/torch_npu/csrc/aten/ops/__And__KernelNpu.cpp +++ b/torch_npu/csrc/aten/ops/__And__KernelNpu.cpp @@ -65,7 +65,7 @@ at::Tensor& __and___out_npu_nocheck( return result; } -at::Tensor NPUNativeFunctions::__and___tensor(const at::Tensor& self, const at::Tensor& other) { +at::Tensor NPUNativeFunctions::__and__(const at::Tensor& self, const at::Tensor& other) { // calculate the output size at::Tensor outputTensor = __and___dest_output(self, other); auto outputSize = broadcast_ops_npu_output_size(self, other); @@ -77,7 +77,7 @@ at::Tensor NPUNativeFunctions::__and___tensor(const at::Tensor& self, const at:: return result; } -at::Tensor NPUNativeFunctions::__and___scalar(const at::Tensor& self, at::Scalar other) { +at::Tensor NPUNativeFunctions::__and__(const at::Tensor& self, at::Scalar other) { at::Tensor result = OpPreparation::ApplyTensor(self); __and___out_npu_nocheck(self, other,result); diff --git a/torch_npu/csrc/aten/ops/__Rshift__KernelNpu.cpp b/torch_npu/csrc/aten/ops/__Rshift__KernelNpu.cpp index e5e78322642..fdc790be75d 100644 --- a/torch_npu/csrc/aten/ops/__Rshift__KernelNpu.cpp +++ b/torch_npu/csrc/aten/ops/__Rshift__KernelNpu.cpp @@ -48,7 +48,7 @@ at::Tensor& __rshift___out_npu_nocheck( return result; } -at::Tensor NPUNativeFunctions::__rshift___tensor(const at::Tensor& self, const at::Tensor& other) { +at::Tensor NPUNativeFunctions::__rshift__(const at::Tensor& self, const at::Tensor& other) { // construct the output tensor of the NPU at::Tensor result = OpPreparation::ApplyTensor(self); __rshift___out_npu_nocheck( self, other,result); @@ -56,7 +56,7 @@ at::Tensor NPUNativeFunctions::__rshift___tensor(const at::Tensor& self, const a return result; } -at::Tensor NPUNativeFunctions::__rshift___scalar(const at::Tensor& self, at::Scalar other) { +at::Tensor NPUNativeFunctions::__rshift__(const at::Tensor& self, at::Scalar other) { // construct the output tensor of the NPU at::Tensor result = OpPreparation::ApplyTensor(self); -- Gitee From f9f7de9663a5d9342f31799efff72498043620c5 Mon Sep 17 00:00:00 2001 From: wangxiao Date: Thu, 27 Jan 2022 20:27:13 +0800 Subject: [PATCH 3/4] __and__, __rshift__, any, addmv clean code --- test/test_network_ops/test___and__.py | 4 ++-- test/test_network_ops/test_addmv.py | 16 ++++++++-------- test/test_network_ops/test_any.py | 4 ++-- torch_npu/csrc/aten/ops/__And__KernelNpu.cpp | 16 ++++++++-------- torch_npu/csrc/aten/ops/__Rshift__KernelNpu.cpp | 8 ++++---- 5 files changed, 24 insertions(+), 24 deletions(-) diff --git a/test/test_network_ops/test___and__.py b/test/test_network_ops/test___and__.py index e9a17293140..9d13f1db08e 100644 --- a/test/test_network_ops/test___and__.py +++ b/test/test_network_ops/test___and__.py @@ -25,14 +25,14 @@ class Test__And__(TestCase): def cpu_op_exec(self, input1, input2): output = input1.__and__(input2) if output.dtype != torch.int32: - output = output.to(torch.int32) + output = output.to(torch.int32) return output.numpy() def npu_op_exec(self, input1, input2): output = input1.__and__(input2) output = output.to("cpu") if output.dtype != torch.int32: - output = output.to(torch.int32) + output = output.to(torch.int32) return output.numpy() def test___And___shape_format(self, device): diff --git a/test/test_network_ops/test_addmv.py b/test/test_network_ops/test_addmv.py index a71f786a59a..73a787b05be 100644 --- a/test/test_network_ops/test_addmv.py +++ b/test/test_network_ops/test_addmv.py @@ -47,8 +47,8 @@ class TestAddmv(TestCase): for item in shape_format: input_a, npu_input_a = create_common_tensor(item[0], -2, 2) - input_b, npu_input_b= create_common_tensor(item[1], -2, 2) - input_c, npu_input_c= create_common_tensor(item[2], -2, 2) + input_b, npu_input_b = create_common_tensor(item[1], -2, 2) + input_c, npu_input_c = create_common_tensor(item[2], -2, 2) input_a = input_a.to(torch.float32) input_b = input_b.to(torch.float32) @@ -68,16 +68,16 @@ class TestAddmv(TestCase): for item in shape_format: input_a, npu_input_a = create_common_tensor(item[0], -2, 2) - input_b, npu_input_b= create_common_tensor(item[1], -2, 2) - input_c, npu_input_c= create_common_tensor(item[2], -2, 2) - input, npu_input= create_common_tensor(item[3], -2, 2) + input_b, npu_input_b = create_common_tensor(item[1], -2, 2) + input_c, npu_input_c = create_common_tensor(item[2], -2, 2) + _, npu_input = create_common_tensor(item[3], -2, 2) input_a = input_a.to(torch.float32) input_b = input_b.to(torch.float32) input_c = input_c.to(torch.float32) cpu_output = self.cpu_op_exec(input_a, input_b, input_c, 1, 1) - npu_output= self.npu_op_exec_out(npu_input_a, npu_input_b, npu_input_c, 1, 1, npu_input) + npu_output = self.npu_op_exec_out(npu_input_a, npu_input_b, npu_input_c, 1, 1, npu_input) cpu_output = cpu_output.astype(np.float16) self.assertRtolEqual(cpu_output, npu_output) @@ -90,8 +90,8 @@ class TestAddmv(TestCase): for item in shape_format: input_a, npu_input_a = create_common_tensor(item[0], -2, 2) - input_b, npu_input_b= create_common_tensor(item[1], -2, 2) - input_c, npu_input_c= create_common_tensor(item[2], -2, 2) + input_b, npu_input_b = create_common_tensor(item[1], -2, 2) + input_c, npu_input_c = create_common_tensor(item[2], -2, 2) cpu_output = self.cpu_op_exec(input_a, input_b, input_c, 1, 1) diff --git a/test/test_network_ops/test_any.py b/test/test_network_ops/test_any.py index 0241c17349e..c1cef3b617a 100644 --- a/test/test_network_ops/test_any.py +++ b/test/test_network_ops/test_any.py @@ -65,8 +65,8 @@ class TestAny(TestCase): def npu_op_out_exec1(self, input1, dim, keepdim): shape = list(input1.shape) - output0 = torch.randn(shape)>0 - output1 = torch.randn(shape.pop())>0 + output0 = torch.randn(shape) > 0 + output1 = torch.randn(shape.pop()) > 0 output0 = output0.npu() output1 = output1.npu() torch.any(input1, dim=dim, keepdim=keepdim, out=output0) diff --git a/torch_npu/csrc/aten/ops/__And__KernelNpu.cpp b/torch_npu/csrc/aten/ops/__And__KernelNpu.cpp index 7c463729bb4..fae417fb64a 100644 --- a/torch_npu/csrc/aten/ops/__And__KernelNpu.cpp +++ b/torch_npu/csrc/aten/ops/__And__KernelNpu.cpp @@ -21,7 +21,7 @@ namespace at_npu { namespace native { -at::Tensor __and___dest_output(const at::Tensor& self, const at::Tensor& other) { +at::Tensor and_dest_output(const at::Tensor& self, const at::Tensor& other) { bool isSelfWrapped = CalcuOpUtil::is_scalar_wrapped_to_tensor(self); if (not isSelfWrapped) { @@ -31,7 +31,7 @@ at::Tensor __and___dest_output(const at::Tensor& self, const at::Tensor& other) } } -at::Tensor& __and___out_npu_nocheck( +at::Tensor& and_out_npu_nocheck( const at::Tensor& self, const at::Scalar other, at::Tensor& result) { @@ -45,14 +45,14 @@ at::Tensor& __and___out_npu_nocheck( return result; } -at::Tensor& __and___out_npu_nocheck( +at::Tensor& and_out_npu_nocheck( const at::Tensor& self, const at::Tensor& other, at::Tensor& result) { if (other.dim() == 0 && !other.is_npu()) { - __and___out_npu_nocheck(self, other.item(),result); + and_out_npu_nocheck(self, other.item(),result); } else if (self.dim() == 0 && !self.is_npu()) { - __and___out_npu_nocheck(other, self.item(),result); + and_out_npu_nocheck(other, self.item(),result); } else { OpCommand cmd; cmd.Name((self.scalar_type() == at::ScalarType::Bool) ? "LogicalAnd" : "BitwiseAnd") @@ -67,19 +67,19 @@ at::Tensor& __and___out_npu_nocheck( at::Tensor NPUNativeFunctions::__and__(const at::Tensor& self, const at::Tensor& other) { // calculate the output size - at::Tensor outputTensor = __and___dest_output(self, other); + at::Tensor outputTensor = and_dest_output(self, other); auto outputSize = broadcast_ops_npu_output_size(self, other); // construct the output tensor of the NPU at::Tensor result = OpPreparation::ApplyTensor(outputTensor, outputSize); // calculate the output result of the NPU - __and___out_npu_nocheck(self, other,result); + and_out_npu_nocheck(self, other,result); return result; } at::Tensor NPUNativeFunctions::__and__(const at::Tensor& self, at::Scalar other) { at::Tensor result = OpPreparation::ApplyTensor(self); - __and___out_npu_nocheck(self, other,result); + and_out_npu_nocheck(self, other,result); return result; } diff --git a/torch_npu/csrc/aten/ops/__Rshift__KernelNpu.cpp b/torch_npu/csrc/aten/ops/__Rshift__KernelNpu.cpp index fdc790be75d..954ac65d503 100644 --- a/torch_npu/csrc/aten/ops/__Rshift__KernelNpu.cpp +++ b/torch_npu/csrc/aten/ops/__Rshift__KernelNpu.cpp @@ -20,7 +20,7 @@ namespace at_npu { namespace native { -at::Tensor& __rshift___out_npu_nocheck( +at::Tensor& rshift_out_npu_nocheck( const at::Tensor& self, at::Scalar other, at::Tensor& result) { @@ -34,7 +34,7 @@ at::Tensor& __rshift___out_npu_nocheck( return result; } -at::Tensor& __rshift___out_npu_nocheck( +at::Tensor& rshift_out_npu_nocheck( const at::Tensor& self, const at::Tensor& other, at::Tensor& result) { @@ -51,7 +51,7 @@ at::Tensor& __rshift___out_npu_nocheck( at::Tensor NPUNativeFunctions::__rshift__(const at::Tensor& self, const at::Tensor& other) { // construct the output tensor of the NPU at::Tensor result = OpPreparation::ApplyTensor(self); - __rshift___out_npu_nocheck( self, other,result); + rshift_out_npu_nocheck(self, other,result); return result; } @@ -60,7 +60,7 @@ at::Tensor NPUNativeFunctions::__rshift__(const at::Tensor& self, at::Scalar oth // construct the output tensor of the NPU at::Tensor result = OpPreparation::ApplyTensor(self); - __rshift___out_npu_nocheck(self, other, result); + rshift_out_npu_nocheck(self, other, result); return result; } -- Gitee From 3a6d0a692642fa6a935ad19c5478f9252b36ca48 Mon Sep 17 00:00:00 2001 From: wangxiao Date: Fri, 28 Jan 2022 09:11:18 +0800 Subject: [PATCH 4/4] rm redundant code --- torch_npu/csrc/aten/ops/AllKernelNpu.cpp | 2 -- 1 file changed, 2 deletions(-) diff --git a/torch_npu/csrc/aten/ops/AllKernelNpu.cpp b/torch_npu/csrc/aten/ops/AllKernelNpu.cpp index ae1d4a7771a..d303b5cddbd 100644 --- a/torch_npu/csrc/aten/ops/AllKernelNpu.cpp +++ b/torch_npu/csrc/aten/ops/AllKernelNpu.cpp @@ -13,8 +13,6 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. -#include "c10/npu/OptionsManager.h" - #include "torch_npu/csrc/framework/utils/CalcuOpUtil.h" #include "torch_npu/csrc/framework/utils/OpAdapter.h" #include "torch_npu/csrc/aten/NPUNativeFunctions.h" -- Gitee