diff --git a/ACL_PyTorch/contrib/cv/segmentation/Segformer/README.md b/ACL_PyTorch/contrib/cv/segmentation/Segformer/README.md new file mode 100644 index 0000000000000000000000000000000000000000..54c373e85617deb3820650735d01ace26dcbcb47 --- /dev/null +++ b/ACL_PyTorch/contrib/cv/segmentation/Segformer/README.md @@ -0,0 +1,109 @@ +## Segformer 模型离线推理指导 + +### 一、环境准备 + +环境: CANN 5.1.rc1 + +1. 安装依赖 + +``` +pip install -r requirements.txt +``` + +  + +2. 获取开源模型代码仓 + +``` +git clone https://github.com/open-mmlab/mmsegmentation.git +cd mmsegmentation +python setup.py develop +``` + +  + +### 二、转 ONNX + +1. 进入 mmsegmentation/configs 目录下 [segformer](https://github.com/open-mmlab/mmsegmentation/tree/master/configs/segformer) +2. 下载 Cityscapes 栏第一行 model 权重文件 +3. 将权重文件 segformer_mit-b0_8x1_1024x1024_160k_cityscapes_20211208_101857-e7f88502.pth 放到当前工作目录 +4. 执行转 ONNX 脚本 + ``` + python mmsegmentation/tools/pytorch2onnx.py \ + mmsegmentation/configs/segformer/segformer_mit-b0_8x1_1024x1024_160k_cityscapes.py \ + --checkpoint ./segformer_mit-b0_8x1_1024x1024_160k_cityscapes_20211208_101857-e7f88502.pth \ + --output-file segformer_mit-b0_8x1_1024x1024_160k_cityscapes_dynamic_bs.onnx \ + --shape 1024 2048 \ + --verify \ + --dynamic-export + ``` +5. 使用 onnx-simplifier 简化 onnx 模型 + ``` + python -m onnxsim --input-shape="1,3,1024,2048" --dynamic-input-shape segformer_mit-b0_8x1_1024x1024_160k_cityscapes_dynamic_bs.onnx segformer_mit-b0_8x1_1024x1024_160k_cityscapes_dynamic_bs_sim.onnx + ``` + +  + +### 三、转 OM + +``` +# 设置环境变量 +source /usr/local/Ascend/ascend-toolkit/set_env.sh + +# 生成的 onnx 模型为动态 batch size,通过调整参数 --input_shape 的第一个参数修改生成的 om 模型的 batch size +atc --framework=5 --model=segformer_mit-b0_8x1_1024x1024_160k_cityscapes_dynamic_bs_sim.onnx --output=segformer_mit-b0_8x1_1024x1024_160k_cityscapes_bs1_sim --input_format=NCHW --input_shape="input:1,3,1024,2048" --soc_version=Ascend710 --log=debug +``` + +  + +### 四、数据集预处理 + +1. 将数据集存放到指定目录 + 获取 cityscapes 数据集,解压存放到 /opt/npu/ 文件夹内 + +  + +2. 执行数据预处理脚本 + ``` + python ./Segformer_preprocess.py /opt/npu/cityscapes/leftImg8bit/val /opt/npu/prep_dataset + ``` + 处理后的 bin 文件放到 /opt/npu/prep_dataset 位置 + +  + +3. 生成 info 文件 + ``` + python ./get_info.py bin /opt/npu/prep_dataset ./prep_bin.info 1024 2048 + ``` + 生成的 info 文件保存到当前工作目录下 + +  + +### 五、离线推理 + +1. 准备 msame 推理工具 + 将 msame 文件放到当前工作目录 + (chmod +x msame 添加访问权限) + +  + +2. 推理时,使用 npu-smi info 命令查看 device 是否在运行其它推理任务,提前确保 device 空闲 + ``` + ./msame --model "./segformer_mit-b0_8x1_1024x1024_160k_cityscapes_bs1_sim.om" --input "/opt/npu/prep_dataset" --output "./msame_result" --outfmt BIN + ``` + 生成的结果文件存放到当前目录下的 msame_result 文件夹内 + +  + +3. 执行数据后处理脚本 + ``` + bash Segformer_postprocess.sh + ``` + +  + +**精度评测结果:** + +| 模型 | 官网精度 | 710 精度 | T4 性能 | 710 性能 | +| ------- | ------- | -------- | -------- | -------- | +| Segformer bs1 | mIoU = 76.54 | mIoU = 76.53 | 8.861 fps | 5.587 fps | diff --git a/ACL_PyTorch/contrib/cv/segmentation/Segformer/Segformer_postprocess.py b/ACL_PyTorch/contrib/cv/segmentation/Segformer/Segformer_postprocess.py new file mode 100644 index 0000000000000000000000000000000000000000..126d420207cc40e72ca624e55832df249f3bca53 --- /dev/null +++ b/ACL_PyTorch/contrib/cv/segmentation/Segformer/Segformer_postprocess.py @@ -0,0 +1,176 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import torch +import argparse +import os +from mmsegmentation.mmseg.core.evaluation import metrics +from PIL import Image + + +class GtFineFile(object): + """ + directory: path to gtFine + suffix: suffix of the gtFine + return path List of gtFine files + """ + def __init__(self, directory, suffix='_gtFine_labelTrainIds.png'): + gtFine_list = [] + for root, dirs, files in os.walk(directory): + for special_file in files: + if special_file.endswith(suffix): + gtFine_list.append(os.path.join(root, special_file)) + # print("Found gtFine files:", os.path.join(root, special_file)) + self.gtFine_list = gtFine_list + + def get_file(self, filename): + """ return file path list """ + for f in self.gtFine_list: + if f.endswith(filename): + return f + + +class IntersectAndUnion(object): + """Calculate Total Intersection and Union. + + Args: + results (list[ndarray] | list[str]): List of prediction segmentation + maps or list of prediction result filenames. + gt_seg_maps (list[ndarray] | list[str]): list of ground truth + segmentation maps or list of label filenames. + num_classes (int): Number of categories. + ignore_index (int): Index that will be ignored in evaluation. + label_map : Mapping old labels to new labels. + reduce_zero_label (bool): Whether ignore zero label. Default: False. + + Returns: + IoU + Acc + """ + + def __init__(self, num_classes, ignore_index, label_map, reduce_zero_label=False): + self.num_classes = num_classes + self.ignore_index = ignore_index + self.label_map = label_map + self.reduce_zero_label = reduce_zero_label + self.total_area_intersect = torch.zeros((num_classes,), dtype=torch.float64) + self.total_area_union = torch.zeros((num_classes,), dtype=torch.float64) + self.total_area_pred_label = torch.zeros((num_classes,), dtype=torch.float64) + self.total_area_label = torch.zeros((num_classes,), dtype=torch.float64) + + def update(self, output, gt_seg_map): + """ update """ + [area_intersect, area_union, area_pred_label, area_label] = \ + metrics.intersect_and_union( + output, gt_seg_map, self.num_classes, self.ignore_index, + self.label_map, self.reduce_zero_label) + self.total_area_intersect += area_intersect.to(torch.float64) + self.total_area_union += area_union.to(torch.float64) + self.total_area_pred_label += area_pred_label.to(torch.float64) + self.total_area_label += area_label.to(torch.float64) + + def get(self): + """ get result """ + iou = self.total_area_intersect / self.total_area_union + acc = self.total_area_intersect / self.total_area_label + all_acc = self.total_area_intersect.sum() / self.total_area_label.sum() + mIoU = np.round(np.nanmean(iou) * 100, 2) + aAcc = np.round(np.nanmean(all_acc) * 100, 2) + return {'aAcc': aAcc, 'mIoU': mIoU} + + +def eval_metrics(_output_path, + _gt_path, + _out_suffix='_leftImg8bit_output_0.bin', + _gt_suffix='_gtFine_labelTrainIds.png', + _result_path='./postprocess_result', + num_classes=19, + ignore_index=255, + label_map=None, + reduce_zero_label=False): + """Calculate evaluation metrics + Args: + results (list[ndarray] | list[str]): List of prediction segmentation + maps or list of prediction result filenames. + gt_seg_maps (list[ndarray] | list[str]): list of ground truth + segmentation maps or list of label filenames. + num_classes (int): Number of categories. + ignore_index (int): Index that will be ignored in evaluation. + metrics (list[str] | str): Metrics to be evaluated, 'mIoU' and 'mDice'. + nan_to_num (int, optional): If specified, NaN values will be replaced + by the numbers defined by the user. Default: None. + label_map (dict): Mapping old labels to new labels. Default: dict(). + reduce_zero_label (bool): Wether ignore zero label. Default: False. + Returns: + float: Overall accuracy on all images. + ndarray: Per category accuracy, shape (num_classes, ). + ndarray: Per category evaluation metrics, shape (num_classes, ). + """ + + # initial metric + label_map = dict() + metric = IntersectAndUnion(num_classes, ignore_index, label_map, reduce_zero_label) + + # initial gtFine files list + fileFinder = GtFineFile(_gt_path) + + for root, dirs, files in os.walk(_output_path): + files = [f for f in files if str(f).endswith('bin')] + length = str(files.__len__()) + for i, output_name in enumerate(files): + if not str(output_name).endswith('bin'): + continue + print('Segformer metric [' + str(i + 1) + '/' + length + '] on process: ' + output_name) + seg_map_name = str(output_name).replace(_out_suffix, _gt_suffix) + seg_map_path = fileFinder.get_file(seg_map_name) + if seg_map_name is not None: + seg_map = Image.open(seg_map_path) + seg_map = np.array(seg_map, dtype=np.uint8) + + _output_path = os.path.realpath(os.path.join(root, output_name)) + output = np.fromfile(_output_path, dtype=np.uint64).reshape(1024, 2048) + output = output.astype(np.uint8) + metric.update(output, seg_map) + else: + print("[ERROR] " + seg_map_name + " not find, check the file or make sure --out_suffix") + + # get result + result = metric.get() + print(result) + with open(_result_path + '.txt', 'w') as f: + f.write('aAcc: {}\n'.format(result['aAcc'])) + f.write('mIoU: {}\n'.format(result['mIoU'])) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser('mIoU calculate') + parser.add_argument('--output_path', default="./result", + help='path to om/onnx output file, default ./result') + parser.add_argument('--gt_path', default="/opt/npu/cityscapes/gtFine/val", + help='path to gtFine/val, default /opt/npu/cityscapes/gtFine/val') + parser.add_argument('--out_suffix', default="_leftImg8bit_output_0.bin", + help='suffix of the om/onnx output, default "_leftImg8bit_output_0.bin"') + parser.add_argument('--result_path', default="./postprocess_result", + help='path to save the script result, default ./postprocess_result.txt') + + args = parser.parse_args() + + output_path = os.path.realpath(args.output_path) + gt_path = os.path.realpath(args.gt_path) + out_suffix = args.out_suffix + result_path = os.path.realpath(args.result_path) + print("output_path :", output_path) + print("gt_path :", gt_path) + eval_metrics(_output_path = output_path, _gt_path = gt_path, _out_suffix = out_suffix, _result_path = result_path) diff --git a/ACL_PyTorch/contrib/cv/segmentation/Segformer/Segformer_preprocess.py b/ACL_PyTorch/contrib/cv/segmentation/Segformer/Segformer_preprocess.py new file mode 100644 index 0000000000000000000000000000000000000000..8a26ab9157516287bdda74de86eed428ad0ff4e7 --- /dev/null +++ b/ACL_PyTorch/contrib/cv/segmentation/Segformer/Segformer_preprocess.py @@ -0,0 +1,80 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import os +import cv2 +import numpy as np +from torchvision import transforms + + +class Normalize(object): + def __init__(self, mean, std, to_rgb=True): + self.mean = np.array(mean, dtype=np.float32) + self.std = np.array(std, dtype=np.float32) + self.to_rgb = to_rgb + + def __call__(self, img): + img = img.copy().astype(np.float32) + # cv2 inplace normalization does not accept uint8 + assert img.dtype != np.uint8 + mean = np.float64(self.mean.reshape(1, -1)) + stdinv = 1 / np.float64(self.std.reshape(1, -1)) + if self.to_rgb: + cv2.cvtColor(img, cv2.COLOR_BGR2RGB, img) + cv2.subtract(img, mean, img) + cv2.multiply(img, stdinv, img) + return img + + +def preprocess(src_path, save_path): + preprocess = transforms.Compose([ + Normalize(mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True), + transforms.ToTensor(), + ]) + + root = src_path + + def _scandir(dir_path, suffix, recursive): + for entry in os.scandir(dir_path): + if not entry.name.startswith('.') and entry.is_file(): + rel_path = os.path.relpath(entry.path, root) + if suffix is None or rel_path.endswith(suffix): + yield rel_path + elif recursive and os.path.isdir(entry.path): + # scan recursively if entry.path is a directory + yield from _scandir(entry.path, suffix=suffix, recursive=recursive) + + in_files = _scandir(src_path, '_leftImg8bit.png', True) + if not os.path.exists(save_path): + os.makedirs(save_path) + + i = 0 + for file in in_files: + i = i + 1 + print(file, "====", i) + input_image = cv2.imread(src_path + '/' + file) + input_tensor = preprocess(input_image) + img = np.array(input_tensor).astype(np.float32) + img.tofile(os.path.join(save_path, file.split('/')[-1].split('.')[0] + ".bin")) + + +if __name__ == '__main__': + if len(sys.argv) < 3: + raise Exception("usage: python xxx.py [src_path] [save_path]") + src_path = sys.argv[1] + save_path = sys.argv[2] + src_path = os.path.realpath(src_path) + save_path = os.path.realpath(save_path) + preprocess(src_path, save_path) diff --git a/ACL_PyTorch/contrib/cv/segmentation/Segformer/modelzoo_level.txt b/ACL_PyTorch/contrib/cv/segmentation/Segformer/modelzoo_level.txt new file mode 100644 index 0000000000000000000000000000000000000000..d2a4478db713b8c4e75840384718047052d2750c --- /dev/null +++ b/ACL_PyTorch/contrib/cv/segmentation/Segformer/modelzoo_level.txt @@ -0,0 +1,4 @@ +FuncStatus:OK +PerfStatus:POK +PrecisionStatus:OK +ModelConvert:OK \ No newline at end of file diff --git a/ACL_PyTorch/contrib/cv/segmentation/Segformer/requirements.txt b/ACL_PyTorch/contrib/cv/segmentation/Segformer/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..f34142973fcecf697acc7c1f9cbde29e094d307c --- /dev/null +++ b/ACL_PyTorch/contrib/cv/segmentation/Segformer/requirements.txt @@ -0,0 +1,8 @@ +numpy == 1.21.5 +torch == 1.7.1 +torchvision == 0.8.2 +onnx == 1.10.1 +onnxruntime == 1.7.2 +onnx-simplifier == 0.3.8 +opencv-python == 4.5.5.64 +mmcv-full == 1.4.8 \ No newline at end of file