diff --git a/test/test_network_ops/test_nms_rotated.py b/test/test_network_ops/test_nms_rotated.py new file mode 100644 index 0000000000000000000000000000000000000000..86e459501ca8cef1d46f5ef0aada6bdabfe72a42 --- /dev/null +++ b/test/test_network_ops/test_nms_rotated.py @@ -0,0 +1,58 @@ +# Copyright (c) 2020, Huawei Technologies.All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch_npu + +from torch_npu.testing.common_utils import TestCase, run_tests +from torch_npu.testing.common_device_type import instantiate_device_type_tests + +class TestNmsRotated(TestCase): + def npu_op_exec(self, det, score): + output1, output2 = torch_npu.npu_nms_rotated(det.npu(), score.npu(), 0.2, 0, -1, 1) + return output1, output2 + + def test_nms_rotated_float32(self, device): + det = torch.tensor([[1.0382e+03, 3.1657e+02, 1.1556e+03, 4.4303e+02, 2.3674e+00], + [1.1503e+03, 3.0598e+02, 1.2602e+03, 4.3456e+02, 3.2729e-01], + [1.1508e+03, 3.0652e+02, 1.2607e+03, 4.3472e+02, 5.1713e-01], + [1.1518e+03, 3.0781e+02, 1.2622e+03, 4.3448e+02, 3.9718e-01], + [1.1748e+03, 3.0202e+02, 1.2859e+03, 4.3915e+02, 1.8112e+00], + [1.1711e+03, 3.0480e+02, 1.2868e+03, 4.3551e+02, 2.1171e+00], + [1.1673e+03, 3.0675e+02, 1.2889e+03, 4.3194e+02, 2.5968e+00], + [1.2741e+03, 3.0181e+02, 1.3823e+03, 4.3036e+02, 2.0379e+00], + [1.2741e+03, 3.0286e+02, 1.3836e+03, 4.2940e+02, 2.2072e+00], + [1.2733e+03, 3.0382e+02, 1.3855e+03, 4.2846e+02, 2.0921e+00], + [1.2935e+03, 3.0517e+02, 1.3961e+03, 4.3137e+02, 2.9583e+00], + [1.4076e+03, 3.2173e+02, 1.4930e+03, 4.2714e+02, 2.6099e+00], + [1.4097e+03, 3.2496e+02, 1.4934e+03, 4.2651e+02, 3.0967e+00], + [1.4097e+03, 3.2569e+02, 1.4935e+03, 4.2632e+02, 2.5553e+00], + [1.0279e+03, 3.1883e+02, 1.1412e+03, 4.4646e+02, 1.2030e+00], + [1.0275e+03, 3.1776e+02, 1.1408e+03, 4.4641e+02, 1.2732e+00], + [1.0289e+03, 3.1694e+02, 1.1407e+03, 4.4510e+02, 9.4897e-01], + [1.0372e+03, 3.1233e+02, 1.1477e+03, 4.4521e+02, 1.4125e+00], + [1.0370e+03, 3.1564e+02, 1.1487e+03, 4.4317e+02, 1.6109e+00], + [1.0367e+03, 3.1682e+02, 1.1510e+03, 4.4020e+02, 1.4112e+00]]) + score = torch.tensor([0.9910, 0.9854, 0.9972, 0.9930, 0.4282, 0.5092, 0.6532, 0.9965, 0.9989, + 0.9976, 0.3144, 0.9874, 0.9980, 0.9967, 0.9698, 0.9824, 0.9474, 0.9856, 0.9964, 0.9926]) + + expect_output1 = torch.tensor([8, 12, 2, 18], dtype=torch.int32) + expect_output2 = torch.tensor([4], dtype=torch.int32) + npu_output1, npu_output2 = self.npu_op_exec(det, score) + self.assertRtolEqual(expect_output1, npu_output1.cpu()) + self.assertRtolEqual(expect_output2, npu_output2.cpu()) + +instantiate_device_type_tests(TestNmsRotated, globals(), except_for='cpu') +if __name__ == "__main__": + run_tests() \ No newline at end of file diff --git a/torch_npu/csrc/aten/npu_native_functions.yaml b/torch_npu/csrc/aten/npu_native_functions.yaml index d46b08a2fcf38bd543448a7f3b65790f54e39d1d..51ecc22c349d71aed2fe6eb9e6a773855f8732de 100644 --- a/torch_npu/csrc/aten/npu_native_functions.yaml +++ b/torch_npu/csrc/aten/npu_native_functions.yaml @@ -1909,8 +1909,6 @@ custom: - func: npu_indexing.out(Tensor self, int[] begin, int[] end, int[] strides, int begin_mask=0, int end_mask=0, int ellipsis_mask=0, int new_axis_mask=0, int shrink_axis_mask=0, *, Tensor(a!) out) -> Tensor(a!) - func: npu_bmmV2(Tensor self, Tensor mat2, int[] output_sizes) -> Tensor variants: function, method - - func: npu_softmax_cross_entropy_with_logits(Tensor self, Tensor labels) -> Tensor - variants: function, method - func: npu_softmax_cross_entropy_with_logits_backward(Tensor grad, Tensor self, Tensor labels) -> Tensor variants: function, method - func: npu_stride_copy(Tensor self, int[] shape, int[] stride, Scalar storage_offset) -> Tensor @@ -1927,6 +1925,8 @@ custom: - func: npu_anchor_response_flags(Tensor self, int[2] featmap_size, int[2] stride, int num_base_anchors) -> Tensor variants: function, method - func: npu_dropout_backward(Tensor grad_output, Tensor mask, float p) -> Tensor + - func: npu_nms_rotated(Tensor self, Tensor scores, float iou_threshold, float scores_threshold=0, int max_output_size=-1, int mode=0) -> (Tensor, Tensor) + variants: function, method custom_autograd: - func: npu_convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, int groups) -> Tensor - func: npu_convolution_transpose(Tensor input, Tensor weight, Tensor? bias, int[] padding, int[] output_padding, int[] stride, int[] dilation, int groups) -> Tensor @@ -1935,4 +1935,6 @@ custom_autograd: variants: function, method - func: npu_ps_roi_pooling(Tensor self, Tensor rois, float spatial_scale, int group_size, int output_dim) -> Tensor - func: npu_linear(Tensor input, Tensor weight, Tensor? bias=None) -> Tensor - - func: _npu_dropout(Tensor self, float p) -> (Tensor, Tensor) \ No newline at end of file + - func: _npu_dropout(Tensor self, float p) -> (Tensor, Tensor) + - func: npu_softmax_cross_entropy_with_logits(Tensor self, Tensor labels) -> Tensor + variants: function, method \ No newline at end of file diff --git a/torch_npu/csrc/aten/ops/NmsRotatedKernelNpu.cpp b/torch_npu/csrc/aten/ops/NmsRotatedKernelNpu.cpp new file mode 100644 index 0000000000000000000000000000000000000000..046a6d4192a99d550cf84af727d86b7410354824 --- /dev/null +++ b/torch_npu/csrc/aten/ops/NmsRotatedKernelNpu.cpp @@ -0,0 +1,62 @@ +// Copyright (c) 2020 Huawei Technologies Co., Ltd +// Copyright (c) 2019, Facebook CORPORATION. +// All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" + +namespace at_npu { +namespace native { + +tuple NPUNativeFunctions::npu_nms_rotated( + const at::Tensor& dets, + const at::Tensor& scores, + double iouThreshold, + double scoreThreshold, + int64_t maxOutputSize, + int64_t mode) { + c10::SmallVector selectedIndexSize = {dets.size(0)}; + c10::SmallVector selectedNumSize = {1}; + + at::Tensor selectedIndex = OpPreparation::ApplyTensor(selectedIndexSize, dets.options().dtype(at::kInt), dets); + at::Tensor selectedNum = OpPreparation::ApplyTensor(selectedNumSize, dets.options().dtype(at::kInt), dets); + + // the Op only support fp32 currently! + auto originDtype = dets.scalar_type(); + at::Tensor detsCast = dets; + at::Tensor scoresCast = scores; + if(originDtype != at::ScalarType::Float){ + detsCast = NPUNativeFunctions::npu_dtype_cast(dets, at::kFloat); + scoresCast = NPUNativeFunctions::npu_dtype_cast(scores, at::kFloat); + } + + OpCommand cmd; + cmd.Name("PolyNMS") + .Input(detsCast) + .Input(scoresCast) + .Output(selectedIndex) + .Output(selectedNum) + .Attr("iou_threshold", (float)iouThreshold) + .Attr("score_threshold", (float)scoreThreshold) + .Attr("max_output_size", maxOutputSize) + .Attr("mode", mode) + .Run(); + + at::Tensor selectedInd = selectedIndex.slice(0, 0, selectedNum.item().toLong()); + return std::tie(selectedInd, selectedNum); +} + +} // namespace native +} // namespace at_npu \ No newline at end of file diff --git a/torch_npu/csrc/aten/ops/SoftmaxCrossEntropyWithLogitsKernelNpu.cpp b/torch_npu/csrc/aten/ops/SoftmaxCrossEntropyWithLogitsKernelNpu.cpp index 73793d896c34ac63ead4299d0e970d3a5a1fb8b2..787469cb2840db54f0553fa2b4858ab4b4b01471 100644 --- a/torch_npu/csrc/aten/ops/SoftmaxCrossEntropyWithLogitsKernelNpu.cpp +++ b/torch_npu/csrc/aten/ops/SoftmaxCrossEntropyWithLogitsKernelNpu.cpp @@ -53,7 +53,7 @@ tuple softmax_cross_entropy_with_logits_impl_npu( return std::make_tuple(result, backprop); } -at::Tensor NPUNativeFunctions::npu_softmax_cross_entropy_with_logits( +at::Tensor softmax_cross_entropy_with_logits_npu( const at::Tensor& self, const at::Tensor& labels) { TORCH_CHECK(self.device().type() == c10::DeviceType::NPU); @@ -76,7 +76,7 @@ public: ctx->saved_data["labels"] = labels; at::AutoNonVariableTypeMode g; ctx->save_for_backward({self}); - return NPUNativeFunctions::npu_softmax_cross_entropy_with_logits(self, labels); + return softmax_cross_entropy_with_logits_npu(self, labels); } static tensor_list backward(AutogradContext *ctx, @@ -94,7 +94,7 @@ public: } }; -at::Tensor npu_softmax_cross_entropy_with_logits_autograd(const at::Tensor& self, +at::Tensor NPUNativeFunctions::npu_softmax_cross_entropy_with_logits(const at::Tensor& self, const at::Tensor& labels) { return NPUSoftmaxCrossEntropyWithLogitsFunction::apply(self, labels); }