From ca679327c4c97c9877f2faa711b1ee0ab215ba72 Mon Sep 17 00:00:00 2001 From: pipihugh Date: Tue, 8 Mar 2022 10:15:21 +0800 Subject: [PATCH] =?UTF-8?q?nms=5Fwith=5Fmask=E7=AE=97=E5=AD=90=E7=A7=BB?= =?UTF-8?q?=E6=A4=8D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- test/test_network_ops/test_nms_with_mask.py | 43 +++++++++++++++ torch_npu/csrc/aten/npu_native_functions.yaml | 2 + .../csrc/aten/ops/NmsWithMaskKernelNpu.cpp | 54 +++++++++++++++++++ 3 files changed, 99 insertions(+) create mode 100644 test/test_network_ops/test_nms_with_mask.py create mode 100644 torch_npu/csrc/aten/ops/NmsWithMaskKernelNpu.cpp diff --git a/test/test_network_ops/test_nms_with_mask.py b/test/test_network_ops/test_nms_with_mask.py new file mode 100644 index 0000000000..6b367f1f29 --- /dev/null +++ b/test/test_network_ops/test_nms_with_mask.py @@ -0,0 +1,43 @@ +# Copyright (c) 2020, Huawei Technologies.All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch_npu +import numpy as np + +from torch_npu.testing.testcase import TestCase, run_tests + +class TestNmsWithMask(TestCase): + def npu_op_exec(self, input1, iou_threshold): + npu_output1, npu_output2, npu_output3, = torch_npu.npu_nms_with_mask(input1, iou_threshold) + npu_output1 = npu_output1.to("cpu") + npu_output2 = npu_output2.to("cpu") + npu_output3 = npu_output3.to("cpu") + + return npu_output1, npu_output2, npu_output3 + + def test_nms_with_mask_float32(self): + input1 = torch.tensor([[0.0, 1.0, 2.0, 3.0, 0.6], [6.0, 7.0, 8.0, 9.0, 0.4]]).npu() + iou_threshold = 0.5 + eq_output1 = torch.tensor([[0.0000, 1.0000, 2.0000, 3.0000, 0.6001], + [6.0000, 7.0000, 8.0000, 9.0000, 0.3999]]) + eq_output2 = torch.tensor([0, 1], dtype=torch.int32) + eq_output3 = torch.tensor([1, 1], dtype=torch.uint8) + npu_output1, npu_output2, npu_output3 = self.npu_op_exec(input1, iou_threshold) + self.assertRtolEqual(eq_output1, npu_output1) + self.assertRtolEqual(eq_output2, npu_output2) + self.assertRtolEqual(eq_output3, npu_output3) + +if __name__ == "__main__": + run_tests() \ No newline at end of file diff --git a/torch_npu/csrc/aten/npu_native_functions.yaml b/torch_npu/csrc/aten/npu_native_functions.yaml index f794be4e55..a97b73a237 100644 --- a/torch_npu/csrc/aten/npu_native_functions.yaml +++ b/torch_npu/csrc/aten/npu_native_functions.yaml @@ -1928,6 +1928,8 @@ custom: - func: npu_rotated_overlaps(Tensor self, Tensor query_boxes, bool trans=False) -> Tensor - func: npu_silu_backward(Tensor grad_output, Tensor x0, Tensor x1) -> Tensor - func: npu_rotated_iou(Tensor self, Tensor query_boxes, bool trans=False, int mode=0, bool is_cross=True, float v_threshold=0.0, float e_threshold=0.0) -> Tensor + - func: npu_nms_with_mask(Tensor input, Scalar iou_threshold) -> (Tensor, Tensor, Tensor) + custom_autograd: - func: npu_convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, int groups) -> Tensor - func: npu_convolution_transpose(Tensor input, Tensor weight, Tensor? bias, int[] padding, int[] output_padding, int[] stride, int[] dilation, int groups) -> Tensor diff --git a/torch_npu/csrc/aten/ops/NmsWithMaskKernelNpu.cpp b/torch_npu/csrc/aten/ops/NmsWithMaskKernelNpu.cpp new file mode 100644 index 0000000000..1870a286d7 --- /dev/null +++ b/torch_npu/csrc/aten/ops/NmsWithMaskKernelNpu.cpp @@ -0,0 +1,54 @@ +// Copyright (c) 2020 Huawei Technologies Co., Ltd +// Copyright (c) 2019, Facebook CORPORATION. +// All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include "torch_npu/csrc/framework/utils/CalcuOpUtil.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" + +namespace at_npu { +namespace native { + +tuple nms_with_mask_npu_nocheck( + const at::Tensor& input, + at::Scalar iou_threshold, + at::Tensor& boxes, + at::Tensor& idx, + at::Tensor& mask) { + float iouThresholdValue = CalcuOpUtil::get_scalar_float_value(iou_threshold); + OpCommand cmd; + cmd.Name("NMSWithMask") + .Input(input) + .Output(boxes) + .Output(idx) + .Output(mask) + .Attr("iou_threshold", iouThresholdValue) + .Run(); + return std::tuple(boxes, idx, mask); +} + +tuple NPUNativeFunctions::npu_nms_with_mask( + const at::Tensor& input, + at::Scalar iou_threshold) { + auto outputSizes = nms_with_mask_npu_output_size(input); + at::Tensor boxes = OpPreparation::ApplyTensor(input, std::get<0>(outputSizes)); + at::Tensor idx = OpPreparation::ApplyTensor(std::get<1>(outputSizes), input.options().dtype(at::kInt), input); + at::Tensor mask = OpPreparation::ApplyTensor(std::get<2>(outputSizes), input.options().dtype(at::kByte), input); + nms_with_mask_npu_nocheck(input, iou_threshold, boxes, idx, mask); + return std::tuple(boxes, idx, mask); +} + +} // namespace native +} // namespace at_npu -- Gitee