diff --git a/ACL_PyTorch/built-in/cv/InternImage_detection_for_Pytorch/README.md b/ACL_PyTorch/built-in/cv/InternImage_detection_for_Pytorch/README.md new file mode 100755 index 0000000000000000000000000000000000000000..e2b0b01f2599a1afa961d1e97379ce97526d11fa --- /dev/null +++ b/ACL_PyTorch/built-in/cv/InternImage_detection_for_Pytorch/README.md @@ -0,0 +1,277 @@ +# InternImage Detection 推理指导 + +- [概述](#summary) + + - [输入数据](#input_data) + +- [推理环境准备](#env_setup) + +- [快速上手](#quick_start) + + - [获取源码](#get_code) + + - [下载数据集](#download_data) + + - [模型推理](#infer) + +- [模型推理性能 & 精度](#performance) + +# 概述 + +InternImage 是一个由上海人工智能实验室、清华大学等机构的研究人员提出的基于卷积神经网络(CNN)的视觉基础模型。与基于 Transformer 的网络不同,InternImage 以可变形卷积 DCNv3 作为核心算子,使模型不仅具有检测和分割等下游任务所需的动态有效感受野,而且能够进行自适应的空间聚合。此指导仅针对InternImage项目下的以InternImage-XL为backbone,method使用Cascade,schd为3×的模型。该模型使用box mAP与mask mAP作为评价指标。 + +- 版本说明: + + ``` + url=https://github.com/OpenGVLab/InternImage + commit_id=41b18fd85f20a4f85c0a1e6b1d5f97303aab1800 + model_name=InternImage + ``` + +## 输入数据 + +InternImage使用公共数据集COCO进行推理 + +| 输入数据 | 数据类型 | 大小 | 数据排布格式 | +|:----:|:--------:|:-----------:|:------:| +| img | RGB_FP32 | (1,3,-1,-1) | NCHW | + +# + +# 推理环境准备 + +该模型需要以下依赖 + +表1 **版本配套表** + +| 依赖 | 版本 | 环境准备指导 | +| ------- | ------- |:-------------------------------------------------------------------------------------------------------------:| +| 固件与驱动 | 24.1.0 | [PyTorch框架推理环境准备](https://www.hiascend.com/document/detail/zh/ModelZoo/pytorchframework/pies/pies_00001.html) | +| PyTorch | 2.1.0 | - | +| CANN | 8.0.RC2 | - | +| Python | 3.9 | - | + +# 快速上手 + +## 获取源码 + +1. 获取本仓源码 + + ``` + git clone https://gitee.com/ascend/ModelZoo-PyTorch.git + cd ModelZoo-PyTorch/ACL_PyTorch/built-in/cv/InternImage_detection_for_Pytorch + ``` + +2. 获取模型仓**InternImage**源码和依赖仓**mmdet**源码 + + ``` + git clone https://github.com/open-mmlab/mmdet.git + git clone https://github.com/OpenGVLab/InternImage.git + cd mmdetection + git reset --hard cfd5d3a985b0249de009b67d04f37263e11cdf3d + cd ../InternImage + git reset --hard 41b18fd85f20a4f85c0a1e6b1d5f97303aab1800 + cd .. + ``` + +3. 转移文件夹位置 + + ``` + mv internimage_det.patch InternImage/detection/ + mv *.py InternImage/detection/ + mv exceptionlist.cfg InternImage/detection/ + mv mmdet.patch mmdetection/mmdet/ + ``` + +4. 安装依赖 + + ``` + pip3 install -r requirement.txt + ``` + +5. 更换当前路径并打补丁,修改完mmseg源码后进行安装 + + ``` + cd mmdetection/mmdet/ + patch -p2 < mmdet.patch + cd .. + pip3 install -v -e . + + cd ../../InternImage/detection/ + patch -p2 < internimage_det.patch + ``` + +## 下载数据集 + +    使用下面的链接下载数据集并解压放在InternImage/detection/data目录下 + +> [COCO数据集下载](https://cocodataset.org/#download) + +    确保data下的路径结构如下 + +``` +├── data +│ ├── coco +│ │ ├── annotations +│ │ ├── val2017 +``` + +## 模型推理 + +1. 使用PyTorch将模型权重文件.pth转换为.onnx文件,再使用ATC工具将.onnx文件转为离线推理模型文件.om文件 + + 1. 下载权重文件并放到InternImage/detection/ckpt下 + + > [ckpt文件下载](https://huggingface.co/OpenGVLab/InternImage/resolve/main/cascade_internimage_xl_fpn_3x_coco.pth) + + 2. 数据预处理 + + 执行如下命令以开始数据预处理。预处理脚本会另外生成一个img_shape.npy文件,将会用于离线推理。由于此项目源码中存在不同尺寸的图片输出的多尺度特征走不同的处理的情况,而ONNX只能记录其中一种处理路径。为了所有尺寸的图片能有一致的处理方法,此项目在预处理时将所有图片统一缩放为1216*1216尺寸 + + ``` + python3 preprocess.py --config configs/coco/cascade_internimage_xl_fpn_3x_coco.py --data_output data_after_preprocess --force_img_shape 1216,1216 + ``` + + - 参数说明 + + - --config:配置文件路径 + + - --data_output:数据经预处理后的输出路径 + + - --force_img_shape:原图经强制缩放后的尺寸 + + 3. 导出onnx文件 + + 确认当前路径为InternImage/detection并执行如下命令导出onnx文件 + + ``` + python export2onnx.py --config configs/coco/cascade_internimage_xl_fpn_3x_coco.py --ckpt ckpt/cascade_internimage_xl_fpn_3x_coco.pth --export onnx/cascade_internimage_xl_fpn_3x_coco.onnx --data ./data_after_preprocess --img_shape_path ./img_shape + ``` + + - 参数说明 + + - --config:配置文件路径 + + - --ckpt:权重文件路径 + + - --export:导出的ONNX模型路径 + + - --data:前处理后的图像数据的路径 + + - --img_shape_path:存储前处理后的图像shape + + 4. 请访问[msit推理工具](https://gitee.com/ascend/msit/tree/master/msit/)代码仓,根据README文档进行工具安装benchmark和surgeon + + 5. 使用ATC工具将ONNX模型转为OM模型 + + 1. 配置环境变量 + + ``` + source /usr/local/Ascend/ascend-toolkit/set_env.sh + ``` + + 2. 执行命令查看芯片名称($(chip_name)) + + > ``` + > npu-smi info + > #该设备芯片名为Ascend310P3 (自行替换) + > 回显如下: + > +-------------------+-----------------+------------------------------------------------------+ + > | NPU Name | Health | Power(W) Temp(C) Hugepages-Usage(page) | + > | Chip Device | Bus-Id | AICore(%) Memory-Usage(MB) | + > +===================+=================+======================================================+ + > | 0 310P3 | OK | 15.8 42 0 / 0 | + > | 0 0 | 0000:82:00.0 | 0 1074 / 21534 | + > +===================+=================+======================================================+ + > | 1 310P3 | OK | 15.4 43 0 / 0 | + > | 0 1 | 0000:89:00.0 | 0 1070 / 21534 | + > +===================+=================+======================================================+ + > ``` + + 3. 执行ATC命令将ONNX模型转为OM模型 + + ``` + atc --model=onnx/cascade_internimage_xl_fpn_3x_coco.onnx --framework=5 --output=om/cascade_internimage_xl_fpn_3x_coco --input_format=NCHW --input_shape="data:1,3,1216,1216" --soc_version=Ascend${chip_name} --keep_dtype exceptionlist.cfg + ``` + + - 参数说明 + + - --model:为ONNX模型文件路径 + + - --framework:5代表ONNX模型 + + - --output:输出的OM模型路径 + + - --input_format:输入数据的格式 + + - --input_shape:输入数据的shape + + - --keep_dtype: 指定部分算子使用FP32运行 + + - --soc_version:指定目标芯片型号 + +2. 开始推理验证 + + 1. 执行离线推理 + + ``` + python3 -m ais_bench --model om/cascade_internimage_xl_fpn_3x_coco.om --input ./data_after_preprocess,./img_shape --output ./ --output_dirname om_output --outfmt NPY + ``` + + - 参数说明: + + - --model:离线推理所使用的OM模型路径 + + - --input:离线推理所使用的数据集 + + - --output:推理结果保存目录 + + - --output_dirname:推理结果保存子目录 + + - --outfmt:输出数据的格式。取值可以是:"NPY","BIN","TXT"。本项目暂时只支持NPY格式输出 + + 2. 数据后处理 + + 执行如下命令以开始数据后处理并获得OM模型的精度 + + ``` + python3 postprocess.py --config configs/coco/cascade_internimage_xl_fpn_3x_coco.py --ckpt ckpt/cascade_internimage_xl_fpn_3x_coco.pth --om_output om_output --eval bbox segm --batch_size 100 --force_img_shape 1216,1216 + ``` + + - 参数说明: + + - --config:配置文件路径 + + - --ckpt:权重文件的路径 + + - --om_output:OM模型的输出路径,也即后处理脚本的输入数据路径 + + - --eval:精度评估指标 + + - --batch_size:后处理的时候每次加载的图片数量,取决于内存大小。每张图片大约需要120M的内存空间。默认为100 + + - --force_img_shape:预处理时原图经强制缩放后的尺寸 + + 上面的脚本运行完毕将会打印精度数据 + + 3. 性能推理 + + 运行以下命令获取OM模型的性能数据 + + ``` + python3 -m ais_bench --model om/cascade_internimage_xl_fpn_3x_coco.om --loop 100 + ``` + + * 参数说明: + + * --model:OM模型路径 + + * --loop:离线推理循环次数 + +# 模型推理性能 & 精度 + +后处理完成后会打印精度数据,精度参考下列数据。 + +| 芯片型号 | 模型 | box mAP | seg mAP | 性能 | +| -------- | ---------------------------------- | ------- | ------- | --------- | +| 300I PRO | cascade_internimage_xl_fpn_3x_coco | 0.556 | 0.486 | 1614.70ms | diff --git a/ACL_PyTorch/built-in/cv/InternImage_detection_for_Pytorch/exceptionlist.cfg b/ACL_PyTorch/built-in/cv/InternImage_detection_for_Pytorch/exceptionlist.cfg new file mode 100755 index 0000000000000000000000000000000000000000..634282ebbebe17a62b3fb6f2bd5d6e7ac777b2ac --- /dev/null +++ b/ACL_PyTorch/built-in/cv/InternImage_detection_for_Pytorch/exceptionlist.cfg @@ -0,0 +1 @@ +OpType::GridSampler2D \ No newline at end of file diff --git a/ACL_PyTorch/built-in/cv/InternImage_detection_for_Pytorch/export2onnx.py b/ACL_PyTorch/built-in/cv/InternImage_detection_for_Pytorch/export2onnx.py new file mode 100755 index 0000000000000000000000000000000000000000..92b24b60ed9bf75ed56c655c0fc55808fd26964a --- /dev/null +++ b/ACL_PyTorch/built-in/cv/InternImage_detection_for_Pytorch/export2onnx.py @@ -0,0 +1,68 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# [Software Name] is licensed under Mulan PSL v2. +# You can use this software according to the terms and conditions of the Mulan PSL v2. +# You may obtain a copy of Mulan PSL v2 at: +# http://license.coscl.org.cn/MulanPSL2 +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, +# EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, +# MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. +# See the Mulan PSL v2 for more details. + +import os +import argparse + +import mmcv +import torch +from mmengine.runner import load_checkpoint, load_state_dict +from mmengine import Config +from mmdet.apis.inference import init_detector +import numpy as np +from auto_optimizer import OnnxGraph + +import mmdet_custom # noqa: F401,F403 + + +def delete_domain(graph): + for node in graph.nodes: + if node.domain != '': + node.domain = '' + while len(graph.opset_imports) > 1: + graph.opset_imports.pop(1) + + +def main(cfg_path, ckpt_path, data_dir, export_path, img_shape_path="./img_shape"): + export_dir = os.path.dirname(export_path) + os.makedirs(export_dir, exist_ok=True) + + model = init_detector(cfg_path, checkpoint=None, device='cpu') + checkpoint = load_checkpoint(model, ckpt_path, map_location='cpu') + load_state_dict(model, checkpoint['state_dict'], strict=False) + model.eval() + + file_path = os.path.join(data_dir, os.listdir(data_dir)[0]) + data_input = torch.from_numpy(np.load(file_path)) + img_shape_file_path = os.path.join(img_shape_path, os.path.basename(file_path)) + img_shape = torch.from_numpy(np.load(img_shape_file_path)) + + torch.onnx.export(model, (data_input, img_shape), export_path, + opset_version=16, verbose=False, + input_names=['data', 'img_shape'], + output_names=['cls_scores', 'bboxes', 'feature_map', 'rois'], + keep_initializers_as_inputs=False) + print('successfully export onnx') + onnx_graph = OnnxGraph.parse(export_path) + delete_domain(onnx_graph) + onnx_graph.save(export_path) + print('successfully delete domain') + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description="Export model to ONNX") + parser.add_argument('--config', type=str, help='config file path', required=True) + parser.add_argument('--ckpt', type=str, help='checkpoint file path', required=True) + parser.add_argument('--data', type=str, help='directory of preprocessed data', required=True) + parser.add_argument('--img_shape_path', type=str, default="./img_shape", help='directory that saves the img shape') + parser.add_argument('--export', type=str, help='ONNX file path to be exported', required=True) + args = parser.parse_args() + + main(args.config, args.ckpt, args.data, args.export, args.img_shape_path) diff --git a/ACL_PyTorch/built-in/cv/InternImage_detection_for_Pytorch/postprocess.py b/ACL_PyTorch/built-in/cv/InternImage_detection_for_Pytorch/postprocess.py new file mode 100755 index 0000000000000000000000000000000000000000..37bfbbd695e5a3e48d65887049280b234d9cc61d --- /dev/null +++ b/ACL_PyTorch/built-in/cv/InternImage_detection_for_Pytorch/postprocess.py @@ -0,0 +1,118 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# [Software Name] is licensed under Mulan PSL v2. +# You can use this software according to the terms and conditions of the Mulan PSL v2. +# You may obtain a copy of Mulan PSL v2 at: +# http://license.coscl.org.cn/MulanPSL2 +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, +# EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, +# MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. +# See the Mulan PSL v2 for more details. + +import argparse +import os + +import torch +from mmengine.config import Config +from mmdet.apis import init_detector +from mmengine.registry import Registry, build_from_cfg +from mmdet.registry import DATASETS +from tqdm import tqdm +import numpy as np +from mmengine.registry import EVALUATOR +from mmdet.evaluation import CocoMetric + +from preprocess import parse_shape, adjust_cfg +import mmdet_custom # noqa: F401,F403 + +# for each img input, 8 NPY files will be output +NUM_OM_OUTPUT_FILE = 8 +# the output NPY files with postfix 2-6 are feature map from FPN +FEAT_IDX_START = 2 +FEAT_IDX_END = 7 + + +def process_batch(cfg, om_output_path, evaluator, basename_metas, model=None): + for filename in tqdm(basename_metas.keys(), desc='post-processing'): + img_meta = basename_metas.get(filename) + try: + # load all the om output files for post-processing + file_path_prefix = os.path.join(om_output_path, filename) + cls_scores = np.load(f'{file_path_prefix}_0.npy') + bbox_preds = np.load(f'{file_path_prefix}_1.npy') + feature_map = [] + for i in range(FEAT_IDX_START, FEAT_IDX_END): + feature_map.append(np.load(f'{file_path_prefix}_{i}.npy')) + rois = np.load(f'{file_path_prefix}_7.npy') + + cls_scores = torch.from_numpy(cls_scores) + bbox_preds = torch.from_numpy(bbox_preds) + feature_map = [torch.from_numpy(fm) for fm in feature_map] + rois = torch.from_numpy(rois) + + bbox_results = model.roi_head.bbox_head[-1].predict_by_feat( + rois=[rois], + cls_scores=[cls_scores], + bbox_preds=[bbox_preds], + batch_img_metas=[img_meta], + rescale=False, + rcnn_test_cfg=cfg.model.test_cfg.rcnn) + mask_results = model.roi_head.predict_mask( + feature_map, [img_meta], bbox_results, rescale=True) + + # construct the legal input for evaluator + result = { + 'pred_instances': { + 'bboxes': mask_results[0]['bboxes'], + 'scores': mask_results[0]['scores'], + 'labels': mask_results[0]['labels'], + 'masks': mask_results[0]['masks'] + }, + **img_meta + } + evaluator.process(data_batch={}, data_samples=[result]) + except Exception as e: + print(f'Error processing {filename}: {str(e)}') + continue + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description="post process data") + parser.add_argument('--config', type=str, required=True, help='config file path') + parser.add_argument('--ckpt', type=str, required=True, help='ckpt file path') + parser.add_argument('--om_output', type=str, required=True, help='om output for post-process') + parser.add_argument( + '--force_img_shape', type=parse_shape, default=None, help='Rescale image to shape (e.g., "256,256")') + parser.add_argument('--batch_size', type=int, default=100, help='number of processed imgs at the same time') + parser.add_argument('--eval', nargs='+', type=str, help='evaluation types, e.g., bbox, segm') + args = parser.parse_args() + + cfg = adjust_cfg(Config.fromfile(args.config), force_img_shape=args.force_img_shape) + dataset = build_from_cfg(cfg.data.test, DATASETS, None) + + basename_metas = {} + for data in tqdm(dataset, desc='loading metainfo'): + img_meta = data['data_sample'][0].metainfo + basename = os.path.basename(img_meta['img_path']).split('.')[0] + basename_metas[basename] = img_meta + + model = init_detector(args.config, args.ckpt, device='cpu') + + # construct evaluator + eval_cfg = dict( + type='mmdet.evaluation.CocoMetric', + ann_file=dataset.ann_file, + metric=args.eval, + classwise=True, + format_only=False, + _scope_='mmdet.evaluation', + ) + eval_cfg.update(cfg.get('evaluation', {})) + evaluator = EVALUATOR.build(eval_cfg) + evaluator.dataset_meta = dataset.metainfo + + print('Start post-processing') + process_batch(cfg, args.om_output, evaluator, basename_metas, model) + + print('Evaluating final results') + metrics = evaluator.evaluate(size=len(dataset)) + print(f"metric: bbox_mAP: {metrics['coco/bbox_mAP']}, segm_mAP: {metrics['coco/segm_mAP']}") \ No newline at end of file diff --git a/ACL_PyTorch/built-in/cv/InternImage_detection_for_Pytorch/preprocess.py b/ACL_PyTorch/built-in/cv/InternImage_detection_for_Pytorch/preprocess.py new file mode 100755 index 0000000000000000000000000000000000000000..0ac4e6fb904a40ec63b8467b3993deb783dc141c --- /dev/null +++ b/ACL_PyTorch/built-in/cv/InternImage_detection_for_Pytorch/preprocess.py @@ -0,0 +1,77 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# [Software Name] is licensed under Mulan PSL v2. +# You can use this software according to the terms and conditions of the Mulan PSL v2. +# You may obtain a copy of Mulan PSL v2 at: +# http://license.coscl.org.cn/MulanPSL2 +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, +# EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, +# MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. +# See the Mulan PSL v2 for more details. + +import argparse +import os +from mmengine.config import Config +from mmengine.registry import Registry, build_from_cfg +from mmdet.registry import DATASETS +from tqdm import tqdm +import numpy as np + + +def preprocess_data(dataset, data_output_dir, force_img_shape=None, img_shape_output_dir="./img_shape"): + os.makedirs(data_output_dir, exist_ok=True) + os.makedirs(img_shape_output_dir, exist_ok=True) + + for data in tqdm(dataset): + img = data['inputs'][0].unsqueeze(0) + file_path = data['data_sample'][0].metainfo['img_path'] + new_filename = os.path.splitext(os.path.basename(file_path))[0] + '.npy' + data_output_path = os.path.join(data_output_dir, new_filename) + img_shape_output_path = os.path.join(img_shape_output_dir, new_filename) + np.save(data_output_path, img) + if force_img_shape: + np.save(img_shape_output_path, force_img_shape) + else: + np.save(img_shape_output_path, data['data_sample'][0].metainfo['img_shape'][:2]) + + +def parse_shape(s): + try: + return [int(x) for x in s.split(',')] + except Exception as e: + raise argparse.ArgumentTypeError("Shape must be 'width,height' (e.g., '256,256')") from e + + +def adjust_cfg(cfg: dict, force_img_shape=None): + img_norm_cfg = cfg.img_norm_cfg + if force_img_shape: + scale = tuple(force_img_shape) + keep_ratio = False + else: + scale = cfg.data.test.pipeline[1].transforms[0].scale + keep_ratio = cfg.data.test.pipeline[1].transforms[0].keep_ratio + cfg.data.test.pipeline[1] = dict(type='MultiScaleFlipAug', + transforms=[ + dict(type='Resize', keep_ratio=keep_ratio, scale=scale), + dict(type='Normalize', **img_norm_cfg), + dict(type='mmdet.PackDetInputs', + meta_keys=( + 'img_id', 'img_path', 'ori_shape', 'img_shape', 'scale_factor')) + ]) + return cfg + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description="preprocess data") + parser.add_argument('--config', type=str, required=True, help='config path') + parser.add_argument('--data_output', type=str, required=True, help='output path for preprocessed data') + parser.add_argument( + '--force_img_shape', type=parse_shape, help='Rescale image to shape (e.g., "256,256")') + parser.add_argument('--img_shape_output', type=str, default="./img_shape", + help='output path for preprocessed img shape') + args = parser.parse_args() + + cfg = Config.fromfile(args.config) + cfg = adjust_cfg(cfg, args.force_img_shape) + dataset = build_from_cfg(cfg.data.test, DATASETS, None) + preprocess_data(dataset, args.data_output, args.force_img_shape, args.img_shape_output) + print('\033[92m' + 'data preprocessing finished' + '\033[0m') \ No newline at end of file diff --git a/ACL_PyTorch/built-in/cv/InternImage_detection_for_Pytorch/requirement.txt b/ACL_PyTorch/built-in/cv/InternImage_detection_for_Pytorch/requirement.txt new file mode 100755 index 0000000000000000000000000000000000000000..8ddcce8ef92d9e6830cfc1a0dda4ce152c7da7e1 --- /dev/null +++ b/ACL_PyTorch/built-in/cv/InternImage_detection_for_Pytorch/requirement.txt @@ -0,0 +1,21 @@ +onnx +torch==2.1.0 +tqdm +torchvision==0.13.0 +mmcv==2.1.0 +timm==0.6.11 +mmdet==3.0.0 +mmengine==0.10.6 +opencv-python +termcolor +yacs +pyyaml +scipy +pydantic==1.10.13 +yapf==0.40.1 +numpy==1.26.4 +prettytable +ftfy +regex +decorator +psutil \ No newline at end of file