From a0a9461165115799e23b6fa9b241e15cf8fe20bf Mon Sep 17 00:00:00 2001 From: Guanzhong Chen Date: Fri, 15 Dec 2023 16:22:27 +0800 Subject: [PATCH 1/6] 1 --- .../built-in/cv/detection/ssd/README.md | 217 +++++ .../built-in/cv/detection/ssd/acc_dataset.py | 87 ++ .../built-in/cv/detection/ssd/coco_eval.py | 98 +++ .../built-in/cv/detection/ssd/export.py | 83 ++ .../built-in/cv/detection/ssd/get_info.py | 64 ++ .../built-in/cv/detection/ssd/perf.py | 105 +++ .../cv/detection/ssd/ssd-requirements.txt | 7 + .../cv/detection/ssd/ssd_postprocess.py | 294 +++++++ .../cv/detection/ssd/ssd_preprocess.py | 72 ++ .../built-in/cv/detection/ssd/txt_to_json.py | 115 +++ .../cv/detection/ssd/update_ssd_mmdet.diff | 798 ++++++++++++++++++ 11 files changed, 1940 insertions(+) create mode 100644 AscendIE/TorchAIE/built-in/cv/detection/ssd/README.md create mode 100644 AscendIE/TorchAIE/built-in/cv/detection/ssd/acc_dataset.py create mode 100644 AscendIE/TorchAIE/built-in/cv/detection/ssd/coco_eval.py create mode 100644 AscendIE/TorchAIE/built-in/cv/detection/ssd/export.py create mode 100644 AscendIE/TorchAIE/built-in/cv/detection/ssd/get_info.py create mode 100644 AscendIE/TorchAIE/built-in/cv/detection/ssd/perf.py create mode 100644 AscendIE/TorchAIE/built-in/cv/detection/ssd/ssd-requirements.txt create mode 100644 AscendIE/TorchAIE/built-in/cv/detection/ssd/ssd_postprocess.py create mode 100644 AscendIE/TorchAIE/built-in/cv/detection/ssd/ssd_preprocess.py create mode 100644 AscendIE/TorchAIE/built-in/cv/detection/ssd/txt_to_json.py create mode 100644 AscendIE/TorchAIE/built-in/cv/detection/ssd/update_ssd_mmdet.diff diff --git a/AscendIE/TorchAIE/built-in/cv/detection/ssd/README.md b/AscendIE/TorchAIE/built-in/cv/detection/ssd/README.md new file mode 100644 index 0000000000..3035c50b37 --- /dev/null +++ b/AscendIE/TorchAIE/built-in/cv/detection/ssd/README.md @@ -0,0 +1,217 @@ +# SSD模型-推理指导 + +- [概述](#ZH-CN_TOPIC_0000001172161501) + + - [输入输出数据](#ZH-CN_TOPIC_0000001126281702) + +- [推理环境准备](#ZH-CN_TOPIC_0000001126281702) + +- [快速上手](#ZH-CN_TOPIC_0000001126281700) + + - [获取源码](#section4622531142816) + - [准备数据集](#section183221994411) + - [模型推理](#section741711594517) + +- [模型推理性能&精度](#ZH-CN_TOPIC_0000001172201573) + + ****** + +# 概述 + +SSD将detection转化为regression的思路,可以一次完成目标定位与分类。该算法基于Faster RCNN中的Anchor,提出了相似的Prior box;该算法修改了传统的SSD网络:将SSD的FC6和FC7层转化为卷积层,去掉所有的Dropout层和FC8层。同时加入基于特征金字塔的检测方式,在不同感受野的feature map上预测目标。 + +- 参考实现: + + ```shell + url=https://github.com/open-mmlab/mmdetection.git + branch=master + commit_id=a21eb25535f31634cef332b09fc27d28956fb24b + model_name=ssd + ``` + +## 输入输出数据 + +- 输入数据 + + | 输入数据 | 数据类型 | 大小 | 数据排布格式 | + | -------- | -------- | ------------------------- | ------------ | + | input | RGB_FP32 | batchsize x 3 x 300 x 300 | NCHW | + +- 输出数据 + + | 输出数据 | 数据类型 | 大小 | 数据排布格式 | + | -------- | -------- | --------------------- | ------------ | + | boxes | FLOAT32 | batchsize x 8732 x 4 | ND | + | labels | FLOAT32 | batchsize x 8732 x 80 | ND | + +# 推理环境准备 + +- 该模型需要两套环境切换运行,用于执行推理的环境(包括插件与驱动)如下 + + **表 1** 版本配套表 + + | 配套 | 版本 | 环境准备指导 | + |---------| ------- | ------------------------------------------------------------ | + | 固件与驱动 | 23.0.rc1 | [Pytorch框架推理环境准备](https://www.hiascend.com/document/detail/zh/ModelZoo/pytorchframework/pies) | + | CANN | 7.0.RC1.alpha003 | - | + | Python | 3.9.11 | - | + | PyTorch | 2.0.1 | - | + | Torch_AIE | 6.3.rc2 | - | + +- 用于执行前后处理以及模型导出,则需要另一套环境,建议使用conda命令构建虚拟环境,并安装相应的包 + +``` +conda create --name ssd python=3.7.16 +``` + +# 快速上手 + +## 获取源码 + +1. 获取SSD源代码并修改mmdetection。 + ```shell + git clone https://github.com/open-mmlab/mmdetection.git + cd mmdetection + git reset --hard a21eb25535f31634cef332b09fc27d28956fb24b + patch -p1 < ../update_ssd_mmdet.diff + pip install -v -e . + cd .. + ``` + +2. 安装依赖。 + ```shell + pip3 install -r ssd-requirements.txt + ``` + +## 准备数据集 + +1. 获取原始数据集。(解压命令参考tar –xvf \*.tar与 unzip \*.zip) + + 推理数据集采用 [coco_val_2017](http://images.cocodataset.org),数据集下载后存放路径:`dataset=/root/datasets/coco` + + 目录结构: + + ``` + ├── coco + │ ├── val2017 + │ ├── annotations + │ ├──instances_val2017.json + ``` + +2. 数据预处理(使用torch 1.8环境)。 + + 将原始数据集转换为模型输入的二进制数据。执行 `ssd_preprocess.py` 脚本。 + + ```shell + python ssd_preprocess.py \ + --image_folder_path $dataset/val2017 \ + --bin_folder_path val2017_ssd_bin + ``` + + - 参数说明: + + - --image_folder_path:原始数据验证集(.jpg)所在路径。 + - --bin_folder_path:输出的二进制文件(.bin)所在路径。 + + 每个图像对应生成一个二进制文件。 + +3. 生成数据集info文件(使用torch 1.8环境)。 + + 运行 `get_info.py` 脚本,生成图片数据info文件。 + ```shell + python get_info.py jpg $dataset/val2017 coco2017_ssd_jpg.info + ``` + + - 参数说明: + + - 第一个参数:生成的数据集文件格式。 + - 第二个参数:预处理后的数据文件相对路径。 + - 第三个参数:即将生成的info文件名。 + + 运行成功后,在当前目录中生成 `coco2017_ssd_jpg.info`。 + +## 模型推理 + +1. 模型转换。 + + 使用PyTorch将模型权重文件.pth转换为.ts文件。 + 1. 获取权重文件。 + + 获取经过训练的权重文件:[ssd300_coco_20200307-a92d2092.pth](http://download.openmmlab.com/mmdetection/v2.0/ssd/ssd300_coco/ssd300_coco_20200307-a92d2092.pth) + + 2. 导出onnx文件。 + + 使用`export.py`导出ts文件(使用torch 1.8环境) + + ``` + python export.py \ + --checkpoint ./ssd300_coco_20200307-a92d2092.pth \ + --mmdet_path ./mmdetection \ + --shape=300 \ + --mean 123.675 116.28 103.53 \ + --std 1 1 1 + ``` + + - 参数说明: + + - checkpoint:原始pth文件所在路径 + - mmdet_path:github拉入文件夹路径 + - shape:图像尺寸 + +2. 开始推理验证。 + + 1. 执行推理。 + ```shell + python3 acc_dataset.py --ts_path ./ssd300_coco.ts --img_bin_path ./coco2017_bin --save_dir ./pyinfer_res_npu + ``` + + - 参数说明: + + - ts_path:导出ts文件路径 + - img_bin_path:图片预处理得到的bin文件夹所在路径 + - save_dir:保存推理结果的路径 + + 2. 精度验证。 + + 调用coco_eval.py评测map精度: + + ```shell + det_path=postprocess_out + python ssd_postprocess.py \ + --bin_data_path=out/2022_*/ \ + --score_threshold=0.02 \ + --test_annotation=coco2017_ssd_jpg.info \ + --nms_pre 200 \ + --det_results_path ${det_path} + python txt_to_json.py --npu_txt_path ${det_path} + python coco_eval.py --ground_truth /root/datasets/coco/annotations/instances_val2017.json + ``` + + - 参数说明: + + - --bin_data_path:为推理结果存放的路径。 + - --score_threshold:得分阈值。 + - --test_annotation:原始图片信息文件。 + - --nms_pre:每张图片获取框数量的阈值。 + - --det_results_path:后处理输出路径。 + - --npu_txt_path:后处理输出路径。 + - --ground_truth:instances_val2017.json文件路径。 + +# 模型推理性能&精度 + +调用ACL接口推理计算,性能参考下列数据。 + +| | mAP | +| --------- | -------- | +| 310P3精度 | mAP=25.4 | + + +| Throughput | 310*4 | 310P3 | 310B1 | +| ---------- | -------- | -------- | ----- | +| bs1 | 179.194 | 298.5514 | 75.42 | +| bs4 | 207.596 | 337.0112 | 77.9 | +| bs8 | 211.7312 | 323.5662 | 79.77 | +| bs16 | 211.288 | 318.1392 | 77.84 | +| bs32 | 200.2948 | 318.7303 | 79.78 | +| bs64 | 196.4192 | 313.0790 | 48.36 | +| 最优batch | 211.7312 | 337.0112 | 79.77 | \ No newline at end of file diff --git a/AscendIE/TorchAIE/built-in/cv/detection/ssd/acc_dataset.py b/AscendIE/TorchAIE/built-in/cv/detection/ssd/acc_dataset.py new file mode 100644 index 0000000000..650772ebd3 --- /dev/null +++ b/AscendIE/TorchAIE/built-in/cv/detection/ssd/acc_dataset.py @@ -0,0 +1,87 @@ +import argparse +from tqdm import tqdm +import os + +import torch +import numpy as np + + +USE_NPU = True +INPUT_WIDTH = 300 +INPUT_HEIGHT = 300 + + +def parse_args(): + args = argparse.ArgumentParser(description="A program that operates in 'om' or 'ts' mode.") + args.add_argument('--ts_path',help='MobilenetV1 ts file path', type=str, + default='./ssd300_coco.ts' + ) + args.add_argument("--batch_size", type=int, default=1, help="batch size.") + args.add_argument('--img_bin_path',help='image bin path', type=str, + default='./coco2017_bin' + ) + args.add_argument('--save_dir',help='result save dir', type=str, + default='./pyinfer_res_npu' + ) + return args.parse_args() + + +if __name__ == '__main__': + infer_times = 100 + om_cost = 0 + pt_cost = 0 + opts = parse_args() + batch_size = opts.batch_size + directory_path = opts.img_bin_path + save_dir = opts.save_dir + + model = torch.jit.load(opts.ts_path) + + if USE_NPU: + + import torch_aie + from torch_aie import _enums + + input_info = [torch_aie.Input((batch_size, 3, INPUT_WIDTH, INPUT_HEIGHT))] + torch_aie.set_device(0) + print("start compile") + model = torch_aie.compile( + model, + inputs=input_info, + precision_policy=_enums.PrecisionPolicy.FP32, + soc_version='Ascend310P3', + optimization_level=0 + ) + print("end compile") + + if not os.path.exists(save_dir): + os.makedirs(save_dir) + print(f"Directory '{save_dir}' created.") + + # Iterate through each file in the directory + for filename in tqdm(os.listdir(directory_path)): + filepath = os.path.join(directory_path, filename) + if os.path.isfile(filepath): + try: + with open(filepath, 'rb') as file: + array_data = np.fromfile(file, dtype=np.float32).reshape((batch_size, 3, INPUT_WIDTH, INPUT_HEIGHT)) + torch_tensor = torch.tensor(array_data) + + if USE_NPU: + input_tensor_npu = torch_tensor.to("npu:0") + aieout_npu = model(input_tensor_npu) + first_out = aieout_npu[0].to("cpu").detach().numpy() + second_out = aieout_npu[1].to("cpu").detach().numpy() + else: + tsout = model(torch_tensor) + first_out = tsout[0].detach().numpy() + second_out = tsout[1].detach().numpy() + + first_out.tofile(os.path.join(save_dir, filename.split(".")[0] + "_0.bin")) + second_out.tofile(os.path.join(save_dir, filename.split(".")[0] + "_1.bin")) + except Exception as e: + print(f'Error reading {filename}: {e}') + + + + diff --git a/AscendIE/TorchAIE/built-in/cv/detection/ssd/coco_eval.py b/AscendIE/TorchAIE/built-in/cv/detection/ssd/coco_eval.py new file mode 100644 index 0000000000..2a3ab83f15 --- /dev/null +++ b/AscendIE/TorchAIE/built-in/cv/detection/ssd/coco_eval.py @@ -0,0 +1,98 @@ +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +from tqdm import tqdm + + +import numpy as np +from pycocotools.coco import COCO +from pycocotools.cocoeval import COCOeval + + +CLASSES = ('person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', + 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', + 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', + 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', + 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', + 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', + 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', + 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', + 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', + 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', + 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', + 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', + 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', + 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush') + + +def coco_evaluation(annotation_json, result_json): + cocoGt = COCO(annotation_json) + cocoDt = cocoGt.loadRes(result_json) + iou_thrs = np.linspace(.5, 0.95, int(np.round((0.95 - .5) / .05)) + 1, endpoint=True) + iou_type = 'bbox' + + cocoEval = COCOeval(cocoGt, cocoDt, iou_type) + cocoEval.params.catIds = cocoGt.get_cat_ids(cat_names=CLASSES) + cocoEval.params.imgIds = cocoGt.get_img_ids() + cocoEval.params.maxDets = [100, 300, 1000] # proposal number for evaluating recalls/mAPs. + cocoEval.params.iouThrs = iou_thrs + + cocoEval.evaluate() + cocoEval.accumulate() + cocoEval.summarize() + + # mapping of cocoEval.stats + coco_metric_names = { + 'mAP': 0, + 'mAP_50': 1, + 'mAP_75': 2, + 'mAP_s': 3, + 'mAP_m': 4, + 'mAP_l': 5, + 'AR@100': 6, + 'AR@300': 7, + 'AR@1000': 8, + 'AR_s@1000': 9, + 'AR_m@1000': 10, + 'AR_l@1000': 11 + } + + metric_items = ['mAP', 'mAP_50', 'mAP_75', 'mAP_s', 'mAP_m', 'mAP_l'] + eval_results = {} + + for metric_item in tqdm(metric_items): + key = f'bbox_{metric_item}' + val = float( + f'{cocoEval.stats[coco_metric_names[metric_item]]:.3f}' + ) + eval_results[key] = val + ap = cocoEval.stats[:6] + eval_results['bbox_mAP_copypaste'] = ( + f'{ap[0]:.3f} {ap[1]:.3f} {ap[2]:.3f} {ap[3]:.3f} ' + f'{ap[4]:.3f} {ap[5]:.3f}') + + return eval_results + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument("--ground_truth", default="instances_val2017.json") + parser.add_argument("--detection_result", default="coco_detection_result.json") + args = parser.parse_args() + result = coco_evaluation(args.ground_truth, args.detection_result) + print(result) + with open('./coco_detection_result.txt', 'w') as f: + for key, value in result.items(): + f.write(key + ': ' + str(value) + '\n') diff --git a/AscendIE/TorchAIE/built-in/cv/detection/ssd/export.py b/AscendIE/TorchAIE/built-in/cv/detection/ssd/export.py new file mode 100644 index 0000000000..b7dab1fd16 --- /dev/null +++ b/AscendIE/TorchAIE/built-in/cv/detection/ssd/export.py @@ -0,0 +1,83 @@ +import os +import argparse + +import torch + +from mmdet.core import (build_model_from_cfg, generate_inputs_and_wrap_model) + + +def pytorch2onnx(config_path, + checkpoint_path, + input_img, + input_shape, + normalize_cfg=None): + + input_config = { + 'input_shape': input_shape, + 'input_path': input_img, + 'normalize_cfg': normalize_cfg + } + + # prepare original model and meta for verifying the onnx model + orig_model = build_model_from_cfg(config_path, checkpoint_path) + print("type of orig_model:", type(orig_model)) + model, tensor_data = generate_inputs_and_wrap_model( + config_path, checkpoint_path, input_config) + + ts_model = torch.jit.trace(model, tensor_data) + ts_model.save("./ssd300_coco_torch201.ts") + + +def parse_args(): + parser = argparse.ArgumentParser( + description='Convert MMDetection models to ONNX') + parser.add_argument('--checkpoint', help='checkpoint file', type=str, default='./ssd300_coco_20200307-a92d2092.pth') + parser.add_argument('--mmdet_path',help='mmdetection repo folder path', type=str, + default='./mmdetection' + ) + parser.add_argument( + '--shape', + type=int, + nargs='+', + default=[800, 1216], + help='input image size') + parser.add_argument( + '--mean', + type=float, + nargs='+', + default=[123.675, 116.28, 103.53], + help='mean value used for preprocess input data') + parser.add_argument( + '--std', + type=float, + nargs='+', + default=[58.395, 57.12, 57.375], + help='variance value used for preprocess input data') + args = parser.parse_args() + return args + + +if __name__ == '__main__': + args = parse_args() + cfg = os.path.join(args.mmdet_path, "configs/ssd/ssd300_coco.py") + input_img = os.path.join(args.mmdet_path, "tests/data/color.jpg") + + if len(args.shape) == 1: + input_shape = (1, 3, args.shape[0], args.shape[0]) + elif len(args.shape) == 2: + input_shape = (1, 3) + tuple(args.shape) + else: + raise ValueError('invalid input shape') + + assert len(args.mean) == 3 + assert len(args.std) == 3 + + normalize_cfg = {'mean': args.mean, 'std': args.std} + + # convert model to onnx file + pytorch2onnx( + cfg, + args.checkpoint, + input_img, + input_shape, + normalize_cfg=normalize_cfg) diff --git a/AscendIE/TorchAIE/built-in/cv/detection/ssd/get_info.py b/AscendIE/TorchAIE/built-in/cv/detection/ssd/get_info.py new file mode 100644 index 0000000000..806398b3cc --- /dev/null +++ b/AscendIE/TorchAIE/built-in/cv/detection/ssd/get_info.py @@ -0,0 +1,64 @@ +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""get info""" + +import os +import sys +from glob import glob + +import cv2 + + +def get_bin_info(bin_file_path, bin_info_name, bin_width, bin_height): + """get_bin_info""" + bin_images = glob(os.path.join(bin_file_path, '*.bin')) + with open(bin_info_name, 'w') as info_file: + for index, img in enumerate(bin_images): + content = ' '.join([str(index), img, bin_width, bin_height]) + info_file.write(content) + info_file.write('\n') + + +def get_jpg_info(jpg_file_path, jpg_info_name): + """get_jpg_info""" + extensions = ['jpg', 'jpeg', 'JPG', 'JPEG'] + image_names = [] + for extension in extensions: + image_names.append(glob(os.path.join(jpg_file_path, '*.' + extension))) + with open(jpg_info_name, 'w') as jpg_file: + for image_name in image_names: + if len(image_name) == 0: + continue + else: + for index, img in enumerate(image_name): + img_cv = cv2.imread(img) + shape = img_cv.shape + jpg_width, jpg_height = shape[1], shape[0] + content = ' '.join([str(index), img, str(jpg_width), str(jpg_height)]) + jpg_file.write(content) + jpg_file.write('\n') + + +if __name__ == '__main__': + file_type = sys.argv[1] + file_path = sys.argv[2] + info_name = sys.argv[3] + if file_type == 'bin': + width = sys.argv[4] + height = sys.argv[5] + assert len(sys.argv) == 6, 'The number of input parameters must be equal to 5' + get_bin_info(file_path, info_name, width, height) + elif file_type == 'jpg': + assert len(sys.argv) == 4, 'The number of input parameters must be equal to 3' + get_jpg_info(file_path, info_name) \ No newline at end of file diff --git a/AscendIE/TorchAIE/built-in/cv/detection/ssd/perf.py b/AscendIE/TorchAIE/built-in/cv/detection/ssd/perf.py new file mode 100644 index 0000000000..58ebef8a50 --- /dev/null +++ b/AscendIE/TorchAIE/built-in/cv/detection/ssd/perf.py @@ -0,0 +1,105 @@ +import argparse +import time +from tqdm import tqdm + +import torch +import numpy as np + +import torch_aie +from torch_aie import _enums + + +INPUT_WIDTH = 300 +INPUT_HEIGHT = 300 + +def parse_args(): + args = argparse.ArgumentParser(description="A program that operates in 'om' or 'ts' mode.") + args.add_argument("--mode", choices=["om", "ts"], required=True, help="Specify the mode ('om' or 'ts').") + args.add_argument('--om_path',help='MobilenetV1 om file path', type=str, + default='/onnx/mobilenetv1/mobilenet-v1_bs1.om' + ) + args.add_argument('--ts_path',help='MobilenetV1 ts file path', type=str, + default='/onnx/ssd/ssd300_coco.ts' + ) + args.add_argument("--batch-size", type=int, default=4, help="batch size.") + return args.parse_args() + +if __name__ == '__main__': + infer_times = 100 + om_cost = 0 + pt_cost = 0 + opts = parse_args() + TS_PATH = opts.ts_path + OM_PATH = opts.om_path + BATCH_SIZE = opts.batch_size + + if opts.mode == "om": + om_model = InferSession(0, OM_PATH) + for _ in tqdm(range(0, infer_times)): + dummy_input = np.random.randn(BATCH_SIZE, 3, INPUT_WIDTH, INPUT_HEIGHT).astype(np.uint8) + start = time.time() + output = om_model.infer([dummy_input], 'static', custom_sizes=90000000) # revise static + # output = om_model.infer([dummy_input], 'dymshape', custom_sizes=4000) # revise dynm fp32为4个字节,输出为1x1000 + cost = time.time() - start + om_cost += cost + + if opts.mode == "ts": + ts_model = torch.jit.load(TS_PATH) + + input_info = [torch_aie.Input((BATCH_SIZE, 3, INPUT_WIDTH, INPUT_HEIGHT))] + + torch_aie.set_device(0) + print("start compile") + torchaie_model = torch_aie.compile( + ts_model, + inputs=input_info, + precision_policy=_enums.PrecisionPolicy.FP16, + soc_version='Ascend310P3', + ) + print("end compile") + torchaie_model.eval() + + print("start export") + torch_aie.export_engine(ts_model, + "forward", + "ssd.om", + inputs=input_info, + precision_policy=_enums.PrecisionPolicy.FP16, + soc_version='Ascend310P3') + print("end export") + + dummy_input = np.random.randn(BATCH_SIZE, 3, INPUT_WIDTH, INPUT_HEIGHT).astype(np.float32) + input_tensor = torch.Tensor(dummy_input) + loops = 100 + warm_ctr = 10 + + default_stream = torch_aie.npu.default_stream() + time_cost = 0 + + input_tensor = input_tensor.to("npu") + while warm_ctr: + _ = torchaie_model(input_tensor) + default_stream.synchronize() + warm_ctr -= 1 + + print("send to npu") + input_tensor = input_tensor.to("npu") + print("finish sent") + for i in range(loops): + t0 = time.time() + _ = torchaie_model(input_tensor) # tuple of 2 lists of len 6 + default_stream.synchronize() + t1 = time.time() + time_cost += (t1 - t0) + print(i) + + print(f"fps: {loops} * {BATCH_SIZE} / {time_cost : .3f} samples/s") + print("torch_aie fps: ", loops * BATCH_SIZE / time_cost) + + from datetime import datetime + current_time = datetime.now() + formatted_time = current_time.strftime("%Y-%m-%d %H:%M:%S") + print("Current Time:", formatted_time) + + + diff --git a/AscendIE/TorchAIE/built-in/cv/detection/ssd/ssd-requirements.txt b/AscendIE/TorchAIE/built-in/cv/detection/ssd/ssd-requirements.txt new file mode 100644 index 0000000000..d3f8a8897c --- /dev/null +++ b/AscendIE/TorchAIE/built-in/cv/detection/ssd/ssd-requirements.txt @@ -0,0 +1,7 @@ +protobuf==3.20.0 +Cython==0.29.35 +matplotlib==3.5.3 +mmpycocotools==12.0.3 +mmcv-full==1.2.7 +torch==1.8.1 +tqdm==4.66.1 \ No newline at end of file diff --git a/AscendIE/TorchAIE/built-in/cv/detection/ssd/ssd_postprocess.py b/AscendIE/TorchAIE/built-in/cv/detection/ssd/ssd_postprocess.py new file mode 100644 index 0000000000..09dc47196e --- /dev/null +++ b/AscendIE/TorchAIE/built-in/cv/detection/ssd/ssd_postprocess.py @@ -0,0 +1,294 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""coco postprocess""" + +import os +import numpy as np +import argparse +import cv2 +import warnings +import torch +import time +try: + from torch import npu_batch_nms as NMSOp + NMS_ON_NPU = True +except: + from torchvision.ops import batched_nms as NMSOp + NMS_ON_NPU = False + +CLASSES = ['person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', + 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', + 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', + 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', + 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', + 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', + 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', + 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', + 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', + 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', + 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', + 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', + 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', + 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'] + + +def coco_postprocess(bbox, image_size, net_input_width, net_input_height): + """ + This function is postprocessing for FasterRCNN output. + + Before calling this function, reshape the raw output of FasterRCNN to + following form + numpy.ndarray: + [x, y, width, height, confidence, probability of 80 classes] + shape: (100,) + The postprocessing restore the bounding rectangles of FasterRCNN output + to origin scale and filter with non-maximum suppression. + + :param bbox: a numpy array of the FasterRCNN output + :param image_path: a string of image path + :return: three list for best bound, class and score + """ + w = image_size[0] + h = image_size[1] + scale_w = net_input_width / w + scale_h = net_input_height / h + + # cal predict box on the image src + pbox = bbox.copy() + pbox[:, 0] = (bbox[:, 0]) / scale_w + pbox[:, 1] = (bbox[:, 1]) / scale_h + pbox[:, 2] = (bbox[:, 2]) / scale_w + pbox[:, 3] = (bbox[:, 3]) / scale_h + return pbox + + +def np_clip_bbox(bboxes, max_shape): + x1, y1, x2, y2 = bboxes[:, 0], bboxes[:, 1], bboxes[:, 2], bboxes[:, 3] + h, w = max_shape + x1 = x1.clip(min=0, max=w) + y1 = y1.clip(min=0, max=h) + x2 = x2.clip(min=0, max=w) + y2 = y2.clip(min=0, max=h) + bboxes = np.stack([x1, y1, x2, y2], axis=-1) + return bboxes + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument("--bin_data_path", + default="/onnx/ssd/cpp_infer_res_cpu") + parser.add_argument("--test_annotation", default="./coco2017_ssd_jpg.info") + parser.add_argument("--det_results_path", default="./postprocess_out_cpu/") + parser.add_argument("--net_out_num", default=2, type=int) + parser.add_argument("--num_pred_box", default=8732, type=int) + parser.add_argument("--nms_pre", default=200, type=int) + parser.add_argument("--net_input_width", default=300, type=int) + parser.add_argument("--net_input_height", default=300, type=int) + parser.add_argument("--min_bbox_size", default=0.01, type=float) + parser.add_argument("--score_threshold", default=0.02, type=float) + parser.add_argument("--nms", default=True, type=bool) + parser.add_argument("--iou_threshold", default=0.45, type=float) + parser.add_argument("--max_per_img", default=200, type=int) + parser.add_argument("--ifShowDetObj", action="store_true", default=True, + help="if input the para means True, neither False.") + parser.add_argument("--start", default=0, type=float) + parser.add_argument("--end", default=1, type=float) + parser.add_argument("--device", default=0, type=int) + parser.add_argument("--clear_cache", action='store_true') + flags = parser.parse_args() + + # generate dict according to annotation file for query resolution + # load width and height of input images + img_size_dict = dict() + with open(flags.test_annotation)as f: + for line in f: + temp = line.split(" ") + img_file_path = temp[1] + img_name = temp[1].split("/")[-1].split(".")[0] + img_width = int(temp[2]) + img_height = int(temp[3]) + img_size_dict[img_name] = (img_width, img_height, img_file_path) + + # read bin file for generate predict result + bin_path = flags.bin_data_path # 推理结果保存路径 + det_results_path = flags.det_results_path + os.makedirs(det_results_path, exist_ok=True) + total_img = set([name[:name.rfind('_')] + for name in os.listdir(bin_path) if "bin" in name]) + total_img = sorted(total_img) # list of img names (str) + num_img = len(total_img) # 5000 + start = int(flags.start * num_img) + end = int(flags.end * num_img) + task_len = end - start + 1 + + finished = 0 + time_start = time.time() + for img_id in range(start, end): + # for img_id, bin_file in enumerate(sorted(total_img)): + bin_file = total_img[img_id] + path_base = os.path.join(bin_path, bin_file) + det_results_file = os.path.join(det_results_path, bin_file + ".txt") + if os.path.exists(det_results_file) and not flags.clear_cache: + continue + + # load all detected output tensor + bbox_file = path_base + "_" + str(0) + ".bin" + score_file = path_base + "_" + str(1) + ".bin" + assert os.path.exists( + bbox_file), '[ERROR] file `{}` not exist'.format(bbox_file) + assert os.path.exists( + score_file), '[ERROR] file `{}` not exist'.format(score_file) + bboxes = np.fromfile(bbox_file, dtype="float32").reshape( + flags.num_pred_box, 4) + scores = np.fromfile(score_file, dtype="float32").reshape( + flags.num_pred_box, 80) + + bboxes = torch.from_numpy(bboxes) + scores = torch.from_numpy(scores) + try: + bboxes = bboxes.npu(flags.device) + scores = scores.npu(flags.device) + except: + warnings.warn('npu is not available, running on cpu') + + max_scores, _ = scores.max(-1) # shape of [8732], torch.float32 + keep_inds = (max_scores > flags.score_threshold).nonzero( + as_tuple=False).view(-1) + bboxes = bboxes[keep_inds, :] + scores = scores[keep_inds, :] + + if flags.nms_pre > 0 and flags.nms_pre < bboxes.shape[0]: + max_scores, _ = scores.max(-1) # shape: torch.Size([2738]) dtype:torch.float32 + _, topk_inds = max_scores.topk(flags.nms_pre) + bboxes = bboxes[topk_inds, :] # shape: torch.Size([200, 4]) + scores = scores[topk_inds, :] # shape: torch.Size([200, 80]) + + # clip bbox border + bboxes[:, 0::2].clamp_(min=0, max=flags.net_input_width - 1) + bboxes[:, 1::2].clamp_(min=0, max=flags.net_input_height - 1) + + # remove small bbox + bboxes_width_height = bboxes[:, 2:] - bboxes[:, :2] + valid_bboxes = bboxes_width_height > flags.min_bbox_size + keep_inds = (valid_bboxes[:, 0] & valid_bboxes[:, 1] + ).nonzero(as_tuple=False).view(-1) + bboxes = bboxes[keep_inds, :] + scores = scores[keep_inds, :] + + # rescale bbox to original image size + original_img_info = img_size_dict[bin_file] + rescale_factor = torch.tensor([ + original_img_info[0] / flags.net_input_width, + original_img_info[1] / flags.net_input_height] * 2, + dtype=bboxes.dtype, device=bboxes.device) + bboxes *= rescale_factor + + if flags.nms: + if NMS_ON_NPU: + # repeat bbox for each class + # (N, 4) -> (B, N, 80, 4), where B = 1 is the batchsize + bboxes = bboxes[None, :, None, :].repeat(1, 1, 80, 1) + # (N, 80) -> (B, N, 80), where B = 1 is the batchsize + scores = scores[None, :, :] + + # bbox batched nms + bboxes, scores, labels, num_total_bboxes = \ + NMSOp( + bboxes.half(), scores.half(), + score_threshold=flags.score_threshold, + iou_threshold=flags.iou_threshold, + max_size_per_class=flags.max_per_img, + max_total_size=flags.max_per_img) + bboxes = bboxes[0, :num_total_bboxes, :] + scores = scores[0, :num_total_bboxes] + class_idxs = labels[0, :num_total_bboxes] + else: + # repeat bbox and class idx for each class + bboxes = bboxes[:, None, :].repeat( + 1, 80, 1) # (N, 4) -> (N, 80, 4) + class_idxs = torch.arange(80, dtype=torch.long, device=bboxes.device + )[None, :].repeat(bboxes.shape[0], 1) # (80) -> (N, 80) + + # reshape bbox for torch nms + bboxes = bboxes.view(-1, 4) + scores = scores.view(-1) + class_idxs = class_idxs.view(-1) + + # bbox batched nms + keep_inds = NMSOp(bboxes, scores, class_idxs, + flags.iou_threshold) + bboxes = bboxes[keep_inds] + scores = scores[keep_inds] + class_idxs = class_idxs[keep_inds] + else: + # repeat bbox and class idx for each class + bboxes = bboxes[:, None, :].repeat( + 1, 80, 1) # (N, 4) -> (N, 80, 4) + class_idxs = torch.arange(80, dtype=torch.long, device=bboxes.device + )[None, :].repeat(bboxes.shape[0], 1) # (80) -> (N, 80) + + # reshape bbox for torch nms + bboxes = bboxes.view(-1, 4) + scores = scores.view(-1) + class_idxs = class_idxs.view(-1) + + # keep topk max_per_img bbox + if flags.max_per_img > 0 and flags.max_per_img < bboxes.shape[0]: + _, topk_inds = scores.topk(flags.max_per_img) + bboxes = bboxes[topk_inds, :] + scores = scores[topk_inds] + class_idxs = class_idxs[topk_inds] + + # move to cpu if running on npu + if bboxes.device != 'cpu': + bboxes = bboxes.cpu() + scores = scores.cpu() + class_idxs = class_idxs.cpu() + + # convert to numpy.ndarray + bboxes = bboxes.numpy() + scores = scores.numpy() + class_idxs = class_idxs.numpy() + + # make det result file + if flags.ifShowDetObj == True: + imgCur = cv2.imread(original_img_info[2]) + + det_results_str = '' + for idx in range(bboxes.shape[0]): + x1, y1, x2, y2 = bboxes[idx, :] + predscore = scores[idx] + class_ind = class_idxs[idx] + + class_name = CLASSES[int(class_ind)] + det_results_str += "{} {} {} {} {} {}\n".format( + class_name, predscore, x1, y1, x2, y2) + if flags.ifShowDetObj == True: + imgCur = cv2.rectangle(imgCur, (int(x1), int( + y1)), (int(x2), int(y2)), (0, 255, 0), 1) + imgCur = cv2.putText(imgCur, class_name + '|' + str(predscore), + (int(x1), int(y1)), cv2.FONT_HERSHEY_SIMPLEX, + 0.5, (0, 0, 255), 1) + + if flags.ifShowDetObj == True: + cv2.imwrite(os.path.join(det_results_path, bin_file + + '.jpg'), imgCur, [int(cv2.IMWRITE_JPEG_QUALITY), 70]) + + with open(det_results_file, "w") as detf: + detf.write(det_results_str) + + finished += 1 + speed = finished / (time.time() - time_start) + print('processed {:5d}/{:<5d} images, speed: {:.2f}FPS'.format( + finished, task_len, speed), end='\r') diff --git a/AscendIE/TorchAIE/built-in/cv/detection/ssd/ssd_preprocess.py b/AscendIE/TorchAIE/built-in/cv/detection/ssd/ssd_preprocess.py new file mode 100644 index 0000000000..5e6dc9dec4 --- /dev/null +++ b/AscendIE/TorchAIE/built-in/cv/detection/ssd/ssd_preprocess.py @@ -0,0 +1,72 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""coco preprocess""" +""" +使用openmmlab环境 +""" + +import os +import argparse +from tqdm import tqdm + +import numpy as np +import mmcv + + +dataset_config = { + 'resize': (300, 300), + 'mean': [123.675, 116.28, 103.53], + 'std': [1, 1, 1], +} + +tensor_height = 300 +tensor_width = 300 + + +def coco_preprocess(input_image, output_bin_path): + """coco_preprocess""" + # define the output file name + img_name = input_image.split('/')[-1] + bin_name = img_name.split('.')[0] + ".bin" + bin_fl = os.path.join(output_bin_path, bin_name) + + one_img = mmcv.imread(input_image, backend='cv2') + one_img = mmcv.imresize(one_img, (tensor_height, tensor_width)) + # calculate padding + mean = np.array(dataset_config['mean'], dtype=np.float32) + std = np.array(dataset_config['std'], dtype=np.float32) + one_img = mmcv.imnormalize(one_img, mean, std) + one_img = one_img.transpose(2, 0, 1) + print("one_img.dtype: ", one_img.dtype) + print("one_img.shape: ", one_img.shape) + one_img.tofile(bin_fl) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description='preprocess of FasterRCNN pytorch model') + parser.add_argument("--image_folder_path", + default="/home/ascend/coco2017/val2017", help='image of dataset') + parser.add_argument( + "--bin_folder_path", default="/home/ascend/coco2017_bin/", help='Preprocessed image buffer') + flags = parser.parse_args() + + if not os.path.exists(flags.bin_folder_path): + os.makedirs(flags.bin_folder_path) + images = os.listdir(flags.image_folder_path) + for image_name in tqdm(images, desc="Starting to process image..."): + if not (image_name.endswith(".jpeg") or image_name.endswith(".JPEG") or image_name.endswith(".jpg")): + continue + path_image = os.path.join(flags.image_folder_path, image_name) + coco_preprocess(path_image, flags.bin_folder_path) diff --git a/AscendIE/TorchAIE/built-in/cv/detection/ssd/txt_to_json.py b/AscendIE/TorchAIE/built-in/cv/detection/ssd/txt_to_json.py new file mode 100644 index 0000000000..6a27b1aee3 --- /dev/null +++ b/AscendIE/TorchAIE/built-in/cv/detection/ssd/txt_to_json.py @@ -0,0 +1,115 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""txt to json""" + +import glob +import os +import sys +import argparse +import mmcv + +CLASSES = ['person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', + 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', + 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', + 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', + 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', + 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', + 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', + 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', + 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', + 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', + 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', + 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', + 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', + 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'] + +cat_ids = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, +24, 25, 27, 28, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, +48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 67, 70, +72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 84, 85, 86, 87, 88, 89, 90] + +''' + 0,0 ------> x (width) + | + | (Left,Top) + | *_________ + | | | + | | + y |_________| + (height) * + (Right,Bottom) +''' + +def file_lines_to_list(path): + """file_lines_to_list""" + # open txt file lines to a list + with open(path) as f: + content = f.readlines() + # remove whitespace characters like `\n` at the end of each line + content = [x.strip() for x in content] + return content + + +def error(msg): + """error""" + print(msg) + sys.exit(0) + + +def get_predict_list(file_path): + """get_predict_list""" + dr_files_list = glob.glob(file_path + '/*.txt') + dr_files_list.sort() + + bounding_boxes = [] + for txt_file in dr_files_list: + file_id = txt_file.split(".txt", 1)[0] + file_id = os.path.basename(os.path.normpath(file_id)) + lines = file_lines_to_list(txt_file) + for line in lines: + try: + sl = line.split() + if len(sl) > 6: + class_name = sl[0] + ' ' + sl[1] + scores, left, top, right, bottom = sl[2:] + else: + class_name, scores, left, top, right, bottom = sl + if float(scores) < 0.02: + continue + except ValueError: + error_msg = "Error: File " + txt_file + " wrong format.\n" + error_msg += " Expected: \n" + error_msg += " Received: " + line + error(error_msg) + + # bbox = left + " " + top + " " + right + " " + bottom + left = float(left) + right = float(right) + top = float(top) + bottom = float(bottom) + bbox = [left, top, right - left, bottom - top] + bounding_boxes.append({"image_id": int(file_id), "bbox": bbox, "score": float(scores), + "category_id": cat_ids[CLASSES.index(class_name)]}) + return bounding_boxes + + +if __name__ == '__main__': + parser = argparse.ArgumentParser('mAp calculate') + parser.add_argument('--npu_txt_path', default="detection-results", + help='the path of the predict result') + parser.add_argument("--json_output_file", default="coco_detection_result") + args = parser.parse_args() + + res_bbox = get_predict_list(args.npu_txt_path) + mmcv.dump(res_bbox, args.json_output_file + '.json') diff --git a/AscendIE/TorchAIE/built-in/cv/detection/ssd/update_ssd_mmdet.diff b/AscendIE/TorchAIE/built-in/cv/detection/ssd/update_ssd_mmdet.diff new file mode 100644 index 0000000000..cd1d46528d --- /dev/null +++ b/AscendIE/TorchAIE/built-in/cv/detection/ssd/update_ssd_mmdet.diff @@ -0,0 +1,798 @@ +diff --git a/mmdet/core/bbox/coder/delta_xywh_bbox_coder.py b/mmdet/core/bbox/coder/delta_xywh_bbox_coder.py +index e9eb3579..066e90fe 100644 +--- a/mmdet/core/bbox/coder/delta_xywh_bbox_coder.py ++++ b/mmdet/core/bbox/coder/delta_xywh_bbox_coder.py +@@ -1,3 +1,7 @@ ++# Copyright (c) OpenMMLab. All rights reserved. ++import warnings ++ ++import mmcv + import numpy as np + import torch + +@@ -20,16 +24,25 @@ class DeltaXYWHBBoxCoder(BaseBBoxCoder): + target for delta coordinates + clip_border (bool, optional): Whether clip the objects outside the + border of the image. Defaults to True. ++ add_ctr_clamp (bool): Whether to add center clamp, when added, the ++ predicted box is clamped is its center is too far away from ++ the original anchor's center. Only used by YOLOF. Default False. ++ ctr_clamp (int): the maximum pixel shift to clamp. Only used by YOLOF. ++ Default 32. + """ + + def __init__(self, + target_means=(0., 0., 0., 0.), + target_stds=(1., 1., 1., 1.), +- clip_border=True): ++ clip_border=True, ++ add_ctr_clamp=False, ++ ctr_clamp=32): + super(BaseBBoxCoder, self).__init__() + self.means = target_means + self.stds = target_stds + self.clip_border = clip_border ++ self.add_ctr_clamp = add_ctr_clamp ++ self.ctr_clamp = ctr_clamp + + def encode(self, bboxes, gt_bboxes): + """Get box regression transformation deltas that can be used to +@@ -57,10 +70,16 @@ class DeltaXYWHBBoxCoder(BaseBBoxCoder): + """Apply transformation `pred_bboxes` to `boxes`. + + Args: +- boxes (torch.Tensor): Basic boxes. +- pred_bboxes (torch.Tensor): Encoded boxes with shape +- max_shape (tuple[int], optional): Maximum shape of boxes. +- Defaults to None. ++ bboxes (torch.Tensor): Basic boxes. Shape (B, N, 4) or (N, 4) ++ pred_bboxes (Tensor): Encoded offsets with respect to each roi. ++ Has shape (B, N, num_classes * 4) or (B, N, 4) or ++ (N, num_classes * 4) or (N, 4). Note N = num_anchors * W * H ++ when rois is a grid of anchors.Offset encoding follows [1]_. ++ max_shape (Sequence[int] or torch.Tensor or Sequence[ ++ Sequence[int]],optional): Maximum bounds for boxes, specifies ++ (H, W, C) or (H, W). If bboxes shape is (B, N, 4), then ++ the max_shape should be a Sequence[Sequence[int]] ++ and the length of max_shape should also be B. + wh_ratio_clip (float, optional): The allowed ratio between + width and height. + +@@ -69,8 +88,28 @@ class DeltaXYWHBBoxCoder(BaseBBoxCoder): + """ + + assert pred_bboxes.size(0) == bboxes.size(0) +- decoded_bboxes = delta2bbox(bboxes, pred_bboxes, self.means, self.stds, +- max_shape, wh_ratio_clip, self.clip_border) ++ if pred_bboxes.ndim == 3: ++ assert pred_bboxes.size(1) == bboxes.size(1) ++ ++ if pred_bboxes.ndim == 2 and not True: ++ # single image decode ++ decoded_bboxes = delta2bbox(bboxes, pred_bboxes, self.means, ++ self.stds, max_shape, wh_ratio_clip, ++ self.clip_border, self.add_ctr_clamp, ++ self.ctr_clamp) ++ else: ++ if pred_bboxes.ndim == 3 and not True: ++ warnings.warn( ++ 'DeprecationWarning: onnx_delta2bbox is deprecated ' ++ 'in the case of batch decoding and non-ONNX, ' ++ 'please use “delta2bbox” instead. In order to improve ' ++ 'the decoding speed, the batch function will no ' ++ 'longer be supported. ') ++ decoded_bboxes = onnx_delta2bbox(bboxes, pred_bboxes, self.means, ++ self.stds, max_shape, ++ wh_ratio_clip, self.clip_border, ++ self.add_ctr_clamp, ++ self.ctr_clamp) + + return decoded_bboxes + +@@ -126,7 +165,108 @@ def delta2bbox(rois, + stds=(1., 1., 1., 1.), + max_shape=None, + wh_ratio_clip=16 / 1000, +- clip_border=True): ++ clip_border=True, ++ add_ctr_clamp=False, ++ ctr_clamp=32): ++ """Apply deltas to shift/scale base boxes. ++ ++ Typically the rois are anchor or proposed bounding boxes and the deltas are ++ network outputs used to shift/scale those boxes. ++ This is the inverse function of :func:`bbox2delta`. ++ ++ Args: ++ rois (Tensor): Boxes to be transformed. Has shape (N, 4). ++ deltas (Tensor): Encoded offsets relative to each roi. ++ Has shape (N, num_classes * 4) or (N, 4). Note ++ N = num_base_anchors * W * H, when rois is a grid of ++ anchors. Offset encoding follows [1]_. ++ means (Sequence[float]): Denormalizing means for delta coordinates. ++ Default (0., 0., 0., 0.). ++ stds (Sequence[float]): Denormalizing standard deviation for delta ++ coordinates. Default (1., 1., 1., 1.). ++ max_shape (tuple[int, int]): Maximum bounds for boxes, specifies ++ (H, W). Default None. ++ wh_ratio_clip (float): Maximum aspect ratio for boxes. Default ++ 16 / 1000. ++ clip_border (bool, optional): Whether clip the objects outside the ++ border of the image. Default True. ++ add_ctr_clamp (bool): Whether to add center clamp. When set to True, ++ the center of the prediction bounding box will be clamped to ++ avoid being too far away from the center of the anchor. ++ Only used by YOLOF. Default False. ++ ctr_clamp (int): the maximum pixel shift to clamp. Only used by YOLOF. ++ Default 32. ++ ++ Returns: ++ Tensor: Boxes with shape (N, num_classes * 4) or (N, 4), where 4 ++ represent tl_x, tl_y, br_x, br_y. ++ ++ References: ++ .. [1] https://arxiv.org/abs/1311.2524 ++ ++ Example: ++ >>> rois = torch.Tensor([[ 0., 0., 1., 1.], ++ >>> [ 0., 0., 1., 1.], ++ >>> [ 0., 0., 1., 1.], ++ >>> [ 5., 5., 5., 5.]]) ++ >>> deltas = torch.Tensor([[ 0., 0., 0., 0.], ++ >>> [ 1., 1., 1., 1.], ++ >>> [ 0., 0., 2., -1.], ++ >>> [ 0.7, -1.9, -0.5, 0.3]]) ++ >>> delta2bbox(rois, deltas, max_shape=(32, 32, 3)) ++ tensor([[0.0000, 0.0000, 1.0000, 1.0000], ++ [0.1409, 0.1409, 2.8591, 2.8591], ++ [0.0000, 0.3161, 4.1945, 0.6839], ++ [5.0000, 5.0000, 5.0000, 5.0000]]) ++ """ ++ num_bboxes, num_classes = deltas.size(0), deltas.size(1) // 4 ++ if num_bboxes == 0: ++ return deltas ++ ++ deltas = deltas.reshape(-1, 4) ++ ++ means = deltas.new_tensor(means).view(1, -1) ++ stds = deltas.new_tensor(stds).view(1, -1) ++ denorm_deltas = deltas * stds + means ++ ++ dxy = denorm_deltas[:, :2] ++ dwh = denorm_deltas[:, 2:] ++ ++ # Compute width/height of each roi ++ rois_ = rois.repeat(1, num_classes).reshape(-1, 4) ++ pxy = ((rois_[:, :2] + rois_[:, 2:]) * 0.5) ++ pwh = (rois_[:, 2:] - rois_[:, :2]) ++ ++ dxy_wh = pwh * dxy ++ ++ max_ratio = np.abs(np.log(wh_ratio_clip)) ++ if add_ctr_clamp: ++ dxy_wh = torch.clamp(dxy_wh, max=ctr_clamp, min=-ctr_clamp) ++ dwh = torch.clamp(dwh, max=max_ratio) ++ else: ++ dwh = dwh.clamp(min=-max_ratio, max=max_ratio) ++ ++ gxy = pxy + dxy_wh ++ gwh = pwh * dwh.exp() ++ x1y1 = gxy - (gwh * 0.5) ++ x2y2 = gxy + (gwh * 0.5) ++ bboxes = torch.cat([x1y1, x2y2], dim=-1) ++ if clip_border and max_shape is not None: ++ bboxes[..., 0::2].clamp_(min=0, max=max_shape[1]) ++ bboxes[..., 1::2].clamp_(min=0, max=max_shape[0]) ++ bboxes = bboxes.reshape(num_bboxes, -1) ++ return bboxes ++ ++ ++def onnx_delta2bbox(rois, ++ deltas, ++ means=(0., 0., 0., 0.), ++ stds=(1., 1., 1., 1.), ++ max_shape=None, ++ wh_ratio_clip=16 / 1000, ++ clip_border=True, ++ add_ctr_clamp=False, ++ ctr_clamp=32): + """Apply deltas to shift/scale base boxes. + + Typically the rois are anchor or proposed bounding boxes and the deltas are +@@ -134,21 +274,34 @@ def delta2bbox(rois, + This is the inverse function of :func:`bbox2delta`. + + Args: +- rois (Tensor): Boxes to be transformed. Has shape (N, 4) ++ rois (Tensor): Boxes to be transformed. Has shape (N, 4) or (B, N, 4) + deltas (Tensor): Encoded offsets with respect to each roi. +- Has shape (N, 4 * num_classes). Note N = num_anchors * W * H when +- rois is a grid of anchors. Offset encoding follows [1]_. +- means (Sequence[float]): Denormalizing means for delta coordinates ++ Has shape (B, N, num_classes * 4) or (B, N, 4) or ++ (N, num_classes * 4) or (N, 4). Note N = num_anchors * W * H ++ when rois is a grid of anchors.Offset encoding follows [1]_. ++ means (Sequence[float]): Denormalizing means for delta coordinates. ++ Default (0., 0., 0., 0.). + stds (Sequence[float]): Denormalizing standard deviation for delta +- coordinates +- max_shape (tuple[int, int]): Maximum bounds for boxes. specifies (H, W) ++ coordinates. Default (1., 1., 1., 1.). ++ max_shape (Sequence[int] or torch.Tensor or Sequence[ ++ Sequence[int]],optional): Maximum bounds for boxes, specifies ++ (H, W, C) or (H, W). If rois shape is (B, N, 4), then ++ the max_shape should be a Sequence[Sequence[int]] ++ and the length of max_shape should also be B. Default None. + wh_ratio_clip (float): Maximum aspect ratio for boxes. ++ Default 16 / 1000. + clip_border (bool, optional): Whether clip the objects outside the +- border of the image. Defaults to True. ++ border of the image. Default True. ++ add_ctr_clamp (bool): Whether to add center clamp, when added, the ++ predicted box is clamped is its center is too far away from ++ the original anchor's center. Only used by YOLOF. Default False. ++ ctr_clamp (int): the maximum pixel shift to clamp. Only used by YOLOF. ++ Default 32. + + Returns: +- Tensor: Boxes with shape (N, 4), where columns represent +- tl_x, tl_y, br_x, br_y. ++ Tensor: Boxes with shape (B, N, num_classes * 4) or (B, N, 4) or ++ (N, num_classes * 4) or (N, 4), where 4 represent ++ tl_x, tl_y, br_x, br_y. + + References: + .. [1] https://arxiv.org/abs/1311.2524 +@@ -162,43 +315,76 @@ def delta2bbox(rois, + >>> [ 1., 1., 1., 1.], + >>> [ 0., 0., 2., -1.], + >>> [ 0.7, -1.9, -0.5, 0.3]]) +- >>> delta2bbox(rois, deltas, max_shape=(32, 32)) ++ >>> delta2bbox(rois, deltas, max_shape=(32, 32, 3)) + tensor([[0.0000, 0.0000, 1.0000, 1.0000], + [0.1409, 0.1409, 2.8591, 2.8591], + [0.0000, 0.3161, 4.1945, 0.6839], + [5.0000, 5.0000, 5.0000, 5.0000]]) + """ +- means = deltas.new_tensor(means).view(1, -1).repeat(1, deltas.size(1) // 4) +- stds = deltas.new_tensor(stds).view(1, -1).repeat(1, deltas.size(1) // 4) ++ means = deltas.new_tensor(means).view(1, ++ -1).repeat(1, ++ deltas.size(-1) // 4) ++ stds = deltas.new_tensor(stds).view(1, -1).repeat(1, deltas.size(-1) // 4) + denorm_deltas = deltas * stds + means +- dx = denorm_deltas[:, 0::4] +- dy = denorm_deltas[:, 1::4] +- dw = denorm_deltas[:, 2::4] +- dh = denorm_deltas[:, 3::4] +- max_ratio = np.abs(np.log(wh_ratio_clip)) +- dw = dw.clamp(min=-max_ratio, max=max_ratio) +- dh = dh.clamp(min=-max_ratio, max=max_ratio) ++ dx = denorm_deltas[..., 0::4] ++ dy = denorm_deltas[..., 1::4] ++ dw = denorm_deltas[..., 2::4] ++ dh = denorm_deltas[..., 3::4] ++ ++ x1, y1 = rois[..., 0], rois[..., 1] ++ x2, y2 = rois[..., 2], rois[..., 3] + # Compute center of each roi +- px = ((rois[:, 0] + rois[:, 2]) * 0.5).unsqueeze(1).expand_as(dx) +- py = ((rois[:, 1] + rois[:, 3]) * 0.5).unsqueeze(1).expand_as(dy) ++ px = ((x1 + x2) * 0.5).unsqueeze(-1).expand_as(dx) ++ py = ((y1 + y2) * 0.5).unsqueeze(-1).expand_as(dy) + # Compute width/height of each roi +- pw = (rois[:, 2] - rois[:, 0]).unsqueeze(1).expand_as(dw) +- ph = (rois[:, 3] - rois[:, 1]).unsqueeze(1).expand_as(dh) ++ pw = (x2 - x1).unsqueeze(-1).expand_as(dw) ++ ph = (y2 - y1).unsqueeze(-1).expand_as(dh) ++ ++ dx_width = pw * dx ++ dy_height = ph * dy ++ ++ max_ratio = np.abs(np.log(wh_ratio_clip)) ++ if add_ctr_clamp: ++ dx_width = torch.clamp(dx_width, max=ctr_clamp, min=-ctr_clamp) ++ dy_height = torch.clamp(dy_height, max=ctr_clamp, min=-ctr_clamp) ++ dw = torch.clamp(dw, max=max_ratio) ++ dh = torch.clamp(dh, max=max_ratio) ++ else: ++ dw = dw.clamp(min=-max_ratio, max=max_ratio) ++ dh = dh.clamp(min=-max_ratio, max=max_ratio) + # Use exp(network energy) to enlarge/shrink each roi + gw = pw * dw.exp() + gh = ph * dh.exp() + # Use network energy to shift the center of each roi +- gx = px + pw * dx +- gy = py + ph * dy ++ gx = px + dx_width ++ gy = py + dy_height + # Convert center-xy/width/height to top-left, bottom-right + x1 = gx - gw * 0.5 + y1 = gy - gh * 0.5 + x2 = gx + gw * 0.5 + y2 = gy + gh * 0.5 +- if clip_border and max_shape is not None: +- x1 = x1.clamp(min=0, max=max_shape[1]) +- y1 = y1.clamp(min=0, max=max_shape[0]) +- x2 = x2.clamp(min=0, max=max_shape[1]) +- y2 = y2.clamp(min=0, max=max_shape[0]) ++ + bboxes = torch.stack([x1, y1, x2, y2], dim=-1).view(deltas.size()) ++ ++ if clip_border and max_shape is not None: ++ # clip bboxes with dynamic `min` and `max` for onnx ++ if True: ++ from mmdet.core.export.onnx_helper import dynamic_clip_for_onnx ++ x1, y1, x2, y2 = dynamic_clip_for_onnx(x1, y1, x2, y2, max_shape) ++ bboxes = torch.stack([x1, y1, x2, y2], dim=-1).view(deltas.size()) ++ return bboxes ++ if not isinstance(max_shape, torch.Tensor): ++ max_shape = x1.new_tensor(max_shape) ++ max_shape = max_shape[..., :2].type_as(x1) ++ if max_shape.ndim == 2: ++ assert bboxes.ndim == 3 ++ assert max_shape.size(0) == bboxes.size(0) ++ ++ min_xy = x1.new_tensor(0) ++ max_xy = torch.cat( ++ [max_shape] * (deltas.size(-1) // 2), ++ dim=-1).flip(-1).unsqueeze(-2) ++ bboxes = torch.where(bboxes < min_xy, min_xy, bboxes) ++ bboxes = torch.where(bboxes > max_xy, max_xy, bboxes) ++ + return bboxes +diff --git a/mmdet/core/export/pytorch2onnx.py b/mmdet/core/export/pytorch2onnx.py +index 8f9309df..b9f43d48 100644 +--- a/mmdet/core/export/pytorch2onnx.py ++++ b/mmdet/core/export/pytorch2onnx.py +@@ -39,6 +39,7 @@ def generate_inputs_and_wrap_model(config_path, checkpoint_path, input_config): + + model = build_model_from_cfg(config_path, checkpoint_path) + one_img, one_meta = preprocess_example_input(input_config) ++ one_meta['img_shape_for_onnx'] = one_img.shape[-2:] + tensor_data = [one_img] + model.forward = partial( + model.forward, img_metas=[[one_meta]], return_loss=False) +diff --git a/mmdet/core/post_processing/bbox_nms.py b/mmdet/core/post_processing/bbox_nms.py +index 463fe2e4..72ca09d3 100644 +--- a/mmdet/core/post_processing/bbox_nms.py ++++ b/mmdet/core/post_processing/bbox_nms.py +@@ -55,7 +55,7 @@ def multiclass_nms(multi_bboxes, + inds = valid_mask.nonzero(as_tuple=False).squeeze(1) + bboxes, scores, labels = bboxes[inds], scores[inds], labels[inds] + if inds.numel() == 0: +- if torch.onnx.is_in_onnx_export(): ++ if True: + raise RuntimeError('[ONNX Error] Can not record NMS ' + 'as it has not been executed this time') + if return_inds: +diff --git a/mmdet/models/backbones/ssd_vgg.py b/mmdet/models/backbones/ssd_vgg.py +index cbc4fbb2..4bb7e37a 100644 +--- a/mmdet/models/backbones/ssd_vgg.py ++++ b/mmdet/models/backbones/ssd_vgg.py +@@ -162,8 +162,14 @@ class L2Norm(nn.Module): + + def forward(self, x): + """Forward function.""" +- # normalization layer convert to FP32 in FP16 training ++ # # normalization layer convert to FP32 in FP16 training ++ # x_float = x.float() ++ # norm = x_float.pow(2).sum(1, keepdim=True).sqrt() + self.eps ++ # return (self.weight[None, :, None, None].float().expand_as(x_float) * ++ # x_float / norm).type_as(x) ++ + x_float = x.float() +- norm = x_float.pow(2).sum(1, keepdim=True).sqrt() + self.eps ++ x_mul = x_float * x_float ++ norm = x_mul.sum(1, keepdim=True).sqrt() + self.eps + return (self.weight[None, :, None, None].float().expand_as(x_float) * + x_float / norm).type_as(x) +diff --git a/mmdet/models/dense_heads/anchor_head.py b/mmdet/models/dense_heads/anchor_head.py +index a5bb4137..6e0d892e 100644 +--- a/mmdet/models/dense_heads/anchor_head.py ++++ b/mmdet/models/dense_heads/anchor_head.py +@@ -487,6 +487,162 @@ class AnchorHead(BaseDenseHead, BBoxTestMixin): + num_total_samples=num_total_samples) + return dict(loss_cls=losses_cls, loss_bbox=losses_bbox) + ++ @force_fp32(apply_to=('cls_scores', 'bbox_preds')) ++ def onnx_export(self, ++ cls_scores, ++ bbox_preds, ++ score_factors=None, ++ img_metas=None, ++ with_nms=True): ++ """Transform network output for a batch into bbox predictions. ++ ++ Args: ++ cls_scores (list[Tensor]): Box scores for each scale level ++ with shape (N, num_points * num_classes, H, W). ++ bbox_preds (list[Tensor]): Box energies / deltas for each scale ++ level with shape (N, num_points * 4, H, W). ++ score_factors (list[Tensor]): score_factors for each s ++ cale level with shape (N, num_points * 1, H, W). ++ Default: None. ++ img_metas (list[dict]): Meta information of each image, e.g., ++ image size, scaling factor, etc. Default: None. ++ with_nms (bool): Whether apply nms to the bboxes. Default: True. ++ ++ Returns: ++ tuple[Tensor, Tensor] | list[tuple]: When `with_nms` is True, ++ it is tuple[Tensor, Tensor], first tensor bboxes with shape ++ [N, num_det, 5], 5 arrange as (x1, y1, x2, y2, score) ++ and second element is class labels of shape [N, num_det]. ++ When `with_nms` is False, first tensor is bboxes with ++ shape [N, num_det, 4], second tensor is raw score has ++ shape [N, num_det, num_classes]. ++ """ ++ assert len(cls_scores) == len(bbox_preds) ++ ++ num_levels = len(cls_scores) ++ ++ featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores] ++ ++ mlvl_priors = self.anchor_generator.grid_anchors( ++ featmap_sizes, device=bbox_preds[0].device) ++ ++ mlvl_cls_scores = [cls_scores[i].detach() for i in range(num_levels)] ++ mlvl_bbox_preds = [bbox_preds[i].detach() for i in range(num_levels)] ++ ++ assert len( ++ img_metas ++ ) == 1, 'Only support one input image while in exporting to ONNX' ++ img_shape = torch.tensor( ++ img_metas[0]['img_shape_for_onnx'], ++ dtype=torch.long, ++ device=bbox_preds[0].device) ++ ++ cfg = self.test_cfg ++ assert len(cls_scores) == len(bbox_preds) == len(mlvl_priors) ++ device = cls_scores[0].device ++ batch_size = cls_scores[0].shape[0] ++ # convert to tensor to keep tracing ++ nms_pre_tensor = torch.tensor( ++ cfg.get('nms_pre', -1), device=device, dtype=torch.long) ++ ++ # e.g. Retina, FreeAnchor, etc. ++ if score_factors is None: ++ with_score_factors = False ++ mlvl_score_factor = [None for _ in range(num_levels)] ++ else: ++ # e.g. FCOS, PAA, ATSS, etc. ++ with_score_factors = True ++ mlvl_score_factor = [ ++ score_factors[i].detach() for i in range(num_levels) ++ ] ++ mlvl_score_factors = [] ++ ++ mlvl_batch_bboxes = [] ++ mlvl_scores = [] ++ ++ for cls_score, bbox_pred, score_factors, priors in zip( ++ mlvl_cls_scores, mlvl_bbox_preds, mlvl_score_factor, ++ mlvl_priors): ++ assert cls_score.size()[-2:] == bbox_pred.size()[-2:] ++ ++ scores = cls_score.permute(0, 2, 3, ++ 1).reshape(batch_size, -1, ++ self.cls_out_channels) ++ if self.use_sigmoid_cls: ++ scores = scores.sigmoid() ++ nms_pre_score = scores ++ else: ++ scores = scores.softmax(-1) ++ nms_pre_score = scores ++ ++ if with_score_factors: ++ score_factors = score_factors.permute(0, 2, 3, 1).reshape( ++ batch_size, -1).sigmoid() ++ bbox_pred = bbox_pred.permute(0, 2, 3, ++ 1).reshape(batch_size, -1, 4) ++ priors = priors.expand(batch_size, -1, priors.size(-1)) ++ # Get top-k predictions ++ from mmdet.core.export.onnx_helper import get_k_for_topk ++ nms_pre = get_k_for_topk(nms_pre_tensor, bbox_pred.shape[1]) ++ if nms_pre > 0: ++ ++ if with_score_factors: ++ nms_pre_score = (nms_pre_score * score_factors[..., None]) ++ else: ++ nms_pre_score = nms_pre_score ++ ++ # Get maximum scores for foreground classes. ++ if self.use_sigmoid_cls: ++ max_scores, _ = nms_pre_score.max(-1) ++ else: ++ # remind that we set FG labels to [0, num_class-1] ++ # since mmdet v2.0 ++ # BG cat_id: num_class ++ max_scores, _ = nms_pre_score[..., :-1].max(-1) ++ _, topk_inds = max_scores.topk(nms_pre) ++ ++ batch_inds = torch.arange( ++ batch_size, device=bbox_pred.device).view( ++ -1, 1).expand_as(topk_inds).long() ++ # Avoid onnx2tensorrt issue in https://github.com/NVIDIA/TensorRT/issues/1134 # noqa: E501 ++ # transformed_inds = bbox_pred.shape[1] * batch_inds + topk_inds ++ transformed_inds = (bbox_pred.shape[1] * batch_inds).int() + topk_inds.int() ++ transformed_inds = transformed_inds.long() ++ priors = priors.reshape( ++ -1, priors.size(-1))[transformed_inds, :].reshape( ++ batch_size, -1, priors.size(-1)) ++ bbox_pred = bbox_pred.reshape(-1, ++ 4)[transformed_inds, :].reshape( ++ batch_size, -1, 4) ++ scores = scores.reshape( ++ -1, self.cls_out_channels)[transformed_inds, :].reshape( ++ batch_size, -1, self.cls_out_channels) ++ if with_score_factors: ++ score_factors = score_factors.reshape( ++ -1, 1)[transformed_inds].reshape(batch_size, -1) ++ ++ bboxes = self.bbox_coder.decode( ++ priors, bbox_pred, max_shape=img_shape) ++ ++ mlvl_batch_bboxes.append(bboxes) ++ mlvl_scores.append(scores) ++ if with_score_factors: ++ mlvl_score_factors.append(score_factors) ++ ++ batch_bboxes = torch.cat(mlvl_batch_bboxes, dim=1) ++ batch_scores = torch.cat(mlvl_scores, dim=1) ++ if with_score_factors: ++ batch_score_factors = torch.cat(mlvl_score_factors, dim=1) ++ ++ if not self.use_sigmoid_cls: ++ batch_scores = batch_scores[..., :self.num_classes] ++ ++ if with_score_factors: ++ batch_scores = batch_scores * (batch_score_factors.unsqueeze(2)) ++ ++ # directly return bboxes without NMS ++ return batch_bboxes, batch_scores ++ + @force_fp32(apply_to=('cls_scores', 'bbox_preds')) + def get_bboxes(self, + cls_scores, +@@ -545,38 +701,45 @@ class AnchorHead(BaseDenseHead, BBoxTestMixin): + >>> assert det_bboxes.shape[1] == 5 + >>> assert len(det_bboxes) == len(det_labels) == cfg.max_per_img + """ +- assert len(cls_scores) == len(bbox_preds) +- num_levels = len(cls_scores) +- +- device = cls_scores[0].device +- featmap_sizes = [cls_scores[i].shape[-2:] for i in range(num_levels)] +- mlvl_anchors = self.anchor_generator.grid_anchors( +- featmap_sizes, device=device) +- +- result_list = [] +- for img_id in range(len(img_metas)): +- cls_score_list = [ +- cls_scores[i][img_id].detach() for i in range(num_levels) +- ] +- bbox_pred_list = [ +- bbox_preds[i][img_id].detach() for i in range(num_levels) +- ] +- img_shape = img_metas[img_id]['img_shape'] +- scale_factor = img_metas[img_id]['scale_factor'] +- if with_nms: +- # some heads don't support with_nms argument +- proposals = self._get_bboxes_single(cls_score_list, +- bbox_pred_list, +- mlvl_anchors, img_shape, +- scale_factor, cfg, rescale) +- else: +- proposals = self._get_bboxes_single(cls_score_list, +- bbox_pred_list, +- mlvl_anchors, img_shape, +- scale_factor, cfg, rescale, +- with_nms) +- result_list.append(proposals) +- return result_list ++ if True: ++ return self.onnx_export(cls_scores, ++ bbox_preds, ++ score_factors=None, ++ img_metas=img_metas, ++ with_nms=with_nms) ++ else: ++ assert len(cls_scores) == len(bbox_preds) ++ num_levels = len(cls_scores) ++ ++ device = cls_scores[0].device ++ featmap_sizes = [cls_scores[i].shape[-2:] for i in range(num_levels)] ++ mlvl_anchors = self.anchor_generator.grid_anchors( ++ featmap_sizes, device=device) ++ ++ result_list = [] ++ for img_id in range(len(img_metas)): ++ cls_score_list = [ ++ cls_scores[i][img_id].detach() for i in range(num_levels) ++ ] ++ bbox_pred_list = [ ++ bbox_preds[i][img_id].detach() for i in range(num_levels) ++ ] ++ img_shape = img_metas[img_id]['img_shape'] ++ scale_factor = img_metas[img_id]['scale_factor'] ++ if with_nms: ++ # some heads don't support with_nms argument ++ proposals = self._get_bboxes_single(cls_score_list, ++ bbox_pred_list, ++ mlvl_anchors, img_shape, ++ scale_factor, cfg, rescale) ++ else: ++ proposals = self._get_bboxes_single(cls_score_list, ++ bbox_pred_list, ++ mlvl_anchors, img_shape, ++ scale_factor, cfg, rescale, ++ with_nms) ++ result_list.append(proposals) ++ return result_list + + def _get_bboxes_single(self, + cls_score_list, +@@ -612,6 +775,7 @@ class AnchorHead(BaseDenseHead, BBoxTestMixin): + are bounding box positions (tl_x, tl_y, br_x, br_y) and the + 5-th column is a score between 0 and 1. + """ ++ print('in _get_bboxes_single') + cfg = self.test_cfg if cfg is None else cfg + assert len(cls_score_list) == len(bbox_pred_list) == len(mlvl_anchors) + mlvl_bboxes = [] +diff --git a/mmdet/models/dense_heads/yolo_head.py b/mmdet/models/dense_heads/yolo_head.py +index 93d051e7..94e496b5 100644 +--- a/mmdet/models/dense_heads/yolo_head.py ++++ b/mmdet/models/dense_heads/yolo_head.py +@@ -281,7 +281,7 @@ class YOLOV3Head(BaseDenseHead, BBoxTestMixin): + # Get top-k prediction + nms_pre = cfg.get('nms_pre', -1) + if 0 < nms_pre < conf_pred.size(0) and ( +- not torch.onnx.is_in_onnx_export()): ++ not True): + _, topk_inds = conf_pred.topk(nms_pre) + bbox_pred = bbox_pred[topk_inds, :] + cls_pred = cls_pred[topk_inds, :] +diff --git a/mmdet/models/detectors/single_stage.py b/mmdet/models/detectors/single_stage.py +index 96c4acac..f5be2641 100644 +--- a/mmdet/models/detectors/single_stage.py ++++ b/mmdet/models/detectors/single_stage.py +@@ -114,7 +114,7 @@ class SingleStageDetector(BaseDetector): + bbox_list = self.bbox_head.get_bboxes( + *outs, img_metas, rescale=rescale) + # skip post-processing when exporting to ONNX +- if torch.onnx.is_in_onnx_export(): ++ if True: + return bbox_list + + bbox_results = [ +diff --git a/mmdet/models/roi_heads/cascade_roi_head.py b/mmdet/models/roi_heads/cascade_roi_head.py +index 45b6f36a..1199a443 100644 +--- a/mmdet/models/roi_heads/cascade_roi_head.py ++++ b/mmdet/models/roi_heads/cascade_roi_head.py +@@ -349,7 +349,7 @@ class CascadeRoIHead(BaseRoIHead, BBoxTestMixin, MaskTestMixin): + det_bboxes.append(det_bbox) + det_labels.append(det_label) + +- if torch.onnx.is_in_onnx_export(): ++ if True: + return det_bboxes, det_labels + bbox_results = [ + bbox2result(det_bboxes[i], det_labels[i], +diff --git a/mmdet/models/roi_heads/mask_heads/fcn_mask_head.py b/mmdet/models/roi_heads/mask_heads/fcn_mask_head.py +index 0cba3cda..d69054b6 100644 +--- a/mmdet/models/roi_heads/mask_heads/fcn_mask_head.py ++++ b/mmdet/models/roi_heads/mask_heads/fcn_mask_head.py +@@ -195,7 +195,7 @@ class FCNMaskHead(nn.Module): + scale_factor = bboxes.new_tensor(scale_factor) + bboxes = bboxes / scale_factor + +- if torch.onnx.is_in_onnx_export(): ++ if True: + # TODO: Remove after F.grid_sample is supported. + from torchvision.models.detection.roi_heads \ + import paste_masks_in_image +@@ -316,7 +316,7 @@ def _do_paste_mask(masks, boxes, img_h, img_w, skip_empty=True): + gy = img_y[:, :, None].expand(N, img_y.size(1), img_x.size(1)) + grid = torch.stack([gx, gy], dim=3) + +- if torch.onnx.is_in_onnx_export(): ++ if True: + raise RuntimeError( + 'Exporting F.grid_sample from Pytorch to ONNX is not supported.') + img_masks = F.grid_sample( +diff --git a/mmdet/models/roi_heads/roi_extractors/single_level_roi_extractor.py b/mmdet/models/roi_heads/roi_extractors/single_level_roi_extractor.py +index c0eebc4a..534b1c9b 100644 +--- a/mmdet/models/roi_heads/roi_extractors/single_level_roi_extractor.py ++++ b/mmdet/models/roi_heads/roi_extractors/single_level_roi_extractor.py +@@ -55,7 +55,7 @@ class SingleRoIExtractor(BaseRoIExtractor): + """Forward function.""" + out_size = self.roi_layers[0].output_size + num_levels = len(feats) +- if torch.onnx.is_in_onnx_export(): ++ if True: + # Work around to export mask-rcnn to onnx + roi_feats = rois[:, :1].clone().detach() + roi_feats = roi_feats.expand( +@@ -82,7 +82,7 @@ class SingleRoIExtractor(BaseRoIExtractor): + mask = target_lvls == i + inds = mask.nonzero(as_tuple=False).squeeze(1) + # TODO: make it nicer when exporting to onnx +- if torch.onnx.is_in_onnx_export(): ++ if True: + # To keep all roi_align nodes exported to onnx + rois_ = rois[inds] + roi_feats_t = self.roi_layers[i](feats[i], rois_) +diff --git a/mmdet/models/roi_heads/standard_roi_head.py b/mmdet/models/roi_heads/standard_roi_head.py +index c530f2a5..85f95e0c 100644 +--- a/mmdet/models/roi_heads/standard_roi_head.py ++++ b/mmdet/models/roi_heads/standard_roi_head.py +@@ -246,7 +246,7 @@ class StandardRoIHead(BaseRoIHead, BBoxTestMixin, MaskTestMixin): + + det_bboxes, det_labels = self.simple_test_bboxes( + x, img_metas, proposal_list, self.test_cfg, rescale=rescale) +- if torch.onnx.is_in_onnx_export(): ++ if True: + if self.with_mask: + segm_results = self.simple_test_mask( + x, img_metas, det_bboxes, det_labels, rescale=rescale) +diff --git a/mmdet/models/roi_heads/test_mixins.py b/mmdet/models/roi_heads/test_mixins.py +index 0e675d6e..ecc08cf6 100644 +--- a/mmdet/models/roi_heads/test_mixins.py ++++ b/mmdet/models/roi_heads/test_mixins.py +@@ -197,7 +197,7 @@ class MaskTestMixin(object): + torch.from_numpy(scale_factor).to(det_bboxes[0].device) + for scale_factor in scale_factors + ] +- if torch.onnx.is_in_onnx_export(): ++ if True: + # avoid mask_pred.split with static number of prediction + mask_preds = [] + _bboxes = [] +diff --git a/tools/pytorch2onnx.py b/tools/pytorch2onnx.py +index a8e7487b..97ed2d09 100644 +--- a/tools/pytorch2onnx.py ++++ b/tools/pytorch2onnx.py +@@ -33,23 +33,32 @@ def pytorch2onnx(config_path, + one_img, one_meta = preprocess_example_input(input_config) + model, tensor_data = generate_inputs_and_wrap_model( + config_path, checkpoint_path, input_config) ++ ++ input_names = ['input'] ++ dynamic_axes = {'input': {0: 'batch', 2: 'height', 3: 'width'}} ++ + output_names = ['boxes'] ++ dynamic_axes['boxes'] = {0: 'batch'} + if model.with_bbox: + output_names.append('labels') ++ dynamic_axes['labels'] = {0: 'batch'} + if model.with_mask: + output_names.append('masks') ++ dynamic_axes['masks'] = {0: 'batch'} + + torch.onnx.export( + model, + tensor_data, + output_file, +- input_names=['input'], ++ input_names=input_names, + output_names=output_names, ++ dynamic_axes=dynamic_axes, + export_params=True, + keep_initializers_as_inputs=True, + do_constant_folding=True, + verbose=show, +- opset_version=opset_version) ++ opset_version=opset_version, ++ enable_onnx_checker=False) + + model.forward = orig_model.forward + print(f'Successfully exported ONNX model: {output_file}') +@@ -67,6 +76,7 @@ def pytorch2onnx(config_path, + tensor_data = [one_img] + # check the numerical value + # get pytorch output ++ one_meta['img_shape_for_onnx'] = one_img.shape[-2:] + pytorch_results = model(tensor_data, [[one_meta]], return_loss=False) + pytorch_results = pytorch_results[0] + # get onnx output -- Gitee From c2e61f13e149da647e2cd800ce2f2f695ceddccb Mon Sep 17 00:00:00 2001 From: Guanzhong Chen Date: Fri, 15 Dec 2023 16:44:10 +0800 Subject: [PATCH 2/6] 1 --- .../built-in/cv/detection/ssd/README.md | 17 +++++++++++------ .../cv/detection/ssd/ssd-requirements.txt | 2 +- 2 files changed, 12 insertions(+), 7 deletions(-) diff --git a/AscendIE/TorchAIE/built-in/cv/detection/ssd/README.md b/AscendIE/TorchAIE/built-in/cv/detection/ssd/README.md index 3035c50b37..a63c7f9b09 100644 --- a/AscendIE/TorchAIE/built-in/cv/detection/ssd/README.md +++ b/AscendIE/TorchAIE/built-in/cv/detection/ssd/README.md @@ -68,7 +68,12 @@ conda create --name ssd python=3.7.16 ## 获取源码 -1. 获取SSD源代码并修改mmdetection。 +1. 安装依赖。 + ```shell + pip3 install -r ssd-requirements.txt + ``` + +2. 获取SSD源代码并修改mmdetection。 ```shell git clone https://github.com/open-mmlab/mmdetection.git cd mmdetection @@ -78,9 +83,9 @@ conda create --name ssd python=3.7.16 cd .. ``` -2. 安装依赖。 - ```shell - pip3 install -r ssd-requirements.txt +3. 随后执行(该步骤如初次安装,可能需等待较长时间): + ``` + pip3 install --no-cache-dir mmcv-full==1.2.7 ``` ## 准备数据集 @@ -160,7 +165,7 @@ conda create --name ssd python=3.7.16 2. 开始推理验证。 - 1. 执行推理。 + 1. 执行推理。(使用torch2.0.1版本) ```shell python3 acc_dataset.py --ts_path ./ssd300_coco.ts --img_bin_path ./coco2017_bin --save_dir ./pyinfer_res_npu ``` @@ -173,7 +178,7 @@ conda create --name ssd python=3.7.16 2. 精度验证。 - 调用coco_eval.py评测map精度: + 调用coco_eval.py评测map精度(使用torch1.8版本): ```shell det_path=postprocess_out diff --git a/AscendIE/TorchAIE/built-in/cv/detection/ssd/ssd-requirements.txt b/AscendIE/TorchAIE/built-in/cv/detection/ssd/ssd-requirements.txt index d3f8a8897c..27c4ebd5e9 100644 --- a/AscendIE/TorchAIE/built-in/cv/detection/ssd/ssd-requirements.txt +++ b/AscendIE/TorchAIE/built-in/cv/detection/ssd/ssd-requirements.txt @@ -2,6 +2,6 @@ protobuf==3.20.0 Cython==0.29.35 matplotlib==3.5.3 mmpycocotools==12.0.3 -mmcv-full==1.2.7 torch==1.8.1 +torchvision==0.9.1 tqdm==4.66.1 \ No newline at end of file -- Gitee From 82cda182b86802800b1f057e2c4d79235915b4bc Mon Sep 17 00:00:00 2001 From: Guanzhong Chen Date: Fri, 15 Dec 2023 17:25:48 +0800 Subject: [PATCH 3/6] 1 --- .../built-in/cv/detection/ssd/export.py | 2 +- .../built-in/cv/detection/ssd/onnx_helper.py | 245 ++++++++++++++++++ .../cv/detection/ssd/update_ssd_mmdet.diff | 13 + 3 files changed, 259 insertions(+), 1 deletion(-) create mode 100644 AscendIE/TorchAIE/built-in/cv/detection/ssd/onnx_helper.py diff --git a/AscendIE/TorchAIE/built-in/cv/detection/ssd/export.py b/AscendIE/TorchAIE/built-in/cv/detection/ssd/export.py index b7dab1fd16..8d17a38aa4 100644 --- a/AscendIE/TorchAIE/built-in/cv/detection/ssd/export.py +++ b/AscendIE/TorchAIE/built-in/cv/detection/ssd/export.py @@ -25,7 +25,7 @@ def pytorch2onnx(config_path, config_path, checkpoint_path, input_config) ts_model = torch.jit.trace(model, tensor_data) - ts_model.save("./ssd300_coco_torch201.ts") + ts_model.save("./ssd300_coco.ts") def parse_args(): diff --git a/AscendIE/TorchAIE/built-in/cv/detection/ssd/onnx_helper.py b/AscendIE/TorchAIE/built-in/cv/detection/ssd/onnx_helper.py new file mode 100644 index 0000000000..9abd220baf --- /dev/null +++ b/AscendIE/TorchAIE/built-in/cv/detection/ssd/onnx_helper.py @@ -0,0 +1,245 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os + +import torch + + +def dynamic_clip_for_onnx(x1, y1, x2, y2, max_shape): + """Clip boxes dynamically for onnx. + + Since torch.clamp cannot have dynamic `min` and `max`, we scale the + boxes by 1/max_shape and clamp in the range [0, 1]. + + Args: + x1 (Tensor): The x1 for bounding boxes. + y1 (Tensor): The y1 for bounding boxes. + x2 (Tensor): The x2 for bounding boxes. + y2 (Tensor): The y2 for bounding boxes. + max_shape (Tensor or torch.Size): The (H,W) of original image. + Returns: + tuple(Tensor): The clipped x1, y1, x2, y2. + """ + # assert isinstance( + # max_shape, + # torch.Tensor), '`max_shape` should be tensor of (h,w) for onnx, got {}'.format(max_shape.__class__.__name__) + + assert isinstance(max_shape, (torch.Tensor, torch.Size, list, tuple)), '`max_shape` should be ' + \ + 'torch.Tensor/torch.Size/list/tuple of (h, w) for onnx, got {}'.format(max_shape.__class__.__name__) + if not isinstance(max_shape, torch.Tensor): + max_shape = torch.tensor(max_shape, dtype=x1.dtype, device=x1.device) + else: + max_shape = max_shape.type_as(x1) + + # scale by 1/max_shape + x1 = x1 / max_shape[1] + y1 = y1 / max_shape[0] + x2 = x2 / max_shape[1] + y2 = y2 / max_shape[0] + + # clamp [0, 1] + x1 = torch.clamp(x1, 0, 1) + y1 = torch.clamp(y1, 0, 1) + x2 = torch.clamp(x2, 0, 1) + y2 = torch.clamp(y2, 0, 1) + + # scale back + x1 = x1 * max_shape[1] + y1 = y1 * max_shape[0] + x2 = x2 * max_shape[1] + y2 = y2 * max_shape[0] + return x1, y1, x2, y2 + + +def get_k_for_topk(k, size): + """Get k of TopK for onnx exporting. + + The K of TopK in TensorRT should not be a Tensor, while in ONNX Runtime + it could be a Tensor.Due to dynamic shape feature, we have to decide + whether to do TopK and what K it should be while exporting to ONNX. + If returned K is less than zero, it means we do not have to do + TopK operation. + + Args: + k (int or Tensor): The set k value for nms from config file. + size (Tensor or torch.Size): The number of elements of \ + TopK's input tensor + Returns: + tuple: (int or Tensor): The final K for TopK. + """ + ret_k = -1 + if k <= 0 or size <= 0: + return ret_k + if torch.onnx.is_in_onnx_export(): + is_trt_backend = os.environ.get('ONNX_BACKEND') == 'MMCVTensorRT' + if is_trt_backend: + # TensorRT does not support dynamic K with TopK op + if 0 < k < size: + ret_k = k + else: + # Always keep topk op for dynamic input in onnx for ONNX Runtime + ret_k = torch.where(k < size, k, size) + elif k < size: + ret_k = k + else: + # ret_k is -1 + pass + return ret_k + + +def add_dummy_nms_for_onnx(boxes, + scores, + max_output_boxes_per_class=1000, + iou_threshold=0.5, + score_threshold=0.05, + pre_top_k=-1, + after_top_k=-1, + labels=None): + """Create a dummy onnx::NonMaxSuppression op while exporting to ONNX. + + This function helps exporting to onnx with batch and multiclass NMS op. + It only supports class-agnostic detection results. That is, the scores + is of shape (N, num_bboxes, num_classes) and the boxes is of shape + (N, num_boxes, 4). + + Args: + boxes (Tensor): The bounding boxes of shape [N, num_boxes, 4] + scores (Tensor): The detection scores of shape + [N, num_boxes, num_classes] + max_output_boxes_per_class (int): Maximum number of output + boxes per class of nms. Defaults to 1000. + iou_threshold (float): IOU threshold of nms. Defaults to 0.5 + score_threshold (float): score threshold of nms. + Defaults to 0.05. + pre_top_k (bool): Number of top K boxes to keep before nms. + Defaults to -1. + after_top_k (int): Number of top K boxes to keep after nms. + Defaults to -1. + labels (Tensor, optional): It not None, explicit labels would be used. + Otherwise, labels would be automatically generated using + num_classed. Defaults to None. + + Returns: + tuple[Tensor, Tensor]: dets of shape [N, num_det, 5] + and class labels of shape [N, num_det]. + """ + max_output_boxes_per_class = torch.LongTensor([max_output_boxes_per_class]) + iou_threshold = torch.tensor([iou_threshold], dtype=torch.float32) + score_threshold = torch.tensor([score_threshold], dtype=torch.float32) + batch_size = scores.shape[0] + num_class = scores.shape[2] + + if pre_top_k > 0: + nms_pre = torch.tensor(pre_top_k, device=scores.device, dtype=torch.long) + nms_pre = get_k_for_topk(nms_pre, boxes.shape[1]) + + if nms_pre > 0: + max_scores, _ = scores.max(-1) + _, topk_inds = max_scores.topk(nms_pre) + batch_inds = torch.arange(batch_size).view( + -1, 1).expand_as(topk_inds).long() + # Avoid onnx2tensorrt issue in https://github.com/NVIDIA/TensorRT/issues/1134 # noqa: E501 + # transformed_inds = boxes.shape[1] * batch_inds + topk_inds + transformed_inds = (boxes.shape[1] * batch_inds.int()) + topk_inds.int() + transformed_inds = transformed_inds.long() + boxes = boxes.reshape(-1, 4)[transformed_inds, :].reshape( + batch_size, -1, 4) + scores = scores.reshape(-1, num_class)[transformed_inds, :].reshape( + batch_size, -1, num_class) + if labels is not None: + labels = labels.reshape(-1, 1)[transformed_inds].reshape( + batch_size, -1) + + scores = scores.permute(0, 2, 1) + num_box = boxes.shape[1] + # turn off tracing to create a dummy output of nms + state = torch._C._get_tracing_state() + # dummy indices of nms's output + num_fake_det = 2 + batch_inds = torch.randint(batch_size, (num_fake_det, 1)) + cls_inds = torch.randint(num_class, (num_fake_det, 1)) + box_inds = torch.randint(num_box, (num_fake_det, 1)) + indices = torch.cat([batch_inds, cls_inds, box_inds], dim=1) + output = indices + setattr(DummyONNXNMSop, 'output', output) + + # open tracing + torch._C._set_tracing_state(state) + selected_indices = DummyONNXNMSop.apply(boxes, scores, + max_output_boxes_per_class, + iou_threshold, score_threshold) + + batch_inds, cls_inds = selected_indices[:, 0], selected_indices[:, 1] + box_inds = selected_indices[:, 2] + if labels is None: + labels = torch.arange(num_class, dtype=torch.long).to(scores.device) + labels = labels.view(1, num_class, 1).expand_as(scores) + scores = scores.reshape(-1, 1) + boxes = boxes.reshape(batch_size, -1).repeat(1, num_class).reshape(-1, 4) + # pos_inds = (num_class * batch_inds + cls_inds) * num_box + box_inds # original + pos_inds = (num_class * batch_inds.int()) + cls_inds.int() + pos_inds = (pos_inds * num_box.int()) + box_inds.int() + pos_inds = pos_inds.long() + # pos_inds = (batch_inds.new_tensor(num_class) * batch_inds + cls_inds) * batch_inds.new_tensor(num_box) + box_inds + mask = scores.new_zeros(scores.shape) + # Avoid onnx2tensorrt issue in https://github.com/NVIDIA/TensorRT/issues/1134 # noqa: E501 + # PyTorch style code: mask[batch_inds, box_inds] += 1 + mask[pos_inds, :] += 1 + scores = scores * mask + boxes = boxes * mask + + scores = scores.reshape(batch_size, -1) + boxes = boxes.reshape(batch_size, -1, 4) + labels = labels.reshape(batch_size, -1) + + if boxes.dtype != torch.float: + boxes = boxes.float() + scores = scores.float() + + if after_top_k > 0: + nms_after = torch.tensor( + after_top_k, device=scores.device, dtype=torch.long) + nms_after = get_k_for_topk(nms_after, num_box * num_class) + + if nms_after > 0: + _, topk_inds = scores.topk(nms_after) + batch_inds = torch.arange(batch_size).view(-1, 1).expand_as(topk_inds).long() + # Avoid onnx2tensorrt issue in https://github.com/NVIDIA/TensorRT/issues/1134 # noqa: E501 + batch_inds = scores.shape[1] * batch_inds + # transformed_inds = batch_inds + topk_inds + transformed_inds = batch_inds.int() + topk_inds.int() + transformed_inds = transformed_inds.long() + scores = scores.reshape(-1, 1)[transformed_inds, :].reshape( + batch_size, -1) + boxes = boxes.reshape(-1, 4)[transformed_inds, :].reshape( + batch_size, -1, 4) + labels = labels.reshape(-1, 1)[transformed_inds, :].reshape( + batch_size, -1) + + scores = scores.unsqueeze(2) + dets = torch.cat([boxes, scores], dim=2) + return dets, labels + + +class DummyONNXNMSop(torch.autograd.Function): + """DummyONNXNMSop. + + This class is only for creating onnx::NonMaxSuppression. + """ + + @staticmethod + def forward(ctx, boxes, scores, max_output_boxes_per_class, iou_threshold, + score_threshold): + + return DummyONNXNMSop.output + + @staticmethod + def symbolic(g, boxes, scores, max_output_boxes_per_class, iou_threshold, + score_threshold): + return g.op( + 'NonMaxSuppression', + boxes, + scores, + max_output_boxes_per_class, + iou_threshold, + score_threshold, + outputs=1) diff --git a/AscendIE/TorchAIE/built-in/cv/detection/ssd/update_ssd_mmdet.diff b/AscendIE/TorchAIE/built-in/cv/detection/ssd/update_ssd_mmdet.diff index cd1d46528d..d3ceb38444 100644 --- a/AscendIE/TorchAIE/built-in/cv/detection/ssd/update_ssd_mmdet.diff +++ b/AscendIE/TorchAIE/built-in/cv/detection/ssd/update_ssd_mmdet.diff @@ -653,6 +653,19 @@ index 93d051e7..94e496b5 100644 _, topk_inds = conf_pred.topk(nms_pre) bbox_pred = bbox_pred[topk_inds, :] cls_pred = cls_pred[topk_inds, :] +diff --git a/mmdet/models/detectors/base.py b/mmdet/models/detectors/base.py +index 7c6d5e96..8bd74238 100644 +--- a/mmdet/models/detectors/base.py ++++ b/mmdet/models/detectors/base.py +@@ -179,6 +179,8 @@ class BaseDetector(nn.Module, metaclass=ABCMeta): + if return_loss: + return self.forward_train(img, img_metas, **kwargs) + else: ++ if not isinstance(img, list): ++ img = [img] + return self.forward_test(img, img_metas, **kwargs) + + def _parse_losses(self, losses): diff --git a/mmdet/models/detectors/single_stage.py b/mmdet/models/detectors/single_stage.py index 96c4acac..f5be2641 100644 --- a/mmdet/models/detectors/single_stage.py -- Gitee From 747e44e09fea61c7085715e66b5f5d0b74459d7b Mon Sep 17 00:00:00 2001 From: Guanzhong Chen Date: Fri, 15 Dec 2023 21:00:28 +0800 Subject: [PATCH 4/6] 1 --- AscendIE/TorchAIE/built-in/cv/detection/ssd/README.md | 1 + AscendIE/TorchAIE/built-in/cv/detection/ssd/export.py | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/AscendIE/TorchAIE/built-in/cv/detection/ssd/README.md b/AscendIE/TorchAIE/built-in/cv/detection/ssd/README.md index a63c7f9b09..bdb8fefd76 100644 --- a/AscendIE/TorchAIE/built-in/cv/detection/ssd/README.md +++ b/AscendIE/TorchAIE/built-in/cv/detection/ssd/README.md @@ -82,6 +82,7 @@ conda create --name ssd python=3.7.16 pip install -v -e . cd .. ``` + 将`onnx_helper.py`文件,放置在`mmdetection/mmdet/core/export`目录下 3. 随后执行(该步骤如初次安装,可能需等待较长时间): ``` diff --git a/AscendIE/TorchAIE/built-in/cv/detection/ssd/export.py b/AscendIE/TorchAIE/built-in/cv/detection/ssd/export.py index 8d17a38aa4..527acd040a 100644 --- a/AscendIE/TorchAIE/built-in/cv/detection/ssd/export.py +++ b/AscendIE/TorchAIE/built-in/cv/detection/ssd/export.py @@ -39,7 +39,7 @@ def parse_args(): '--shape', type=int, nargs='+', - default=[800, 1216], + default=[300], help='input image size') parser.add_argument( '--mean', @@ -51,7 +51,7 @@ def parse_args(): '--std', type=float, nargs='+', - default=[58.395, 57.12, 57.375], + default=[1, 1, 1], help='variance value used for preprocess input data') args = parser.parse_args() return args -- Gitee From 26e8f562c4a6c37c1b90356d43c3b5a77caffedd Mon Sep 17 00:00:00 2001 From: Guanzhong Chen Date: Tue, 19 Dec 2023 16:19:51 +0800 Subject: [PATCH 5/6] 1 --- .../built-in/cv/detection/ssd/README.md | 223 ----- .../built-in/cv/detection/ssd/onnx_helper.py | 245 ------ .../cv/detection/ssd/ssd_postprocess.py | 294 ------- .../cv/detection/ssd/update_ssd_mmdet.diff | 811 ------------------ 4 files changed, 1573 deletions(-) delete mode 100644 AscendIE/TorchAIE/built-in/cv/detection/ssd/README.md delete mode 100644 AscendIE/TorchAIE/built-in/cv/detection/ssd/onnx_helper.py delete mode 100644 AscendIE/TorchAIE/built-in/cv/detection/ssd/ssd_postprocess.py delete mode 100644 AscendIE/TorchAIE/built-in/cv/detection/ssd/update_ssd_mmdet.diff diff --git a/AscendIE/TorchAIE/built-in/cv/detection/ssd/README.md b/AscendIE/TorchAIE/built-in/cv/detection/ssd/README.md deleted file mode 100644 index bdb8fefd76..0000000000 --- a/AscendIE/TorchAIE/built-in/cv/detection/ssd/README.md +++ /dev/null @@ -1,223 +0,0 @@ -# SSD模型-推理指导 - -- [概述](#ZH-CN_TOPIC_0000001172161501) - - - [输入输出数据](#ZH-CN_TOPIC_0000001126281702) - -- [推理环境准备](#ZH-CN_TOPIC_0000001126281702) - -- [快速上手](#ZH-CN_TOPIC_0000001126281700) - - - [获取源码](#section4622531142816) - - [准备数据集](#section183221994411) - - [模型推理](#section741711594517) - -- [模型推理性能&精度](#ZH-CN_TOPIC_0000001172201573) - - ****** - -# 概述 - -SSD将detection转化为regression的思路,可以一次完成目标定位与分类。该算法基于Faster RCNN中的Anchor,提出了相似的Prior box;该算法修改了传统的SSD网络:将SSD的FC6和FC7层转化为卷积层,去掉所有的Dropout层和FC8层。同时加入基于特征金字塔的检测方式,在不同感受野的feature map上预测目标。 - -- 参考实现: - - ```shell - url=https://github.com/open-mmlab/mmdetection.git - branch=master - commit_id=a21eb25535f31634cef332b09fc27d28956fb24b - model_name=ssd - ``` - -## 输入输出数据 - -- 输入数据 - - | 输入数据 | 数据类型 | 大小 | 数据排布格式 | - | -------- | -------- | ------------------------- | ------------ | - | input | RGB_FP32 | batchsize x 3 x 300 x 300 | NCHW | - -- 输出数据 - - | 输出数据 | 数据类型 | 大小 | 数据排布格式 | - | -------- | -------- | --------------------- | ------------ | - | boxes | FLOAT32 | batchsize x 8732 x 4 | ND | - | labels | FLOAT32 | batchsize x 8732 x 80 | ND | - -# 推理环境准备 - -- 该模型需要两套环境切换运行,用于执行推理的环境(包括插件与驱动)如下 - - **表 1** 版本配套表 - - | 配套 | 版本 | 环境准备指导 | - |---------| ------- | ------------------------------------------------------------ | - | 固件与驱动 | 23.0.rc1 | [Pytorch框架推理环境准备](https://www.hiascend.com/document/detail/zh/ModelZoo/pytorchframework/pies) | - | CANN | 7.0.RC1.alpha003 | - | - | Python | 3.9.11 | - | - | PyTorch | 2.0.1 | - | - | Torch_AIE | 6.3.rc2 | - | - -- 用于执行前后处理以及模型导出,则需要另一套环境,建议使用conda命令构建虚拟环境,并安装相应的包 - -``` -conda create --name ssd python=3.7.16 -``` - -# 快速上手 - -## 获取源码 - -1. 安装依赖。 - ```shell - pip3 install -r ssd-requirements.txt - ``` - -2. 获取SSD源代码并修改mmdetection。 - ```shell - git clone https://github.com/open-mmlab/mmdetection.git - cd mmdetection - git reset --hard a21eb25535f31634cef332b09fc27d28956fb24b - patch -p1 < ../update_ssd_mmdet.diff - pip install -v -e . - cd .. - ``` - 将`onnx_helper.py`文件,放置在`mmdetection/mmdet/core/export`目录下 - -3. 随后执行(该步骤如初次安装,可能需等待较长时间): - ``` - pip3 install --no-cache-dir mmcv-full==1.2.7 - ``` - -## 准备数据集 - -1. 获取原始数据集。(解压命令参考tar –xvf \*.tar与 unzip \*.zip) - - 推理数据集采用 [coco_val_2017](http://images.cocodataset.org),数据集下载后存放路径:`dataset=/root/datasets/coco` - - 目录结构: - - ``` - ├── coco - │ ├── val2017 - │ ├── annotations - │ ├──instances_val2017.json - ``` - -2. 数据预处理(使用torch 1.8环境)。 - - 将原始数据集转换为模型输入的二进制数据。执行 `ssd_preprocess.py` 脚本。 - - ```shell - python ssd_preprocess.py \ - --image_folder_path $dataset/val2017 \ - --bin_folder_path val2017_ssd_bin - ``` - - - 参数说明: - - - --image_folder_path:原始数据验证集(.jpg)所在路径。 - - --bin_folder_path:输出的二进制文件(.bin)所在路径。 - - 每个图像对应生成一个二进制文件。 - -3. 生成数据集info文件(使用torch 1.8环境)。 - - 运行 `get_info.py` 脚本,生成图片数据info文件。 - ```shell - python get_info.py jpg $dataset/val2017 coco2017_ssd_jpg.info - ``` - - - 参数说明: - - - 第一个参数:生成的数据集文件格式。 - - 第二个参数:预处理后的数据文件相对路径。 - - 第三个参数:即将生成的info文件名。 - - 运行成功后,在当前目录中生成 `coco2017_ssd_jpg.info`。 - -## 模型推理 - -1. 模型转换。 - - 使用PyTorch将模型权重文件.pth转换为.ts文件。 - 1. 获取权重文件。 - - 获取经过训练的权重文件:[ssd300_coco_20200307-a92d2092.pth](http://download.openmmlab.com/mmdetection/v2.0/ssd/ssd300_coco/ssd300_coco_20200307-a92d2092.pth) - - 2. 导出onnx文件。 - - 使用`export.py`导出ts文件(使用torch 1.8环境) - - ``` - python export.py \ - --checkpoint ./ssd300_coco_20200307-a92d2092.pth \ - --mmdet_path ./mmdetection \ - --shape=300 \ - --mean 123.675 116.28 103.53 \ - --std 1 1 1 - ``` - - - 参数说明: - - - checkpoint:原始pth文件所在路径 - - mmdet_path:github拉入文件夹路径 - - shape:图像尺寸 - -2. 开始推理验证。 - - 1. 执行推理。(使用torch2.0.1版本) - ```shell - python3 acc_dataset.py --ts_path ./ssd300_coco.ts --img_bin_path ./coco2017_bin --save_dir ./pyinfer_res_npu - ``` - - - 参数说明: - - - ts_path:导出ts文件路径 - - img_bin_path:图片预处理得到的bin文件夹所在路径 - - save_dir:保存推理结果的路径 - - 2. 精度验证。 - - 调用coco_eval.py评测map精度(使用torch1.8版本): - - ```shell - det_path=postprocess_out - python ssd_postprocess.py \ - --bin_data_path=out/2022_*/ \ - --score_threshold=0.02 \ - --test_annotation=coco2017_ssd_jpg.info \ - --nms_pre 200 \ - --det_results_path ${det_path} - python txt_to_json.py --npu_txt_path ${det_path} - python coco_eval.py --ground_truth /root/datasets/coco/annotations/instances_val2017.json - ``` - - - 参数说明: - - - --bin_data_path:为推理结果存放的路径。 - - --score_threshold:得分阈值。 - - --test_annotation:原始图片信息文件。 - - --nms_pre:每张图片获取框数量的阈值。 - - --det_results_path:后处理输出路径。 - - --npu_txt_path:后处理输出路径。 - - --ground_truth:instances_val2017.json文件路径。 - -# 模型推理性能&精度 - -调用ACL接口推理计算,性能参考下列数据。 - -| | mAP | -| --------- | -------- | -| 310P3精度 | mAP=25.4 | - - -| Throughput | 310*4 | 310P3 | 310B1 | -| ---------- | -------- | -------- | ----- | -| bs1 | 179.194 | 298.5514 | 75.42 | -| bs4 | 207.596 | 337.0112 | 77.9 | -| bs8 | 211.7312 | 323.5662 | 79.77 | -| bs16 | 211.288 | 318.1392 | 77.84 | -| bs32 | 200.2948 | 318.7303 | 79.78 | -| bs64 | 196.4192 | 313.0790 | 48.36 | -| 最优batch | 211.7312 | 337.0112 | 79.77 | \ No newline at end of file diff --git a/AscendIE/TorchAIE/built-in/cv/detection/ssd/onnx_helper.py b/AscendIE/TorchAIE/built-in/cv/detection/ssd/onnx_helper.py deleted file mode 100644 index 9abd220baf..0000000000 --- a/AscendIE/TorchAIE/built-in/cv/detection/ssd/onnx_helper.py +++ /dev/null @@ -1,245 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import os - -import torch - - -def dynamic_clip_for_onnx(x1, y1, x2, y2, max_shape): - """Clip boxes dynamically for onnx. - - Since torch.clamp cannot have dynamic `min` and `max`, we scale the - boxes by 1/max_shape and clamp in the range [0, 1]. - - Args: - x1 (Tensor): The x1 for bounding boxes. - y1 (Tensor): The y1 for bounding boxes. - x2 (Tensor): The x2 for bounding boxes. - y2 (Tensor): The y2 for bounding boxes. - max_shape (Tensor or torch.Size): The (H,W) of original image. - Returns: - tuple(Tensor): The clipped x1, y1, x2, y2. - """ - # assert isinstance( - # max_shape, - # torch.Tensor), '`max_shape` should be tensor of (h,w) for onnx, got {}'.format(max_shape.__class__.__name__) - - assert isinstance(max_shape, (torch.Tensor, torch.Size, list, tuple)), '`max_shape` should be ' + \ - 'torch.Tensor/torch.Size/list/tuple of (h, w) for onnx, got {}'.format(max_shape.__class__.__name__) - if not isinstance(max_shape, torch.Tensor): - max_shape = torch.tensor(max_shape, dtype=x1.dtype, device=x1.device) - else: - max_shape = max_shape.type_as(x1) - - # scale by 1/max_shape - x1 = x1 / max_shape[1] - y1 = y1 / max_shape[0] - x2 = x2 / max_shape[1] - y2 = y2 / max_shape[0] - - # clamp [0, 1] - x1 = torch.clamp(x1, 0, 1) - y1 = torch.clamp(y1, 0, 1) - x2 = torch.clamp(x2, 0, 1) - y2 = torch.clamp(y2, 0, 1) - - # scale back - x1 = x1 * max_shape[1] - y1 = y1 * max_shape[0] - x2 = x2 * max_shape[1] - y2 = y2 * max_shape[0] - return x1, y1, x2, y2 - - -def get_k_for_topk(k, size): - """Get k of TopK for onnx exporting. - - The K of TopK in TensorRT should not be a Tensor, while in ONNX Runtime - it could be a Tensor.Due to dynamic shape feature, we have to decide - whether to do TopK and what K it should be while exporting to ONNX. - If returned K is less than zero, it means we do not have to do - TopK operation. - - Args: - k (int or Tensor): The set k value for nms from config file. - size (Tensor or torch.Size): The number of elements of \ - TopK's input tensor - Returns: - tuple: (int or Tensor): The final K for TopK. - """ - ret_k = -1 - if k <= 0 or size <= 0: - return ret_k - if torch.onnx.is_in_onnx_export(): - is_trt_backend = os.environ.get('ONNX_BACKEND') == 'MMCVTensorRT' - if is_trt_backend: - # TensorRT does not support dynamic K with TopK op - if 0 < k < size: - ret_k = k - else: - # Always keep topk op for dynamic input in onnx for ONNX Runtime - ret_k = torch.where(k < size, k, size) - elif k < size: - ret_k = k - else: - # ret_k is -1 - pass - return ret_k - - -def add_dummy_nms_for_onnx(boxes, - scores, - max_output_boxes_per_class=1000, - iou_threshold=0.5, - score_threshold=0.05, - pre_top_k=-1, - after_top_k=-1, - labels=None): - """Create a dummy onnx::NonMaxSuppression op while exporting to ONNX. - - This function helps exporting to onnx with batch and multiclass NMS op. - It only supports class-agnostic detection results. That is, the scores - is of shape (N, num_bboxes, num_classes) and the boxes is of shape - (N, num_boxes, 4). - - Args: - boxes (Tensor): The bounding boxes of shape [N, num_boxes, 4] - scores (Tensor): The detection scores of shape - [N, num_boxes, num_classes] - max_output_boxes_per_class (int): Maximum number of output - boxes per class of nms. Defaults to 1000. - iou_threshold (float): IOU threshold of nms. Defaults to 0.5 - score_threshold (float): score threshold of nms. - Defaults to 0.05. - pre_top_k (bool): Number of top K boxes to keep before nms. - Defaults to -1. - after_top_k (int): Number of top K boxes to keep after nms. - Defaults to -1. - labels (Tensor, optional): It not None, explicit labels would be used. - Otherwise, labels would be automatically generated using - num_classed. Defaults to None. - - Returns: - tuple[Tensor, Tensor]: dets of shape [N, num_det, 5] - and class labels of shape [N, num_det]. - """ - max_output_boxes_per_class = torch.LongTensor([max_output_boxes_per_class]) - iou_threshold = torch.tensor([iou_threshold], dtype=torch.float32) - score_threshold = torch.tensor([score_threshold], dtype=torch.float32) - batch_size = scores.shape[0] - num_class = scores.shape[2] - - if pre_top_k > 0: - nms_pre = torch.tensor(pre_top_k, device=scores.device, dtype=torch.long) - nms_pre = get_k_for_topk(nms_pre, boxes.shape[1]) - - if nms_pre > 0: - max_scores, _ = scores.max(-1) - _, topk_inds = max_scores.topk(nms_pre) - batch_inds = torch.arange(batch_size).view( - -1, 1).expand_as(topk_inds).long() - # Avoid onnx2tensorrt issue in https://github.com/NVIDIA/TensorRT/issues/1134 # noqa: E501 - # transformed_inds = boxes.shape[1] * batch_inds + topk_inds - transformed_inds = (boxes.shape[1] * batch_inds.int()) + topk_inds.int() - transformed_inds = transformed_inds.long() - boxes = boxes.reshape(-1, 4)[transformed_inds, :].reshape( - batch_size, -1, 4) - scores = scores.reshape(-1, num_class)[transformed_inds, :].reshape( - batch_size, -1, num_class) - if labels is not None: - labels = labels.reshape(-1, 1)[transformed_inds].reshape( - batch_size, -1) - - scores = scores.permute(0, 2, 1) - num_box = boxes.shape[1] - # turn off tracing to create a dummy output of nms - state = torch._C._get_tracing_state() - # dummy indices of nms's output - num_fake_det = 2 - batch_inds = torch.randint(batch_size, (num_fake_det, 1)) - cls_inds = torch.randint(num_class, (num_fake_det, 1)) - box_inds = torch.randint(num_box, (num_fake_det, 1)) - indices = torch.cat([batch_inds, cls_inds, box_inds], dim=1) - output = indices - setattr(DummyONNXNMSop, 'output', output) - - # open tracing - torch._C._set_tracing_state(state) - selected_indices = DummyONNXNMSop.apply(boxes, scores, - max_output_boxes_per_class, - iou_threshold, score_threshold) - - batch_inds, cls_inds = selected_indices[:, 0], selected_indices[:, 1] - box_inds = selected_indices[:, 2] - if labels is None: - labels = torch.arange(num_class, dtype=torch.long).to(scores.device) - labels = labels.view(1, num_class, 1).expand_as(scores) - scores = scores.reshape(-1, 1) - boxes = boxes.reshape(batch_size, -1).repeat(1, num_class).reshape(-1, 4) - # pos_inds = (num_class * batch_inds + cls_inds) * num_box + box_inds # original - pos_inds = (num_class * batch_inds.int()) + cls_inds.int() - pos_inds = (pos_inds * num_box.int()) + box_inds.int() - pos_inds = pos_inds.long() - # pos_inds = (batch_inds.new_tensor(num_class) * batch_inds + cls_inds) * batch_inds.new_tensor(num_box) + box_inds - mask = scores.new_zeros(scores.shape) - # Avoid onnx2tensorrt issue in https://github.com/NVIDIA/TensorRT/issues/1134 # noqa: E501 - # PyTorch style code: mask[batch_inds, box_inds] += 1 - mask[pos_inds, :] += 1 - scores = scores * mask - boxes = boxes * mask - - scores = scores.reshape(batch_size, -1) - boxes = boxes.reshape(batch_size, -1, 4) - labels = labels.reshape(batch_size, -1) - - if boxes.dtype != torch.float: - boxes = boxes.float() - scores = scores.float() - - if after_top_k > 0: - nms_after = torch.tensor( - after_top_k, device=scores.device, dtype=torch.long) - nms_after = get_k_for_topk(nms_after, num_box * num_class) - - if nms_after > 0: - _, topk_inds = scores.topk(nms_after) - batch_inds = torch.arange(batch_size).view(-1, 1).expand_as(topk_inds).long() - # Avoid onnx2tensorrt issue in https://github.com/NVIDIA/TensorRT/issues/1134 # noqa: E501 - batch_inds = scores.shape[1] * batch_inds - # transformed_inds = batch_inds + topk_inds - transformed_inds = batch_inds.int() + topk_inds.int() - transformed_inds = transformed_inds.long() - scores = scores.reshape(-1, 1)[transformed_inds, :].reshape( - batch_size, -1) - boxes = boxes.reshape(-1, 4)[transformed_inds, :].reshape( - batch_size, -1, 4) - labels = labels.reshape(-1, 1)[transformed_inds, :].reshape( - batch_size, -1) - - scores = scores.unsqueeze(2) - dets = torch.cat([boxes, scores], dim=2) - return dets, labels - - -class DummyONNXNMSop(torch.autograd.Function): - """DummyONNXNMSop. - - This class is only for creating onnx::NonMaxSuppression. - """ - - @staticmethod - def forward(ctx, boxes, scores, max_output_boxes_per_class, iou_threshold, - score_threshold): - - return DummyONNXNMSop.output - - @staticmethod - def symbolic(g, boxes, scores, max_output_boxes_per_class, iou_threshold, - score_threshold): - return g.op( - 'NonMaxSuppression', - boxes, - scores, - max_output_boxes_per_class, - iou_threshold, - score_threshold, - outputs=1) diff --git a/AscendIE/TorchAIE/built-in/cv/detection/ssd/ssd_postprocess.py b/AscendIE/TorchAIE/built-in/cv/detection/ssd/ssd_postprocess.py deleted file mode 100644 index 09dc47196e..0000000000 --- a/AscendIE/TorchAIE/built-in/cv/detection/ssd/ssd_postprocess.py +++ /dev/null @@ -1,294 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""coco postprocess""" - -import os -import numpy as np -import argparse -import cv2 -import warnings -import torch -import time -try: - from torch import npu_batch_nms as NMSOp - NMS_ON_NPU = True -except: - from torchvision.ops import batched_nms as NMSOp - NMS_ON_NPU = False - -CLASSES = ['person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', - 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', - 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', - 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', - 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', - 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', - 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', - 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', - 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', - 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', - 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', - 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', - 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', - 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'] - - -def coco_postprocess(bbox, image_size, net_input_width, net_input_height): - """ - This function is postprocessing for FasterRCNN output. - - Before calling this function, reshape the raw output of FasterRCNN to - following form - numpy.ndarray: - [x, y, width, height, confidence, probability of 80 classes] - shape: (100,) - The postprocessing restore the bounding rectangles of FasterRCNN output - to origin scale and filter with non-maximum suppression. - - :param bbox: a numpy array of the FasterRCNN output - :param image_path: a string of image path - :return: three list for best bound, class and score - """ - w = image_size[0] - h = image_size[1] - scale_w = net_input_width / w - scale_h = net_input_height / h - - # cal predict box on the image src - pbox = bbox.copy() - pbox[:, 0] = (bbox[:, 0]) / scale_w - pbox[:, 1] = (bbox[:, 1]) / scale_h - pbox[:, 2] = (bbox[:, 2]) / scale_w - pbox[:, 3] = (bbox[:, 3]) / scale_h - return pbox - - -def np_clip_bbox(bboxes, max_shape): - x1, y1, x2, y2 = bboxes[:, 0], bboxes[:, 1], bboxes[:, 2], bboxes[:, 3] - h, w = max_shape - x1 = x1.clip(min=0, max=w) - y1 = y1.clip(min=0, max=h) - x2 = x2.clip(min=0, max=w) - y2 = y2.clip(min=0, max=h) - bboxes = np.stack([x1, y1, x2, y2], axis=-1) - return bboxes - - -if __name__ == '__main__': - parser = argparse.ArgumentParser() - parser.add_argument("--bin_data_path", - default="/onnx/ssd/cpp_infer_res_cpu") - parser.add_argument("--test_annotation", default="./coco2017_ssd_jpg.info") - parser.add_argument("--det_results_path", default="./postprocess_out_cpu/") - parser.add_argument("--net_out_num", default=2, type=int) - parser.add_argument("--num_pred_box", default=8732, type=int) - parser.add_argument("--nms_pre", default=200, type=int) - parser.add_argument("--net_input_width", default=300, type=int) - parser.add_argument("--net_input_height", default=300, type=int) - parser.add_argument("--min_bbox_size", default=0.01, type=float) - parser.add_argument("--score_threshold", default=0.02, type=float) - parser.add_argument("--nms", default=True, type=bool) - parser.add_argument("--iou_threshold", default=0.45, type=float) - parser.add_argument("--max_per_img", default=200, type=int) - parser.add_argument("--ifShowDetObj", action="store_true", default=True, - help="if input the para means True, neither False.") - parser.add_argument("--start", default=0, type=float) - parser.add_argument("--end", default=1, type=float) - parser.add_argument("--device", default=0, type=int) - parser.add_argument("--clear_cache", action='store_true') - flags = parser.parse_args() - - # generate dict according to annotation file for query resolution - # load width and height of input images - img_size_dict = dict() - with open(flags.test_annotation)as f: - for line in f: - temp = line.split(" ") - img_file_path = temp[1] - img_name = temp[1].split("/")[-1].split(".")[0] - img_width = int(temp[2]) - img_height = int(temp[3]) - img_size_dict[img_name] = (img_width, img_height, img_file_path) - - # read bin file for generate predict result - bin_path = flags.bin_data_path # 推理结果保存路径 - det_results_path = flags.det_results_path - os.makedirs(det_results_path, exist_ok=True) - total_img = set([name[:name.rfind('_')] - for name in os.listdir(bin_path) if "bin" in name]) - total_img = sorted(total_img) # list of img names (str) - num_img = len(total_img) # 5000 - start = int(flags.start * num_img) - end = int(flags.end * num_img) - task_len = end - start + 1 - - finished = 0 - time_start = time.time() - for img_id in range(start, end): - # for img_id, bin_file in enumerate(sorted(total_img)): - bin_file = total_img[img_id] - path_base = os.path.join(bin_path, bin_file) - det_results_file = os.path.join(det_results_path, bin_file + ".txt") - if os.path.exists(det_results_file) and not flags.clear_cache: - continue - - # load all detected output tensor - bbox_file = path_base + "_" + str(0) + ".bin" - score_file = path_base + "_" + str(1) + ".bin" - assert os.path.exists( - bbox_file), '[ERROR] file `{}` not exist'.format(bbox_file) - assert os.path.exists( - score_file), '[ERROR] file `{}` not exist'.format(score_file) - bboxes = np.fromfile(bbox_file, dtype="float32").reshape( - flags.num_pred_box, 4) - scores = np.fromfile(score_file, dtype="float32").reshape( - flags.num_pred_box, 80) - - bboxes = torch.from_numpy(bboxes) - scores = torch.from_numpy(scores) - try: - bboxes = bboxes.npu(flags.device) - scores = scores.npu(flags.device) - except: - warnings.warn('npu is not available, running on cpu') - - max_scores, _ = scores.max(-1) # shape of [8732], torch.float32 - keep_inds = (max_scores > flags.score_threshold).nonzero( - as_tuple=False).view(-1) - bboxes = bboxes[keep_inds, :] - scores = scores[keep_inds, :] - - if flags.nms_pre > 0 and flags.nms_pre < bboxes.shape[0]: - max_scores, _ = scores.max(-1) # shape: torch.Size([2738]) dtype:torch.float32 - _, topk_inds = max_scores.topk(flags.nms_pre) - bboxes = bboxes[topk_inds, :] # shape: torch.Size([200, 4]) - scores = scores[topk_inds, :] # shape: torch.Size([200, 80]) - - # clip bbox border - bboxes[:, 0::2].clamp_(min=0, max=flags.net_input_width - 1) - bboxes[:, 1::2].clamp_(min=0, max=flags.net_input_height - 1) - - # remove small bbox - bboxes_width_height = bboxes[:, 2:] - bboxes[:, :2] - valid_bboxes = bboxes_width_height > flags.min_bbox_size - keep_inds = (valid_bboxes[:, 0] & valid_bboxes[:, 1] - ).nonzero(as_tuple=False).view(-1) - bboxes = bboxes[keep_inds, :] - scores = scores[keep_inds, :] - - # rescale bbox to original image size - original_img_info = img_size_dict[bin_file] - rescale_factor = torch.tensor([ - original_img_info[0] / flags.net_input_width, - original_img_info[1] / flags.net_input_height] * 2, - dtype=bboxes.dtype, device=bboxes.device) - bboxes *= rescale_factor - - if flags.nms: - if NMS_ON_NPU: - # repeat bbox for each class - # (N, 4) -> (B, N, 80, 4), where B = 1 is the batchsize - bboxes = bboxes[None, :, None, :].repeat(1, 1, 80, 1) - # (N, 80) -> (B, N, 80), where B = 1 is the batchsize - scores = scores[None, :, :] - - # bbox batched nms - bboxes, scores, labels, num_total_bboxes = \ - NMSOp( - bboxes.half(), scores.half(), - score_threshold=flags.score_threshold, - iou_threshold=flags.iou_threshold, - max_size_per_class=flags.max_per_img, - max_total_size=flags.max_per_img) - bboxes = bboxes[0, :num_total_bboxes, :] - scores = scores[0, :num_total_bboxes] - class_idxs = labels[0, :num_total_bboxes] - else: - # repeat bbox and class idx for each class - bboxes = bboxes[:, None, :].repeat( - 1, 80, 1) # (N, 4) -> (N, 80, 4) - class_idxs = torch.arange(80, dtype=torch.long, device=bboxes.device - )[None, :].repeat(bboxes.shape[0], 1) # (80) -> (N, 80) - - # reshape bbox for torch nms - bboxes = bboxes.view(-1, 4) - scores = scores.view(-1) - class_idxs = class_idxs.view(-1) - - # bbox batched nms - keep_inds = NMSOp(bboxes, scores, class_idxs, - flags.iou_threshold) - bboxes = bboxes[keep_inds] - scores = scores[keep_inds] - class_idxs = class_idxs[keep_inds] - else: - # repeat bbox and class idx for each class - bboxes = bboxes[:, None, :].repeat( - 1, 80, 1) # (N, 4) -> (N, 80, 4) - class_idxs = torch.arange(80, dtype=torch.long, device=bboxes.device - )[None, :].repeat(bboxes.shape[0], 1) # (80) -> (N, 80) - - # reshape bbox for torch nms - bboxes = bboxes.view(-1, 4) - scores = scores.view(-1) - class_idxs = class_idxs.view(-1) - - # keep topk max_per_img bbox - if flags.max_per_img > 0 and flags.max_per_img < bboxes.shape[0]: - _, topk_inds = scores.topk(flags.max_per_img) - bboxes = bboxes[topk_inds, :] - scores = scores[topk_inds] - class_idxs = class_idxs[topk_inds] - - # move to cpu if running on npu - if bboxes.device != 'cpu': - bboxes = bboxes.cpu() - scores = scores.cpu() - class_idxs = class_idxs.cpu() - - # convert to numpy.ndarray - bboxes = bboxes.numpy() - scores = scores.numpy() - class_idxs = class_idxs.numpy() - - # make det result file - if flags.ifShowDetObj == True: - imgCur = cv2.imread(original_img_info[2]) - - det_results_str = '' - for idx in range(bboxes.shape[0]): - x1, y1, x2, y2 = bboxes[idx, :] - predscore = scores[idx] - class_ind = class_idxs[idx] - - class_name = CLASSES[int(class_ind)] - det_results_str += "{} {} {} {} {} {}\n".format( - class_name, predscore, x1, y1, x2, y2) - if flags.ifShowDetObj == True: - imgCur = cv2.rectangle(imgCur, (int(x1), int( - y1)), (int(x2), int(y2)), (0, 255, 0), 1) - imgCur = cv2.putText(imgCur, class_name + '|' + str(predscore), - (int(x1), int(y1)), cv2.FONT_HERSHEY_SIMPLEX, - 0.5, (0, 0, 255), 1) - - if flags.ifShowDetObj == True: - cv2.imwrite(os.path.join(det_results_path, bin_file + - '.jpg'), imgCur, [int(cv2.IMWRITE_JPEG_QUALITY), 70]) - - with open(det_results_file, "w") as detf: - detf.write(det_results_str) - - finished += 1 - speed = finished / (time.time() - time_start) - print('processed {:5d}/{:<5d} images, speed: {:.2f}FPS'.format( - finished, task_len, speed), end='\r') diff --git a/AscendIE/TorchAIE/built-in/cv/detection/ssd/update_ssd_mmdet.diff b/AscendIE/TorchAIE/built-in/cv/detection/ssd/update_ssd_mmdet.diff deleted file mode 100644 index d3ceb38444..0000000000 --- a/AscendIE/TorchAIE/built-in/cv/detection/ssd/update_ssd_mmdet.diff +++ /dev/null @@ -1,811 +0,0 @@ -diff --git a/mmdet/core/bbox/coder/delta_xywh_bbox_coder.py b/mmdet/core/bbox/coder/delta_xywh_bbox_coder.py -index e9eb3579..066e90fe 100644 ---- a/mmdet/core/bbox/coder/delta_xywh_bbox_coder.py -+++ b/mmdet/core/bbox/coder/delta_xywh_bbox_coder.py -@@ -1,3 +1,7 @@ -+# Copyright (c) OpenMMLab. All rights reserved. -+import warnings -+ -+import mmcv - import numpy as np - import torch - -@@ -20,16 +24,25 @@ class DeltaXYWHBBoxCoder(BaseBBoxCoder): - target for delta coordinates - clip_border (bool, optional): Whether clip the objects outside the - border of the image. Defaults to True. -+ add_ctr_clamp (bool): Whether to add center clamp, when added, the -+ predicted box is clamped is its center is too far away from -+ the original anchor's center. Only used by YOLOF. Default False. -+ ctr_clamp (int): the maximum pixel shift to clamp. Only used by YOLOF. -+ Default 32. - """ - - def __init__(self, - target_means=(0., 0., 0., 0.), - target_stds=(1., 1., 1., 1.), -- clip_border=True): -+ clip_border=True, -+ add_ctr_clamp=False, -+ ctr_clamp=32): - super(BaseBBoxCoder, self).__init__() - self.means = target_means - self.stds = target_stds - self.clip_border = clip_border -+ self.add_ctr_clamp = add_ctr_clamp -+ self.ctr_clamp = ctr_clamp - - def encode(self, bboxes, gt_bboxes): - """Get box regression transformation deltas that can be used to -@@ -57,10 +70,16 @@ class DeltaXYWHBBoxCoder(BaseBBoxCoder): - """Apply transformation `pred_bboxes` to `boxes`. - - Args: -- boxes (torch.Tensor): Basic boxes. -- pred_bboxes (torch.Tensor): Encoded boxes with shape -- max_shape (tuple[int], optional): Maximum shape of boxes. -- Defaults to None. -+ bboxes (torch.Tensor): Basic boxes. Shape (B, N, 4) or (N, 4) -+ pred_bboxes (Tensor): Encoded offsets with respect to each roi. -+ Has shape (B, N, num_classes * 4) or (B, N, 4) or -+ (N, num_classes * 4) or (N, 4). Note N = num_anchors * W * H -+ when rois is a grid of anchors.Offset encoding follows [1]_. -+ max_shape (Sequence[int] or torch.Tensor or Sequence[ -+ Sequence[int]],optional): Maximum bounds for boxes, specifies -+ (H, W, C) or (H, W). If bboxes shape is (B, N, 4), then -+ the max_shape should be a Sequence[Sequence[int]] -+ and the length of max_shape should also be B. - wh_ratio_clip (float, optional): The allowed ratio between - width and height. - -@@ -69,8 +88,28 @@ class DeltaXYWHBBoxCoder(BaseBBoxCoder): - """ - - assert pred_bboxes.size(0) == bboxes.size(0) -- decoded_bboxes = delta2bbox(bboxes, pred_bboxes, self.means, self.stds, -- max_shape, wh_ratio_clip, self.clip_border) -+ if pred_bboxes.ndim == 3: -+ assert pred_bboxes.size(1) == bboxes.size(1) -+ -+ if pred_bboxes.ndim == 2 and not True: -+ # single image decode -+ decoded_bboxes = delta2bbox(bboxes, pred_bboxes, self.means, -+ self.stds, max_shape, wh_ratio_clip, -+ self.clip_border, self.add_ctr_clamp, -+ self.ctr_clamp) -+ else: -+ if pred_bboxes.ndim == 3 and not True: -+ warnings.warn( -+ 'DeprecationWarning: onnx_delta2bbox is deprecated ' -+ 'in the case of batch decoding and non-ONNX, ' -+ 'please use “delta2bbox” instead. In order to improve ' -+ 'the decoding speed, the batch function will no ' -+ 'longer be supported. ') -+ decoded_bboxes = onnx_delta2bbox(bboxes, pred_bboxes, self.means, -+ self.stds, max_shape, -+ wh_ratio_clip, self.clip_border, -+ self.add_ctr_clamp, -+ self.ctr_clamp) - - return decoded_bboxes - -@@ -126,7 +165,108 @@ def delta2bbox(rois, - stds=(1., 1., 1., 1.), - max_shape=None, - wh_ratio_clip=16 / 1000, -- clip_border=True): -+ clip_border=True, -+ add_ctr_clamp=False, -+ ctr_clamp=32): -+ """Apply deltas to shift/scale base boxes. -+ -+ Typically the rois are anchor or proposed bounding boxes and the deltas are -+ network outputs used to shift/scale those boxes. -+ This is the inverse function of :func:`bbox2delta`. -+ -+ Args: -+ rois (Tensor): Boxes to be transformed. Has shape (N, 4). -+ deltas (Tensor): Encoded offsets relative to each roi. -+ Has shape (N, num_classes * 4) or (N, 4). Note -+ N = num_base_anchors * W * H, when rois is a grid of -+ anchors. Offset encoding follows [1]_. -+ means (Sequence[float]): Denormalizing means for delta coordinates. -+ Default (0., 0., 0., 0.). -+ stds (Sequence[float]): Denormalizing standard deviation for delta -+ coordinates. Default (1., 1., 1., 1.). -+ max_shape (tuple[int, int]): Maximum bounds for boxes, specifies -+ (H, W). Default None. -+ wh_ratio_clip (float): Maximum aspect ratio for boxes. Default -+ 16 / 1000. -+ clip_border (bool, optional): Whether clip the objects outside the -+ border of the image. Default True. -+ add_ctr_clamp (bool): Whether to add center clamp. When set to True, -+ the center of the prediction bounding box will be clamped to -+ avoid being too far away from the center of the anchor. -+ Only used by YOLOF. Default False. -+ ctr_clamp (int): the maximum pixel shift to clamp. Only used by YOLOF. -+ Default 32. -+ -+ Returns: -+ Tensor: Boxes with shape (N, num_classes * 4) or (N, 4), where 4 -+ represent tl_x, tl_y, br_x, br_y. -+ -+ References: -+ .. [1] https://arxiv.org/abs/1311.2524 -+ -+ Example: -+ >>> rois = torch.Tensor([[ 0., 0., 1., 1.], -+ >>> [ 0., 0., 1., 1.], -+ >>> [ 0., 0., 1., 1.], -+ >>> [ 5., 5., 5., 5.]]) -+ >>> deltas = torch.Tensor([[ 0., 0., 0., 0.], -+ >>> [ 1., 1., 1., 1.], -+ >>> [ 0., 0., 2., -1.], -+ >>> [ 0.7, -1.9, -0.5, 0.3]]) -+ >>> delta2bbox(rois, deltas, max_shape=(32, 32, 3)) -+ tensor([[0.0000, 0.0000, 1.0000, 1.0000], -+ [0.1409, 0.1409, 2.8591, 2.8591], -+ [0.0000, 0.3161, 4.1945, 0.6839], -+ [5.0000, 5.0000, 5.0000, 5.0000]]) -+ """ -+ num_bboxes, num_classes = deltas.size(0), deltas.size(1) // 4 -+ if num_bboxes == 0: -+ return deltas -+ -+ deltas = deltas.reshape(-1, 4) -+ -+ means = deltas.new_tensor(means).view(1, -1) -+ stds = deltas.new_tensor(stds).view(1, -1) -+ denorm_deltas = deltas * stds + means -+ -+ dxy = denorm_deltas[:, :2] -+ dwh = denorm_deltas[:, 2:] -+ -+ # Compute width/height of each roi -+ rois_ = rois.repeat(1, num_classes).reshape(-1, 4) -+ pxy = ((rois_[:, :2] + rois_[:, 2:]) * 0.5) -+ pwh = (rois_[:, 2:] - rois_[:, :2]) -+ -+ dxy_wh = pwh * dxy -+ -+ max_ratio = np.abs(np.log(wh_ratio_clip)) -+ if add_ctr_clamp: -+ dxy_wh = torch.clamp(dxy_wh, max=ctr_clamp, min=-ctr_clamp) -+ dwh = torch.clamp(dwh, max=max_ratio) -+ else: -+ dwh = dwh.clamp(min=-max_ratio, max=max_ratio) -+ -+ gxy = pxy + dxy_wh -+ gwh = pwh * dwh.exp() -+ x1y1 = gxy - (gwh * 0.5) -+ x2y2 = gxy + (gwh * 0.5) -+ bboxes = torch.cat([x1y1, x2y2], dim=-1) -+ if clip_border and max_shape is not None: -+ bboxes[..., 0::2].clamp_(min=0, max=max_shape[1]) -+ bboxes[..., 1::2].clamp_(min=0, max=max_shape[0]) -+ bboxes = bboxes.reshape(num_bboxes, -1) -+ return bboxes -+ -+ -+def onnx_delta2bbox(rois, -+ deltas, -+ means=(0., 0., 0., 0.), -+ stds=(1., 1., 1., 1.), -+ max_shape=None, -+ wh_ratio_clip=16 / 1000, -+ clip_border=True, -+ add_ctr_clamp=False, -+ ctr_clamp=32): - """Apply deltas to shift/scale base boxes. - - Typically the rois are anchor or proposed bounding boxes and the deltas are -@@ -134,21 +274,34 @@ def delta2bbox(rois, - This is the inverse function of :func:`bbox2delta`. - - Args: -- rois (Tensor): Boxes to be transformed. Has shape (N, 4) -+ rois (Tensor): Boxes to be transformed. Has shape (N, 4) or (B, N, 4) - deltas (Tensor): Encoded offsets with respect to each roi. -- Has shape (N, 4 * num_classes). Note N = num_anchors * W * H when -- rois is a grid of anchors. Offset encoding follows [1]_. -- means (Sequence[float]): Denormalizing means for delta coordinates -+ Has shape (B, N, num_classes * 4) or (B, N, 4) or -+ (N, num_classes * 4) or (N, 4). Note N = num_anchors * W * H -+ when rois is a grid of anchors.Offset encoding follows [1]_. -+ means (Sequence[float]): Denormalizing means for delta coordinates. -+ Default (0., 0., 0., 0.). - stds (Sequence[float]): Denormalizing standard deviation for delta -- coordinates -- max_shape (tuple[int, int]): Maximum bounds for boxes. specifies (H, W) -+ coordinates. Default (1., 1., 1., 1.). -+ max_shape (Sequence[int] or torch.Tensor or Sequence[ -+ Sequence[int]],optional): Maximum bounds for boxes, specifies -+ (H, W, C) or (H, W). If rois shape is (B, N, 4), then -+ the max_shape should be a Sequence[Sequence[int]] -+ and the length of max_shape should also be B. Default None. - wh_ratio_clip (float): Maximum aspect ratio for boxes. -+ Default 16 / 1000. - clip_border (bool, optional): Whether clip the objects outside the -- border of the image. Defaults to True. -+ border of the image. Default True. -+ add_ctr_clamp (bool): Whether to add center clamp, when added, the -+ predicted box is clamped is its center is too far away from -+ the original anchor's center. Only used by YOLOF. Default False. -+ ctr_clamp (int): the maximum pixel shift to clamp. Only used by YOLOF. -+ Default 32. - - Returns: -- Tensor: Boxes with shape (N, 4), where columns represent -- tl_x, tl_y, br_x, br_y. -+ Tensor: Boxes with shape (B, N, num_classes * 4) or (B, N, 4) or -+ (N, num_classes * 4) or (N, 4), where 4 represent -+ tl_x, tl_y, br_x, br_y. - - References: - .. [1] https://arxiv.org/abs/1311.2524 -@@ -162,43 +315,76 @@ def delta2bbox(rois, - >>> [ 1., 1., 1., 1.], - >>> [ 0., 0., 2., -1.], - >>> [ 0.7, -1.9, -0.5, 0.3]]) -- >>> delta2bbox(rois, deltas, max_shape=(32, 32)) -+ >>> delta2bbox(rois, deltas, max_shape=(32, 32, 3)) - tensor([[0.0000, 0.0000, 1.0000, 1.0000], - [0.1409, 0.1409, 2.8591, 2.8591], - [0.0000, 0.3161, 4.1945, 0.6839], - [5.0000, 5.0000, 5.0000, 5.0000]]) - """ -- means = deltas.new_tensor(means).view(1, -1).repeat(1, deltas.size(1) // 4) -- stds = deltas.new_tensor(stds).view(1, -1).repeat(1, deltas.size(1) // 4) -+ means = deltas.new_tensor(means).view(1, -+ -1).repeat(1, -+ deltas.size(-1) // 4) -+ stds = deltas.new_tensor(stds).view(1, -1).repeat(1, deltas.size(-1) // 4) - denorm_deltas = deltas * stds + means -- dx = denorm_deltas[:, 0::4] -- dy = denorm_deltas[:, 1::4] -- dw = denorm_deltas[:, 2::4] -- dh = denorm_deltas[:, 3::4] -- max_ratio = np.abs(np.log(wh_ratio_clip)) -- dw = dw.clamp(min=-max_ratio, max=max_ratio) -- dh = dh.clamp(min=-max_ratio, max=max_ratio) -+ dx = denorm_deltas[..., 0::4] -+ dy = denorm_deltas[..., 1::4] -+ dw = denorm_deltas[..., 2::4] -+ dh = denorm_deltas[..., 3::4] -+ -+ x1, y1 = rois[..., 0], rois[..., 1] -+ x2, y2 = rois[..., 2], rois[..., 3] - # Compute center of each roi -- px = ((rois[:, 0] + rois[:, 2]) * 0.5).unsqueeze(1).expand_as(dx) -- py = ((rois[:, 1] + rois[:, 3]) * 0.5).unsqueeze(1).expand_as(dy) -+ px = ((x1 + x2) * 0.5).unsqueeze(-1).expand_as(dx) -+ py = ((y1 + y2) * 0.5).unsqueeze(-1).expand_as(dy) - # Compute width/height of each roi -- pw = (rois[:, 2] - rois[:, 0]).unsqueeze(1).expand_as(dw) -- ph = (rois[:, 3] - rois[:, 1]).unsqueeze(1).expand_as(dh) -+ pw = (x2 - x1).unsqueeze(-1).expand_as(dw) -+ ph = (y2 - y1).unsqueeze(-1).expand_as(dh) -+ -+ dx_width = pw * dx -+ dy_height = ph * dy -+ -+ max_ratio = np.abs(np.log(wh_ratio_clip)) -+ if add_ctr_clamp: -+ dx_width = torch.clamp(dx_width, max=ctr_clamp, min=-ctr_clamp) -+ dy_height = torch.clamp(dy_height, max=ctr_clamp, min=-ctr_clamp) -+ dw = torch.clamp(dw, max=max_ratio) -+ dh = torch.clamp(dh, max=max_ratio) -+ else: -+ dw = dw.clamp(min=-max_ratio, max=max_ratio) -+ dh = dh.clamp(min=-max_ratio, max=max_ratio) - # Use exp(network energy) to enlarge/shrink each roi - gw = pw * dw.exp() - gh = ph * dh.exp() - # Use network energy to shift the center of each roi -- gx = px + pw * dx -- gy = py + ph * dy -+ gx = px + dx_width -+ gy = py + dy_height - # Convert center-xy/width/height to top-left, bottom-right - x1 = gx - gw * 0.5 - y1 = gy - gh * 0.5 - x2 = gx + gw * 0.5 - y2 = gy + gh * 0.5 -- if clip_border and max_shape is not None: -- x1 = x1.clamp(min=0, max=max_shape[1]) -- y1 = y1.clamp(min=0, max=max_shape[0]) -- x2 = x2.clamp(min=0, max=max_shape[1]) -- y2 = y2.clamp(min=0, max=max_shape[0]) -+ - bboxes = torch.stack([x1, y1, x2, y2], dim=-1).view(deltas.size()) -+ -+ if clip_border and max_shape is not None: -+ # clip bboxes with dynamic `min` and `max` for onnx -+ if True: -+ from mmdet.core.export.onnx_helper import dynamic_clip_for_onnx -+ x1, y1, x2, y2 = dynamic_clip_for_onnx(x1, y1, x2, y2, max_shape) -+ bboxes = torch.stack([x1, y1, x2, y2], dim=-1).view(deltas.size()) -+ return bboxes -+ if not isinstance(max_shape, torch.Tensor): -+ max_shape = x1.new_tensor(max_shape) -+ max_shape = max_shape[..., :2].type_as(x1) -+ if max_shape.ndim == 2: -+ assert bboxes.ndim == 3 -+ assert max_shape.size(0) == bboxes.size(0) -+ -+ min_xy = x1.new_tensor(0) -+ max_xy = torch.cat( -+ [max_shape] * (deltas.size(-1) // 2), -+ dim=-1).flip(-1).unsqueeze(-2) -+ bboxes = torch.where(bboxes < min_xy, min_xy, bboxes) -+ bboxes = torch.where(bboxes > max_xy, max_xy, bboxes) -+ - return bboxes -diff --git a/mmdet/core/export/pytorch2onnx.py b/mmdet/core/export/pytorch2onnx.py -index 8f9309df..b9f43d48 100644 ---- a/mmdet/core/export/pytorch2onnx.py -+++ b/mmdet/core/export/pytorch2onnx.py -@@ -39,6 +39,7 @@ def generate_inputs_and_wrap_model(config_path, checkpoint_path, input_config): - - model = build_model_from_cfg(config_path, checkpoint_path) - one_img, one_meta = preprocess_example_input(input_config) -+ one_meta['img_shape_for_onnx'] = one_img.shape[-2:] - tensor_data = [one_img] - model.forward = partial( - model.forward, img_metas=[[one_meta]], return_loss=False) -diff --git a/mmdet/core/post_processing/bbox_nms.py b/mmdet/core/post_processing/bbox_nms.py -index 463fe2e4..72ca09d3 100644 ---- a/mmdet/core/post_processing/bbox_nms.py -+++ b/mmdet/core/post_processing/bbox_nms.py -@@ -55,7 +55,7 @@ def multiclass_nms(multi_bboxes, - inds = valid_mask.nonzero(as_tuple=False).squeeze(1) - bboxes, scores, labels = bboxes[inds], scores[inds], labels[inds] - if inds.numel() == 0: -- if torch.onnx.is_in_onnx_export(): -+ if True: - raise RuntimeError('[ONNX Error] Can not record NMS ' - 'as it has not been executed this time') - if return_inds: -diff --git a/mmdet/models/backbones/ssd_vgg.py b/mmdet/models/backbones/ssd_vgg.py -index cbc4fbb2..4bb7e37a 100644 ---- a/mmdet/models/backbones/ssd_vgg.py -+++ b/mmdet/models/backbones/ssd_vgg.py -@@ -162,8 +162,14 @@ class L2Norm(nn.Module): - - def forward(self, x): - """Forward function.""" -- # normalization layer convert to FP32 in FP16 training -+ # # normalization layer convert to FP32 in FP16 training -+ # x_float = x.float() -+ # norm = x_float.pow(2).sum(1, keepdim=True).sqrt() + self.eps -+ # return (self.weight[None, :, None, None].float().expand_as(x_float) * -+ # x_float / norm).type_as(x) -+ - x_float = x.float() -- norm = x_float.pow(2).sum(1, keepdim=True).sqrt() + self.eps -+ x_mul = x_float * x_float -+ norm = x_mul.sum(1, keepdim=True).sqrt() + self.eps - return (self.weight[None, :, None, None].float().expand_as(x_float) * - x_float / norm).type_as(x) -diff --git a/mmdet/models/dense_heads/anchor_head.py b/mmdet/models/dense_heads/anchor_head.py -index a5bb4137..6e0d892e 100644 ---- a/mmdet/models/dense_heads/anchor_head.py -+++ b/mmdet/models/dense_heads/anchor_head.py -@@ -487,6 +487,162 @@ class AnchorHead(BaseDenseHead, BBoxTestMixin): - num_total_samples=num_total_samples) - return dict(loss_cls=losses_cls, loss_bbox=losses_bbox) - -+ @force_fp32(apply_to=('cls_scores', 'bbox_preds')) -+ def onnx_export(self, -+ cls_scores, -+ bbox_preds, -+ score_factors=None, -+ img_metas=None, -+ with_nms=True): -+ """Transform network output for a batch into bbox predictions. -+ -+ Args: -+ cls_scores (list[Tensor]): Box scores for each scale level -+ with shape (N, num_points * num_classes, H, W). -+ bbox_preds (list[Tensor]): Box energies / deltas for each scale -+ level with shape (N, num_points * 4, H, W). -+ score_factors (list[Tensor]): score_factors for each s -+ cale level with shape (N, num_points * 1, H, W). -+ Default: None. -+ img_metas (list[dict]): Meta information of each image, e.g., -+ image size, scaling factor, etc. Default: None. -+ with_nms (bool): Whether apply nms to the bboxes. Default: True. -+ -+ Returns: -+ tuple[Tensor, Tensor] | list[tuple]: When `with_nms` is True, -+ it is tuple[Tensor, Tensor], first tensor bboxes with shape -+ [N, num_det, 5], 5 arrange as (x1, y1, x2, y2, score) -+ and second element is class labels of shape [N, num_det]. -+ When `with_nms` is False, first tensor is bboxes with -+ shape [N, num_det, 4], second tensor is raw score has -+ shape [N, num_det, num_classes]. -+ """ -+ assert len(cls_scores) == len(bbox_preds) -+ -+ num_levels = len(cls_scores) -+ -+ featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores] -+ -+ mlvl_priors = self.anchor_generator.grid_anchors( -+ featmap_sizes, device=bbox_preds[0].device) -+ -+ mlvl_cls_scores = [cls_scores[i].detach() for i in range(num_levels)] -+ mlvl_bbox_preds = [bbox_preds[i].detach() for i in range(num_levels)] -+ -+ assert len( -+ img_metas -+ ) == 1, 'Only support one input image while in exporting to ONNX' -+ img_shape = torch.tensor( -+ img_metas[0]['img_shape_for_onnx'], -+ dtype=torch.long, -+ device=bbox_preds[0].device) -+ -+ cfg = self.test_cfg -+ assert len(cls_scores) == len(bbox_preds) == len(mlvl_priors) -+ device = cls_scores[0].device -+ batch_size = cls_scores[0].shape[0] -+ # convert to tensor to keep tracing -+ nms_pre_tensor = torch.tensor( -+ cfg.get('nms_pre', -1), device=device, dtype=torch.long) -+ -+ # e.g. Retina, FreeAnchor, etc. -+ if score_factors is None: -+ with_score_factors = False -+ mlvl_score_factor = [None for _ in range(num_levels)] -+ else: -+ # e.g. FCOS, PAA, ATSS, etc. -+ with_score_factors = True -+ mlvl_score_factor = [ -+ score_factors[i].detach() for i in range(num_levels) -+ ] -+ mlvl_score_factors = [] -+ -+ mlvl_batch_bboxes = [] -+ mlvl_scores = [] -+ -+ for cls_score, bbox_pred, score_factors, priors in zip( -+ mlvl_cls_scores, mlvl_bbox_preds, mlvl_score_factor, -+ mlvl_priors): -+ assert cls_score.size()[-2:] == bbox_pred.size()[-2:] -+ -+ scores = cls_score.permute(0, 2, 3, -+ 1).reshape(batch_size, -1, -+ self.cls_out_channels) -+ if self.use_sigmoid_cls: -+ scores = scores.sigmoid() -+ nms_pre_score = scores -+ else: -+ scores = scores.softmax(-1) -+ nms_pre_score = scores -+ -+ if with_score_factors: -+ score_factors = score_factors.permute(0, 2, 3, 1).reshape( -+ batch_size, -1).sigmoid() -+ bbox_pred = bbox_pred.permute(0, 2, 3, -+ 1).reshape(batch_size, -1, 4) -+ priors = priors.expand(batch_size, -1, priors.size(-1)) -+ # Get top-k predictions -+ from mmdet.core.export.onnx_helper import get_k_for_topk -+ nms_pre = get_k_for_topk(nms_pre_tensor, bbox_pred.shape[1]) -+ if nms_pre > 0: -+ -+ if with_score_factors: -+ nms_pre_score = (nms_pre_score * score_factors[..., None]) -+ else: -+ nms_pre_score = nms_pre_score -+ -+ # Get maximum scores for foreground classes. -+ if self.use_sigmoid_cls: -+ max_scores, _ = nms_pre_score.max(-1) -+ else: -+ # remind that we set FG labels to [0, num_class-1] -+ # since mmdet v2.0 -+ # BG cat_id: num_class -+ max_scores, _ = nms_pre_score[..., :-1].max(-1) -+ _, topk_inds = max_scores.topk(nms_pre) -+ -+ batch_inds = torch.arange( -+ batch_size, device=bbox_pred.device).view( -+ -1, 1).expand_as(topk_inds).long() -+ # Avoid onnx2tensorrt issue in https://github.com/NVIDIA/TensorRT/issues/1134 # noqa: E501 -+ # transformed_inds = bbox_pred.shape[1] * batch_inds + topk_inds -+ transformed_inds = (bbox_pred.shape[1] * batch_inds).int() + topk_inds.int() -+ transformed_inds = transformed_inds.long() -+ priors = priors.reshape( -+ -1, priors.size(-1))[transformed_inds, :].reshape( -+ batch_size, -1, priors.size(-1)) -+ bbox_pred = bbox_pred.reshape(-1, -+ 4)[transformed_inds, :].reshape( -+ batch_size, -1, 4) -+ scores = scores.reshape( -+ -1, self.cls_out_channels)[transformed_inds, :].reshape( -+ batch_size, -1, self.cls_out_channels) -+ if with_score_factors: -+ score_factors = score_factors.reshape( -+ -1, 1)[transformed_inds].reshape(batch_size, -1) -+ -+ bboxes = self.bbox_coder.decode( -+ priors, bbox_pred, max_shape=img_shape) -+ -+ mlvl_batch_bboxes.append(bboxes) -+ mlvl_scores.append(scores) -+ if with_score_factors: -+ mlvl_score_factors.append(score_factors) -+ -+ batch_bboxes = torch.cat(mlvl_batch_bboxes, dim=1) -+ batch_scores = torch.cat(mlvl_scores, dim=1) -+ if with_score_factors: -+ batch_score_factors = torch.cat(mlvl_score_factors, dim=1) -+ -+ if not self.use_sigmoid_cls: -+ batch_scores = batch_scores[..., :self.num_classes] -+ -+ if with_score_factors: -+ batch_scores = batch_scores * (batch_score_factors.unsqueeze(2)) -+ -+ # directly return bboxes without NMS -+ return batch_bboxes, batch_scores -+ - @force_fp32(apply_to=('cls_scores', 'bbox_preds')) - def get_bboxes(self, - cls_scores, -@@ -545,38 +701,45 @@ class AnchorHead(BaseDenseHead, BBoxTestMixin): - >>> assert det_bboxes.shape[1] == 5 - >>> assert len(det_bboxes) == len(det_labels) == cfg.max_per_img - """ -- assert len(cls_scores) == len(bbox_preds) -- num_levels = len(cls_scores) -- -- device = cls_scores[0].device -- featmap_sizes = [cls_scores[i].shape[-2:] for i in range(num_levels)] -- mlvl_anchors = self.anchor_generator.grid_anchors( -- featmap_sizes, device=device) -- -- result_list = [] -- for img_id in range(len(img_metas)): -- cls_score_list = [ -- cls_scores[i][img_id].detach() for i in range(num_levels) -- ] -- bbox_pred_list = [ -- bbox_preds[i][img_id].detach() for i in range(num_levels) -- ] -- img_shape = img_metas[img_id]['img_shape'] -- scale_factor = img_metas[img_id]['scale_factor'] -- if with_nms: -- # some heads don't support with_nms argument -- proposals = self._get_bboxes_single(cls_score_list, -- bbox_pred_list, -- mlvl_anchors, img_shape, -- scale_factor, cfg, rescale) -- else: -- proposals = self._get_bboxes_single(cls_score_list, -- bbox_pred_list, -- mlvl_anchors, img_shape, -- scale_factor, cfg, rescale, -- with_nms) -- result_list.append(proposals) -- return result_list -+ if True: -+ return self.onnx_export(cls_scores, -+ bbox_preds, -+ score_factors=None, -+ img_metas=img_metas, -+ with_nms=with_nms) -+ else: -+ assert len(cls_scores) == len(bbox_preds) -+ num_levels = len(cls_scores) -+ -+ device = cls_scores[0].device -+ featmap_sizes = [cls_scores[i].shape[-2:] for i in range(num_levels)] -+ mlvl_anchors = self.anchor_generator.grid_anchors( -+ featmap_sizes, device=device) -+ -+ result_list = [] -+ for img_id in range(len(img_metas)): -+ cls_score_list = [ -+ cls_scores[i][img_id].detach() for i in range(num_levels) -+ ] -+ bbox_pred_list = [ -+ bbox_preds[i][img_id].detach() for i in range(num_levels) -+ ] -+ img_shape = img_metas[img_id]['img_shape'] -+ scale_factor = img_metas[img_id]['scale_factor'] -+ if with_nms: -+ # some heads don't support with_nms argument -+ proposals = self._get_bboxes_single(cls_score_list, -+ bbox_pred_list, -+ mlvl_anchors, img_shape, -+ scale_factor, cfg, rescale) -+ else: -+ proposals = self._get_bboxes_single(cls_score_list, -+ bbox_pred_list, -+ mlvl_anchors, img_shape, -+ scale_factor, cfg, rescale, -+ with_nms) -+ result_list.append(proposals) -+ return result_list - - def _get_bboxes_single(self, - cls_score_list, -@@ -612,6 +775,7 @@ class AnchorHead(BaseDenseHead, BBoxTestMixin): - are bounding box positions (tl_x, tl_y, br_x, br_y) and the - 5-th column is a score between 0 and 1. - """ -+ print('in _get_bboxes_single') - cfg = self.test_cfg if cfg is None else cfg - assert len(cls_score_list) == len(bbox_pred_list) == len(mlvl_anchors) - mlvl_bboxes = [] -diff --git a/mmdet/models/dense_heads/yolo_head.py b/mmdet/models/dense_heads/yolo_head.py -index 93d051e7..94e496b5 100644 ---- a/mmdet/models/dense_heads/yolo_head.py -+++ b/mmdet/models/dense_heads/yolo_head.py -@@ -281,7 +281,7 @@ class YOLOV3Head(BaseDenseHead, BBoxTestMixin): - # Get top-k prediction - nms_pre = cfg.get('nms_pre', -1) - if 0 < nms_pre < conf_pred.size(0) and ( -- not torch.onnx.is_in_onnx_export()): -+ not True): - _, topk_inds = conf_pred.topk(nms_pre) - bbox_pred = bbox_pred[topk_inds, :] - cls_pred = cls_pred[topk_inds, :] -diff --git a/mmdet/models/detectors/base.py b/mmdet/models/detectors/base.py -index 7c6d5e96..8bd74238 100644 ---- a/mmdet/models/detectors/base.py -+++ b/mmdet/models/detectors/base.py -@@ -179,6 +179,8 @@ class BaseDetector(nn.Module, metaclass=ABCMeta): - if return_loss: - return self.forward_train(img, img_metas, **kwargs) - else: -+ if not isinstance(img, list): -+ img = [img] - return self.forward_test(img, img_metas, **kwargs) - - def _parse_losses(self, losses): -diff --git a/mmdet/models/detectors/single_stage.py b/mmdet/models/detectors/single_stage.py -index 96c4acac..f5be2641 100644 ---- a/mmdet/models/detectors/single_stage.py -+++ b/mmdet/models/detectors/single_stage.py -@@ -114,7 +114,7 @@ class SingleStageDetector(BaseDetector): - bbox_list = self.bbox_head.get_bboxes( - *outs, img_metas, rescale=rescale) - # skip post-processing when exporting to ONNX -- if torch.onnx.is_in_onnx_export(): -+ if True: - return bbox_list - - bbox_results = [ -diff --git a/mmdet/models/roi_heads/cascade_roi_head.py b/mmdet/models/roi_heads/cascade_roi_head.py -index 45b6f36a..1199a443 100644 ---- a/mmdet/models/roi_heads/cascade_roi_head.py -+++ b/mmdet/models/roi_heads/cascade_roi_head.py -@@ -349,7 +349,7 @@ class CascadeRoIHead(BaseRoIHead, BBoxTestMixin, MaskTestMixin): - det_bboxes.append(det_bbox) - det_labels.append(det_label) - -- if torch.onnx.is_in_onnx_export(): -+ if True: - return det_bboxes, det_labels - bbox_results = [ - bbox2result(det_bboxes[i], det_labels[i], -diff --git a/mmdet/models/roi_heads/mask_heads/fcn_mask_head.py b/mmdet/models/roi_heads/mask_heads/fcn_mask_head.py -index 0cba3cda..d69054b6 100644 ---- a/mmdet/models/roi_heads/mask_heads/fcn_mask_head.py -+++ b/mmdet/models/roi_heads/mask_heads/fcn_mask_head.py -@@ -195,7 +195,7 @@ class FCNMaskHead(nn.Module): - scale_factor = bboxes.new_tensor(scale_factor) - bboxes = bboxes / scale_factor - -- if torch.onnx.is_in_onnx_export(): -+ if True: - # TODO: Remove after F.grid_sample is supported. - from torchvision.models.detection.roi_heads \ - import paste_masks_in_image -@@ -316,7 +316,7 @@ def _do_paste_mask(masks, boxes, img_h, img_w, skip_empty=True): - gy = img_y[:, :, None].expand(N, img_y.size(1), img_x.size(1)) - grid = torch.stack([gx, gy], dim=3) - -- if torch.onnx.is_in_onnx_export(): -+ if True: - raise RuntimeError( - 'Exporting F.grid_sample from Pytorch to ONNX is not supported.') - img_masks = F.grid_sample( -diff --git a/mmdet/models/roi_heads/roi_extractors/single_level_roi_extractor.py b/mmdet/models/roi_heads/roi_extractors/single_level_roi_extractor.py -index c0eebc4a..534b1c9b 100644 ---- a/mmdet/models/roi_heads/roi_extractors/single_level_roi_extractor.py -+++ b/mmdet/models/roi_heads/roi_extractors/single_level_roi_extractor.py -@@ -55,7 +55,7 @@ class SingleRoIExtractor(BaseRoIExtractor): - """Forward function.""" - out_size = self.roi_layers[0].output_size - num_levels = len(feats) -- if torch.onnx.is_in_onnx_export(): -+ if True: - # Work around to export mask-rcnn to onnx - roi_feats = rois[:, :1].clone().detach() - roi_feats = roi_feats.expand( -@@ -82,7 +82,7 @@ class SingleRoIExtractor(BaseRoIExtractor): - mask = target_lvls == i - inds = mask.nonzero(as_tuple=False).squeeze(1) - # TODO: make it nicer when exporting to onnx -- if torch.onnx.is_in_onnx_export(): -+ if True: - # To keep all roi_align nodes exported to onnx - rois_ = rois[inds] - roi_feats_t = self.roi_layers[i](feats[i], rois_) -diff --git a/mmdet/models/roi_heads/standard_roi_head.py b/mmdet/models/roi_heads/standard_roi_head.py -index c530f2a5..85f95e0c 100644 ---- a/mmdet/models/roi_heads/standard_roi_head.py -+++ b/mmdet/models/roi_heads/standard_roi_head.py -@@ -246,7 +246,7 @@ class StandardRoIHead(BaseRoIHead, BBoxTestMixin, MaskTestMixin): - - det_bboxes, det_labels = self.simple_test_bboxes( - x, img_metas, proposal_list, self.test_cfg, rescale=rescale) -- if torch.onnx.is_in_onnx_export(): -+ if True: - if self.with_mask: - segm_results = self.simple_test_mask( - x, img_metas, det_bboxes, det_labels, rescale=rescale) -diff --git a/mmdet/models/roi_heads/test_mixins.py b/mmdet/models/roi_heads/test_mixins.py -index 0e675d6e..ecc08cf6 100644 ---- a/mmdet/models/roi_heads/test_mixins.py -+++ b/mmdet/models/roi_heads/test_mixins.py -@@ -197,7 +197,7 @@ class MaskTestMixin(object): - torch.from_numpy(scale_factor).to(det_bboxes[0].device) - for scale_factor in scale_factors - ] -- if torch.onnx.is_in_onnx_export(): -+ if True: - # avoid mask_pred.split with static number of prediction - mask_preds = [] - _bboxes = [] -diff --git a/tools/pytorch2onnx.py b/tools/pytorch2onnx.py -index a8e7487b..97ed2d09 100644 ---- a/tools/pytorch2onnx.py -+++ b/tools/pytorch2onnx.py -@@ -33,23 +33,32 @@ def pytorch2onnx(config_path, - one_img, one_meta = preprocess_example_input(input_config) - model, tensor_data = generate_inputs_and_wrap_model( - config_path, checkpoint_path, input_config) -+ -+ input_names = ['input'] -+ dynamic_axes = {'input': {0: 'batch', 2: 'height', 3: 'width'}} -+ - output_names = ['boxes'] -+ dynamic_axes['boxes'] = {0: 'batch'} - if model.with_bbox: - output_names.append('labels') -+ dynamic_axes['labels'] = {0: 'batch'} - if model.with_mask: - output_names.append('masks') -+ dynamic_axes['masks'] = {0: 'batch'} - - torch.onnx.export( - model, - tensor_data, - output_file, -- input_names=['input'], -+ input_names=input_names, - output_names=output_names, -+ dynamic_axes=dynamic_axes, - export_params=True, - keep_initializers_as_inputs=True, - do_constant_folding=True, - verbose=show, -- opset_version=opset_version) -+ opset_version=opset_version, -+ enable_onnx_checker=False) - - model.forward = orig_model.forward - print(f'Successfully exported ONNX model: {output_file}') -@@ -67,6 +76,7 @@ def pytorch2onnx(config_path, - tensor_data = [one_img] - # check the numerical value - # get pytorch output -+ one_meta['img_shape_for_onnx'] = one_img.shape[-2:] - pytorch_results = model(tensor_data, [[one_meta]], return_loss=False) - pytorch_results = pytorch_results[0] - # get onnx output -- Gitee From 0c3fcff39fd11554b58aa1f3a398493c9b8e1e72 Mon Sep 17 00:00:00 2001 From: Guanzhong Chen Date: Tue, 19 Dec 2023 16:57:34 +0800 Subject: [PATCH 6/6] 1 --- .../built-in/cv/detection/ssd/acc_dataset.py | 18 ++++++++++++++---- .../built-in/cv/detection/ssd/export.py | 14 ++++++++++++++ .../built-in/cv/detection/ssd/get_info.py | 1 - .../TorchAIE/built-in/cv/detection/ssd/perf.py | 17 ++++++++++++++--- .../cv/detection/ssd/ssd_preprocess.py | 4 ---- .../built-in/cv/detection/ssd/txt_to_json.py | 1 - 6 files changed, 42 insertions(+), 13 deletions(-) diff --git a/AscendIE/TorchAIE/built-in/cv/detection/ssd/acc_dataset.py b/AscendIE/TorchAIE/built-in/cv/detection/ssd/acc_dataset.py index 650772ebd3..d43094efb1 100644 --- a/AscendIE/TorchAIE/built-in/cv/detection/ssd/acc_dataset.py +++ b/AscendIE/TorchAIE/built-in/cv/detection/ssd/acc_dataset.py @@ -1,3 +1,17 @@ +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import argparse from tqdm import tqdm import os @@ -81,7 +95,3 @@ if __name__ == '__main__': second_out.tofile(os.path.join(save_dir, filename.split(".")[0] + "_1.bin")) except Exception as e: print(f'Error reading {filename}: {e}') - - - - diff --git a/AscendIE/TorchAIE/built-in/cv/detection/ssd/export.py b/AscendIE/TorchAIE/built-in/cv/detection/ssd/export.py index 527acd040a..50610187dd 100644 --- a/AscendIE/TorchAIE/built-in/cv/detection/ssd/export.py +++ b/AscendIE/TorchAIE/built-in/cv/detection/ssd/export.py @@ -1,3 +1,17 @@ +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import os import argparse diff --git a/AscendIE/TorchAIE/built-in/cv/detection/ssd/get_info.py b/AscendIE/TorchAIE/built-in/cv/detection/ssd/get_info.py index 806398b3cc..7bfd9c7515 100644 --- a/AscendIE/TorchAIE/built-in/cv/detection/ssd/get_info.py +++ b/AscendIE/TorchAIE/built-in/cv/detection/ssd/get_info.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""get info""" import os import sys diff --git a/AscendIE/TorchAIE/built-in/cv/detection/ssd/perf.py b/AscendIE/TorchAIE/built-in/cv/detection/ssd/perf.py index 58ebef8a50..d815e1ec7b 100644 --- a/AscendIE/TorchAIE/built-in/cv/detection/ssd/perf.py +++ b/AscendIE/TorchAIE/built-in/cv/detection/ssd/perf.py @@ -1,3 +1,17 @@ +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import argparse import time from tqdm import tqdm @@ -100,6 +114,3 @@ if __name__ == '__main__': current_time = datetime.now() formatted_time = current_time.strftime("%Y-%m-%d %H:%M:%S") print("Current Time:", formatted_time) - - - diff --git a/AscendIE/TorchAIE/built-in/cv/detection/ssd/ssd_preprocess.py b/AscendIE/TorchAIE/built-in/cv/detection/ssd/ssd_preprocess.py index 5e6dc9dec4..8fc86186d8 100644 --- a/AscendIE/TorchAIE/built-in/cv/detection/ssd/ssd_preprocess.py +++ b/AscendIE/TorchAIE/built-in/cv/detection/ssd/ssd_preprocess.py @@ -11,10 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""coco preprocess""" -""" -使用openmmlab环境 -""" import os import argparse diff --git a/AscendIE/TorchAIE/built-in/cv/detection/ssd/txt_to_json.py b/AscendIE/TorchAIE/built-in/cv/detection/ssd/txt_to_json.py index 6a27b1aee3..328b2319c8 100644 --- a/AscendIE/TorchAIE/built-in/cv/detection/ssd/txt_to_json.py +++ b/AscendIE/TorchAIE/built-in/cv/detection/ssd/txt_to_json.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""txt to json""" import glob import os -- Gitee