From c87bb6f3318a33df85d2481fb7b4152e0304f362 Mon Sep 17 00:00:00 2001 From: "xiaomei.wang" Date: Wed, 16 Oct 2024 14:29:17 +0800 Subject: [PATCH 1/2] Add centernet fp16 IxRT inference. --- models/cv/detection/centernet/ixrt/README.md | 63 ++++ ...ernet_r18-dcnv2_8xb16-crop512-140e_coco.py | 151 ++++++++ .../centernet/ixrt/base/coco_detection.py | 75 ++++ .../centernet/ixrt/base/default_runtime.py | 39 ++ .../centernet/ixrt/base/schedule_1x.py | 43 +++ .../detection/centernet/ixrt/build_engine.py | 62 +++ .../centernet_r18_8xb16-crop512-140e_coco.py | 354 ++++++++++++++++++ models/cv/detection/centernet/ixrt/common.py | 69 ++++ .../centernet/ixrt/deploy_default.py | 41 ++ models/cv/detection/centernet/ixrt/export.py | 74 ++++ .../cv/detection/centernet/ixrt/inference.py | 184 +++++++++ .../scripts/infer_centernet_fp16_accuracy.sh | 34 ++ .../infer_centernet_fp16_performance.sh | 35 ++ 13 files changed, 1224 insertions(+) create mode 100644 models/cv/detection/centernet/ixrt/README.md create mode 100644 models/cv/detection/centernet/ixrt/base/centernet_r18-dcnv2_8xb16-crop512-140e_coco.py create mode 100644 models/cv/detection/centernet/ixrt/base/coco_detection.py create mode 100644 models/cv/detection/centernet/ixrt/base/default_runtime.py create mode 100644 models/cv/detection/centernet/ixrt/base/schedule_1x.py create mode 100644 models/cv/detection/centernet/ixrt/build_engine.py create mode 100644 models/cv/detection/centernet/ixrt/centernet_r18_8xb16-crop512-140e_coco.py create mode 100644 models/cv/detection/centernet/ixrt/common.py create mode 100644 models/cv/detection/centernet/ixrt/deploy_default.py create mode 100644 models/cv/detection/centernet/ixrt/export.py create mode 100644 models/cv/detection/centernet/ixrt/inference.py create mode 100644 models/cv/detection/centernet/ixrt/scripts/infer_centernet_fp16_accuracy.sh create mode 100644 models/cv/detection/centernet/ixrt/scripts/infer_centernet_fp16_performance.sh diff --git a/models/cv/detection/centernet/ixrt/README.md b/models/cv/detection/centernet/ixrt/README.md new file mode 100644 index 00000000..e47ea2f0 --- /dev/null +++ b/models/cv/detection/centernet/ixrt/README.md @@ -0,0 +1,63 @@ +# CenterNet + +## Description + +CenterNet is an efficient object detection model that simplifies the traditional object detection process by representing targets as the center points of their bounding boxes and using keypoint estimation techniques to locate these points. This model not only excels in speed, achieving real-time detection while maintaining high accuracy, but also exhibits good versatility, easily extending to tasks such as 3D object detection and human pose estimation. CenterNet's network architecture employs various optimized fully convolutional networks and combines effective loss functions, making the model training and inference process more efficient. + +## Setup + +### Install + +```bash +# Install libGL +## CentOS +yum install -y mesa-libGL +## Ubuntu +apt install -y libgl1-mesa-dev + +pip3 install onnx +pip3 install tqdm +pip3 install mmdet +pip3 install mmdeploy +pip3 install mmengine +# Contact the Iluvatar administrator to get the mmcv install package. +``` + +### Download + +Pretrained model: + +Dataset: to download the validation dataset. + +### Model Conversion + +```bash +# export onnx model +python3 export.py --weight centernet_resnet18_140e_coco_20210705_093630-bb5b3bf7.pth --cfg centernet_r18_8xb16-crop512-140e_coco.py --output centernet.onnx +``` + +## Inference + +```bash +export DATASETS_DIR=/Path/to/coco/ +export MODEL_PATH=/Path/to/centernet.onnx +``` + +### FP16 + +```bash +# Accuracy +bash scripts/infer_centernet_fp16_accuracy.sh +# Performance +bash scripts/infer_centernet_fp16_performance.sh +``` + +## Results + +Model |BatchSize |Precision |FPS |IOU@0.5 |IOU@0.5:0.95 | +----------|-----------|----------|----------|----------|---------------| +CenterNet | 32 | FP16 | 879.447 | 0.423 | 0.258 | + +## Reference + +mmdetection: diff --git a/models/cv/detection/centernet/ixrt/base/centernet_r18-dcnv2_8xb16-crop512-140e_coco.py b/models/cv/detection/centernet/ixrt/base/centernet_r18-dcnv2_8xb16-crop512-140e_coco.py new file mode 100644 index 00000000..894e4b4f --- /dev/null +++ b/models/cv/detection/centernet/ixrt/base/centernet_r18-dcnv2_8xb16-crop512-140e_coco.py @@ -0,0 +1,151 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + + +_base_ = [ + 'coco_detection.py', + 'schedule_1x.py', 'default_runtime.py', +] + +dataset_type = 'CocoDataset' +data_root = 'data/coco/' + +# model settings +model = dict( + type='CenterNet', + data_preprocessor=dict( + type='DetDataPreprocessor', + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + bgr_to_rgb=True), + backbone=dict( + type='ResNet', + depth=18, + norm_eval=False, + norm_cfg=dict(type='BN'), + init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet18')), + neck=dict( + type='CTResNetNeck', + in_channels=512, + num_deconv_filters=(256, 128, 64), + num_deconv_kernels=(4, 4, 4), + use_dcn=True), + bbox_head=dict( + type='CenterNetHead', + num_classes=80, + in_channels=64, + feat_channels=64, + loss_center_heatmap=dict(type='GaussianFocalLoss', loss_weight=1.0), + loss_wh=dict(type='L1Loss', loss_weight=0.1), + loss_offset=dict(type='L1Loss', loss_weight=1.0)), + train_cfg=None, + test_cfg=dict(topk=100, local_maximum_kernel=3, max_per_img=100)) + +train_pipeline = [ + dict(type='LoadImageFromFile', backend_args={{_base_.backend_args}}), + dict(type='LoadAnnotations', with_bbox=True), + dict( + type='PhotoMetricDistortion', + brightness_delta=32, + contrast_range=(0.5, 1.5), + saturation_range=(0.5, 1.5), + hue_delta=18), + dict( + type='RandomCenterCropPad', + # The cropped images are padded into squares during training, + # but may be less than crop_size. + crop_size=(512, 512), + ratios=(0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3), + mean=[0, 0, 0], + std=[1, 1, 1], + to_rgb=True, + test_pad_mode=None), + # Make sure the output is always crop_size. + dict(type='Resize', scale=(512, 512), keep_ratio=True), + dict(type='RandomFlip', prob=0.5), + dict(type='PackDetInputs') +] +test_pipeline = [ + dict( + type='LoadImageFromFile', + backend_args={{_base_.backend_args}}, + to_float32=True), + # don't need Resize + dict( + type='RandomCenterCropPad', + ratios=None, + border=None, + mean=[0, 0, 0], + std=[1, 1, 1], + to_rgb=True, + test_mode=True, + test_pad_mode=['logical_or', 31], + test_pad_add_pix=1), + dict(type='LoadAnnotations', with_bbox=True), + dict( + type='PackDetInputs', + meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', 'border')) +] + +# Use RepeatDataset to speed up training +train_dataloader = dict( + batch_size=16, + num_workers=4, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=True), + dataset=dict( + _delete_=True, + type='RepeatDataset', + times=5, + dataset=dict( + type=dataset_type, + data_root=data_root, + ann_file='annotations/instances_train2017.json', + data_prefix=dict(img='train2017/'), + filter_cfg=dict(filter_empty_gt=True, min_size=32), + pipeline=train_pipeline, + backend_args={{_base_.backend_args}}, + ))) + +val_dataloader = dict(dataset=dict(pipeline=test_pipeline)) +test_dataloader = val_dataloader + +# optimizer +# Based on the default settings of modern detectors, the SGD effect is better +# than the Adam in the source code, so we use SGD default settings and +# if you use adam+lr5e-4, the map is 29.1. +optim_wrapper = dict(clip_grad=dict(max_norm=35, norm_type=2)) + +max_epochs = 28 +# learning policy +# Based on the default settings of modern detectors, we added warmup settings. +param_scheduler = [ + dict( + type='LinearLR', start_factor=0.001, by_epoch=False, begin=0, + end=1000), + dict( + type='MultiStepLR', + begin=0, + end=max_epochs, + by_epoch=True, + milestones=[18, 24], # the real step is [18*5, 24*5] + gamma=0.1) +] +train_cfg = dict(max_epochs=max_epochs) # the real epoch is 28*5=140 + +# NOTE: `auto_scale_lr` is for automatically scaling LR, +# USER SHOULD NOT CHANGE ITS VALUES. +# base_batch_size = (8 GPUs) x (16 samples per GPU) +auto_scale_lr = dict(base_batch_size=128) diff --git a/models/cv/detection/centernet/ixrt/base/coco_detection.py b/models/cv/detection/centernet/ixrt/base/coco_detection.py new file mode 100644 index 00000000..f58fe67b --- /dev/null +++ b/models/cv/detection/centernet/ixrt/base/coco_detection.py @@ -0,0 +1,75 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +# dataset settings +dataset_type = 'CocoDataset' +data_root = 'data/coco/' + +backend_args = None + +train_pipeline = [ + dict(type='LoadImageFromFile', backend_args=backend_args), + dict(type='LoadAnnotations', with_bbox=True), + dict(type='Resize', scale=(1333, 800), keep_ratio=True), + dict(type='RandomFlip', prob=0.5), + dict(type='PackDetInputs') +] +test_pipeline = [ + dict(type='LoadImageFromFile', backend_args=backend_args), + dict(type='Resize', scale=(1333, 800), keep_ratio=True), + # If you don't have a gt annotation, delete the pipeline + dict(type='LoadAnnotations', with_bbox=True), + dict( + type='PackDetInputs', + meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', + 'scale_factor')) +] +train_dataloader = dict( + batch_size=2, + num_workers=2, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=True), + batch_sampler=dict(type='AspectRatioBatchSampler'), + dataset=dict( + type=dataset_type, + data_root=data_root, + ann_file='annotations/instances_train2017.json', + data_prefix=dict(img='train2017/'), + filter_cfg=dict(filter_empty_gt=True, min_size=32), + pipeline=train_pipeline, + backend_args=backend_args)) +val_dataloader = dict( + batch_size=1, + num_workers=2, + persistent_workers=True, + drop_last=False, + sampler=dict(type='DefaultSampler', shuffle=False), + dataset=dict( + type=dataset_type, + data_root=data_root, + ann_file='annotations/instances_val2017.json', + data_prefix=dict(img='val2017/'), + test_mode=True, + pipeline=test_pipeline, + backend_args=backend_args)) +test_dataloader = val_dataloader + +val_evaluator = dict( + type='CocoMetric', + ann_file=data_root + 'annotations/instances_val2017.json', + metric='bbox', + format_only=False, + backend_args=backend_args) +test_evaluator = val_evaluator \ No newline at end of file diff --git a/models/cv/detection/centernet/ixrt/base/default_runtime.py b/models/cv/detection/centernet/ixrt/base/default_runtime.py new file mode 100644 index 00000000..609d8037 --- /dev/null +++ b/models/cv/detection/centernet/ixrt/base/default_runtime.py @@ -0,0 +1,39 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +default_scope = 'mmdet' + +default_hooks = dict( + timer=dict(type='IterTimerHook'), + logger=dict(type='LoggerHook', interval=50), + param_scheduler=dict(type='ParamSchedulerHook'), + checkpoint=dict(type='CheckpointHook', interval=1), + sampler_seed=dict(type='DistSamplerSeedHook'), + visualization=dict(type='DetVisualizationHook')) + +env_cfg = dict( + cudnn_benchmark=False, + mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), + dist_cfg=dict(backend='nccl'), +) + +vis_backends = [dict(type='LocalVisBackend')] +visualizer = dict( + type='DetLocalVisualizer', vis_backends=vis_backends, name='visualizer') +log_processor = dict(type='LogProcessor', window_size=50, by_epoch=True) + +log_level = 'INFO' +load_from = None +resume = False diff --git a/models/cv/detection/centernet/ixrt/base/schedule_1x.py b/models/cv/detection/centernet/ixrt/base/schedule_1x.py new file mode 100644 index 00000000..9b16d80c --- /dev/null +++ b/models/cv/detection/centernet/ixrt/base/schedule_1x.py @@ -0,0 +1,43 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +# training schedule for 1x +train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=12, val_interval=1) +val_cfg = dict(type='ValLoop') +test_cfg = dict(type='TestLoop') + +# learning rate +param_scheduler = [ + dict( + type='LinearLR', start_factor=0.001, by_epoch=False, begin=0, end=500), + dict( + type='MultiStepLR', + begin=0, + end=12, + by_epoch=True, + milestones=[8, 11], + gamma=0.1) +] + +# optimizer +optim_wrapper = dict( + type='OptimWrapper', + optimizer=dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001)) + +# Default setting for scaling LR automatically +# - `enable` means enable scaling LR automatically +# or not by default. +# - `base_batch_size` = (8 GPUs) x (2 samples per GPU). +auto_scale_lr = dict(enable=False, base_batch_size=16) diff --git a/models/cv/detection/centernet/ixrt/build_engine.py b/models/cv/detection/centernet/ixrt/build_engine.py new file mode 100644 index 00000000..cf7ba4ee --- /dev/null +++ b/models/cv/detection/centernet/ixrt/build_engine.py @@ -0,0 +1,62 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import os +import cv2 +import argparse +import numpy as np + +import torch +import tensorrt +from tensorrt import Dims + +def main(config): + IXRT_LOGGER = tensorrt.Logger(tensorrt.Logger.WARNING) + builder = tensorrt.Builder(IXRT_LOGGER) + EXPLICIT_BATCH = 1 << (int)(tensorrt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH) + network = builder.create_network(EXPLICIT_BATCH) + build_config = builder.create_builder_config() + profile = builder.create_optimization_profile() + profile.set_shape("input", Dims([32, 3, 672, 672]), Dims([32, 3, 672, 672]), Dims([32, 3, 672, 672])) + build_config.add_optimization_profile(profile) + parser = tensorrt.OnnxParser(network, IXRT_LOGGER) + parser.parse_from_file(config.model) + + precision = tensorrt.BuilderFlag.INT8 if config.precision == "int8" else tensorrt.BuilderFlag.FP16 + # print("precision : ", precision) + build_config.set_flag(precision) + num_inputs = network.num_inputs + + for i in range(num_inputs): + input_tensor = network.get_input(i) + input_tensor.shape = Dims([32, 3, 672, 672]) + + plan = builder.build_serialized_network(network, build_config) + engine_file_path = config.engine + with open(engine_file_path, "wb") as f: + f.write(plan) + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--model", type=str) + parser.add_argument("--precision", type=str, choices=["float16", "int8", "float32"], default="int8", + help="The precision of datatype") + parser.add_argument("--engine", type=str, default=None) + args = parser.parse_args() + return args + +if __name__ == "__main__": + args = parse_args() + main(args) \ No newline at end of file diff --git a/models/cv/detection/centernet/ixrt/centernet_r18_8xb16-crop512-140e_coco.py b/models/cv/detection/centernet/ixrt/centernet_r18_8xb16-crop512-140e_coco.py new file mode 100644 index 00000000..cb986c51 --- /dev/null +++ b/models/cv/detection/centernet/ixrt/centernet_r18_8xb16-crop512-140e_coco.py @@ -0,0 +1,354 @@ +auto_scale_lr = dict(base_batch_size=128, enable=False) +backend_args = None +data_root = 'data/coco/' +dataset_type = 'CocoDataset' +default_hooks = dict( + checkpoint=dict(interval=1, type='CheckpointHook'), + logger=dict(interval=50, type='LoggerHook'), + param_scheduler=dict(type='ParamSchedulerHook'), + sampler_seed=dict(type='DistSamplerSeedHook'), + timer=dict(type='IterTimerHook'), + visualization=dict(type='DetVisualizationHook')) +default_scope = 'mmdet' +env_cfg = dict( + cudnn_benchmark=False, + dist_cfg=dict(backend='nccl'), + mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0)) +load_from = None +log_level = 'INFO' +log_processor = dict(by_epoch=True, type='LogProcessor', window_size=50) +max_epochs = 28 +model = dict( + backbone=dict( + depth=18, + init_cfg=dict(checkpoint='torchvision://resnet18', type='Pretrained'), + norm_cfg=dict(type='BN'), + norm_eval=False, + type='ResNet'), + bbox_head=dict( + feat_channels=64, + in_channels=64, + loss_center_heatmap=dict(loss_weight=1.0, type='GaussianFocalLoss'), + loss_offset=dict(loss_weight=1.0, type='L1Loss'), + loss_wh=dict(loss_weight=0.1, type='L1Loss'), + num_classes=80, + type='CenterNetHead'), + data_preprocessor=dict( + bgr_to_rgb=True, + mean=[ + 123.675, + 116.28, + 103.53, + ], + std=[ + 58.395, + 57.12, + 57.375, + ], + type='DetDataPreprocessor'), + neck=dict( + in_channels=512, + num_deconv_filters=( + 256, + 128, + 64, + ), + num_deconv_kernels=( + 4, + 4, + 4, + ), + type='CTResNetNeck', + use_dcn=False), + test_cfg=dict(local_maximum_kernel=3, max_per_img=100, topk=100), + train_cfg=None, + type='CenterNet') +optim_wrapper = dict( + clip_grad=dict(max_norm=35, norm_type=2), + optimizer=dict(lr=0.02, momentum=0.9, type='SGD', weight_decay=0.0001), + type='OptimWrapper') +param_scheduler = [ + dict( + begin=0, by_epoch=False, end=1000, start_factor=0.001, + type='LinearLR'), + dict( + begin=0, + by_epoch=True, + end=28, + gamma=0.1, + milestones=[ + 18, + 24, + ], + type='MultiStepLR'), +] +resume = False +test_cfg = dict(type='TestLoop') +test_dataloader = dict( + batch_size=32, + dataset=dict( + ann_file='annotations/instances_val2017.json', + backend_args=None, + data_prefix=dict(img='images/val2017/'), + data_root='./coco', + pipeline=[ + dict(backend_args=None, to_float32=True, type='LoadImageFromFile'), + dict( + border=None, + mean=[ + 0, + 0, + 0, + ], + ratios=None, + std=[ + 1, + 1, + 1, + ], + test_mode=True, + test_pad_add_pix=1, + test_pad_mode=[ + 'logical_or', + 31, + ], + to_rgb=True, + type='RandomCenterCropPad'), + dict(type='LoadAnnotations', with_bbox=True), + dict( + meta_keys=( + 'img_id', + 'img_path', + 'ori_shape', + 'img_shape', + 'border', + ), + type='PackDetInputs'), + ], + test_mode=True, + type='CocoDataset'), + drop_last=False, + num_workers=2, + persistent_workers=True, + sampler=dict(shuffle=False, type='DefaultSampler')) +test_evaluator = dict( + ann_file='./coco/annotations/instances_val2017.json', + backend_args=None, + format_only=False, + metric='bbox', + type='CocoMetric') +test_pipeline = [ + dict(backend_args=None, to_float32=True, type='LoadImageFromFile'), + dict( + border=None, + mean=[ + 0, + 0, + 0, + ], + ratios=None, + std=[ + 1, + 1, + 1, + ], + test_mode=True, + test_pad_add_pix=1, + test_pad_mode=[ + 'logical_or', + 31, + ], + to_rgb=True, + type='RandomCenterCropPad'), + dict(type='LoadAnnotations', with_bbox=True), + dict( + meta_keys=( + 'img_id', + 'img_path', + 'ori_shape', + 'img_shape', + 'border', + ), + type='PackDetInputs'), +] +train_cfg = dict(max_epochs=28, type='EpochBasedTrainLoop', val_interval=1) +train_dataloader = dict( + batch_sampler=dict(type='AspectRatioBatchSampler'), + batch_size=16, + dataset=dict( + dataset=dict( + ann_file='annotations/instances_train2017.json', + backend_args=None, + data_prefix=dict(img='train2017/'), + data_root='data/coco/', + filter_cfg=dict(filter_empty_gt=True, min_size=32), + pipeline=[ + dict(backend_args=None, type='LoadImageFromFile'), + dict(type='LoadAnnotations', with_bbox=True), + dict( + brightness_delta=32, + contrast_range=( + 0.5, + 1.5, + ), + hue_delta=18, + saturation_range=( + 0.5, + 1.5, + ), + type='PhotoMetricDistortion'), + dict( + crop_size=( + 512, + 512, + ), + mean=[ + 0, + 0, + 0, + ], + ratios=( + 0.6, + 0.7, + 0.8, + 0.9, + 1.0, + 1.1, + 1.2, + 1.3, + ), + std=[ + 1, + 1, + 1, + ], + test_pad_mode=None, + to_rgb=True, + type='RandomCenterCropPad'), + dict(keep_ratio=True, scale=( + 512, + 512, + ), type='Resize'), + dict(prob=0.5, type='RandomFlip'), + dict(type='PackDetInputs'), + ], + type='CocoDataset'), + times=5, + type='RepeatDataset'), + num_workers=4, + persistent_workers=True, + sampler=dict(shuffle=True, type='DefaultSampler')) +train_pipeline = [ + dict(backend_args=None, type='LoadImageFromFile'), + dict(type='LoadAnnotations', with_bbox=True), + dict( + brightness_delta=32, + contrast_range=( + 0.5, + 1.5, + ), + hue_delta=18, + saturation_range=( + 0.5, + 1.5, + ), + type='PhotoMetricDistortion'), + dict( + crop_size=( + 512, + 512, + ), + mean=[ + 0, + 0, + 0, + ], + ratios=( + 0.6, + 0.7, + 0.8, + 0.9, + 1.0, + 1.1, + 1.2, + 1.3, + ), + std=[ + 1, + 1, + 1, + ], + test_pad_mode=None, + to_rgb=True, + type='RandomCenterCropPad'), + dict(keep_ratio=True, scale=( + 512, + 512, + ), type='Resize'), + dict(prob=0.5, type='RandomFlip'), + dict(type='PackDetInputs'), +] +val_cfg = dict(type='ValLoop') +val_dataloader = dict( + batch_size=1, + dataset=dict( + ann_file='annotations/instances_val2017.json', + backend_args=None, + data_prefix=dict(img='val2017/'), + data_root='data/coco/', + pipeline=[ + dict(backend_args=None, to_float32=True, type='LoadImageFromFile'), + dict( + border=None, + mean=[ + 0, + 0, + 0, + ], + ratios=None, + std=[ + 1, + 1, + 1, + ], + test_mode=True, + test_pad_add_pix=1, + test_pad_mode=[ + 'logical_or', + 31, + ], + to_rgb=True, + type='RandomCenterCropPad'), + dict(type='LoadAnnotations', with_bbox=True), + dict( + meta_keys=( + 'img_id', + 'img_path', + 'ori_shape', + 'img_shape', + 'border', + ), + type='PackDetInputs'), + ], + test_mode=True, + type='CocoDataset'), + drop_last=False, + num_workers=2, + persistent_workers=True, + sampler=dict(shuffle=False, type='DefaultSampler')) +val_evaluator = dict( + ann_file='data/coco/annotations/instances_val2017.json', + backend_args=None, + format_only=False, + metric='bbox', + type='CocoMetric') +vis_backends = [ + dict(type='LocalVisBackend'), +] +visualizer = dict( + name='visualizer', + type='DetLocalVisualizer', + vis_backends=[ + dict(type='LocalVisBackend'), + ]) +work_dir = './' diff --git a/models/cv/detection/centernet/ixrt/common.py b/models/cv/detection/centernet/ixrt/common.py new file mode 100644 index 00000000..ef92a6ba --- /dev/null +++ b/models/cv/detection/centernet/ixrt/common.py @@ -0,0 +1,69 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import os +import cv2 +import glob +import torch +import tensorrt +import numpy as np +from cuda import cuda, cudart + +def create_engine_context(engine_path, logger): + with open(engine_path, "rb") as f: + runtime = tensorrt.Runtime(logger) + assert runtime + engine = runtime.deserialize_cuda_engine(f.read()) + assert engine + context = engine.create_execution_context() + assert context + + return engine, context + +def get_io_bindings(engine): + # Setup I/O bindings + inputs = [] + outputs = [] + allocations = [] + + for i in range(engine.num_bindings): + is_input = False + if engine.binding_is_input(i): + is_input = True + name = engine.get_binding_name(i) + dtype = engine.get_binding_dtype(i) + shape = engine.get_binding_shape(i) + if is_input: + batch_size = shape[0] + size = np.dtype(tensorrt.nptype(dtype)).itemsize + for s in shape: + size *= s + err, allocation = cudart.cudaMalloc(size) + assert err == cudart.cudaError_t.cudaSuccess + binding = { + "index": i, + "name": name, + "dtype": np.dtype(tensorrt.nptype(dtype)), + "shape": list(shape), + "allocation": allocation, + "nbytes": size, + } + print(f"binding {i}, name : {name} dtype : {np.dtype(tensorrt.nptype(dtype))} shape : {list(shape)}") + allocations.append(allocation) + if engine.binding_is_input(i): + inputs.append(binding) + else: + outputs.append(binding) + return inputs, outputs, allocations \ No newline at end of file diff --git a/models/cv/detection/centernet/ixrt/deploy_default.py b/models/cv/detection/centernet/ixrt/deploy_default.py new file mode 100644 index 00000000..b8d8e43d --- /dev/null +++ b/models/cv/detection/centernet/ixrt/deploy_default.py @@ -0,0 +1,41 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +onnx_config = dict( + type='onnx', + export_params=True, + keep_initializers_as_inputs=False, + opset_version=11, + save_file='end2end.onnx', + input_names=['input'], + output_names=['output'], + input_shape=None, + optimize=True) + +codebase_config = dict( + type='mmdet', + task='ObjectDetection', + model_type='end2end', + post_processing=dict( + score_threshold=0.05, + confidence_threshold=0.005, + iou_threshold=0.5, + max_output_boxes_per_class=200, + pre_top_k=5000, + keep_top_k=100, + background_label_id=-1, + )) + +backend_config = dict(type='onnxruntime') \ No newline at end of file diff --git a/models/cv/detection/centernet/ixrt/export.py b/models/cv/detection/centernet/ixrt/export.py new file mode 100644 index 00000000..25672ab2 --- /dev/null +++ b/models/cv/detection/centernet/ixrt/export.py @@ -0,0 +1,74 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +import argparse + +import torch +from mmdeploy.utils import load_config +from mmdeploy.apis import build_task_processor + +def parse_args(): + parser = argparse.ArgumentParser() + + parser.add_argument("--weight", + type=str, + required=True, + help="pytorch model weight.") + + parser.add_argument("--cfg", + type=str, + required=True, + help="model config file.") + + parser.add_argument("--output", + type=str, + required=True, + help="export onnx model path.") + + args = parser.parse_args() + return args + +def main(): + args = parse_args() + + deploy_cfg = 'deploy_default.py' + model_cfg = args.cfg + model_checkpoint = args.weight + + deploy_cfg, model_cfg = load_config(deploy_cfg, model_cfg) + + task_processor = build_task_processor(model_cfg, deploy_cfg, device='cpu') + + model = task_processor.build_pytorch_model(model_checkpoint) + + input_names = ['input'] + output_names = ['output'] + dynamic_axes = {'input': {0: '-1'}, 'output': {0: '-1'}} + dummy_input = torch.randn(1, 3, 672, 672) + + torch.onnx.export( + model, + dummy_input, + args.output, + input_names = input_names, + dynamic_axes = dynamic_axes, + output_names = output_names, + opset_version=13 + ) + + print("Export onnx model successfully! ") + +if __name__ == '__main__': + main() + diff --git a/models/cv/detection/centernet/ixrt/inference.py b/models/cv/detection/centernet/ixrt/inference.py new file mode 100644 index 00000000..3e7f954f --- /dev/null +++ b/models/cv/detection/centernet/ixrt/inference.py @@ -0,0 +1,184 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import os +import time +import argparse +import tensorrt +import torch +import torchvision +import numpy as np +from tensorrt import Dims +from cuda import cuda, cudart +from tqdm import tqdm +from mmdet.registry import RUNNERS +from mmengine.config import Config + +from common import create_engine_context, get_io_bindings + +def parse_args(): + parser = argparse.ArgumentParser() + + parser.add_argument("--engine", + type=str, + required=True, + help="igie engine path.") + + parser.add_argument("--batchsize", + type=int, + required=True, + help="inference batch size.") + + parser.add_argument("--datasets", + type=str, + required=True, + help="datasets path.") + + parser.add_argument("--input_name", + type=str, + required=True, + help="input name of the model.") + + parser.add_argument("--warmup", + type=int, + default=3, + help="number of warmup before test.") + + parser.add_argument("--acc_target", + type=float, + default=None, + help="Model inference Accuracy target.") + + parser.add_argument("--fps_target", + type=float, + default=700.0, + help="Model inference FPS target.") + + parser.add_argument("--perf_only", + type=bool, + default=False, + help="Run performance test only") + + args = parser.parse_args() + + return args + +def main(): + args = parse_args() + + batch_size = args.batchsize + + host_mem = tensorrt.IHostMemory + logger = tensorrt.Logger(tensorrt.Logger.ERROR) + + # Load Engine && I/O bindings + engine, context = create_engine_context(args.engine, logger) + inputs, outputs, allocations = get_io_bindings(engine) + + if args.warmup > 0: + print("\nWarm Start.") + for i in range(args.warmup): + context.execute_v2(allocations) + print("Warm Done.") + + # just run perf test + if args.perf_only: + torch.cuda.synchronize() + start_time = time.time() + + for i in range(1000): + context.execute_v2(allocations) + + torch.cuda.synchronize() + end_time = time.time() + forward_time = end_time - start_time + num_samples = 1000 * args.batchsize + fps = num_samples / forward_time + + print("FPS : ", fps) + print(f"Performance Check : Test {fps} >= target {args.fps_target}") + if fps >= args.fps_target: + print("pass!") + exit() + else: + print("failed!") + exit(1) + else: + # Runner config + cfg = Config.fromfile("centernet_r18_8xb16-crop512-140e_coco.py") + cfg.work_dir = "./" + + cfg['test_dataloader']['batch_size'] = batch_size + cfg['test_dataloader']['dataset']['data_root'] = args.datasets + cfg['test_dataloader']['dataset']['data_prefix']['img'] = 'images/val2017/' + cfg['test_evaluator']['ann_file'] = os.path.join(args.datasets, 'annotations/instances_val2017.json') + + runner = RUNNERS.build(cfg) + + for input_data in tqdm(runner.test_dataloader): + + input_data = runner.model.data_preprocessor(input_data, False) + image = input_data['inputs'].cpu() + image = image.numpy().astype(inputs[0]["dtype"]) + + pad_batch = len(image) != batch_size + if pad_batch: + origin_size = len(image) + image = np.resize(image, (batch_size, *image.shape[1:])) + + image = np.ascontiguousarray(image) + + (err,) = cudart.cudaMemcpy( + inputs[0]["allocation"], + image, + image.nbytes, + cudart.cudaMemcpyKind.cudaMemcpyHostToDevice, + ) + assert err == cudart.cudaError_t.cudaSuccess + # cuda.memcpy_htod(inputs[0]["allocation"], batch_data) + context.execute_v2(allocations) + + results = [] + + for i in range(len(outputs)): + output = np.zeros(outputs[i]["shape"], outputs[i]["dtype"]) + (err,) = cudart.cudaMemcpy( + output, + outputs[i]["allocation"], + outputs[i]["nbytes"], + cudart.cudaMemcpyKind.cudaMemcpyDeviceToHost, + ) + assert err == cudart.cudaError_t.cudaSuccess + + if pad_batch: + output = output[:origin_size] + + output = torch.from_numpy(output) + results.append(output) + + batch_img_metas = [ + data_samples.metainfo for data_samples in input_data['data_samples'] + ] + + results_list = runner.model.bbox_head.predict_by_feat([results[0]], [results[1]], [results[2]], batch_img_metas=batch_img_metas, rescale=True) + + batch_data_samples = runner.model.add_pred_to_datasample(input_data['data_samples'], results_list) + + runner.test_evaluator.process(data_samples=batch_data_samples, data_batch=input_data) + + metrics = runner.test_evaluator.evaluate(len(runner.test_dataloader.dataset)) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/models/cv/detection/centernet/ixrt/scripts/infer_centernet_fp16_accuracy.sh b/models/cv/detection/centernet/ixrt/scripts/infer_centernet_fp16_accuracy.sh new file mode 100644 index 00000000..644737b5 --- /dev/null +++ b/models/cv/detection/centernet/ixrt/scripts/infer_centernet_fp16_accuracy.sh @@ -0,0 +1,34 @@ +#!/bin/bash + +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +batchsize=32 +model_path=${MODEL_PATH} +datasets_path=${DATASETS_DIR} + +# build engine +python3 build_engine.py \ + --model ${model_path} \ + --precision float16 \ + --engine centernet_bs_${batchsize}_fp16.engine + + +# inference +python3 inference.py \ + --engine centernet_bs_${batchsize}_fp16.engine \ + --batchsize ${batchsize} \ + --input_name input \ + --datasets ${datasets_path} \ No newline at end of file diff --git a/models/cv/detection/centernet/ixrt/scripts/infer_centernet_fp16_performance.sh b/models/cv/detection/centernet/ixrt/scripts/infer_centernet_fp16_performance.sh new file mode 100644 index 00000000..9e06e472 --- /dev/null +++ b/models/cv/detection/centernet/ixrt/scripts/infer_centernet_fp16_performance.sh @@ -0,0 +1,35 @@ +#!/bin/bash + +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +batchsize=32 +model_path=${MODEL_PATH} +datasets_path=${DATASETS_DIR} + +# build engine +python3 build_engine.py \ + --model ${model_path} \ + --precision float16 \ + --engine centernet_bs_${batchsize}_fp16.engine + + +# inference +python3 inference.py \ + --engine centernet_bs_${batchsize}_fp16.engine \ + --batchsize ${batchsize} \ + --input_name input \ + --datasets ${datasets_path} \ + --perf_only True \ No newline at end of file -- Gitee From f448b8b2d754a2d5cc7cc2f66737c9aba8cf5ac1 Mon Sep 17 00:00:00 2001 From: may Date: Thu, 17 Oct 2024 09:30:21 +0000 Subject: [PATCH 2/2] update models/cv/detection/centernet/ixrt/centernet_r18_8xb16-crop512-140e_coco.py. Signed-off-by: may --- .../ixrt/centernet_r18_8xb16-crop512-140e_coco.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/models/cv/detection/centernet/ixrt/centernet_r18_8xb16-crop512-140e_coco.py b/models/cv/detection/centernet/ixrt/centernet_r18_8xb16-crop512-140e_coco.py index cb986c51..10b30dda 100644 --- a/models/cv/detection/centernet/ixrt/centernet_r18_8xb16-crop512-140e_coco.py +++ b/models/cv/detection/centernet/ixrt/centernet_r18_8xb16-crop512-140e_coco.py @@ -1,3 +1,17 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. auto_scale_lr = dict(base_batch_size=128, enable=False) backend_args = None data_root = 'data/coco/' -- Gitee