From 75a3b9a89a1a6ba415debd259db7cc587dbbb0f2 Mon Sep 17 00:00:00 2001 From: shenpengcheng Date: Tue, 25 Jan 2022 19:33:02 +0800 Subject: [PATCH 1/2] 1.8_prelu --- test/test_network_ops/test_prelu.py | 56 +++++++++++++++++++ torch_npu/csrc/aten/npu_native_functions.yaml | 1 + torch_npu/csrc/aten/ops/PreluKernelNpu.cpp | 38 +++++++++++++ 3 files changed, 95 insertions(+) create mode 100644 test/test_network_ops/test_prelu.py create mode 100644 torch_npu/csrc/aten/ops/PreluKernelNpu.cpp diff --git a/test/test_network_ops/test_prelu.py b/test/test_network_ops/test_prelu.py new file mode 100644 index 00000000000..8a69c5a4f19 --- /dev/null +++ b/test/test_network_ops/test_prelu.py @@ -0,0 +1,56 @@ +# Copyright (c) 2020, Huawei Technologies.All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch_npu +import numpy as np +from torch_npu.testing.common_utils import TestCase, run_tests +from torch_npu.testing.common_device_type import instantiate_device_type_tests +from torch_npu.testing.util_test import create_common_tensor + +class TestPrelu(TestCase): + + def cpu_op_exec(self, input1, input2): + output = input1.prelu(input2) + return output.numpy() + + def npu_op_exec(self, input1, input2): + output = input1.prelu(input2) + output = output.to("cpu") + if output.dtype != torch.float32: + output = output.to(torch.float32) + return output.numpy() + + def test_prelu_shape_format(self, device): + shape_format = [ + [[np.float32, 0, [1, 1]], [np.float32, 0, 1]], + [[np.float32, 0, [2, 2]], [np.float32, 0, 1]], + [[np.float16, 0, [1, 1]], [np.float16, 0, 1]], + [[np.float16, 0, [2, 2]], [np.float16, 0, 1]] + ] + for item in shape_format: + cpu_input1, npu_input1 = create_common_tensor(item[0], 0, 10) + cpu_input2, npu_input2 = create_common_tensor(item[1], 0, 10) + if cpu_input1.dtype == torch.float16: + cpu_input1 = cpu_input1.to(torch.float32) + if cpu_input2.dtype == torch.float16: + cpu_input2 = cpu_input2.to(torch.float32) + cpu_output = self.cpu_op_exec(cpu_input1, cpu_input2) + npu_output = self.npu_op_exec(npu_input1, npu_input2) + cpu_output = cpu_output.astype(npu_output.dtype) + self.assertRtolEqual(cpu_output, npu_output) + +instantiate_device_type_tests(TestPrelu, globals(), except_for='cpu') +if __name__ == "__main__": + run_tests() \ No newline at end of file diff --git a/torch_npu/csrc/aten/npu_native_functions.yaml b/torch_npu/csrc/aten/npu_native_functions.yaml index 412fb9da7ce..58ba15b4907 100644 --- a/torch_npu/csrc/aten/npu_native_functions.yaml +++ b/torch_npu/csrc/aten/npu_native_functions.yaml @@ -240,6 +240,7 @@ supported: - addcdiv - addcdiv_ - addcdiv.out + - prelu custom: - func: npu_transpose_to_contiguous(Tensor self) -> Tensor diff --git a/torch_npu/csrc/aten/ops/PreluKernelNpu.cpp b/torch_npu/csrc/aten/ops/PreluKernelNpu.cpp new file mode 100644 index 00000000000..2b9baa4be89 --- /dev/null +++ b/torch_npu/csrc/aten/ops/PreluKernelNpu.cpp @@ -0,0 +1,38 @@ +// Copyright (c) 2020, Huawei Technologies.All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" + +namespace at_npu { +namespace native { + +at::Tensor NPUNativeFunctions::prelu(const at::Tensor& self, const at::Tensor& weight_) { + auto input = self.contiguous(); + auto weight = weight_.contiguous(); + + // calculate the output size + auto outputSize = input_same_output_size(self); + at::Tensor result = OpPreparation::ApplyTensor(input, outputSize); + + OpCommand cmd; + cmd.Name("PRelu") + .Input(self) + .Input(weight) + .Output(result) + .Run(); + return result; +} +} // namespace native +} // namespace at_npu \ No newline at end of file -- Gitee From f39f493fc5a5c572604e686a18dfd7df4e7eade8 Mon Sep 17 00:00:00 2001 From: shenpengcheng Date: Tue, 25 Jan 2022 19:40:53 +0800 Subject: [PATCH 2/2] 1.8_prelu --- test/test_network_ops/test_prelu.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_network_ops/test_prelu.py b/test/test_network_ops/test_prelu.py index 8a69c5a4f19..cebcf1f1e85 100644 --- a/test/test_network_ops/test_prelu.py +++ b/test/test_network_ops/test_prelu.py @@ -29,7 +29,7 @@ class TestPrelu(TestCase): output = input1.prelu(input2) output = output.to("cpu") if output.dtype != torch.float32: - output = output.to(torch.float32) + output = output.to(torch.float32) return output.numpy() def test_prelu_shape_format(self, device): -- Gitee