diff --git a/test/test_network_ops/test_batch_nms.py b/test/test_network_ops/test_batch_nms.py new file mode 100644 index 0000000000000000000000000000000000000000..0cb7b5fcafc2e64ea6046fd50d8e672f292b4d50 --- /dev/null +++ b/test/test_network_ops/test_batch_nms.py @@ -0,0 +1,41 @@ +# Copyright (c) 2020, Huawei Technologies.All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch_npu + +from torch_npu.testing.testcase import TestCase, run_tests +from torch_npu.testing.common_utils import create_common_tensor + +class TesBatchNms(TestCase): + def test_batch_nms_shape_format(self): + boxes = torch.randn(8, 2, 4, 4, dtype = torch.float32).to("npu") + scores = torch.randn(3, 2, 4, dtype = torch.float32).to("npu") + boxes_fp16 = boxes.to(torch.half) + scores_fp16 = scores.to(torch.half) + nmsed_boxes, nmsed_scores, nmsed_classes, nmsed_num = torch_npu.npu_batch_nms(boxes, scores, 0.3, 0.5, 3, 4) + boxes1, scores1, classes1, num1 = torch_npu.npu_batch_nms(boxes_fp16, scores_fp16, 0.3, 0.5, 3, 4) + expedt_nmsed_classes = torch.tensor([[0.0000, 2.1250, 0.0000, 1.8750], + [0.0000, 0.0000, 0.0000, 0.0000], + [0.0000, 1.8750, 0.0000, 2.0000], + [0.0000, 0.0000, 0.0000, 0.0000], + [0.0000, 2.0000, 0.0000, 2.1250], + [0.0000, 2.1250, 0.0000, 1.8750], + [0.0000, 0.0000, 0.0000, 0.0000], + [0.0000, 0.0000, 0.0000, 0.0000]], dtype = torch.float16) + self.assertRtolEqual(expedt_nmsed_classes, nmsed_classes.cpu()) + self.assertRtolEqual(expedt_nmsed_classes, classes1.cpu()) + +if __name__ == "__main__": + run_tests() diff --git a/test/test_network_ops/test_bounding_box_decode.py b/test/test_network_ops/test_bounding_box_decode.py new file mode 100644 index 0000000000000000000000000000000000000000..1941a8554ca2fc1114bd158c2262847bb2955569 --- /dev/null +++ b/test/test_network_ops/test_bounding_box_decode.py @@ -0,0 +1,35 @@ +# Copyright (c) 2020, Huawei Technologies.All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch_npu + +from torch_npu.testing.testcase import TestCase, run_tests +from torch_npu.testing.common_utils import create_common_tensor + +class TesBoundingBoxDecode(TestCase): + def test_decode_shape_format_fp32(self): + input1 = torch.tensor([[1., 2., 3., 4.], [3.,4., 5., 6.]], dtype = torch.float32).to("npu") + input2 = torch.tensor([[5., 6., 7., 8.], [7.,8., 9., 6.]], dtype = torch.float32).to("npu") + input1_fp16 = input1.to(torch.half) + input2_fp16 = input2.to(torch.half) + expedt_output = torch.tensor([[2.5000, 6.5000, 9.0000, 9.0000], + [9.0000, 9.0000, 9.0000, 9.0000]], dtype = torch.float32) + output = torch_npu.npu_bounding_box_decode(input1, input2, 0, 0, 0, 0, 1, 1, 1, 1, (10, 10), 0.1) + output_fp16 = torch_npu.npu_bounding_box_decode(input1_fp16, input2_fp16, 0, 0, 0, 0, 1, 1, 1, 1, (10, 10), 0.1) + self.assertRtolEqual(expedt_output, output.cpu()) + self.assertRtolEqual(expedt_output.to(torch.half), output_fp16.cpu()) + +if __name__ == "__main__": + run_tests() diff --git a/test/test_network_ops/test_bounding_box_encode.py b/test/test_network_ops/test_bounding_box_encode.py new file mode 100644 index 0000000000000000000000000000000000000000..60d0287597c61ed89b2578d70718697cfbf7f659 --- /dev/null +++ b/test/test_network_ops/test_bounding_box_encode.py @@ -0,0 +1,35 @@ +# Copyright (c) 2020, Huawei Technologies.All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch_npu + +from torch_npu.testing.testcase import TestCase, run_tests +from torch_npu.testing.common_utils import create_common_tensor + +class TesBoundingBoxEncode(TestCase): + def test_encode_shape_format(self): + input1 = torch.tensor([[1., 2., 3., 4.], [3.,4., 5., 6.]], dtype = torch.float32).to("npu") + input2 = torch.tensor([[5., 6., 7., 8.], [7.,8., 9., 6.]], dtype = torch.float32).to("npu") + input1_fp16 = input1.to(torch.half) + input2_fp16 = input2.to(torch.half) + expedt_output = torch.tensor([[13.3281, 13.3281, 0.0000, 0.0000], + [13.3281, 6.6641, 0.0000, -5.4922]], dtype = torch.float32) + output = torch_npu.npu_bounding_box_encode(input1, input2, 0, 0, 0, 0, 0.1, 0.1, 0.2, 0.2) + output_fp16 = torch_npu.npu_bounding_box_encode(input1_fp16, input2_fp16, 0, 0, 0, 0, 0.1, 0.1, 0.2, 0.2) + self.assertRtolEqual(expedt_output, output.cpu()) + self.assertRtolEqual(expedt_output.to(torch.half), output_fp16.cpu()) + +if __name__ == "__main__": + run_tests() diff --git a/torch_npu/csrc/aten/npu_native_functions.yaml b/torch_npu/csrc/aten/npu_native_functions.yaml index 9144830bc98672eb7f7a53a4f11f7f82851df044..405fa9d419f716a09dc071ed7f5a67976005a182 100644 --- a/torch_npu/csrc/aten/npu_native_functions.yaml +++ b/torch_npu/csrc/aten/npu_native_functions.yaml @@ -1933,6 +1933,9 @@ custom: - func: npu_min_backward(Tensor grad, int dim, Tensor indices, int[] sizes, bool keepdim=False) -> Tensor - func: npu_reshape(Tensor self, int[] shape, bool can_refresh=False) -> Tensor - func: npu_reshape.out(Tensor self, int[] shape, bool can_refresh=False, *, Tensor(a!) out) -> Tensor(a!) + - func: npu_batch_nms(Tensor self, Tensor scores, float score_threshold, float iou_threshold, int max_size_per_class, int max_total_size, bool change_coordinate_frame=False, bool transpose_box=False) -> (Tensor, Tensor, Tensor, Tensor) + - func: npu_bounding_box_encode(Tensor anchor_box, Tensor ground_truth_box, float means0, float means1, float means2, float means3, float stds0, float stds1, float stds2, float stds3) -> Tensor + - func: npu_bounding_box_decode(Tensor rois, Tensor deltas, float means0, float means1, float means2, float means3, float stds0, float stds1, float stds2, float stds3, int[1] max_shape, float wh_ratio_clip) -> Tensor custom_autograd: - func: npu_convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, int groups) -> Tensor diff --git a/torch_npu/csrc/aten/ops/BatchNMSKernelNpu.cpp b/torch_npu/csrc/aten/ops/BatchNMSKernelNpu.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c75431c8bfecfc0e6933dbab2730c97e21c0f304 --- /dev/null +++ b/torch_npu/csrc/aten/ops/BatchNMSKernelNpu.cpp @@ -0,0 +1,65 @@ +// Copyright (c) 2020, Huawei Technologies.All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" + +namespace at_npu { +namespace native { + +std::tuple NPUNativeFunctions::npu_batch_nms( + const at::Tensor& self, + const at::Tensor& scores, + double score_threshold, + double iou_threshold, + int64_t max_size_per_class, + int64_t max_total_size, + bool change_coordinate_frame, + bool transpose_box) { + at::Tensor nmsed_boxes = OpPreparation::ApplyTensor( + {self.size(0), max_total_size, 4}, + self.options().dtype(at::kHalf), + self); + at::Tensor nmsed_scores = OpPreparation::ApplyTensor( + {self.size(0), max_total_size}, + self.options().dtype(at::kHalf), + self); + at::Tensor nmsed_classes = OpPreparation::ApplyTensor( + {self.size(0), max_total_size}, + self.options().dtype(at::kHalf), + self); + at::Tensor nmsed_num = OpPreparation::ApplyTensor( + {self.size(0)}, + self.options().dtype(at::kInt), + self); + OpCommand cmd; + cmd.Name("BatchMultiClassNonMaxSuppression") + .Input(self) + .Input(scores) + .Output(nmsed_boxes) + .Output(nmsed_scores) + .Output(nmsed_classes) + .Output(nmsed_num) + .Attr("score_threshold", static_cast(score_threshold)) + .Attr("iou_threshold", static_cast(iou_threshold)) + .Attr("max_size_per_class", max_size_per_class) + .Attr("max_total_size", max_total_size) + .Attr("change_coordinate_frame", change_coordinate_frame) + .Attr("transpose_box", transpose_box) + .Run(); + return std::tie(nmsed_boxes, nmsed_scores, nmsed_classes, nmsed_num); +} + +} // namespace native +} // namespace at_npu diff --git a/torch_npu/csrc/aten/ops/BoundingBoxDecodeKernelNpu.cpp b/torch_npu/csrc/aten/ops/BoundingBoxDecodeKernelNpu.cpp new file mode 100644 index 0000000000000000000000000000000000000000..58c51e909466198c0e23b1d40eaec147627771d4 --- /dev/null +++ b/torch_npu/csrc/aten/ops/BoundingBoxDecodeKernelNpu.cpp @@ -0,0 +1,75 @@ +// Copyright (c) 2020 Huawei Technologies Co., Ltd +// Copyright (c) 2019, Facebook CORPORATION. +// All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" + +namespace at_npu { +namespace native { + +at::Tensor& bounding_box_decode_npu_nocheck( + const at::Tensor& rois, + const at::Tensor& deltas, + c10::SmallVector means, + c10::SmallVector stds, + at::IntArrayRef max_shape, + double wh_ratio_clip, + at::Tensor& result) { + OpCommand cmd; + cmd.Name("BoundingBoxDecode") + .Input(rois) + .Input(deltas) + .Output(result) + .Attr("means", means) + .Attr("stds", stds) + .Attr("max_shape", max_shape) + .Attr("wh_ratio_clip", static_cast(wh_ratio_clip)) + .Run(); + return result; +} + +at::Tensor NPUNativeFunctions::npu_bounding_box_decode( + const at::Tensor& rois, + const at::Tensor& deltas, + double means0, + double means1, + double means2, + double means3, + double stds0, + double stds1, + double stds2, + double stds3, + at::IntArrayRef max_shape, + double wh_ratio_clip) { + c10::SmallVector outputSize = {rois.size(0), 4}; + at::Tensor result = OpPreparation::ApplyTensor(rois, outputSize); + c10::SmallVector means = { + static_cast(means0), + static_cast(means1), + static_cast(means2), + static_cast(means3)}; + c10::SmallVector stds = { + static_cast(stds0), + static_cast(stds1), + static_cast(stds2), + static_cast(stds3)}; + bounding_box_decode_npu_nocheck( + rois, deltas, means, stds, max_shape, wh_ratio_clip, result); + return result; +} + +} // namespace native +} // namespace at_npu diff --git a/torch_npu/csrc/aten/ops/BoundingBoxEncodeKernelNpu.cpp b/torch_npu/csrc/aten/ops/BoundingBoxEncodeKernelNpu.cpp new file mode 100644 index 0000000000000000000000000000000000000000..3406e6b99eed918f5ce9de0db5769a2af2b8339d --- /dev/null +++ b/torch_npu/csrc/aten/ops/BoundingBoxEncodeKernelNpu.cpp @@ -0,0 +1,68 @@ +// Copyright (c) 2020 Huawei Technologies Co., Ltd +// Copyright (c) 2019, Facebook CORPORATION. +// All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" + +namespace at_npu { +namespace native { + +at::Tensor& bounding_box_encode_npu_nocheck( + const at::Tensor& anchor_box, + const at::Tensor& ground_truth_box, + c10::SmallVector means, + c10::SmallVector stds, + at::Tensor& result) { + OpCommand cmd; + cmd.Name("BoundingBoxEncode") + .Input(anchor_box) + .Input(ground_truth_box) + .Output(result) + .Attr("means", means) + .Attr("stds", stds) + .Run(); + return result; +} + +at::Tensor NPUNativeFunctions::npu_bounding_box_encode( + const at::Tensor& anchor_box, + const at::Tensor& ground_truth_box, + double means0, + double means1, + double means2, + double means3, + double stds0, + double stds1, + double stds2, + double stds3) { + at::Tensor result = OpPreparation::ApplyTensor(anchor_box, {anchor_box.size(0), 4}); + c10::SmallVector means = { + static_cast(means0), + static_cast(means1), + static_cast(means2), + static_cast(means3)}; + c10::SmallVector stds = { + static_cast(stds0), + static_cast(stds1), + static_cast(stds2), + static_cast(stds3)}; + bounding_box_encode_npu_nocheck( + anchor_box, ground_truth_box, means, stds, result); + return result; +} + +} // namespace native +} // namespace at_npu