diff --git a/test/test_network_ops/test_equal.py b/test/test_network_ops/test_equal.py new file mode 100644 index 0000000000000000000000000000000000000000..a0f23a41ecb6d5c673c22f88167e2822d01d4473 --- /dev/null +++ b/test/test_network_ops/test_equal.py @@ -0,0 +1,109 @@ +# Copyright (c) 2021 Huawei Technologies Co., Ltd +# Copyright (c) 2019, Facebook CORPORATION. +# All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import torch +import torch_npu +import numpy as np +from torch_npu.testing.common_utils import TestCase, run_tests +from torch_npu.testing.common_device_type import Dtypes, instantiate_device_type_tests +from torch_npu.testing.util_test import create_common_tensor, create_dtype_tensor, UT_FAST_MODE + + +class TestTensorEqual(TestCase): + + def cpu_op_exec(self, input1, input2): + output = torch.equal(input1, input2) + output = torch.tensor(output) + output = output.numpy() + return output + + def npu_op_exec(self, input1, input2): + output = torch.equal(input1, input2) + output = torch.tensor(output).to("cpu") + output = output.numpy() + return output + + def test_tensor_equal_common_shape_format(self, device): + shape_format = [ + [[np.float32, 0, (4, 3)], [np.float32, 0, (4, 3)]], + [[np.float32, 29, (4, 3, 1)], [np.float32, 29, (4, 1, 5)]], + [[np.float32, 2, (4, 3, 2)], [np.float32, 2, (4, 3, 2)]], + [[np.float32, 3, (7, 3, 2)], [np.float32, 3, (7, 3, 2)]], + [[np.float32, 4, (8, 4)], [np.float32, 4, (8, 4)]], + [[np.int32, 0, (2, 3)], [np.int32, 0, (2, 3)]], + [[np.int32, 2, (4, 3, 1)], [np.int32, 2, (4, 1, 5)]], + [[np.int32, 2, (4, 3, 2)], [np.int32, 2, (4, 3, 2)]], + [[np.int32, -1, (7, 3, 2)], [np.int32, -1, (7, 3, 2)]], + [[np.int8, 0, (7, 3)], [np.int8, 0, (7, 3)]], + [[np.int8, 2, (4, 3, 2)], [np.int8, 2, (4, 3, 2)]], + [[np.uint8, 0, (3, 2)], [np.uint8, 0, (3, 2)]], + [[np.uint8, 2, (4, 3, 2)], [np.uint8, 2, (4, 3, 2)]] + ] + for item in shape_format: + cpu_input1, npu_input1 = create_common_tensor(item[0], 1, 100) + cpu_input2, npu_input2 = create_common_tensor(item[1], 1, 100) + cpu_output = self.cpu_op_exec(cpu_input1, cpu_input2) + npu_output = self.npu_op_exec(npu_input1, npu_input2) + self.assertRtolEqual(cpu_output, npu_output) + + cpu_input1, npu_input1 = create_common_tensor(shape_format[11][0], 1, 100) + cpu_input2 = cpu_input1 + npu_input2 = npu_input1 + cpu_output = self.cpu_op_exec(cpu_input1, cpu_input2) + npu_output = self.npu_op_exec(npu_input1, npu_input2) + self.assertRtolEqual(cpu_output, npu_output) + + def test_tensor_equal_float16_shape_format(self, device): + def cpu_op_exec_fp16(input1, input2): + output = np.array_equal(input1, input2) + output = torch.tensor(output) + output = output.numpy() + output = output.astype(np.float16) + return output + + def npu_op_exec_fp16(input1, input2): + output = torch.equal(input1, input2) + output = torch.tensor(output).to("cpu") + output = output.numpy() + output = output.astype(np.float16) + return output + + shape_format = [ + [[np.float16, 0, (4, 3)], [np.float16, 0, (4, 3)]], + [[np.float16, 29, (4, 3, 1)], [np.float16, 29, (4, 1, 5)]], + [[np.float16, 2, (4, 3, 2)], [np.float16, 2, (4, 3, 2)]], + [[np.float16, 3, (7, 3, 2)], [np.float16, 3, (7, 3, 2)]], + [[np.float16, 4, (8, 4)], [np.float16, 4, (8, 4)]], + ] + for item in shape_format: + cpu_input1, npu_input1 = create_common_tensor(item[0], 1, 100) + cpu_input2, npu_input2 = create_common_tensor(item[1], 1, 100) + cpu_output = cpu_op_exec_fp16(cpu_input1, cpu_input2) + npu_output = npu_op_exec_fp16(npu_input1, npu_input2) + self.assertRtolEqual(cpu_output, npu_output) + + cpu_input1, npu_input1 = create_common_tensor(shape_format[2][0], 2, 100) + cpu_input2 = cpu_input1 + npu_input2 = npu_input1 + cpu_output = cpu_op_exec_fp16(cpu_input1, cpu_input2) + npu_output = npu_op_exec_fp16(npu_input1, npu_input2) + self.assertRtolEqual(cpu_output, npu_output) + + +instantiate_device_type_tests(TestTensorEqual, globals(), except_for="cpu") +if __name__ == "__main__": + run_tests() \ No newline at end of file diff --git a/torch_npu/csrc/aten/npu_native_functions.yaml b/torch_npu/csrc/aten/npu_native_functions.yaml index 412fb9da7ce2fdf366e37a24e4b21404c4e90c55..04d32c06691661d1ab6333864d45c1f5de397393 100644 --- a/torch_npu/csrc/aten/npu_native_functions.yaml +++ b/torch_npu/csrc/aten/npu_native_functions.yaml @@ -240,7 +240,6 @@ supported: - addcdiv - addcdiv_ - addcdiv.out - custom: - func: npu_transpose_to_contiguous(Tensor self) -> Tensor variants: function, method diff --git a/torch_npu/csrc/aten/ops/EqualKernelNpu.cpp b/torch_npu/csrc/aten/ops/EqualKernelNpu.cpp new file mode 100644 index 0000000000000000000000000000000000000000..4ae82b7198f78b40537923bdd1899ccd671caf6d --- /dev/null +++ b/torch_npu/csrc/aten/ops/EqualKernelNpu.cpp @@ -0,0 +1,55 @@ +// Copyright (c) 2020 Huawei Technologies Co., Ltd +// Copyright (c) 2019, Facebook CORPORATION. +// All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include "torch_npu/csrc/framework/utils/CalcuOpUtil.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" + +namespace at_npu { +namespace native { + +bool NPUNativeFunctions::equal(const at::Tensor& self, const at::Tensor& other) { + // check the shape of self and other + if(self.sizes() != other.sizes()) { + return false; + } + + TORCH_CHECK( + self.scalar_type() == other.scalar_type(), + "Expected object of scalar type ", + self.scalar_type(), + ", but got ", + other.scalar_type(), + " for argument #2 'other' in call to equal_npu"); + + // construct the output tensor of the NPU + at::Tensor result = OpPreparation::ApplyTensorWithFormat( + {1}, + self.options().dtype(at::kBool), + ACL_FORMAT_ND); + + // calculate the output result of the NPU + OpCommand cmd; + cmd.Name("TensorEqual") + .Input(self) + .Input(other) + .Output(result) + .Run(); + + return result.item().to(); +} +} // namespace native +} // namespace at_npu \ No newline at end of file