diff --git a/test/test_network_ops/test_hard_swish.py b/test/test_network_ops/test_hard_swish.py new file mode 100644 index 0000000000000000000000000000000000000000..c1541d4d03d0a53596a7d13dbfa9fc0aa7ae8743 --- /dev/null +++ b/test/test_network_ops/test_hard_swish.py @@ -0,0 +1,75 @@ +# Copyright (c) 2020, Huawei Technologies.All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch_npu +import numpy as np + +from torch_npu.testing.testcase import TestCase, run_tests +from torch_npu.testing.common_utils import create_common_tensor + + +class TestHardSwish(TestCase): + def cpu_op_exec(self, input1): + cpu_output = torch.nn.functional.hardswish(input1, inplace=False) + cpu_output = cpu_output.numpy() + + return cpu_output + + def npu_op_exec(self, input1): + output = torch.nn.functional.hardswish(input1, inplace=False) + output = output.to("cpu") + output = output.numpy() + + return output + + def create_shape_format16(self): + format_list = [0, 3, 29] + dtype_list = [np.float16] + shape_list = [[32], [32, 3], [32, 3, 3], [64, 32, 3, 3]] + shape_format = [[i, j, k] for i in dtype_list for j in format_list for k in shape_list] + + return shape_format + + def create_shape_format32(self): + format_list32 = [0, 3, 29] + dtype_list32 = [np.float32] + shape_list32 = [[32], [32, 3], [32, 3, 3], [64, 32, 3, 3]] + shape_format32 = [[i, j, k] for i in dtype_list32 for j in format_list32 for k in shape_list32] + + return shape_format32 + + def test_hardswish_shape_format_fp16(self): + for item in self.create_shape_format16(): + cpu_input1, npu_input1 = create_common_tensor(item, 2, 100) + if cpu_input1.dtype == torch.float16: + cpu_input1 = cpu_input1.to(torch.float32) + + cpu_output = self.cpu_op_exec(cpu_input1) + npu_output = self.npu_op_exec(npu_input1) + cpu_output = cpu_output.astype(npu_output.dtype) + + self.assertRtolEqual(cpu_output, npu_output) + + def test_hardswish_shape_format_fp32(self): + for item in self.create_shape_format32(): + cpu_input1, npu_input1 = create_common_tensor(item, 2, 100) + cpu_output = self.cpu_op_exec(cpu_input1) + npu_output = self.npu_op_exec(npu_input1) + + self.assertRtolEqual(cpu_output, npu_output) + + +if __name__ == "__main__": + run_tests() diff --git a/test/test_network_ops/test_hard_swish_backward.py b/test/test_network_ops/test_hard_swish_backward.py new file mode 100644 index 0000000000000000000000000000000000000000..7c3716d8b7e2d9fc91f9b005acb6a32992acafbc --- /dev/null +++ b/test/test_network_ops/test_hard_swish_backward.py @@ -0,0 +1,89 @@ +# Copyright (c) 2020, Huawei Technologies.All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch_npu +import numpy as np + +from torch_npu.testing.testcase import TestCase, run_tests +from torch_npu.testing.common_utils import create_common_tensor + + +class TestHardSwishBackWard(TestCase): + def cpu_op_exec(self, input1): + input1.requires_grad = True + cpu_output = torch.nn.functional.hardswish(input1, inplace=False) + cpu_output.backward(torch.ones_like(cpu_output)) + output_grad = input1.grad + output_grad = output_grad.detach().numpy() + cpu_output = cpu_output.detach().numpy() + + return cpu_output, output_grad + + def npu_op_exec(self, input1): + input1.requires_grad = True + output = torch.nn.functional.hardswish(input1, inplace=False) + output.backward(torch.ones_like(output)) + output = output.to("cpu") + output_grad = input1.grad + output_grad = output_grad.to("cpu") + output_grad = output_grad.detach().numpy() + output = output.detach().numpy() + + return output, output_grad + + def backward_create_shape_format16(self): + backward_format_list = [0, 3, 29] + backward_dtype_list = [np.float16] + backward_shape_list = [[32], [32, 3], [32, 3, 3], [64, 32, 3, 3]] + backward_shape_format = [[i, j, k] for i in backward_dtype_list + for j in backward_format_list for k in backward_shape_list] + + return backward_shape_format + + def backward_create_shape_format32(self): + backward_format_list32 = [0, 3, 29] + backward_dtype_list32 = [np.float32] + backward_shape_list32 = [[32], [32, 3], [32, 3, 3], [64, 32, 3, 3]] + backward_shape_format32 = [[i, j, k] for i in backward_dtype_list32 + for j in backward_format_list32 for k in backward_shape_list32] + + return backward_shape_format32 + + def test_hardswish_shape_format_fp16(self): + for item in self.backward_create_shape_format16(): + cpu_input1, npu_input1 = create_common_tensor(item, 2, 100) + if cpu_input1.dtype == torch.float16: + cpu_input1 = cpu_input1.to(torch.float32) + + cpu_output, cpu_output_grad = self.cpu_op_exec(cpu_input1) + npu_output, npu_output_grad = self.npu_op_exec(npu_input1) + cpu_output = cpu_output.astype(npu_output.dtype) + cpu_output_grad = cpu_output_grad.astype(npu_output_grad.dtype) + + self.assertRtolEqual(cpu_output, npu_output) + self.assertRtolEqual(cpu_output_grad, npu_output_grad) + + def test_hardswish_shape_format_fp32(self): + for item in self.backward_create_shape_format32(): + cpu_input1, npu_input1 = create_common_tensor(item, 2, 100) + cpu_output, cpu_output_grad = self.cpu_op_exec(cpu_input1) + npu_output, npu_output_grad = self.npu_op_exec(npu_input1) + + self.assertRtolEqual(cpu_output, npu_output) + self.assertRtolEqual(cpu_output_grad, npu_output_grad) + + +if __name__ == "__main__": + run_tests() diff --git a/torch_npu/csrc/aten/ops/HardSwishBackwardKernelNpu.cpp b/torch_npu/csrc/aten/ops/HardSwishBackwardKernelNpu.cpp new file mode 100644 index 0000000000000000000000000000000000000000..b42687cc6e02044e9cac74fae1c11df936789060 --- /dev/null +++ b/torch_npu/csrc/aten/ops/HardSwishBackwardKernelNpu.cpp @@ -0,0 +1,44 @@ +// Copyright (c) 2020 Huawei Technologies Co., Ltd +// All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" + +namespace at_npu { +namespace native { + +at::Tensor& hardswish_backward_nocheck( + at::Tensor& grad_input, + const at::Tensor& grad_output, + const at::Tensor& self) { + + OpCommand cmd; + cmd.Name("HardSwishGrad") + .Input(grad_output) + .Input(self) + .Output(grad_input) + .Run(); + + return grad_input; +} + +at::Tensor NPUNativeFunctions::hardswish_backward(const at::Tensor& grad_output, const at::Tensor& self) { + at::Tensor grad_input = OpPreparation::ApplyTensor(grad_output); + + return hardswish_backward_nocheck(grad_input, grad_output, self); +} + +} // namespace native +} // namespace at_npu \ No newline at end of file diff --git a/torch_npu/csrc/aten/ops/HardSwishKernelNpu.cpp b/torch_npu/csrc/aten/ops/HardSwishKernelNpu.cpp new file mode 100644 index 0000000000000000000000000000000000000000..9329f7d6b07093f5a302e1c1bf4972687e8cbd2a --- /dev/null +++ b/torch_npu/csrc/aten/ops/HardSwishKernelNpu.cpp @@ -0,0 +1,61 @@ +// Copyright (c) 2020 Huawei Technologies Co., Ltd +// All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" + +namespace at_npu { +namespace native { + +at::Tensor& hardswish_out_nocheck(const at::Tensor& self, at::Tensor& result) { + OpCommand cmd; + cmd.Name("HardSwish") + .Input(self) + .Output(result) + .Run(); + + return result; +} + +at::Tensor& NPUNativeFunctions::hardswish_out(const at::Tensor& self, at::Tensor& result) { + OpPreparation::CheckOut({self}, result, self); + + if (!NpuUtils::check_match(&result)) { + at::Tensor contiguousResult = NpuUtils::format_contiguous(result); + hardswish_out_nocheck(self, contiguousResult); + NpuUtils::format_fresh_view(result, contiguousResult); + } else { + hardswish_out_nocheck(self, result); + } + + return result; +} + +at::Tensor NPUNativeFunctions::hardswish(const at::Tensor& self) { + at::Tensor result = OpPreparation::ApplyTensor(self); + + hardswish_out_nocheck(self, result); + + return result; +} + +at::Tensor& NPUNativeFunctions::hardswish_(at::Tensor& self) { + NPUNativeFunctions::hardswish_out(self, self); + + return self; +} + +} // namespace native +} // namespace at_npu \ No newline at end of file