From eaae3773ccaaa9aa6b810261fe3b3f7bb6bfa809 Mon Sep 17 00:00:00 2001 From: hxf12345677 Date: Thu, 17 Feb 2022 16:18:24 +0800 Subject: [PATCH] =?UTF-8?q?=5F=5For=5F=5F1.8.1=E7=AE=97=E5=AD=90=E7=A7=BB?= =?UTF-8?q?=E6=A4=8D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- test/test_network_ops/test___or__.py | 72 +++++++++++++++++ torch_npu/csrc/aten/ops/__Or__KernelNpu.cpp | 87 +++++++++++++++++++++ 2 files changed, 159 insertions(+) create mode 100644 test/test_network_ops/test___or__.py create mode 100644 torch_npu/csrc/aten/ops/__Or__KernelNpu.cpp diff --git a/test/test_network_ops/test___or__.py b/test/test_network_ops/test___or__.py new file mode 100644 index 00000000000..11c111c33c0 --- /dev/null +++ b/test/test_network_ops/test___or__.py @@ -0,0 +1,72 @@ +# Copyright (c) 2020 Huawei Technologies Co., Ltd +# All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions or +# limitations under the License. + +import copy +import torch +import torch_npu +import numpy as np + +from torch_npu.testing.common_utils import TestCase, run_tests +from torch_npu.testing.common_device_type import instantiate_device_type_tests +from torch_npu.testing.util_test import create_common_tensor + +class Test__Or__(TestCase): + def cpu_op_exec(self, input1, input2): + output = input1.__or__(input2) + if output.dtype != torch.int32: + output = output.to(torch.int32) + return output.numpy() + + def npu_op_exec(self, input1, input2): + output = input1.__or__(input2) + output = output.to("cpu") + if output.dtype != torch.int32: + output = output.to(torch.int32) + return output.numpy() + + def test___Or___shape_format(self, device): + shape_format = [ + [[np.int32, 0, [256, 1000]], [1]], + [[np.int32, 0, [256, 1000]], [np.int32, 0, [256, 1000]]], + [[np.int16, 0, [256, 1000]], [2]], + [[np.int16, 0, [256, 1000]], [np.int16, 0, [256, 1000]]], + [[np.int8, 0, [256, 1000]], [3]], + [[np.int8, 0, [256, 1000]], [np.int8, 0, [256, 1000]]], + ] + + for item in shape_format: + if len(item[1]) > 1: + cpu_input1, npu_input1 = create_common_tensor(item[0], 0, 100) + cpu_input2, npu_input2 = create_common_tensor(item[1], 0, 100) + cpu_output = self.cpu_op_exec(cpu_input1, cpu_input2) + npu_output = self.npu_op_exec(npu_input1, npu_input2) + self.assertRtolEqual(cpu_output, npu_output) + else: + cpu_input1, npu_input1 = create_common_tensor(item[0], 0, 100) + cpu_output = self.cpu_op_exec(cpu_input1, item[1][0]) + npu_output = self.npu_op_exec(npu_input1, item[1][0]) + self.assertRtolEqual(cpu_output, npu_output) + + cpu_input1 = torch.tensor([True, False, True, False, True], dtype=torch.bool) + npu_input1 = torch.tensor([True, False, True, False, True], dtype=torch.bool, device="npu") + cpu_input2 = torch.tensor([True, False, True, False, False], dtype=torch.bool) + npu_input2 = torch.tensor([True, False, True, False, False], dtype=torch.bool, device="npu") + cpu_output = self.cpu_op_exec(cpu_input1, cpu_input2) + npu_output = self.npu_op_exec(npu_input1, npu_input2) + self.assertRtolEqual(cpu_output, npu_output) + +instantiate_device_type_tests(Test__Or__, globals(), except_for="cpu") +if __name__ == "__main__": + run_tests() \ No newline at end of file diff --git a/torch_npu/csrc/aten/ops/__Or__KernelNpu.cpp b/torch_npu/csrc/aten/ops/__Or__KernelNpu.cpp new file mode 100644 index 00000000000..ead6cdd15c2 --- /dev/null +++ b/torch_npu/csrc/aten/ops/__Or__KernelNpu.cpp @@ -0,0 +1,87 @@ +// Copyright (c) 2020 Huawei Technologies Co., Ltd +// Copyright (c) 2019, Facebook CORPORATION. +// All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include "torch_npu/csrc/framework/utils/CalcuOpUtil.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" + +namespace at_npu { +namespace native { + +at::Tensor or___dest_output(const at::Tensor& self, const at::Tensor& other) { + bool isSelfWrapped = CalcuOpUtil::is_scalar_wrapped_to_tensor(self); + + if (isSelfWrapped) { + return other; + } else { + return self; + } +} + +at::Tensor& or___out_scalar_npu(at::Tensor& result, const at::Tensor& self, at::Scalar other) { + string real_op_name = + (self.dtype() == at::ScalarType::Bool) ? "LogicalOr" : "BitwiseOr"; + OpCommand cmd; + cmd.Name(real_op_name) + .Input(self) + .Input(other, self.scalar_type()) + .Output(result) + .Run(); + + return result; +} + +at::Tensor& or___out_tensor_npu( + at::Tensor& result, + const at::Tensor& self, + const at::Tensor& other) { + if (other.dim() == 0 && !other.is_npu()) { + or___out_scalar_npu(result, self, other.item()); + } else if (self.dim() == 0 && !self.is_npu()) { + or___out_scalar_npu(result, other, self.item()); + } else { + + string real_op_name = + (self.dtype() == at::ScalarType::Bool) ? "LogicalOr" : "BitwiseOr"; + OpCommand cmd; + cmd.Name(real_op_name) + .Input(self) + .Input(other) + .Output(result) + .Run(); + } + + return result; +} + +at::Tensor NPUNativeFunctions::__or__(const at::Tensor& self, const at::Tensor& other) { + at::Tensor outputTensor = or___dest_output(self, other); + auto outputSize = broadcast_ops_npu_output_size(self, other); + at::Tensor result = OpPreparation::ApplyTensor(outputTensor, outputSize); + or___out_tensor_npu(result, self, other); + + return result; +} + +at::Tensor NPUNativeFunctions::__or__(const at::Tensor& self, at::Scalar other) { + at::Tensor result = OpPreparation::ApplyTensor(self); + or___out_scalar_npu(result, self, other); + + return result; +} + +} // namespace native +} // namespace at_npu \ No newline at end of file -- Gitee