diff --git a/models/cv/pose_estimation/rtmpose/ixrt/README.md b/models/cv/pose_estimation/rtmpose/ixrt/README.md new file mode 100644 index 0000000000000000000000000000000000000000..417a73dbb652b1e252ce920d4afbfa16e9c1798b --- /dev/null +++ b/models/cv/pose_estimation/rtmpose/ixrt/README.md @@ -0,0 +1,52 @@ +# RTMPose + +## Description + +RTMPose, a state-of-the-art framework developed by Shanghai AI Laboratory, excels in real-time multi-person pose estimation by integrating an innovative model architecture with the efficiency of the MMPose foundation. The framework's architecture is meticulously designed to enhance performance and reduce latency, making it suitable for a variety of applications where real-time analysis is crucial. + +## Setup + +### Install + +```bash +# Install libGL +## CentOS +yum install -y mesa-libGL +## Ubuntu +apt install -y libgl1-mesa-dev + +pip3 install onnx +pip3 install tqdm +pip3 install onnxsim +pip3 install mmdet==3.3.0 +pip3 install mmpose==1.3.1 +pip3 install mmdeploy==1.3.1 +pip3 install mmengine==0.10.4 +``` + +### Download + +Pretrained model: + +Dataset: to download the validation dataset. + +## Model Conversion +```bash +# export onnx model + +mkdir -p data/rtmpose + +wget -P data/rtmpose/ https://download.openmmlab.com/mmpose/v1/projects/rtmposev1/rtmpose-m_simcc-aic-coco_pt-aic-coco_420e-256x192-63eb25f7_20230126.pth + +python3 export.py --weight data/rtmpose/rtmpose-m_simcc-aic-coco_pt-aic-coco_420e-256x192-63eb25f7_20230126.pth --cfg rtmpose-m_8xb256-420e_coco-256x192.py --input 1,3,256,192 --output data/rtmpose/rtmpose.onnx + +# use onnxsim optimize onnx model +onnxsim data/rtmpose/rtmpose.onnx data/rtmpose/rtmpose_opt.onnx +``` + +## TestSample +```bash +python3 predict.py --model data/rtmpose/rtmpose_opt.onnx --precision fp16 --img_path demo/demo.jpg +``` + + diff --git a/models/cv/pose_estimation/rtmpose/ixrt/demo/demo.jpg b/models/cv/pose_estimation/rtmpose/ixrt/demo/demo.jpg new file mode 100644 index 0000000000000000000000000000000000000000..78718f5f5c031d1fed853b55878083058f289755 Binary files /dev/null and b/models/cv/pose_estimation/rtmpose/ixrt/demo/demo.jpg differ diff --git a/models/cv/pose_estimation/rtmpose/ixrt/deploy_default.py b/models/cv/pose_estimation/rtmpose/ixrt/deploy_default.py new file mode 100644 index 0000000000000000000000000000000000000000..e0c613591e34f7502a09fb2f6e65a1229b315a5f --- /dev/null +++ b/models/cv/pose_estimation/rtmpose/ixrt/deploy_default.py @@ -0,0 +1,29 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +onnx_config = dict( + type='onnx', + export_params=True, + keep_initializers_as_inputs=False, + opset_version=11, + save_file='end2end.onnx', + input_names=['input'], + output_names=['output'], + input_shape=None, + optimize=True) + +codebase_config = dict(type='mmpose', task='PoseDetection') + +backend_config = dict(type='onnxruntime') diff --git a/models/cv/pose_estimation/rtmpose/ixrt/export.py b/models/cv/pose_estimation/rtmpose/ixrt/export.py new file mode 100644 index 0000000000000000000000000000000000000000..4af5f647845e5c04e95d51d1d6030ddb613efdab --- /dev/null +++ b/models/cv/pose_estimation/rtmpose/ixrt/export.py @@ -0,0 +1,78 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +import argparse + +import torch +from mmdeploy.utils import load_config +from mmdeploy.apis import build_task_processor + +def parse_args(): + parser = argparse.ArgumentParser() + + parser.add_argument("--weight", + type=str, + required=True, + help="pytorch model weight.") + + parser.add_argument("--cfg", + type=str, + required=True, + help="model config file.") + + parser.add_argument("--input", + required=True, + help="model input.") + + parser.add_argument("--output", + type=str, + required=True, + help="export onnx model path.") + + args = parser.parse_args() + return args + +def main(): + args = parse_args() + + deploy_cfg = 'deploy_default.py' + model_cfg = args.cfg + model_checkpoint = args.weight + + deploy_cfg, model_cfg = load_config(deploy_cfg, model_cfg) + + task_processor = build_task_processor(model_cfg, deploy_cfg, device='cpu') + + model = task_processor.build_pytorch_model(model_checkpoint) + + input_names = ['input'] + dynamic_axes = {'input': {0: '-1'}} + + input_shape = [int(item) for item in args.input.split(",")] + + dummy_input = torch.randn(input_shape) + + torch.onnx.export( + model, + dummy_input, + args.output, + input_names = input_names, + #dynamic_axes = dynamic_axes, + opset_version=13 + ) + + print("Export onnx model successfully! ") + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/models/cv/pose_estimation/rtmpose/ixrt/predict.py b/models/cv/pose_estimation/rtmpose/ixrt/predict.py new file mode 100644 index 0000000000000000000000000000000000000000..9d11f889c2c1e4dc941eded74776cae39946f757 --- /dev/null +++ b/models/cv/pose_estimation/rtmpose/ixrt/predict.py @@ -0,0 +1,155 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import copy +import argparse +import numpy as np +from PIL import Image + + +import torch +import os + +from mmcv.image import imread +from mmengine.dataset import Compose, pseudo_collate +from mmengine.registry import init_default_scope + +from mmpose.apis import init_model +from mmpose.registry import VISUALIZERS +from mmpose.structures import merge_data_samples + +from tensorrt_common import create_engine_from_onnx,create_context,get_ixrt_output + +def parse_args(): + parser = argparse.ArgumentParser() + + parser.add_argument("--model", + type=str, + required=True, + help="the path of the model.") + + # parser.add_argument("--input_name", + # type=str, + # required=True, + # help="input name of the model.") + + parser.add_argument("--img_path", + type=str, + required=True, + help="image path.") + + parser.add_argument("--conf", + type=float, + default=0.25, + help="confidence threshold.") + + parser.add_argument("--iou", + type=float, + default=0.65, + help="iou threshold.") + + parser.add_argument("--precision", + type=str, + choices=["fp32", "fp16", "int8"], + required=True, + help="model inference precision.") + + args = parser.parse_args() + + return args + +def preprocess(model, img, bboxes=None, bbox_format="xyxy"): + scope = model.cfg.get('default_scope', 'mmpose') + + if scope is not None: + init_default_scope(scope) + + pipeline = Compose(model.cfg.test_dataloader.dataset.pipeline) + + # get bbox from the image size + if isinstance(img, str): + w, h = Image.open(img).size + else: + h, w = img.shape[:2] + + bboxes = np.array([[0, 0, w, h]], dtype=np.float32) + + # construct batch data samples + data_list = [] + for bbox in bboxes: + if isinstance(img, str): + data_info = dict(img_path=img) + else: + data_info = dict(img=img) + data_info['bbox'] = bbox[None] # shape (1, 4) + data_info['bbox_score'] = np.ones(1, dtype=np.float32) # shape (1,) + data_info.update(model.dataset_meta) + data_list.append(pipeline(data_info)) + + data = pseudo_collate(data_list) + + return data + + +def main(): + args = parse_args() + engine_file = args.model.replace(".onnx",".engine") + create_engine_from_onnx(args.model,engine_file) + + engine, context = create_context(engine_file) + + + + model = init_model('rtmpose-m_8xb256-420e_coco-256x192.py') + model.cfg.visualizer.radius = 3 + model.cfg.visualizer.alpha = 0.8 + model.cfg.visualizer.line_width = 1 + + visualizer = VISUALIZERS.build(model.cfg.visualizer) + visualizer.set_dataset_meta(model.dataset_meta, skeleton_style="mmpose") + + outputs = [] + + # get inputs + inputs = preprocess(model, args.img_path) + input_data = model.data_preprocessor(inputs, False) + input_data = input_data['inputs'].cpu().numpy() + + outputs = get_ixrt_output(engine, context,input_data) + + preds = model.head.decode((torch.from_numpy(outputs[0]), torch.from_numpy(outputs[1]))) + + if isinstance(preds, tuple): + batch_pred_instances, batch_pred_fields = preds + else: + batch_pred_instances = preds + batch_pred_fields = None + + batch_data_samples = model.add_pred_to_datasample(batch_pred_instances, batch_pred_fields, inputs['data_samples']) + results = merge_data_samples(batch_data_samples) + + img = imread(args.img_path, channel_order='rgb') + visualizer.add_datasample( + 'result', + img, + data_sample=results, + draw_gt=False, + draw_bbox=True, + out_file="./result.jpg") + + print("Results saved as result.jpg.") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/models/cv/pose_estimation/rtmpose/ixrt/result.jpg b/models/cv/pose_estimation/rtmpose/ixrt/result.jpg new file mode 100644 index 0000000000000000000000000000000000000000..dbe8a434f24bad24564f7817f12247fdec75858e Binary files /dev/null and b/models/cv/pose_estimation/rtmpose/ixrt/result.jpg differ diff --git a/models/cv/pose_estimation/rtmpose/ixrt/rtmpose-m_8xb256-420e_coco-256x192.py b/models/cv/pose_estimation/rtmpose/ixrt/rtmpose-m_8xb256-420e_coco-256x192.py new file mode 100644 index 0000000000000000000000000000000000000000..4bcafcc5816181f5f937808a12a979a458173828 --- /dev/null +++ b/models/cv/pose_estimation/rtmpose/ixrt/rtmpose-m_8xb256-420e_coco-256x192.py @@ -0,0 +1,465 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +auto_scale_lr = dict(base_batch_size=1024) +backend_args = dict(backend='local') +base_lr = 0.004 +codec = dict( + input_size=( + 192, + 256, + ), + normalize=False, + sigma=( + 4.9, + 5.66, + ), + simcc_split_ratio=2.0, + type='SimCCLabel', + use_dark=False) +custom_hooks = [ + dict( + ema_type='ExpMomentumEMA', + momentum=0.0002, + priority=49, + type='EMAHook', + update_buffers=True), + dict( + switch_epoch=390, + switch_pipeline=[ + dict(backend_args=dict(backend='local'), type='LoadImage'), + dict(type='GetBBoxCenterScale'), + dict(direction='horizontal', type='RandomFlip'), + dict(type='RandomHalfBody'), + dict( + rotate_factor=60, + scale_factor=[ + 0.75, + 1.25, + ], + shift_factor=0.0, + type='RandomBBoxTransform'), + dict(input_size=( + 192, + 256, + ), type='TopdownAffine'), + dict(type='mmdet.YOLOXHSVRandomAug'), + dict( + transforms=[ + dict(p=0.1, type='Blur'), + dict(p=0.1, type='MedianBlur'), + dict( + max_height=0.4, + max_holes=1, + max_width=0.4, + min_height=0.2, + min_holes=1, + min_width=0.2, + p=0.5, + type='CoarseDropout'), + ], + type='Albumentation'), + dict( + encoder=dict( + input_size=( + 192, + 256, + ), + normalize=False, + sigma=( + 4.9, + 5.66, + ), + simcc_split_ratio=2.0, + type='SimCCLabel', + use_dark=False), + type='GenerateTarget'), + dict(type='PackPoseInputs'), + ], + type='mmdet.PipelineSwitchHook'), +] +data_mode = 'topdown' +data_root = 'data/coco/' +dataset_type = 'CocoDataset' +default_hooks = dict( + badcase=dict( + _scope_='mmpose', + badcase_thr=5, + enable=False, + metric_type='loss', + out_dir='badcase', + type='BadCaseAnalysisHook'), + checkpoint=dict( + _scope_='mmpose', + interval=10, + max_keep_ckpts=1, + rule='greater', + save_best='coco/AP', + type='CheckpointHook'), + logger=dict(_scope_='mmpose', interval=50, type='LoggerHook'), + param_scheduler=dict(_scope_='mmpose', type='ParamSchedulerHook'), + sampler_seed=dict(_scope_='mmpose', type='DistSamplerSeedHook'), + timer=dict(_scope_='mmpose', type='IterTimerHook'), + visualization=dict( + _scope_='mmpose', enable=False, type='PoseVisualizationHook')) +default_scope = 'mmpose' +env_cfg = dict( + cudnn_benchmark=False, + dist_cfg=dict(backend='nccl'), + mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0)) +input_size = ( + 192, + 256, +) +load_from = None +log_level = 'ERROR' +log_processor = dict( + _scope_='mmpose', + by_epoch=True, + num_digits=6, + type='LogProcessor', + window_size=50) +max_epochs = 420 +model = dict( + backbone=dict( + _scope_='mmdet', + act_cfg=dict(type='SiLU'), + arch='P5', + channel_attention=True, + deepen_factor=0.67, + expand_ratio=0.5, + init_cfg=dict( + checkpoint= + 'https://download.openmmlab.com/mmpose/v1/projects/rtmposev1/cspnext-m_udp-aic-coco_210e-256x192-f2f7d6f6_20230130.pth', + prefix='backbone.', + type='Pretrained'), + norm_cfg=dict(type='SyncBN'), + out_indices=(4, ), + type='CSPNeXt', + widen_factor=0.75), + data_preprocessor=dict( + bgr_to_rgb=True, + mean=[ + 123.675, + 116.28, + 103.53, + ], + std=[ + 58.395, + 57.12, + 57.375, + ], + type='PoseDataPreprocessor'), + head=dict( + decoder=dict( + input_size=( + 192, + 256, + ), + normalize=False, + sigma=( + 4.9, + 5.66, + ), + simcc_split_ratio=2.0, + type='SimCCLabel', + use_dark=False), + final_layer_kernel_size=7, + gau_cfg=dict( + act_fn='SiLU', + drop_path=0.0, + dropout_rate=0.0, + expansion_factor=2, + hidden_dims=256, + pos_enc=False, + s=128, + use_rel_bias=False), + in_channels=768, + in_featuremap_size=( + 6, + 8, + ), + input_size=( + 192, + 256, + ), + loss=dict( + beta=10.0, + label_softmax=True, + type='KLDiscretLoss', + use_target_weight=True), + out_channels=17, + simcc_split_ratio=2.0, + type='RTMCCHead'), + test_cfg=dict(flip_test=True), + type='TopdownPoseEstimator') +num_keypoints = 17 +optim_wrapper = dict( + clip_grad=dict(max_norm=35, norm_type=2), + optimizer=dict(lr=0.004, type='AdamW', weight_decay=0.05), + paramwise_cfg=dict( + bias_decay_mult=0, bypass_duplicate=True, norm_decay_mult=0), + type='OptimWrapper') +param_scheduler = [ + dict( + begin=0, by_epoch=False, end=1000, start_factor=1e-05, + type='LinearLR'), + dict( + T_max=210, + begin=210, + by_epoch=True, + convert_to_iter_based=True, + end=420, + eta_min=0.0002, + type='CosineAnnealingLR'), +] +randomness = dict(seed=21) +resume = False +stage2_num_epochs = 30 +test_cfg = dict() +test_dataloader = dict( + batch_size=32, + dataset=dict( + ann_file='annotations/person_keypoints_val2017.json', + data_mode='topdown', + data_prefix=dict(img='images/val2017/'), + data_root='../../../data/datasets/coco/', + pipeline=[ + dict(backend_args=dict(backend='local'), type='LoadImage'), + dict(type='GetBBoxCenterScale'), + dict(input_size=( + 192, + 256, + ), type='TopdownAffine'), + dict(type='PackPoseInputs'), + ], + test_mode=True, + type='CocoDataset'), + drop_last=False, + num_workers=10, + persistent_workers=True, + sampler=dict(round_up=False, shuffle=False, type='DefaultSampler')) +test_evaluator = dict( + ann_file= + '../../../data/datasets/coco/annotations/person_keypoints_val2017.json', + type='CocoMetric') +train_batch_size = 256 +train_cfg = dict(by_epoch=True, max_epochs=420, val_interval=10) +train_dataloader = dict( + batch_size=256, + dataset=dict( + ann_file='annotations/person_keypoints_train2017.json', + data_mode='topdown', + data_prefix=dict(img='train2017/'), + data_root='data/coco/', + pipeline=[ + dict(backend_args=dict(backend='local'), type='LoadImage'), + dict(type='GetBBoxCenterScale'), + dict(direction='horizontal', type='RandomFlip'), + dict(type='RandomHalfBody'), + dict( + rotate_factor=80, + scale_factor=[ + 0.6, + 1.4, + ], + type='RandomBBoxTransform'), + dict(input_size=( + 192, + 256, + ), type='TopdownAffine'), + dict(type='mmdet.YOLOXHSVRandomAug'), + dict( + transforms=[ + dict(p=0.1, type='Blur'), + dict(p=0.1, type='MedianBlur'), + dict( + max_height=0.4, + max_holes=1, + max_width=0.4, + min_height=0.2, + min_holes=1, + min_width=0.2, + p=1.0, + type='CoarseDropout'), + ], + type='Albumentation'), + dict( + encoder=dict( + input_size=( + 192, + 256, + ), + normalize=False, + sigma=( + 4.9, + 5.66, + ), + simcc_split_ratio=2.0, + type='SimCCLabel', + use_dark=False), + type='GenerateTarget'), + dict(type='PackPoseInputs'), + ], + type='CocoDataset'), + num_workers=10, + persistent_workers=True, + sampler=dict(shuffle=True, type='DefaultSampler')) +train_pipeline = [ + dict(backend_args=dict(backend='local'), type='LoadImage'), + dict(type='GetBBoxCenterScale'), + dict(direction='horizontal', type='RandomFlip'), + dict(type='RandomHalfBody'), + dict( + rotate_factor=80, + scale_factor=[ + 0.6, + 1.4, + ], + type='RandomBBoxTransform'), + dict(input_size=( + 192, + 256, + ), type='TopdownAffine'), + dict(type='mmdet.YOLOXHSVRandomAug'), + dict( + transforms=[ + dict(p=0.1, type='Blur'), + dict(p=0.1, type='MedianBlur'), + dict( + max_height=0.4, + max_holes=1, + max_width=0.4, + min_height=0.2, + min_holes=1, + min_width=0.2, + p=1.0, + type='CoarseDropout'), + ], + type='Albumentation'), + dict( + encoder=dict( + input_size=( + 192, + 256, + ), + normalize=False, + sigma=( + 4.9, + 5.66, + ), + simcc_split_ratio=2.0, + type='SimCCLabel', + use_dark=False), + type='GenerateTarget'), + dict(type='PackPoseInputs'), +] +train_pipeline_stage2 = [ + dict(backend_args=dict(backend='local'), type='LoadImage'), + dict(type='GetBBoxCenterScale'), + dict(direction='horizontal', type='RandomFlip'), + dict(type='RandomHalfBody'), + dict( + rotate_factor=60, + scale_factor=[ + 0.75, + 1.25, + ], + shift_factor=0.0, + type='RandomBBoxTransform'), + dict(input_size=( + 192, + 256, + ), type='TopdownAffine'), + dict(type='mmdet.YOLOXHSVRandomAug'), + dict( + transforms=[ + dict(p=0.1, type='Blur'), + dict(p=0.1, type='MedianBlur'), + dict( + max_height=0.4, + max_holes=1, + max_width=0.4, + min_height=0.2, + min_holes=1, + min_width=0.2, + p=0.5, + type='CoarseDropout'), + ], + type='Albumentation'), + dict( + encoder=dict( + input_size=( + 192, + 256, + ), + normalize=False, + sigma=( + 4.9, + 5.66, + ), + simcc_split_ratio=2.0, + type='SimCCLabel', + use_dark=False), + type='GenerateTarget'), + dict(type='PackPoseInputs'), +] +val_batch_size = 64 +val_cfg = dict() +val_dataloader = dict( + batch_size=64, + dataset=dict( + ann_file='annotations/person_keypoints_val2017.json', + data_mode='topdown', + data_prefix=dict(img='images/val2017/'), + data_root='data/coco/', + pipeline=[ + dict(backend_args=dict(backend='local'), type='LoadImage'), + dict(type='GetBBoxCenterScale'), + dict(input_size=( + 192, + 256, + ), type='TopdownAffine'), + dict(type='PackPoseInputs'), + ], + test_mode=True, + type='CocoDataset'), + drop_last=False, + num_workers=10, + persistent_workers=True, + sampler=dict(round_up=False, shuffle=False, type='DefaultSampler')) +val_evaluator = dict( + ann_file='data/coco/annotations/person_keypoints_val2017.json', + type='CocoMetric') +val_pipeline = [ + dict(backend_args=dict(backend='local'), type='LoadImage'), + dict(type='GetBBoxCenterScale'), + dict(input_size=( + 192, + 256, + ), type='TopdownAffine'), + dict(type='PackPoseInputs'), +] +vis_backends = [ + dict(_scope_='mmpose', type='LocalVisBackend'), +] +visualizer = dict( + _scope_='mmpose', + name='visualizer', + type='PoseLocalVisualizer', + vis_backends=[ + dict(type='LocalVisBackend'), + ]) +work_dir = './' diff --git a/models/cv/pose_estimation/rtmpose/ixrt/tensorrt_common.py b/models/cv/pose_estimation/rtmpose/ixrt/tensorrt_common.py new file mode 100644 index 0000000000000000000000000000000000000000..6407dabe4d03d506d339e3e41cf7dacb8d88c54c --- /dev/null +++ b/models/cv/pose_estimation/rtmpose/ixrt/tensorrt_common.py @@ -0,0 +1,179 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import os + +import cuda.cudart as cudart +import numpy as np +import tensorrt + + + +def create_engine_from_onnx(onnx_file, engine_file): + IXRT_LOGGER = tensorrt.Logger(tensorrt.Logger.ERROR) + builder = tensorrt.Builder(IXRT_LOGGER) + EXPLICIT_BATCH = 1 << (int)( + tensorrt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH + ) + network = builder.create_network(EXPLICIT_BATCH) + build_config = builder.create_builder_config() + parser = tensorrt.OnnxParser(network, IXRT_LOGGER) + parser.parse_from_file(onnx_file) + precision = tensorrt.BuilderFlag.FP16 + build_config.set_flag(precision) + plan = builder.build_serialized_network(network, build_config) + with open(engine_file, "wb") as f: + f.write(plan) + +def create_context(engine_file): + + host_mem = tensorrt.IHostMemory + logger = tensorrt.Logger(tensorrt.Logger.ERROR) + engine, context = create_engine_context(engine_file, logger) + + return engine, context + + + +def get_ixrt_output(engine, context, input_x): + + inputs, outputs, allocations = get_io_bindings(engine) + + input_data = input_x.astype(inputs[0]["dtype"]) + input_data = np.ascontiguousarray(input_data) + assert inputs[0]["nbytes"] == input_data.nbytes + (err,) = cudart.cudaMemcpy( + inputs[0]["allocation"], + input_data, + inputs[0]["nbytes"], + cudart.cudaMemcpyKind.cudaMemcpyHostToDevice, + ) + assert err == cudart.cudaError_t.cudaSuccess + + output0 = np.zeros(outputs[0]["shape"], outputs[0]["dtype"]) + output1 = np.zeros(outputs[1]["shape"], outputs[1]["dtype"]) + + + context.execute_v2(allocations) + assert outputs[0]["nbytes"] == output0.nbytes + (err,) = cudart.cudaMemcpy( + output0, + outputs[0]["allocation"], + outputs[0]["nbytes"], + cudart.cudaMemcpyKind.cudaMemcpyDeviceToHost, + ) + assert err == cudart.cudaError_t.cudaSuccess + + + assert outputs[1]["nbytes"] == output1.nbytes + (err,) = cudart.cudaMemcpy( + output1, + outputs[1]["allocation"], + outputs[1]["nbytes"], + cudart.cudaMemcpyKind.cudaMemcpyDeviceToHost, + ) + assert err == cudart.cudaError_t.cudaSuccess + # Free + for alloc in allocations: + (err,) = cudart.cudaFree(alloc) + assert err == cudart.cudaError_t.cudaSuccess + return output0,output1 + + +def create_engine_context(engine_path, logger): + with open(engine_path, "rb") as f: + runtime = tensorrt.Runtime(logger) + assert runtime + engine = runtime.deserialize_cuda_engine(f.read()) + assert engine + context = engine.create_execution_context() + assert context + + return engine, context + + +def get_io_bindings(engine): + # Setup I/O bindings + inputs = [] + outputs = [] + allocations = [] + + for i in range(engine.num_bindings): + is_input = False + if engine.binding_is_input(i): + is_input = True + name = engine.get_binding_name(i) + dtype = engine.get_binding_dtype(i) + shape = engine.get_binding_shape(i) + if is_input: + batch_size = shape[0] + size = np.dtype(tensorrt.nptype(dtype)).itemsize + for s in shape: + size *= s + err, allocation = cudart.cudaMalloc(size) + assert err == cudart.cudaError_t.cudaSuccess + binding = { + "index": i, + "name": name, + "dtype": np.dtype(tensorrt.nptype(dtype)), + "shape": list(shape), + "allocation": allocation, + "nbytes": size, + } + print( + f"binding {i}, name : {name} dtype : {np.dtype(tensorrt.nptype(dtype))} shape : {list(shape)}" + ) + allocations.append(allocation) + if engine.binding_is_input(i): + inputs.append(binding) + else: + outputs.append(binding) + return inputs, outputs, allocations + + +def setup_io_bindings(engine, context): + # Setup I/O bindings + inputs = [] + outputs = [] + allocations = [] + + for i in range(engine.num_bindings): + is_input = False + if engine.binding_is_input(i): + is_input = True + name = engine.get_binding_name(i) + dtype = engine.get_binding_dtype(i) + shape = context.get_binding_shape(i) + if is_input: + batch_size = shape[0] + size = np.dtype(tensorrt.nptype(dtype)).itemsize + for s in shape: + size *= s + err, allocation = cudart.cudaMalloc(size) + assert err == cudart.cudaError_t.cudaSuccess + binding = { + "index": i, + "name": name, + "dtype": np.dtype(tensorrt.nptype(dtype)), + "shape": list(shape), + "allocation": allocation, + "nbytes": size, + } + allocations.append(allocation) + if engine.binding_is_input(i): + inputs.append(binding) + else: + outputs.append(binding) + return inputs, outputs, allocations