diff --git a/test/test_network_ops/test_silu.py b/test/test_network_ops/test_silu.py new file mode 100644 index 0000000000000000000000000000000000000000..0da283cafc8101b5547fed1e222836fa571065be --- /dev/null +++ b/test/test_network_ops/test_silu.py @@ -0,0 +1,95 @@ +# Copyright (c) 2020, Huawei Technologies.All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch_npu +import numpy as np + +from torch_npu.testing.testcase import TestCase, run_tests +from torch_npu.testing.common_utils import create_common_tensor + +class TestSilu(TestCase): + + def cpu_op_exec(self, input1): + output = input1 * torch.sigmoid(input1) + output = output.numpy() + return output + + def npu_op_exec(self, input1): + output = torch_npu.npu_silu(input1) + output = output.to("cpu") + output = output.numpy() + return output + + def npu_op_exec_inplace(self, input1): + torch_npu.npu_silu_(input1) + output = input1.to("cpu") + output = output.numpy() + return output + + def test_silu_shape_format_fp16(self): + format_list = [0] + shape_list = [1, (64, 10), (32, 3, 3), (256, 2048, 7, 7)] + shape_format = [ + [np.float16, i, j] for i in format_list for j in shape_list + ] + for item in shape_format: + cpu_input, npu_input = create_common_tensor(item, 1, 100) + cpu_input = cpu_input.to(torch.float32) + cpu_output = self.cpu_op_exec(cpu_input) + npu_output = self.npu_op_exec(npu_input) + cpu_output = cpu_output.astype(np.float16) + self.assertRtolEqual(cpu_output, npu_output) + + def test_silu_shape_format_fp32(self): + format_list = [0, 3, 4, 29] + shape_list = [1, (32, 32, 3, 3), (256, 2048, 7, 7)] + shape_format = [ + [np.float32, i, j] for i in format_list for j in shape_list + ] + for item in shape_format: + cpu_input, npu_input = create_common_tensor(item, 1, 100) + cpu_output = self.cpu_op_exec(cpu_input) + npu_output = self.npu_op_exec(npu_input) + self.assertRtolEqual(cpu_output, npu_output) + + def test_silu_inplace_shape_format_fp16(self): + format_list = [0] + shape_list = [1, (64, 10), (32, 3, 3), (256, 2048, 7, 7)] + shape_format = [ + [np.float16, i, j] for i in format_list for j in shape_list + ] + for item in shape_format: + cpu_input, npu_input = create_common_tensor(item, 1, 100) + cpu_input = cpu_input.to(torch.float32) + cpu_output = self.cpu_op_exec(cpu_input) + npu_output = self.npu_op_exec_inplace(npu_input) + cpu_output = cpu_output.astype(np.float16) + self.assertRtolEqual(cpu_output, npu_output) + + def test_silu_inplace_shape_format_fp32(self): + format_list = [0, 3, 4, 29] + shape_list = [1, (32, 32, 3, 3), (256, 2048, 7, 7)] + shape_format = [ + [np.float32, i, j] for i in format_list for j in shape_list + ] + for item in shape_format: + cpu_input, npu_input = create_common_tensor(item, 1, 100) + cpu_output = self.cpu_op_exec(cpu_input) + npu_output = self.npu_op_exec_inplace(npu_input) + self.assertRtolEqual(cpu_output, npu_output) + + +if __name__ == "__main__": + run_tests() \ No newline at end of file diff --git a/test/test_network_ops/test_silu_backward.py b/test/test_network_ops/test_silu_backward.py new file mode 100644 index 0000000000000000000000000000000000000000..9e4e319e81a7180540b8aaf9209a609bcac819a9 --- /dev/null +++ b/test/test_network_ops/test_silu_backward.py @@ -0,0 +1,74 @@ +# Copyright (c) 2020, Huawei Technologies.All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch_npu +import numpy as np + +from torch_npu.testing.testcase import TestCase, run_tests +from torch_npu.testing.common_utils import create_common_tensor + +def input_grad_hook(grad): + global input_grad + input_grad = grad + input_grad = input_grad.numpy() + + +def npu_input_grad_hook(grad): + global npu_input_grad + npu_input_grad = grad.to("cpu") + npu_input_grad = npu_input_grad.numpy() + + +class TestSiluBackward(TestCase): + def cpu_op_exec(self, input1, is_contiguous = True): + if is_contiguous is False : + input1 = input1.as_strided([2,2], [1,2], 1) + input1.requires_grad = True + input1.register_hook(input_grad_hook) + output = input1 * torch.sigmoid(input1) + z = output.sum() + z.backward() + + def npu_op_exec(self, input1, is_contiguous = True): + if is_contiguous is False : + input1 = input1.as_strided([2,2], [1,2], 1) + input1.requires_grad = True + input1.register_hook(npu_input_grad_hook) + + output = torch_npu.npu_silu(input1) + z = output.sum() + z.backward() + input1 = input1.cpu() + + def test_silu_backward_shape_format_fp32(self): + format_list = [0, 3, 4, 29] + shape_list = [(256, 2048, 7, 7)] + shape_format = [ + [np.float32, i, j] for i in format_list for j in shape_list + ] + for item in shape_format: + input1, npu_input1 = create_common_tensor(item, 1, 100) + input2, npu_input2 = create_common_tensor(item, 1, 100) + self.cpu_op_exec(input1) + self.npu_op_exec(npu_input1) + self.assertRtolEqual(input_grad, npu_input_grad) + + self.cpu_op_exec(input2, False) + self.npu_op_exec(npu_input2, False) + self.assertRtolEqual(input_grad, npu_input_grad) + + +if __name__ == "__main__": + run_tests() \ No newline at end of file diff --git a/torch_npu/csrc/aten/npu_native_functions.yaml b/torch_npu/csrc/aten/npu_native_functions.yaml index 0dfd6c955f99e4b4f1ee07d0203613215f20062b..f794be4e55b7e46b46793010b04f51eb2ae36f72 100644 --- a/torch_npu/csrc/aten/npu_native_functions.yaml +++ b/torch_npu/csrc/aten/npu_native_functions.yaml @@ -1926,6 +1926,8 @@ custom: - func: npu_rotated_box_encode(Tensor self, Tensor gt_bboxes, Tensor weight) -> Tensor - func: npu_rotated_box_decode(Tensor self, Tensor deltas, Tensor weight) -> Tensor - func: npu_rotated_overlaps(Tensor self, Tensor query_boxes, bool trans=False) -> Tensor + - func: npu_silu_backward(Tensor grad_output, Tensor x0, Tensor x1) -> Tensor + - func: npu_rotated_iou(Tensor self, Tensor query_boxes, bool trans=False, int mode=0, bool is_cross=True, float v_threshold=0.0, float e_threshold=0.0) -> Tensor custom_autograd: - func: npu_convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, int groups) -> Tensor - func: npu_convolution_transpose(Tensor input, Tensor weight, Tensor? bias, int[] padding, int[] output_padding, int[] stride, int[] dilation, int groups) -> Tensor @@ -1944,4 +1946,6 @@ custom_autograd: - func: npu_bmmV2(Tensor self, Tensor mat2, int[] output_sizes) -> Tensor variants: function, method - func: npu_dtype_cast(Tensor self, ScalarType dtype) -> Tensor - variants: function, method \ No newline at end of file + variants: function, method + - func: npu_silu(Tensor self) -> Tensor + - func: npu_silu_(Tensor(a!) self) -> Tensor(a!) \ No newline at end of file diff --git a/torch_npu/csrc/aten/ops/RotatedIouKernelNpu.cpp b/torch_npu/csrc/aten/ops/RotatedIouKernelNpu.cpp new file mode 100644 index 0000000000000000000000000000000000000000..6338bf911b00942411014cd2abda5384ae5f78e9 --- /dev/null +++ b/torch_npu/csrc/aten/ops/RotatedIouKernelNpu.cpp @@ -0,0 +1,81 @@ +// Copyright (c) 2020 Huawei Technologies Co., Ltd +// Copyright (c) 2019, Facebook CORPORATION. +// All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" + +namespace at_npu { +namespace native { + +at::Tensor& rotated_iou_npu_nocheck( + at::Tensor& iou, + const at::Tensor& boxes, + const at::Tensor& query_boxes, + bool trans, + int64_t mode, + bool is_cross, + double v_threshold, + double e_threshold) { + string mode_str = (mode == 0) ? "iou" : "iof"; + + OpCommand cmd; + cmd.Name("RotatedIou") + .Input(boxes) + .Input(query_boxes) + .Output(iou) + .Attr("trans", trans) + .Attr("mode", mode_str) + .Attr("is_cross", is_cross) + .Attr("value", static_cast(v_threshold)) + .Attr("value", static_cast(e_threshold)) + .Run(); + return iou; +} + +at::Tensor NPUNativeFunctions::npu_rotated_iou( + const at::Tensor& boxes, + const at::Tensor& query_boxes, + bool trans, + int64_t mode, + bool is_cross, + double v_threshold, + double e_threshold) { + TORCH_CHECK(boxes.ndimension() == 3 && query_boxes.ndimension() == 3); + + auto origin_dtype = boxes.scalar_type(); + + at::Tensor boxesOk = boxes.permute({0, 2, 1}); + if (boxesOk.scalar_type() == at::kHalf){ + boxesOk = NPUNativeFunctions::npu_dtype_cast(boxesOk, at::kFloat); + } + at::Tensor query_boxesOk = query_boxes.permute({0, 2, 1}); + if (query_boxesOk.scalar_type() == at::kHalf){ + query_boxesOk = NPUNativeFunctions::npu_dtype_cast(query_boxesOk, at::kFloat); + } + + int64_t B = boxesOk.size(0); + int64_t N = boxesOk.size(-1); + int64_t K = query_boxesOk.size(-1); + + c10::SmallVector output_size({B, N, K}); + at::Tensor iou = OpPreparation::ApplyTensor(boxesOk, output_size); + + rotated_iou_npu_nocheck(iou, boxesOk, query_boxesOk, trans, mode, is_cross, v_threshold, e_threshold); + iou = NPUNativeFunctions::npu_dtype_cast(iou, origin_dtype); + return iou; +} +} // namespace native +} // namespace at_npu \ No newline at end of file diff --git a/torch_npu/csrc/aten/ops/SiluKernelNpu.cpp b/torch_npu/csrc/aten/ops/SiluKernelNpu.cpp new file mode 100644 index 0000000000000000000000000000000000000000..eaf9b562b2abc6ad92a1c867b0f7eccd1651f782 --- /dev/null +++ b/torch_npu/csrc/aten/ops/SiluKernelNpu.cpp @@ -0,0 +1,122 @@ +// Copyright (c) 2020 Huawei Technologies Co., Ltd +// Copyright (c) 2019, Facebook CORPORATION. +// All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" + +namespace at_npu { +namespace native { +using torch::autograd::Function; +using torch::autograd::AutogradContext; +using tensor_list = std::vector; + +at::Tensor& silu_out_npu_nocheck(at::Tensor& result, const at::Tensor& self) { + OpCommand cmd; + cmd.Name("Swish") + .Input(self) + .Output(result) + .Run(); + return result; +} + +at::Tensor& silu_out_npu(const at::Tensor& self, at::Tensor& out){ + OpPreparation::CheckOut( + {self}, + out, + self); + + if (!NpuUtils::check_match(&out)) { + at::Tensor contiguousOut = NpuUtils::format_contiguous(out); + at::Tensor newOut = silu_out_npu_nocheck(contiguousOut, self); + NpuUtils::format_fresh_view(out, newOut); + } else { + silu_out_npu_nocheck(out, self); + } + + return out; +} + +at::Tensor silu_kerner_npu(const at::Tensor& self) { + at::Tensor result = OpPreparation::ApplyTensor(self); + + silu_out_npu_nocheck(result, self); + + return result; +} + +at::Tensor& NPUNativeFunctions::npu_silu_(at::Tensor& self) { + silu_out_npu(self, self); + return self; +} + +at::Tensor& silu_backward_out_npu_nocheck( + at::Tensor& result, + const at::Tensor& grad_output, + const at::Tensor& x0, + const at::Tensor& x1) { + + OpCommand cmd; + cmd.Name("SwishGrad") + .Input(grad_output) + .Input(x0) + .Input(x1) + .Output(result) + .Run(); + + return result; +} + +at::Tensor NPUNativeFunctions::npu_silu_backward(const at::Tensor& grad_output, const at::Tensor& x0, const at::Tensor& x1) { + // construct the output tensor of the NPU + at::Tensor grad_input = OpPreparation::ApplyTensor(grad_output); + + // calculate the output result of the NPU + silu_backward_out_npu_nocheck(grad_input, grad_output, x0, x1); + + return grad_input; +} + +class NPUSiluFunction : public torch::autograd::Function { +public: + static at::Tensor forward(AutogradContext *ctx, + const at::Tensor& self) { + at::AutoNonVariableTypeMode g; + at::Tensor result = silu_kerner_npu(self); + ctx->save_for_backward({self, result}); + + return result; + } + + static tensor_list backward(AutogradContext *ctx, + tensor_list grad_outputs) { + auto saved = ctx->get_saved_variables(); + auto input = saved[0]; + auto result = saved[1]; + + at::Tensor output = NPUNativeFunctions::npu_silu_backward(grad_outputs[0], input, result); + tensor_list outputlist = {output}; + + return outputlist; + } +}; + +at::Tensor NPUNativeFunctions::npu_silu(const at::Tensor& self) { + return NPUSiluFunction::apply(self); +} + +} // namespace native +} // namespace at_npu \ No newline at end of file