From 58cf426a45c0f1d45a7b587a595157289e8584ba Mon Sep 17 00:00:00 2001 From: hxf12345677 Date: Tue, 25 Jan 2022 16:43:53 +0800 Subject: [PATCH 1/4] =?UTF-8?q?equal=E7=AE=97=E5=AD=90=E7=A7=BB=E6=A4=8D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- test/test_network_ops/test_equal.py | 109 ++++++++++++++++++ torch_npu/csrc/aten/npu_native_functions.yaml | 1 + torch_npu/csrc/aten/ops/EqualKernelNpu.cpp | 55 +++++++++ 3 files changed, 165 insertions(+) create mode 100644 test/test_network_ops/test_equal.py create mode 100644 torch_npu/csrc/aten/ops/EqualKernelNpu.cpp diff --git a/test/test_network_ops/test_equal.py b/test/test_network_ops/test_equal.py new file mode 100644 index 0000000000..2239bcd6f6 --- /dev/null +++ b/test/test_network_ops/test_equal.py @@ -0,0 +1,109 @@ +# Copyright (c) 2021 Huawei Technologies Co., Ltd +# Copyright (c) 2019, Facebook CORPORATION. +# All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import sys +import copy +import torch +import torch_npu +import numpy as np +from torch_npu.testing.common_utils import TestCase, run_tests +from torch_npu.testing.common_device_type import Dtypes, instantiate_device_type_tests +from torch_npu.testing.util_test import create_common_tensor, test_2args_broadcast, create_dtype_tensor, UT_FAST_MODE + + +class TestTensorEqual(TestCase): + + def cpu_op_exec(self, input1, input2): + output = torch.equal(input1, input2) + output = torch.tensor(output) + output = output.numpy() + return output + + def npu_op_exec(self, input1, input2): + output = torch.equal(input1, input2) + output = torch.tensor(output).to("cpu") + output = output.numpy() + return output + + def test_tensor_equal_common_shape_format(self, device): + shape_format = [ + [[np.float32, 0, (4, 3)], [np.float32, 0, (4, 3)]], + [[np.float32, 29, (4, 3, 1)], [np.float32, 29, (4, 1, 5)]], + [[np.float32, 2, (4, 3, 2)], [np.float32, 2, (4, 3, 2)]], + [[np.float32, 3, (7, 3, 2)], [np.float32, 3, (7, 3, 2)]], + [[np.float32, 4, (8, 4)], [np.float32, 4, (8, 4)]], + [[np.int32, 0, (2, 3)], [np.int32, 0, (2, 3)]], + [[np.int32, 2, (4, 3, 1)], [np.int32, 2, (4, 1, 5)]], + [[np.int32, 2, (4, 3, 2)], [np.int32, 2, (4, 3, 2)]], + [[np.int32, -1, (7, 3, 2)], [np.int32, -1, (7, 3, 2)]], + [[np.int8, 0, (7, 3)], [np.int8, 0, (7, 3)]], + [[np.int8, 2, (4, 3, 2)], [np.int8, 2, (4, 3, 2)]], + [[np.uint8, 0, (3, 2)], [np.uint8, 0, (3, 2)]], + [[np.uint8, 2, (4, 3, 2)], [np.uint8, 2, (4, 3, 2)]] + ] + for item in shape_format: + cpu_input1, npu_input1 = create_common_tensor(item[0], 1, 100) + cpu_input2, npu_input2 = create_common_tensor(item[1], 1, 100) + cpu_output = self.cpu_op_exec(cpu_input1, cpu_input2) + npu_output = self.npu_op_exec(npu_input1, npu_input2) + self.assertRtolEqual(cpu_output, npu_output) + + cpu_input1, npu_input1 = create_common_tensor(shape_format[11][0], 1, 100) + cpu_input2 = cpu_input1 + npu_input2 = npu_input1 + cpu_output = self.cpu_op_exec(cpu_input1, cpu_input2) + npu_output = self.npu_op_exec(npu_input1, npu_input2) + self.assertRtolEqual(cpu_output, npu_output) + + def test_tensor_equal_float16_shape_format(self, device): + def cpu_op_exec_fp16(input1, input2): + output = np.array_equal(input1, input2) + output = torch.tensor(output) + output = output.numpy() + output = output.astype(np.float16) + return output + + def npu_op_exec_fp16(input1, input2): + output = torch.equal(input1, input2) + output = torch.tensor(output).to("cpu") + output = output.numpy() + output = output.astype(np.float16) + return output + + shape_format = [ + [[np.float16, 0, (4, 3)], [np.float16, 0, (4, 3)]], + [[np.float16, 29, (4, 3, 1)], [np.float16, 29, (4, 1, 5)]], + [[np.float16, 2, (4, 3, 2)], [np.float16, 2, (4, 3, 2)]], + [[np.float16, 3, (7, 3, 2)], [np.float16, 3, (7, 3, 2)]], + [[np.float16, 4, (8, 4)], [np.float16, 4, (8, 4)]], + ] + for item in shape_format: + cpu_input1, npu_input1 = create_common_tensor(item[0], 1, 100) + cpu_input2, npu_input2 = create_common_tensor(item[1], 1, 100) + cpu_output = cpu_op_exec_fp16(cpu_input1, cpu_input2) + npu_output = npu_op_exec_fp16(npu_input1, npu_input2) + self.assertRtolEqual(cpu_output, npu_output) + + cpu_input1, npu_input1 = create_common_tensor(shape_format[2][0], 2, 100) + cpu_input2 = cpu_input1 + npu_input2 = npu_input1 + cpu_output = cpu_op_exec_fp16(cpu_input1, cpu_input2) + npu_output = npu_op_exec_fp16(npu_input1, npu_input2) + self.assertRtolEqual(cpu_output, npu_output) + + +instantiate_device_type_tests(TestTensorEqual, globals(), except_for="cpu") +if __name__ == "__main__": + run_tests() \ No newline at end of file diff --git a/torch_npu/csrc/aten/npu_native_functions.yaml b/torch_npu/csrc/aten/npu_native_functions.yaml index 412fb9da7c..f7beeb5cf5 100644 --- a/torch_npu/csrc/aten/npu_native_functions.yaml +++ b/torch_npu/csrc/aten/npu_native_functions.yaml @@ -240,6 +240,7 @@ supported: - addcdiv - addcdiv_ - addcdiv.out + - equal custom: - func: npu_transpose_to_contiguous(Tensor self) -> Tensor diff --git a/torch_npu/csrc/aten/ops/EqualKernelNpu.cpp b/torch_npu/csrc/aten/ops/EqualKernelNpu.cpp new file mode 100644 index 0000000000..0847d0c8c5 --- /dev/null +++ b/torch_npu/csrc/aten/ops/EqualKernelNpu.cpp @@ -0,0 +1,55 @@ +// Copyright (c) 2020 Huawei Technologies Co., Ltd +// Copyright (c) 2019, Facebook CORPORATION. +// All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include "torch_npu/csrc/framework/utils/CalcuOpUtil.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" + +namespace at_npu { +namespace native { + +bool NPUNativeFunctions::equal(const at::Tensor& self, const at::Tensor& other) { + // check the shape of self and other + if(self.sizes() != other.sizes()) { + return false; + } + + TORCH_CHECK( + self.scalar_type() == other.scalar_type(), + "Expected object of scalar type ", + self.scalar_type(), + ", but got ", + other.scalar_type(), + " for argument #2 'other' in call to equal_npu"); + + // construct the output tensor of the NPU + at::Tensor result = OpPreparation::ApplyTensorWithFormat( + {1}, + self.options().dtype(at::kBool), + ACL_FORMAT_ND); + + // calculate the output result of the NPU + OpCommand cmd; + cmd.Name("TensorEqual") + .Input(self) + .Input(other) + .Output(result) + .Run(); + + return result.item().to(); +} +} // namespace native +} // namespace at \ No newline at end of file -- Gitee From 140a4189c15d360bb392272360ee1f43f505c9b4 Mon Sep 17 00:00:00 2001 From: hxf12345677 Date: Tue, 25 Jan 2022 16:50:08 +0800 Subject: [PATCH 2/4] =?UTF-8?q?equal=E6=B3=A8=E9=87=8A=E4=BF=AE=E6=94=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- torch_npu/csrc/aten/ops/EqualKernelNpu.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_npu/csrc/aten/ops/EqualKernelNpu.cpp b/torch_npu/csrc/aten/ops/EqualKernelNpu.cpp index 0847d0c8c5..4ae82b7198 100644 --- a/torch_npu/csrc/aten/ops/EqualKernelNpu.cpp +++ b/torch_npu/csrc/aten/ops/EqualKernelNpu.cpp @@ -52,4 +52,4 @@ bool NPUNativeFunctions::equal(const at::Tensor& self, const at::Tensor& other) return result.item().to(); } } // namespace native -} // namespace at \ No newline at end of file +} // namespace at_npu \ No newline at end of file -- Gitee From 0488c7c6503ed7e334b0a10eded80ee0bfea3df2 Mon Sep 17 00:00:00 2001 From: hxf12345677 Date: Tue, 25 Jan 2022 17:44:45 +0800 Subject: [PATCH 3/4] =?UTF-8?q?equal=E4=BF=AE=E6=94=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- test/test_network_ops/test_equal.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/test_network_ops/test_equal.py b/test/test_network_ops/test_equal.py index 2239bcd6f6..a0f23a41ec 100644 --- a/test/test_network_ops/test_equal.py +++ b/test/test_network_ops/test_equal.py @@ -13,14 +13,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import sys + import copy import torch import torch_npu import numpy as np from torch_npu.testing.common_utils import TestCase, run_tests from torch_npu.testing.common_device_type import Dtypes, instantiate_device_type_tests -from torch_npu.testing.util_test import create_common_tensor, test_2args_broadcast, create_dtype_tensor, UT_FAST_MODE +from torch_npu.testing.util_test import create_common_tensor, create_dtype_tensor, UT_FAST_MODE class TestTensorEqual(TestCase): -- Gitee From e6ba2c88846920f84bbc5f26d979904cd094d1da Mon Sep 17 00:00:00 2001 From: hxf12345677 Date: Wed, 26 Jan 2022 12:18:03 +0800 Subject: [PATCH 4/4] =?UTF-8?q?equal=20yuml=E4=BF=AE=E6=94=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- torch_npu/csrc/aten/npu_native_functions.yaml | 2 -- 1 file changed, 2 deletions(-) diff --git a/torch_npu/csrc/aten/npu_native_functions.yaml b/torch_npu/csrc/aten/npu_native_functions.yaml index f7beeb5cf5..04d32c0669 100644 --- a/torch_npu/csrc/aten/npu_native_functions.yaml +++ b/torch_npu/csrc/aten/npu_native_functions.yaml @@ -240,8 +240,6 @@ supported: - addcdiv - addcdiv_ - addcdiv.out - - equal - custom: - func: npu_transpose_to_contiguous(Tensor self) -> Tensor variants: function, method -- Gitee