From 52ded69adf8f6485d5811b03d27950f7cced422a Mon Sep 17 00:00:00 2001 From: "xinchi.tian" Date: Thu, 9 Jan 2025 18:36:58 +0800 Subject: [PATCH] Add FSAF --- models/cv/detection/fsaf/ixrt/README.md | 64 ++++ models/cv/detection/fsaf/ixrt/build_engine.py | 62 ++++ models/cv/detection/fsaf/ixrt/common.py | 69 +++++ .../cv/detection/fsaf/ixrt/deploy_default.py | 41 +++ models/cv/detection/fsaf/ixrt/export.py | 72 +++++ .../fsaf/ixrt/fsaf_r50_fpn_1x_coco.py | 276 ++++++++++++++++++ models/cv/detection/fsaf/ixrt/inference.py | 187 ++++++++++++ .../cv/detection/fsaf/ixrt/requirements.txt | 6 + .../ixrt/scripts/infer_fsaf_fp16_accuracy.sh | 34 +++ .../scripts/infer_fsaf_fp16_performance.sh | 36 +++ 10 files changed, 847 insertions(+) create mode 100644 models/cv/detection/fsaf/ixrt/README.md create mode 100644 models/cv/detection/fsaf/ixrt/build_engine.py create mode 100644 models/cv/detection/fsaf/ixrt/common.py create mode 100644 models/cv/detection/fsaf/ixrt/deploy_default.py create mode 100644 models/cv/detection/fsaf/ixrt/export.py create mode 100644 models/cv/detection/fsaf/ixrt/fsaf_r50_fpn_1x_coco.py create mode 100644 models/cv/detection/fsaf/ixrt/inference.py create mode 100644 models/cv/detection/fsaf/ixrt/requirements.txt create mode 100644 models/cv/detection/fsaf/ixrt/scripts/infer_fsaf_fp16_accuracy.sh create mode 100644 models/cv/detection/fsaf/ixrt/scripts/infer_fsaf_fp16_performance.sh diff --git a/models/cv/detection/fsaf/ixrt/README.md b/models/cv/detection/fsaf/ixrt/README.md new file mode 100644 index 00000000..5d75a2e6 --- /dev/null +++ b/models/cv/detection/fsaf/ixrt/README.md @@ -0,0 +1,64 @@ +# FSAF + +## Description + +The FSAF (Feature Selective Anchor-Free) module is an innovative component for single-shot object detection that enhances performance through online feature selection and anchor-free branches. The FSAF module dynamically selects the most suitable feature level for each object instance, rather than relying on traditional anchor-based heuristic methods. This improvement significantly boosts the accuracy of object detection, especially for small targets and in complex scenes. Moreover, compared to existing anchor-based detectors, the FSAF module maintains high efficiency while adding negligible additional inference overhead. + +## Setup + +### Install + +```bash +# Install libGL +## CentOS +yum install -y mesa-libGL +## Ubuntu +apt install -y libgl1-mesa-dev + +pip3 install -r requirements.txt +``` + +### Download + +Pretrained model: + +Dataset: to download the validation dataset. + +```bash +wget https://download.openmmlab.com/mmdetection/v2.0/fsaf/fsaf_r50_fpn_1x_coco/fsaf_r50_fpn_1x_coco-94ccc51f.pth +``` + +### Model Conversion + +```bash +# export onnx model +python3 export.py --weight fsaf_r50_fpn_1x_coco-94ccc51f.pth --cfg fsaf_r50_fpn_1x_coco.py --output fsaf.onnx + +# use onnxsim optimize onnx model +onnxsim fsaf.onnx fsaf_opt.onnx +``` + +## Inference + +```bash +export DATASETS_DIR=/Path/to/coco/ +``` + +### FP16 + +```bash +# Accuracy +bash scripts/infer_fsaf_fp16_accuracy.sh +# Performance +bash scripts/infer_fsaf_fp16_performance.sh +``` + +## Results + +Model |BatchSize |Precision |FPS |IOU@0.5 |IOU@0.5:0.95 | +-------|-----------|----------|----------|----------|---------------| +FSAF | 32 | FP16 | 133.85 | 0.530 | 0.345 | + +## Reference + +mmdetection: diff --git a/models/cv/detection/fsaf/ixrt/build_engine.py b/models/cv/detection/fsaf/ixrt/build_engine.py new file mode 100644 index 00000000..d7ad10e5 --- /dev/null +++ b/models/cv/detection/fsaf/ixrt/build_engine.py @@ -0,0 +1,62 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import os +import cv2 +import argparse +import numpy as np + +import torch +import tensorrt +from tensorrt import Dims + +def main(config): + IXRT_LOGGER = tensorrt.Logger(tensorrt.Logger.WARNING) + builder = tensorrt.Builder(IXRT_LOGGER) + EXPLICIT_BATCH = 1 << (int)(tensorrt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH) + network = builder.create_network(EXPLICIT_BATCH) + build_config = builder.create_builder_config() + profile = builder.create_optimization_profile() + profile.set_shape("input", Dims([32, 3, 800, 800]), Dims([32, 3, 800, 800]), Dims([32, 3, 800, 800])) + build_config.add_optimization_profile(profile) + parser = tensorrt.OnnxParser(network, IXRT_LOGGER) + parser.parse_from_file(config.model) + + precision = tensorrt.BuilderFlag.INT8 if config.precision == "int8" else tensorrt.BuilderFlag.FP16 + # print("precision : ", precision) + build_config.set_flag(precision) + num_inputs = network.num_inputs + + for i in range(num_inputs): + input_tensor = network.get_input(i) + input_tensor.shape = Dims([32, 3, 800, 800]) + + plan = builder.build_serialized_network(network, build_config) + engine_file_path = config.engine + with open(engine_file_path, "wb") as f: + f.write(plan) + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--model", type=str) + parser.add_argument("--precision", type=str, choices=["float16", "int8", "float32"], default="float16", + help="The precision of datatype") + parser.add_argument("--engine", type=str, default=None) + args = parser.parse_args() + return args + +if __name__ == "__main__": + args = parse_args() + main(args) \ No newline at end of file diff --git a/models/cv/detection/fsaf/ixrt/common.py b/models/cv/detection/fsaf/ixrt/common.py new file mode 100644 index 00000000..ef92a6ba --- /dev/null +++ b/models/cv/detection/fsaf/ixrt/common.py @@ -0,0 +1,69 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import os +import cv2 +import glob +import torch +import tensorrt +import numpy as np +from cuda import cuda, cudart + +def create_engine_context(engine_path, logger): + with open(engine_path, "rb") as f: + runtime = tensorrt.Runtime(logger) + assert runtime + engine = runtime.deserialize_cuda_engine(f.read()) + assert engine + context = engine.create_execution_context() + assert context + + return engine, context + +def get_io_bindings(engine): + # Setup I/O bindings + inputs = [] + outputs = [] + allocations = [] + + for i in range(engine.num_bindings): + is_input = False + if engine.binding_is_input(i): + is_input = True + name = engine.get_binding_name(i) + dtype = engine.get_binding_dtype(i) + shape = engine.get_binding_shape(i) + if is_input: + batch_size = shape[0] + size = np.dtype(tensorrt.nptype(dtype)).itemsize + for s in shape: + size *= s + err, allocation = cudart.cudaMalloc(size) + assert err == cudart.cudaError_t.cudaSuccess + binding = { + "index": i, + "name": name, + "dtype": np.dtype(tensorrt.nptype(dtype)), + "shape": list(shape), + "allocation": allocation, + "nbytes": size, + } + print(f"binding {i}, name : {name} dtype : {np.dtype(tensorrt.nptype(dtype))} shape : {list(shape)}") + allocations.append(allocation) + if engine.binding_is_input(i): + inputs.append(binding) + else: + outputs.append(binding) + return inputs, outputs, allocations \ No newline at end of file diff --git a/models/cv/detection/fsaf/ixrt/deploy_default.py b/models/cv/detection/fsaf/ixrt/deploy_default.py new file mode 100644 index 00000000..b8d8e43d --- /dev/null +++ b/models/cv/detection/fsaf/ixrt/deploy_default.py @@ -0,0 +1,41 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +onnx_config = dict( + type='onnx', + export_params=True, + keep_initializers_as_inputs=False, + opset_version=11, + save_file='end2end.onnx', + input_names=['input'], + output_names=['output'], + input_shape=None, + optimize=True) + +codebase_config = dict( + type='mmdet', + task='ObjectDetection', + model_type='end2end', + post_processing=dict( + score_threshold=0.05, + confidence_threshold=0.005, + iou_threshold=0.5, + max_output_boxes_per_class=200, + pre_top_k=5000, + keep_top_k=100, + background_label_id=-1, + )) + +backend_config = dict(type='onnxruntime') \ No newline at end of file diff --git a/models/cv/detection/fsaf/ixrt/export.py b/models/cv/detection/fsaf/ixrt/export.py new file mode 100644 index 00000000..13573c9d --- /dev/null +++ b/models/cv/detection/fsaf/ixrt/export.py @@ -0,0 +1,72 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import argparse + +import torch +from mmdeploy.utils import load_config +from mmdeploy.apis import build_task_processor + +def parse_args(): + parser = argparse.ArgumentParser() + + parser.add_argument("--weight", + type=str, + required=True, + help="pytorch model weight.") + + parser.add_argument("--cfg", + type=str, + required=True, + help="model config file.") + + parser.add_argument("--output", + type=str, + required=True, + help="export onnx model path.") + + args = parser.parse_args() + return args + +def main(): + args = parse_args() + + deploy_cfg = 'deploy_default.py' + model_cfg = args.cfg + model_checkpoint = args.weight + + deploy_cfg, model_cfg = load_config(deploy_cfg, model_cfg) + + task_processor = build_task_processor(model_cfg, deploy_cfg, device='cpu') + + model = task_processor.build_pytorch_model(model_checkpoint) + + input_names = ['input'] + dynamic_axes = {'input': {0: '-1'}} + dummy_input = torch.randn(1, 3, 800, 800) + + torch.onnx.export( + model, + dummy_input, + args.output, + input_names = input_names, + dynamic_axes = dynamic_axes, + opset_version=13 + ) + + print("Export onnx model successfully! ") + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/models/cv/detection/fsaf/ixrt/fsaf_r50_fpn_1x_coco.py b/models/cv/detection/fsaf/ixrt/fsaf_r50_fpn_1x_coco.py new file mode 100644 index 00000000..d511321f --- /dev/null +++ b/models/cv/detection/fsaf/ixrt/fsaf_r50_fpn_1x_coco.py @@ -0,0 +1,276 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +auto_scale_lr = dict(base_batch_size=16, enable=False) +backend_args = None +data_root = 'data/coco/' +dataset_type = 'CocoDataset' +default_hooks = dict( + checkpoint=dict(interval=1, type='CheckpointHook'), + logger=dict(interval=50, type='LoggerHook'), + param_scheduler=dict(type='ParamSchedulerHook'), + sampler_seed=dict(type='DistSamplerSeedHook'), + timer=dict(type='IterTimerHook'), + visualization=dict(type='DetVisualizationHook')) +default_scope = 'mmdet' +env_cfg = dict( + cudnn_benchmark=False, + dist_cfg=dict(backend='nccl'), + mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0)) +load_from = None +log_level = 'ERROR' +log_processor = dict(by_epoch=True, type='LogProcessor', window_size=50) +model = dict( + backbone=dict( + depth=50, + frozen_stages=1, + init_cfg=dict(checkpoint='torchvision://resnet50', type='Pretrained'), + norm_cfg=dict(requires_grad=True, type='BN'), + norm_eval=True, + num_stages=4, + out_indices=( + 0, + 1, + 2, + 3, + ), + style='pytorch', + type='ResNet'), + bbox_head=dict( + anchor_generator=dict( + octave_base_scale=1, + ratios=[ + 1.0, + ], + scales_per_octave=1, + strides=[ + 8, + 16, + 32, + 64, + 128, + ], + type='AnchorGenerator'), + bbox_coder=dict(normalizer=4.0, type='TBLRBBoxCoder'), + feat_channels=256, + in_channels=256, + loss_bbox=dict( + eps=1e-06, loss_weight=1.0, reduction='none', type='IoULoss'), + loss_cls=dict( + alpha=0.25, + gamma=2.0, + loss_weight=1.0, + reduction='none', + type='FocalLoss', + use_sigmoid=True), + num_classes=80, + reg_decoded_bbox=True, + stacked_convs=4, + type='FSAFHead'), + data_preprocessor=dict( + bgr_to_rgb=True, + mean=[ + 123.675, + 116.28, + 103.53, + ], + pad_size_divisor=32, + std=[ + 58.395, + 57.12, + 57.375, + ], + type='DetDataPreprocessor'), + neck=dict( + add_extra_convs='on_input', + in_channels=[ + 256, + 512, + 1024, + 2048, + ], + num_outs=5, + out_channels=256, + start_level=1, + type='FPN'), + test_cfg=dict( + max_per_img=100, + min_bbox_size=0, + nms=dict(iou_threshold=0.5, type='nms'), + nms_pre=1000, + score_thr=0.05), + train_cfg=dict( + allowed_border=-1, + assigner=dict( + min_pos_iof=0.01, + neg_scale=0.2, + pos_scale=0.2, + type='CenterRegionAssigner'), + debug=False, + pos_weight=-1, + sampler=dict(type='PseudoSampler')), + type='FSAF') +optim_wrapper = dict( + optimizer=dict(lr=0.02, momentum=0.9, type='SGD', weight_decay=0.0001), + type='OptimWrapper') +param_scheduler = [ + dict( + begin=0, by_epoch=False, end=500, start_factor=0.001, type='LinearLR'), + dict( + begin=0, + by_epoch=True, + end=12, + gamma=0.1, + milestones=[ + 8, + 11, + ], + type='MultiStepLR'), +] +resume = False +test_cfg = dict(type='TestLoop') +test_dataloader = dict( + batch_size=32, + dataset=dict( + ann_file='annotations/instances_val2017.json', + backend_args=None, + data_prefix=dict(img='images/val2017/'), + data_root='data/coco/', + pipeline=[ + dict(backend_args=None, type='LoadImageFromFile'), + dict(keep_ratio=False, scale=( + 800, + 800, + ), type='Resize'), + dict(type='LoadAnnotations', with_bbox=True), + dict( + meta_keys=( + 'img_id', + 'img_path', + 'ori_shape', + 'img_shape', + 'scale_factor', + ), + type='PackDetInputs'), + ], + test_mode=True, + type='CocoDataset'), + drop_last=False, + num_workers=2, + persistent_workers=True, + sampler=dict(shuffle=False, type='DefaultSampler')) +test_evaluator = dict( + ann_file='data/coco/annotations/instances_val2017.json', + backend_args=None, + format_only=False, + metric='bbox', + type='CocoMetric') +test_pipeline = [ + dict(backend_args=None, type='LoadImageFromFile'), + dict(keep_ratio=True, scale=( + 1333, + 800, + ), type='Resize'), + dict(type='LoadAnnotations', with_bbox=True), + dict( + meta_keys=( + 'img_id', + 'img_path', + 'ori_shape', + 'img_shape', + 'scale_factor', + ), + type='PackDetInputs'), +] +train_cfg = dict(max_epochs=12, type='EpochBasedTrainLoop', val_interval=1) +train_dataloader = dict( + batch_sampler=dict(type='AspectRatioBatchSampler'), + batch_size=2, + dataset=dict( + ann_file='annotations/instances_train2017.json', + backend_args=None, + data_prefix=dict(img='train2017/'), + data_root='data/coco/', + filter_cfg=dict(filter_empty_gt=True, min_size=32), + pipeline=[ + dict(backend_args=None, type='LoadImageFromFile'), + dict(type='LoadAnnotations', with_bbox=True), + dict(keep_ratio=True, scale=( + 1333, + 800, + ), type='Resize'), + dict(prob=0.5, type='RandomFlip'), + dict(type='PackDetInputs'), + ], + type='CocoDataset'), + num_workers=2, + persistent_workers=True, + sampler=dict(shuffle=True, type='DefaultSampler')) +train_pipeline = [ + dict(backend_args=None, type='LoadImageFromFile'), + dict(type='LoadAnnotations', with_bbox=True), + dict(keep_ratio=True, scale=( + 1333, + 800, + ), type='Resize'), + dict(prob=0.5, type='RandomFlip'), + dict(type='PackDetInputs'), +] +val_cfg = dict(type='ValLoop') +val_dataloader = dict( + batch_size=1, + dataset=dict( + ann_file='annotations/instances_val2017.json', + backend_args=None, + data_prefix=dict(img='val2017/'), + data_root='data/coco/', + pipeline=[ + dict(backend_args=None, type='LoadImageFromFile'), + dict(keep_ratio=True, scale=( + 1333, + 800, + ), type='Resize'), + dict(type='LoadAnnotations', with_bbox=True), + dict( + meta_keys=( + 'img_id', + 'img_path', + 'ori_shape', + 'img_shape', + 'scale_factor', + ), + type='PackDetInputs'), + ], + test_mode=True, + type='CocoDataset'), + drop_last=False, + num_workers=2, + persistent_workers=True, + sampler=dict(shuffle=False, type='DefaultSampler')) +val_evaluator = dict( + ann_file='data/coco/annotations/instances_val2017.json', + backend_args=None, + format_only=False, + metric='bbox', + type='CocoMetric') +vis_backends = [ + dict(type='LocalVisBackend'), +] +visualizer = dict( + name='visualizer', + type='DetLocalVisualizer', + vis_backends=[ + dict(type='LocalVisBackend'), + ]) \ No newline at end of file diff --git a/models/cv/detection/fsaf/ixrt/inference.py b/models/cv/detection/fsaf/ixrt/inference.py new file mode 100644 index 00000000..5d940cd5 --- /dev/null +++ b/models/cv/detection/fsaf/ixrt/inference.py @@ -0,0 +1,187 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import os +import time +import argparse +import tensorrt +import torch +import torchvision +import numpy as np +from tensorrt import Dims +from cuda import cuda, cudart +from tqdm import tqdm +from mmdet.registry import RUNNERS +from mmengine.config import Config + +from common import create_engine_context, get_io_bindings + +def parse_args(): + parser = argparse.ArgumentParser() + + parser.add_argument("--engine", + type=str, + required=True, + help="igie engine path.") + + parser.add_argument("--batchsize", + type=int, + required=True, + help="inference batch size.") + + parser.add_argument("--datasets", + type=str, + required=True, + help="datasets path.") + + parser.add_argument("--input_name", + type=str, + required=True, + help="input name of the model.") + + parser.add_argument("--warmup", + type=int, + default=3, + help="number of warmup before test.") + + parser.add_argument("--acc_target", + type=float, + default=None, + help="Model inference Accuracy target.") + + parser.add_argument("--fps_target", + type=float, + default=None, + help="Model inference FPS target.") + + parser.add_argument("--perf_only", + type=bool, + default=False, + help="Run performance test only") + + args = parser.parse_args() + + return args + +def main(): + args = parse_args() + + batch_size = args.batchsize + + host_mem = tensorrt.IHostMemory + logger = tensorrt.Logger(tensorrt.Logger.ERROR) + + # Load Engine && I/O bindings + engine, context = create_engine_context(args.engine, logger) + inputs, outputs, allocations = get_io_bindings(engine) + + # just run perf test + if args.perf_only: + torch.cuda.synchronize() + start_time = time.time() + + for i in range(10): + context.execute_v2(allocations) + + torch.cuda.synchronize() + end_time = time.time() + forward_time = end_time - start_time + num_samples = 10 * args.batchsize + fps = num_samples / forward_time + + print("FPS : ", fps) + print(f"Performance Check : Test {fps} >= target {args.fps_target}") + if fps >= args.fps_target: + print("pass!") + exit() + else: + print("failed!") + exit(1) + else: + # runner config + cfg = Config.fromfile("fsaf_r50_fpn_1x_coco.py") + + cfg.work_dir = "./workspace" + cfg['test_dataloader']['batch_size'] = batch_size + cfg['test_dataloader']['dataset']['data_root'] = args.datasets + cfg['test_dataloader']['dataset']['data_prefix']['img'] = 'images/val2017/' + cfg['test_evaluator']['ann_file'] = os.path.join(args.datasets, 'annotations/instances_val2017.json') + cfg['log_level'] = 'ERROR' + + # build runner + runner = RUNNERS.build(cfg) + + for data in tqdm(runner.test_dataloader): + cls_score = [] + box_reg = [] + + input_data = runner.model.data_preprocessor(data, False) + image = input_data['inputs'].cpu() + image = image.numpy().astype(inputs[0]["dtype"]) + pad_batch = len(image) != batch_size + + if pad_batch: + origin_size = len(image) + image = np.resize(image, (batch_size, *image.shape[1:])) + + image = np.ascontiguousarray(image) + + (err,) = cudart.cudaMemcpy( + inputs[0]["allocation"], + image, + image.nbytes, + cudart.cudaMemcpyKind.cudaMemcpyHostToDevice, + ) + assert err == cudart.cudaError_t.cudaSuccess + # cuda.memcpy_htod(inputs[0]["allocation"], batch_data) + context.execute_v2(allocations) + + for i in range(len(outputs)): + output = np.zeros(outputs[i]["shape"], outputs[i]["dtype"]) + (err,) = cudart.cudaMemcpy( + output, + outputs[i]["allocation"], + outputs[i]["nbytes"], + cudart.cudaMemcpyKind.cudaMemcpyDeviceToHost, + ) + assert err == cudart.cudaError_t.cudaSuccess + + if pad_batch: + output = output[:origin_size] + + output = torch.from_numpy(output) + + if output.shape[1] == 80: + cls_score.append(output) + elif output.shape[1] == 4: + box_reg.append(output) + + batch_img_metas = [ + data_samples.metainfo for data_samples in data['data_samples'] + ] + + preds = runner.model.bbox_head.predict_by_feat( + cls_score, box_reg, batch_img_metas=batch_img_metas, rescale=True + ) + + batch_data_samples = runner.model.add_pred_to_datasample(input_data['data_samples'], preds) + + runner.test_evaluator.process(data_samples=batch_data_samples, data_batch=data) + + metrics = runner.test_evaluator.evaluate(len(runner.test_dataloader.dataset)) + + +if __name__ == "__main__": + main() diff --git a/models/cv/detection/fsaf/ixrt/requirements.txt b/models/cv/detection/fsaf/ixrt/requirements.txt new file mode 100644 index 00000000..a26706ef --- /dev/null +++ b/models/cv/detection/fsaf/ixrt/requirements.txt @@ -0,0 +1,6 @@ +onnx +tqdm +onnxsim +mmdet==3.3.0 +mmdeploy==1.3.1 +mmengine==0.10.4 diff --git a/models/cv/detection/fsaf/ixrt/scripts/infer_fsaf_fp16_accuracy.sh b/models/cv/detection/fsaf/ixrt/scripts/infer_fsaf_fp16_accuracy.sh new file mode 100644 index 00000000..ed3132c6 --- /dev/null +++ b/models/cv/detection/fsaf/ixrt/scripts/infer_fsaf_fp16_accuracy.sh @@ -0,0 +1,34 @@ +#!/bin/bash + +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +batchsize=32 +model_path="fsaf_opt.onnx" +datasets_path=${DATASETS_DIR} + +# build engine +python3 build_engine.py \ + --model ${model_path} \ + --precision float16 \ + --engine fsaf.engine + + +# inference +python3 inference.py \ + --engine fsaf.engine \ + --batchsize ${batchsize} \ + --input_name input \ + --datasets ${datasets_path} \ No newline at end of file diff --git a/models/cv/detection/fsaf/ixrt/scripts/infer_fsaf_fp16_performance.sh b/models/cv/detection/fsaf/ixrt/scripts/infer_fsaf_fp16_performance.sh new file mode 100644 index 00000000..65fad0c7 --- /dev/null +++ b/models/cv/detection/fsaf/ixrt/scripts/infer_fsaf_fp16_performance.sh @@ -0,0 +1,36 @@ +#!/bin/bash + +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +batchsize=32 +model_path="fsaf_opt.onnx" +datasets_path=${DATASETS_DIR} + +# build engine +python3 build_engine.py \ + --model ${model_path} \ + --precision float16 \ + --engine fsaf.engine + + +# inference +python3 inference.py \ + --engine fsaf.engine \ + --batchsize ${batchsize} \ + --input_name input \ + --datasets ${datasets_path} \ + --perf_only True \ + --fps_target 130 \ No newline at end of file -- Gitee