From 47df278610d2b3cb49ce70d09dd8748268c6b2db Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=89=9B=E4=BD=9C=E4=B8=9C?= Date: Tue, 25 Jan 2022 02:18:13 +0000 Subject: [PATCH] add log --- test/test_network_ops/test_log.py | 152 ++++++++++++++++++ torch_npu/csrc/aten/npu_native_functions.yaml | 3 + torch_npu/csrc/aten/ops/LogKernalNpu.cpp | 78 +++++++++ 3 files changed, 233 insertions(+) create mode 100644 test/test_network_ops/test_log.py create mode 100644 torch_npu/csrc/aten/ops/LogKernalNpu.cpp diff --git a/test/test_network_ops/test_log.py b/test/test_network_ops/test_log.py new file mode 100644 index 0000000..304ab96 --- /dev/null +++ b/test/test_network_ops/test_log.py @@ -0,0 +1,152 @@ +# Copyright (c) 2020, Huawei Technologies.All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch_npu +import numpy as np + +from torch_npu.testing.common_utils import TestCase, run_tests +from torch_npu.testing.common_device_type import instantiate_device_type_tests +from torch_npu.testing.util_test import create_common_tensor + + +class TestLog(TestCase): + def cpu_op_exec(self, input1): + output = torch.log(input1) + output = output.numpy() + return output + + def npu_op_exec(self, input1): + input1 = input1.to("npu") + output = torch.log(input1) + output = output.to("cpu") + output = output.numpy() + return output + + def npu_op_exec_out(self, input1): + input1 = input1.to("npu") + output = input1.to("npu") + torch.log(input1, out=output) + output = output.to("cpu") + output = output.numpy() + return output + + def cpu_inp_op_exec(self, input1): + output = torch.log_(input1) + output = output.numpy() + return output + + def npu_inp_op_exec(self, input1): + input1 = input1.to("npu") + output = torch.log_(input1) + output = input1.to("cpu") + output = output.numpy() + return output + + def cpu_inp_uncon_op_exec(self, input1): + input1 = input1.as_strided([2, 2], [1, 2], 2) + output = torch.log_(input1) + output = output.numpy() + return output + + def npu_inp_uncon_op_exec(self, input1): + input1 = input1.to("npu") + input1 = input1.as_strided([2, 2], [1, 2], 2) + output = torch.log_(input1) + output = input1.to("cpu") + output = output.numpy() + return output + +# TestCase + def test_log_shape_format_fp32(self, device): + format_list = [3] + shape_list = [(4, 4)] + shape_format = [ + [np.float32, i, j] for i in format_list for j in shape_list + ] + for item in shape_format: + cpu_input1, npu_input1 = create_common_tensor(item, 0, 100) + cpu_output = self.cpu_op_exec(cpu_input1) + npu_output = self.npu_op_exec(npu_input1) + self.assertRtolEqual(cpu_output, npu_output) + + def test_log_shape_format_fp16(self, device): + format_list = [3] + shape_list = [(4, 4)] + shape_format = [ + [np.float16, i, j] for i in format_list for j in shape_list + ] + for item in shape_format: + cpu_input1, npu_input1 = create_common_tensor(item, 0, 100) + cpu_input1 = cpu_input1.to(torch.float32) + cpu_output = self.cpu_op_exec(cpu_input1) + npu_output = self.npu_op_exec(npu_input1) + cpu_output = cpu_output.astype(np.float16) + self.assertRtolEqual(cpu_output, npu_output) + + def test_log_inp_shape_format_fp32(self, device): + format_list = [3] + shape_list = [(4, 4)] + shape_format = [ + [np.float32, i, j] for i in format_list for j in shape_list + ] + for item in shape_format: + cpu_input1, npu_input1 = create_common_tensor(item, 0, 100) + cpu_output = self.cpu_inp_op_exec(cpu_input1) + npu_output = self.npu_inp_op_exec(npu_input1) + self.assertRtolEqual(cpu_output, npu_output) + + def test_log_inp_shape_format_fp16(self, device): + format_list = [3] + shape_list = [(4, 4)] + shape_format = [ + [np.float16, i, j] for i in format_list for j in shape_list + ] + for item in shape_format: + cpu_input1, npu_input1 = create_common_tensor(item, 0, 100) + cpu_input1 = cpu_input1.to(torch.float32) + cpu_output = self.cpu_inp_op_exec(cpu_input1) + npu_output = self.npu_inp_op_exec(npu_input1) + cpu_output = cpu_output.astype(np.float16) + self.assertRtolEqual(cpu_output, npu_output) + + def test_log_inp_uncon_shape_format_fp32(self, device): + format_list = [3] + shape_list = [(8, 6)] + shape_format = [ + [np.float32, i, j] for i in format_list for j in shape_list + ] + for item in shape_format: + cpu_input1, npu_input1 = create_common_tensor(item, 0, 100) + cpu_output = self.cpu_inp_uncon_op_exec(cpu_input1) + npu_output = self.npu_inp_uncon_op_exec(npu_input1) + self.assertRtolEqual(cpu_output, npu_output) + + def test_log_inp_uncon_shape_format_fp16(self, device): + format_list = [3] + shape_list = [(8, 6)] + shape_format = [ + [np.float16, i, j] for i in format_list for j in shape_list + ] + for item in shape_format: + cpu_input1, npu_input1 = create_common_tensor(item, 0, 100) + cpu_input1 = cpu_input1.to(torch.float32) + cpu_output = self.cpu_inp_uncon_op_exec(cpu_input1) + npu_output = self.npu_inp_uncon_op_exec(npu_input1) + cpu_output = cpu_output.astype(np.float16) + self.assertRtolEqual(cpu_output, npu_output) + +instantiate_device_type_tests(TestLog, globals(), except_for="cpu") +if __name__ == '__main__': + run_tests() diff --git a/torch_npu/csrc/aten/npu_native_functions.yaml b/torch_npu/csrc/aten/npu_native_functions.yaml index 412fb9d..a26bd0c 100644 --- a/torch_npu/csrc/aten/npu_native_functions.yaml +++ b/torch_npu/csrc/aten/npu_native_functions.yaml @@ -13,6 +13,9 @@ supported: - log_softmax.Dimname - _log_softmax - _log_softmax_backward_data + - log + - log_ + - log_out - batch_norm - native_batch_norm - native_batch_norm_backward diff --git a/torch_npu/csrc/aten/ops/LogKernalNpu.cpp b/torch_npu/csrc/aten/ops/LogKernalNpu.cpp new file mode 100644 index 0000000..47a0ff2 --- /dev/null +++ b/torch_npu/csrc/aten/ops/LogKernalNpu.cpp @@ -0,0 +1,78 @@ +// Copyright (c) 2020 Huawei Technologies Co., Ltd +// Copyright (c) 2019, Facebook CORPORATION. +// All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include + +#include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" + + +namespace at_npu { +namespace native { +namespace { + +at::Tensor& log_out_npu_nocheck(at::Tensor& result, const at::Tensor& self) { + OpCommand cmd; + cmd.Name("Log") + .Input(self) + .Output(result) + .Attr("base", (float)-1) + .Attr("scale", (float)1) + .Attr("shift", (float)0) + .Run(); + + return result; +} + +} // namespace + +at::Tensor& NPUNativeFunctions::log_out(const at::Tensor& self, at::Tensor& result) { + if (!result.is_same(self)) { + OpPreparation::CheckOut( + {self}, + result, + ACL_FORMAT_ND, + self.scalar_type(), + self.sizes()); + } + + OpPipeWithDefinedOut pipe; + return pipe.CheckMemory({self}, {result}) + .Func([&self](Tensor& result){log_out_npu_nocheck(result, self);}) + .Call(result); +} + +at::Tensor NPUNativeFunctions::log(const at::Tensor& self) { + // construct the output tensor of the NPU + at::Tensor result = OpPreparation::ApplyTensor(self); + + // calculate the output result of the NPU + log_out_npu_nocheck(result, self); + + return result; +} + +at::Tensor& log_(at::Tensor& self) { + NPUNativeFunctions::log_out(self, self); + + return self; +} + +} // namespace native +} // namespace at_npu \ No newline at end of file -- Gitee