From 1782ddd7fea3817379e457ece5ecfcb2fc6244fc Mon Sep 17 00:00:00 2001 From: wuxiankun Date: Mon, 24 Jan 2022 17:41:37 +0800 Subject: [PATCH 1/3] plug npu_giou --- test/test_network_ops/test_npu_giou.py | 134 ++++++++++++++++++ .../test_npu_giou_backward.py | 87 ++++++++++++ torch_npu/csrc/aten/npu_native_functions.yaml | 3 + .../csrc/aten/ops/GiouBackwardKernelNpu.cpp | 133 +++++++++++++++++ torch_npu/csrc/aten/ops/GiouKernelNpu.cpp | 99 +++++++++++++ 5 files changed, 456 insertions(+) create mode 100644 test/test_network_ops/test_npu_giou.py create mode 100644 test/test_network_ops/test_npu_giou_backward.py create mode 100644 torch_npu/csrc/aten/ops/GiouBackwardKernelNpu.cpp create mode 100644 torch_npu/csrc/aten/ops/GiouKernelNpu.cpp diff --git a/test/test_network_ops/test_npu_giou.py b/test/test_network_ops/test_npu_giou.py new file mode 100644 index 00000000000..85b4fae5c73 --- /dev/null +++ b/test/test_network_ops/test_npu_giou.py @@ -0,0 +1,134 @@ +# Copyright (c) 2020 Huawei Technologies Co., Ltd +# Copyright (c) 2019, Facebook CORPORATION. +# All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import torch +import torch_npu +import numpy as np +from torch_npu.testing.common_utils import TestCase, run_tests +from torch_npu.testing.common_device_type import instantiate_device_type_tests +from torch_npu.testing.util_test import create_common_tensor + +class TestNpuGiou(TestCase): + def generate_giou_data(self, n, m, dtype): + data_bboxes = np.array([]).astype(dtype) + for i in range(4): + data_bboxes_array = i // 2 + math.pow(-1, i // 2) * 0.5 * np.random.rand(1, n).astype(dtype) + data_bboxes = np.append(data_bboxes, data_bboxes_array) + data_bboxes = data_bboxes.reshape([4, n]) + data_gtboxes = np.array([]).astype(dtype) + for i in range(4): + data_gtboxes_array = i // 2 + math.pow(-1, i // 2) * 0.5 * np.random.rand(1, m).astype(dtype) + data_gtboxes = np.append(data_gtboxes, data_gtboxes_array) + data_gtboxes = data_gtboxes.reshape([4, m]) + cpu_input1 = torch.from_numpy(data_bboxes) + cpu_input2 = torch.from_numpy(data_gtboxes) + npu_input1 = cpu_input1.npu() + npu_input2 = cpu_input2.npu() + return cpu_input1, cpu_input2, npu_input1, npu_input2 + + def cpu_op_exec(self, box1, box2, trans=False, is_cross=False, mode="iou"): + box1 = box1.numpy() + box2 = box2.numpy() + dtype = box1.dtype + _, n = box1.shape + _, m = box2.shape + if trans: + b1_x1, b1_x2 = box1[0] - box1[2] / 2, box1[0] + box1[2] / 2 + b1_y1, b1_y2 = box1[1] - box1[3] / 2, box1[1] + box1[3] / 2 + b2_x1, b2_x2 = box2[0] - box2[2] / 2, box2[0] + box2[2] / 2 + b2_y1, b2_y2 = box2[1] - box2[3] / 2, box2[1] + box2[3] / 2 + else: + b1_x1, b1_y1, b1_x2, b1_y2 = box1[0], box1[1], box1[2], box1[3] + b2_x1, b2_y1, b2_x2, b2_y2 = box2[0], box2[1], box2[2], box2[3] + w1, h1 = b1_x2 - b1_x1, b1_y2 - b1_y1 + w2, h2 = b2_x2 - b2_x1, b2_y2 - b2_y1 + area1 = w1 * h1 + area2 = w2 * h2 + giou_res = np.array([], dtype=dtype) + + for i in range(n): + for j in range(m): + inter_x1 = max(b1_x1[i], b2_x1[j]) + inter_x2 = min(b1_x2[i], b2_x2[j]) + inter_y1 = max(b1_y1[i], b2_y1[j]) + inter_y2 = min(b1_y2[i], b2_y2[j]) + outer_x1 = min(b1_x1[i], b2_x1[j]) + outer_x2 = max(b1_x2[i], b2_x2[j]) + outer_y1 = min(b1_y1[i], b2_y1[j]) + outer_y2 = max(b1_y2[i], b2_y2[j]) + inter_area = max(0, (inter_x2 - inter_x1)) * max(0, (inter_y2 - inter_y1)) + outer_area = abs(outer_x2 - outer_x1) * abs(outer_y2 - outer_y1) + union_area = area1[i] + area2[j] - inter_area + 1e-16 + other_area = outer_area - union_area + giou_ij = inter_area / union_area - other_area / outer_area + if not is_cross: + if i == j: + giou_res = np.append(giou_res, giou_ij) + else: + giou_res = np.append(giou_res, giou_ij) + + if not is_cross: + res = giou_res.reshape(1, n) + else: + res = giou_res.reshape(n, m) + res = np.transpose(res) + res = np.transpose(res) + return res + + def npu_op_exec(self, box1, box2, trans=False, is_cross=False, mode=0): + output = torch_npu.npu_giou(box1, box2, trans, is_cross, mode) + output = output.detach().cpu().numpy() + return output + + def test_npu_giou_shape_format_fp32(self, device): + self._test_npu_giou_shape_format(np.float32) + + def test_npu_giou_shape_format_fp16(self, device): + self._test_npu_giou_shape_format(np.float16) + + def _test_npu_giou_shape_format(self, dtype): + shape_list = [ + [10, 10], + [12, 12], + [100, 100] + ] + is_trans_list = [True] + mode_list = ["iou"] + # TODO(Ascend): 反向只支持 mode=="iof", is_cross==False, + # is_trans==Fasle场景,这里同步验证相同场景 + shape_format = [[j, k, m] + for j in shape_list + for k in is_trans_list + for m in mode_list] + + for item in shape_format: + mode_digit = 0 if item[-1] == "iou" else 1 + is_cross = False if item[0][0] == item[0][1] else True + cpu_input1, cpu_input2, npu_input1, npu_input2 = self.generate_giou_data(*item[0], dtype) + cpu_output = self.cpu_op_exec(cpu_input1, cpu_input2, item[1], is_cross, item[-1]) + npu_output = self.npu_op_exec(npu_input1, npu_input2, item[1], is_cross, mode_digit) + cpu_output = cpu_output.astype(npu_output.dtype) + if dtype == np.float16: + # TODO(Ascend): fp16 insufficient precision + self.assertRtolEqual(cpu_output, npu_output, prec16=1e-2) + else: + self.assertRtolEqual(cpu_output, npu_output) + + +instantiate_device_type_tests(TestNpuGiou, globals(), except_for="cpu") +if __name__ == "__main__": + run_tests() diff --git a/test/test_network_ops/test_npu_giou_backward.py b/test/test_network_ops/test_npu_giou_backward.py new file mode 100644 index 00000000000..41cdfb3d480 --- /dev/null +++ b/test/test_network_ops/test_npu_giou_backward.py @@ -0,0 +1,87 @@ +# Copyright (c) 2020 Huawei Technologies Co., Ltd +# Copyright (c) 2019, Facebook CORPORATION. +# All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import torch +import torch_npu +import numpy as np +from torch_npu.testing.common_utils import TestCase, run_tests +from torch_npu.testing.common_device_type import instantiate_device_type_tests +from torch_npu.testing.util_test import create_common_tensor + +class TestNpuGiouBackward(TestCase): + def generate_giou_data(self, n, m, dtype): + data_bboxes = np.array([]).astype(dtype) + for i in range(4): + data_bboxes_array = i // 2 + math.pow(-1, i // 2) * 0.5 * np.random.rand(1, n).astype(dtype) + data_bboxes = np.append(data_bboxes, data_bboxes_array) + data_bboxes = data_bboxes.reshape([4, n]) + data_gtboxes = np.array([]).astype(dtype) + for i in range(4): + data_gtboxes_array = i // 2 + math.pow(-1, i // 2) * 0.5 * np.random.rand(1, m).astype(dtype) + data_gtboxes = np.append(data_gtboxes, data_gtboxes_array) + data_gtboxes = data_gtboxes.reshape([4, m]) + cpu_input1 = torch.from_numpy(data_bboxes) + cpu_input2 = torch.from_numpy(data_gtboxes) + npu_input1 = cpu_input1.npu() + npu_input2 = cpu_input2.npu() + return cpu_input1, cpu_input2, npu_input1, npu_input2 + + def npu_op_exec(self, box1, box2, trans=False, is_cross=False, mode=0): + box1.requires_grad = True + box2.requires_grad = True + output = torch_npu.npu_giou(box1, box2, trans, is_cross, mode) + output.backward(torch.ones_like(output)) + box1_grad = box1.grad + box2_grad = box2.grad + box1_grad = box1_grad.detach().cpu().numpy() + box2_grad = box2_grad.detach().cpu().numpy() + output = output.detach().cpu().numpy() + return output, box1_grad, box2_grad + + def test_npu_giou_backward_shape_format(self, dtype): + shape_list = [ + [1, 1] + ] + is_trans_list = [True] + mode_list = ["iou"] + # TODO(Ascend): only support mode=="iof", is_cross==False, + # is_trans==Fasle currently + shape_format = [[j, k, m] + for j in shape_list + for k in is_trans_list + for m in mode_list] + + for item in shape_format: + mode_digit = 0 if item[-1] == "iou" else 1 + is_cross = False if item[0][0] == item[0][1] else True + expected_cpu_grad1 = np.array([[1.0218241], + [-1.4181931], + [0.18631615], + [0.1747725]], dtype=np.float32) + expected_cpu_grad2 = np.array([[-1.0218241], + [1.4181931], + [0.17999186], + [0.23653218]], dtype=np.float32) + _, _, npu_input1, npu_input2 = self.generate_giou_data(*item[0], np.float32) + _, npu_grad1, npu_grad2 = self.npu_op_exec(npu_input1, npu_input2, item[1], is_cross, mode_digit) + self.assertRtolEqual(expected_cpu_grad1, npu_grad1) + self.assertRtolEqual(expected_cpu_grad2, npu_grad2) + + +instantiate_device_type_tests(TestNpuGiouBackward, globals(), except_for="cpu") +if __name__ == "__main__": + run_tests() diff --git a/torch_npu/csrc/aten/npu_native_functions.yaml b/torch_npu/csrc/aten/npu_native_functions.yaml index 62f461b32c5..62a9b16bdf9 100644 --- a/torch_npu/csrc/aten/npu_native_functions.yaml +++ b/torch_npu/csrc/aten/npu_native_functions.yaml @@ -1913,8 +1913,11 @@ custom: - func: npu_stride_add(Tensor self, Tensor other, Scalar offset1, Scalar offset2, Scalar c1_len) -> Tensor - func: npu_slice(Tensor self, int[] offsets, int[] size) -> Tensor - func: npu_slice.out(Tensor self, int[] offsets, int[] size, *, Tensor(a!) out) -> Tensor(a!) + - func: npu_giou(Tensor self, Tensor gtboxes, bool trans=False, bool is_cross=False, int mode=0) -> Tensor + - func: npu_giou_backward(Tensor grad, Tensor bboxes, Tensor gtboxes, bool trans=False, bool is_cross=False, int mode=0) -> (Tensor, Tensor) - func: npu_indexing(Tensor self, int[] begin, int[] end, int[] strides, int begin_mask=0, int end_mask=0, int ellipsis_mask=0, int new_axis_mask=0, int shrink_axis_mask=0) -> Tensor - func: npu_indexing.out(Tensor self, int[] begin, int[] end, int[] strides, int begin_mask=0, int end_mask=0, int ellipsis_mask=0, int new_axis_mask=0, int shrink_axis_mask=0, *, Tensor(a!) out) -> Tensor(a!) custom_autograd: - func: npu_convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, int groups) -> Tensor - func: npu_convolution_transpose(Tensor input, Tensor weight, Tensor? bias, int[] padding, int[] output_padding, int[] stride, int[] dilation, int groups) -> Tensor + - func: npu_giou(Tensor self, Tensor gtboxes, bool trans=False, bool is_cross=False, int mode=0) -> Tensor diff --git a/torch_npu/csrc/aten/ops/GiouBackwardKernelNpu.cpp b/torch_npu/csrc/aten/ops/GiouBackwardKernelNpu.cpp new file mode 100644 index 00000000000..ee79835b3b9 --- /dev/null +++ b/torch_npu/csrc/aten/ops/GiouBackwardKernelNpu.cpp @@ -0,0 +1,133 @@ +// Copyright (c) 2020 Huawei Technologies Co., Ltd +// Copyright (c) 2019, Facebook CORPORATION. +// All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include "torch_npu/csrc/framework/utils/CalcuOpUtil.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" + +namespace at_npu { +namespace native { +using namespace torch::autograd; + +std::tuple giou_backward_inner_out_npu( + at::Tensor& dbboxes, + at::Tensor& dgtboxes, + const at::Tensor& grad, + const at::Tensor& bboxes, + const at::Tensor& gtboxes, + bool trans, + bool is_cross, + int64_t mode){ + string mode_str = mode == 1 ? "iof" : "iou"; + + OpCommand cmd; + cmd.Name("GIoUGrad") + .Input(grad) + .Input(bboxes) + .Input(gtboxes) + .Output(dbboxes) + .Output(dgtboxes) + .Attr("trans", trans) + .Attr("is_cross", is_cross) + .Attr("mode", mode_str) + .Run(); + return std::tie(dbboxes, dgtboxes); +} + +std::tuple npu_giou_backward( + const at::Tensor& grad, + const at::Tensor& bboxes, + const at::Tensor& gtboxes, + bool trans, + bool is_cross, + int64_t mode){ + TORCH_CHECK(trans && !is_cross && mode == 0, + "giou backward only support trans==True, ", + "is_cross==False, ", + "mode==0('iou') current version ", + "if you need to back propagation, ", + "please ensure your parameter is correct!"); + // Op need form of [n] grad + // Note: temp avoid! it'll be remove while op deal with fp16 issue! + at::Tensor gradCp = at::squeeze(grad, 0); + if(gradCp.scalar_type() == at::kHalf){ + gradCp = gradCp.npu_dtype_cast(at::kFloat); + } + at::Tensor bboxesCp = bboxes; + if(bboxesCp.scalar_type() == at::kHalf){ + bboxesCp = bboxesCp.npu_dtype_cast(at::kFloat); + } + at::Tensor gtboxesCp = gtboxes; + if(gtboxesCp.scalar_type() == at::kHalf){ + gtboxesCp = gtboxesCp.npu_dtype_cast(at::kFloat); + } + at::Tensor dbboxes = OpPreparation::ApplyTensor(bboxesCp); + at::Tensor dgtboxes = OpPreparation::ApplyTensor(gtboxesCp); + + giou_backward_inner_out_npu(dbboxes, dgtboxes, gradCp, bboxesCp, gtboxesCp, trans, is_cross, mode); + if(bboxes.scalar_type() == at::kHalf || gtboxes.scalar_type() == at::kHalf){ + dbboxes = dbboxes.npu_dtype_cast(at::kHalf); + dgtboxes = dgtboxes.npu_dtype_cast(at::kHalf); + } + return std::tie(dbboxes, dgtboxes); +} + +class NPUGiouFunction : public torch::autograd::Function { +public: + static at::Tensor forward(AutogradContext *ctx, + const at::Tensor& self, + const at::Tensor& gtboxes, + bool trans, + bool is_cross, + int64_t mode) { + ctx->saved_data["trans"] = trans; + ctx->saved_data["is_cross"] = is_cross; + ctx->saved_data["mode"] = mode; + at::AutoNonVariableTypeMode g; + ctx->save_for_backward({self, gtboxes}); + return NPUNativeFunctions::npu_giou(self, gtboxes, trans, is_cross, mode); + } + + static tensor_list backward(AutogradContext *ctx, + tensor_list grad_outputs) { + auto trans = ctx->saved_data["trans"].toBool(); + auto is_cross = ctx->saved_data["is_cross"].toBool(); + auto mode = ctx->saved_data["mode"].toInt(); + auto saved = ctx->get_saved_variables(); + auto bboxes = saved[0]; + auto gtboxes = saved[1]; + + tuple result = NPUNativeFunctions::giou_backward_npu(grad_outputs[0], bboxes, gtboxes, trans, is_cross, mode); + + tensor_list output = {std::get<0>(result), + std::get<1>(result), + at::Tensor(), + at::Tensor(), + at::Tensor()}; + return output; + } +}; + +at::Tensor NPUNativeFunctions::npu_giou(const at::Tensor& self, + const at::Tensor& gtboxes, + bool trans, + bool is_cross, + int64_t mode) { + return NPUGiouFunction::apply(self, gtboxes, trans, is_cross, mode); +} + +} // namespace native +} // namespace at diff --git a/torch_npu/csrc/aten/ops/GiouKernelNpu.cpp b/torch_npu/csrc/aten/ops/GiouKernelNpu.cpp new file mode 100644 index 00000000000..3f70d252764 --- /dev/null +++ b/torch_npu/csrc/aten/ops/GiouKernelNpu.cpp @@ -0,0 +1,99 @@ +// Copyright (c) 2020 Huawei Technologies Co., Ltd +// Copyright (c) 2019, Facebook CORPORATION. +// All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include "torch_npu/csrc/framework/utils/CalcuOpUtil.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" + +namespace at_npu { +namespace native { +using namespace at::native::npu; + +c10::SmallVector giou_output_size( + const at::Tensor& self, + const at::Tensor& gtboxes, + bool is_cross){ + c10::SmallVector output_size; + if(is_cross){ + output_size = {gtboxes.size(0), self.size(0)}; + } else { + output_size = {1, self.size(0)}; + } + return output_size; +} + +at::Tensor& giou_inner_out_npu( + at::Tensor& result, + const at::Tensor& self, + const at::Tensor& gtboxes, + bool trans, + bool is_cross, + int64_t mode){ + auto output_size = giou_output_size(self, gtboxes, is_cross); + OpPreparation::CheckOut( + {self}, + result, + self, + output_size); + string mode_str = mode == 1 ? "iof" : "iou"; + + OpCommand cmd; + cmd.Name("GIoU") + .Input(self) + .Input(gtboxes) + .Output(result) + .Attr("trans", trans) + .Attr("is_cross", is_cross) + .Attr("mode", mode_str) + .Run(); + return result; +} + +at::Tensor NPUNativeFunctions::npu_giou( + const at::Tensor& self, + const at::Tensor& gtboxes, + bool trans, + bool is_cross, + int64_t mode) { + TORCH_CHECK(trans && !is_cross && mode == 0, + "giou backward only support trans==True, ", + "is_cross==False, ", + "mode==0('iou') current version ", + "if you need to back propagation, ", + "please ensure your parameter is correct!"); + // Op need form of [n, 4], but pass should be [4, n]; + // Note: temp avoid! it'll be removed while op deal with fp16 issue! + at::Tensor selfCp = self.permute({1, 0}); + if(selfCp.scalar_type() == at::kHalf){ + selfCp = selfCp.npu_dtype_cast(at::kFloat); + } + at::Tensor gtboxesCp = gtboxes.permute({1, 0}); + if(gtboxesCp.scalar_type() == at::kHalf){ + gtboxesCp = gtboxesCp.npu_dtype_cast(at::kFloat); + } + auto output_size = giou_output_size(selfCp, gtboxesCp, is_cross); + at::Tensor result = OpPreparation::ApplyTensor(selfCp, output_size); + + giou_inner_out_npu(result, selfCp, gtboxesCp, trans, is_cross, mode); + result = result.permute({1, 0}); + if(self.scalar_type() == at::kHalf || gtboxes.scalar_type() == at::kHalf){ + result = result.npu_dtype_cast(at::kHalf); + } + return result; +} + +} // namespace native +} // namespace at_npu -- Gitee From ab4d599c6fa8abb8ed7d1d9218be8009d63566ed Mon Sep 17 00:00:00 2001 From: wuxiankun Date: Thu, 27 Jan 2022 14:39:33 +0800 Subject: [PATCH 2/3] npuscatter,yoloboxes,normalizebatch --- test/test_network_ops/test_normalize_batch.py | 70 +++++++++ test/test_network_ops/test_scatterv1.py | 40 ++++++ .../test_yolo_boxes_encode.py | 41 ++++++ torch_npu/csrc/aten/npu_native_functions.yaml | 7 +- .../csrc/aten/ops/GiouBackwardKernelNpu.cpp | 133 ------------------ torch_npu/csrc/aten/ops/GiouKernelNpu.cpp | 118 +++++++++++++++- .../csrc/aten/ops/NormalizeBatchKernelNpu.cpp | 66 +++++++++ .../csrc/aten/ops/ScatterV1KernelNpu.cpp | 47 +++++++ .../aten/ops/YoloBoxesEncodeKernelNpu.cpp | 76 ++++++++++ 9 files changed, 459 insertions(+), 139 deletions(-) create mode 100644 test/test_network_ops/test_normalize_batch.py create mode 100644 test/test_network_ops/test_scatterv1.py create mode 100644 test/test_network_ops/test_yolo_boxes_encode.py delete mode 100644 torch_npu/csrc/aten/ops/GiouBackwardKernelNpu.cpp create mode 100644 torch_npu/csrc/aten/ops/NormalizeBatchKernelNpu.cpp create mode 100644 torch_npu/csrc/aten/ops/ScatterV1KernelNpu.cpp create mode 100644 torch_npu/csrc/aten/ops/YoloBoxesEncodeKernelNpu.cpp diff --git a/test/test_network_ops/test_normalize_batch.py b/test/test_network_ops/test_normalize_batch.py new file mode 100644 index 00000000000..8756a8f4ae8 --- /dev/null +++ b/test/test_network_ops/test_normalize_batch.py @@ -0,0 +1,70 @@ +# Copyright (c) 2020, Huawei Technologies.All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +import torch_npu +import numpy as np +from torch_npu.testing.common_utils import TestCase, run_tests +from torch_npu.testing.common_device_type import dtypes, instantiate_device_type_tests +from torch_npu.testing.util_test import create_common_tensor + + +class TestNormalizeBatch(TestCase): + def cpu_op_exec(self, input1, seq_len, normalize_type): + if normalize_type == 0: + x_mean = torch.zeros((seq_len.shape[0], input1.shape[1]), dtype=input1.dtype, + device=input1.device) + x_std = torch.zeros((seq_len.shape[0], input1.shape[1]), dtype=input1.dtype, + device=input1.device) + for i in range(input1.shape[0]): + x_mean[i, :] = input1[i, :, :seq_len[i]].mean(dim=1) + x_std[i, :] = input1[i, :, :seq_len[i]].std(dim=1) + x_std += 1e-5 + result = (input1 - x_mean.unsqueeze(2)) / x_std.unsqueeze(2) + else: + x_mean = torch.zeros(seq_len.shape, dtype=input1.dtype, + device=input1.device) + x_std = torch.zeros(seq_len.shape, dtype=input1.dtype, + device=input1.device) + for i in range(input1.shape[0]): + x_mean[i] = input1[i, :, :int(seq_len[i])].mean() + x_std[i] = input1[i, :, :int(seq_len[i])].std() + x_std += 1e-5 + result = (input1 - x_mean.view(-1, 1, 1)) / x_std.view(-1, 1, 1) + return result.numpy() + + def npu_op_exec(self, input1, seq_len, normalize_type): + out = torch_npu.npu_normalize_batch(input1, seq_len, normalize_type) + out = out.to("cpu") + return out.detach().numpy() + + def test_normalize_batch(self, device): + shape_format = [ + [[np.float32, -1, [32, 3, 6]], [np.int32, -1, [32]], 0], + [[np.float32, -1, [32, 3, 6]], [np.int32, -1, [32]], 1], + [[np.float32, -1, [16, 6, 1000]], [np.int32, -1, [16]], 0], + [[np.float32, -1, [16, 6, 1000]], [np.int32, -1, [16]], 1] + ] + for item in shape_format: + right_range = item[0][-1][-1] + cpu_input1, npu_input1 = create_common_tensor(item[0], 1, 10) + cpu_seqlen, npu_seqlen = create_common_tensor( + item[1], 3, right_range) + cpu_output = self.cpu_op_exec(cpu_input1, cpu_seqlen, item[-1]) + npu_output = self.npu_op_exec(npu_input1, npu_seqlen, item[-1]) + self.assertRtolEqual(cpu_output, npu_output) + + +instantiate_device_type_tests(TestNormalizeBatch, globals(), except_for='cpu') +if __name__ == "__main__": + run_tests() diff --git a/test/test_network_ops/test_scatterv1.py b/test/test_network_ops/test_scatterv1.py new file mode 100644 index 00000000000..e0f0f8fa4c3 --- /dev/null +++ b/test/test_network_ops/test_scatterv1.py @@ -0,0 +1,40 @@ +# Copyright (c) 2020, Huawei Technologies.All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +import torch_npu +import numpy as np +from torch_npu.testing.common_device_type import dtypes, instantiate_device_type_tests +from torch_npu.testing.util_test import create_common_tensor +from torch_npu.testing.common_utils import TestCase, run_tests + + +class TestScatterV1(TestCase): + def npu_op_exec(self, input1, indices, updates, dim): + output = torch_npu.npu_scatter(input1, indices, updates, dim) + output = output.to("cpu") + output = output.numpy() + return output + + def test_scatterv1(self, device): + input1 = torch.tensor([[1.6279, 0.1226], [0.9041, 1.0980]]).npu() + indices = torch.tensor([0, 1]).npu().to(torch.int32) + updates = torch.tensor([-1.1993, -1.5247]).npu() + dim = 0 + exoutput = torch.tensor([[-1.1993, 0.1226], [0.9041, -1.5247]]) + output = self.npu_op_exec(input1, indices, updates, dim) + self.assertRtolEqual(exoutput.numpy(), output) + +instantiate_device_type_tests(TestScatterV1, globals(), except_for="cpu") +if __name__ == "__main__": + run_tests() diff --git a/test/test_network_ops/test_yolo_boxes_encode.py b/test/test_network_ops/test_yolo_boxes_encode.py new file mode 100644 index 00000000000..6c32695c20f --- /dev/null +++ b/test/test_network_ops/test_yolo_boxes_encode.py @@ -0,0 +1,41 @@ +# Copyright (c) 2020, Huawei Technologies.All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +import torch_npu +import numpy as np +from torch_npu.testing.common_utils import TestCase, run_tests +from torch_npu.testing.common_device_type import dtypes, instantiate_device_type_tests +from torch_npu.testing.util_test import create_common_tensor + + +class TestYoloBoxesEncode(TestCase): + def npu_op_exec(self, anchor_boxes, gt_bboxes, stride, impl_mode=False): + out = torch_npu.npu_yolo_boxes_encode( + anchor_boxes, gt_bboxes, stride, impl_mode) + out = out.to("cpu") + return out.detach().numpy() + + def test_yolo_boxes_encode(self, device): + anchor_boxes = torch.rand((2, 4), dtype=torch.float32).to("npu") + gt_bboxes = torch.rand((2, 4), dtype=torch.float32).to("npu") + stride = torch.tensor([2, 2], dtype=torch.int32).to("npu") + expect_cpu = torch.tensor([[0.7921727, 0.5314963, -0.74224466, -13.815511], + [0.7360072, 0.58343244, 4.3334002, -0.51378196]], dtype=torch.float32) + npu_output = self.npu_op_exec(anchor_boxes, gt_bboxes, stride, False) + self.assertRtolEqual(expect_cpu.numpy(), npu_output) + + +instantiate_device_type_tests(TestYoloBoxesEncode, globals(), except_for='cpu') +if __name__ == "__main__": + run_tests() diff --git a/torch_npu/csrc/aten/npu_native_functions.yaml b/torch_npu/csrc/aten/npu_native_functions.yaml index 62a9b16bdf9..300dfc212c1 100644 --- a/torch_npu/csrc/aten/npu_native_functions.yaml +++ b/torch_npu/csrc/aten/npu_native_functions.yaml @@ -1913,10 +1913,15 @@ custom: - func: npu_stride_add(Tensor self, Tensor other, Scalar offset1, Scalar offset2, Scalar c1_len) -> Tensor - func: npu_slice(Tensor self, int[] offsets, int[] size) -> Tensor - func: npu_slice.out(Tensor self, int[] offsets, int[] size, *, Tensor(a!) out) -> Tensor(a!) - - func: npu_giou(Tensor self, Tensor gtboxes, bool trans=False, bool is_cross=False, int mode=0) -> Tensor - func: npu_giou_backward(Tensor grad, Tensor bboxes, Tensor gtboxes, bool trans=False, bool is_cross=False, int mode=0) -> (Tensor, Tensor) - func: npu_indexing(Tensor self, int[] begin, int[] end, int[] strides, int begin_mask=0, int end_mask=0, int ellipsis_mask=0, int new_axis_mask=0, int shrink_axis_mask=0) -> Tensor - func: npu_indexing.out(Tensor self, int[] begin, int[] end, int[] strides, int begin_mask=0, int end_mask=0, int ellipsis_mask=0, int new_axis_mask=0, int shrink_axis_mask=0, *, Tensor(a!) out) -> Tensor(a!) + - func: npu_normalize_batch(Tensor self, Tensor seq_len, int normalize_type=0) -> Tensor + variants: function, method + - func: npu_yolo_boxes_encode(Tensor self, Tensor gt_bboxes, Tensor stride, bool performance_mode=False) -> Tensor + variants: function, method + - func: npu_scatter(Tensor self, Tensor indices, Tensor updates, int dim) -> Tensor + variants: function, method custom_autograd: - func: npu_convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, int groups) -> Tensor - func: npu_convolution_transpose(Tensor input, Tensor weight, Tensor? bias, int[] padding, int[] output_padding, int[] stride, int[] dilation, int groups) -> Tensor diff --git a/torch_npu/csrc/aten/ops/GiouBackwardKernelNpu.cpp b/torch_npu/csrc/aten/ops/GiouBackwardKernelNpu.cpp deleted file mode 100644 index ee79835b3b9..00000000000 --- a/torch_npu/csrc/aten/ops/GiouBackwardKernelNpu.cpp +++ /dev/null @@ -1,133 +0,0 @@ -// Copyright (c) 2020 Huawei Technologies Co., Ltd -// Copyright (c) 2019, Facebook CORPORATION. -// All rights reserved. -// -// Licensed under the BSD 3-Clause License (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://opensource.org/licenses/BSD-3-Clause -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "torch_npu/csrc/framework/utils/OpAdapter.h" -#include "torch_npu/csrc/framework/utils/CalcuOpUtil.h" -#include "torch_npu/csrc/aten/NPUNativeFunctions.h" - -namespace at_npu { -namespace native { -using namespace torch::autograd; - -std::tuple giou_backward_inner_out_npu( - at::Tensor& dbboxes, - at::Tensor& dgtboxes, - const at::Tensor& grad, - const at::Tensor& bboxes, - const at::Tensor& gtboxes, - bool trans, - bool is_cross, - int64_t mode){ - string mode_str = mode == 1 ? "iof" : "iou"; - - OpCommand cmd; - cmd.Name("GIoUGrad") - .Input(grad) - .Input(bboxes) - .Input(gtboxes) - .Output(dbboxes) - .Output(dgtboxes) - .Attr("trans", trans) - .Attr("is_cross", is_cross) - .Attr("mode", mode_str) - .Run(); - return std::tie(dbboxes, dgtboxes); -} - -std::tuple npu_giou_backward( - const at::Tensor& grad, - const at::Tensor& bboxes, - const at::Tensor& gtboxes, - bool trans, - bool is_cross, - int64_t mode){ - TORCH_CHECK(trans && !is_cross && mode == 0, - "giou backward only support trans==True, ", - "is_cross==False, ", - "mode==0('iou') current version ", - "if you need to back propagation, ", - "please ensure your parameter is correct!"); - // Op need form of [n] grad - // Note: temp avoid! it'll be remove while op deal with fp16 issue! - at::Tensor gradCp = at::squeeze(grad, 0); - if(gradCp.scalar_type() == at::kHalf){ - gradCp = gradCp.npu_dtype_cast(at::kFloat); - } - at::Tensor bboxesCp = bboxes; - if(bboxesCp.scalar_type() == at::kHalf){ - bboxesCp = bboxesCp.npu_dtype_cast(at::kFloat); - } - at::Tensor gtboxesCp = gtboxes; - if(gtboxesCp.scalar_type() == at::kHalf){ - gtboxesCp = gtboxesCp.npu_dtype_cast(at::kFloat); - } - at::Tensor dbboxes = OpPreparation::ApplyTensor(bboxesCp); - at::Tensor dgtboxes = OpPreparation::ApplyTensor(gtboxesCp); - - giou_backward_inner_out_npu(dbboxes, dgtboxes, gradCp, bboxesCp, gtboxesCp, trans, is_cross, mode); - if(bboxes.scalar_type() == at::kHalf || gtboxes.scalar_type() == at::kHalf){ - dbboxes = dbboxes.npu_dtype_cast(at::kHalf); - dgtboxes = dgtboxes.npu_dtype_cast(at::kHalf); - } - return std::tie(dbboxes, dgtboxes); -} - -class NPUGiouFunction : public torch::autograd::Function { -public: - static at::Tensor forward(AutogradContext *ctx, - const at::Tensor& self, - const at::Tensor& gtboxes, - bool trans, - bool is_cross, - int64_t mode) { - ctx->saved_data["trans"] = trans; - ctx->saved_data["is_cross"] = is_cross; - ctx->saved_data["mode"] = mode; - at::AutoNonVariableTypeMode g; - ctx->save_for_backward({self, gtboxes}); - return NPUNativeFunctions::npu_giou(self, gtboxes, trans, is_cross, mode); - } - - static tensor_list backward(AutogradContext *ctx, - tensor_list grad_outputs) { - auto trans = ctx->saved_data["trans"].toBool(); - auto is_cross = ctx->saved_data["is_cross"].toBool(); - auto mode = ctx->saved_data["mode"].toInt(); - auto saved = ctx->get_saved_variables(); - auto bboxes = saved[0]; - auto gtboxes = saved[1]; - - tuple result = NPUNativeFunctions::giou_backward_npu(grad_outputs[0], bboxes, gtboxes, trans, is_cross, mode); - - tensor_list output = {std::get<0>(result), - std::get<1>(result), - at::Tensor(), - at::Tensor(), - at::Tensor()}; - return output; - } -}; - -at::Tensor NPUNativeFunctions::npu_giou(const at::Tensor& self, - const at::Tensor& gtboxes, - bool trans, - bool is_cross, - int64_t mode) { - return NPUGiouFunction::apply(self, gtboxes, trans, is_cross, mode); -} - -} // namespace native -} // namespace at diff --git a/torch_npu/csrc/aten/ops/GiouKernelNpu.cpp b/torch_npu/csrc/aten/ops/GiouKernelNpu.cpp index 3f70d252764..a7956d338ea 100644 --- a/torch_npu/csrc/aten/ops/GiouKernelNpu.cpp +++ b/torch_npu/csrc/aten/ops/GiouKernelNpu.cpp @@ -16,11 +16,12 @@ #include "torch_npu/csrc/framework/utils/OpAdapter.h" #include "torch_npu/csrc/framework/utils/CalcuOpUtil.h" +#include #include "torch_npu/csrc/aten/NPUNativeFunctions.h" namespace at_npu { namespace native { -using namespace at::native::npu; +using namespace torch::autograd; c10::SmallVector giou_output_size( const at::Tensor& self, @@ -62,7 +63,7 @@ at::Tensor& giou_inner_out_npu( return result; } -at::Tensor NPUNativeFunctions::npu_giou( +at::Tensor giou_kernel_npu( const at::Tensor& self, const at::Tensor& gtboxes, bool trans, @@ -78,11 +79,11 @@ at::Tensor NPUNativeFunctions::npu_giou( // Note: temp avoid! it'll be removed while op deal with fp16 issue! at::Tensor selfCp = self.permute({1, 0}); if(selfCp.scalar_type() == at::kHalf){ - selfCp = selfCp.npu_dtype_cast(at::kFloat); + selfCp = NPUNativeFunctions::npu_dtype_cast(selfCp, at::kFloat); } at::Tensor gtboxesCp = gtboxes.permute({1, 0}); if(gtboxesCp.scalar_type() == at::kHalf){ - gtboxesCp = gtboxesCp.npu_dtype_cast(at::kFloat); + gtboxesCp = NPUNativeFunctions::npu_dtype_cast(gtboxesCp, at::kFloat); } auto output_size = giou_output_size(selfCp, gtboxesCp, is_cross); at::Tensor result = OpPreparation::ApplyTensor(selfCp, output_size); @@ -90,10 +91,117 @@ at::Tensor NPUNativeFunctions::npu_giou( giou_inner_out_npu(result, selfCp, gtboxesCp, trans, is_cross, mode); result = result.permute({1, 0}); if(self.scalar_type() == at::kHalf || gtboxes.scalar_type() == at::kHalf){ - result = result.npu_dtype_cast(at::kHalf); + result = NPUNativeFunctions::npu_dtype_cast(result, at::kHalf); } return result; } +std::tuple giou_backward_inner_out_npu( + at::Tensor& dbboxes, + at::Tensor& dgtboxes, + const at::Tensor& grad, + const at::Tensor& bboxes, + const at::Tensor& gtboxes, + bool trans, + bool is_cross, + int64_t mode){ + string mode_str = mode == 1 ? "iof" : "iou"; + + OpCommand cmd; + cmd.Name("GIoUGrad") + .Input(grad) + .Input(bboxes) + .Input(gtboxes) + .Output(dbboxes) + .Output(dgtboxes) + .Attr("trans", trans) + .Attr("is_cross", is_cross) + .Attr("mode", mode_str) + .Run(); + return std::tie(dbboxes, dgtboxes); +} + +std::tuple NPUNativeFunctions::npu_giou_backward( + const at::Tensor& grad, + const at::Tensor& bboxes, + const at::Tensor& gtboxes, + bool trans, + bool is_cross, + int64_t mode){ + TORCH_CHECK(trans && !is_cross && mode == 0, + "giou backward only support trans==True, ", + "is_cross==False, ", + "mode==0('iou') current version ", + "if you need to back propagation, ", + "please ensure your parameter is correct!"); + // Op need form of [n] grad + // Note: temp avoid! it'll be remove while op deal with fp16 issue! + at::Tensor gradCp = at::squeeze(grad, 0); + if(gradCp.scalar_type() == at::kHalf){ + gradCp = NPUNativeFunctions::npu_dtype_cast(gradCp, at::kFloat); + } + at::Tensor bboxesCp = bboxes; + if(bboxesCp.scalar_type() == at::kHalf){ + bboxesCp = NPUNativeFunctions::npu_dtype_cast(bboxesCp, at::kFloat); + } + at::Tensor gtboxesCp = gtboxes; + if(gtboxesCp.scalar_type() == at::kHalf){ + gtboxesCp = NPUNativeFunctions::npu_dtype_cast(gtboxesCp, at::kFloat); + } + at::Tensor dbboxes = OpPreparation::ApplyTensor(bboxesCp); + at::Tensor dgtboxes = OpPreparation::ApplyTensor(gtboxesCp); + + giou_backward_inner_out_npu(dbboxes, dgtboxes, gradCp, bboxesCp, gtboxesCp, trans, is_cross, mode); + if(bboxes.scalar_type() == at::kHalf || gtboxes.scalar_type() == at::kHalf){ + dbboxes = NPUNativeFunctions::npu_dtype_cast(dbboxes, at::kHalf); + dgtboxes = NPUNativeFunctions::npu_dtype_cast(dgtboxes, at::kHalf); + } + return std::tie(dbboxes, dgtboxes); +} + +class NPUGiouFunction : public torch::autograd::Function { +public: + static at::Tensor forward(AutogradContext *ctx, + const at::Tensor& self, + const at::Tensor& gtboxes, + bool trans, + bool is_cross, + int64_t mode) { + ctx->saved_data["trans"] = trans; + ctx->saved_data["is_cross"] = is_cross; + ctx->saved_data["mode"] = mode; + at::AutoNonVariableTypeMode g; + ctx->save_for_backward({self, gtboxes}); + return giou_kernel_npu(self, gtboxes, trans, is_cross, mode); + } + + static tensor_list backward(AutogradContext *ctx, + tensor_list grad_outputs) { + auto trans = ctx->saved_data["trans"].toBool(); + auto is_cross = ctx->saved_data["is_cross"].toBool(); + auto mode = ctx->saved_data["mode"].toInt(); + auto saved = ctx->get_saved_variables(); + auto bboxes = saved[0]; + auto gtboxes = saved[1]; + + tuple result = NPUNativeFunctions::npu_giou_backward(grad_outputs[0], bboxes, gtboxes, trans, is_cross, mode); + + tensor_list output = {std::get<0>(result), + std::get<1>(result), + at::Tensor(), + at::Tensor(), + at::Tensor()}; + return output; + } +}; + +at::Tensor NPUNativeFunctions::npu_giou(const at::Tensor& self, + const at::Tensor& gtboxes, + bool trans, + bool is_cross, + int64_t mode) { + return NPUGiouFunction::apply(self, gtboxes, trans, is_cross, mode); +} + } // namespace native } // namespace at_npu diff --git a/torch_npu/csrc/aten/ops/NormalizeBatchKernelNpu.cpp b/torch_npu/csrc/aten/ops/NormalizeBatchKernelNpu.cpp new file mode 100644 index 00000000000..8befc312256 --- /dev/null +++ b/torch_npu/csrc/aten/ops/NormalizeBatchKernelNpu.cpp @@ -0,0 +1,66 @@ +// Copyright (c) 2020 Huawei Technologies Co., Ltd +// Copyright (c) 2019, Facebook CORPORATION. +// All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" + +constexpr float_t EPSILON = 1e-5; + +namespace at_npu { +namespace native { + +static inline void normalize_batch_check( + const at::Tensor& self, + const at::Tensor& seq_len, + int64_t normalize_type){ + TORCH_CHECK( + seq_len.dim() == 1, + "Non-empty 1D seq_len tensor expected but got a tensor with sizes ", + seq_len.sizes()); + TORCH_CHECK( + seq_len.size(0) == self.size(0), + "seq_len's length should be equal self' num, but got seq_len length ", + seq_len.size(0), + "self num ", + self.size(0)); + TORCH_CHECK( + normalize_type >= 0 && normalize_type <= 1, + "normalize_type expected to be in range [0, 1], but got ", + normalize_type); +} + +at::Tensor NPUNativeFunctions::npu_normalize_batch( + const at::Tensor& self, + const at::Tensor& seq_len, + int64_t normalize_type){ + normalize_batch_check(self, seq_len, normalize_type); + // apply output tensor + at::Tensor result = OpPreparation::ApplyTensor(self); + string normalizeType = normalize_type == 0 ? "per_feature" : "all_features"; + + OpCommand cmd; + cmd.Name("NormalizeBatch") + .Input(self) + .Input(seq_len) + .Output(result) + .Attr("normalize_type", normalizeType) + .Attr("epsilon", EPSILON) + .Run(); + return result; +} + +} // namespace native +} // namespace at_npu diff --git a/torch_npu/csrc/aten/ops/ScatterV1KernelNpu.cpp b/torch_npu/csrc/aten/ops/ScatterV1KernelNpu.cpp new file mode 100644 index 00000000000..3b318388723 --- /dev/null +++ b/torch_npu/csrc/aten/ops/ScatterV1KernelNpu.cpp @@ -0,0 +1,47 @@ +// Copyright (c) 2020 Huawei Technologies Co., Ltd +// Copyright (c) 2019, Facebook CORPORATION. +// All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" + +namespace at_npu { +namespace native { + +at::Tensor& scatter_out_npu( + at::Tensor& output, + const at::Tensor& self, + const at::Tensor& indices, + const at::Tensor& updates, + int64_t dim) { + OpCommand cmd; + cmd.Name("ArgMaxGrad") + .Input(self) + .Input(indices) + .Input(updates) + .Output(output) + .Attr("dimension", dim) + .Run(); + return output; +} + +at::Tensor NPUNativeFunctions::npu_scatter(const at::Tensor& self, const at::Tensor& indices, const at::Tensor& updates, int64_t dim) { + at::Tensor outputs = OpPreparation::ApplyTensor(self); + scatter_out_npu(outputs, self, indices, updates, dim); + return outputs; +} + +} // namespace native +} // namespace at_npu diff --git a/torch_npu/csrc/aten/ops/YoloBoxesEncodeKernelNpu.cpp b/torch_npu/csrc/aten/ops/YoloBoxesEncodeKernelNpu.cpp new file mode 100644 index 00000000000..607e8a8d390 --- /dev/null +++ b/torch_npu/csrc/aten/ops/YoloBoxesEncodeKernelNpu.cpp @@ -0,0 +1,76 @@ +// Copyright (c) 2020 Huawei Technologies Co., Ltd +// Copyright (c) 2019, Facebook CORPORATION. +// All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" + +namespace at_npu { +namespace native { + +static inline void yolo_boxes_encode_check( + const at::Tensor& anchor_boxes, + const at::Tensor& gt_bboxes, + const at::Tensor& stride){ + TORCH_CHECK( + anchor_boxes.dim() == 2 && anchor_boxes.size(1) == 4, + "Non-empty 2D anchor_boxes tensor expected but got a tensor with sizes ", + anchor_boxes.sizes()); + TORCH_CHECK( + anchor_boxes.size(0) <= 20480, + "anchor_boxes only support max [20480] num, but got num ", + anchor_boxes.size(0)); + TORCH_CHECK( + gt_bboxes.dim() == 2 && gt_bboxes.size(1) == 4, + "Non-empty 2D gt_bboxes tensor expected but got a tensor with sizes ", + gt_bboxes.sizes()); + TORCH_CHECK( + stride.dim() == 1, + "Non-empty 1D stride tensor expected but got a tensor with sizes ", + stride.sizes()); + TORCH_CHECK( + stride.size(0) == gt_bboxes.size(0), + "stride's length should be equal gt_bboxes' num, but got stride length ", + stride.size(0), + "gt_bboxes num ", + gt_bboxes.size(0)); + TORCH_CHECK( + at::isIntegralType(stride.scalar_type(), true) && stride.scalar_type() != at::ScalarType::Long, + "int32 strdie tensor expected but got a tensor with dtype: ", + stride.scalar_type()); +} + +at::Tensor NPUNativeFunctions::npu_yolo_boxes_encode( + const at::Tensor& anchor_boxes, + const at::Tensor& gt_bboxes, + const at::Tensor& stride, + bool performance_mode){ + yolo_boxes_encode_check(anchor_boxes, gt_bboxes, stride); + at::Tensor result = OpPreparation::ApplyTensor(gt_bboxes); + string implModeStr = performance_mode ? "high_performance" : "high_precision"; + at::Tensor strideCp = NPUNativeFunctions::npu_dtype_cast(stride, at::ScalarType::Int); + OpCommand cmd; + cmd.Name("YoloBoxesEncode") + .Input(anchor_boxes) + .Input(gt_bboxes) + .Input(strideCp) + .Output(result) + .Attr("performance_mode", implModeStr) + .Run(); + return result; +} + +} // namespace native +} // namespace at_npu -- Gitee From d1d35dc81c8e09f922fa4ec8dccf604e887bac9c Mon Sep 17 00:00:00 2001 From: wuxiankun Date: Thu, 27 Jan 2022 15:36:11 +0800 Subject: [PATCH 3/3] ifmr,mish,deformable, subsample && clean code smell && format --- .../test_deformable_conv2d.py | 71 ++++++ .../test_deformable_conv2d_backward.py | 130 +++++++++++ test/test_network_ops/test_ifmr.py | 138 +++++++++++ test/test_network_ops/test_mish.py | 63 +++++ test/test_network_ops/test_mish_backward.py | 60 +++++ test/test_network_ops/test_normalize_batch.py | 3 +- test/test_network_ops/test_npu_giou.py | 1 + .../test_npu_giou_backward.py | 1 + test/test_network_ops/test_scatterv1.py | 3 +- .../test_yolo_boxes_encode.py | 3 +- torch_npu/csrc/aten/npu_native_functions.yaml | 9 + torch_npu/csrc/aten/ops/GiouKernelNpu.cpp | 3 +- torch_npu/csrc/aten/ops/IfmrKernelNpu.cpp | 63 +++++ torch_npu/csrc/aten/ops/MishKernelNpu.cpp | 74 ++++++ .../csrc/aten/ops/SubSampleKernelNpu.cpp | 37 +++ .../convolution/DeformableConv2dKernelNpu.cpp | 217 ++++++++++++++++++ 16 files changed, 872 insertions(+), 4 deletions(-) create mode 100644 test/test_network_ops/test_deformable_conv2d.py create mode 100644 test/test_network_ops/test_deformable_conv2d_backward.py create mode 100644 test/test_network_ops/test_ifmr.py create mode 100644 test/test_network_ops/test_mish.py create mode 100644 test/test_network_ops/test_mish_backward.py create mode 100644 torch_npu/csrc/aten/ops/IfmrKernelNpu.cpp create mode 100644 torch_npu/csrc/aten/ops/MishKernelNpu.cpp create mode 100644 torch_npu/csrc/aten/ops/SubSampleKernelNpu.cpp create mode 100644 torch_npu/csrc/aten/ops/convolution/DeformableConv2dKernelNpu.cpp diff --git a/test/test_network_ops/test_deformable_conv2d.py b/test/test_network_ops/test_deformable_conv2d.py new file mode 100644 index 00000000000..892c3b2113d --- /dev/null +++ b/test/test_network_ops/test_deformable_conv2d.py @@ -0,0 +1,71 @@ +# Copyright (c) 2020, Huawei Technologies.All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch_npu +import numpy as np + +from torch_npu.testing.common_utils import TestCase, run_tests +from torch_npu.testing.common_device_type import instantiate_device_type_tests + +class TestDeformableConv2d(TestCase): + def create_single_npu_tensor(self, item, minvalue, maxvalue): + dtype = item[0] + format_ = item[1] + shape = item[2] + input1 = np.random.uniform(minvalue, maxvalue, shape).astype(dtype) + npu_input = torch.from_numpy(input1).to("npu") + if format_ != -1: + npu_input = npu_input.npu_format_cast(format_) + return npu_input + + def test_deformable_conv2d_fp32(self, device): + input1 = self.create_single_npu_tensor([np.float32, 0, (16, 32, 32, 32)], 0, 10) + weight = self.create_single_npu_tensor([np.float32, 0, (32, 32, 5, 5)], 0, 10) + offset = self.create_single_npu_tensor([np.float32, 0, (16, 75, 32, 32)], 0, 10) + npu_output, offset_out = torch_npu.npu_deformable_conv2d( + input1, weight, offset, None, kernel_size=[5, 5], stride = [1, 1, 1, 1], padding = [2, 2, 2, 2]) + npu_output = npu_output.cpu().detach() + output = npu_output.select(1, 2).select(1, 2).select(1, 2) + expect_output = torch.tensor([65504., 65504., 65504., 65504., 65504., 65504., 65504., 65504., 65504., + 65504., 65504., 65504., 65504., 65504., 65504., 65504.]) + self.assertRtolEqual(expect_output, output) + + offset_out_select = offset_out.select(1, 2).select(1, 2).select(1, 2) + expect_offset_out = torch.tensor([13.6374, 6.0295, 79.6111, 33.5996, 53.6248, 62.2289, 14.3497, 47.6463, + 52.3312, 34.1246, 8.6705, 3.3515, 9.9513, 15.3604, 38.9772, 57.1306]) + self.assertRtolEqual(expect_offset_out, offset_out_select.cpu().detach()) + + def test_deformable_conv2d_fp16(self, device): + input_fp16 = self.create_single_npu_tensor([np.float16, 0, (16, 32, 32, 32)], 0, 10) + weight = self.create_single_npu_tensor([np.float16, 0, (32, 32, 5, 5)], 0, 10) + offset = self.create_single_npu_tensor([np.float16, 0, (16, 75, 32, 32)], 0, 10) + npu_output, offset_out = torch_npu.npu_deformable_conv2d( + input_fp16, weight, offset, None, kernel_size=[5, 5], stride = [1, 1, 1, 1], padding = [2, 2, 2, 2]) + npu_output = npu_output.cpu().detach() + output = npu_output.select(1, 2).select(1, 2).select(1, 2) + expect_output = torch.tensor([65504., 65504., 65504., 65504., 65504., 65504., 65504., 65504., 65504., + 65504., 65504., 65504., 65504., 65504., 65504., 65504.], dtype=torch.float16) + self.assertRtolEqual(expect_output, output) + + offset_out_select = offset_out.select(1, 2).select(1, 2).select(1, 2) + expect_offset_out = torch.tensor([13.6562, 6.0352, 79.4375, 33.5938, 53.6875, 62.2188, 14.3438, 47.6562, + 52.3750, 34.0938, 8.6797, 3.3516, 9.9531, 15.3750, 39.0625, 57.1875], + dtype=torch.float16) + self.assertRtolEqual(expect_offset_out, offset_out_select.cpu().detach()) + + +instantiate_device_type_tests(TestDeformableConv2d, globals(), except_for="cpu") +if __name__ == "__main__": + run_tests() diff --git a/test/test_network_ops/test_deformable_conv2d_backward.py b/test/test_network_ops/test_deformable_conv2d_backward.py new file mode 100644 index 00000000000..fe4a154bbb2 --- /dev/null +++ b/test/test_network_ops/test_deformable_conv2d_backward.py @@ -0,0 +1,130 @@ +# Copyright (c) 2020, Huawei Technologies.All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch_npu +import numpy as np + +from torch_npu.testing.common_utils import TestCase, run_tests +from torch_npu.testing.common_device_type import instantiate_device_type_tests + + +class TestDeformableConv2dBackward(TestCase): + def create_single_npu_tensor(self, item, minvalue, maxvalue): + dtype = item[0] + format_ = item[1] + shape = item[2] + input1 = np.random.uniform(minvalue, maxvalue, shape).astype(dtype) + npu_input = torch.from_numpy(input1).to("npu") + if format_ != -1: + npu_input = npu_input.npu_format_cast(format_) + return npu_input + + def test_deformable_conv2d_backward_fp32(self, device): + input1 = self.create_single_npu_tensor( + [np.float32, 0, (16, 32, 32, 32)], 0, 10) + weight = self.create_single_npu_tensor( + [np.float32, 0, (32, 32, 5, 5)], 0, 10) + offset = self.create_single_npu_tensor( + [np.float32, 0, (16, 75, 32, 32)], 0, 10) + input1.requires_grad = True + weight.requires_grad = True + offset.requires_grad = True + + npu_output, offset_out = torch_npu.npu_deformable_conv2d( + input1, weight, offset, None, kernel_size=[5, 5], stride=[1, 1, 1, 1], padding=[2, 2, 2, 2]) + npu_output.backward(torch.ones_like(npu_output)) + npu_output = npu_output.cpu().detach() + output = npu_output.select(1, 2).select(1, 2).select(1, 2) + expect_output = torch.tensor([65504., 65504., 65504., 65504., 65504., 65504., 65504., 65504., 65504., + 65504., 65504., 65504., 65504., 65504., 65504., 65504.]) + self.assertRtolEqual(expect_output, output) + + offset_out_select = offset_out.select(1, 2).select(1, 2).select(1, 2) + expect_offset_out = torch.tensor([13.6374, 6.0295, 79.6111, 33.5996, 53.6248, 62.2289, 14.3497, 47.6463, + 52.3312, 34.1246, 8.6705, 3.3515, 9.9513, 15.3604, 38.9772, 57.1306]) + self.assertRtolEqual(expect_offset_out, offset_out_select.cpu().detach()) + + input_grad = input1.grad.select(1, 2).select(1, 2).select(1, 3) + expect_input_grad = torch.tensor([1018.3082, 1080.2413, 2533.7673, 1305.1088, 3977.8022, 2363.5249, + 1414.8381, 2117.0735, 1400.9083, 2064.4629, 1945.1212, 2338.8213, + 300.2924, 2646.9910, 1898.9320, 2165.8921]) + self.assertRtolEqual(expect_input_grad, input_grad.cpu()) + + offest_grad = offset.grad.select(1, 2).select(1, 2).select(1, 3) + expect_offest_grad = torch.tensor([-4707.2891, -139.2936, -2391.8394, 31024.4375, 19856.1621, + -1205.9329, -24091.1953, -3947.4133, -31050.9805, 4570.9854, + 582.9227, -5515.2178, 78396.9062, -1778.6783, -14314.9893, + -2066.9614]) + self.assertRtolEqual(expect_offest_grad, offest_grad.cpu()) + + weight_grad = weight.grad.select(1, 2).select(1, 2).select(1, 3) + expect_weight_grad = torch.tensor([279501.8438, 279501.8438, 279501.8438, 279501.8438, 279501.8438, + 279501.8438, 279501.8438, 279501.8438, 279501.8750, 279501.8750, + 279501.8750, 279501.8750, 279501.8750, 279501.8750, 279501.8750, + 279501.8750, 279501.8438, 279501.8438, 279501.8438, 279501.8438, + 279501.8125, 279501.8125, 279501.8125, 279501.8125, 279501.8750, + 279501.8750, 279501.8750, 279501.8750, 279501.8750, 279501.8750, + 279501.8750, 279501.8750]) + self.assertRtolEqual(expect_weight_grad, weight_grad.cpu()) + + def test_deformable_conv2d_backward_fp16(self, device): + input_fp16 = self.create_single_npu_tensor( + [np.float16, 0, (16, 32, 32, 32)], 0, 10) + weight = self.create_single_npu_tensor( + [np.float16, 0, (32, 32, 5, 5)], 0, 10) + offset = self.create_single_npu_tensor( + [np.float16, 0, (16, 75, 32, 32)], 0, 10) + input_fp16.requires_grad = True + weight.requires_grad = True + offset.requires_grad = True + npu_output, offset_out = torch_npu.npu_deformable_conv2d( + input_fp16, weight, offset, None, kernel_size=[5, 5], stride=[1, 1, 1, 1], padding=[2, 2, 2, 2]) + npu_output.backward(torch.ones_like(npu_output)) + npu_output = npu_output.cpu().detach() + output = npu_output.select(1, 2).select(1, 2).select(1, 2) + expect_output = torch.tensor([65504., 65504., 65504., 65504., 65504., 65504., 65504., 65504., 65504., + 65504., 65504., 65504., 65504., 65504., 65504., 65504.], dtype=torch.float16) + self.assertRtolEqual(expect_output, output) + + offset_out_select = offset_out.select(1, 2).select(1, 2).select(1, 2) + expect_offset_out = torch.tensor([13.6562, 6.0352, 79.4375, 33.5938, 53.6875, 62.2188, 14.3438, 47.6562, + 52.3750, 34.0938, 8.6797, 3.3516, 9.9531, 15.3750, 39.0625, 57.1875], + dtype=torch.float16) + self.assertRtolEqual(expect_offset_out, offset_out_select.cpu().detach()) + + input_grad = input_fp16.grad.select(1, 2).select(1, 2).select(1, 3) + expect_input_grad = torch.tensor([1019.0000, 1082.0000, 2534.0000, 1304.0000, 3978.0000, 2364.0000, + 1413.0000, 2114.0000, 1402.0000, 2064.0000, 1944.0000, 2340.0000, + 299.7500, 2648.0000, 1895.0000, 2164.0000], dtype=torch.float16) + self.assertRtolEqual(expect_input_grad, input_grad.cpu()) + + offest_grad = offset.grad.select(1, 2).select(1, 2).select(1, 2) + expect_offest_grad = torch.tensor([141.2500, 34784.0000, -12384.0000, -1885.0000, -5440.0000, + 4416.0000, 13920.0000, -19952.0000, 4160.0000, 24848.0000, + -1464.0000, -21088.0000, -1060.0000, -22544.0000, 9152.0000, + -4312.0000], dtype=torch.float16) + self.assertRtolEqual(expect_offest_grad, offest_grad.cpu()) + + weight_grad = weight.grad.select(1, 2).select(1, 2).select(1, 3) + expect_weight_grad = torch.tensor([65504., 65504., 65504., 65504., 65504., 65504., 65504., 65504., 65504., + 65504., 65504., 65504., 65504., 65504., 65504., 65504., 65504., 65504., + 65504., 65504., 65504., 65504., 65504., 65504., 65504., 65504., 65504., + 65504., 65504., 65504., 65504., 65504.], dtype=torch.float16) + self.assertRtolEqual(expect_weight_grad, weight_grad.cpu()) + + +instantiate_device_type_tests(TestDeformableConv2dBackward, globals(), except_for="cpu") +if __name__ == "__main__": + run_tests() diff --git a/test/test_network_ops/test_ifmr.py b/test/test_network_ops/test_ifmr.py new file mode 100644 index 00000000000..ef96ee5c305 --- /dev/null +++ b/test/test_network_ops/test_ifmr.py @@ -0,0 +1,138 @@ +# Copyright (c) 2020, Huawei Technologies.All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from functools import reduce +import torch +import torch_npu +import numpy as np + +from torch_npu.testing.common_utils import TestCase, run_tests +from torch_npu.testing.common_device_type import instantiate_device_type_tests +from torch_npu.testing.util_test import create_common_tensor + + +class TestIFMR(TestCase): + def cpu_op_exec(self, + input_data, + with_offset, + bins_num=128, + min_percentile=0.999999, + max_percentile=0.999999, + search_range=(0.7, 1.3), + search_step=0.01): + pre_mode = np.float32 + input_data = input_data.numpy().astype(pre_mode) + data_min = input_data.min() + data_max = input_data.max() + data_num = reduce(lambda x, y: x * y, input_data.shape) + data_num = np.array(data_num, pre_mode) + + bins, threshold = np.histogram(input_data, bins_num) + cumsum = np.cumsum(bins).astype(np.int32) + bins_num = np.array(bins_num, pre_mode) + cdf = cumsum.astype(pre_mode) / data_num + max_index = np.where(cdf >= np.array(max_percentile, pre_mode), 0, + 1).sum().astype(pre_mode) + min_index = np.where(cdf >= np.array(1 - min_percentile, pre_mode), 0, + 1).sum().astype(pre_mode) + max_init = max_index / bins_num * (data_max - data_min) + data_min + min_init = min_index / bins_num * (data_max - data_min) + data_min + step = np.arange(search_range[0], + search_range[1], + search_step, + dtype=pre_mode) + if with_offset: + if max_init < 0: + max_init = np.array(0, pre_mode) + if min_init > 0: + min_init = np.array(0, pre_mode) + min_list = min_init * np.ones(step.shape, dtype=pre_mode) + else: + max_init = np.max([np.abs(max_init), np.abs(min_init)]) + max_list = max_init * step + + if with_offset: + scale = (max_list - min_list) / 255 + scale = np.where(scale < 1.192092896e-07, 1, scale) + offset = np.round(min_list / scale) + offset = -(offset + 128) + else: + scale = max_list / 127 + offset = np.round(scale * 0) + + loss_list = np.zeros(step.shape, dtype=pre_mode) + for i in range(step.size): + quant_data = np.round(input_data / scale[i]) + offset[i] + np.clip(quant_data, -128, 127, out=quant_data) + quant_data = (quant_data - offset[i]) * scale[i] + loss_list[i] = np.sum(np.square(quant_data - input_data)) + index = np.argmin(loss_list) + return scale[index], offset[index] + + def npu_op_exec(self, input_data, with_offset): + min_value = torch.min(input_data.cpu()).npu() + max_value = torch.max(input_data.cpu()).npu() + min_value = torch.reshape(min_value, (1, )) + max_value = torch.reshape(max_value, (1, )) + hist = torch.histc(input_data.to('cpu'), + bins=128, + min=min_value[0].to('cpu'), + max=max_value[0].to('cpu')) + cdf = torch.cumsum(hist, dim=0).int() + cdf = cdf.to('npu') + scale, offset = torch_npu.npu_ifmr(input_data, + min_value, + max_value, + cdf, + min_percentile=0.999999, + max_percentile=0.999999, + search_start=0.7, + search_end=1.3, + search_step=0.01, + with_offset=with_offset) + + return scale, offset + + def test_ifrm_with_offset(self, device): + format_list = [0, 3] + shape_list = [(2, 2, 3, 4), (5, 5)] + shape_format = [[np.float32, i, j] for i in format_list + for j in shape_list] + for item in shape_format: + cpu_input, npu_input = create_common_tensor(item, -1, 1) + scale_cpu, offset_cpu = self.cpu_op_exec(cpu_input, + with_offset=True) + scale_npu, offset_npu = self.npu_op_exec(npu_input, + with_offset=True) + self.assertTrue((scale_cpu - scale_npu[0].cpu()) / scale_cpu < 0.0001) + self.assertRtolEqual(offset_cpu, offset_npu[0].cpu().numpy()) + + def test_ifrm_without_offset(self, device): + format_list = [0, 3] + shape_list = [(2, 2, 3, 4), (5, 5)] + shape_format = [[np.float32, i, j] for i in format_list + for j in shape_list] + for item in shape_format: + cpu_input, npu_input = create_common_tensor(item, -1, 1) + scale_cpu, offset_cpu = self.cpu_op_exec(cpu_input, + with_offset=False) + scale_npu, offset_npu = self.npu_op_exec(npu_input, + with_offset=False) + self.assertTrue((scale_cpu - scale_npu[0].cpu()) / scale_cpu < 0.0001) + self.assertRtolEqual(offset_cpu, offset_npu[0].cpu().numpy()) + + +instantiate_device_type_tests(TestIFMR, globals(), except_for="cpu") +if __name__ == "__main__": + run_tests() diff --git a/test/test_network_ops/test_mish.py b/test/test_network_ops/test_mish.py new file mode 100644 index 00000000000..ee60ba4faa8 --- /dev/null +++ b/test/test_network_ops/test_mish.py @@ -0,0 +1,63 @@ +# Copyright (c) 2020 Huawei Technologies Co., Ltd +# Copyright (c) 2019, Facebook CORPORATION. +# All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch_npu +import torch.nn.functional as F +import numpy as np + +from torch_npu.testing.common_utils import TestCase, run_tests +from torch_npu.testing.common_device_type import instantiate_device_type_tests +from torch_npu.testing.util_test import create_common_tensor + + +class TestMish(TestCase): + def npu_op_exec(self, input1): + output = torch_npu.npu_mish(input1) + output = output.cpu().numpy() + return output + + def cpu_op_exec(self, input1): + output = input1 * (torch.tanh(F.softplus(input1))) + output = output.numpy() + return output + + def test_mish_fp32(self, device): + shape_format = [ + [[np.float32, -1, [10, 30, 10]]], + [[np.float32, -1, [20, 30, 20]]], + ] + for item in shape_format: + cpu_input, npu_input = create_common_tensor(item[0], 0, 100) + cpu_output = self.cpu_op_exec(cpu_input) + npu_output = self.npu_op_exec(npu_input) + self.assertRtolEqual(cpu_output, npu_output) + + def test_mish_fp16(self, device): + shape_format = [ + [[np.float16, -1, [10, 30, 10]]], + [[np.float16, -1, [20, 30, 20]]], + ] + for item in shape_format: + cpu_input, npu_input = create_common_tensor(item[0], 0, 100) + cpu_output = self.cpu_op_exec(cpu_input.float()).astype(np.float16) + npu_output = self.npu_op_exec(npu_input) + self.assertRtolEqual(cpu_output, npu_output) + + +instantiate_device_type_tests(TestMish, globals(), except_for='cpu') +if __name__ == "__main__": + run_tests() diff --git a/test/test_network_ops/test_mish_backward.py b/test/test_network_ops/test_mish_backward.py new file mode 100644 index 00000000000..e21a95e6091 --- /dev/null +++ b/test/test_network_ops/test_mish_backward.py @@ -0,0 +1,60 @@ +# Copyright (c) 2020 Huawei Technologies Co., Ltd +# Copyright (c) 2019, Facebook CORPORATION. +# All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch_npu +import torch.nn.functional as F +import numpy as np + +from torch_npu.testing.common_utils import TestCase, run_tests +from torch_npu.testing.common_device_type import instantiate_device_type_tests +from torch_npu.testing.util_test import create_common_tensor + + +class TestMishBackward(TestCase): + def npu_op_exec(self, input1): + input1.requires_grad = True + output = torch_npu.npu_mish(input1) + output.backward(torch.ones_like(output)) + output_grad = input1.grad + output_grad = output_grad.to("cpu") + output_grad = output_grad.detach().numpy() + output = output.cpu().detach().numpy() + return output_grad, output + + def cpu_op_exec(self, input1): + input1.requires_grad = True + output = input1 * (torch.tanh(F.softplus(input1))) + output.backward(torch.ones_like(output)) + output_grad = input1.grad + output_grad = output_grad.to("cpu") + output_grad = output_grad.detach().numpy() + output = output.detach().numpy() + return output_grad, output + + def test_mish_fp32(self, device): + npu_input = torch.tensor( + [1., 2., 3., 4., 5., 6., 7., 8., 9., 10.]).npu() + cpu_input = torch.tensor([1., 2., 3., 4., 5., 6., 7., 8., 9., 10.]) + output_grad, npu_output = self.npu_op_exec(npu_input) + ep_output_grad, ep_npu_output = self.cpu_op_exec(cpu_input) + self.assertRtolEqual(ep_output_grad, output_grad) + self.assertRtolEqual(ep_npu_output, npu_output) + + +instantiate_device_type_tests(TestMishBackward, globals(), except_for='cpu') +if __name__ == "__main__": + run_tests() diff --git a/test/test_network_ops/test_normalize_batch.py b/test/test_network_ops/test_normalize_batch.py index 8756a8f4ae8..c86c973129e 100644 --- a/test/test_network_ops/test_normalize_batch.py +++ b/test/test_network_ops/test_normalize_batch.py @@ -14,8 +14,9 @@ import torch import torch_npu import numpy as np + from torch_npu.testing.common_utils import TestCase, run_tests -from torch_npu.testing.common_device_type import dtypes, instantiate_device_type_tests +from torch_npu.testing.common_device_type import instantiate_device_type_tests from torch_npu.testing.util_test import create_common_tensor diff --git a/test/test_network_ops/test_npu_giou.py b/test/test_network_ops/test_npu_giou.py index 85b4fae5c73..ef0d5aab37c 100644 --- a/test/test_network_ops/test_npu_giou.py +++ b/test/test_network_ops/test_npu_giou.py @@ -18,6 +18,7 @@ import math import torch import torch_npu import numpy as np + from torch_npu.testing.common_utils import TestCase, run_tests from torch_npu.testing.common_device_type import instantiate_device_type_tests from torch_npu.testing.util_test import create_common_tensor diff --git a/test/test_network_ops/test_npu_giou_backward.py b/test/test_network_ops/test_npu_giou_backward.py index 41cdfb3d480..e249f3c307c 100644 --- a/test/test_network_ops/test_npu_giou_backward.py +++ b/test/test_network_ops/test_npu_giou_backward.py @@ -18,6 +18,7 @@ import math import torch import torch_npu import numpy as np + from torch_npu.testing.common_utils import TestCase, run_tests from torch_npu.testing.common_device_type import instantiate_device_type_tests from torch_npu.testing.util_test import create_common_tensor diff --git a/test/test_network_ops/test_scatterv1.py b/test/test_network_ops/test_scatterv1.py index e0f0f8fa4c3..512b0094185 100644 --- a/test/test_network_ops/test_scatterv1.py +++ b/test/test_network_ops/test_scatterv1.py @@ -14,7 +14,8 @@ import torch import torch_npu import numpy as np -from torch_npu.testing.common_device_type import dtypes, instantiate_device_type_tests + +from torch_npu.testing.common_device_type import instantiate_device_type_tests from torch_npu.testing.util_test import create_common_tensor from torch_npu.testing.common_utils import TestCase, run_tests diff --git a/test/test_network_ops/test_yolo_boxes_encode.py b/test/test_network_ops/test_yolo_boxes_encode.py index 6c32695c20f..056a429a53e 100644 --- a/test/test_network_ops/test_yolo_boxes_encode.py +++ b/test/test_network_ops/test_yolo_boxes_encode.py @@ -14,8 +14,9 @@ import torch import torch_npu import numpy as np + from torch_npu.testing.common_utils import TestCase, run_tests -from torch_npu.testing.common_device_type import dtypes, instantiate_device_type_tests +from torch_npu.testing.common_device_type import instantiate_device_type_tests from torch_npu.testing.util_test import create_common_tensor diff --git a/torch_npu/csrc/aten/npu_native_functions.yaml b/torch_npu/csrc/aten/npu_native_functions.yaml index 300dfc212c1..7a84532e9d4 100644 --- a/torch_npu/csrc/aten/npu_native_functions.yaml +++ b/torch_npu/csrc/aten/npu_native_functions.yaml @@ -1922,7 +1922,16 @@ custom: variants: function, method - func: npu_scatter(Tensor self, Tensor indices, Tensor updates, int dim) -> Tensor variants: function, method + - func: npu_ifmr(Tensor data, Tensor data_min, Tensor data_max, Tensor cumsum, float min_percentile, float max_percentile, float search_start, float search_end, float search_step, bool with_offset) -> (Tensor, Tensor) + - func: npu_mish_backward(Tensor grad, Tensor input) -> Tensor + - func: npu_deformable_conv2dbk(Tensor input, Tensor grad_output, Tensor offset_out, Tensor weight, Tensor offset, int[2] kernel_size, int[] stride, int[] padding, int[] dilation=[1,1,1,1], int groups=1, int deformable_groups=1, bool modulated=True) -> (Tensor, Tensor, Tensor, Tensor) + - func: npu_sub_sample(Tensor self, int per_images, float positive_fraction) -> Tensor + custom_autograd: - func: npu_convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, int groups) -> Tensor - func: npu_convolution_transpose(Tensor input, Tensor weight, Tensor? bias, int[] padding, int[] output_padding, int[] stride, int[] dilation, int groups) -> Tensor - func: npu_giou(Tensor self, Tensor gtboxes, bool trans=False, bool is_cross=False, int mode=0) -> Tensor + - func: npu_deformable_conv2d(Tensor input, Tensor weight, Tensor offset, Tensor? bias, int[2] kernel_size, int[] stride, int[] padding, int[] dilation=[1,1,1,1], int groups=1, int deformable_groups=1, bool modulated=True) -> (Tensor, Tensor) + - func: npu_mish(Tensor self) -> Tensor + variants: function, method + diff --git a/torch_npu/csrc/aten/ops/GiouKernelNpu.cpp b/torch_npu/csrc/aten/ops/GiouKernelNpu.cpp index a7956d338ea..0fd4da964e9 100644 --- a/torch_npu/csrc/aten/ops/GiouKernelNpu.cpp +++ b/torch_npu/csrc/aten/ops/GiouKernelNpu.cpp @@ -14,9 +14,10 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include + #include "torch_npu/csrc/framework/utils/OpAdapter.h" #include "torch_npu/csrc/framework/utils/CalcuOpUtil.h" -#include #include "torch_npu/csrc/aten/NPUNativeFunctions.h" namespace at_npu { diff --git a/torch_npu/csrc/aten/ops/IfmrKernelNpu.cpp b/torch_npu/csrc/aten/ops/IfmrKernelNpu.cpp new file mode 100644 index 00000000000..22c0c388546 --- /dev/null +++ b/torch_npu/csrc/aten/ops/IfmrKernelNpu.cpp @@ -0,0 +1,63 @@ +// Copyright (c) 2020 Huawei Technologies Co., Ltd +// Copyright (c) 2019, Facebook CORPORATION. +// All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" + +namespace at_npu { +namespace native { + +tuple NPUNativeFunctions::npu_ifmr( + const at::Tensor& data, + const at::Tensor& data_min, + const at::Tensor& data_max, + const at::Tensor& cumsum, + const double min_percentile=0.999999, + const double max_percentile=0.999999, + const double search_start=0.7, + const double search_end=1.3, + const double search_step=0.01, + const bool with_offset=true) { + at::Tensor scale = OpPreparation::ApplyTensorWithFormat( + data_min, ACL_FORMAT_NCHW); + at::Tensor offset = OpPreparation::ApplyTensorWithFormat( + data_min, ACL_FORMAT_NCHW); + + std::vector tmp; + tmp.push_back(static_cast(search_start)); + tmp.push_back(static_cast(search_end)); + at::ArrayRef searchRange(tmp); + + OpCommand cmd; + cmd.Name("IFMR") + .Input(data) + .Input(data_min) + .Input(data_max) + .Input(cumsum) + .Attr("min_percentile", static_cast(min_percentile)) + .Attr("max_percentile", static_cast(max_percentile)) + .Attr("search_range", searchRange) + .Attr("search_step", static_cast(search_step)) + .Attr("with_offset", with_offset) + .Output(scale) + .Output(offset) + .Run(); + + return std::tie(scale, offset); +} + +} // namespace native +} // namespace at_npu diff --git a/torch_npu/csrc/aten/ops/MishKernelNpu.cpp b/torch_npu/csrc/aten/ops/MishKernelNpu.cpp new file mode 100644 index 00000000000..dd210d3459f --- /dev/null +++ b/torch_npu/csrc/aten/ops/MishKernelNpu.cpp @@ -0,0 +1,74 @@ +// Copyright (c) 2020 Huawei Technologies Co., Ltd +// Copyright (c) 2019, Facebook CORPORATION. +// All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" + +namespace at_npu { +namespace native { +using namespace torch::autograd; + +at::Tensor mish_kernel_npu(const at::Tensor& self) { + at::Tensor result = OpPreparation::ApplyTensor(self); + + OpCommand cmd; + cmd.Name("Mish") + .Input(self) + .Output(result) + .Run(); + return result; +} + +at::Tensor NPUNativeFunctions::npu_mish_backward(const at::Tensor& grad, const at::Tensor& input) { + at::Tensor result = OpPreparation::ApplyTensor(input); + + OpCommand cmd; + cmd.Name("MishGrad") + .Input(grad) + .Input(input) + .Output(result) + .Run(); + return result; +} + +class NPUMishFunction : public torch::autograd::Function { +public: + static at::Tensor forward(AutogradContext *ctx, + const at::Tensor& self) { + at::AutoNonVariableTypeMode g; + ctx->save_for_backward({self}); + return mish_kernel_npu(self); + } + + static tensor_list backward(AutogradContext *ctx, + tensor_list grad_outputs) { + auto saved = ctx->get_saved_variables(); + auto input = saved[0]; + + at::Tensor result = NPUNativeFunctions::npu_mish_backward(grad_outputs[0], input); + tensor_list output = {result}; + return output; + } +}; + +at::Tensor NPUNativeFunctions::npu_mish(const at::Tensor& self) { + return NPUMishFunction::apply(self); +} + +} // namespace native +} // namespace at_npu \ No newline at end of file diff --git a/torch_npu/csrc/aten/ops/SubSampleKernelNpu.cpp b/torch_npu/csrc/aten/ops/SubSampleKernelNpu.cpp new file mode 100644 index 00000000000..a7acaeecf39 --- /dev/null +++ b/torch_npu/csrc/aten/ops/SubSampleKernelNpu.cpp @@ -0,0 +1,37 @@ +// Copyright (c) 2020, Huawei Technologies.All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" + +namespace at_npu { +namespace native { + +at::Tensor NPUNativeFunctions::npu_sub_sample( + const at::Tensor &self, + int64_t per_images, + double positive_fraction) { + at::Tensor result = OpPreparation::ApplyTensor(self); + OpCommand cmd; + cmd.Name("SubSample") + .Input(self) + .Output(result) + .Attr("batch_size_per_images", per_images) + .Attr("positive_fraction", (float)positive_fraction) + .Run(); + return result; +} + +} // namespace native +} // namespace at_npu diff --git a/torch_npu/csrc/aten/ops/convolution/DeformableConv2dKernelNpu.cpp b/torch_npu/csrc/aten/ops/convolution/DeformableConv2dKernelNpu.cpp new file mode 100644 index 00000000000..5a2acf44eb9 --- /dev/null +++ b/torch_npu/csrc/aten/ops/convolution/DeformableConv2dKernelNpu.cpp @@ -0,0 +1,217 @@ +// Copyright (c) 2020 Huawei Technologies Co., Ltd +// Copyright (c) 2019, Facebook CORPORATION. +// All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" + +namespace at_npu { +namespace native { +using namespace torch::autograd; + +tuple deformable_conv2d_kernel_npu( + const at::Tensor& input, + const at::Tensor& weight, + const at::Tensor& offset, + const c10::optional& bias_opt, + at::IntArrayRef kernel_size, + at::IntArrayRef stride, + at::IntArrayRef padding, + at::IntArrayRef dilation, + int64_t groups, + int64_t deformable_groups, + bool modulated) { + const at::Tensor& bias = c10::value_or_else(bias_opt, [] {return at::Tensor();}); + auto outputSize = deformable_conv2d_npu_output_size( + input, weight, offset, bias, kernel_size, stride, padding, dilation, groups, deformable_groups, modulated); + + at::Tensor deformableOffsetsOutput = OpPreparation::ApplyTensorWithFormat(outputSize, input.options(), ACL_FORMAT_NCHW); + + string dataFormat = "NCHW"; + OpCommand cmd; + cmd.Name("DeformableOffsets") + .Input(input) + .Input(offset) + .Output(deformableOffsetsOutput) + .Attr("ksize", kernel_size) + .Attr("strides", stride) + .Attr("pads", padding) + .Attr("dilations", dilation) + .Attr("deformable_groups", deformable_groups) + .Attr("data_format",dataFormat) + .Attr("modulated",modulated) + .Run(); + + c10::SmallVector conv2dStride = array_to_small_vector(kernel_size); + c10::SmallVector conv2dPadding = {0, 0, 0, 0}; + c10::SmallVector conv2dDilation = {1, 1}; + at::Tensor conv2dOutput = NPUNativeFunctions::npu_conv2d( + deformableOffsetsOutput, weight, bias, conv2dStride, conv2dPadding, conv2dDilation, groups); + + return std::tie(conv2dOutput, deformableOffsetsOutput); +} + +tuple NPUNativeFunctions::npu_deformable_conv2dbk( + const at::Tensor& input, + const at::Tensor& grad_output, + const at::Tensor& offset_out, + const at::Tensor& weight, + const at::Tensor& offset, + at::IntArrayRef kernel_size, + at::IntArrayRef stride, + at::IntArrayRef padding, + at::IntArrayRef dilation, + int64_t groups, + int64_t deformable_groups, + bool modulated) { + at::Tensor grad_input = OpPreparation::ApplyTensorWithFormat(input, ACL_FORMAT_NCHW); + at::Tensor grad_offset = OpPreparation::ApplyTensorWithFormat(offset, ACL_FORMAT_NCHW); + + // deformable_conv2d_backward includes conv2d_backward and DeformableOffsetsGrad + c10::SmallVector conv2dStride = array_to_small_vector(kernel_size); + c10::SmallVector conv2dPadding = {0, 0, 0, 0}; + c10::SmallVector conv2dDilation = {1, 1}; + auto conv2dBackwardOutput = NPUNativeFunctions::npu_conv2d_backward( + offset_out, grad_output, weight, conv2dStride, conv2dPadding, conv2dDilation, groups, {true, true, true}); + + // DeformableOffsetsGrad's input 'grad' is the output[0] of conv2d_backward + at::Tensor deformableOffsetsBackwardInput = std::get<0>(conv2dBackwardOutput); + at::Tensor grad_weight = std::get<1>(conv2dBackwardOutput); + at::Tensor grad_bias = std::get<2>(conv2dBackwardOutput); + + string dataFormat = "NCHW"; + OpCommand cmd; + cmd.Name("DeformableOffsetsGrad") + .Input(deformableOffsetsBackwardInput) + .Input(input) + .Input(offset) + .Output(grad_input) + .Output(grad_offset) + .Attr("strides", stride) + .Attr("pads", padding) + .Attr("ksize", kernel_size) + .Attr("dilations", dilation) + .Attr("data_format",dataFormat) + .Attr("deformable_groups", deformable_groups) + .Attr("modulated",modulated) + .Run(); + + return std::tie(grad_input, grad_weight, grad_offset, grad_bias); +} + +class NPUDeformableConv2dFunction : public torch::autograd::Function { +public: + static tensor_list forward(AutogradContext *ctx, + const at::Tensor& input, + const at::Tensor& weight, + const at::Tensor& offset, + const c10::optional& bias_opt, + at::IntArrayRef kernel_size, + at::IntArrayRef stride, + at::IntArrayRef padding, + at::IntArrayRef dilation, + int64_t groups, + int64_t deformable_groups, + bool modulated) { + ctx->saved_data["kernel_size"] = kernel_size; + ctx->saved_data["stride"] = stride; + ctx->saved_data["padding"] = padding; + ctx->saved_data["dilation"] = dilation; + ctx->saved_data["groups"] = groups; + ctx->saved_data["deformable_groups"] = deformable_groups; + ctx->saved_data["modulated"] = modulated; + ctx->saved_data["bias_has_value"] = bias_opt.has_value() ? bias_opt.value().requires_grad() : false; + + at::AutoNonVariableTypeMode g; + ctx->save_for_backward({input, weight, offset}); + auto result = deformable_conv2d_kernel_npu(input, weight, offset, bias_opt, kernel_size, stride, + padding, dilation, groups, deformable_groups, modulated); + auto result1 = std::get<1>(result); + ctx->saved_data["offset_out"] = result1; + tensor_list result_list = {std::get<0>(result), result1}; + return result_list; + } + + static tensor_list backward(AutogradContext *ctx, + tensor_list grad_outputs) { + auto kernel_size = ctx->saved_data["kernel_size"].toIntVector(); + auto stride = ctx->saved_data["stride"].toIntVector(); + auto padding = ctx->saved_data["padding"].toIntVector(); + auto dilation = ctx->saved_data["dilation"].toIntVector(); + auto groups = ctx->saved_data["groups"].toInt(); + auto deformable_groups = ctx->saved_data["deformable_groups"].toInt(); + auto modulated = ctx->saved_data["modulated"].toBool(); + auto bias_has_value = ctx->saved_data["bias_has_value"].toBool(); + auto offset_out = ctx->saved_data["offset_out"].toTensor(); + + auto saved = ctx->get_saved_variables(); + auto input = saved[0]; + auto weight = saved[1]; + auto offset = saved[2]; + tuple result = NPUNativeFunctions::npu_deformable_conv2dbk(input, grad_outputs[0], + offset_out, weight, offset, kernel_size, + stride, padding, dilation, groups, deformable_groups, modulated); + + tensor_list output; + if (bias_has_value) { + output = {std::get<0>(result), + std::get<1>(result), + std::get<2>(result), + std::get<3>(result), + at::Tensor(), + at::Tensor(), + at::Tensor(), + at::Tensor(), + at::Tensor(), + at::Tensor(), + at::Tensor()}; + } else { + output = {std::get<0>(result), + std::get<1>(result), + std::get<2>(result), + at::Tensor(), + at::Tensor(), + at::Tensor(), + at::Tensor(), + at::Tensor(), + at::Tensor(), + at::Tensor(), + at::Tensor()}; + } + return output; + } +}; + +tuple NPUNativeFunctions::npu_deformable_conv2d(const at::Tensor& input, + const at::Tensor& weight, + const at::Tensor& offset, + const c10::optional& bias_opt, + at::IntArrayRef kernel_size, + at::IntArrayRef stride, + at::IntArrayRef padding, + at::IntArrayRef dilation, + int64_t groups, + int64_t deformable_groups, + bool modulated) { + auto result = NPUDeformableConv2dFunction::apply(input, weight, offset, bias_opt, kernel_size, stride, + padding, dilation, groups, deformable_groups, modulated); + std::tuple output(result[0], result[1]); + return output; +} + +} // namespace native +} // namespace at_npu -- Gitee