From 68c95af3c317b00c8c3b10382d8a80cbae3371d1 Mon Sep 17 00:00:00 2001 From: zhoufan37 Date: Tue, 1 Mar 2022 17:49:58 +0800 Subject: [PATCH] Add YoloBugfix Operator --- .../test_network_ops/test_yolo_boxes_encode.py | 18 +++++++++--------- .../csrc/aten/ops/YoloBoxesEncodeKernelNpu.cpp | 1 + 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/test/test_network_ops/test_yolo_boxes_encode.py b/test/test_network_ops/test_yolo_boxes_encode.py index 23656d7918..50b4fbe728 100644 --- a/test/test_network_ops/test_yolo_boxes_encode.py +++ b/test/test_network_ops/test_yolo_boxes_encode.py @@ -25,23 +25,23 @@ class TestYoloBoxesEncode(TestCase): return out.detach().numpy() def test_yolo_boxes_encode(self, device="npu"): - anchor_boxes = [(2, 4)] - gt_bboxes = [(2 ,4)] - stride = [[2, 2]] - exoutput_list = [[[0.7921727, 0.5314963, -0.74224466, -13.815511], + anchor_boxes_list = [(2, 4)] + gt_bboxes_list = [(2 ,4)] + stride_list = [[2, 2]] + expect_cpu_list = [[[0.7921727, 0.5314963, -0.74224466, -13.815511], [0.7360072, 0.58343244, 4.3334002, -0.51378196]]] - shape_format = [[i, j, k, h] for i in anchor_boxes - for j in gt_bboxes for k in stride for h in exoutput_list] + shape_format = [[i, j, k, h] for i in anchor_boxes_list + for j in gt_bboxes_list for k in stride_list for h in expect_cpu_list] for item in shape_format: anchor_boxes_tensor = torch.rand(item[0], dtype=torch.float32).to("npu") gt_bboxes_tensor = torch.rand(item[1], dtype=torch.float32).to("npu") - stride_tensor = torch.tensor(item[2], dtype=torch.float32).to("npu") - exoutput_cpu_tensor = torch.tensor(item[3], dtype=torch.float32) + stride_tensor = torch.tensor(item[2], dtype=torch.int32).to("npu") + expect_cpu_tensor = torch.tensor(item[3], dtype=torch.float32) npu_output = self.npu_op_exec(anchor_boxes_tensor, gt_bboxes_tensor, stride_tensor, False) - self.assertRtolEqual(exoutput_cpu_tensor.numpy(), npu_output) + self.assertRtolEqual(expect_cpu_tensor.numpy(), npu_output) if __name__ == "__main__": diff --git a/torch_npu/csrc/aten/ops/YoloBoxesEncodeKernelNpu.cpp b/torch_npu/csrc/aten/ops/YoloBoxesEncodeKernelNpu.cpp index 877c0a79c0..c3263a8190 100644 --- a/torch_npu/csrc/aten/ops/YoloBoxesEncodeKernelNpu.cpp +++ b/torch_npu/csrc/aten/ops/YoloBoxesEncodeKernelNpu.cpp @@ -71,5 +71,6 @@ at::Tensor NPUNativeFunctions::npu_yolo_boxes_encode( .Run(); return result; } + } // namespace native } // namespace at_npu \ No newline at end of file -- Gitee