From c15121e68a031c5b66d062d99d4643142a0a4db9 Mon Sep 17 00:00:00 2001 From: hxf12345677 Date: Tue, 25 Jan 2022 19:08:44 +0800 Subject: [PATCH 1/6] =?UTF-8?q?logsumexp=E7=AE=97=E5=AD=901.8.1=E7=A7=BB?= =?UTF-8?q?=E6=A4=8D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- test/test_network_ops/test_logsumexp.py | 109 ++++++++++++++++++ torch_npu/csrc/aten/npu_native_functions.yaml | 4 + .../csrc/aten/ops/LogSumExpKernelNpu.cpp | 51 ++++++++ 3 files changed, 164 insertions(+) create mode 100644 test/test_network_ops/test_logsumexp.py create mode 100644 torch_npu/csrc/aten/ops/LogSumExpKernelNpu.cpp diff --git a/test/test_network_ops/test_logsumexp.py b/test/test_network_ops/test_logsumexp.py new file mode 100644 index 00000000000..fd47e7d0088 --- /dev/null +++ b/test/test_network_ops/test_logsumexp.py @@ -0,0 +1,109 @@ +# Copyright (c) 2020, Huawei Technologies.All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +import torch_npu +import numpy as np +from torch_npu.testing.common_utils import TestCase, run_tests +from torch_npu.testing.common_device_type import Dtypes, instantiate_device_type_tests +from torch_npu.testing.util_test import create_common_tensor, create_dtype_tensor, UT_FAST_MODE + + +class TestLogsumexp(TestCase): + + def generate_data(self, min, max, shape, dtype): + x = np.random.uniform(min, max, shape).astype(dtype) + npu_x = torch.from_numpy(x) + return npu_x + + def cpu_op_exec(self, input1, dim, keepdim): + output = torch.logsumexp(input1, dim, keepdim=keepdim) + return output + + def npu_op_exec(self, input1, dim, keepdim): + output = torch.logsumexp(input1, dim, keepdim=keepdim) + output = output.to("cpu") + return output + + def cpu_op_out_exec(self, input1, dim, out, keepdim): + torch.logsumexp(input1, dim, keepdim=keepdim, out=out) + return out + + def npu_op_out_exec(self, input1, dim, out, keepdim): + torch.logsumexp(input1, dim, keepdim=keepdim, out=out) + output = out.to("cpu") + return output + + def test_logsumexp_shape_format(self, device): + shape_format = [ + [[np.float32, 0, (3, 4, 2)], [np.float32, 0, (3, 4, 1)], 2, True], + [[np.float32, 0, (3, 4, 2)], [np.float32, 0, (3, 4)], 2, False], + [[np.float32, 0, (3, 4, 2)], [np.float32, 0, (3,)], [1, 2], False], + [[np.float32, 0, (2, 3, 4, 2)], [np.float32, 0, (2, 3, 1, 2)], 2, True], + [[np.float32, 0, (2, 3, 4, 2)], [np.float32, 0, (2, 3, 2)], 2, False], + [[np.float32, 0, (2, 3, 4, 2)], [np.float32, 0, (2, 3)], [2, 3], False], + [[np.float16, 0, (3, 4, 2)], [np.float16, 0, (3, 4, 1)], 2, True], + [[np.float16, 0, (3, 4, 2)], [np.float16, 0, (3, 4)], 2, False], + [[np.float16, 0, (3, 4, 2)], [np.float16, 0, (3,)], [1, 2], False], + [[np.float16, 0, (2, 3, 4, 2)], [np.float16, 0, (2, 3, 1, 2)], 2, True], + [[np.float16, 0, (2, 3, 4, 2)], [np.float16, 0, (2, 3, 2)], 2, False], + [[np.float16, 0, (2, 3, 4, 2)], [np.float16, 0, (2, 3)], [2, 3], False] + ] + for item in shape_format: + cpu_input, npu_input = create_common_tensor(item[0], 1, 100) + cpu_out, npu_out = create_common_tensor(item[1], 1, 10) + if cpu_input.dtype == torch.float16: + cpu_input = cpu_input.to(torch.float32) + if cpu_out.dtype == torch.float16: + cpu_out = cpu_out.to(torch.float32) + cpu_out_result = self.cpu_op_out_exec(cpu_input, item[2], cpu_out, item[3]) + npu_out_result = self.npu_op_out_exec(npu_input, item[2], npu_out, item[3]) + cpu_out_result = cpu_out_result.to(npu_out_result.dtype) + self.assertRtolEqual(cpu_out_result.numpy(), npu_out_result.numpy()) + + cpu_result = self.cpu_op_exec(cpu_input, item[2], item[3]) + npu_result = self.npu_op_exec(npu_input, item[2], item[3]) + cpu_result = cpu_result.to(npu_result.dtype) + self.assertRtolEqual(cpu_result.numpy(), npu_result.numpy()) + + def test_logsumexp_dimname1(self, device): + cpu_input = self.generate_data(-10, 10, (2, 14, 69, 96, 1824), np.float32) + cpu_input.names = ['A', 'B', 'C', 'D', 'E'] + dim = ['C'] + keepdim = True + cpu_out = self.cpu_op_exec(cpu_input, dim, keepdim) + npu_out = self.npu_op_exec(cpu_input.npu(), dim, keepdim) + self.assertRtolEqual(cpu_out.numpy(), npu_out.numpy()) + + def test_logsumexp_dimname2(self, device): + cpu_input = self.generate_data(-10, 10, (14, 69, 96, 1824), np.float32) + cpu_input.names = ['A', 'B', 'C', 'D'] + dim = ['B', 'C'] + keepdim = False + cpu_out = self.cpu_op_exec(cpu_input, dim, keepdim) + npu_out = self.npu_op_exec(cpu_input.npu(), dim, keepdim) + self.assertRtolEqual(cpu_out.numpy(), npu_out.numpy()) + + def test_logsumexp_dimname3(self, device): + cpu_input = self.generate_data(-10, 10, (14, 69, 96, 1824), np.float32) + cpu_input.names = ['A', 'B', 'C', 'D'] + dim = ['B', 'C', 'D'] + keepdim = False + cpu_out = self.cpu_op_exec(cpu_input, dim, keepdim) + npu_out = self.npu_op_exec(cpu_input.npu(), dim, keepdim) + self.assertRtolEqual(cpu_out.numpy(), npu_out.numpy()) + + +instantiate_device_type_tests(TestLogsumexp, globals(), except_for="cpu") +if __name__ == "__main__": + run_tests() \ No newline at end of file diff --git a/torch_npu/csrc/aten/npu_native_functions.yaml b/torch_npu/csrc/aten/npu_native_functions.yaml index 412fb9da7ce..cb0e7cb3666 100644 --- a/torch_npu/csrc/aten/npu_native_functions.yaml +++ b/torch_npu/csrc/aten/npu_native_functions.yaml @@ -240,6 +240,10 @@ supported: - addcdiv - addcdiv_ - addcdiv.out + - logsumexp + - logsumexp.out + - logsumexp.names + - logsumexp.names_out custom: - func: npu_transpose_to_contiguous(Tensor self) -> Tensor diff --git a/torch_npu/csrc/aten/ops/LogSumExpKernelNpu.cpp b/torch_npu/csrc/aten/ops/LogSumExpKernelNpu.cpp new file mode 100644 index 00000000000..e5dba81532b --- /dev/null +++ b/torch_npu/csrc/aten/ops/LogSumExpKernelNpu.cpp @@ -0,0 +1,51 @@ +// Copyright (c) 2020, Huawei Technologies.All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include "torch_npu/csrc/framework/utils/CalcuOpUtil.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" + +namespace at_npu { +namespace native { + +c10::SmallVector logsumexp_npu_output_size( + const at::Tensor& self, + at::IntArrayRef dims, + bool keepdim) { + return reduce_ops_npu_output_size(self, dims, keepdim); +} + +at::Tensor& NPUNativeFunctions::logsumexp_out(const at::Tensor& self, at::DimnameList dims, bool keepdim, at::Tensor& result) { + return logsumexp_out(self, dimnames_to_positions(self, dims), keepdim, result); +} + +at::Tensor& NPUNativeFunctions::logsumexp_out(const Tensor& self, at::IntArrayRef dims, bool keepdim, at::Tensor& result) { + OpPreparation::CheckOut( + {self}, + result, + self) + return at::native::logsumexp_out(result, self, dims, keepdim); +} + +at::Tensor NPUNativeFunctions::logsumexp(const at::Tensor& self, at::IntArrayRef dims, bool keepdim) { + auto outputSize = logsumexp_npu_output_size(self, dims, keepdim); + at::Tensor result = OpPreparation::ApplyTensorWithSizes(outputSize, self.options()); + return logsumexp_out(self, dims, keepdim, result); +} + +at::Tensor NPUNativeFunctions::logsumexp(const at::Tensor& self, at::DimnameList dims, bool keepdim) { + return logsumexp(self, dimnames_to_positions(self, dims), keepdim); +} +} // namespace native +} // namespace at_npu \ No newline at end of file -- Gitee From fcde34bff45634a3c24845293456ffa397ed74c0 Mon Sep 17 00:00:00 2001 From: hxf12345677 Date: Tue, 25 Jan 2022 20:38:23 +0800 Subject: [PATCH 2/6] =?UTF-8?q?logsumexp=E7=AE=97=E5=AD=90=E4=BF=AE?= =?UTF-8?q?=E6=94=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- test/test_network_ops/test_logsumexp.py | 4 ++-- torch_npu/csrc/aten/ops/LogSumExpKernelNpu.cpp | 11 ++++++++--- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/test/test_network_ops/test_logsumexp.py b/test/test_network_ops/test_logsumexp.py index fd47e7d0088..ac28a164463 100644 --- a/test/test_network_ops/test_logsumexp.py +++ b/test/test_network_ops/test_logsumexp.py @@ -21,8 +21,8 @@ from torch_npu.testing.util_test import create_common_tensor, create_dtype_tenso class TestLogsumexp(TestCase): - def generate_data(self, min, max, shape, dtype): - x = np.random.uniform(min, max, shape).astype(dtype) + def generate_data(self, min1, max1, shape, dtype): + x = np.random.uniform(min1, max1, shape).astype(dtype) npu_x = torch.from_numpy(x) return npu_x diff --git a/torch_npu/csrc/aten/ops/LogSumExpKernelNpu.cpp b/torch_npu/csrc/aten/ops/LogSumExpKernelNpu.cpp index e5dba81532b..f6b24004fa6 100644 --- a/torch_npu/csrc/aten/ops/LogSumExpKernelNpu.cpp +++ b/torch_npu/csrc/aten/ops/LogSumExpKernelNpu.cpp @@ -25,23 +25,28 @@ c10::SmallVector logsumexp_npu_output_size( bool keepdim) { return reduce_ops_npu_output_size(self, dims, keepdim); } +at::Tensor& NPUNativeFunctions::logsumexp_out_nocheck(const Tensor& self, at::IntArrayRef dims, bool keepdim, at::Tensor& result) { + return at::native::logsumexp_out(result, self, dims, keepdim); +} at::Tensor& NPUNativeFunctions::logsumexp_out(const at::Tensor& self, at::DimnameList dims, bool keepdim, at::Tensor& result) { return logsumexp_out(self, dimnames_to_positions(self, dims), keepdim, result); } at::Tensor& NPUNativeFunctions::logsumexp_out(const Tensor& self, at::IntArrayRef dims, bool keepdim, at::Tensor& result) { + auto outputSize = logsumexp_npu_output_size(self, dims, keepdim); OpPreparation::CheckOut( {self}, result, - self) - return at::native::logsumexp_out(result, self, dims, keepdim); + self, + outputSize) + return logsumexp_out_nocheck(result, self, dims, keepdim); } at::Tensor NPUNativeFunctions::logsumexp(const at::Tensor& self, at::IntArrayRef dims, bool keepdim) { auto outputSize = logsumexp_npu_output_size(self, dims, keepdim); at::Tensor result = OpPreparation::ApplyTensorWithSizes(outputSize, self.options()); - return logsumexp_out(self, dims, keepdim, result); + return logsumexp_out_nocheck(self, dims, keepdim, result); } at::Tensor NPUNativeFunctions::logsumexp(const at::Tensor& self, at::DimnameList dims, bool keepdim) { -- Gitee From 1ea717db7c1eb3b5863aac9d6dec6962b9ab14fd Mon Sep 17 00:00:00 2001 From: hxf12345677 Date: Tue, 25 Jan 2022 20:40:38 +0800 Subject: [PATCH 3/6] =?UTF-8?q?logsumexp=E7=AE=97=E5=AD=90=E4=BF=AE?= =?UTF-8?q?=E6=94=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- torch_npu/csrc/aten/ops/LogSumExpKernelNpu.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torch_npu/csrc/aten/ops/LogSumExpKernelNpu.cpp b/torch_npu/csrc/aten/ops/LogSumExpKernelNpu.cpp index f6b24004fa6..536c904df58 100644 --- a/torch_npu/csrc/aten/ops/LogSumExpKernelNpu.cpp +++ b/torch_npu/csrc/aten/ops/LogSumExpKernelNpu.cpp @@ -25,7 +25,7 @@ c10::SmallVector logsumexp_npu_output_size( bool keepdim) { return reduce_ops_npu_output_size(self, dims, keepdim); } -at::Tensor& NPUNativeFunctions::logsumexp_out_nocheck(const Tensor& self, at::IntArrayRef dims, bool keepdim, at::Tensor& result) { +at::Tensor& logsumexp_out_nocheck(const at::Tensor& self, at::IntArrayRef dims, bool keepdim, at::Tensor& result) { return at::native::logsumexp_out(result, self, dims, keepdim); } @@ -33,7 +33,7 @@ at::Tensor& NPUNativeFunctions::logsumexp_out(const at::Tensor& self, at::Dimnam return logsumexp_out(self, dimnames_to_positions(self, dims), keepdim, result); } -at::Tensor& NPUNativeFunctions::logsumexp_out(const Tensor& self, at::IntArrayRef dims, bool keepdim, at::Tensor& result) { +at::Tensor& NPUNativeFunctions::logsumexp_out(const at::Tensor& self, at::IntArrayRef dims, bool keepdim, at::Tensor& result) { auto outputSize = logsumexp_npu_output_size(self, dims, keepdim); OpPreparation::CheckOut( {self}, -- Gitee From a3358a058c7003e1d84c0d32ce1c4ab031ee71ba Mon Sep 17 00:00:00 2001 From: hxf12345677 Date: Wed, 26 Jan 2022 07:21:43 +0000 Subject: [PATCH 4/6] =?UTF-8?q?logsumexp=20yuml=E4=BF=AE=E6=94=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- torch_npu/csrc/aten/npu_native_functions.yaml | 4 ---- 1 file changed, 4 deletions(-) diff --git a/torch_npu/csrc/aten/npu_native_functions.yaml b/torch_npu/csrc/aten/npu_native_functions.yaml index cb0e7cb3666..412fb9da7ce 100644 --- a/torch_npu/csrc/aten/npu_native_functions.yaml +++ b/torch_npu/csrc/aten/npu_native_functions.yaml @@ -240,10 +240,6 @@ supported: - addcdiv - addcdiv_ - addcdiv.out - - logsumexp - - logsumexp.out - - logsumexp.names - - logsumexp.names_out custom: - func: npu_transpose_to_contiguous(Tensor self) -> Tensor -- Gitee From 107e1d64e04012c436293fb9699d967d3fbd7d86 Mon Sep 17 00:00:00 2001 From: hxf12345677 Date: Wed, 26 Jan 2022 09:01:13 +0000 Subject: [PATCH 5/6] =?UTF-8?q?logsumexp=20=E7=AE=97=E5=AD=90=E4=BF=AE?= =?UTF-8?q?=E6=94=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- torch_npu/csrc/aten/ops/LogSumExpKernelNpu.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_npu/csrc/aten/ops/LogSumExpKernelNpu.cpp b/torch_npu/csrc/aten/ops/LogSumExpKernelNpu.cpp index 536c904df58..768f161ca21 100644 --- a/torch_npu/csrc/aten/ops/LogSumExpKernelNpu.cpp +++ b/torch_npu/csrc/aten/ops/LogSumExpKernelNpu.cpp @@ -39,7 +39,7 @@ at::Tensor& NPUNativeFunctions::logsumexp_out(const at::Tensor& self, at::IntArr {self}, result, self, - outputSize) + outputSize); return logsumexp_out_nocheck(result, self, dims, keepdim); } -- Gitee From e45ede170b7488b5c2efb650809d1c3bc7b9b907 Mon Sep 17 00:00:00 2001 From: hxf12345677 Date: Thu, 27 Jan 2022 01:33:12 +0000 Subject: [PATCH 6/6] =?UTF-8?q?logsumexp=20=20=E7=AE=97=E5=AD=90=E4=BF=AE?= =?UTF-8?q?=E6=94=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- torch_npu/csrc/aten/ops/LogSumExpKernelNpu.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_npu/csrc/aten/ops/LogSumExpKernelNpu.cpp b/torch_npu/csrc/aten/ops/LogSumExpKernelNpu.cpp index 768f161ca21..0a1c2d5f8f9 100644 --- a/torch_npu/csrc/aten/ops/LogSumExpKernelNpu.cpp +++ b/torch_npu/csrc/aten/ops/LogSumExpKernelNpu.cpp @@ -40,7 +40,7 @@ at::Tensor& NPUNativeFunctions::logsumexp_out(const at::Tensor& self, at::IntArr result, self, outputSize); - return logsumexp_out_nocheck(result, self, dims, keepdim); + return logsumexp_out_nocheck(self, dims, keepdim, result); } at::Tensor NPUNativeFunctions::logsumexp(const at::Tensor& self, at::IntArrayRef dims, bool keepdim) { -- Gitee