diff --git a/test/test_network_ops/test_gelu_backward.py b/test/test_network_ops/test_gelu_backward.py new file mode 100644 index 0000000000000000000000000000000000000000..4952e4f602c2c3abbc653cb33dda8d28aa5e3281 --- /dev/null +++ b/test/test_network_ops/test_gelu_backward.py @@ -0,0 +1,77 @@ +# Copyright (c) 2020, Huawei Technologies.All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import copy +import torch +import torch_npu +import numpy as np + +from torch_npu.testing.common_utils import TestCase, run_tests +from torch_npu.testing.common_device_type import Dtypes, instantiate_device_type_tests +from torch_npu.testing.util_test import create_common_tensor + +class TestGeluBackward(TestCase): + def generate_single_data(self, min_val, max_val, shape, dtype): + input1 = np.random.uniform(min_val, max_val, shape).astype(dtype) + npu_input1 = torch.from_numpy(input1) + return npu_input1 + + def cpu_op_exec(self, input1): + input1.requires_grad_(True) + output = torch.nn.functional.gelu(input1) + z = output.sum() + z.backward() + res = input1.grad + return res.detach().numpy() + + def npu_op_exec(self, input1): + input1 = input1.to("npu") + input1.requires_grad = True + output = torch.nn.functional.gelu(input1) + z = output.sum() + z.backward() + res = input1.grad.to("cpu") + return res.detach().numpy() + + def test_gelu_backward_float32_1(self, device): + input1 = self.generate_single_data(0, 100, (4, 3, 1, 1), np.float32) + cpu_input1 = copy.deepcopy(input1) + cpu_output = self.cpu_op_exec(cpu_input1) + npu_output = self.npu_op_exec(input1) + self.assertRtolEqual(cpu_output, npu_output) + + def test_gelu_backward_float32_2(self, device): + input1 = self.generate_single_data(0, 100, (15, 3, 1), np.float32) + cpu_input1 = copy.deepcopy(input1) + cpu_output = self.cpu_op_exec(cpu_input1) + npu_output = self.npu_op_exec(input1) + self.assertRtolEqual(cpu_output, npu_output) + + def test_gelu_backward_float32_3(self, device): + input1 = self.generate_single_data(0, 100, (4, 4), np.float32) + cpu_input1 = copy.deepcopy(input1) + cpu_output = self.cpu_op_exec(cpu_input1) + npu_output = self.npu_op_exec(input1) + self.assertRtolEqual(cpu_output, npu_output) + + def test_gelu_backward_float16(self, device): + input1 = self.generate_single_data(0, 100, (5, 10, 100), np.float16) + cpu_input1 = input1.to(torch.float32) + cpu_output = self.cpu_op_exec(cpu_input1) + cpu_output = cpu_output.astype(np.float16) + npu_output = self.npu_op_exec(input1) + self.assertRtolEqual(cpu_output, npu_output) + +instantiate_device_type_tests(TestGeluBackward, globals(), except_for="cpu") +if __name__ == "__main__": + run_tests() diff --git a/torch_npu/csrc/aten/ops/GeluBackwardKernelNpu.cpp b/torch_npu/csrc/aten/ops/GeluBackwardKernelNpu.cpp new file mode 100644 index 0000000000000000000000000000000000000000..0e6c809960ee86b9e3efae4209cbde611f91b9f2 --- /dev/null +++ b/torch_npu/csrc/aten/ops/GeluBackwardKernelNpu.cpp @@ -0,0 +1,47 @@ +// Copyright (c) 2020, Huawei Technologies.All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "torch_npu/csrc/framework/utils/CalcuOpUtil.h" +#include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" + +namespace at_npu { +namespace native { + +at::Tensor& gelu_backward_out_npu_nocheck( + at::Tensor& grad_input, + const at::Tensor& grad, + const at::Tensor& self) { + at::Tensor unused = grad; + OpCommand cmd; + cmd.Name("GeluGrad") + .Input(grad) + .Input(self) + .Input(unused) + .Output(grad_input) + .Run(); + + return grad_input; +} + +at::Tensor NPUNativeFunctions::gelu_backward( + const at::Tensor& grad, + const at::Tensor& self) { + at::Tensor grad_input = OpPreparation::ApplyTensor(self); + gelu_backward_out_npu_nocheck(grad_input, grad, self); + return grad_input; +} + +} // namespace native +} // namespace at_npu \ No newline at end of file