From 8a6dd6fa3f84d2662aa5789d9bf15a6ec64544e2 Mon Sep 17 00:00:00 2001 From: 18190895210 Date: Tue, 25 Jan 2022 19:13:07 +0800 Subject: [PATCH] =?UTF-8?q?logical=5Fand=E7=AE=97=E5=AD=90=E7=A7=BB?= =?UTF-8?q?=E6=A4=8D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- test/test_network_ops/test_logical_and.py | 121 ++++++++++++++++++ .../csrc/aten/ops/LogicalAndKernelNpu.cpp | 72 +++++++++++ 2 files changed, 193 insertions(+) create mode 100644 test/test_network_ops/test_logical_and.py create mode 100644 torch_npu/csrc/aten/ops/LogicalAndKernelNpu.cpp diff --git a/test/test_network_ops/test_logical_and.py b/test/test_network_ops/test_logical_and.py new file mode 100644 index 0000000000..5542a2971c --- /dev/null +++ b/test/test_network_ops/test_logical_and.py @@ -0,0 +1,121 @@ +# Copyright (c) 2020, Huawei Technologies.All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import copy +import torch +import torch_npu +import numpy as np + +from torch_npu.testing.common_utils import TestCase, run_tests +from torch_npu.testing.common_device_type import instantiate_device_type_tests +from torch_npu.testing.util_test import create_common_tensor + +class TestLogicalAnd(TestCase): + def generate_single_data(self, min_d, max_d, shape, dtype): + input1 = np.random.uniform(min_d, max_d, shape).astype(dtype) + npu_input1 = torch.from_numpy(input1) + return npu_input1 + + def generate_data(self, min_d, max_d, shape, dtype): + input1 = np.random.uniform(min_d, max_d, shape).astype(dtype) + input2 = np.random.uniform(min_d, max_d, shape).astype(dtype) + npu_input1 = torch.from_numpy(input1) + npu_input2 = torch.from_numpy(input2) + return npu_input1, npu_input2 + + def generate_three_data(self, min_d, max_d, shape, dtype): + input1 = np.random.uniform(min_d, max_d, shape).astype(dtype) + input2 = np.random.uniform(min_d, max_d, shape).astype(dtype) + input3 = np.random.uniform(min_d, max_d, shape).astype(dtype) + npu_input1 = torch.from_numpy(input1) + npu_input2 = torch.from_numpy(input2) + npu_input3 = torch.from_numpy(input3) + return npu_input1, npu_input2, npu_input3 + + def cpu_op_exec(self, input1, input2): + output = torch.logical_and(input1, input2) + output = output.numpy() + return output + + def npu_op_exec(self, input1, input2): + input1 = input1.to("npu") + input2 = input2.to("npu") + output = torch.logical_and(input1, input2) + output = output.to("cpu") + output = output.numpy() + return output + + def cpu_op_exec_out(self, input1, input2, input3): + torch.logical_and(input1, input2, out=input3) + output = input3.numpy() + return output + + def npu_op_exec_out(self, input1, input2, input3): + input1 = input1.to("npu") + input2 = input2.to("npu") + output = input3.to("npu") + torch.logical_and(input1, input2, out=output) + output = output.to("cpu") + output = output.numpy() + return output + + def cpu_op_exec_(self, input1, input2): + output = torch.Tensor.logical_and_(input1, input2) + output = output.numpy() + return output + + def npu_op_exec_(self, input1, input2): + input1 = input1.to("npu") + input2 = input2.to("npu") + output = torch.Tensor.logical_and_(input1, input2) + output = output.to("cpu") + output = output.numpy() + return output + + def logical_and_out_result(self, shape_format): + for item in shape_format: + cpu_input1 = torch.randn(item[0]) < 0 + cpu_input2 = torch.randn(item[0]) < 0 + cpu_input3 = torch.randn(item[1]) < 0 + cpu_output_out = self.cpu_op_exec_out(cpu_input1, cpu_input2, cpu_input3) + npu_output_out = self.npu_op_exec_out(cpu_input1, cpu_input2, cpu_input3) + self.assertRtolEqual(cpu_output_out, npu_output_out) + + def test_logical_and_out(self, device): + shape_format = [ + [[128, 116, 14, 14], [256, 116, 1, 1, 28]], + [[128, 3, 224, 224], [3, 3, 3]], + [[128, 116, 14, 14], [128, 116, 14, 14]], + [[256, 128, 7, 7], [128, 256, 3, 3, 28]], + [[2, 3, 3, 3], [3, 1, 3]], + [[128, 232, 7, 7], [128, 232, 7, 7]], + ] + self.logical_and_out_result(shape_format) + + def test_logical_and_bool(self, device): + npu_input1, npu_input2 = self.generate_data(0, 2, (2, 5), np.bool) + cpu_output = self.cpu_op_exec(npu_input1, npu_input2).astype(np.float32) + npu_output = self.npu_op_exec(npu_input1, npu_input2).astype(np.float32) + self.assertRtolEqual(cpu_output, npu_output) + + def test_logical_and_inplace_bool(self, device): + npu_input1, npu_input2 = self.generate_data(0, 2, (2, 5), np.bool) + cpu_output = self.cpu_op_exec_(npu_input1, npu_input2).astype(np.float32) + npu_output = self.npu_op_exec_(npu_input1, npu_input2).astype(np.float32) + self.assertRtolEqual(cpu_output, npu_output) + +instantiate_device_type_tests(TestLogicalAnd, globals(), except_for='cpu') +if __name__ == "__main__": + run_tests() diff --git a/torch_npu/csrc/aten/ops/LogicalAndKernelNpu.cpp b/torch_npu/csrc/aten/ops/LogicalAndKernelNpu.cpp new file mode 100644 index 0000000000..20dba47788 --- /dev/null +++ b/torch_npu/csrc/aten/ops/LogicalAndKernelNpu.cpp @@ -0,0 +1,72 @@ +// Copyright (c) 2020, Huawei Technologies.All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "torch_npu/csrc/framework/utils/CalcuOpUtil.h" +#include "torch_npu/csrc/framework/utils/KernelNpuOutputSize.h" +#include "torch_npu/csrc/framework/utils/NpuUtils.h" +#include "torch_npu/csrc/framework/utils/OpTemplate.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" + +namespace at_npu { +namespace native { + +at::Tensor& logical_and_out_npu_nocheck( + const at::Tensor& self, + const at::Tensor& other, + at::Tensor& result) { + OpCommand cmd; + cmd.Name("LogicalAnd") + .Input(self) + .Input(other) + .Output(result) + .Run(); + return result; +} + +at::Tensor& NPUNativeFunctions::logical_and_out( + const at::Tensor& self, + const at::Tensor& other, + at::Tensor& result) { + auto outputSize = broadcast_ops_npu_output_size(self, other); + OpPreparation::CheckOut( + {self}, + result, + CalcuOpUtil::get_tensor_npu_format(self), + result.scalar_type(), + outputSize); + if (!NpuUtils::check_match(&result)) { + at::Tensor contiguousResult = NpuUtils::format_contiguous(result); + logical_and_out_npu_nocheck(self, other, contiguousResult); + NpuUtils::format_fresh_view(result, contiguousResult); + } else { + logical_and_out_npu_nocheck(self, other, result); + } + return result; +} + +at::Tensor NPUNativeFunctions::logical_and(const at::Tensor& self, const at::Tensor& other) { + auto outputSize = broadcast_ops_npu_output_size(self, other); + at::Tensor result = OpPreparation::ApplyTensor(self, outputSize); + logical_and_out_npu_nocheck(self, other, result); + result = NPUNativeFunctions::npu_dtype_cast(result, at::kBool); + return result; +} + +at::Tensor& NPUNativeFunctions::logical_and_(at::Tensor& self, const at::Tensor& other) { + OpPreparation::CheckMemory({self, other}, {self}); + logical_and_out(self, other, self); + return self; +} +} // namespace native +} // namespace at_npu -- Gitee