diff --git a/test/test_network_ops/test_xor.py b/test/test_network_ops/test_xor.py new file mode 100644 index 0000000000000000000000000000000000000000..9f4f916fa6cdc40e9de9c72392295908e1e536eb --- /dev/null +++ b/test/test_network_ops/test_xor.py @@ -0,0 +1,177 @@ +# Copyright (c) 2020, Huawei Technologies.All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch_npu +import numpy as np + +from torch_npu.testing.common_utils import TestCase, run_tests +from torch_npu.testing.common_device_type import instantiate_device_type_tests +from torch_npu.testing.util_test import create_common_tensor + +class TestXor(TestCase): + + def generate_bool_data(self, shape): + input1 = np.random.uniform(0, 1, shape) + input2 = np.random.uniform(0, 1, shape) + input1 = input1.reshape(-1) + input2 = input2.reshape(-1) + len1 = len(input1) + len2 = len(input2) + + for i in range(len1): + if input1[i] < 0.5: + input1[i] = 0 + for i in range(len2): + if input2[i] < 0.5: + input2[i] = 0 + input1 = input1.astype(np.bool) + input2 = input2.astype(np.bool) + input1 = input1.reshape(shape) + input2 = input2.reshape(shape) + # modify from numpy.ndarray to torch.tensor + npu_input1 = torch.from_numpy(input1) + npu_input2 = torch.from_numpy(input2) + + return npu_input1, npu_input2 + + def generate_single_bool_data(self, shape): + input1 = np.random.uniform(0, 1, shape) + input1 = input1.reshape(-1) + len3 = len(input1) + for i in range(len3): + if input1[i] < 0.5: + input1[i] = 0 + input1 = input1.astype(np.bool) + input1 = input1.reshape(shape) + # modify from numpy.ndarray to torch.tensor + npu_input1 = torch.from_numpy(input1) + + return npu_input1 + + def generate_data(self, min_d, max_d, shape, dtype): + input1 = np.random.uniform(min_d, max_d, shape).astype(dtype) + input2 = np.random.uniform(min_d, max_d, shape).astype(dtype) + + # modify from numpy.ndarray to torch.tensor + npu_input1 = torch.from_numpy(input1) + npu_input2 = torch.from_numpy(input2) + + return npu_input1, npu_input2 + + def generate_single_data(self, min_d, max_d, shape, dtype): + input1 = np.random.uniform(min_d, max_d, shape).astype(dtype) + npu_input1 = torch.from_numpy(input1) + + return npu_input1 + + def cpu_op_exec(self, input1, input2): + output = input1 ^ input2 + output = output.numpy() + return output + + def npu_op_exec(self, input1, input2): + input1 = input1.to("npu") + input2 = input2.to("npu") + output = input1.__xor__(input2) + output = output.to("cpu") + output = output.numpy() + return output + + def npu_op_exec_scalar(self, input1, input2): + input1 = input1.to("npu") + output = input1.__xor__(input2) + output = output.to("cpu") + output = output.numpy() + return output + + def test_xor_tensor_int32(self, device): + npu_input1 = self.generate_single_data(0, 100, (10, 10), np.int32) + npu_input2 = self.generate_single_data(0, 100, (10, 10), np.int32) + cpu_output = self.cpu_op_exec(npu_input1, npu_input2) + npu_output = self.npu_op_exec(npu_input1, npu_input2) + self.assertEqual(cpu_output, npu_output) + + def test_xor_tensor_int16(self, device): + npu_input1 = self.generate_single_data(0, 100, (10, 10), np.int16) + npu_input2 = self.generate_single_data(0, 100, (10, 10), np.int16) + cpu_output = self.cpu_op_exec(npu_input1, npu_input2) + npu_output = self.npu_op_exec(npu_input1, npu_input2) + self.assertEqual(cpu_output, npu_output) + + def test_xor_tensor_int8(self, device): + npu_input1 = self.generate_single_data(0, 100, (10, 10), np.int8) + npu_input2 = self.generate_single_data(0, 100, (10, 10), np.int8) + cpu_output = self.cpu_op_exec(npu_input1, npu_input2) + npu_output = self.npu_op_exec(npu_input1, npu_input2) + self.assertEqual(cpu_output, npu_output) + + def test_xor_scalar_int32(self, device): + npu_input = self.generate_single_data(0, 100, (1, 10), np.int32) + npu_input_scalr = np.random.randint(0, 100) + cpu_output = self.cpu_op_exec(npu_input, npu_input_scalr) + npu_output = self.npu_op_exec_scalar(npu_input, npu_input_scalr) + self.assertEqual(cpu_output, npu_output) + + def test_xor_scalar_int16(self, device): + npu_input = self.generate_single_data(0, 100, (10, 20), np.int16) + npu_input_scalr = np.random.randint(0, 100) + cpu_output = self.cpu_op_exec(npu_input, npu_input_scalr) + npu_output = self.npu_op_exec_scalar(npu_input, npu_input_scalr) + self.assertEqual(cpu_output, npu_output) + + def test_xor_scalar_int8(self, device): + npu_input = self.generate_single_data(0, 100, (20, 10), np.int8) + npu_input_scalr = np.random.randint(0, 100) + cpu_output = self.cpu_op_exec(npu_input, npu_input_scalr) + npu_output = self.npu_op_exec_scalar(npu_input, npu_input_scalr) + self.assertEqual(cpu_output, npu_output) + + def test_xor_tensor_uint8(self, device): + npu_input1 = self.generate_single_data(0, 100, (10, 10), np.uint8) + npu_input2 = self.generate_single_data(0, 100, (10, 10), np.uint8) + cpu_output = self.cpu_op_exec(npu_input1, npu_input2) + npu_output = self.npu_op_exec(npu_input1, npu_input2) + self.assertEqual(cpu_output, npu_output) + + def test_xor_scalar_uint8(self, device): + npu_input = self.generate_single_data(0, 100, (5, 10), np.uint8) + npu_input_scalr = np.random.randint(0, 100) + cpu_output = self.cpu_op_exec(npu_input, npu_input_scalr) + npu_output = self.npu_op_exec_scalar(npu_input, npu_input_scalr) + self.assertEqual(cpu_output, npu_output) + + def test_xor_scalar_bool1(self, device): + npu_input = self.generate_single_bool_data((10, 10)) + npu_input_scalr = True + cpu_output = self.cpu_op_exec(npu_input, npu_input_scalr) + npu_output = self.npu_op_exec_scalar(npu_input, npu_input_scalr) + self.assertEqual(cpu_output, npu_output) + + def test_xor_scalar_bool2(self, device): + npu_input = self.generate_single_bool_data((10, 10)) + npu_input_scalr = False + cpu_output = self.cpu_op_exec(npu_input, npu_input_scalr) + npu_output = self.npu_op_exec_scalar(npu_input, npu_input_scalr) + self.assertEqual(cpu_output, npu_output) + + def test_xor_tensor_bool(self, device): + npu_input1, npu_input2 = self.generate_bool_data((10, 10)) + cpu_output = self.cpu_op_exec(npu_input1, npu_input2) + npu_output = self.npu_op_exec(npu_input1, npu_input2) + self.assertEqual(cpu_output, npu_output) + +instantiate_device_type_tests(TestXor, globals(), except_for='cpu') +if __name__ == "__main__": + run_tests() \ No newline at end of file diff --git a/torch_npu/csrc/aten/ops/__Xor__KernelNpu.cpp b/torch_npu/csrc/aten/ops/__Xor__KernelNpu.cpp new file mode 100644 index 0000000000000000000000000000000000000000..02848845051b305be0595a7dbd2ffbe2d9ed855b --- /dev/null +++ b/torch_npu/csrc/aten/ops/__Xor__KernelNpu.cpp @@ -0,0 +1,165 @@ +// Copyright (c) 2020, Huawei Technologies.All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" + +namespace at_npu { +namespace native { + +at::Tensor& not_out_npu(at::Tensor& result, const at::Tensor& self) { + OpCommand cmd; + cmd.Name("LogicalNot") + .Input(self) + .Output(result) + .Run(); + return result; +} + +at::Tensor& not_out_npu(at::Tensor& result, const at::Scalar self) { + OpCommand cmd; + cmd.Name("LogicalNot") + .Input(self, self.type()) + .Output(result) + .Run(); + return result; +} + +at::Tensor& and_out_npu(at::Tensor& result, const at::Tensor& self, const at::Tensor& other) { + OpCommand cmd; + cmd.Name("LogicalAnd") + .Input(self) + .Input(other) + .Output(result) + .Run(); + return result; +} + +at::Tensor& and_out_npu(at::Tensor& result, const at::Tensor& self, const at::Scalar other) { + OpCommand cmd; + cmd.Name("LogicalAnd") + .Input(self) + .Input(other, self.scalar_type()) + .Output(result) + .Run(); + return result; +} + +at::Tensor& or_out_npu(at::Tensor& result, const at::Tensor& self, const at::Scalar other) { + OpCommand cmd; + cmd.Name("LogicalOr") + .Input(self) + .Input(other, self.scalar_type()) + .Output(result) + .Run(); + return result; +} + +at::Tensor& xor_out_npu( + at::Tensor& result, + const at::Tensor& self, + const at::Tensor& other) { + auto outputSize = broadcast_ops_npu_output_size(self, other); + // executing the NPU operator + if (self.dtype() == at::ScalarType::Bool) { + auto not_self_result = OpPreparation::ApplyTensor(self, outputSize); + not_out_npu(not_self_result, self); + + auto not_other_result = OpPreparation::ApplyTensor(self, outputSize); + not_out_npu(not_other_result, other); + + auto not_self_and_other = OpPreparation::ApplyTensor(self, outputSize); + and_out_npu(not_self_and_other, not_self_result, other); + + auto self_and_not_other = OpPreparation::ApplyTensor(self, outputSize); + and_out_npu(self_and_not_other, self, not_other_result); + + OpCommand cmd; + cmd.Name("LogicalOr") + .Input(not_self_and_other) + .Input(self_and_not_other) + .Output(result) + .Run(); + } else { + OpCommand cmd; + cmd.Name("BitwiseXor") + .Input(self) + .Input(other) + .Output(result) + .Run(); + } + + return result; +} + +at::Tensor& xor_out_npu( + at::Tensor& result, + const at::Tensor& self, + const at::Scalar other) { + // executing the NPU operator + if (self.dtype() == at::ScalarType::Bool) { + auto not_self_result = OpPreparation::ApplyTensor(self); + not_out_npu(not_self_result, self); + + auto not_self_or_other = OpPreparation::ApplyTensor(self); + or_out_npu(not_self_or_other, not_self_result, other); + + auto not_not_self_or_other = OpPreparation::ApplyTensor(self); + not_out_npu(not_not_self_or_other, not_self_or_other); + + auto not_self_and_other = OpPreparation::ApplyTensor(self); + and_out_npu(not_self_and_other, not_self_result, other); + + OpCommand cmd; + cmd.Name("LogicalOr") + .Input(not_self_and_other) + .Input(not_not_self_or_other) + .Output(result) + .Run(); + + } else { + OpCommand cmd; + cmd.Name("BitwiseXor") + .Input(self) + .Input(other, self.scalar_type()) + .Output(result) + .Run(); + } + + return result; +} + +at::Tensor NPUNativeFunctions::__xor__(const at::Tensor& self, const at::Tensor& other) { + // calculate the output size + auto outputSize = broadcast_ops_npu_output_size(self, other); + + // construct the output tensor of the NPU + at::Tensor result = OpPreparation::ApplyTensor(self, outputSize); + + // calculate the output result of the NPU + xor_out_npu(result, self, other); + return result; +} + +at::Tensor NPUNativeFunctions::__xor__(const at::Tensor& self, const at::Scalar other) { + // construct the output tensor of the NPU + at::Tensor result = OpPreparation::ApplyTensor(self); + + // calculate the output result of the NPU + xor_out_npu(result, self, other); + + return result; +} +} // namespace native +} // namespace at_npu \ No newline at end of file