From 2bfceadcf7210699faed27c2cf89f9cc54d77516 Mon Sep 17 00:00:00 2001 From: 18190895210 Date: Tue, 25 Jan 2022 16:27:02 +0800 Subject: [PATCH 01/12] add files: test_logical_or.py LogicalOrKernelNpu.cpp --- test/test_network_ops/test_logical_or.py | 140 ++++++++++++++++++ torch_npu/csrc/aten/npu_native_functions.yaml | 3 + .../csrc/aten/ops/LogicalOrKernelNpu.cpp | 75 ++++++++++ 3 files changed, 218 insertions(+) create mode 100644 test/test_network_ops/test_logical_or.py create mode 100644 torch_npu/csrc/aten/ops/LogicalOrKernelNpu.cpp diff --git a/test/test_network_ops/test_logical_or.py b/test/test_network_ops/test_logical_or.py new file mode 100644 index 0000000000..be14da7001 --- /dev/null +++ b/test/test_network_ops/test_logical_or.py @@ -0,0 +1,140 @@ +# Copyright (c) 2020, Huawei Technologies.All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import copy +import torch +import torch_npu +import numpy as np +from torch_npu.testing.common_utils import TestCase, run_tests +from torch_npu.testing.common_device_type import instantiate_device_type_tests +from torch_npu.testing.util_test import create_common_tensor + + +class TestLogicalOr(TestCase): + + def generate_single_data(self, min_d, max_d, shape, dtype): + input1 = np.random.uniform(min_d, max_d, shape).astype(dtype) + npu_input1 = torch.from_numpy(input1) + + return npu_input1 + + def generate_data(self, min_d, max_d, shape, dtype): + input1 = np.random.uniform(min_d, max_d, shape).astype(dtype) + input2 = np.random.uniform(min_d, max_d, shape).astype(dtype) + + #modify from numpy.ndarray to torch.tensor + npu_input1 = torch.from_numpy(input1) + npu_input2 = torch.from_numpy(input2) + + return npu_input1, npu_input2 + + def generate_three_data(self, min_d, max_d, shape, dtype): + input1 = np.random.uniform(min_d, max_d, shape).astype(dtype) + input2 = np.random.uniform(min_d, max_d, shape).astype(dtype) + input3 = np.random.uniform(min_d, max_d, shape).astype(dtype) + + #modify from numpy.ndarray to torch.tensor + npu_input1 = torch.from_numpy(input1) + npu_input2 = torch.from_numpy(input2) + npu_input3 = torch.from_numpy(input3) + + return npu_input1, npu_input2, npu_input3 + + def generate_bool_data(self, min_d, max_d, shape, dtype): + input1 = np.random.randint(min_d, max_d, shape).astype(dtype) + input2 = np.random.randint(min_d, max_d, shape).astype(dtype) + + #modify from numpy.ndarray to torch.tensor + npu_input1 = torch.from_numpy(input1) + npu_input2 = torch.from_numpy(input2) + + return npu_input1, npu_input2 + + def cpu_op_exec(self, input1, input2): + output = torch.logical_or(input1, input2) + output = output.numpy() + return output + + def npu_op_exec(self, input1, input2): + input1 = input1.to("npu") + input2 = input2.to("npu") + output = torch.logical_or(input1, input2) + output = output.to("cpu") + output = output.numpy() + return output + + def cpu_op_exec_out(self, input1, input2, input3): + torch.logical_or(input1, input2, out=input3) + output = input3.numpy() + return output + + def npu_op_exec_out(self, input1, input2, input3): + input1 = input1.to("npu") + input2 = input2.to("npu") + output = input3.to("npu") + torch.logical_or(input1, input2, out=output) + output = output.to("cpu") + output = output.numpy() + return output + + def cpu_op_exec_(self, input1, input2): + output = torch.Tensor.logical_or_(input1, input2) + output = output.numpy() + return output + + def npu_op_exec_(self, input1, input2): + input1 = input1.to("npu") + input2 = input2.to("npu") + output = torch.Tensor.logical_or_(input1, input2) + output = output.to("cpu") + output = output.numpy() + return output + + def logical_or_out_result(self, shape_format): + for item in shape_format: + cpu_input1 = torch.randn(item[0])<0 + cpu_input2 = torch.randn(item[0])<0 + cpu_input3 = torch.randn(item[1])<0 + cpu_output_out = self.cpu_op_exec_out(cpu_input1, cpu_input2, cpu_input3) + npu_output_out = self.npu_op_exec_out(cpu_input1, cpu_input2, cpu_input3) + + self.assertRtolEqual(cpu_output_out, npu_output_out) + + def test_logical_or_out(self, device): + shape_format = [ + [[128, 116, 14, 14], [256, 116, 1, 1, 28]], + [[128, 3, 224, 224], [3, 3, 3]], + [[128, 116, 14, 14], [128, 116, 14, 14]], + [[256, 128, 7, 7], [128, 256, 3, 3, 28]], + [[2, 3, 3, 3], [3, 1, 3]], + [[128, 232, 7, 7], [128, 232, 7, 7]], + ] + self.logical_or_out_result(shape_format) + + def test_logical_or_bool(self, device): + npu_input1, npu_input2 = self.generate_bool_data(0, 2, (10, 64), np.bool) + cpu_output = self.cpu_op_exec(npu_input1, npu_input2) + npu_output = self.npu_op_exec(npu_input1, npu_input2) + self.assertRtolEqual(cpu_output, npu_output) + + def test_logical_or_inplace_bool(self, device): + npu_input1, npu_input2 = self.generate_bool_data(0, 2, (10, 64), np.bool) + cpu_output = self.cpu_op_exec_(npu_input1, npu_input2) + npu_output = self.npu_op_exec_(npu_input1, npu_input2) + self.assertRtolEqual(cpu_output, npu_output) + +instantiate_device_type_tests(TestLogicalOr, globals(), except_for='cpu') +if __name__ == "__main__": + run_tests() diff --git a/torch_npu/csrc/aten/npu_native_functions.yaml b/torch_npu/csrc/aten/npu_native_functions.yaml index 412fb9da7c..0166917590 100644 --- a/torch_npu/csrc/aten/npu_native_functions.yaml +++ b/torch_npu/csrc/aten/npu_native_functions.yaml @@ -240,6 +240,9 @@ supported: - addcdiv - addcdiv_ - addcdiv.out + - logical_or + - logical_or_ + - logical.out custom: - func: npu_transpose_to_contiguous(Tensor self) -> Tensor diff --git a/torch_npu/csrc/aten/ops/LogicalOrKernelNpu.cpp b/torch_npu/csrc/aten/ops/LogicalOrKernelNpu.cpp new file mode 100644 index 0000000000..d9a71e8cc8 --- /dev/null +++ b/torch_npu/csrc/aten/ops/LogicalOrKernelNpu.cpp @@ -0,0 +1,75 @@ +// Copyright (c) 2020, Huawei Technologies.All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "torch_npu/csrc/framework/utils/CalcuOpUtil.h" +#include "torch_npu/csrc/framework/utils/KernelNpuOutputSize.h" +#include "torch_npu/csrc/framework/utils/NpuUtils.h" +#include "torch_npu/csrc/framework/utils/OpTemplate.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" + +namespace at_npu { +namespace native { + +at::Tensor& logical_or_out_npu_nocheck( + const at::Tensor& self, + const at::Tensor& other, + at::Tensor& result) { + OpCommand cmd; + cmd.Name("LogicalOr") + .Input(self) + .Input(other) + .Output(result) + .Run(); + + return result; +} + +at::Tensor& NPUNativeFunctions::logical_or_out(at::Tensor& result, const at::Tensor& self, const at::Tensor& other) { + auto outputSize = broadcast_ops_npu_output_size(self, other); + OpPreparation::CheckOut( + {self}, + result, + CalcuOpUtil::get_tensor_npu_format(self), + result.scalar_type(), + outputSize); + logical_or_out_npu_nocheck(self, other, result); + + return result; +} + +at::Tensor NPUNativeFunctions::logical_or(const at::Tensor& self, const at::Tensor& other) { + // calculate the output size + auto outputSize = broadcast_ops_npu_output_size(self, other); + at::Tensor result = OpPreparation::ApplyTensor(self, outputSize); + logical_or_out_npu_nocheck(self, other, result); + + return result.toType(kBool); +} + + +at::Tensor& NPUNativeFunctions::logical_or_(at::Tensor& self, const at::Tensor& other) { + OpPreparation::CheckMemory({self, other},{self}) + if (!NpuUtils::check_match(&self)) { + at::Tensor contiguousSelf = NpuUtils::format_contiguous(self); + at::Tensor result = logical_or_out_npu_nocheck(contiguousSelf, other, contiguousSelf); + NpuUtils::format_fresh_view(self, result); + } else { + logical_or_out_npu_nocheck(self, other, self); + } + + return self; +} + +} // namespace native +} // namespace at \ No newline at end of file -- Gitee From 8d06ddb30b7ccd2fa35aaaf5b20e5655c93a3958 Mon Sep 17 00:00:00 2001 From: 18190895210 Date: Tue, 25 Jan 2022 16:27:56 +0800 Subject: [PATCH 02/12] =?UTF-8?q?yaml=EF=BC=9Alogical=5For.out?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- torch_npu/csrc/aten/npu_native_functions.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_npu/csrc/aten/npu_native_functions.yaml b/torch_npu/csrc/aten/npu_native_functions.yaml index 0166917590..2ce537abc1 100644 --- a/torch_npu/csrc/aten/npu_native_functions.yaml +++ b/torch_npu/csrc/aten/npu_native_functions.yaml @@ -242,7 +242,7 @@ supported: - addcdiv.out - logical_or - logical_or_ - - logical.out + - logical_or.out custom: - func: npu_transpose_to_contiguous(Tensor self) -> Tensor -- Gitee From bd391171b6051d527522ece7c04ce5b8e40b9c5d Mon Sep 17 00:00:00 2001 From: 18190895210 Date: Tue, 25 Jan 2022 17:06:25 +0800 Subject: [PATCH 03/12] =?UTF-8?q?=E5=88=86=E5=8F=B7?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- torch_npu/csrc/aten/ops/LogicalOrKernelNpu.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_npu/csrc/aten/ops/LogicalOrKernelNpu.cpp b/torch_npu/csrc/aten/ops/LogicalOrKernelNpu.cpp index d9a71e8cc8..bac0ef2600 100644 --- a/torch_npu/csrc/aten/ops/LogicalOrKernelNpu.cpp +++ b/torch_npu/csrc/aten/ops/LogicalOrKernelNpu.cpp @@ -59,7 +59,7 @@ at::Tensor NPUNativeFunctions::logical_or(const at::Tensor& self, const at::Tens at::Tensor& NPUNativeFunctions::logical_or_(at::Tensor& self, const at::Tensor& other) { - OpPreparation::CheckMemory({self, other},{self}) + OpPreparation::CheckMemory({self, other},{self}); if (!NpuUtils::check_match(&self)) { at::Tensor contiguousSelf = NpuUtils::format_contiguous(self); at::Tensor result = logical_or_out_npu_nocheck(contiguousSelf, other, contiguousSelf); -- Gitee From 20228c8f210c6202fe97219b9b6af7f75a269750 Mon Sep 17 00:00:00 2001 From: 18190895210 Date: Tue, 25 Jan 2022 17:12:37 +0800 Subject: [PATCH 04/12] =?UTF-8?q?logical=5For=5Fout=E5=8F=82=E6=95=B0?= =?UTF-8?q?=E4=BF=AE=E6=94=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- torch_npu/csrc/aten/ops/LogicalOrKernelNpu.cpp | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/torch_npu/csrc/aten/ops/LogicalOrKernelNpu.cpp b/torch_npu/csrc/aten/ops/LogicalOrKernelNpu.cpp index bac0ef2600..6ff797c02e 100644 --- a/torch_npu/csrc/aten/ops/LogicalOrKernelNpu.cpp +++ b/torch_npu/csrc/aten/ops/LogicalOrKernelNpu.cpp @@ -35,7 +35,10 @@ at::Tensor& logical_or_out_npu_nocheck( return result; } -at::Tensor& NPUNativeFunctions::logical_or_out(at::Tensor& result, const at::Tensor& self, const at::Tensor& other) { +at::Tensor& NPUNativeFunctions::logical_or_out( + const at::Tensor& self, + const at::Tensor& other, + at::Tensor& result) { auto outputSize = broadcast_ops_npu_output_size(self, other); OpPreparation::CheckOut( {self}, @@ -54,7 +57,7 @@ at::Tensor NPUNativeFunctions::logical_or(const at::Tensor& self, const at::Tens at::Tensor result = OpPreparation::ApplyTensor(self, outputSize); logical_or_out_npu_nocheck(self, other, result); - return result.toType(kBool); + return result.toType(at::kBool); } -- Gitee From c6b0bab21c85c00a380928461fdc321111476782 Mon Sep 17 00:00:00 2001 From: 18190895210 Date: Tue, 25 Jan 2022 17:17:04 +0800 Subject: [PATCH 05/12] codecheck --- test/test_network_ops/test_logical_or.py | 2 +- torch_npu/csrc/aten/ops/LogicalOrKernelNpu.cpp | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/test/test_network_ops/test_logical_or.py b/test/test_network_ops/test_logical_or.py index be14da7001..f02d35c446 100644 --- a/test/test_network_ops/test_logical_or.py +++ b/test/test_network_ops/test_logical_or.py @@ -17,13 +17,13 @@ import copy import torch import torch_npu import numpy as np + from torch_npu.testing.common_utils import TestCase, run_tests from torch_npu.testing.common_device_type import instantiate_device_type_tests from torch_npu.testing.util_test import create_common_tensor class TestLogicalOr(TestCase): - def generate_single_data(self, min_d, max_d, shape, dtype): input1 = np.random.uniform(min_d, max_d, shape).astype(dtype) npu_input1 = torch.from_numpy(input1) diff --git a/torch_npu/csrc/aten/ops/LogicalOrKernelNpu.cpp b/torch_npu/csrc/aten/ops/LogicalOrKernelNpu.cpp index 6ff797c02e..ea05512437 100644 --- a/torch_npu/csrc/aten/ops/LogicalOrKernelNpu.cpp +++ b/torch_npu/csrc/aten/ops/LogicalOrKernelNpu.cpp @@ -75,4 +75,4 @@ at::Tensor& NPUNativeFunctions::logical_or_(at::Tensor& self, const at::Tensor& } } // namespace native -} // namespace at \ No newline at end of file +} // namespace at_npu \ No newline at end of file -- Gitee From be1f7659ed7eea937c2564f2bd64015772d16247 Mon Sep 17 00:00:00 2001 From: 18190895210 Date: Tue, 25 Jan 2022 17:44:19 +0800 Subject: [PATCH 06/12] codecheck --- test/test_network_ops/test_logical_or.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/test_network_ops/test_logical_or.py b/test/test_network_ops/test_logical_or.py index f02d35c446..7e423b57fe 100644 --- a/test/test_network_ops/test_logical_or.py +++ b/test/test_network_ops/test_logical_or.py @@ -104,9 +104,9 @@ class TestLogicalOr(TestCase): def logical_or_out_result(self, shape_format): for item in shape_format: - cpu_input1 = torch.randn(item[0])<0 - cpu_input2 = torch.randn(item[0])<0 - cpu_input3 = torch.randn(item[1])<0 + cpu_input1 = torch.randn(item[0]) < 0 + cpu_input2 = torch.randn(item[0]) < 0 + cpu_input3 = torch.randn(item[1]) < 0 cpu_output_out = self.cpu_op_exec_out(cpu_input1, cpu_input2, cpu_input3) npu_output_out = self.npu_op_exec_out(cpu_input1, cpu_input2, cpu_input3) -- Gitee From 3d6fdbe41aaeeb8f1e76d3a4c5f0f86f127488f2 Mon Sep 17 00:00:00 2001 From: 18190895210 Date: Tue, 25 Jan 2022 17:46:07 +0800 Subject: [PATCH 07/12] =?UTF-8?q?=E7=A9=BA=E8=A1=8C?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- torch_npu/csrc/aten/ops/LogicalOrKernelNpu.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_npu/csrc/aten/ops/LogicalOrKernelNpu.cpp b/torch_npu/csrc/aten/ops/LogicalOrKernelNpu.cpp index ea05512437..bd95781769 100644 --- a/torch_npu/csrc/aten/ops/LogicalOrKernelNpu.cpp +++ b/torch_npu/csrc/aten/ops/LogicalOrKernelNpu.cpp @@ -75,4 +75,4 @@ at::Tensor& NPUNativeFunctions::logical_or_(at::Tensor& self, const at::Tensor& } } // namespace native -} // namespace at_npu \ No newline at end of file +} // namespace at_npu -- Gitee From 4131c5d28f64aa9437d038b59e9f74b95eb79d28 Mon Sep 17 00:00:00 2001 From: 18190895210 Date: Tue, 25 Jan 2022 17:49:34 +0800 Subject: [PATCH 08/12] =?UTF-8?q?=E5=88=A0=E9=99=A4=E5=A4=9A=E4=BD=99?= =?UTF-8?q?=E7=A9=BA=E8=A1=8C?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- torch_npu/csrc/aten/ops/LogicalOrKernelNpu.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/torch_npu/csrc/aten/ops/LogicalOrKernelNpu.cpp b/torch_npu/csrc/aten/ops/LogicalOrKernelNpu.cpp index bd95781769..0b10f685b8 100644 --- a/torch_npu/csrc/aten/ops/LogicalOrKernelNpu.cpp +++ b/torch_npu/csrc/aten/ops/LogicalOrKernelNpu.cpp @@ -60,7 +60,6 @@ at::Tensor NPUNativeFunctions::logical_or(const at::Tensor& self, const at::Tens return result.toType(at::kBool); } - at::Tensor& NPUNativeFunctions::logical_or_(at::Tensor& self, const at::Tensor& other) { OpPreparation::CheckMemory({self, other},{self}); if (!NpuUtils::check_match(&self)) { -- Gitee From 4286d10fa0321434293ffd6d633718a150ce5907 Mon Sep 17 00:00:00 2001 From: 18190895210 Date: Wed, 26 Jan 2022 13:24:34 +0800 Subject: [PATCH 09/12] =?UTF-8?q?=E5=8E=BB=E9=99=A4yaml=E7=9A=84=E4=BF=AE?= =?UTF-8?q?=E6=94=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- torch_npu/csrc/aten/npu_native_functions.yaml | 3 --- 1 file changed, 3 deletions(-) diff --git a/torch_npu/csrc/aten/npu_native_functions.yaml b/torch_npu/csrc/aten/npu_native_functions.yaml index 2ce537abc1..412fb9da7c 100644 --- a/torch_npu/csrc/aten/npu_native_functions.yaml +++ b/torch_npu/csrc/aten/npu_native_functions.yaml @@ -240,9 +240,6 @@ supported: - addcdiv - addcdiv_ - addcdiv.out - - logical_or - - logical_or_ - - logical_or.out custom: - func: npu_transpose_to_contiguous(Tensor self) -> Tensor -- Gitee From 2baf6513063bdbfddeefaf6c97c5d1b7067d5145 Mon Sep 17 00:00:00 2001 From: 18190895210 Date: Wed, 26 Jan 2022 13:33:40 +0800 Subject: [PATCH 10/12] codecheck --- test/test_network_ops/test_logical_or.py | 1 - 1 file changed, 1 deletion(-) diff --git a/test/test_network_ops/test_logical_or.py b/test/test_network_ops/test_logical_or.py index 7e423b57fe..ef9e180191 100644 --- a/test/test_network_ops/test_logical_or.py +++ b/test/test_network_ops/test_logical_or.py @@ -22,7 +22,6 @@ from torch_npu.testing.common_utils import TestCase, run_tests from torch_npu.testing.common_device_type import instantiate_device_type_tests from torch_npu.testing.util_test import create_common_tensor - class TestLogicalOr(TestCase): def generate_single_data(self, min_d, max_d, shape, dtype): input1 = np.random.uniform(min_d, max_d, shape).astype(dtype) -- Gitee From e04bceb1e0a5cb6bdcd59f028d21a1346e3fef2e Mon Sep 17 00:00:00 2001 From: 18190895210 Date: Wed, 26 Jan 2022 14:54:38 +0800 Subject: [PATCH 11/12] =?UTF-8?q?=E5=88=A0=E9=99=A4=E7=A9=BA=E8=A1=8C?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- test/test_network_ops/test_logical_or.py | 11 ----------- torch_npu/csrc/aten/ops/LogicalOrKernelNpu.cpp | 4 ---- 2 files changed, 15 deletions(-) diff --git a/test/test_network_ops/test_logical_or.py b/test/test_network_ops/test_logical_or.py index ef9e180191..dba839692d 100644 --- a/test/test_network_ops/test_logical_or.py +++ b/test/test_network_ops/test_logical_or.py @@ -26,39 +26,29 @@ class TestLogicalOr(TestCase): def generate_single_data(self, min_d, max_d, shape, dtype): input1 = np.random.uniform(min_d, max_d, shape).astype(dtype) npu_input1 = torch.from_numpy(input1) - return npu_input1 def generate_data(self, min_d, max_d, shape, dtype): input1 = np.random.uniform(min_d, max_d, shape).astype(dtype) input2 = np.random.uniform(min_d, max_d, shape).astype(dtype) - - #modify from numpy.ndarray to torch.tensor npu_input1 = torch.from_numpy(input1) npu_input2 = torch.from_numpy(input2) - return npu_input1, npu_input2 def generate_three_data(self, min_d, max_d, shape, dtype): input1 = np.random.uniform(min_d, max_d, shape).astype(dtype) input2 = np.random.uniform(min_d, max_d, shape).astype(dtype) input3 = np.random.uniform(min_d, max_d, shape).astype(dtype) - - #modify from numpy.ndarray to torch.tensor npu_input1 = torch.from_numpy(input1) npu_input2 = torch.from_numpy(input2) npu_input3 = torch.from_numpy(input3) - return npu_input1, npu_input2, npu_input3 def generate_bool_data(self, min_d, max_d, shape, dtype): input1 = np.random.randint(min_d, max_d, shape).astype(dtype) input2 = np.random.randint(min_d, max_d, shape).astype(dtype) - - #modify from numpy.ndarray to torch.tensor npu_input1 = torch.from_numpy(input1) npu_input2 = torch.from_numpy(input2) - return npu_input1, npu_input2 def cpu_op_exec(self, input1, input2): @@ -108,7 +98,6 @@ class TestLogicalOr(TestCase): cpu_input3 = torch.randn(item[1]) < 0 cpu_output_out = self.cpu_op_exec_out(cpu_input1, cpu_input2, cpu_input3) npu_output_out = self.npu_op_exec_out(cpu_input1, cpu_input2, cpu_input3) - self.assertRtolEqual(cpu_output_out, npu_output_out) def test_logical_or_out(self, device): diff --git a/torch_npu/csrc/aten/ops/LogicalOrKernelNpu.cpp b/torch_npu/csrc/aten/ops/LogicalOrKernelNpu.cpp index 0b10f685b8..9f50bd132e 100644 --- a/torch_npu/csrc/aten/ops/LogicalOrKernelNpu.cpp +++ b/torch_npu/csrc/aten/ops/LogicalOrKernelNpu.cpp @@ -31,7 +31,6 @@ at::Tensor& logical_or_out_npu_nocheck( .Input(other) .Output(result) .Run(); - return result; } @@ -47,7 +46,6 @@ at::Tensor& NPUNativeFunctions::logical_or_out( result.scalar_type(), outputSize); logical_or_out_npu_nocheck(self, other, result); - return result; } @@ -56,7 +54,6 @@ at::Tensor NPUNativeFunctions::logical_or(const at::Tensor& self, const at::Tens auto outputSize = broadcast_ops_npu_output_size(self, other); at::Tensor result = OpPreparation::ApplyTensor(self, outputSize); logical_or_out_npu_nocheck(self, other, result); - return result.toType(at::kBool); } @@ -69,7 +66,6 @@ at::Tensor& NPUNativeFunctions::logical_or_(at::Tensor& self, const at::Tensor& } else { logical_or_out_npu_nocheck(self, other, self); } - return self; } -- Gitee From 14430160e44c8c91babe2d59e11daf6fdd0b1e09 Mon Sep 17 00:00:00 2001 From: 18190895210 Date: Tue, 25 Jan 2022 16:27:02 +0800 Subject: [PATCH 12/12] add files: test_logical_or.py LogicalOrKernelNpu.cpp MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit yaml:logical_or.out 分号 logical_or_out参数修改 codecheck codecheck 空行 删除多余空行 去除yaml的修改 codecheck 删除空行 --- test/test_network_ops/test_logical_or.py | 128 ++++++++++++++++++ .../csrc/aten/ops/LogicalOrKernelNpu.cpp | 73 ++++++++++ 2 files changed, 201 insertions(+) create mode 100644 test/test_network_ops/test_logical_or.py create mode 100644 torch_npu/csrc/aten/ops/LogicalOrKernelNpu.cpp diff --git a/test/test_network_ops/test_logical_or.py b/test/test_network_ops/test_logical_or.py new file mode 100644 index 0000000000..dba839692d --- /dev/null +++ b/test/test_network_ops/test_logical_or.py @@ -0,0 +1,128 @@ +# Copyright (c) 2020, Huawei Technologies.All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import copy +import torch +import torch_npu +import numpy as np + +from torch_npu.testing.common_utils import TestCase, run_tests +from torch_npu.testing.common_device_type import instantiate_device_type_tests +from torch_npu.testing.util_test import create_common_tensor + +class TestLogicalOr(TestCase): + def generate_single_data(self, min_d, max_d, shape, dtype): + input1 = np.random.uniform(min_d, max_d, shape).astype(dtype) + npu_input1 = torch.from_numpy(input1) + return npu_input1 + + def generate_data(self, min_d, max_d, shape, dtype): + input1 = np.random.uniform(min_d, max_d, shape).astype(dtype) + input2 = np.random.uniform(min_d, max_d, shape).astype(dtype) + npu_input1 = torch.from_numpy(input1) + npu_input2 = torch.from_numpy(input2) + return npu_input1, npu_input2 + + def generate_three_data(self, min_d, max_d, shape, dtype): + input1 = np.random.uniform(min_d, max_d, shape).astype(dtype) + input2 = np.random.uniform(min_d, max_d, shape).astype(dtype) + input3 = np.random.uniform(min_d, max_d, shape).astype(dtype) + npu_input1 = torch.from_numpy(input1) + npu_input2 = torch.from_numpy(input2) + npu_input3 = torch.from_numpy(input3) + return npu_input1, npu_input2, npu_input3 + + def generate_bool_data(self, min_d, max_d, shape, dtype): + input1 = np.random.randint(min_d, max_d, shape).astype(dtype) + input2 = np.random.randint(min_d, max_d, shape).astype(dtype) + npu_input1 = torch.from_numpy(input1) + npu_input2 = torch.from_numpy(input2) + return npu_input1, npu_input2 + + def cpu_op_exec(self, input1, input2): + output = torch.logical_or(input1, input2) + output = output.numpy() + return output + + def npu_op_exec(self, input1, input2): + input1 = input1.to("npu") + input2 = input2.to("npu") + output = torch.logical_or(input1, input2) + output = output.to("cpu") + output = output.numpy() + return output + + def cpu_op_exec_out(self, input1, input2, input3): + torch.logical_or(input1, input2, out=input3) + output = input3.numpy() + return output + + def npu_op_exec_out(self, input1, input2, input3): + input1 = input1.to("npu") + input2 = input2.to("npu") + output = input3.to("npu") + torch.logical_or(input1, input2, out=output) + output = output.to("cpu") + output = output.numpy() + return output + + def cpu_op_exec_(self, input1, input2): + output = torch.Tensor.logical_or_(input1, input2) + output = output.numpy() + return output + + def npu_op_exec_(self, input1, input2): + input1 = input1.to("npu") + input2 = input2.to("npu") + output = torch.Tensor.logical_or_(input1, input2) + output = output.to("cpu") + output = output.numpy() + return output + + def logical_or_out_result(self, shape_format): + for item in shape_format: + cpu_input1 = torch.randn(item[0]) < 0 + cpu_input2 = torch.randn(item[0]) < 0 + cpu_input3 = torch.randn(item[1]) < 0 + cpu_output_out = self.cpu_op_exec_out(cpu_input1, cpu_input2, cpu_input3) + npu_output_out = self.npu_op_exec_out(cpu_input1, cpu_input2, cpu_input3) + self.assertRtolEqual(cpu_output_out, npu_output_out) + + def test_logical_or_out(self, device): + shape_format = [ + [[128, 116, 14, 14], [256, 116, 1, 1, 28]], + [[128, 3, 224, 224], [3, 3, 3]], + [[128, 116, 14, 14], [128, 116, 14, 14]], + [[256, 128, 7, 7], [128, 256, 3, 3, 28]], + [[2, 3, 3, 3], [3, 1, 3]], + [[128, 232, 7, 7], [128, 232, 7, 7]], + ] + self.logical_or_out_result(shape_format) + + def test_logical_or_bool(self, device): + npu_input1, npu_input2 = self.generate_bool_data(0, 2, (10, 64), np.bool) + cpu_output = self.cpu_op_exec(npu_input1, npu_input2) + npu_output = self.npu_op_exec(npu_input1, npu_input2) + self.assertRtolEqual(cpu_output, npu_output) + + def test_logical_or_inplace_bool(self, device): + npu_input1, npu_input2 = self.generate_bool_data(0, 2, (10, 64), np.bool) + cpu_output = self.cpu_op_exec_(npu_input1, npu_input2) + npu_output = self.npu_op_exec_(npu_input1, npu_input2) + self.assertRtolEqual(cpu_output, npu_output) + +instantiate_device_type_tests(TestLogicalOr, globals(), except_for='cpu') +if __name__ == "__main__": + run_tests() diff --git a/torch_npu/csrc/aten/ops/LogicalOrKernelNpu.cpp b/torch_npu/csrc/aten/ops/LogicalOrKernelNpu.cpp new file mode 100644 index 0000000000..9f50bd132e --- /dev/null +++ b/torch_npu/csrc/aten/ops/LogicalOrKernelNpu.cpp @@ -0,0 +1,73 @@ +// Copyright (c) 2020, Huawei Technologies.All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "torch_npu/csrc/framework/utils/CalcuOpUtil.h" +#include "torch_npu/csrc/framework/utils/KernelNpuOutputSize.h" +#include "torch_npu/csrc/framework/utils/NpuUtils.h" +#include "torch_npu/csrc/framework/utils/OpTemplate.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" + +namespace at_npu { +namespace native { + +at::Tensor& logical_or_out_npu_nocheck( + const at::Tensor& self, + const at::Tensor& other, + at::Tensor& result) { + OpCommand cmd; + cmd.Name("LogicalOr") + .Input(self) + .Input(other) + .Output(result) + .Run(); + return result; +} + +at::Tensor& NPUNativeFunctions::logical_or_out( + const at::Tensor& self, + const at::Tensor& other, + at::Tensor& result) { + auto outputSize = broadcast_ops_npu_output_size(self, other); + OpPreparation::CheckOut( + {self}, + result, + CalcuOpUtil::get_tensor_npu_format(self), + result.scalar_type(), + outputSize); + logical_or_out_npu_nocheck(self, other, result); + return result; +} + +at::Tensor NPUNativeFunctions::logical_or(const at::Tensor& self, const at::Tensor& other) { + // calculate the output size + auto outputSize = broadcast_ops_npu_output_size(self, other); + at::Tensor result = OpPreparation::ApplyTensor(self, outputSize); + logical_or_out_npu_nocheck(self, other, result); + return result.toType(at::kBool); +} + +at::Tensor& NPUNativeFunctions::logical_or_(at::Tensor& self, const at::Tensor& other) { + OpPreparation::CheckMemory({self, other},{self}); + if (!NpuUtils::check_match(&self)) { + at::Tensor contiguousSelf = NpuUtils::format_contiguous(self); + at::Tensor result = logical_or_out_npu_nocheck(contiguousSelf, other, contiguousSelf); + NpuUtils::format_fresh_view(self, result); + } else { + logical_or_out_npu_nocheck(self, other, self); + } + return self; +} + +} // namespace native +} // namespace at_npu -- Gitee