diff --git a/test/test_network_ops/test_elu.py b/test/test_network_ops/test_elu.py new file mode 100644 index 0000000000000000000000000000000000000000..514a4e843521ce4340727f3704096c3c21da2665 --- /dev/null +++ b/test/test_network_ops/test_elu.py @@ -0,0 +1,97 @@ +# Copyright (c) 2020 Huawei Technologies Co., Ltd +# All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch_npu +import numpy as np + +from torch_npu.testing.testcase import TestCase, run_tests +from torch_npu.testing.common_utils import create_common_tensor + + +class TestElu(TestCase): + def get_format1(self): + shape_format = [ + [[np.float16, 0, (65535, 1, 1, 1)], -2, 2], + [[np.float16, 0, (1, 1, 1, 8192)], -2, 2], + [[np.float16, 0, (1, 1, 1, 16384)], -2, 2], + [[np.float16, 0, (1, 1, 1, 32768)], -2, 2], + [[np.float16, 0, (1, 1, 1, 65535)], -2, 2], + [[np.float16, 0, (1, 1, 1, 131072)], -2, 2], + [[np.float16, 0, (1, 1, 1, 196608)], -2, 2], + [[np.float16, 0, (1, 1, 1, 262144)], -2, 2], + [[np.float16, 0, (1, 1, 1, 393216)], -2, 2], + [[np.float16, 0, (1, 1, 1, 524288)], -2, 2], + [[np.float16, 0, (1, 1, 1, 655360)], -2, 2], + [[np.float16, 0, (1, 1, 1, 786432)], -2, 2], + ] + return shape_format + + def get_format2(self): + shape_format = [ + [[np.float32, 0, (1, 31, 149, 2)], -1.1754943508e-38, -1.1754943508e-38], + [[np.float32, 0, (1, 32, 31, 1)], -3402823500.0, 3402823500.0], + [[np.float32, 0, (128,)], 3402823500, 3402800000], + [[np.float32, 0, (184965, 1)], -9.313225746154785e-10, 9.313225746154785e-10], + [[np.float32, 0, (1, 31, 149, 2)], -3402823500.0, -3402823500.0], + [[np.float32, 0, (1, 31, 149, 2)], -3402823500.0, 3402823500.0], + [[np.float32, 0, (1, 31, 149, 2)], -9.313225746154785e-10, 9.313225746154785e-10], + [[np.float32, 0, (2, 31, 149, 2)], -0.000000000000000000000000000000000000011754943508, + 0.000000000000000000000000000000000000011754943508], + [[np.float32, 0, (4, 31, 149, 2)], 0.000000000000000000000000000000000000011754943508, + 0.000000000000000000000000000000000000011754943508], + [[np.float32, 0, (2048, 31, 1, 2)], -0.000000000000000000000000000000000000011754943508, + -0.000000000000000000000000000000000000011754943508], + [[np.float32, 0, (8, 7, 149)], -0.000000000000000000000000000000000000011754943508, + 0.000000000000000000000000000000000000011754943508] + ] + return shape_format + + def cpu_op_exec(self, input1): + output = torch.nn.functional.elu(input1) + output = output.numpy() + return output + + def cpu_op_exec_fp16(self, input1): + input1 = input1.to(torch.float32) + output = torch.nn.functional.elu(input1) + output = output.to(torch.float16) + output = output.numpy() + return output + + def npu_op_exec(self, input1): + output = torch.nn.functional.elu(input1) + output = output.to("cpu") + output = output.numpy() + return output + + def test_elu_common_shape_format_fp16(self): + shape_format = self.get_format1() + for item in shape_format: + cpu_input1, npu_input1 = create_common_tensor(item[0], item[1], item[2]) + cpu_output = self.cpu_op_exec_fp16(cpu_input1) + npu_output = self.npu_op_exec(npu_input1) + self.assertRtolEqual(cpu_output, npu_output) + + def test_elu_common_shape_format_fp32(self): + shape_format = self.get_format2() + for item in shape_format: + cpu_input1, npu_input1 = create_common_tensor(item[0], item[1], item[2]) + cpu_output = self.cpu_op_exec(cpu_input1) + npu_output = self.npu_op_exec(npu_input1) + self.assertRtolEqual(cpu_output, npu_output) + +if __name__ == '__main__': + run_tests() diff --git a/test/test_network_ops/test_elu_backward.py b/test/test_network_ops/test_elu_backward.py new file mode 100644 index 0000000000000000000000000000000000000000..5d8feaba0efc6f423b34b4bb02072e65474915f8 --- /dev/null +++ b/test/test_network_ops/test_elu_backward.py @@ -0,0 +1,101 @@ +# Copyright (c) 2020 Huawei Technologies Co., Ltd +# All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch_npu +import numpy as np +import torch.nn as nn + +from torch_npu.testing.testcase import TestCase, run_tests +from torch_npu.testing.common_utils import create_common_tensor + + +class TestEluBackward(TestCase): + def cpu_op_exec(self, input1): + flag = 0 + if input1.dtype == torch.float16: + input1 = input1.to(torch.float32) + flag = 1 + input1.requires_grad = True + output = torch.nn.functional.elu(input1) + output.backward(torch.ones_like(output)) + output_grad = input1.grad + if flag: + output_grad = output_grad.to(torch.float16) + output = output.to(torch.float16) + output_grad = output_grad.detach().numpy() + output = output.detach().numpy() + return output, output_grad + + def npu_op_exec(self, input1): + input1.requires_grad = True + output = torch.nn.functional.elu(input1) + output.backward(torch.ones_like(output)) + output_grad = input1.grad + output_grad = output_grad.to("cpu") + output = output.to("cpu") + output_grad = output_grad.detach().numpy() + output = output.detach().numpy() + return output, output_grad + + def test_elu_common_shape_format_fp16(self): + shape_format = [ + [[np.float16, 0, (65535, 1, 1, 1)], -2, 2], + [[np.float16, 0, (1, 1, 1, 8192)], -2, 2], + [[np.float16, 0, (1, 1, 1, 16384)], -2, 2], + [[np.float16, 0, (1, 1, 1, 32768)], -2, 2], + [[np.float16, 0, (1, 1, 1, 65535)], -2, 2], + [[np.float16, 0, (1, 1, 1, 131072)], -2, 2], + [[np.float16, 0, (1, 1, 1, 196608)], -2, 2], + [[np.float16, 0, (1, 1, 1, 262144)], -2, 2], + [[np.float16, 0, (1, 1, 1, 393216)], -2, 2], + [[np.float16, 2, (1, 1, 1, 524288)], -2, 2], + [[np.float16, 0, (1, 1, 1, 655360)], -2, 2], + [[np.float16, 0, (1, 1, 1, 786432)], -2, 2], + ] + for item in shape_format: + cpu_input1, npu_input1 = create_common_tensor(item[0], item[1], item[2]) + cpu_output, cpu_output_grad = self.cpu_op_exec(cpu_input1) + npu_output, npu_output_grad = self.npu_op_exec(npu_input1) + self.assertRtolEqual(cpu_output, npu_output) + self.assertRtolEqual(cpu_output_grad, npu_output_grad) + + def test_elu_common_shape_format_fp32(self): + shape_format1 = [ + [[np.float32, 0, (1, 31, 149, 2)], -1.1754943508e-38, -1.1754943508e-38], + [[np.float32, 0, (1, 32, 31, 1)], -3402823500.0, 3402823500.0], + [[np.float32, 0, (128,)], 3402823500, 3402800000], + [[np.float32, 0, (184965, 1)], -9.313225746154785e-10, 9.313225746154785e-10], + [[np.float32, 0, (1, 31, 149, 2)], -3402823500.0, -3402823500.0], + [[np.float32, 2, (1, 31, 149, 2)], -3402823500.0, 3402823500.0], + [[np.float32, 0, (1, 31, 149, 2)], -9.313225746154785e-10, 9.313225746154785e-10], + [[np.float32, 0, (2, 31, 149, 2)], -0.000000000000000000000000000000000000011754943508, + 0.000000000000000000000000000000000000011754943508], + [[np.float32, 0, (4, 31, 149, 2)], 0.000000000000000000000000000000000000011754943508, + 0.000000000000000000000000000000000000011754943508], + [[np.float32, 0, (2048, 31, 1, 2)], -0.000000000000000000000000000000000000011754943508, + -0.000000000000000000000000000000000000011754943508], + [[np.float32, 0, (8, 7, 149)], -0.000000000000000000000000000000000000011754943508, + 0.000000000000000000000000000000000000011754943508] + ] + for item in shape_format1: + cpu_input1, npu_input1 = create_common_tensor(item[0], item[1], item[2]) + cpu_output, cpu_output_grad = self.cpu_op_exec(cpu_input1) + npu_output, npu_output_grad = self.npu_op_exec(npu_input1) + self.assertRtolEqual(cpu_output, npu_output) + self.assertRtolEqual(cpu_output_grad, npu_output_grad) + +if __name__ == '__main__': + run_tests() diff --git a/torch_npu/csrc/aten/ops/EluBackwardKernelNpu.cpp b/torch_npu/csrc/aten/ops/EluBackwardKernelNpu.cpp new file mode 100644 index 0000000000000000000000000000000000000000..cef3f2ce1e8208722d45ceb684155c589e1e5adc --- /dev/null +++ b/torch_npu/csrc/aten/ops/EluBackwardKernelNpu.cpp @@ -0,0 +1,41 @@ +// Copyright (c) 2020 Huawei Technologies Co., Ltd +// All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "torch_npu/csrc/framework/utils/CalcuOpUtil.h" +#include "torch_npu/csrc/framework/utils/KernelNpuOutputSize.h" +#include "torch_npu/csrc/framework/utils/NpuUtils.h" +#include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" + +namespace at_npu { +namespace native { +at::Tensor& elu_backward_out_npu(at::Tensor& grad_input, const at::Tensor& grad_output, at::Scalar alpha, at::Scalar scale, at::Scalar input_scale, const at::Tensor& output) { + float value = CalcuOpUtil::get_scalar_float_value(alpha); + OpCommand cmd; + cmd.Name("EluGradV2") + .Input(grad_output) + .Input(output) + .Output(grad_input) + .Attr("alpha", value) + .Run(); + return grad_input; +} +at::Tensor NPUNativeFunctions::elu_backward(const at::Tensor& grad_output, at::Scalar alpha, at::Scalar scale, at::Scalar input_scale, bool is_result, const at::Tensor& output) { + at::Tensor result = OpPreparation::ApplyTensor(grad_output); + elu_backward_out_npu(result, grad_output, alpha, scale, input_scale, output); + return result; +} +} // namespace native +} // namespace at_npu diff --git a/torch_npu/csrc/aten/ops/EluKernelNpu.cpp b/torch_npu/csrc/aten/ops/EluKernelNpu.cpp new file mode 100644 index 0000000000000000000000000000000000000000..40344436075802fd6c2fdbae460a0554cb6b9f16 --- /dev/null +++ b/torch_npu/csrc/aten/ops/EluKernelNpu.cpp @@ -0,0 +1,64 @@ +// Copyright (c) 2020 Huawei Technologies Co., Ltd +// Copyright (c) 2019, Facebook CORPORATION. +// All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "torch_npu/csrc/framework/utils/CalcuOpUtil.h" +#include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" + +namespace at_npu { +namespace native { +at::Tensor& elu_out_nocheck(const at::Tensor& self, at::Scalar alpha, at::Scalar scale, at::Scalar input_scale, at::Tensor& result) { + float alphaValue = CalcuOpUtil::get_scalar_float_value(alpha); + float scaleValue = CalcuOpUtil::get_scalar_float_value(scale); + float inputScaleValue = CalcuOpUtil::get_scalar_float_value(input_scale); + OpCommand cmd; + cmd.Name("Elu") + .Input(self) + .Output(result) + .Attr("alpha", alphaValue) + .Attr("scale", scaleValue) + .Attr("input_scale", inputScaleValue) + .Run(); + return result; +} + +at::Tensor& NPUNativeFunctions::elu_out(const at::Tensor& self, at::Scalar alpha, at::Scalar scale, at::Scalar input_scale, at::Tensor& result) { + OpPreparation::CheckOut( + {self}, + result, + self); + if (!NpuUtils::check_match(&result)) { + at::Tensor contiguousResult = NpuUtils::format_contiguous(result); + at::Tensor checkResult = elu_out_nocheck(self, alpha, scale, input_scale, contiguousResult); + NpuUtils::format_fresh_view(result, checkResult); + } else { + elu_out_nocheck(self, alpha, scale, input_scale, result); + } + return result; +} + +at::Tensor NPUNativeFunctions::elu(const at::Tensor& self, at::Scalar alpha, at::Scalar scale, at::Scalar input_scale) { + at::Tensor result = OpPreparation::ApplyTensor(self); + elu_out_nocheck(self, alpha, scale, input_scale, result); + return result; +} + +at::Tensor& NPUNativeFunctions::elu_(at::Tensor& self, at::Scalar alpha, at::Scalar scale, at::Scalar input_scale) { + elu_out(self, alpha, scale, input_scale, self); + return self; +} +} // namespace native +} // namespace at_npu