From 05faff8d9343549d4d215f27685641928d3660ba Mon Sep 17 00:00:00 2001 From: zhoufan37 Date: Tue, 1 Mar 2022 18:14:26 +0800 Subject: [PATCH] Add rre Operator --- .../test_network_ops/test_rrelu_with_noise.py | 84 +++++++++++ .../test_rrelu_with_noise_backward.py | 92 +++++++++++++ .../ops/RreluWithNoiseBackwardKernelNpu.cpp | 42 ++++++ .../csrc/aten/ops/RreluWithNoiseKernelNpu.cpp | 130 ++++++++++++++++++ 4 files changed, 348 insertions(+) create mode 100644 test/test_network_ops/test_rrelu_with_noise.py create mode 100644 test/test_network_ops/test_rrelu_with_noise_backward.py create mode 100644 torch_npu/csrc/aten/ops/RreluWithNoiseBackwardKernelNpu.cpp create mode 100644 torch_npu/csrc/aten/ops/RreluWithNoiseKernelNpu.cpp diff --git a/test/test_network_ops/test_rrelu_with_noise.py b/test/test_network_ops/test_rrelu_with_noise.py new file mode 100644 index 0000000000..3ed5c59654 --- /dev/null +++ b/test/test_network_ops/test_rrelu_with_noise.py @@ -0,0 +1,84 @@ +# Copyright (c) 2020, Huawei Technologies. +# Copyright (c) 2019, Facebook CORPORATION. +# All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch_npu +import numpy as np + +from torch_npu.testing.testcase import TestCase, run_tests +from torch_npu.testing.common_utils import create_common_tensor + +class TestRreluWithNoise(TestCase): + def cpu_op_exec(self, input1, input2): + output = torch._C._nn.rrelu_with_noise(input1, input2, 0.1, 0.3) + output = output.numpy() + return output + + def npu_op_exec(self, input1, input2): + output = torch._C._nn.rrelu_with_noise(input1, input2, 0.1, 0.3) + output = output.to("cpu") + output = output.numpy() + return output + + def create_shape_format(self, device="npu"): + format_list = [0, 3] + dtype_list = [np.float32] + shape_list = [(1, 6, 4), (1, 4, 8), (1, 6, 8), + (2, 4, 5), (2, 5, 10), (2, 4, 10)] + + shape_format = [[[i, j, k]] for i in dtype_list + for j in format_list for k in shape_list] + + return shape_format + + def create_shape_format16(self, device="npu"): + format_list1 = [0, 3] + dtype_list1 = [np.float16] + shape_list1 = [(1, 6, 4), (1, 4, 8), (1, 6, 8), + (2, 4, 5), (2, 5, 10), (2, 4, 10)] + + shape_format1 = [[[i, j, k]] for i in dtype_list1 + for j in format_list1 for k in shape_list1] + + return shape_format1 + + def test_leaky_relu_shape_format(self, device="npu"): + for item in self.create_shape_format(device="npu"): + cpu_input1, npu_input1 = create_common_tensor(item[0], 1, 100) + cpu_input2, npu_input2 = create_common_tensor(item[0], 1, 100) + cpu_output = self.cpu_op_exec(cpu_input1, cpu_input2) + npu_output = self.npu_op_exec(npu_input1, npu_input2) + self.assertRtolEqual(cpu_output, npu_output) + + def test_leaky_relu_shape_format_fp16(self, device="npu"): + def cpu_op_exec_fp16(input1, input2): + input1 = input1.to(torch.float32) + input2 = input2.to(torch.float32) + output = torch._C._nn.rrelu_with_noise(input1, input2, 0.1, 0.3) + output = output.numpy() + output = output.astype(np.float16) + return output + + for item in self.create_shape_format16(device="npu"): + cpu_input1, npu_input1 = create_common_tensor(item[0], 1, 100) + cpu_input2, npu_input2 = create_common_tensor(item[0], 1, 100) + cpu_output = cpu_op_exec_fp16(cpu_input1, cpu_input2) + npu_output = self.npu_op_exec(npu_input1, npu_input2) + self.assertRtolEqual(cpu_output, npu_output) + + +if __name__ == "__main__": + run_tests() \ No newline at end of file diff --git a/test/test_network_ops/test_rrelu_with_noise_backward.py b/test/test_network_ops/test_rrelu_with_noise_backward.py new file mode 100644 index 0000000000..4ba2539211 --- /dev/null +++ b/test/test_network_ops/test_rrelu_with_noise_backward.py @@ -0,0 +1,92 @@ +# Copyright (c) 2020, Huawei Technologies. +# Copyright (c) 2019, Facebook CORPORATION. +# All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch_npu +import numpy as np + +from torch_npu.testing.testcase import TestCase, run_tests +from torch_npu.testing.common_utils import create_common_tensor + +class TestRreluWithNoiseBackward(TestCase): + def cpu_op_exec(self, input1, input2): + input1.requires_grad_(True) + output = torch._C._nn.rrelu_with_noise(input1, input2, 0.1, 0.3) + output.backward(torch.ones_like(input1)) + res = input1.grad + res = res.numpy() + return res + + def npu_op_exec(self, input1, input2): + input1.requires_grad_(True) + output = torch._C._nn.rrelu_with_noise(input1, input2, 0.1, 0.3) + output.backward(torch.ones_like(input1)) + res = input1.grad.to("cpu") + res = res.numpy() + return res + + def create_shape_format(self, device="npu"): + format_list2 = [0, 3] + dtype_list2 = [np.float32] + shape_list2 = [(1, 6, 4), (1, 4, 8), (1, 6, 8), + (2, 4, 5), (2, 5, 10), (2, 4, 10)] + + shape_format2 = [[[i, j, k]] for i in dtype_list2 + for j in format_list2 for k in shape_list2] + + return shape_format2 + + def create_shape_format16(self, device="npu"): + format_list3 = [0, 3] + dtype_list3 = [np.float16] + shape_list3 = [(1, 6, 4), (1, 4, 8), (1, 6, 8), + (2, 4, 5), (2, 5, 10), (2, 4, 10)] + + shape_format3 = [[[i, j, k]] for i in dtype_list3 + for j in format_list3 for k in shape_list3] + + return shape_format3 + + def test_leaky_relu_shape_format(self, device="npu"): + for item in self.create_shape_format(device="npu"): + cpu_input1, npu_input1 = create_common_tensor(item[0], 1, 100) + cpu_input2, npu_input2 = create_common_tensor(item[0], 1, 100) + cpu_output = self.cpu_op_exec(cpu_input1, cpu_input2) + npu_output = self.npu_op_exec(npu_input1, npu_input2) + self.assertRtolEqual(cpu_output, npu_output) + + def test_leaky_relu_shape_format_fp16(self, device="npu"): + def cpu_op_exec_fp16(input1, input2): + input1 = input1.to(torch.float32) + input2 = input2.to(torch.float32) + input1.requires_grad_(True) + output = torch._C._nn.rrelu_with_noise(input1, input2, 0.1, 0.3) + output.backward(torch.ones_like(input1)) + res = input1.grad + res = res.numpy() + res = res.astype(np.float16) + return res + + for item in self.create_shape_format16(device="npu"): + cpu_input1, npu_input1 = create_common_tensor(item[0], 1, 100) + cpu_input2, npu_input2 = create_common_tensor(item[0], 1, 100) + cpu_output = cpu_op_exec_fp16(cpu_input1, cpu_input2) + npu_output = self.npu_op_exec(npu_input1, npu_input2) + self.assertRtolEqual(cpu_output, npu_output) + + +if __name__ == "__main__": + run_tests() \ No newline at end of file diff --git a/torch_npu/csrc/aten/ops/RreluWithNoiseBackwardKernelNpu.cpp b/torch_npu/csrc/aten/ops/RreluWithNoiseBackwardKernelNpu.cpp new file mode 100644 index 0000000000..61679d504e --- /dev/null +++ b/torch_npu/csrc/aten/ops/RreluWithNoiseBackwardKernelNpu.cpp @@ -0,0 +1,42 @@ +// Copyright (c) 2020 Huawei Technologies Co., Ltd +// Copyright (c) 2019, Facebook CORPORATION. +// All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" + +namespace at_npu { +namespace native { + +at::Tensor NPUNativeFunctions::rrelu_with_noise_backward( + const at::Tensor& grad_output, + const at::Tensor& self_or_result, + const at::Tensor& noise, + at::Scalar lower, + at::Scalar upper, + bool training, + bool is_result) { + auto folat_lower = lower.toFloat(); + auto float_upper = upper.toFloat(); + if (training && (float_upper - folat_lower > 1E-6)) { + return grad_output.mul(noise); + } else { + at::Scalar negative_slope = (folat_lower + float_upper) / 2; + return NPUNativeFunctions::leaky_relu_backward(grad_output, self_or_result, negative_slope, is_result); + } +} + +} // namespace native +} // namespace at_npu \ No newline at end of file diff --git a/torch_npu/csrc/aten/ops/RreluWithNoiseKernelNpu.cpp b/torch_npu/csrc/aten/ops/RreluWithNoiseKernelNpu.cpp new file mode 100644 index 0000000000..49a4ac6ca0 --- /dev/null +++ b/torch_npu/csrc/aten/ops/RreluWithNoiseKernelNpu.cpp @@ -0,0 +1,130 @@ +// Copyright (c) 2020 Huawei Technologies Co., Ltd +// Copyright (c) 2019, Facebook CORPORATION. +// All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include "torch_npu/csrc/framework/utils/CalcuOpUtil.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" +#include + +namespace at_npu { +namespace native { + +void _rrelu_with_noise_train( + at::Tensor& output, + const at::Tensor& input, + const at::Tensor& noise, + at::Scalar lower_, + at::Scalar upper_, + c10::optional generator) { + float lower = lower_.toFloat(); + float upper = upper_.toFloat(); + auto shape = output.sizes(); + auto noise_shape = noise.sizes(); + at::Tensor tmp_tensor = output.contiguous(); + at::Tensor output_data = tmp_tensor.reshape({output.numel()}); + at::Tensor input_data = input.reshape({input.numel()}); + at::Tensor tmp_noise = noise; + tmp_noise = tmp_noise.reshape({tmp_noise.numel()}); + auto gen = at::get_generator_or_default(generator, at::detail::getDefaultCPUGenerator()); + + for (int64_t i = 0; i < input.numel(); i++) { + if (input_data[i].item().toFloat() <= 0) { + at::uniform_real_distribution uniform(lower, upper); + const float r = uniform(gen); + output_data[i] = input_data[i] * r; + tmp_noise[i] = r; + } else { + tmp_noise[i] = 1; + output_data[i] = input_data[i]; + } + } + if (!output.is_contiguous()) { + output.copy_(tmp_tensor); + } + tmp_noise.reshape(noise_shape); + noise.copy_(tmp_noise); + output.reshape(shape); +} + +at::Tensor& rrelu_with_noise_out_nocheck( + const at::Tensor& self, + const at::Tensor& noise, + at::Scalar lower, + at::Scalar upper, + bool training, + c10::optional generator, + at::Tensor& output) { + + if (training) { + _rrelu_with_noise_train(output, self.contiguous(), noise, lower, upper, generator); + return output; + } else { + auto float_lower = lower.toFloat(); + auto float_upper = upper.toFloat(); + at::Scalar negative_slope = (float_lower + float_upper) / 2; + return NPUNativeFunctions::leaky_relu_out(self, negative_slope, output); + } +} + +at::Tensor NPUNativeFunctions::rrelu_with_noise( + const at::Tensor& self, + const at::Tensor& noise, + at::Scalar lower, + at::Scalar upper, + bool training, + c10::optional generator) { + auto output = at::empty_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT); + return rrelu_with_noise_out_nocheck(self, noise, lower, upper, training, generator, output); +} + +at::Tensor& NPUNativeFunctions::rrelu_with_noise_( + at::Tensor& self, + const at::Tensor& noise, + at::Scalar lower, + at::Scalar upper, + bool training, + c10::optional generator) { + return NPUNativeFunctions::rrelu_with_noise_out(self, noise, lower, upper, training, generator, self); +} + +at::Tensor& NPUNativeFunctions::rrelu_with_noise_out( + const at::Tensor& self, + const at::Tensor& noise, + at::Scalar lower, + at::Scalar upper, + bool training, + c10::optional generator, + at::Tensor& output) { + + OpPreparation::CheckOut( + {self}, + output, + ACL_FORMAT_ND, + self.scalar_type(), + self.sizes()); + + if (!NpuUtils::check_match(&output)) { + at::Tensor contiguousResult = NpuUtils::format_contiguous(output); + at::Tensor newResult = rrelu_with_noise_out_nocheck(self, noise, lower, upper, training, generator, contiguousResult); + NpuUtils::format_fresh_view(output, newResult); + } else { + rrelu_with_noise_out_nocheck(self, noise, lower, upper, training, generator, output); + } + + return output; +} +} // namespace native +} // namespace at_npu \ No newline at end of file -- Gitee