From d8dc61fb02ebd556a4ac22b85eb9773ebd875397 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=AD=99=E5=B9=B8?= <346736790@qq.com> Date: Mon, 7 Feb 2022 09:52:26 +0000 Subject: [PATCH 1/2] =?UTF-8?q?bitwiseand=E7=AE=97=E5=AD=90=E8=BF=81?= =?UTF-8?q?=E7=A7=BB?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- test/test_network_ops/test_bitwiseand.py | 208 ++++++++++++++++++ .../csrc/aten/ops/BitwiseAndKernelNpu.cpp | 174 +++++++++++++++ 2 files changed, 382 insertions(+) create mode 100644 test/test_network_ops/test_bitwiseand.py create mode 100644 torch_npu/csrc/aten/ops/BitwiseAndKernelNpu.cpp diff --git a/test/test_network_ops/test_bitwiseand.py b/test/test_network_ops/test_bitwiseand.py new file mode 100644 index 00000000000..6ad1a567d23 --- /dev/null +++ b/test/test_network_ops/test_bitwiseand.py @@ -0,0 +1,208 @@ +# Copyright (c) 2020, Huawei Technologies.All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch_npu +import numpy as np +import unittest + +from torch_npu.testing.common_utils import TestCase, run_tests +from torch_npu.testing.common_device_type import Dtypes, instantiate_device_type_tests +from torch_npu.testing.util_test import create_common_tensor, test_2args_broadcast, create_dtype_tensor, UT_FAST_MODE + + + +class TestBitwiseAnd(TestCase): + def generate_data(self, min_d, max_d, shape, dtype): + input1 = np.random.uniform(min_d, max_d, shape).astype(dtype) + input2 = np.random.uniform(min_d, max_d, shape).astype(dtype) + + #modify from numpy.ndarray to torch.tensor + npu_input1 = torch.from_numpy(input1) + npu_input2 = torch.from_numpy(input2) + + return npu_input1, npu_input2 + + def generate_bool_data(self, min_d, max_d, shape): + input1 = np.random.uniform(min_d, max_d, shape) + input2 = np.random.uniform(min_d, max_d, shape) + input1 = input1.reshape(-1) + input2 = input2.reshape(-1) + for i in range(len(input1)): + if input1[i]<0.5: + input1[i] = 0 + for i in range(len(input2)): + if input2[i]<0.5: + input2[i] = 0 + input1 = input1.astype(np.bool) + input2 = input2.astype(np.bool) + input1 = input1.reshape(shape) + input2 = input2.reshape(shape) + #modify from numpy.ndarray to torch.tensor + npu_input1 = torch.from_numpy(input1) + npu_input2 = torch.from_numpy(input2) + + return npu_input1, npu_input2 + + def generate_single_data(self, min_d, max_d, shape, dtype): + input1 = np.random.uniform(min_d, max_d, shape).astype(dtype) + npu_input1 = torch.from_numpy(input1) + + return npu_input1 + + def generate_single_bool_data(self, min_d, max_d, shape): + input1 = np.random.uniform(min_d, max_d, shape) + input1 = input1.reshape(-1) + for i in range(len(input1)): + if input1[i]<0.5: + input1[i] = 0 + input1 = input1.astype(np.bool) + input1 = input1.reshape(shape) + #modify from numpy.ndarray to torch.tensor + npu_input1 = torch.from_numpy(input1) + + return npu_input1 + + def generate_three_data(self, min_d, max_d, shape, dtype): + input1 = np.random.uniform(min_d, max_d, shape).astype(dtype) + input2 = np.random.uniform(min_d, max_d, shape).astype(dtype) + input3 = np.random.uniform(min_d, max_d, shape).astype(dtype) + + #modify from numpy.ndarray to torch.tensor + npu_input1 = torch.from_numpy(input1) + npu_input2 = torch.from_numpy(input2) + npu_input3 = torch.from_numpy(input3) + + return npu_input1, npu_input2, npu_input3 + + def npu_op_exec_out(self, input1, input2, input3): + input1 = input1.to("npu") + input2 = input2.to("npu") + output = input3.to("npu") + torch.bitwise_and(input1, input2, out = output) + output = output.to("cpu") + if output.dtype != torch.int32: + output = output.to(torch.int32) + output = output.numpy() + return output + + def npu_mix_op_exec(self, input1, input2): + input1 = input1.to("npu") + input2 = input2.to("npu") + output = torch.bitwise_and(input1, input2) + output = output.to("cpu") + output = output.numpy() + return output + + def cpu_op_exec(self, input1, input2): + output = torch.bitwise_and(input1, input2) + if output.dtype != torch.int32: + output = output.to(torch.int32) + output = output.numpy() + return output + + def cpu_op_exec_out(self, input1, input2, output): + torch.bitwise_and(input1, input2, out = output) + if output.dtype != torch.int32: + output = output.to(torch.int32) + output = output.numpy() + return output + + def cpu_op_exec_scalar(self, input1, scalar): + output = torch.bitwise_and(input1, scalar) + if output.dtype != torch.int32: + output = output.to(torch.int32) + output = output.numpy() + return output + + def cpu_op_exec_scalar_out(self, input1, scalar, output): + torch.bitwise_and(input1, scalar, out = output) + if output.dtype != torch.int32: + output = output.to(torch.int32) + output = output.numpy() + return output + + def npu_op_exec_scalar_out(self, input1, input2, output): + input1 = input1.to("npu") + output = output.to("npu") + torch.bitwise_and(input1, input2, out = output) + output = output.to("cpu") + if output.dtype != torch.int32: + output = output.to(torch.int32) + output = output.numpy() + return output + + def bitwise_and_tensor_out_result(self, shape_format): + for item in shape_format: + cpu_input1, npu_input1 = create_common_tensor(item[0], 0, 100) + cpu_input2, npu_input2 = create_common_tensor(item[0], 0, 100) + cpu_input3, npu_input3 = create_common_tensor(item[1], 0, 100) + cpu_output_out = self.cpu_op_exec(cpu_input1, cpu_input2) + npu_output_out = self.npu_op_exec_out(npu_input1, npu_input2, npu_input3) + cpu_output_out = cpu_output_out.astype(npu_output_out.dtype) + + self.assertRtolEqual(cpu_output_out, npu_output_out) + + def test_bitwise_and_tensor_out(self, device): + shape_format = [ + [[np.int16, 0, [128, 3, 224, 224]], [np.int16, 0, [3, 3, 3]]], + [[np.int16, 0, [128, 116, 14, 14]], [np.int16, 0, [128, 116, 14, 14]]], + [[np.int32, 0, [256, 128, 7, 7]], [np.int32, 0, [128, 256, 3, 3]]], + [[np.int32, 0, [2, 3, 3, 3]], [np.int32, 0, [3, 1, 3]]], + [[np.int32, 0, [128, 232, 7, 7]], [np.int32, 0, [128, 232, 7, 7]]], + ] + self.bitwise_and_tensor_out_result(shape_format) + + def bitwise_and_scalar_out_result(self, shape_format): + for item in shape_format: + cpu_input1, npu_input1 = create_common_tensor(item[0], 0, 100) + cpu_input2, npu_input2 = create_common_tensor(item[1], 0, 100) + scalar = np.random.randint(1, 5) + cpu_output_out = self.cpu_op_exec_scalar(cpu_input1, scalar) + npu_output_out = self.npu_op_exec_scalar_out(npu_input1, scalar, npu_input2) + cpu_output_out = cpu_output_out.astype(npu_output_out.dtype) + self.assertRtolEqual(cpu_output_out, npu_output_out) + + def test_bitwise_and_scalar_out(self, device): + shape_format = [ + [[np.int16, 0, [16, 3, 1111, 1212]], [np.int16, 0, [3, 3, 3]]], + [[np.int16, 0, [128, 116, 14, 14]], [np.int16, 0, [128, 116, 14, 14]]], + [[np.int32, 0, [1313, 3, 3, 3]], [np.int32, 0, [3, 1, 3]]], + [[np.int32, 0, [128, 232, 7, 7]], [np.int32, 0, [128, 232, 7, 7]]], + ] + self.bitwise_and_scalar_out_result(shape_format) + + def test_bitwise_and_bool_scalar(self, device): + npu_input1, npu_input2 = self.generate_data(0, 100, (2,3), np.int32) + cpu_output = self.cpu_op_exec_out(npu_input1, True,npu_input1) + npu_output = self.npu_op_exec_scalar_out(npu_input1, True,npu_input1) + self.assertRtolEqual(cpu_output, npu_output) + + def test_bitwise_and_int16_diff(self, device): + npu_input1 = self.generate_single_data(0, 100, (1,6), np.int16) + npu_input2 = self.generate_single_data(0, 100, (1,1), np.int16) + cpu_output = self.cpu_op_exec_out(npu_input1, npu_input2, npu_input1) + npu_output = self.npu_op_exec_out(npu_input1, npu_input2, npu_input1) + self.assertRtolEqual(cpu_output, npu_output) + + def test_bitwise_and_mix_dtype(self, device): + npu_input1, npu_input2 = self.generate_data(0, 100, (1,6), np.int32) + npu_input3, npu_input4 = self.generate_data(0, 100, (1,6), np.int16) + cpu_output = self.cpu_op_exec(npu_input1, npu_input3) + npu_output = self.npu_mix_op_exec(npu_input1, npu_input3) + self.assertRtolEqual(cpu_output, npu_output) + +instantiate_device_type_tests(TestBitwiseAnd, globals(), except_for='cpu') +if __name__ == '__main__': + run_tests() \ No newline at end of file diff --git a/torch_npu/csrc/aten/ops/BitwiseAndKernelNpu.cpp b/torch_npu/csrc/aten/ops/BitwiseAndKernelNpu.cpp new file mode 100644 index 00000000000..c5f47a7f12e --- /dev/null +++ b/torch_npu/csrc/aten/ops/BitwiseAndKernelNpu.cpp @@ -0,0 +1,174 @@ +// Copyright (c) 2020, Huawei Technologies.All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include "torch_npu/csrc/framework/utils/CalcuOpUtil.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" + +namespace at { +namespace native { + +at::Tensor& bitwise_and_out_npu_nocheck( + at::Tensor& result, + const at::Tensor& self, + const at::Scalar other) { + // executing the NPU operator + string real_op_name = (self.dtype() == at::ScalarType::Bool) ? "LogicalAnd" : "BitwiseAnd"; + + OpCommand cmd; + cmd.Name(real_op_name) + .Input(self) + .Input(other, self.scalar_type()) + .Output(result) + .Run(); + + return result; +} + +at::Tensor& NPUNativeFunctions::bitwise_and_out_npu( + at::Tensor& result, + const at::Tensor& self, + const at::Scalar other) { + OpPreparation::CheckOut( + {self}, + result, + CalcuOpUtil::get_tensor_npu_format(self), + self.scalar_type(), + self.sizes()); + + bitwise_and_out_npu_nocheck(result, self, other); + + return result; +} + +at::Tensor& bitwise_and_out_npu_nocheck( + at::Tensor& result, + const at::Tensor& self, + const at::Tensor& other) { + auto unified_result = OpPreparation::binary_op_check(result, self, other, true); + if (other.dim() == 0 && !other.is_npu()) { + bitwise_and_out_npu(result, self, other.item()); + } else if (self.dim() == 0 && !self.is_npu()) { + bitwise_and_out_npu(result, other, self.item()); + } else { + // executing the NPU operator + string real_op_name = (self.dtype() == at::ScalarType::Bool) ? "LogicalAnd" : "BitwiseAnd"; + + OpCommand cmd; + cmd.Name(real_op_name) + .Expect(unified_result) + .Input(self) + .Input(other) + .Output(result) + .Run(); + } + + return result; +} + +at::Tensor& NPUNativeFunctions::bitwise_and_out_npu( + at::Tensor& result, + const at::Tensor& self, + const at::Tensor& other) { + bool isSelfWrapped = CalcuOpUtil::is_scalar_wrapped_to_tensor(self); + + at::Tensor outputTensor; + if (not isSelfWrapped) { + outputTensor = self; + } else { + outputTensor = other; + } + + auto outputSize = broadcast_ops_npu_output_size(self, other); + + OpPreparation::CheckOut( + {self}, + result, + CalcuOpUtil::get_tensor_npu_format(outputTensor), + outputTensor.scalar_type(), + outputSize); + + bitwise_and_out_npu_nocheck(result, self, other); + + return result; +} + +at::Tensor NPUNativeFunctions::bitwise_and_npu(const at::Tensor& self, const at::Tensor& other) { + // calculate the output size + bool isSelfWrapped = CalcuOpUtil::is_scalar_wrapped_to_tensor(self); + + at::Tensor outputTensor; + if (not isSelfWrapped) { + outputTensor = self; + } else { + outputTensor = other; + } + + auto outputSize = broadcast_ops_npu_output_size(self, other); + + // construct the output at::Tensor of the NPU + at::Tensor result = at::empty_with_format( + outputSize, + outputTensor.options(), + CalcuOpUtil::get_tensor_npu_format(outputTensor)); + + // calculate the output result of the NPU + bitwise_and_out_npu_nocheck(result, self, other); + + return result; +} + +at::Tensor NPUNativeFunctions::bitwise_and_npu(const at::Tensor& self, at::Scalar other) { + // calculate the output size + auto outputSize = input_same_output_size(self); + + // construct the output at::Tensor of the NPU + at::Tensor result = at::empty_with_format( + outputSize, self.options(), CalcuOpUtil::get_tensor_npu_format(self)); + + // calculate the output result of the NPU + bitwise_and_out_npu_nocheck(result, self, other); + + return result; +} + +at::Tensor& NPUNativeFunctions::bitwise_and_npu_(at::Tensor& self, const at::Tensor& other) { + c10::SmallVector inputs = {self, other}; + c10::SmallVector outputs = {self}; + CalcuOpUtil::check_memory_over_laps(inputs, outputs); + + if (!NpuUtils::check_match(&self)) { + at::Tensor contiguousSelf = NpuUtils::format_contiguous(self); + at::Tensor result = bitwise_and_out_npu_nocheck(contiguousSelf, contiguousSelf, other); + NpuUtils::format_fresh_view(self, result); + } else { + bitwise_and_out_npu_nocheck(self, self, other); + } + + return self; +} + +at::Tensor& NPUNativeFunctions::bitwise_and_npu_(at::Tensor& self, at::Scalar other) { + if (!NpuUtils::check_match(&self)) { + at::Tensor contiguousSelf = NpuUtils::format_contiguous(self); + at::Tensor result = bitwise_and_out_npu_nocheck(contiguousSelf, contiguousSelf, other); + NpuUtils::format_fresh_view(self, result); + } else { + bitwise_and_out_npu_nocheck(self, self, other); + } + + return self; +} +} // namespace native +} // namespace at_npu \ No newline at end of file -- Gitee From 1dba7bd41ed87e6e9953e86a179bb52fe819e1d7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=AD=99=E5=B9=B8?= <346736790@qq.com> Date: Tue, 8 Feb 2022 02:14:01 +0000 Subject: [PATCH 2/2] =?UTF-8?q?Bitwiseand=E7=AE=97=E5=AD=90=E8=BF=81?= =?UTF-8?q?=E7=A7=BB?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Bitwiseand算子迁移 BitwiseAnd算子迁移 Bitwiseand算子迁移 Bitwiseand算子迁移 Bitwiseand算子迁移 --- .../csrc/aten/ops/BitwiseAndKernelNpu.cpp | 28 ++++++++----------- 1 file changed, 11 insertions(+), 17 deletions(-) diff --git a/torch_npu/csrc/aten/ops/BitwiseAndKernelNpu.cpp b/torch_npu/csrc/aten/ops/BitwiseAndKernelNpu.cpp index c5f47a7f12e..144f5da9644 100644 --- a/torch_npu/csrc/aten/ops/BitwiseAndKernelNpu.cpp +++ b/torch_npu/csrc/aten/ops/BitwiseAndKernelNpu.cpp @@ -12,11 +12,11 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "torch_npu/csrc/framework/utils/OpAdapter.h" #include "torch_npu/csrc/framework/utils/CalcuOpUtil.h" +#include "torch_npu/csrc/framework/utils/OpAdapter.h" #include "torch_npu/csrc/aten/NPUNativeFunctions.h" -namespace at { +namespace at_npu { namespace native { at::Tensor& bitwise_and_out_npu_nocheck( @@ -36,10 +36,7 @@ at::Tensor& bitwise_and_out_npu_nocheck( return result; } -at::Tensor& NPUNativeFunctions::bitwise_and_out_npu( - at::Tensor& result, - const at::Tensor& self, - const at::Scalar other) { +at::Tensor & NPUNativeFunctions::bitwise_and_out(const at::Tensor & self, at::Scalar other, at::Tensor & result) { OpPreparation::CheckOut( {self}, result, @@ -58,9 +55,9 @@ at::Tensor& bitwise_and_out_npu_nocheck( const at::Tensor& other) { auto unified_result = OpPreparation::binary_op_check(result, self, other, true); if (other.dim() == 0 && !other.is_npu()) { - bitwise_and_out_npu(result, self, other.item()); + NPUNativeFunctions::bitwise_and_out(self, other.item(), result); } else if (self.dim() == 0 && !self.is_npu()) { - bitwise_and_out_npu(result, other, self.item()); + NPUNativeFunctions::bitwise_and_out(other, self.item(), result); } else { // executing the NPU operator string real_op_name = (self.dtype() == at::ScalarType::Bool) ? "LogicalAnd" : "BitwiseAnd"; @@ -77,10 +74,7 @@ at::Tensor& bitwise_and_out_npu_nocheck( return result; } -at::Tensor& NPUNativeFunctions::bitwise_and_out_npu( - at::Tensor& result, - const at::Tensor& self, - const at::Tensor& other) { +at::Tensor & NPUNativeFunctions::bitwise_and_out(const at::Tensor & self, const at::Tensor & other, at::Tensor & result) { bool isSelfWrapped = CalcuOpUtil::is_scalar_wrapped_to_tensor(self); at::Tensor outputTensor; @@ -104,7 +98,7 @@ at::Tensor& NPUNativeFunctions::bitwise_and_out_npu( return result; } -at::Tensor NPUNativeFunctions::bitwise_and_npu(const at::Tensor& self, const at::Tensor& other) { +at::Tensor NPUNativeFunctions::bitwise_and(const at::Tensor& self, const at::Tensor& other) { // calculate the output size bool isSelfWrapped = CalcuOpUtil::is_scalar_wrapped_to_tensor(self); @@ -129,7 +123,7 @@ at::Tensor NPUNativeFunctions::bitwise_and_npu(const at::Tensor& self, const at: return result; } -at::Tensor NPUNativeFunctions::bitwise_and_npu(const at::Tensor& self, at::Scalar other) { +at::Tensor NPUNativeFunctions::bitwise_and(const at::Tensor& self, at::Scalar other) { // calculate the output size auto outputSize = input_same_output_size(self); @@ -143,7 +137,7 @@ at::Tensor NPUNativeFunctions::bitwise_and_npu(const at::Tensor& self, at::Scala return result; } -at::Tensor& NPUNativeFunctions::bitwise_and_npu_(at::Tensor& self, const at::Tensor& other) { +at::Tensor& NPUNativeFunctions::bitwise_and_(at::Tensor& self, const at::Tensor& other) { c10::SmallVector inputs = {self, other}; c10::SmallVector outputs = {self}; CalcuOpUtil::check_memory_over_laps(inputs, outputs); @@ -159,7 +153,7 @@ at::Tensor& NPUNativeFunctions::bitwise_and_npu_(at::Tensor& self, const at::Ten return self; } -at::Tensor& NPUNativeFunctions::bitwise_and_npu_(at::Tensor& self, at::Scalar other) { +at::Tensor& NPUNativeFunctions::bitwise_and_(at::Tensor& self, at::Scalar other) { if (!NpuUtils::check_match(&self)) { at::Tensor contiguousSelf = NpuUtils::format_contiguous(self); at::Tensor result = bitwise_and_out_npu_nocheck(contiguousSelf, contiguousSelf, other); @@ -171,4 +165,4 @@ at::Tensor& NPUNativeFunctions::bitwise_and_npu_(at::Tensor& self, at::Scalar ot return self; } } // namespace native -} // namespace at_npu \ No newline at end of file +} // namespace at_npu -- Gitee