From d1f7fdb96ed1a891d3ba6766a72fcef78254edca Mon Sep 17 00:00:00 2001 From: zhoufan37 Date: Wed, 16 Feb 2022 13:33:04 +0800 Subject: [PATCH] Add Rotated Operator --- test/test_network_ops/test_rotated_box.py | 66 +++++++++++++ test/test_network_ops/test_rotated_iou.py | 82 ++++++++++++++++ .../test_network_ops/test_rotated_overlaps.py | 96 +++++++++++++++++++ torch_npu/csrc/aten/npu_native_functions.yaml | 4 + .../aten/ops/RotatedBoxDecodeKernelNpu.cpp | 41 ++++++++ .../aten/ops/RotatedBoxEncodeKernelNpu.cpp | 41 ++++++++ .../csrc/aten/ops/RotatedIouKernelNpu.cpp | 81 ++++++++++++++++ .../aten/ops/RotatedOverlapsKernelNpu.cpp | 63 ++++++++++++ 8 files changed, 474 insertions(+) create mode 100644 test/test_network_ops/test_rotated_box.py create mode 100644 test/test_network_ops/test_rotated_iou.py create mode 100644 test/test_network_ops/test_rotated_overlaps.py create mode 100644 torch_npu/csrc/aten/ops/RotatedBoxDecodeKernelNpu.cpp create mode 100644 torch_npu/csrc/aten/ops/RotatedBoxEncodeKernelNpu.cpp create mode 100644 torch_npu/csrc/aten/ops/RotatedIouKernelNpu.cpp create mode 100644 torch_npu/csrc/aten/ops/RotatedOverlapsKernelNpu.cpp diff --git a/test/test_network_ops/test_rotated_box.py b/test/test_network_ops/test_rotated_box.py new file mode 100644 index 0000000000..91c690fd69 --- /dev/null +++ b/test/test_network_ops/test_rotated_box.py @@ -0,0 +1,66 @@ +# Copyright (c) 2020, Huawei Technologies.All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch_npu + +from torch_npu.testing.common_utils import TestCase, run_tests +from torch_npu.testing.common_device_type import instantiate_device_type_tests + +class TestRotatedBox(TestCase): + def npu_op_encode_exec(self, anchor_boxes, gt_bboxes, weight): + out = torch_npu.npu_rotated_box_encode(anchor_boxes, gt_bboxes, weight) + out = out.to("cpu") + return out.detach().numpy() + + def npu_op_decode_exec(self, anchor_boxes, deltas, weight): + out = torch_npu.npu_rotated_box_decode(anchor_boxes, deltas, weight) + out = out.to("cpu") + return out.detach().numpy() + + def test_rotated_boxes_encode_fp32(self, device): + anchor_boxes = torch.tensor([[[44.2877], [9.1412], [88.7575], [25.8879], [64.8047]]]).to("npu") + gt_bboxes = torch.tensor([[[39.1763], [0.9838], [78.1028], [29.5997], [51.5907]]]).to("npu") + weight = torch.tensor([1., 1., 1., 1., 1.]).npu() + expect_cpu = torch.tensor([[[-0.1773], [-0.1327], [-0.1331], [0.5358], [-0.8643]]]) + npu_output = self.npu_op_encode_exec(anchor_boxes, gt_bboxes, weight) + self.assertRtolEqual(expect_cpu.numpy(), npu_output) + + def test_rotated_boxes_decode_fp32(self, device): + anchor_boxes = torch.tensor([[[32.1855], [41.9922], [64.1435], [62.5325], [34.607]]]).to("npu") + deltas = torch.tensor([[[1.8725], [-1.8915], [0.2395], [-0.4622], [-34.6539]]]).to("npu") + weight = torch.tensor([1., 1., 1., 1., 1.]).npu() + expect_cpu = torch.tensor([[[87.70366], [6.9412346], [128.31055], [19.879467], [-88.313515]]]) + npu_output = self.npu_op_decode_exec(anchor_boxes, deltas, weight) + self.assertRtolEqual(expect_cpu.numpy(), npu_output) + + def test_rotated_boxes_encode_fp16(self, device): + anchor_boxes = torch.tensor([[[30.69], [32.6], [45.94], [59.88], [-44.53]]], dtype=torch.float16).to("npu") + gt_bboxes = torch.tensor([[[30.44], [18.72], [33.22], [45.56], [8.5]]], dtype=torch.float16).to("npu") + weight = torch.tensor([1., 1., 1., 1., 1.], dtype=torch.float16).npu() + expect_cpu = torch.tensor([[[-0.4253], [-0.5166], [-1.702], [-0.0162], [1.133]]], dtype=torch.float16) + npu_output = self.npu_op_encode_exec(anchor_boxes, gt_bboxes, weight) + self.assertRtolEqual(expect_cpu.numpy(), npu_output) + + def test_rotated_boxes_decode_fp16(self, device): + anchor_boxes = torch.tensor([[[4.137],[33.72],[29.4], [54.06], [41.28]]], dtype=torch.float16).to("npu") + deltas = torch.tensor([[[0.0244], [-1.992], [0.2109], [0.315], [-37.25]]], dtype=torch.float16).to("npu") + weight = torch.tensor([1., 1., 1., 1., 1.], dtype=torch.float16).npu() + expect_cpu = torch.tensor([[[1.786], [-10.58], [33.], [17.3], [-88.44]]], dtype=torch.float16) + npu_output = self.npu_op_decode_exec(anchor_boxes, deltas, weight) + self.assertRtolEqual(expect_cpu.numpy(), npu_output) + +instantiate_device_type_tests(TestRotatedBox, globals(), except_for='cpu') +if __name__ == "__main__": + run_tests() \ No newline at end of file diff --git a/test/test_network_ops/test_rotated_iou.py b/test/test_network_ops/test_rotated_iou.py new file mode 100644 index 0000000000..63228ebace --- /dev/null +++ b/test/test_network_ops/test_rotated_iou.py @@ -0,0 +1,82 @@ +# Copyright (c) 2020, Huawei Technologies.All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +import torch_npu +import numpy as np + +from torch_npu.testing.common_utils import TestCase, run_tests +from torch_npu.testing.common_device_type import instantiate_device_type_tests + +class TestRotatedIou(TestCase): + def generate_rto_data(self, item): + minValue, maxValue = 20, 60 + scope = 20 + dtype = item[0][0] + shape_one = item[0][-1] + shape_two = item[1][-1] + trans = item[-1] + + boxes_array1 = np.random.uniform(minValue, maxValue, shape_one[:2]+[2]).astype(dtype) + boxes_wh = np.random.randint(1, scope, size=shape_one[:2]+[2]) + boxes_angle = np.random.randint(-180, 180, size=shape_one[:2]+[1]) + boxes = np.concatenate([boxes_array1, boxes_wh, boxes_angle], dtype=dtype, axis=-1) + #query_boxes + query_boxes_array1 = np.random.uniform(minValue, maxValue, shape_two[:2]+[2]).astype(dtype) + query_boxes_wh = np.random.randint(1, scope, size=shape_two[:2]+[2] ) + query_boxes_angle = np.random.randint(-180, 180, size=shape_two[:2]+[1]) + query_boxes = np.concatenate([query_boxes_array1, query_boxes_wh, query_boxes_angle], dtype=dtype, axis=-1) + + cpu_input1 = torch.from_numpy(boxes) + cpu_input2 = torch.from_numpy(query_boxes) + npu_input1 = cpu_input1.npu() + npu_input2 = cpu_input2.npu() + return boxes, query_boxes, npu_input1, npu_input2 + + def cpu_expect_result(self, dtype): + if dtype == np.float32: + output = np.array([[[0., 0.00045966, 0.],[0., 0., 0.]], + [[0., 0., 0.],[0., 0., 0.]], + [[0., 0., 0.],[0.00600622, 0.10504241, 0.]], + [[0., 0., 0.],[0., 0., 0.]]], dtype=np.float32) + else: + output = np.array([[[0., 0.00045966, 0.],[0., 0., 0.]], + [[0., 0., 0.],[0., 0., 0.]], + [[0., 0., 0.],[0.00600622, 0.10504241, 0.]], + [[0., 0., 0.],[0., 0., 0.]]], dtype=np.float16) + return output + + def npu_op_exec(self, box1, box2, trans=False): + output = torch_npu.npu_rotated_iou(box1, box2, trans, 0, True, 0.0, 0.0) + output = output.detach().cpu().numpy() + return output + + def test_rotated_iou_shape_format_fp32(self, device): + dtype = np.float32 + shape_format = [[dtype, -1, [4,2,5]],[dtype, -1, [4,3,5]], False] + cpu_input1, cpu_input2, npu_input1, npu_input2 = self.generate_rto_data(shape_format) + cpu_output = self.cpu_expect_result(dtype) + npu_output = self.npu_op_exec(npu_input1, npu_input2, shape_format[-1]) + self.assertRtolEqual(cpu_output, npu_output) + + def test_rotated_iou_shape_format_fp16(self, device): + dtype = np.float16 + shape_format = [[dtype, -1, [4,2,5]],[dtype, -1, [4,3,5]], False] + cpu_input1, cpu_input2, npu_input1, npu_input2 = self.generate_rto_data(shape_format) + cpu_output = self.cpu_expect_result(dtype) + npu_output = self.npu_op_exec(npu_input1, npu_input2, shape_format[-1]) + self.assertRtolEqual(cpu_output, npu_output) + +instantiate_device_type_tests(TestRotatedIou, globals(), except_for="cpu") +if __name__ == "__main__": + run_tests() \ No newline at end of file diff --git a/test/test_network_ops/test_rotated_overlaps.py b/test/test_network_ops/test_rotated_overlaps.py new file mode 100644 index 0000000000..92efac8af0 --- /dev/null +++ b/test/test_network_ops/test_rotated_overlaps.py @@ -0,0 +1,96 @@ +# Copyright (c) 2020, Huawei Technologies.All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +import torch_npu +import numpy as np + +from torch_npu.testing.common_utils import TestCase, run_tests +from torch_npu.testing.common_device_type import instantiate_device_type_tests + +class TestRotatedOverlaps(TestCase): + def generate_rto_data(self, item): + min_value, max_value = 30, 60 + scope = 20 + dtype = item[0][0] + shape_one = item[0][-1] + shape_two = item[1][-1] + + boxes_center = np.random.uniform(min_value, max_value, shape_one[:2] + [2]).astype(dtype) + boxes_wh = np.random.randint(1, scope, size=shape_one[:2] + [2]) + boxes_angle = np.random.randint(-180, 180, size=shape_one[:2] + [1]) + boxes = np.concatenate([boxes_center, boxes_wh, boxes_angle], axis=-1, dtype=dtype) + # query_boxes + query_boxes_center = np.random.uniform(min_value, max_value, shape_two[:2] + [2]).astype(dtype) + query_boxes_wh = np.random.randint(1, scope, size=shape_two[:2] + [2]) + query_boxes_angle = np.random.randint(-180, 180, size=shape_two[:2] + [1]) + query_boxes = np.concatenate([query_boxes_center, query_boxes_wh, query_boxes_angle], axis=-1, dtype=dtype) + + cpu_input1 = torch.from_numpy(boxes) + cpu_input2 = torch.from_numpy(query_boxes) + npu_input1 = cpu_input1.npu() + npu_input2 = cpu_input2.npu() + return cpu_input1, cpu_input2, npu_input1, npu_input2 + + def cpu_expect_result(self, dtype): + if dtype == np.float16: + output = np.array([[[0., 13.27, 1.022, 0.], + [0., 0., 54.12, 0.], + [0., 0., 0., 19.17]]], dtype=np.float16) + else: + output = np.array([[[0., 10.289731], + [0., 0.], + [0., 0.]]], dtype=np.float32) + return output + + def npu_op_exec(self, box1, box2, trans=False): + output = torch_npu.npu_rotated_overlaps(box1, box2, trans) + output = output.detach().cpu().numpy() + return output + + def test_rotated_overlaps_shape_format_fp32(self, device): + dtype = np.float32 + shape_list = [ + [[1, 3, 5], [1, 2, 5]], + ] + is_trans_list = [False] + shape_format = [[[dtype, -1, m[0]],[dtype, -1, m[1]], k] + for m in shape_list + for k in is_trans_list] + + for item in shape_format: + cpu_input1, cpu_input2, npu_input1, npu_input2 = self.generate_rto_data(item[:-1]) + cpu_output = self.cpu_expect_result(dtype) + npu_output = self.npu_op_exec(npu_input1, npu_input2, item[-1]) + # fp32 has't enough precission, but match model need currently. + self.assertRtolEqual(cpu_output, npu_output, prec=0.00005) + + def test_rotated_overlaps_shape_format_fp16(self, device): + dtype = np.float16 + shape_list = [ + [[1, 3, 5], [1, 4, 5]], + ] + # true is xyxyt, false is xywh format + is_trans_list = [False] + shape_format = [[[dtype, -1, m[0]],[dtype, -1, m[1]], k] + for m in shape_list + for k in is_trans_list] + for item in shape_format: + cpu_input1, cpu_input2, npu_input1, npu_input2 = self.generate_rto_data(item) + cpu_output = self.cpu_expect_result(dtype) + npu_output = self.npu_op_exec(npu_input1, npu_input2, item[-1]) + self.assertRtolEqual(cpu_output, npu_output) + +instantiate_device_type_tests(TestRotatedOverlaps, globals(), except_for="cpu") +if __name__ == "__main__": + run_tests() \ No newline at end of file diff --git a/torch_npu/csrc/aten/npu_native_functions.yaml b/torch_npu/csrc/aten/npu_native_functions.yaml index 0ab0fa4062..6ebbebcf8e 100644 --- a/torch_npu/csrc/aten/npu_native_functions.yaml +++ b/torch_npu/csrc/aten/npu_native_functions.yaml @@ -1918,6 +1918,10 @@ custom: variants: function, method - func: npu_stride_copy(Tensor self, int[] shape, int[] stride, Scalar storage_offset) -> Tensor - func: npu_stride_copy.out(Tensor self, int[] shape, int[] stride, Scalar storage_offset, *, Tensor(a!) out) -> Tensor(a!) + - func: npu_rotated_box_encode(Tensor self, Tensor gt_bboxes, Tensor weight) -> Tensor + - func: npu_rotated_box_decode(Tensor self, Tensor deltas, Tensor weight) -> Tensor + - func: npu_rotated_iou(Tensor self, Tensor query_boxes, bool trans=False, int mode=0, bool is_cross=True, float v_threshold=0.0, float e_threshold=0.0) -> Tensor + - func: npu_rotated_overlaps(Tensor self, Tensor query_boxes, bool trans=False) -> Tensor custom_autograd: - func: npu_convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, int groups) -> Tensor - func: npu_convolution_transpose(Tensor input, Tensor weight, Tensor? bias, int[] padding, int[] output_padding, int[] stride, int[] dilation, int groups) -> Tensor diff --git a/torch_npu/csrc/aten/ops/RotatedBoxDecodeKernelNpu.cpp b/torch_npu/csrc/aten/ops/RotatedBoxDecodeKernelNpu.cpp new file mode 100644 index 0000000000..89f1d1ae81 --- /dev/null +++ b/torch_npu/csrc/aten/ops/RotatedBoxDecodeKernelNpu.cpp @@ -0,0 +1,41 @@ +// Copyright (c) 2020 Huawei Technologies Co., Ltd +// Copyright (c) 2019, Facebook CORPORATION. +// All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" + +namespace at_npu { +namespace native { + +at::Tensor NPUNativeFunctions::npu_rotated_box_decode( + const at::Tensor& self, + const at::Tensor& deltas, + const at::Tensor& weight){ + at::Tensor result = OpPreparation::ApplyTensor(self); + at::Tensor weightContiguous = weight.to(at::Device(at::kCPU), at::kFloat); + at::ArrayRef weightList(weightContiguous.data_ptr(), weightContiguous.numel()); + + OpCommand cmd; + cmd.Name("RotatedBoxDecode") + .Input(self) + .Input(deltas) + .Output(result) + .Attr("weight", weightList) + .Run(); + return result; +} +} //namespace native +} //namespace at_npu \ No newline at end of file diff --git a/torch_npu/csrc/aten/ops/RotatedBoxEncodeKernelNpu.cpp b/torch_npu/csrc/aten/ops/RotatedBoxEncodeKernelNpu.cpp new file mode 100644 index 0000000000..6706622b4f --- /dev/null +++ b/torch_npu/csrc/aten/ops/RotatedBoxEncodeKernelNpu.cpp @@ -0,0 +1,41 @@ +// Copyright (c) 2020 Huawei Technologies Co., Ltd +// Copyright (c) 2019, Facebook CORPORATION. +// All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" + +namespace at_npu { +namespace native { + +at::Tensor NPUNativeFunctions::npu_rotated_box_encode( + const at::Tensor& self, + const at::Tensor& gtBox, + const at::Tensor& weight){ + at::Tensor result = OpPreparation::ApplyTensor(self); + at::Tensor weightContiguous = weight.to(at::Device(at::kCPU), at::kFloat); + at::ArrayRef weightList(weightContiguous.data_ptr(), weightContiguous.numel()); + + OpCommand cmd; + cmd.Name("RotatedBoxEncode") + .Input(self) + .Input(gtBox) + .Output(result) + .Attr("weight", weightList) + .Run(); + return result; +} +} //namespace native +} //namespace at_npu \ No newline at end of file diff --git a/torch_npu/csrc/aten/ops/RotatedIouKernelNpu.cpp b/torch_npu/csrc/aten/ops/RotatedIouKernelNpu.cpp new file mode 100644 index 0000000000..3cd330aa2a --- /dev/null +++ b/torch_npu/csrc/aten/ops/RotatedIouKernelNpu.cpp @@ -0,0 +1,81 @@ +// Copyright (c) 2020 Huawei Technologies Co., Ltd +// Copyright (c) 2019, Facebook CORPORATION. +// All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" + +namespace at_npu { +namespace native { + +at::Tensor& rotated_iou_npu_nocheck( + at::Tensor& iou, + const at::Tensor& boxes, + const at::Tensor& query_boxes, + bool trans, + int64_t mode, + bool is_cross, + double v_threshold, + double e_threshold) { + string mode_str = (mode == 0) ? "iou" : "iof"; + + OpCommand cmd; + cmd.Name("RotatedIou") + .Input(boxes) + .Input(query_boxes) + .Output(iou) + .Attr("trans", trans) + .Attr("mode", mode_str) + .Attr("is_cross", is_cross) + .Attr("value", static_cast(v_threshold)) + .Attr("value", static_cast(e_threshold)) + .Run(); + return iou; +} + +at::Tensor NPUNativeFunctions::npu_rotated_iou( + const at::Tensor& boxes, + const at::Tensor& query_boxes, + bool trans, + int64_t mode, + bool is_cross, + double v_threshold, + double e_threshold) { + TORCH_CHECK(boxes.ndimension() == 3 && query_boxes.ndimension() == 3); + + auto origin_dtype = boxes.scalar_type(); + + at::Tensor boxesOk = boxes.permute({0, 2, 1}); + if (boxesOk.scalar_type() == at::kHalf){ + boxesOk = NPUNativeFunctions::npu_dtype_cast(boxesOk, at::kFloat); + } + at::Tensor query_boxesOk = query_boxes.permute({0, 2, 1}); + if (query_boxesOk.scalar_type() == at::kHalf){ + query_boxesOk = NPUNativeFunctions::npu_dtype_cast(query_boxesOk, at::kFloat); + } + + int64_t B = boxesOk.size(0); + int64_t N = boxesOk.size(-1); + int64_t K = query_boxesOk.size(-1); + + c10::SmallVector output_size({B, N, K}); + at::Tensor iou = OpPreparation::ApplyTensor(boxesOk, output_size); + + rotated_iou_npu_nocheck(iou, boxesOk, query_boxesOk, trans, mode, is_cross, v_threshold, e_threshold); + iou = NPUNativeFunctions::npu_dtype_cast(iou, origin_dtype); + return iou; +} +} // namespace native +} // namespace at \ No newline at end of file diff --git a/torch_npu/csrc/aten/ops/RotatedOverlapsKernelNpu.cpp b/torch_npu/csrc/aten/ops/RotatedOverlapsKernelNpu.cpp new file mode 100644 index 0000000000..df7e916ce0 --- /dev/null +++ b/torch_npu/csrc/aten/ops/RotatedOverlapsKernelNpu.cpp @@ -0,0 +1,63 @@ +// Copyright (c) 2020 Huawei Technologies Co., Ltd +// Copyright (c) 2019, Facebook CORPORATION. +// All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" + +namespace at_npu { +namespace native { + +at::Tensor& rotated_overlaps_npu_nocheck( + at::Tensor& overlaps, + const at::Tensor& self, + const at::Tensor& query_boxes, + bool trans) { + OpCommand cmd; + cmd.Name("RotatedOverlaps") + .Input(self) + .Input(query_boxes) + .Output(overlaps) + .Attr("trans", trans) + .Run(); + return overlaps; +} + +at::Tensor NPUNativeFunctions::npu_rotated_overlaps( + const at::Tensor& self, + const at::Tensor& query_boxes, + bool trans) { + TORCH_CHECK(self.ndimension() == 3 && query_boxes.ndimension() == 3, + "boxes' dim should be equal to query_boxes' ndimension() ", + "and equal to 3!"); + auto origin_dtype = self.scalar_type(); + // the Op only support fp32 currently! + at::Tensor selfCp = NPUNativeFunctions::npu_dtype_cast(self, at::kFloat).permute({0, 2, 1}); + at::Tensor queryBoxesCp = NPUNativeFunctions::npu_dtype_cast(query_boxes, at::kFloat).permute({0, 2, 1}); + + int64_t B = selfCp.size(0); + int64_t N = selfCp.size(-1); + int64_t K = queryBoxesCp.size(-1); + + c10::SmallVector output_size({B, N, K}); + at::Tensor overlaps = OpPreparation::ApplyTensor(selfCp, output_size); + + rotated_overlaps_npu_nocheck(overlaps, selfCp, queryBoxesCp, trans); + at::Tensor overlaps = NPUNativeFunctions::npu_dtype_cast(overlaps, origin_dtype); + + return overlaps; +} +} // namespace native +} // namespace at \ No newline at end of file -- Gitee