From 440286e3e1404ef318b791727eb7b0e174668745 Mon Sep 17 00:00:00 2001 From: hxf12345677 Date: Thu, 10 Mar 2022 19:22:14 +0800 Subject: [PATCH] =?UTF-8?q?npu=5Fgiou,npu=5Fgiou=5Fbackward,=20npu=5Fiou,?= =?UTF-8?q?=20npu=5Fnms=5Fv4,=20npu=5Fpad,=20npu=5Frandom=5Fchoice=5Fwith?= =?UTF-8?q?=5Fmask,=20npu=5Fnormalize=5Fbatch=E7=AE=97=E5=AD=90=E9=80=82?= =?UTF-8?q?=E9=85=8D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- test/test_network_ops/test_iou.py | 58 +++++ test/test_network_ops/test_nms_v4.py | 45 ++++ test/test_network_ops/test_normalize_batch.py | 67 ++++++ test/test_network_ops/test_npu_giou.py | 129 +++++++++++ .../test_npu_giou_backward.py | 83 +++++++ test/test_network_ops/test_npu_pad.py | 34 +++ .../test_random_choice_with_mask.py | 31 +++ torch_npu/csrc/aten/npu_native_functions.yaml | 10 +- torch_npu/csrc/aten/ops/GiouKernelNpu.cpp | 206 ++++++++++++++++++ torch_npu/csrc/aten/ops/IouKernelNpu.cpp | 68 ++++++ torch_npu/csrc/aten/ops/NmsV4KernelNpu.cpp | 80 +++++++ .../csrc/aten/ops/NormalizeBatchKernelNpu.cpp | 67 ++++++ torch_npu/csrc/aten/ops/PadKernelNpu.cpp | 48 ++++ .../ops/RandomChoiceWithMaskKernelNpu.cpp | 53 +++++ 14 files changed, 978 insertions(+), 1 deletion(-) create mode 100644 test/test_network_ops/test_iou.py create mode 100644 test/test_network_ops/test_nms_v4.py create mode 100644 test/test_network_ops/test_normalize_batch.py create mode 100644 test/test_network_ops/test_npu_giou.py create mode 100644 test/test_network_ops/test_npu_giou_backward.py create mode 100644 test/test_network_ops/test_npu_pad.py create mode 100644 test/test_network_ops/test_random_choice_with_mask.py create mode 100644 torch_npu/csrc/aten/ops/GiouKernelNpu.cpp create mode 100644 torch_npu/csrc/aten/ops/IouKernelNpu.cpp create mode 100644 torch_npu/csrc/aten/ops/NmsV4KernelNpu.cpp create mode 100644 torch_npu/csrc/aten/ops/NormalizeBatchKernelNpu.cpp create mode 100644 torch_npu/csrc/aten/ops/PadKernelNpu.cpp create mode 100644 torch_npu/csrc/aten/ops/RandomChoiceWithMaskKernelNpu.cpp diff --git a/test/test_network_ops/test_iou.py b/test/test_network_ops/test_iou.py new file mode 100644 index 00000000000..bdb1914a601 --- /dev/null +++ b/test/test_network_ops/test_iou.py @@ -0,0 +1,58 @@ +# Copyright (c) 2020, Huawei Technologies.All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch_npu + +from torch_npu.testing.testcase import TestCase, run_tests + +class TestIou(TestCase): + def test_iou_fp16(self, device="npu"): + bboxes = torch.tensor([[0, 0, 10, 10], + [10, 10, 20, 20], + [32, 32, 38, 42]], dtype=torch.float16).to("npu") + gtboxes = torch.tensor([[0, 0, 10, 20], + [0, 10, 10, 10], + [10, 10, 20, 20]], dtype=torch.float16).to("npu") + expect_iof = torch.tensor([[0.4990, 0.0000, 0.0000], + [0.0000, 0.0000, 0.0000], + [0.0000, 0.9980, 0.0000]], dtype=torch.float16) + output_iof = torch_npu.npu_iou(bboxes, gtboxes, 1) + self.assertRtolEqual(expect_iof, output_iof.cpu()) + + expect_iou = torch.tensor([[0.4985, 0.0000, 0.0000], + [0.0000, 0.0000, 0.0000], + [0.0000, 0.9961, 0.0000]], dtype=torch.float16) + output_iou = torch_npu.npu_iou(bboxes, gtboxes, 0) + self.assertRtolEqual(expect_iou, output_iou.cpu()) + + def test_iou_fp16_pt(self, device="npu"): + bboxs = torch.tensor([[1, 2, 3, 4], + [5, 6, 7, 8], + [9, 10, 11, 12], + [13, 14, 15, 16]], dtype = torch.float16).npu() + gtboxes = torch.tensor([[1, 2, 3, 4], + [5, 6, 7, 8]], dtype = torch.float16).npu() + expect_iof = torch.tensor([[0.9902, 0.0000, 0.0000, 0.0000], + [0.0000, 0.9902, 0.0000, 0.0000]], dtype = torch.float16) + output_iof = torch_npu.npu_ptiou(bboxs, gtboxes, 1) + self.assertRtolEqual(expect_iof, output_iof.cpu(), 1.e-3) + + expect_iou = torch.tensor([[0.9805, 0.0000, 0.0000, 0.0000], + [0.0000, 0.9805, 0.0000, 0.0000]], dtype = torch.float16) + output_iou = torch_npu.npu_ptiou(bboxs, gtboxes, 0) + self.assertRtolEqual(expect_iou, output_iou.cpu(), 1.e-3) + +if __name__ == "__main__": + run_tests() \ No newline at end of file diff --git a/test/test_network_ops/test_nms_v4.py b/test/test_network_ops/test_nms_v4.py new file mode 100644 index 00000000000..c5aa47a9d84 --- /dev/null +++ b/test/test_network_ops/test_nms_v4.py @@ -0,0 +1,45 @@ +# Copyright (c) 2020, Huawei Technologies.All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch_npu +import numpy as np + +from torch_npu.testing.testcase import TestCase, run_tests + +class TestNmsV4(TestCase): + def generate_data(self, min1, max1, shape, dtype): + input1 = np.random.uniform(min1, max1, shape).astype(dtype) + npu_input = torch.from_numpy(input1) + return npu_input + + def npu_op_exec(self, boxes, scores, max_output_size, iou_threshold, scores_threshold): + boxes = boxes.to("npu") + scores = scores.to("npu") + iou_threshold = iou_threshold.to("npu") + scores_threshold = scores_threshold.to("npu") + npu_output = torch_npu.npu_nms_v4(boxes, scores, max_output_size, iou_threshold, scores_threshold) + return npu_output + + def test_nms_v4_float32(self, device="npu"): + boxes = self.generate_data(0, 100, (100, 4), np.float32) + scores = self.generate_data(0, 1, (100), np.float32) + max_output_size = 20 + iou_threshold = torch.tensor(0.5) + scores_threshold = torch.tensor(0.3) + + npu_output = self.npu_op_exec(boxes, scores, max_output_size, iou_threshold, scores_threshold) + +if __name__ == "__main__": + run_tests() \ No newline at end of file diff --git a/test/test_network_ops/test_normalize_batch.py b/test/test_network_ops/test_normalize_batch.py new file mode 100644 index 00000000000..2c84fc8a91c --- /dev/null +++ b/test/test_network_ops/test_normalize_batch.py @@ -0,0 +1,67 @@ +# Copyright (c) 2020, Huawei Technologies.All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch_npu +import numpy as np + +from torch_npu.testing.testcase import TestCase, run_tests +from torch_npu.testing.common_utils import create_common_tensor + +class TestNormalizeBatch(TestCase): + def cpu_op_exec(self, input1, seq_len, normalize_type): + if normalize_type == 0: + x_mean = torch.zeros((seq_len.shape[0], input1.shape[1]), dtype=input1.dtype, + device=input1.device) + x_std = torch.zeros((seq_len.shape[0], input1.shape[1]), dtype=input1.dtype, + device=input1.device) + for i in range(input1.shape[0]): + x_mean[i, :] = input1[i, :, :seq_len[i]].mean(dim=1) + x_std[i, :] = input1[i, :, :seq_len[i]].std(dim=1) + x_std += 1e-5 + result = (input1 - x_mean.unsqueeze(2)) / x_std.unsqueeze(2) + else: + x_mean = torch.zeros(seq_len.shape, dtype=input1.dtype, + device=input1.device) + x_std = torch.zeros(seq_len.shape, dtype=input1.dtype, + device=input1.device) + for i in range(input1.shape[0]): + x_mean[i] = input1[i, :, :int(seq_len[i])].mean() + x_std[i] = input1[i, :, :int(seq_len[i])].std() + x_std += 1e-5 + result = (input1 - x_mean.view(-1, 1, 1)) / x_std.view(-1, 1, 1) + return result.numpy() + + def npu_op_exec(self, input1, seq_len, normalize_type): + out = torch.npu_normalize_batch(input1, seq_len, normalize_type) + out = out.to("cpu") + return out.detach().numpy() + + def test_normalize_batch(self, device="npu"): + shape_format = [ + [[np.float32, -1, [32, 3, 6]], [np.int32, -1, [32]], 0], + [[np.float32, -1, [32, 3, 6]], [np.int32, -1, [32]], 1], + [[np.float32, -1, [16, 6, 1000]], [np.int32, -1, [16]], 0], + [[np.float32, -1, [16, 6, 1000]], [np.int32, -1, [16]], 1] + ] + for item in shape_format: + right_range = item[0][-1][-1] + cpu_input1, npu_input1 = create_common_tensor(item[0], 1, 10) + cpu_seqlen, npu_seqlen = create_common_tensor(item[1], 3, right_range) + cpu_output = self.cpu_op_exec(cpu_input1, cpu_seqlen, item[-1]) + npu_output = self.npu_op_exec(npu_input1, npu_seqlen, item[-1]) + self.assertRtolEqual(cpu_output, npu_output) + +if __name__ == "__main__": + run_tests() \ No newline at end of file diff --git a/test/test_network_ops/test_npu_giou.py b/test/test_network_ops/test_npu_giou.py new file mode 100644 index 00000000000..2581c892347 --- /dev/null +++ b/test/test_network_ops/test_npu_giou.py @@ -0,0 +1,129 @@ +# Copyright (c) 2020 Huawei Technologies Co., Ltd +# Copyright (c) 2019, Facebook CORPORATION. +# All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import torch +import torch_npu +import numpy as np + +from torch_npu.testing.testcase import TestCase, run_tests + +class TestNpuGiou(TestCase): + def generate_giou_data(self, n, m, dtype): + data_bboxes = np.array([]).astype(dtype) + for i in range(4): + data_bboxes_array = i // 2 + math.pow(-1, i // 2) * 0.5 * np.random.rand(1, n).astype(dtype) + data_bboxes = np.append(data_bboxes, data_bboxes_array) + data_bboxes = data_bboxes.reshape([4, n]) + data_gtboxes = np.array([]).astype(dtype) + for i in range(4): + data_gtboxes_array = i // 2 + math.pow(-1, i // 2) * 0.5 * np.random.rand(1, m).astype(dtype) + data_gtboxes = np.append(data_gtboxes, data_gtboxes_array) + data_gtboxes = data_gtboxes.reshape([4, m]) + cpu_input1 = torch.from_numpy(data_bboxes) + cpu_input2 = torch.from_numpy(data_gtboxes) + npu_input1 = cpu_input1.npu() + npu_input2 = cpu_input2.npu() + list1 = [cpu_input1, cpu_input2, npu_input1, npu_input2] + return list1 + + def cpu_op_exec(self, box1, box2, trans=False, is_cross=False, mode="iou"): + box1 = box1.numpy() + box2 = box2.numpy() + dtype = box1.dtype + _, n = box1.shape + _, m = box2.shape + if trans: + b1_x1, b1_x2 = box1[0] - box1[2] / 2, box1[0] + box1[2] / 2 + b1_y1, b1_y2 = box1[1] - box1[3] / 2, box1[1] + box1[3] / 2 + b2_x1, b2_x2 = box2[0] - box2[2] / 2, box2[0] + box2[2] / 2 + b2_y1, b2_y2 = box2[1] - box2[3] / 2, box2[1] + box2[3] / 2 + else: + b1_x1, b1_y1, b1_x2, b1_y2 = box1[0], box1[1], box1[2], box1[3] + b2_x1, b2_y1, b2_x2, b2_y2 = box2[0], box2[1], box2[2], box2[3] + w1, h1 = b1_x2 - b1_x1, b1_y2 - b1_y1 + w2, h2 = b2_x2 - b2_x1, b2_y2 - b2_y1 + area1 = w1 * h1 + area2 = w2 * h2 + giou_res = np.array([], dtype=dtype) + + for i in range(n): + for j in range(m): + inter_x1 = max(b1_x1[i], b2_x1[j]) + inter_x2 = min(b1_x2[i], b2_x2[j]) + inter_y1 = max(b1_y1[i], b2_y1[j]) + inter_y2 = min(b1_y2[i], b2_y2[j]) + outer_x1 = min(b1_x1[i], b2_x1[j]) + outer_x2 = max(b1_x2[i], b2_x2[j]) + outer_y1 = min(b1_y1[i], b2_y1[j]) + outer_y2 = max(b1_y2[i], b2_y2[j]) + inter_area = max(0, (inter_x2 - inter_x1)) * max(0, (inter_y2 - inter_y1)) + outer_area = abs(outer_x2 - outer_x1) * abs(outer_y2 - outer_y1) + union_area = area1[i] + area2[j] - inter_area + 1e-16 + other_area = outer_area - union_area + giou_ij = inter_area / union_area - other_area / outer_area + if not is_cross: + if i == j: + giou_res = np.append(giou_res, giou_ij) + else: + giou_res = np.append(giou_res, giou_ij) + + if not is_cross: + res = giou_res.reshape(1, n) + else: + res = giou_res.reshape(n, m) + res = np.transpose(res) + res = np.transpose(res) + return res + + def npu_op_exec(self, box1, box2, trans=False, is_cross=False, mode=0): + output = torch_npu.npu_giou(box1, box2, trans, is_cross, mode) + output = output.detach().cpu().numpy() + return output + + def test_npu_giou_shape_format_fp32(self, device="npu"): + self._test_npu_giou_shape_format(np.float32) + + def test_npu_giou_shape_format_fp16(self, device="npu"): + self._test_npu_giou_shape_format(np.float16) + + def _test_npu_giou_shape_format(self, dtype): + shape_list = [ + [10, 10], + [12, 12], + [100, 100] + ] + is_trans_list = [True] + mode_list = ["iou"] + shape_format = [[j, k, m] + for j in shape_list + for k in is_trans_list + for m in mode_list] + + for item in shape_format: + mode_digit = 0 if item[-1] == "iou" else 1 + is_cross = False if item[0][0] == item[0][1] else True + list1 = self.generate_giou_data(*item[0], dtype) + cpu_output = self.cpu_op_exec(list1[0], list1[1], item[1], is_cross, item[-1]) + npu_output = self.npu_op_exec(list1[2], list1[3], item[1], is_cross, mode_digit) + cpu_output = cpu_output.astype(npu_output.dtype) + if dtype == np.float16: + self.assertRtolEqual(cpu_output, npu_output, prec16=1e-2) + else: + self.assertRtolEqual(cpu_output, npu_output) + +if __name__ == "__main__": + run_tests() \ No newline at end of file diff --git a/test/test_network_ops/test_npu_giou_backward.py b/test/test_network_ops/test_npu_giou_backward.py new file mode 100644 index 00000000000..1e6cde0bd69 --- /dev/null +++ b/test/test_network_ops/test_npu_giou_backward.py @@ -0,0 +1,83 @@ +# Copyright (c) 2020 Huawei Technologies Co., Ltd +# Copyright (c) 2019, Facebook CORPORATION. +# All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import torch +import torch_npu +import numpy as np + +from torch_npu.testing.testcase import TestCase, run_tests + +class TestNpuGiouBackward(TestCase): + def generate_giou_data(self, n, m, dtype): + data_bboxes1 = np.array([]).astype(dtype) + for i in range(4): + data_bboxes_array1 = i // 2 + math.pow(-1, i // 2) * 0.5 * np.random.rand(1, n).astype(dtype) + data_bboxes1 = np.append(data_bboxes1, data_bboxes_array1) + data_bboxes1 = data_bboxes1.reshape([4, n]) + data_gtboxes1 = np.array([]).astype(dtype) + for i in range(4): + data_gtboxes_array1 = i // 2 + math.pow(-1, i // 2) * 0.5 * np.random.rand(1, m).astype(dtype) + data_gtboxes1 = np.append(data_gtboxes1, data_gtboxes_array1) + data_gtboxes1 = data_gtboxes1.reshape([4, m]) + cpu_input1 = torch.from_numpy(data_bboxes1) + cpu_input2 = torch.from_numpy(data_gtboxes1) + npu_input1 = cpu_input1.npu() + npu_input2 = cpu_input2.npu() + list1 = [cpu_input1, cpu_input2, npu_input1, npu_input2] + return list1 + + def npu_op_exec(self, box1, box2, trans=False, is_cross=False, mode=0): + box1.requires_grad = True + box2.requires_grad = True + output = torch_npu.npu_giou(box1, box2, trans, is_cross, mode) + output.backward(torch.ones_like(output)) + box1_grad = box1.grad + box2_grad = box2.grad + box1_grad = box1_grad.detach().cpu().numpy() + box2_grad = box2_grad.detach().cpu().numpy() + output = output.detach().cpu().numpy() + return output, box1_grad, box2_grad + + def test_npu_giou_backward_shape_format(self): + shape_list1 = [ + [1, 1] + ] + is_trans_list1 = [True] + mode_list1 = ["iou"] + shape_format1 = [[j, k, m] + for j in shape_list1 + for k in is_trans_list1 + for m in mode_list1] + + for item in shape_format1: + mode_digit = 0 if item[-1] == "iou" else 1 + is_cross = False if item[0][0] == item[0][1] else True + expected_cpu_grad1 = np.array([[1.0218241], + [-1.4181931], + [0.18631615], + [0.1747725]], dtype=np.float32) + expected_cpu_grad2 = np.array([[-1.0218241], + [1.4181931], + [0.17999186], + [0.23653218]], dtype=np.float32) + list1 = self.generate_giou_data(*item[0], np.float32) + _, npu_grad1, npu_grad2 = self.npu_op_exec(list1[2], list1[3], item[1], is_cross, mode_digit) + self.assertRtolEqual(expected_cpu_grad1, npu_grad1) + self.assertRtolEqual(expected_cpu_grad2, npu_grad2) + +if __name__ == "__main__": + run_tests() diff --git a/test/test_network_ops/test_npu_pad.py b/test/test_network_ops/test_npu_pad.py new file mode 100644 index 00000000000..5ff36c7a999 --- /dev/null +++ b/test/test_network_ops/test_npu_pad.py @@ -0,0 +1,34 @@ +# Copyright (c) 2020, Huawei Technologies.All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch_npu +import numpy as np + +from torch_npu.testing.testcase import TestCase, run_tests + +class TestNpuPad(TestCase): + def test_npu_pad(self, device="npu"): + npu_input = torch.ones(2, 2).npu() + pads = (1, 1, 1, 1) + benchmark = torch.tensor([[0., 0., 0., 0.], + [0., 1., 1., 0.], + [0., 1., 1., 0.], + [0., 0., 0., 0.]]) + npu_output = torch_npu.npu_pad(npu_input, pads) + npu_output = npu_output.cpu().detach() + self.assertRtolEqual(benchmark, npu_output) + +if __name__ == "__main__": + run_tests() diff --git a/test/test_network_ops/test_random_choice_with_mask.py b/test/test_network_ops/test_random_choice_with_mask.py new file mode 100644 index 00000000000..9d79606f92e --- /dev/null +++ b/test/test_network_ops/test_random_choice_with_mask.py @@ -0,0 +1,31 @@ +# Copyright (c) 2020, Huawei Technologies.All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch_npu +import numpy as np + +from torch_npu.testing.testcase import TestCase, run_tests + +class TestRandomChoiceWithMask(TestCase): + def test_random_choice_with_mask_fp32(self, device="npu"): + input_bool = torch.tensor([1, 0, 1, 0], dtype=torch.bool).npu() + expect_ret = torch.tensor([[0], [2]], dtype=torch.int32) + expect_mask = torch.tensor([True, True]) + result, mask = torch_npu.npu_random_choice_with_mask(input_bool, 2, 1, 0) + self.assertRtolEqual(expect_ret, result.cpu()) + self.assertRtolEqual(expect_mask, mask.cpu()) + +if __name__ == "__main__": + run_tests() \ No newline at end of file diff --git a/torch_npu/csrc/aten/npu_native_functions.yaml b/torch_npu/csrc/aten/npu_native_functions.yaml index a93fad392f8..087ee04c215 100644 --- a/torch_npu/csrc/aten/npu_native_functions.yaml +++ b/torch_npu/csrc/aten/npu_native_functions.yaml @@ -1934,6 +1934,13 @@ custom: - func: npu_batch_nms(Tensor self, Tensor scores, float score_threshold, float iou_threshold, int max_size_per_class, int max_total_size, bool change_coordinate_frame=False, bool transpose_box=False) -> (Tensor, Tensor, Tensor, Tensor) - func: npu_bounding_box_encode(Tensor anchor_box, Tensor ground_truth_box, float means0, float means1, float means2, float means3, float stds0, float stds1, float stds2, float stds3) -> Tensor - func: npu_bounding_box_decode(Tensor rois, Tensor deltas, float means0, float means1, float means2, float means3, float stds0, float stds1, float stds2, float stds3, int[1] max_shape, float wh_ratio_clip) -> Tensor + - func: npu_giou_backward(Tensor grad, Tensor bboxes, Tensor gtboxes, bool trans=False, bool is_cross=False, int mode=0) -> (Tensor, Tensor) + - func: npu_iou(Tensor bboxes, Tensor gtboxes, int mode=0) -> Tensor + - func: npu_nms_v4(Tensor self, Tensor scores, Scalar max_output_size, Tensor iou_threshold, Tensor scores_threshold, bool pad_to_max_output_size=False) -> (Tensor, Tensor) + - func: npu_pad(Tensor input, int[] paddings) -> Tensor + - func: npu_random_choice_with_mask(Tensor x, int count=256, int seed=0, int seed2=0) -> (Tensor, Tensor) + - func: npu_normalize_batch(Tensor self, Tensor seq_len, int normalize_type=0) -> Tensor + - func: npu_ptiou(Tensor bboxes, Tensor gtboxes, int mode=0) -> Tensor custom_autograd: - func: npu_convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, int groups) -> Tensor @@ -1960,4 +1967,5 @@ custom_autograd: - func: npu_mish(Tensor self) -> Tensor variants: function, method - func: npu_min.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices) - - func: npu_min.names_dim(Tensor self, Dimname dim, bool keepdim=False) -> (Tensor values, Tensor indices) \ No newline at end of file + - func: npu_min.names_dim(Tensor self, Dimname dim, bool keepdim=False) -> (Tensor values, Tensor indices) + - func: npu_giou(Tensor self, Tensor gtboxes, bool trans=False, bool is_cross=False, int mode=0) -> Tensor \ No newline at end of file diff --git a/torch_npu/csrc/aten/ops/GiouKernelNpu.cpp b/torch_npu/csrc/aten/ops/GiouKernelNpu.cpp new file mode 100644 index 00000000000..06a47f35e95 --- /dev/null +++ b/torch_npu/csrc/aten/ops/GiouKernelNpu.cpp @@ -0,0 +1,206 @@ +// Copyright (c) 2020 Huawei Technologies Co., Ltd +// Copyright (c) 2019, Facebook CORPORATION. +// All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include "torch_npu/csrc/framework/utils/CalcuOpUtil.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" + +namespace at_npu { +namespace native { +using torch::autograd::Function; +using torch::autograd::AutogradContext; +using tensor_list = std::vector; + +c10::SmallVector giou_output_size( + const at::Tensor& self, + const at::Tensor& gtboxes, + bool is_cross){ + c10::SmallVector output_size; + if(is_cross){ + output_size = {gtboxes.size(0), self.size(0)}; + } else { + output_size = {1, self.size(0)}; + } + return output_size; +} + +at::Tensor& giou_inner_out_npu( + at::Tensor& result, + const at::Tensor& self, + const at::Tensor& gtboxes, + bool trans, + bool is_cross, + int64_t mode){ + auto output_size = giou_output_size(self, gtboxes, is_cross); + string mode_str = mode == 1 ? "iof" : "iou"; + + OpCommand cmd; + cmd.Name("GIoU") + .Input(self) + .Input(gtboxes) + .Output(result) + .Attr("trans", trans) + .Attr("is_cross", is_cross) + .Attr("mode", mode_str) + .Run(); + return result; +} + +at::Tensor giou_npu( + const at::Tensor& self, + const at::Tensor& gtboxes, + bool trans, + bool is_cross, + int64_t mode) { + TORCH_CHECK(trans && !is_cross && mode == 0, + "giou backward only support trans==True, ", + "is_cross==False, ", + "mode==0('iou') current version ", + "if you need to back propagation, ", + "please ensure your parameter is correct!"); + // Op need form of [n, 4], but pass should be [4, n]; + // Note: temp avoid! it'll be removed while op deal with fp16 issue! + at::Tensor selfCp = self.permute({1, 0}); + if (selfCp.scalar_type() == at::kHalf) { + selfCp = NPUNativeFunctions::npu_dtype_cast(selfCp, at::kFloat); + } + at::Tensor gtboxesCp = gtboxes.permute({1, 0}); + if (gtboxesCp.scalar_type() == at::kHalf) { + gtboxesCp = NPUNativeFunctions::npu_dtype_cast(gtboxesCp, at::kFloat); + } + auto output_size = giou_output_size(selfCp, gtboxesCp, is_cross); + at::Tensor result = OpPreparation::ApplyTensor(selfCp, output_size); + + giou_inner_out_npu(result, selfCp, gtboxesCp, trans, is_cross, mode); + result = result.permute({1, 0}); + if (self.scalar_type() == at::kHalf || gtboxes.scalar_type() == at::kHalf) { + result = NPUNativeFunctions::npu_dtype_cast(result, at::kHalf); + } + return result; +} + +std::tuple giou_backward_inner_out_npu( + at::Tensor& dbboxes, + at::Tensor& dgtboxes, + const at::Tensor& grad, + const at::Tensor& bboxes, + const at::Tensor& gtboxes, + bool trans, + bool is_cross, + int64_t mode){ + string mode_str = mode == 1 ? "iof" : "iou"; + + OpCommand cmd; + cmd.Name("GIoUGrad") + .Input(grad) + .Input(bboxes) + .Input(gtboxes) + .Output(dbboxes) + .Output(dgtboxes) + .Attr("trans", trans) + .Attr("is_cross", is_cross) + .Attr("mode", mode_str) + .Run(); + return std::tie(dbboxes, dgtboxes); +} + +std::tuple NPUNativeFunctions::npu_giou_backward( + const at::Tensor& grad, + const at::Tensor& bboxes, + const at::Tensor& gtboxes, + bool trans, + bool is_cross, + int64_t mode){ + TORCH_CHECK(trans && !is_cross && mode == 0, + "giou backward only support trans==True, ", + "is_cross==False, ", + "mode==0('iou') current version ", + "if you need to back propagation, ", + "please ensure your parameter is correct!"); + // Op need form of [n] grad + // Note: temp avoid! it'll be remove while op deal with fp16 issue! + at::Tensor gradCp = at::squeeze(grad, 0); + if (gradCp.scalar_type() == at::kHalf) { + gradCp = NPUNativeFunctions::npu_dtype_cast(gradCp, at::kFloat); + } + at::Tensor bboxesCp = bboxes; + if (bboxesCp.scalar_type() == at::kHalf) { + bboxesCp = NPUNativeFunctions::npu_dtype_cast(bboxesCp, at::kFloat); + } + at::Tensor gtboxesCp = gtboxes; + if (gtboxesCp.scalar_type() == at::kHalf) { + gtboxesCp = NPUNativeFunctions::npu_dtype_cast(gtboxesCp, at::kFloat); + } + at::Tensor dbboxes = OpPreparation::ApplyTensor(bboxesCp); + at::Tensor dgtboxes = OpPreparation::ApplyTensor(gtboxesCp); + + giou_backward_inner_out_npu(dbboxes, dgtboxes, gradCp, bboxesCp, gtboxesCp, trans, is_cross, mode); + if (bboxes.scalar_type() == at::kHalf || gtboxes.scalar_type() == at::kHalf) { + dbboxes = NPUNativeFunctions::npu_dtype_cast(dbboxes, at::kHalf); + dgtboxes = NPUNativeFunctions::npu_dtype_cast(dgtboxes, at::kHalf); + } + return std::tie(dbboxes, dgtboxes); +} + +class NPUGiouFunction : public torch::autograd::Function { +public: + static at::Tensor forward(AutogradContext *ctx, + const at::Tensor& self, + const at::Tensor& gtboxes, + bool trans, + bool is_cross, + int64_t mode) { + ctx->saved_data["trans"] = trans; + ctx->saved_data["is_cross"] = is_cross; + ctx->saved_data["mode"] = mode; + at::AutoNonVariableTypeMode g; + ctx->save_for_backward({self, gtboxes}); + return giou_npu(self, gtboxes, trans, is_cross, mode); + } + + static tensor_list backward(AutogradContext *ctx, + tensor_list grad_outputs) { + auto trans = ctx->saved_data["trans"].toBool(); + auto is_cross = ctx->saved_data["is_cross"].toBool(); + auto mode = ctx->saved_data["mode"].toInt(); + auto saved = ctx->get_saved_variables(); + auto bboxes = saved[0]; + auto gtboxes = saved[1]; + + tuple result = NPUNativeFunctions::npu_giou_backward(grad_outputs[0], + bboxes, gtboxes, trans, is_cross, mode); + + tensor_list output = {std::get<0>(result), + std::get<1>(result), + at::Tensor(), + at::Tensor(), + at::Tensor()}; + return output; + } +}; + +at::Tensor NPUNativeFunctions::npu_giou(const at::Tensor& self, + const at::Tensor& gtboxes, + bool trans, + bool is_cross, + int64_t mode) { + return NPUGiouFunction::apply(self, gtboxes, trans, is_cross, mode); +} + +} // namespace native +} // namespace at \ No newline at end of file diff --git a/torch_npu/csrc/aten/ops/IouKernelNpu.cpp b/torch_npu/csrc/aten/ops/IouKernelNpu.cpp new file mode 100644 index 00000000000..30b70c7536c --- /dev/null +++ b/torch_npu/csrc/aten/ops/IouKernelNpu.cpp @@ -0,0 +1,68 @@ +// Copyright (c) 2020 Huawei Technologies Co., Ltd +// Copyright (c) 2019, Facebook CORPORATION. +// All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include "torch_npu/csrc/framework/utils/CalcuOpUtil.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" + +namespace at_npu { +namespace native { + +at::Tensor NPUNativeFunctions::npu_iou( + const at::Tensor& bboxes, + const at::Tensor& gtboxes, + int64_t mode) { + at::Tensor bboxesFP16 = bboxes; + if (bboxes.scalar_type() != at::ScalarType::Half) { + bboxesFP16 = bboxes.to(at::kHalf); + } + at::Tensor gtboxesFP16 = gtboxes; + if (gtboxes.scalar_type() != at::ScalarType::Half) { + gtboxesFP16 = gtboxes.to(at::kHalf); + } + + auto outputSize = {gtboxes.size(0), bboxes.size(0)}; + at::Tensor overlap = OpPreparation::ApplyTensorWithFormat( + bboxesFP16, + outputSize, + CalcuOpUtil::get_tensor_npu_format(bboxes)); + string modeStr = "iou"; + if (mode == 1) { + modeStr = "iof"; + } + OpCommand cmd; + cmd.Name("Iou") + .Input(bboxesFP16) + .Input(gtboxesFP16) + .Output(overlap) + .Attr("mode", modeStr) + .Attr("eps", static_cast(0.01)) + .Run(); + if (overlap.scalar_type() != bboxes.scalar_type()) { + overlap = overlap.to(bboxes.scalar_type()); + } + return overlap; +} + +at::Tensor NPUNativeFunctions::npu_ptiou( + const at::Tensor& bboxes, + const at::Tensor& gtboxes, + int64_t mode) { + return NPUNativeFunctions::npu_iou(bboxes, gtboxes, mode); +} + +} // namespace native +} // namespace at_npu \ No newline at end of file diff --git a/torch_npu/csrc/aten/ops/NmsV4KernelNpu.cpp b/torch_npu/csrc/aten/ops/NmsV4KernelNpu.cpp new file mode 100644 index 00000000000..ac777eaf973 --- /dev/null +++ b/torch_npu/csrc/aten/ops/NmsV4KernelNpu.cpp @@ -0,0 +1,80 @@ +// Copyright (c) 2020, Huawei Technologies.All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include "torch_npu/csrc/framework/utils/CalcuOpUtil.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" + +namespace at_npu { +namespace native { + +tuple nms_v4_npu_nocheck( + const at::Tensor& self, + const at::Tensor& scores, + at::Scalar max_output_size, + const at::Tensor& iou_threshold, + const at::Tensor& scores_threshold, + bool pad_to_max_output_size, + at::Tensor& selected_indices, + at::Tensor& valid_outputs) { + at::Tensor max_output_size_tensor = OpPreparation::ApplyTensor( + {}, self.options().dtype(at::kInt), self).fill_(max_output_size); + OpCommand cmd; + cmd.Name("NonMaxSuppressionV4") + .Input(self) + .Input(scores) + .Input(max_output_size_tensor) + .Input(iou_threshold) + .Input(scores_threshold) + .Output(selected_indices) + .Output(valid_outputs) + .Attr("pad_to_max_output_size", pad_to_max_output_size) + .Run(); + + return std::tuple(selected_indices, valid_outputs); +} + +tuple NPUNativeFunctions::npu_nms_v4( + const at::Tensor& self, + const at::Tensor& scores, + at::Scalar max_output_size, + const at::Tensor& iou_threshold, + const at::Tensor& scores_threshold, + bool pad_to_max_output_size) { + auto outputSizes = nms_v4_npu_output_size(max_output_size); + + at::Tensor selected_indices = OpPreparation::ApplyTensor( + std::get<0>(outputSizes), + self.options().dtype(at::kInt), + self); + at::Tensor valid_outputs = OpPreparation::ApplyTensor( + std::get<1>(outputSizes), + self.options().dtype(at::kInt), + self); + + nms_v4_npu_nocheck( + self, + scores, + max_output_size, + iou_threshold, + scores_threshold, + pad_to_max_output_size, + selected_indices, + valid_outputs); + + return std::tuple(selected_indices, valid_outputs); +} + +} // namespace native +} // namespace at \ No newline at end of file diff --git a/torch_npu/csrc/aten/ops/NormalizeBatchKernelNpu.cpp b/torch_npu/csrc/aten/ops/NormalizeBatchKernelNpu.cpp new file mode 100644 index 00000000000..2231fe96fe5 --- /dev/null +++ b/torch_npu/csrc/aten/ops/NormalizeBatchKernelNpu.cpp @@ -0,0 +1,67 @@ +// Copyright (c) 2020 Huawei Technologies Co., Ltd +// Copyright (c) 2019, Facebook CORPORATION. +// All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include "torch_npu/csrc/framework/utils/CalcuOpUtil.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" + +constexpr float_t EPSILON = 1e-5; + +namespace at_npu { +namespace native { + +static inline void normalize_batch_check( + const at::Tensor& self, + const at::Tensor& seq_len, + int64_t normalize_type){ + TORCH_CHECK( + seq_len.dim() == 1, + "Non-empty 1D seq_len tensor expected but got a tensor with sizes ", + seq_len.sizes()); + TORCH_CHECK( + seq_len.size(0) == self.size(0), + "seq_len's length should be equal self' num, but got seq_len length ", + seq_len.size(0), + "self num ", + self.size(0)); + TORCH_CHECK( + normalize_type >= 0 && normalize_type <= 1, + "normalize_type expected to be in range [0, 1], but got ", + normalize_type); +} + +at::Tensor NPUNativeFunctions::npu_normalize_batch( + const at::Tensor& self, + const at::Tensor& seq_len, + int64_t normalize_type){ + normalize_batch_check(self, seq_len, normalize_type); + // apply output tensor + at::Tensor result = OpPreparation::ApplyTensor(self); + string normalizeType = normalize_type == 0 ? "per_feature" : "all_features"; + + OpCommand cmd; + cmd.Name("NormalizeBatch") + .Input(self) + .Input(seq_len) + .Output(result) + .Attr("normalize_type", normalizeType) + .Attr("epsilon", EPSILON) + .Run(); + return result; +} + +} // namespace native +} // namespace at \ No newline at end of file diff --git a/torch_npu/csrc/aten/ops/PadKernelNpu.cpp b/torch_npu/csrc/aten/ops/PadKernelNpu.cpp new file mode 100644 index 00000000000..43383b0ef23 --- /dev/null +++ b/torch_npu/csrc/aten/ops/PadKernelNpu.cpp @@ -0,0 +1,48 @@ +// Copyright (c) 2020 Huawei Technologies Co., Ltd +// Copyright (c) 2019, Facebook CORPORATION. +// All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include "torch_npu/csrc/framework/utils/CalcuOpUtil.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" + +namespace at_npu { +namespace native { + +at::Tensor& pad_npu_nocheck( + at::Tensor& output, + const at::Tensor& input, + at::IntArrayRef paddings) { + c10::SmallVector paddingsVector = array_to_small_vector(paddings); + paddingsVector.resize(2 * input.dim(), 0); + + OpCommand cmd; + cmd.Name("Pad") + .Input(input) + .Input(paddingsVector) + .Output(output) + .Run(); + return output; +} + +at::Tensor NPUNativeFunctions::npu_pad(const at::Tensor& input, at::IntArrayRef paddings) { + auto outputSize = pad_npu_output_size(input, paddings); + at::Tensor output = OpPreparation::ApplyTensor(input, outputSize); + pad_npu_nocheck(output, input, paddings); + return output; +} + +} // namespace native +} // namespace at_npu \ No newline at end of file diff --git a/torch_npu/csrc/aten/ops/RandomChoiceWithMaskKernelNpu.cpp b/torch_npu/csrc/aten/ops/RandomChoiceWithMaskKernelNpu.cpp new file mode 100644 index 00000000000..77d73f98de5 --- /dev/null +++ b/torch_npu/csrc/aten/ops/RandomChoiceWithMaskKernelNpu.cpp @@ -0,0 +1,53 @@ +// Copyright (c) 2020, Huawei Technologies.All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include "torch_npu/csrc/framework/utils/CalcuOpUtil.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" + +namespace at_npu { +namespace native { + +std::tuple NPUNativeFunctions::npu_random_choice_with_mask( + const at::Tensor& self, + int64_t count, + int64_t seed, + int64_t seed2) { + TORCH_CHECK( + self.scalar_type() == at::ScalarType::Bool, + "The input.dtype should be bool, but get", + self.scalar_type()); + TORCH_CHECK( + self.dim() <= 5 && self.dim() >= 1, + "The input.dim should be in [1, 5], but get", + self.dim()); + TORCH_CHECK(count > 0, "The count must greater than 0, but get", count); + + at::Tensor result = OpPreparation::ApplyTensor({count, self.dim()}, self.options().dtype(at::kInt), self); + at::Tensor mask = OpPreparation::ApplyTensor(self, {count}); + OpCommand cmd; + cmd.Name("RandomChoiceWithMask") + .Input(self) + .Output(result) + .Output(mask) + .Attr("count", count) + .Attr("seed", seed) + .Attr("seed2", seed2) + .Run(); + + return std::tie(result, mask); +} + +} // namespace native +} // namespace at \ No newline at end of file -- Gitee