From 03530ead40c278e888cbaf7d0d4059af4ea39327 Mon Sep 17 00:00:00 2001 From: pipihugh Date: Thu, 17 Feb 2022 17:07:16 +0800 Subject: [PATCH] =?UTF-8?q?log1p=E7=AE=97=E5=AD=90=E7=A7=BB=E6=A4=8D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- test/test_network_ops/test_log1p.py | 84 ++++++++++++++++++++++ torch_npu/csrc/aten/ops/Log1pKernelNpu.cpp | 53 ++++++++++++++ 2 files changed, 137 insertions(+) create mode 100644 test/test_network_ops/test_log1p.py create mode 100644 torch_npu/csrc/aten/ops/Log1pKernelNpu.cpp diff --git a/test/test_network_ops/test_log1p.py b/test/test_network_ops/test_log1p.py new file mode 100644 index 0000000000..878375d0fa --- /dev/null +++ b/test/test_network_ops/test_log1p.py @@ -0,0 +1,84 @@ +# Copyright (c) 2020, Huawei Technologies.All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch_npu +import numpy as np + +from torch_npu.testing.common_utils import TestCase, run_tests +from torch_npu.testing.common_device_type import instantiate_device_type_tests +from torch_npu.testing.util_test import create_common_tensor + +class TestLog1p(TestCase): + def cpu_op_exec(self, input1): + output = torch.log1p(input1) + output = output.numpy() + return output + + def npu_op_exec(self, input1): + output = torch.log1p(input1) + output = output.to("cpu") + output = output.numpy() + return output + + def cpu_op_exec_(self, input1): + input1.log1p_() + input1 = input1.numpy() + return input1 + + def npu_op_exec_(self, input1): + input1.log1p_() + input1 = input1.to("cpu") + input1 = input1.numpy() + return input1 + + def cpu_op_exec_fp16(self, input1): + input1 = input1.to(torch.float32) + output = torch.log1p(input1) + output = output.numpy() + output = output.astype(np.float16) + return output + + def test_log1p_common_shape_format(self, device): + shape_format = [ + [[np.float32, 0, 1]], + [[np.float32, 0, (64, 10)]], + [[np.float32, 4, (32, 1, 3, 3)]], + [[np.float32, 29, (10, 128)]] + ] + for item in shape_format: + cpu_input, npu_input = create_common_tensor(item[0], -1, 1) + cpu_output = self.cpu_op_exec(cpu_input) + npu_output = self.npu_op_exec(npu_input) + cpu_output1 = self.cpu_op_exec_(cpu_input) + npu_output1 = self.npu_op_exec_(npu_input) + self.assertRtolEqual(cpu_output, npu_output) + self.assertRtolEqual(cpu_output1, npu_output1) + + def test_log1p_float16_shape_format(self, device): + shape_format = [ + [[np.float16, -1, 1]], + [[np.float16, -1, (64, 10)]], + [[np.float16, -1, (31, 1, 3)]] + ] + for item in shape_format: + cpu_input, npu_input = create_common_tensor(item[0], -1, 1) + cpu_output = self.cpu_op_exec_fp16(cpu_input) + npu_output = self.npu_op_exec(npu_input) + self.assertRtolEqual(cpu_output, npu_output) + +instantiate_device_type_tests(TestLog1p, globals(), except_for="cpu") +if __name__ == "__main__": + run_tests() + \ No newline at end of file diff --git a/torch_npu/csrc/aten/ops/Log1pKernelNpu.cpp b/torch_npu/csrc/aten/ops/Log1pKernelNpu.cpp new file mode 100644 index 0000000000..5b1505134c --- /dev/null +++ b/torch_npu/csrc/aten/ops/Log1pKernelNpu.cpp @@ -0,0 +1,53 @@ +// Copyright (c) 2020, Huawei Technologies.All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include "torch_npu/csrc/framework/utils/OpTemplate.h" +#include "torch_npu/csrc/framework/utils/KernelNpuOutputSize.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" + +namespace at_npu { +namespace native { + +at::Tensor& log1p_out_npu_nocheck(const at::Tensor& self, at::Tensor& result){ + OpCommand cmd; + cmd.Name("Log1p") + .Input(self) + .Output(result) + .Run(); + return result; +} + +at::Tensor& NPUNativeFunctions::log1p_out(const at::Tensor& self, at::Tensor& result){ + OpPreparation::CheckOut( + {self}, + result, + self); + OpPipeWithDefinedOut pipe; + return pipe.CheckMemory({self}, {result}) + .Func([&self](at::Tensor& result){log1p_out_npu_nocheck(self, result);}) + .Call(result); +} + +at::Tensor NPUNativeFunctions::log1p(const at::Tensor& self) { + at::Tensor result = OpPreparation::ApplyTensor(self); + log1p_out_npu_nocheck(self, result); + return result; +} + +at::Tensor& NPUNativeFunctions::log1p_(at::Tensor& self) { + log1p_out(self, self); + return self; +} +}} -- Gitee