From 573aab2e882f2e19637ce5240950e22aa09f6d9a Mon Sep 17 00:00:00 2001 From: "xinchi.tian" Date: Mon, 6 May 2024 13:05:08 +0800 Subject: [PATCH] Add Solov1 Model IXRT link #I9FP0V Add Solov1 Model IXRT Signed-off-by: xinchi.tian --- models/cv/detection/solov1/ixrt/README.md | 63 +++++++ .../cv/detection/solov1/ixrt/build_engine.py | 62 +++++++ .../cv/detection/solov1/ixrt/coco_instance.py | 64 +++++++ models/cv/detection/solov1/ixrt/common.py | 95 ++++++++++ .../scripts/infer_solov1_fp16_accuracy.sh | 96 ++++++++++ .../scripts/infer_solov1_fp16_performance.sh | 95 ++++++++++ .../detection/solov1/ixrt/simplify_model.py | 36 ++++ .../solov1/ixrt/solo_r50_fpn_3x_coco.py | 67 +++++++ .../detection/solov1/ixrt/solo_torch2onnx.py | 95 ++++++++++ .../detection/solov1/ixrt/solov1_inference.py | 170 ++++++++++++++++++ 10 files changed, 843 insertions(+) create mode 100644 models/cv/detection/solov1/ixrt/README.md create mode 100644 models/cv/detection/solov1/ixrt/build_engine.py create mode 100755 models/cv/detection/solov1/ixrt/coco_instance.py create mode 100644 models/cv/detection/solov1/ixrt/common.py create mode 100755 models/cv/detection/solov1/ixrt/scripts/infer_solov1_fp16_accuracy.sh create mode 100644 models/cv/detection/solov1/ixrt/scripts/infer_solov1_fp16_performance.sh create mode 100644 models/cv/detection/solov1/ixrt/simplify_model.py create mode 100755 models/cv/detection/solov1/ixrt/solo_r50_fpn_3x_coco.py create mode 100755 models/cv/detection/solov1/ixrt/solo_torch2onnx.py create mode 100644 models/cv/detection/solov1/ixrt/solov1_inference.py diff --git a/models/cv/detection/solov1/ixrt/README.md b/models/cv/detection/solov1/ixrt/README.md new file mode 100644 index 00000000..e4a56a8f --- /dev/null +++ b/models/cv/detection/solov1/ixrt/README.md @@ -0,0 +1,63 @@ +# Solov1 + +## Description +SOLO (Segmenting Objects by Locations) is a new instance segmentation method that differs from traditional approaches by introducing the concept of “instance categories”. Based on the location and size of each instance, SOLO assigns each pixel to a corresponding instance category. This method transforms the instance segmentation problem into a single-shot classification task, simplifying the overall process. + +## Setup + +### Install +```bash +yum install mesa-libGL + +pip3 install tqdm +pip3 install onnx +pip3 install onnxsim +pip3 install tabulate +pip3 install mmdet==2.28.2 +pip3 install addict +pip3 install yapf +``` + +### Dependency +The inference of the Solov1 model requires a dependency on a well-adapted mmcv-v1.7.0 library. Please inquire with the staff to obtain the relevant libraries. +```bash +cd mmcv +sh build_mmcv.sh +sh install_mmcv.sh +``` + +### Download +Pretrained model: + +Dataset: to download the validation dataset. + +### Model Conversion +```bash +mkdir checkpoints +python3 solo_torch2onnx.py --cfg /path/to/solo/solo_r50_fpn_3x_coco.py --checkpoint /path/to/solo_r50_fpn_3x_coco_20210901_012353-11d224d7.pth --batch_size 1 +mv r50_solo_bs1_800x800.onnx /Path/to/checkpoints/r50_solo_bs1_800x800.onnx +``` + +## Inference +```bash +export PROJ_DIR=./ +export DATASETS_DIR=/path/to/coco2017/ +export CHECKPOINTS_DIR=./checkpoints +export COCO_GT=${DATASETS_DIR}/annotations/instances_val2017.json +export EVAL_DIR=${DATASETS_DIR}/val2017 +export RUN_DIR=./ +``` +### FP16 + +```bash +# Accuracy +bash scripts/infer_solov1_fp16_accuracy.sh +# Performance +bash scripts/infer_solov1_fp16_performance.sh +``` + +## Results + +Model |BatchSize |Precision |FPS |MAP@0.5 |MAP@0.5:0.95 +--------|-----------|----------|----------|----------|------------ +Solov1 | 1 | FP16 | 24.67 | 0.541 | 0.338 \ No newline at end of file diff --git a/models/cv/detection/solov1/ixrt/build_engine.py b/models/cv/detection/solov1/ixrt/build_engine.py new file mode 100644 index 00000000..08dfb0d5 --- /dev/null +++ b/models/cv/detection/solov1/ixrt/build_engine.py @@ -0,0 +1,62 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import argparse +import tensorrt +from tensorrt import Dims + +def make_parser(): + parser = argparse.ArgumentParser("DBnet Build engine") + parser.add_argument("--model", default="", type=str) + parser.add_argument("--engine", default="", type=str,help="float16 None,int8 quant json file") + return parser + +def main(config): + + IXRT_LOGGER = tensorrt.Logger(tensorrt.Logger.WARNING) + builder = tensorrt.Builder(IXRT_LOGGER) + EXPLICIT_BATCH = 1 << (int)(tensorrt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH) + network = builder.create_network(EXPLICIT_BATCH) + build_config = builder.create_builder_config() + parser = tensorrt.OnnxParser(network, IXRT_LOGGER) + + precision = tensorrt.BuilderFlag.FP16 + parser.parse_from_file(config.model) + build_config.set_flag(precision) + + plan = builder.build_serialized_network(network, build_config) + engine_file_path = config.engine + with open(engine_file_path, "wb") as f: + f.write(plan) + +if __name__ == "__main__": + config = make_parser().parse_args() + main(config) + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/models/cv/detection/solov1/ixrt/coco_instance.py b/models/cv/detection/solov1/ixrt/coco_instance.py new file mode 100755 index 00000000..bb1e046d --- /dev/null +++ b/models/cv/detection/solov1/ixrt/coco_instance.py @@ -0,0 +1,64 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +# dataset settings +dataset_type = 'CocoDataset' +data_root = 'data/coco/' +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations', with_bbox=True, with_mask=True), + dict(type='Resize', img_scale=(1333, 800), keep_ratio=True), + dict(type='RandomFlip', flip_ratio=0.5), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size_divisor=32), + dict(type='DefaultFormatBundle'), + dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']), +] +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='MultiScaleFlipAug', + img_scale=(800, 800), + flip=False, + transforms=[ + dict(type='Resize', keep_ratio=True), + dict(type='RandomFlip'), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size=(800,800)), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']), + ]) +] +data = dict( + samples_per_gpu=2, + workers_per_gpu=2, + train=dict( + type=dataset_type, + ann_file=data_root + 'annotations/instances_train2017.json', + img_prefix=data_root + 'train2017/', + pipeline=train_pipeline), + val=dict( + type=dataset_type, + ann_file=data_root + 'annotations/instances_val2017.json', + img_prefix=data_root + 'val2017/', + pipeline=test_pipeline), + test=dict( + type=dataset_type, + ann_file=data_root + 'annotations/instances_val2017.json', + img_prefix=data_root + 'val2017/', + pipeline=test_pipeline)) +evaluation = dict(metric=['bbox', 'segm']) diff --git a/models/cv/detection/solov1/ixrt/common.py b/models/cv/detection/solov1/ixrt/common.py new file mode 100644 index 00000000..3660388a --- /dev/null +++ b/models/cv/detection/solov1/ixrt/common.py @@ -0,0 +1,95 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import numpy as np +from tqdm import tqdm + +import tensorrt +import pycuda.driver as cuda +# input : [bsz, box_num, 5(cx, cy, w, h, conf) + class_num(prob[0], prob[1], ...)] +# output : [bsz, box_num, 6(left_top_x, left_top_y, right_bottom_x, right_bottom_y, class_id, max_prob*conf)] +def box_class85to6(input): + center_x_y = input[:, :2] + side = input[:, 2:4] + conf = input[:, 4:5] + class_id = np.argmax(input[:, 5:], axis = -1) + class_id = class_id.astype(np.float32).reshape(-1, 1) + 1 + max_prob = np.max(input[:, 5:], axis = -1).reshape(-1, 1) + x1_y1 = center_x_y - 0.5 * side + x2_y2 = center_x_y + 0.5 * side + nms_input = np.concatenate([x1_y1, x2_y2, class_id, max_prob*conf], axis = -1) + return nms_input + +def save2json(batch_img_id, pred_boxes, json_result): + for i, boxes in enumerate(pred_boxes): + image_id = int(batch_img_id) + if boxes is not None: + x, y, w, h, c, p = boxes + if image_id!=-1: + + x, y, w, h, p = float(x), float(y), float(w), float(h), float(p) + c = int(c) + json_result.append( + { + "image_id": image_id, + "category_id": c, + "bbox": [x, y, w, h], + "score": p, + } + ) + +def create_engine_context(engine_path, logger): + with open(engine_path, "rb") as f, tensorrt.Runtime(logger) as runtime: + runtime = tensorrt.Runtime(logger) + assert runtime + engine = runtime.deserialize_cuda_engine(f.read()) + assert engine + context = engine.create_execution_context() + assert context + return engine, context + +def get_io_bindings(engine): + # Setup I/O bindings + inputs = [] + outputs = [] + allocations = [] + + for i in range(engine.num_bindings): + is_input = False + if engine.binding_is_input(i): + is_input = True + name = engine.get_binding_name(i) + dtype = engine.get_binding_dtype(i) + shape = engine.get_binding_shape(i) + if is_input: + batch_size = shape[0] + size = np.dtype(tensorrt.nptype(dtype)).itemsize + for s in shape: + size *= s + allocation = cuda.mem_alloc(size) + binding = { + "index": i, + "name": name, + "dtype": np.dtype(tensorrt.nptype(dtype)), + "shape": list(shape), + "allocation": allocation, + } + # print(f"binding {i}, name : {name} dtype : {np.dtype(tensorrt.nptype(dtype))} shape : {list(shape)}") + allocations.append(allocation) + if engine.binding_is_input(i): + inputs.append(binding) + else: + outputs.append(binding) + return inputs, outputs, allocations \ No newline at end of file diff --git a/models/cv/detection/solov1/ixrt/scripts/infer_solov1_fp16_accuracy.sh b/models/cv/detection/solov1/ixrt/scripts/infer_solov1_fp16_accuracy.sh new file mode 100755 index 00000000..07838d7b --- /dev/null +++ b/models/cv/detection/solov1/ixrt/scripts/infer_solov1_fp16_accuracy.sh @@ -0,0 +1,96 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +#!/bin/bash + +EXIT_STATUS=0 +check_status() +{ + if ((${PIPESTATUS[0]} != 0));then + EXIT_STATUS=1 + fi +} + +# Run paraments +BSZ=1 +WARM_UP=-1 +TGT=-1 +LOOP_COUNT=-1 +RUN_MODE=MAP +PRECISION=float16 + +# Update arguments +index=0 +options=$@ +arguments=($options) +for argument in $options +do + index=`expr $index + 1` + case $argument in + --bs) BSZ=${arguments[index]};; + --tgt) TGT=${arguments[index]};; + esac +done + +MODEL_NAME="r50_solo_bs1_800x800" + +echo PROJ_DIR : ${PROJ_DIR} +echo CHECKPOINTS_DIR : ${CHECKPOINTS_DIR} +echo DATASETS_DIR : ${DATASETS_DIR} +echo RUN_DIR : ${RUN_DIR} + +step=0 + +# Simplify Model +let step++ +echo; +echo [STEP ${step}] : Simplify Model +SIM_MODEL=${CHECKPOINTS_DIR}/${MODEL_NAME}_sim.onnx +if [ -f ${SIM_MODEL} ];then + echo " "Simplify Model Skipped, ${SIM_MODEL} has been existed +else + python3 ${RUN_DIR}/simplify_model.py \ + --origin_model ${CHECKPOINTS_DIR}/${MODEL_NAME}.onnx \ + --output_model ${SIM_MODEL} + echo " "Generate ${SIM_MODEL} +fi + + +# Build Engine +let step++ +echo; +echo [STEP ${step}] : Build Engine +ENGINE_FILE=${CHECKPOINTS_DIR}/${MODEL_NAME}_${PRECISION}.engine +if [ -f $ENGINE_FILE ];then + echo " "Build Engine Skip, $ENGINE_FILE has been existed +else + python3 ${RUN_DIR}/build_engine.py \ + --model ${SIM_MODEL} \ + --engine ${ENGINE_FILE} + echo " "Generate Engine ${ENGINE_FILE} +fi + +# Inference +let step++ +echo; +echo [STEP ${step}] : Inference +python3 ${RUN_DIR}/solov1_inference.py \ + --engine ${ENGINE_FILE} \ + --cfg_file ${RUN_DIR}/solo_r50_fpn_3x_coco.py \ + --data_path ${DATASETS_DIR} \ + --task "precision" \ + --batch_size 1 \ + --target_map 0.331 "$@";check_status +exit ${EXIT_STATUS} \ No newline at end of file diff --git a/models/cv/detection/solov1/ixrt/scripts/infer_solov1_fp16_performance.sh b/models/cv/detection/solov1/ixrt/scripts/infer_solov1_fp16_performance.sh new file mode 100644 index 00000000..50231bf6 --- /dev/null +++ b/models/cv/detection/solov1/ixrt/scripts/infer_solov1_fp16_performance.sh @@ -0,0 +1,95 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +#!/bin/bash + +EXIT_STATUS=0 +check_status() +{ + if ((${PIPESTATUS[0]} != 0));then + EXIT_STATUS=1 + fi +} + +# Run paraments +BSZ=1 +WARM_UP=-1 +TGT=-1 +LOOP_COUNT=-1 +RUN_MODE=MAP +PRECISION=float16 + +# Update arguments +index=0 +options=$@ +arguments=($options) +for argument in $options +do + index=`expr $index + 1` + case $argument in + --bs) BSZ=${arguments[index]};; + --tgt) TGT=${arguments[index]};; + esac +done + +MODEL_NAME="r50_solo_bs1_800x800" + +echo PROJ_DIR : ${PROJ_DIR} +echo CHECKPOINTS_DIR : ${CHECKPOINTS_DIR} +echo DATASETS_DIR : ${DATASETS_DIR} +echo RUN_DIR : ${RUN_DIR} + +step=0 + +# Simplify Model +let step++ +echo; +echo [STEP ${step}] : Simplify Model +SIM_MODEL=${CHECKPOINTS_DIR}/${MODEL_NAME}_sim.onnx +if [ -f ${SIM_MODEL} ];then + echo " "Simplify Model Skipped, ${SIM_MODEL} has been existed +else + python3 ${RUN_DIR}/simplify_model.py \ + --origin_model ${CHECKPOINTS_DIR}/${MODEL_NAME}.onnx \ + --output_model ${SIM_MODEL} + echo " "Generate ${SIM_MODEL} +fi + + +# Build Engine +let step++ +echo; +echo [STEP ${step}] : Build Engine +ENGINE_FILE=${CHECKPOINTS_DIR}/${MODEL_NAME}_${PRECISION}.engine +if [ -f $ENGINE_FILE ];then + echo " "Build Engine Skip, $ENGINE_FILE has been existed +else + python3 ${RUN_DIR}/build_engine.py \ + --model ${SIM_MODEL} \ + --engine ${ENGINE_FILE} + echo " "Generate Engine ${ENGINE_FILE} +fi + +# Inference +let step++ +echo; +echo [STEP ${step}] : Inference +python3 ${RUN_DIR}/solov1_inference.py \ + --engine ${ENGINE_FILE} \ + --cfg_file ${RUN_DIR}/solo_r50_fpn_3x_coco.py \ + --task "pref" \ + --batch_size 1 \ + --target_fps 15 "$@";check_status +exit ${EXIT_STATUS} \ No newline at end of file diff --git a/models/cv/detection/solov1/ixrt/simplify_model.py b/models/cv/detection/solov1/ixrt/simplify_model.py new file mode 100644 index 00000000..1400fd81 --- /dev/null +++ b/models/cv/detection/solov1/ixrt/simplify_model.py @@ -0,0 +1,36 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import onnx +import argparse +from onnxsim import simplify + +# Simplify +def simplify_model(args): + onnx_model = onnx.load(args.origin_model) + model_simp, check = simplify(onnx_model) + model_simp = onnx.shape_inference.infer_shapes(model_simp) + onnx.save(model_simp, args.output_model) + print(" Simplify onnx Done.") + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--origin_model", type=str) + parser.add_argument("--output_model", type=str) + args = parser.parse_args() + return args + +args = parse_args() +simplify_model(args) \ No newline at end of file diff --git a/models/cv/detection/solov1/ixrt/solo_r50_fpn_3x_coco.py b/models/cv/detection/solov1/ixrt/solo_r50_fpn_3x_coco.py new file mode 100755 index 00000000..96b5f4d5 --- /dev/null +++ b/models/cv/detection/solov1/ixrt/solo_r50_fpn_3x_coco.py @@ -0,0 +1,67 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +_base_ = [ + 'coco_instance.py', +] + +# model settings +model = dict( + type='SOLO', + backbone=dict( + type='ResNet', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + frozen_stages=1, + init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50'), + style='pytorch'), + neck=dict( + type='FPN', + in_channels=[256, 512, 1024, 2048], + out_channels=256, + start_level=0, + num_outs=5), + mask_head=dict( + type='SOLOHead', + num_classes=80, + in_channels=256, + stacked_convs=7, + feat_channels=256, + strides=[8, 8, 16, 32, 32], + scale_ranges=((1, 96), (48, 192), (96, 384), (192, 768), (384, 2048)), + pos_scale=0.2, + num_grids=[40, 36, 24, 16, 12], + cls_down_index=0, + loss_mask=dict(type='DiceLoss', use_sigmoid=True, loss_weight=3.0), + loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0), + norm_cfg=dict(type='GN', num_groups=32, requires_grad=True)), + # model training and testing settings + test_cfg=dict( + nms_pre=500, + score_thr=0.1, + mask_thr=0.5, + filter_thr=0.05, + kernel='gaussian', # gaussian/linear + sigma=2.0, + max_per_img=100)) + +# optimizer +optimizer = dict(type='SGD', lr=0.01) diff --git a/models/cv/detection/solov1/ixrt/solo_torch2onnx.py b/models/cv/detection/solov1/ixrt/solo_torch2onnx.py new file mode 100755 index 00000000..a9f0c6d0 --- /dev/null +++ b/models/cv/detection/solov1/ixrt/solo_torch2onnx.py @@ -0,0 +1,95 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from mmdet.apis import init_detector +from mmdet.apis import init_detector, inference_detector, show_result_pyplot +import mmcv +import torch +import onnx +import argparse + +class Model(torch.nn.Module): + def __init__(self,config_file,checkpoint_file): + super().__init__() + self.model = init_detector(config_file, checkpoint_file, device='cuda:0') + + def forward(self, x): + feat = self.model.backbone(x) + out_neck =self.model.neck(feat) + out_head =self.model.mask_head(out_neck) + return out_head + +def parse_args(): + parser = argparse.ArgumentParser() + # engine args + parser.add_argument("--cfg", type=str, default="") + parser.add_argument("--checkpoint", type=str, default="") + parser.add_argument("--batch_size", type=int, default=1) + parser.add_argument("--input_size", type=tuple, default=(800,800)) + args = parser.parse_args() + return args + + +def main(): + args= parse_args() + model = Model(args.cfg,args.checkpoint) + model.eval() + device='cuda:0' + + input = torch.zeros(args.batch_size, 3, args.input_size[0], args.input_size[1]).to(device) + + + output_onnx_name = f'r50_solo_bs{args.batch_size}_{args.input_size[0]}x{args.input_size[1]}.onnx' + + # ################ pytorch onnx 模型导出 + print ("start transfer model to onnx") + torch.onnx.export(model, + input, + output_onnx_name, + input_names=["input"], + output_names=["output"], + do_constant_folding=True, + opset_version=11, + keep_initializers_as_inputs=True, + # dynamic_axes={'input':{0:'batch'}, 'output':{0:'batch'}} + # dynamic_axes={'input':{0:'batch', 2:'h', 3:'w'}, 'output':{0:'batch', 2:'h2', 3:'w2'}} + ) + print ("end transfer model to onnx") + + output_file =output_onnx_name + import onnx + import onnxsim + from mmcv import digit_version + + min_required_version = '0.4.0' + assert digit_version(onnxsim.__version__) >= digit_version( + min_required_version + ), f'Requires to install onnxsim>={min_required_version}' + + model_opt, check_ok = onnxsim.simplify(output_file) + if check_ok: + onnx.save(model_opt, output_file) + print(f'Successfully simplified ONNX model: {output_file}') + else: + print('Failed to simplify ONNX model.') + + +if __name__ == "__main__": + main() + + + + + diff --git a/models/cv/detection/solov1/ixrt/solov1_inference.py b/models/cv/detection/solov1/ixrt/solov1_inference.py new file mode 100644 index 00000000..473bff85 --- /dev/null +++ b/models/cv/detection/solov1/ixrt/solov1_inference.py @@ -0,0 +1,170 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import argparse +import os +import time +from typing import Tuple +import cv2 +import numpy as np +import torch +from mmdet.datasets import build_dataloader, build_dataset +from mmdet.models import build_detector +import mmcv +from mmdet.core import encode_mask_results +import pycuda.autoinit +import pycuda.driver as cuda +import tensorrt +from tqdm import tqdm +import numpy as np +import sys +from common import create_engine_context, get_io_bindings + +def check_target(inference, target): + satisfied = False + if inference > target: + satisfied = True + return satisfied + + +def get_dataloder(args): + cfg_path = args.cfg_file + cfg = mmcv.Config.fromfile(cfg_path) + datasets_path = args.data_path + cfg['data']['val']['img_prefix'] = os.path.join(datasets_path, 'val2017') + cfg['data']['val']['ann_file'] = os.path.join(datasets_path, 'annotations/instances_val2017.json') + dataset = build_dataset(cfg.data.val) + data_loader = build_dataloader(dataset, samples_per_gpu=args.batch_size, workers_per_gpu=args.num_workers, shuffle=False) + model = build_detector(cfg.model, test_cfg=cfg.get('test_cfg')) + return dataset, data_loader, model + + +def eval_coco(args, inputs, outputs, allocations, context): + dataset, dataloader, model = get_dataloder(args) + outputs_0 = np.zeros(outputs[0]["shape"], outputs[0]["dtype"]) + outputs_1 = np.zeros(outputs[1]["shape"], outputs[1]["dtype"]) + outputs_2 = np.zeros(outputs[2]["shape"], outputs[2]["dtype"]) + outputs_3 = np.zeros(outputs[3]["shape"], outputs[3]["dtype"]) + outputs_4 = np.zeros(outputs[4]["shape"], outputs[4]["dtype"]) + outputs_5 = np.zeros(outputs[5]["shape"], outputs[5]["dtype"]) + outputs_6 = np.zeros(outputs[6]["shape"], outputs[6]["dtype"]) + outputs_7 = np.zeros(outputs[7]["shape"], outputs[7]["dtype"]) + outputs_8 = np.zeros(outputs[8]["shape"], outputs[8]["dtype"]) + outputs_9 = np.zeros(outputs[9]["shape"], outputs[9]["dtype"]) + + results = [] + for batch in tqdm(dataloader): + image = batch['img'][0].data.numpy() + img_metas = batch['img_metas'][0].data[0] + # Set input + image = np.ascontiguousarray(image) + cuda.memcpy_htod(inputs[0]["allocation"], image) + context.execute_v2(allocations) + # # Fetch output + cuda.memcpy_dtoh(outputs_0, outputs[0]["allocation"]) + cuda.memcpy_dtoh(outputs_1, outputs[1]["allocation"]) + cuda.memcpy_dtoh(outputs_2, outputs[2]["allocation"]) + cuda.memcpy_dtoh(outputs_3, outputs[3]["allocation"]) + cuda.memcpy_dtoh(outputs_4, outputs[4]["allocation"]) + cuda.memcpy_dtoh(outputs_5, outputs[5]["allocation"]) + cuda.memcpy_dtoh(outputs_6, outputs[6]["allocation"]) + cuda.memcpy_dtoh(outputs_7, outputs[7]["allocation"]) + cuda.memcpy_dtoh(outputs_8, outputs[8]["allocation"]) + cuda.memcpy_dtoh(outputs_9, outputs[9]["allocation"]) + + mask_preds = [] + cls_preds = [] + + mask_preds.append(torch.from_numpy(outputs_0)) + mask_preds.append(torch.from_numpy(outputs_1)) + mask_preds.append(torch.from_numpy(outputs_2)) + mask_preds.append(torch.from_numpy(outputs_3)) + mask_preds.append(torch.from_numpy(outputs_4)) + cls_preds.append(torch.from_numpy(outputs_5)) + cls_preds.append(torch.from_numpy(outputs_6)) + cls_preds.append(torch.from_numpy(outputs_7)) + cls_preds.append(torch.from_numpy(outputs_8)) + cls_preds.append(torch.from_numpy(outputs_9)) + mask_preds.sort(key=lambda x: x.shape[1], reverse=True) + cls_preds.sort(key=lambda x: x.shape[2], reverse=True) + results_list = model.mask_head.get_results(mask_preds, cls_preds, img_metas) + format_results_list = [] + for result in results_list: + format_results_list.append(model.format_results(result)) + + if isinstance(format_results_list[0], tuple): + result = [(bbox_results, encode_mask_results(mask_results)) + for bbox_results, mask_results in format_results_list] + results.extend(result) + eval_results = dataset.evaluate(results, metric=['segm']) + print(eval_results) + segm_mAP = eval_results['segm_mAP'] + return segm_mAP + + +def parse_args(): + parser = argparse.ArgumentParser() + # engine args + parser.add_argument("--engine", type=str, default="") + parser.add_argument("--cfg_file", type=str, default="") + parser.add_argument("--data_path", type=str, default="") + parser.add_argument("--batch_size", type=int, default=1) + parser.add_argument("--num_workers", type=int, default=4) + parser.add_argument("--warp_up", type=int, default=40) + parser.add_argument("--loop_count", type=int, default=50) + parser.add_argument("--target_map", default=0.331, type=float, help="target map") + parser.add_argument("--target_fps", default=15, type=float, help="target fps") + parser.add_argument("--task", default="precision", type=str, help="precision or pref") + args = parser.parse_args() + return args + + +def main(): + args= parse_args() + host_mem = tensorrt.IHostMemory + logger = tensorrt.Logger(tensorrt.Logger.ERROR) + + # Load Engine + engine, context = create_engine_context(args.engine, logger) + inputs, outputs, allocations = get_io_bindings(engine) + + if args.task=="precision": + segm_mAP= eval_coco(args,inputs, outputs, allocations, context) + + print("="*40) + print("segm_mAP:{0}".format(round(segm_mAP,3))) + print("="*40) + print(f"Check segm_mAP Test : {round(segm_mAP,3)} Target:{args.target_map} State : {'Pass' if round(segm_mAP,3) >= args.target_map else 'Fail'}") + status_map = check_target(segm_mAP, args.target_map) + sys.exit(int(not (status_map))) + else: + torch.cuda.synchronize() + start_time = time.time() + for i in range(args.loop_count): + context.execute_v2(allocations) + torch.cuda.synchronize() + end_time = time.time() + forward_time = end_time - start_time + fps = args.loop_count * args.batch_size / forward_time + print("="*40) + print("fps:{0}".format(round(fps,2))) + print("="*40) + print(f"Check fps Test : {round(fps,3)} Target:{args.target_fps} State : {'Pass' if fps >= args.target_fps else 'Fail'}") + status_fps = check_target(fps, args.target_fps) + sys.exit(int(not (status_fps))) + +if __name__ == "__main__": + + main() -- Gitee