From b2fa2f828ae8a2bd43da38004553f82d464a9447 Mon Sep 17 00:00:00 2001 From: may Date: Fri, 5 Jul 2024 14:18:46 +0800 Subject: [PATCH 1/2] Add yolov4darknet fp16 and int8 inference. --- models/cv/detection/yolov4/ixrt/README.md | 74 ++++ .../cv/detection/yolov4/ixrt/build_engine.py | 97 +++++ .../cv/detection/yolov4/ixrt/coco_labels.py | 103 ++++++ models/cv/detection/yolov4/ixrt/common.py | 335 ++++++++++++++++++ models/cv/detection/yolov4/ixrt/cut_model.py | 30 ++ models/cv/detection/yolov4/ixrt/deploy.py | 210 +++++++++++ models/cv/detection/yolov4/ixrt/export.py | 56 +++ models/cv/detection/yolov4/ixrt/inference.py | 211 +++++++++++ .../detection/yolov4/ixrt/load_ixrt_plugin.py | 26 ++ models/cv/detection/yolov4/ixrt/quant.py | 105 ++++++ .../infer_yolov4darknet_fp16_accuary.sh | 88 +++++ .../infer_yolov4darknet_fp16_performance.sh | 88 +++++ .../infer_yolov4darknet_int8_accuary.sh | 106 ++++++ .../infer_yolov4darknet_int8_performance.sh | 106 ++++++ 14 files changed, 1635 insertions(+) create mode 100644 models/cv/detection/yolov4/ixrt/README.md create mode 100644 models/cv/detection/yolov4/ixrt/build_engine.py create mode 100644 models/cv/detection/yolov4/ixrt/coco_labels.py create mode 100644 models/cv/detection/yolov4/ixrt/common.py create mode 100644 models/cv/detection/yolov4/ixrt/cut_model.py create mode 100644 models/cv/detection/yolov4/ixrt/deploy.py create mode 100644 models/cv/detection/yolov4/ixrt/export.py create mode 100644 models/cv/detection/yolov4/ixrt/inference.py create mode 100644 models/cv/detection/yolov4/ixrt/load_ixrt_plugin.py create mode 100644 models/cv/detection/yolov4/ixrt/quant.py create mode 100644 models/cv/detection/yolov4/ixrt/scripts/infer_yolov4darknet_fp16_accuary.sh create mode 100644 models/cv/detection/yolov4/ixrt/scripts/infer_yolov4darknet_fp16_performance.sh create mode 100644 models/cv/detection/yolov4/ixrt/scripts/infer_yolov4darknet_int8_accuary.sh create mode 100644 models/cv/detection/yolov4/ixrt/scripts/infer_yolov4darknet_int8_performance.sh diff --git a/models/cv/detection/yolov4/ixrt/README.md b/models/cv/detection/yolov4/ixrt/README.md new file mode 100644 index 00000000..abbeb2bf --- /dev/null +++ b/models/cv/detection/yolov4/ixrt/README.md @@ -0,0 +1,74 @@ +# YOLOv4 + +## Description + +YOLOv4 employs a two-step process, involving regression for bounding box positioning and classification for object categorization. it amalgamates past YOLO family research contributions with novel features like WRC, CSP, CmBN, SAT, Mish activation, Mosaic data augmentation, DropBlock regularization, and CIoU loss. + +## Setup + +### Install + +```bash +# Install libGL +## CentOS +yum install -y mesa-libGL +## Ubuntu +apt install -y libgl1-mesa-dev + +pip3 install tqdm +pip3 install onnx +pip3 install onnxsim +pip3 install pycocotools +``` + +### Download + +Pretrained cfg: +Pretrained model: + +Dataset: to download the validation dataset. + +### Model Conversion + +```bash +# clone yolov4 +git clone https://github.com/Tianxiaomo/pytorch-YOLOv4.git yolov4 + +# download weight +mkdir data +wget https://github.com/AlexeyAB/darknet/releases/download/darknet_yolo_v3_optimal/yolov4.weights -P data + +# export onnx model +python3 export.py --cfg yolov4/cfg/yolov4.cfg --weight data/yolov4.weights --batchsize 16 --output data/yolov4.onnx +mv yolov4_16_3_608_608_static.onnx data/yolov4.onnx + +# Use onnxsim optimize onnx model +onnxsim data/yolov4.onnx data/yolov4_sim.onnx + +# Make sure the dataset path is "data/coco" +``` + +## Inference + +### FP16 + +```bash +# Accuracy +bash scripts/infer_yolov4darknet_fp16_accuary.sh +# Performance +bash scripts/infer_yolov4darknet_fp16_performance.sh +``` + +### INT8 + +```bash +# Accuracy +bash scripts/infer_yolov4darknet_int8_accuracy.sh +# Performance +bash scripts/infer_yolov4darknet_int8_performance.sh +``` + +## Reference + +DarkNet: +Pytorch-YOLOv4: diff --git a/models/cv/detection/yolov4/ixrt/build_engine.py b/models/cv/detection/yolov4/ixrt/build_engine.py new file mode 100644 index 00000000..ec4080ed --- /dev/null +++ b/models/cv/detection/yolov4/ixrt/build_engine.py @@ -0,0 +1,97 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +import os +import cv2 +import argparse +import numpy as np + +import torch +import tensorrt +from tensorrt import Dims + +from load_ixrt_plugin import load_ixrt_plugin +load_ixrt_plugin() + + +def build_engine_trtapi_staticshape(config): + IXRT_LOGGER = tensorrt.Logger(tensorrt.Logger.WARNING) + builder = tensorrt.Builder(IXRT_LOGGER) + EXPLICIT_BATCH = 1 << (int)(tensorrt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH) + network = builder.create_network(EXPLICIT_BATCH) + build_config = builder.create_builder_config() + parser = tensorrt.OnnxParser(network, IXRT_LOGGER) + parser.parse_from_file(config.model) + + precision = tensorrt.BuilderFlag.INT8 if config.precision == "int8" else tensorrt.BuilderFlag.FP16 + # print("precision : ", precision) + build_config.set_flag(precision) + + plan = builder.build_serialized_network(network, build_config) + engine_file_path = config.engine + with open(engine_file_path, "wb") as f: + f.write(plan) + print("Build static shape engine done!") + + +def build_engine_trtapi_dynamicshape(config): + IXRT_LOGGER = tensorrt.Logger(tensorrt.Logger.WARNING) + builder = tensorrt.Builder(IXRT_LOGGER) + EXPLICIT_BATCH = 1 << (int)(tensorrt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH) + network = builder.create_network(EXPLICIT_BATCH) + build_config = builder.create_builder_config() + + profile = builder.create_optimization_profile() + profile.set_shape("input", + Dims([1, 3, 608, 608]), + Dims([32, 3, 608, 608]), + Dims([64, 3, 608, 608]), + ) + build_config.add_optimization_profile(profile) + + parser = tensorrt.OnnxParser(network, IXRT_LOGGER) + parser.parse_from_file(config.model) + precision = tensorrt.BuilderFlag.INT8 if config.precision == "int8" else tensorrt.BuilderFlag.FP16 + # print("precision : ", precision) + build_config.set_flag(precision) + + # set dynamic + num_inputs = network.num_inputs + for i in range(num_inputs): + input_tensor = network.get_input(i) + input_tensor.shape = Dims([-1, 3, 608, 608]) + + plan = builder.build_serialized_network(network, build_config) + engine_file_path = config.engine + with open(engine_file_path, "wb") as f: + f.write(plan) + print("Build dynamic shape engine done!") + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--model", type=str) + parser.add_argument("--precision", type=str, choices=["float16", "int8", "float32"], default="int8", + help="The precision of datatype") + # engine args + parser.add_argument("--engine", type=str, default=None) + + args = parser.parse_args() + return args + + +if __name__ == "__main__": + args = parse_args() + build_engine_trtapi_staticshape(args) + # build_engine_trtapi_dynamicshape(args) diff --git a/models/cv/detection/yolov4/ixrt/coco_labels.py b/models/cv/detection/yolov4/ixrt/coco_labels.py new file mode 100644 index 00000000..5fc21282 --- /dev/null +++ b/models/cv/detection/yolov4/ixrt/coco_labels.py @@ -0,0 +1,103 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +labels = [ + "person", + "bicycle", + "car", + "motorcycle", + "airplane", + "bus", + "train", + "truck", + "boat", + "traffic light", + "fire hydrant", + "stop sign", + "parking meter", + "bench", + "bird", + "cat", + "dog", + "horse", + "sheep", + "cow", + "elephant", + "bear", + "zebra", + "giraffe", + "backpack", + "umbrella", + "handbag", + "tie", + "suitcase", + "frisbee", + "skis", + "snowboard", + "sports ball", + "kite", + "baseball bat", + "baseball glove", + "skateboard", + "surfboard", + "tennis racket", + "bottle", + "wine glass", + "cup", + "fork", + "knife", + "spoon", + "bowl", + "banana", + "apple", + "sandwich", + "orange", + "broccoli", + "carrot", + "hot dog", + "pizza", + "donut", + "cake", + "chair", + "couch", + "potted plant", + "bed", + "dining table", + "toilet", + "tv", + "laptop", + "mouse", + "remote", + "keyboard", + "cell phone", + "microwave", + "oven", + "toaster", + "sink", + "refrigerator", + "book", + "clock", + "vase", + "scissors", + "teddy bear", + "hair drier", + "toothbrush", +] +def coco80_to_coco91_class(): # converts 80-index (val2014) to 91-index (paper) + return [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34, + 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, + 64, 65, 67, 70, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 84, 85, 86, 87, 88, 89, 90] + +__all__ = ["labels"] diff --git a/models/cv/detection/yolov4/ixrt/common.py b/models/cv/detection/yolov4/ixrt/common.py new file mode 100644 index 00000000..dc3c2766 --- /dev/null +++ b/models/cv/detection/yolov4/ixrt/common.py @@ -0,0 +1,335 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +import os +import cv2 +import glob +import time +import numpy as np +from tqdm import tqdm + +import tensorrt +import pycuda.driver as cuda + + +def load_class_names(namesfile): + class_names = [] + with open(namesfile, 'r') as fp: + lines = fp.readlines() + for line in lines: + line = line.rstrip() + class_names.append(line) + return class_names + +# input : [bsz, box_num, 5(cx, cy, w, h, conf) + class_num(prob[0], prob[1], ...)] +# output : [bsz, box_num, 6(left_top_x, left_top_y, right_bottom_x, right_bottom_y, class_id, max_prob*conf)] +def box_class85to6(input): + center_x_y = input[:, :2] + side = input[:, 2:4] + conf = input[:, 4:5] + class_id = np.argmax(input[:, 5:], axis = -1) + class_id = class_id.astype(np.float32).reshape(-1, 1) + 1 + max_prob = np.max(input[:, 5:], axis = -1).reshape(-1, 1) + x1_y1 = center_x_y - 0.5 * side + x2_y2 = center_x_y + 0.5 * side + nms_input = np.concatenate([x1_y1, x2_y2, class_id, max_prob*conf], axis = -1) + return nms_input + +def save2json(batch_img_id, pred_boxes, json_result, class_trans): + for i, boxes in enumerate(pred_boxes): + if boxes is not None: + image_id = int(batch_img_id[i]) + # have no target + if image_id == -1: + continue + + for x1, y1, x2, y2, _, p, c in boxes: + x1, y1, x2, y2, p = float(x1), float(y1), float(x2), float(y2), float(p) + c = int(c) + x = x1 + y = y1 + w = x2 - x1 + h = y2 - y1 + + json_result.append( + { + "image_id": image_id, + "category_id": class_trans[c - 1], + "bbox": [x, y, w, h], + "score": p, + } + ) + +################## About TensorRT ################# +def create_engine_context(engine_path, logger): + with open(engine_path, "rb") as f: + runtime = tensorrt.Runtime(logger) + assert runtime + engine = runtime.deserialize_cuda_engine(f.read()) + assert engine + context = engine.create_execution_context() + assert context + + return engine, context + +def setup_io_bindings(engine, context): + # Setup I/O bindings + inputs = [] + outputs = [] + allocations = [] + + for i in range(engine.num_bindings): + is_input = False + if engine.binding_is_input(i): + is_input = True + name = engine.get_binding_name(i) + dtype = engine.get_binding_dtype(i) + shape = context.get_binding_shape(i) + if is_input: + batch_size = shape[0] + size = np.dtype(tensorrt.nptype(dtype)).itemsize + for s in shape: + size *= s + allocation = cuda.mem_alloc(size) + binding = { + "index": i, + "name": name, + "dtype": np.dtype(tensorrt.nptype(dtype)), + "shape": list(shape), + "allocation": allocation, + } + # print(f"binding {i}, name : {name} dtype : {np.dtype(tensorrt.nptype(dtype))} shape : {list(shape)}") + allocations.append(allocation) + if engine.binding_is_input(i): + inputs.append(binding) + else: + outputs.append(binding) + return inputs, outputs, allocations +########################################################## + + +################## About Loading Dataset ################# +def load_images(images_path): + """ + If image path is given, return it directly + For txt file, read it and return each line as image path + In other case, it's a folder, return a list with names of each + jpg, jpeg and png file + """ + input_path_extension = images_path.split('.')[-1] + if input_path_extension in ['jpg', 'jpeg', 'png']: + return [images_path] + elif input_path_extension == "txt": + with open(images_path, "r") as f: + return f.read().splitlines() + else: + return glob.glob( + os.path.join(images_path, "*.jpg")) + \ + glob.glob(os.path.join(images_path, "*.png")) + \ + glob.glob(os.path.join(images_path, "*.jpeg")) + +def prepare_batch(images_path, bs=16, input_size=(608, 608)): + + width, height = input_size + + batch_names = [] + batch_images = [] + batch_shapes = [] + + temp_names = [] + temp_images = [] + temp_shapes = [] + + for i, image_path in tqdm(enumerate(images_path), desc="Loading coco data"): + name = os.path.basename(image_path) + image = cv2.imread(image_path) + h, w, _ = image.shape + image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + image_resized = cv2.resize(image_rgb, (width, height), + interpolation=cv2.INTER_LINEAR) + custom_image = image_resized.transpose(2, 0, 1).astype(np.float32) / 255. + custom_image = np.expand_dims(custom_image, axis=0) + + if i != 0 and i % bs == 0: + batch_names.append(temp_names) + batch_images.append(np.concatenate(temp_images, axis=0)) + batch_shapes.append(temp_shapes) + + temp_names = [name] + temp_images = [custom_image] + temp_shapes = [(h, w)] + else: + temp_names.append(name) + temp_images.append(custom_image) + temp_shapes.append((h, w)) + + return batch_names, batch_images, batch_shapes +########################################################## + + +################## About Operating box ################# +def letterbox(im, new_shape=(640, 640), color=(114, 114, 114), auto=True, scaleFill=False, scaleup=True, stride=32): + # Resize and pad image while meeting stride-multiple constraints + shape = im.shape[:2] # current shape [height, width] + if isinstance(new_shape, int): + new_shape = (new_shape, new_shape) + + # Scale ratio (new / old) + r = min(new_shape[0] / shape[0], new_shape[1] / shape[1]) + if not scaleup: # only scale down, do not scale up (for better val mAP) + r = min(r, 1.0) + + # Compute padding + ratio = r, r # width, height ratios + new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r)) + dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding + if auto: # minimum rectangle + dw, dh = np.mod(dw, stride), np.mod(dh, stride) # wh padding + elif scaleFill: # stretch + dw, dh = 0.0, 0.0 + new_unpad = (new_shape[1], new_shape[0]) + ratio = new_shape[1] / shape[1], new_shape[0] / shape[0] # width, height ratios + + dw /= 2 # divide padding into 2 sides + dh /= 2 + + if shape[::-1] != new_unpad: # resize + im = cv2.resize(im, new_unpad, interpolation=cv2.INTER_LINEAR) + top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1)) + left, right = int(round(dw - 0.1)), int(round(dw + 0.1)) + im = cv2.copyMakeBorder(im, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color) # add border + return im, ratio, (dw, dh) + +def scale_boxes(net_shape, boxes, ori_shape, use_letterbox=False): + # Rescale boxes (xyxy) from net_shape to ori_shape + + if use_letterbox: + + gain = min( + net_shape[0] / ori_shape[0], net_shape[1] / ori_shape[1] + ) # gain = new / old + pad = (net_shape[1] - ori_shape[1] * gain) / 2, ( + net_shape[0] - ori_shape[0] * gain + ) / 2.0 + + boxes[:, [0, 2]] -= pad[0] # x padding + boxes[:, [1, 3]] -= pad[1] # y padding + boxes[:, :4] /= gain + else: + x_scale, y_scale = net_shape[1] / ori_shape[1], net_shape[0] / ori_shape[0] + + boxes[:, 0] /= x_scale + boxes[:, 1] /= y_scale + boxes[:, 2] /= x_scale + boxes[:, 3] /= y_scale + + clip_boxes(boxes, ori_shape) + return boxes + +def clip_boxes(boxes, shape): + + boxes[:, [0, 2]] = boxes[:, [0, 2]].clip(0, shape[1]) # x1, x2 + boxes[:, [1, 3]] = boxes[:, [1, 3]].clip(0, shape[0]) # y1, y2 +########################################################## + + +################## About pre and post processing ######### +def pre_processing(src_img, imgsz=608): + resized = cv2.resize(src_img, (imgsz, imgsz), interpolation=cv2.INTER_LINEAR) + in_img = cv2.cvtColor(resized, cv2.COLOR_BGR2RGB) + in_img = np.transpose(in_img, (2, 0, 1)).astype(np.float32) + in_img = np.expand_dims(in_img, axis=0) + in_img /= 255.0 + return in_img + +def nms_cpu(boxes, confs, nms_thresh=0.5, min_mode=False): + # print(boxes.shape) + x1 = boxes[:, 0] + y1 = boxes[:, 1] + x2 = boxes[:, 2] + y2 = boxes[:, 3] + + areas = (x2 - x1) * (y2 - y1) + order = confs.argsort()[::-1] + + keep = [] + while order.size > 0: + idx_self = order[0] + idx_other = order[1:] + + keep.append(idx_self) + + xx1 = np.maximum(x1[idx_self], x1[idx_other]) + yy1 = np.maximum(y1[idx_self], y1[idx_other]) + xx2 = np.minimum(x2[idx_self], x2[idx_other]) + yy2 = np.minimum(y2[idx_self], y2[idx_other]) + + w = np.maximum(0.0, xx2 - xx1) + h = np.maximum(0.0, yy2 - yy1) + inter = w * h + + if min_mode: + over = inter / np.minimum(areas[order[0]], areas[order[1:]]) + else: + over = inter / (areas[order[0]] + areas[order[1:]] - inter) + + inds = np.where(over <= nms_thresh)[0] + order = order[inds + 1] + + return np.array(keep) + + +def post_processing(img, conf_thresh, nms_thresh, output, num_classes=80): + + # [batch, num, 1, 4] + box_array = output[:, :, :4] + # [batch, num, 2] + class_confs = output[:, :, 4:] + + max_conf = class_confs[:, :, 1] + max_id = class_confs[:, :, 0] + + bboxes_batch = [] + for i in range(box_array.shape[0]): + + argwhere = max_conf[i] > conf_thresh + l_box_array = box_array[i, argwhere, :] + l_max_conf = max_conf[i, argwhere] + l_max_id = max_id[i, argwhere] + + bboxes = [] + # nms for each class + for j in range(num_classes): + + cls_argwhere = l_max_id == j + ll_box_array = l_box_array[cls_argwhere, :] + ll_max_conf = l_max_conf[cls_argwhere] + ll_max_id = l_max_id[cls_argwhere] + + keep = nms_cpu(ll_box_array, ll_max_conf, nms_thresh) + + if (keep.size > 0): + ll_box_array = ll_box_array[keep, :] + ll_max_conf = ll_max_conf[keep] + ll_max_id = ll_max_id[keep] + + for k in range(ll_box_array.shape[0]): + bboxes.append([ll_box_array[k, 0], ll_box_array[k, 1], ll_box_array[k, 2], + ll_box_array[k, 3], ll_max_conf[k], ll_max_conf[k], ll_max_id[k]]) + + bboxes_batch.append(bboxes) + + return bboxes_batch +########################################################## + diff --git a/models/cv/detection/yolov4/ixrt/cut_model.py b/models/cv/detection/yolov4/ixrt/cut_model.py new file mode 100644 index 00000000..cf4f88da --- /dev/null +++ b/models/cv/detection/yolov4/ixrt/cut_model.py @@ -0,0 +1,30 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +import onnx +import argparse +from onnxsim import simplify + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--input_model", type=str) + parser.add_argument("--output_model", type=str) + parser.add_argument("--input_names", nargs='+', type=str) + parser.add_argument("--output_names", nargs='+', type=str) + args = parser.parse_args() + return args + +args = parse_args() +onnx.utils.extract_model(args.input_model, args.output_model, args.input_names, args.output_names) +print(" Cut Model Done.") diff --git a/models/cv/detection/yolov4/ixrt/deploy.py b/models/cv/detection/yolov4/ixrt/deploy.py new file mode 100644 index 00000000..084356ec --- /dev/null +++ b/models/cv/detection/yolov4/ixrt/deploy.py @@ -0,0 +1,210 @@ +# !/usr/bin/env python +# -*- coding: utf-8 -*- +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +import argparse +import copy + +from typing import Union, Callable, List + +from tensorrt.deploy.api import * +from tensorrt.deploy.backend.onnx.converter import default_converter +from tensorrt.deploy.backend.torch.executor.operators._operators import to_py_type +from tensorrt.deploy.ir.operator_attr import BaseOperatorAttr, EmptyAttr +from tensorrt.deploy.ir.operator_type import OperatorType as OP +from tensorrt.deploy.ir import operator_attr as attr, Operator, generate_operator_name +from tensorrt.deploy.fusion import BasePass, PatternGraph, build_sequence_graph, GraphMatcher, PassSequence +from tensorrt.deploy.ir import Graph +from tensorrt.deploy.quantizer.quant_operator.base import quant_single_input_operator +from tensorrt.deploy.backend.onnx.converter import convert_onnx_operator +from tensorrt.deploy.api import GraphTransform, create_source, create_target + +class FuseMishPass(BasePass): + def process(self, graph: Graph) -> Graph: + pattern = build_sequence_graph([OP.SOFTPLUS, OP.TANH, OP.MUL]) + + matcher = GraphMatcher(pattern, strict=False) + self.transform = GraphTransform(graph) + matcher.findall(graph, self.fuse_mish) + return graph + + def fuse_mish(self, graph: Graph, pattern_graph: PatternGraph): + softplus = pattern_graph.nodes[0].operator + mul = pattern_graph.nodes[-1].operator + + if not self.can_fused(graph, pattern_graph): + return + + self.transform.delete_operators_between_op_op(softplus, mul) + + mish_op = Operator( + name=generate_operator_name(graph, pattern="Mish_{idx}"), + op_type=OP.MISH, + inputs=copy.copy(softplus.inputs), + outputs=copy.copy(mul.outputs), + ) + mish_op.is_quant_operator = softplus.is_quant_operator and mul.is_quant_operator + graph.add_operator(mish_op) + + def can_fused(self, graph: Graph, pattern_graph: PatternGraph): + softplus = pattern_graph.nodes[0].operator + mul = pattern_graph.nodes[-1].operator + + # 检查 Softplus, tanh 的输出是不是只有一个 OP 使用 + # 如果有多个 OP 使用,则不能融合 + for node in pattern_graph.nodes[:2]: + next_ops = graph.get_next_operators(node.operator) + if len(next_ops) != 1: + return False + + # 检查 Mul 的输入是不是和 Softplus 是同源的 + softplus_prev_op = graph.get_previous_operators(softplus) + if len(softplus_prev_op) != 1: + return False + + mul_prev_op = graph.get_previous_operators(mul) + if len(mul_prev_op) != 2: + return False + + for op in mul_prev_op: + if op is softplus_prev_op[0]: + return True + + return False + + +class Transform: + def __init__(self, graph): + self.t = GraphTransform(graph) + self.graph = graph + + def ReplaceFocus(self, input_edge, outputs, to_op): + input_var = self.graph.get_variable(input_edge) + op = self.graph.get_operator(to_op) + self.t.delete_operators_between_var_op( + from_var=input_var, to_op=op + ) + self.t.make_operator( + "Focus", inputs=input_edge, outputs=outputs + ) + return self.graph + + def AddYoloDecoderOp(self, inputs: list, outputs: list, op_type, **attributes): + if attributes["anchor"] is None: + del attributes["anchor"] + self.t.make_operator( + op_type, inputs=inputs, outputs=outputs, **attributes + ) + return self.graph + + def AddConcatOp(self, inputs: list, outputs, **attributes): + self.t.make_operator( + "Concat", inputs=inputs, outputs=outputs, **attributes + ) + return self.graph + +def customize_ops(graph, args): + t = Transform(graph) + fuse_focus = args.focus_input is not None and args.focus_output is not None and args.focus_last_node is not None + if fuse_focus: + graph = t.ReplaceFocus( + input_edge=args.focus_input, + outputs=args.focus_output, + to_op=args.focus_last_node + ) + decoder_input = args.decoder_input_names + num = len(decoder_input) // 3 + graph = t.AddYoloDecoderOp( + inputs=decoder_input[:num], + outputs=["decoder_8"], + op_type=args.decoder_type, + anchor=args.decoder8_anchor, + num_class=args.num_class, + stride=8, + faster_impl=args.faster + ) + graph = t.AddYoloDecoderOp( + inputs=decoder_input[num:num*2], + outputs=["decoder_16"], + op_type=args.decoder_type, + anchor=args.decoder16_anchor, + num_class=args.num_class, + stride=16, + faster_impl=args.faster + ) + graph = t.AddYoloDecoderOp( + inputs=decoder_input[num*2:num*2+1], + outputs=["decoder_32"], + op_type=args.decoder_type, + anchor=args.decoder32_anchor, + num_class=args.num_class, + stride=32, + faster_impl=args.faster + ) + if args.decoder64_anchor is not None: + graph = t.AddYoloDecoderOp( + inputs=decoder_input[num*2+1:], + outputs=["decoder_64"], + op_type=args.decoder_type, + anchor=args.decoder64_anchor, + num_class=args.num_class, + stride=64, + faster_impl=args.faster + ) + graph = t.AddConcatOp( + inputs=["decoder_8", "decoder_16", "decoder_32", "decoder_64"], + outputs=["output"], + axis=1 + ) + else: + graph = t.AddConcatOp( + inputs=["decoder_32", "decoder_16", "decoder_8"], + outputs=["output"], + axis=1 + ) + + graph.outputs.clear() + graph.add_output("output") + graph.outputs["output"].dtype = "FLOAT" + return graph + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--src", type=str) + parser.add_argument("--dst", type=str) + parser.add_argument("--decoder_type", type=str, choices=["YoloV3Decoder", "YoloV5Decoder", "YoloV7Decoder", "YoloxDecoder"]) + parser.add_argument("--decoder_input_names", nargs='+', type=str) + parser.add_argument("--decoder8_anchor", nargs='*', type=int) + parser.add_argument("--decoder16_anchor", nargs='*', type=int) + parser.add_argument("--decoder32_anchor", nargs='*', type=int) + parser.add_argument("--decoder64_anchor", nargs='*', type=int, default=None) + parser.add_argument("--num_class", type=int, default=80) + parser.add_argument("--faster", type=int, default=1) + parser.add_argument("--focus_input", type=str, default=None) + parser.add_argument("--focus_output", type=str, default=None) + parser.add_argument("--focus_last_node", type=str, default=None) + args = parser.parse_args() + return args + + +if __name__ == "__main__": + + args = parse_args() + graph = create_source(args.src)() + graph = customize_ops(graph, args) + graph = FuseMishPass().process(graph) + create_target(saved_path=args.dst).export(graph) + print("Surged onnx lies on", args.dst) diff --git a/models/cv/detection/yolov4/ixrt/export.py b/models/cv/detection/yolov4/ixrt/export.py new file mode 100644 index 00000000..7c8bbfa5 --- /dev/null +++ b/models/cv/detection/yolov4/ixrt/export.py @@ -0,0 +1,56 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import sys +sys.path.insert(0, "yolov4") +import argparse + +from yolov4.tool.darknet2onnx import * + +def parse_args(): + parser = argparse.ArgumentParser() + + parser.add_argument("--cfg", + type=str, + required=True, + help="darknet cfg path.") + + parser.add_argument("--weight", + type=str, + required=True, + help="darknet weights path.") + + parser.add_argument("--batchsize", + type=int, + required=True, + help="Onnx model batchsize.") + + parser.add_argument("--output", + type=str, + required=True, + help="export onnx model path.") + + args = parser.parse_args() + + return args + +def main(): + args = parse_args() + + transform_to_onnx(args.cfg, args.weight, args.batchsize, args.output) + +if __name__ == "__main__": + main() + diff --git a/models/cv/detection/yolov4/ixrt/inference.py b/models/cv/detection/yolov4/ixrt/inference.py new file mode 100644 index 00000000..5d740507 --- /dev/null +++ b/models/cv/detection/yolov4/ixrt/inference.py @@ -0,0 +1,211 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +import argparse +import glob +import json +import os +import time +import sys +from tqdm import tqdm + +import torch +import numpy as np +import tensorrt +from tensorrt import Dims +import pycuda.autoinit +import pycuda.driver as cuda + +from coco_labels import coco80_to_coco91_class +from common import save2json, box_class85to6 +from common import load_images, prepare_batch +from common import create_engine_context, setup_io_bindings +from common import scale_boxes, post_processing + +from pycocotools.coco import COCO +from pycocotools.cocoeval import COCOeval +from load_ixrt_plugin import load_ixrt_plugin +load_ixrt_plugin() + + + +def main(config): + + # Step1: Load dataloader + images_path = load_images(config.eval_dir) + dataloader = prepare_batch(images_path, config.bsz) + + # Step2: Load Engine + input_name = "input" + host_mem = tensorrt.IHostMemory + logger = tensorrt.Logger(tensorrt.Logger.ERROR) + engine, context = create_engine_context(config.model_engine, logger) + input_idx = engine.get_binding_index(input_name) + context.set_binding_shape(input_idx, Dims((config.bsz,3,config.imgsz,config.imgsz))) + inputs, outputs, allocations = setup_io_bindings(engine, context) + + # Warm up + if config.warm_up > 0: + print("\nWarm Start.") + for i in range(config.warm_up): + context.execute_v2(allocations) + print("Warm Done.") + + json_result = [] + forward_time = 0.0 + class_map = coco80_to_coco91_class() + num_samples = 0 + # Step3: Run on coco dataset + for batch_names, batch_images, batch_shapes in tqdm(zip(*dataloader)): + batch_data = np.ascontiguousarray(batch_images) + data_shape = batch_data.shape + h, w = zip(*batch_shapes) + batch_img_shape = [h, w] + batch_img_id = [int(x.split('.')[0]) for x in batch_names] + + cur_bsz_sample = batch_images.shape[0] + num_samples += cur_bsz_sample + # Set input + input_idx = engine.get_binding_index(input_name) + context.set_binding_shape(input_idx, Dims(data_shape)) + inputs, outputs, allocations = setup_io_bindings(engine, context) + + cuda.memcpy_htod(inputs[0]["allocation"], batch_data) + # Prepare the output data + output = np.zeros(outputs[0]["shape"], outputs[0]["dtype"]) + # print(f"output shape : {output.shape} output type : {output.dtype}") + + # Forward + start_time = time.time() + context.execute_v2(allocations) + end_time = time.time() + forward_time += end_time - start_time + + if config.test_mode == "MAP": + # Fetch output + cuda.memcpy_dtoh(output, outputs[0]["allocation"]) + pred_boxes = post_processing(None, 0.001, 0.6, output) + + pred_results = [] + # Calculate pred box on raw shape + for (pred_box, raw_shape) in zip(pred_boxes, batch_shapes): + h, w = raw_shape + if len(pred_box) == 0:continue # no detection results + pred_box = np.array(pred_box, dtype=np.float32) + pred_box = scale_boxes((config.imgsz, config.imgsz), pred_box, raw_shape, use_letterbox=False) + + pred_results.append(pred_box.tolist()) + + save2json(batch_img_id, pred_results, json_result, class_map) + + fps = num_samples / forward_time + + if config.test_mode == "FPS": + print("FPS : ", fps) + print(f"Performance Check : Test {fps} >= target {config.fps_target}") + if fps >= config.fps_target: + print("pass!") + exit() + else: + print("failed!") + exit(1) + + if config.test_mode == "MAP": + if len(json_result) == 0: + print("Predict zero box!") + exit(1) + + if not os.path.exists(config.pred_dir): + os.makedirs(config.pred_dir) + + pred_json = os.path.join( + config.pred_dir, f"{config.model_name}_{config.precision}_preds.json" + ) + with open(pred_json, "w") as f: + json.dump(json_result, f) + + anno_json = config.coco_gt + anno = COCO(anno_json) # init annotations api + pred = anno.loadRes(pred_json) # init predictions api + eval = COCOeval(anno, pred, "bbox") + + eval.evaluate() + eval.accumulate() + print( + f"==============================eval {config.model_name} {config.precision} coco map ==============================" + ) + eval.summarize() + + map, map50 = eval.stats[:2] + print("MAP@0.5 : ", map50) + print(f"Accuracy Check : Test {map50} >= target {config.map_target}") + if map50 >= config.map_target: + print("pass!") + exit() + else: + print("failed!") + exit(1) + + +def parse_config(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--model_name", type=str, default="YOLOV4", help="YOLOV3 YOLOV4 YOLOV5 YOLOV7 YOLOX" + ) + parser.add_argument("--precision", type=str, choices=["float16", "int8", "float32"], default="int8", + help="The precision of datatype") + parser.add_argument("--test_mode", type=str, default="FPS", help="FPS MAP") + parser.add_argument( + "--model_engine", + type=str, + default="", + help="model engine path", + ) + parser.add_argument( + "--coco_gt", + type=str, + default="data/datasets/cv/coco2017/annotations/instances_val2017.json", + help="coco instances_val2017.json", + ) + parser.add_argument("--warm_up", type=int, default=3, help="warm_up count") + parser.add_argument("--loop_count", type=int, default=-1, help="loop count") + parser.add_argument( + "--eval_dir", + type=str, + default="data/datasets/cv/coco2017/val2017", + help="coco image dir", + ) + parser.add_argument("--bsz", type=int, default=32, help="test batch size") + parser.add_argument( + "--imgsz", + "--img", + "--img-size", + type=int, + default=608, + help="inference size h,w", + ) + parser.add_argument("--pred_dir", type=str, default=".", help="pred save json dirs") + parser.add_argument("--map_target", type=float, default=0.56, help="target mAP") + parser.add_argument("--fps_target", type=float, default=-1.0, help="target fps") + + config = parser.parse_args() + print("config:", config) + return config + + +if __name__ == "__main__": + config = parse_config() + main(config) diff --git a/models/cv/detection/yolov4/ixrt/load_ixrt_plugin.py b/models/cv/detection/yolov4/ixrt/load_ixrt_plugin.py new file mode 100644 index 00000000..2bb0abc2 --- /dev/null +++ b/models/cv/detection/yolov4/ixrt/load_ixrt_plugin.py @@ -0,0 +1,26 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +import ctypes +import tensorrt +from os.path import join, dirname, exists +def load_ixrt_plugin(logger=tensorrt.Logger(tensorrt.Logger.INFO), namespace="", dynamic_path=""): + if not dynamic_path: + dynamic_path = join(dirname(tensorrt.__file__), "lib", "libixrt_plugin.so") + if not exists(dynamic_path): + raise FileNotFoundError( + f"The ixrt_plugin lib {dynamic_path} is not existed, please provided effective plugin path!") + ctypes.CDLL(dynamic_path) + tensorrt.init_libnvinfer_plugins(logger, namespace) + print(f"Loaded plugin from {dynamic_path}") diff --git a/models/cv/detection/yolov4/ixrt/quant.py b/models/cv/detection/yolov4/ixrt/quant.py new file mode 100644 index 00000000..70265cbc --- /dev/null +++ b/models/cv/detection/yolov4/ixrt/quant.py @@ -0,0 +1,105 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +import os +import cv2 +import random +import argparse +import numpy as np +from tensorrt.deploy import static_quantize + +import torch +import torchvision.datasets +from torch.utils.data import DataLoader +from common import letterbox + + +def setseed(seed=42): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--model_name", type=str) + parser.add_argument("--model", type=str, default="yolov4_bs16_without_decoder.onnx") + parser.add_argument("--dataset_dir", type=str, default="./coco2017/val2017") + parser.add_argument("--ann_file", type=str, default="./coco2017/annotations/instances_val2017.json") + parser.add_argument("--observer", type=str, choices=["hist_percentile", "percentile", "minmax", "entropy", "ema"], default="hist_percentile") + parser.add_argument("--disable_quant_names", nargs='*', type=str) + parser.add_argument("--save_quant_model", type=str, help="save the quantization model path", default=None) + parser.add_argument("--bsz", type=int, default=16) + parser.add_argument("--step", type=int, default=32) + parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--imgsz", type=int, default=608) + parser.add_argument("--use_letterbox", action="store_true") + args = parser.parse_args() + return args + +args = parse_args() +setseed(args.seed) +model_name = args.model_name + + +def get_dataloader(data_dir, step=32, batch_size=16, new_shape=[608, 608], use_letterbox=False): + num = step * batch_size + val_list = [os.path.join(data_dir, x) for x in os.listdir(data_dir)] + random.shuffle(val_list) + pic_list = val_list[:num] + + calibration_dataset = [] + for file_path in pic_list: + pic_data = cv2.imread(file_path) + org_img = pic_data + assert org_img is not None, 'Image not Found ' + file_path + h0, w0 = org_img.shape[:2] + + if use_letterbox: + img, ratio, dwdh = letterbox(org_img, new_shape=(new_shape[1], new_shape[0]), auto=False, scaleup=True) + else: + img = cv2.resize(org_img, new_shape) + img = img.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB + img = np.ascontiguousarray(img) / 255.0 # 0~1 np array + img = torch.from_numpy(img).float() + + calibration_dataset.append(img) + + calibration_dataloader = DataLoader( + calibration_dataset, + shuffle=True, + batch_size=batch_size, + drop_last=True + ) + return calibration_dataloader + +dataloader = get_dataloader( + data_dir=args.dataset_dir, + step=args.step, + batch_size=args.bsz, + new_shape=(args.imgsz, args.imgsz), + use_letterbox=args.use_letterbox +) + +dirname = os.path.dirname(args.save_quant_model) +quant_json_path = os.path.join(dirname, f"quantized_{model_name}.json") + +static_quantize(args.model, + calibration_dataloader=dataloader, + save_quant_onnx_path=args.save_quant_model, + save_quant_params_path=quant_json_path, + observer=args.observer, + data_preprocess=lambda x: x.to("cuda"), + quant_format="qdq", + disable_quant_names=args.disable_quant_names) diff --git a/models/cv/detection/yolov4/ixrt/scripts/infer_yolov4darknet_fp16_accuary.sh b/models/cv/detection/yolov4/ixrt/scripts/infer_yolov4darknet_fp16_accuary.sh new file mode 100644 index 00000000..e176357c --- /dev/null +++ b/models/cv/detection/yolov4/ixrt/scripts/infer_yolov4darknet_fp16_accuary.sh @@ -0,0 +1,88 @@ +#!/bin/bash +EXIT_STATUS=0 +check_status() +{ + if ((${PIPESTATUS[0]} != 0));then + EXIT_STATUS=1 + fi +} + +PROJ_DIR=$(cd $(dirname $0);cd ../; pwd) +DATASETS_DIR="${PROJ_DIR}/data/coco" +COCO_GT=${DATASETS_DIR}/annotations/instances_val2017.json +EVAL_DIR=${DATASETS_DIR}/images/val2017 +CHECKPOINTS_DIR="${PROJ_DIR}/data" +RUN_DIR="${PROJ_DIR}" +ORIGINE_MODEL=${CHECKPOINTS_DIR} + +echo CHECKPOINTS_DIR : ${CHECKPOINTS_DIR} +echo DATASETS_DIR : ${DATASETS_DIR} +echo RUN_DIR : ${RUN_DIR} +echo ====================== Model Info ====================== +echo Model Name : yolov4_darknet +echo Onnx Path : ${ORIGINE_MODEL} + +BATCH_SIZE=16 +CURRENT_MODEL=${CHECKPOINTS_DIR}/yolov4_sim.onnx + +# Cut decoder part +echo "Cut decoder part" +FINAL_MODEL=${CHECKPOINTS_DIR}/yolov4_bs${BATCH_SIZE}_without_decoder.onnx +if [ -f $FINAL_MODEL ];then + echo " "CUT Model Skip, $FINAL_MODEL has been existed +else + python3 ${RUN_DIR}/cut_model.py \ + --input_model ${CURRENT_MODEL} \ + --output_model ${FINAL_MODEL} \ + --input_names input \ + --output_names /models.138/conv94/Conv_output_0 /models.149/conv102/Conv_output_0 /models.160/conv110/Conv_output_0 + echo " "Generate ${FINAL_MODEL} +fi +CURRENT_MODEL=${FINAL_MODEL} + +# add decoder op +FINAL_MODEL=${CHECKPOINTS_DIR}/yolov4_bs${BATCH_SIZE}_with_decoder.onnx +if [ -f $FINAL_MODEL ];then + echo " "Add Decoder Skip, $FINAL_MODEL has been existed +else + python3 ${RUN_DIR}/deploy.py \ + --src ${CURRENT_MODEL} \ + --dst ${FINAL_MODEL} \ + --decoder_type YoloV3Decoder \ + --decoder_input_names /models.138/conv94/Conv_output_0 /models.149/conv102/Conv_output_0 /models.160/conv110/Conv_output_0 \ + --decoder8_anchor 12 16 19 36 40 28 \ + --decoder16_anchor 36 75 76 55 72 146 \ + --decoder32_anchor 142 110 192 243 459 401 + echo " "Generate ${FINAL_MODEL} +fi +CURRENT_MODEL=${FINAL_MODEL} + +# Build Engine +echo Build Engine +ENGINE_FILE=${CHECKPOINTS_DIR}/yolov4_fp16.engine +if [ -f $ENGINE_FILE ];then + echo " "Build Engine Skip, $ENGINE_FILE has been existed +else + python3 ${RUN_DIR}/build_engine.py \ + --precision float16 \ + --model ${CURRENT_MODEL} \ + --engine ${ENGINE_FILE} + echo " "Generate Engine ${ENGINE_FILE} +fi + +# Inference +echo Inference +RUN_BATCH_SIZE=16 +python3 ${RUN_DIR}/inference.py \ + --test_mode MAP \ + --model_engine ${ENGINE_FILE} \ + --warm_up 2 \ + --bsz ${RUN_BATCH_SIZE} \ + --imgsz 608 \ + --loop_count 10 \ + --eval_dir ${EVAL_DIR} \ + --coco_gt ${COCO_GT} \ + --pred_dir ${CHECKPOINTS_DIR} \ + --precision float16 \ + --map_target 0.30; check_status +exit ${EXIT_STATUS} diff --git a/models/cv/detection/yolov4/ixrt/scripts/infer_yolov4darknet_fp16_performance.sh b/models/cv/detection/yolov4/ixrt/scripts/infer_yolov4darknet_fp16_performance.sh new file mode 100644 index 00000000..0570764f --- /dev/null +++ b/models/cv/detection/yolov4/ixrt/scripts/infer_yolov4darknet_fp16_performance.sh @@ -0,0 +1,88 @@ +#!/bin/bash +EXIT_STATUS=0 +check_status() +{ + if ((${PIPESTATUS[0]} != 0));then + EXIT_STATUS=1 + fi +} + +PROJ_DIR=$(cd $(dirname $0);cd ../; pwd) +DATASETS_DIR="${PROJ_DIR}/data/coco" +COCO_GT=${DATASETS_DIR}/annotations/instances_val2017.json +EVAL_DIR=${DATASETS_DIR}/images/val2017 +CHECKPOINTS_DIR="${PROJ_DIR}/data" +RUN_DIR="${PROJ_DIR}" +ORIGINE_MODEL=${CHECKPOINTS_DIR} + +echo CHECKPOINTS_DIR : ${CHECKPOINTS_DIR} +echo DATASETS_DIR : ${DATASETS_DIR} +echo RUN_DIR : ${RUN_DIR} +echo ====================== Model Info ====================== +echo Model Name : yolov4_darknet +echo Onnx Path : ${ORIGINE_MODEL} + +BATCH_SIZE=16 +CURRENT_MODEL=${CHECKPOINTS_DIR}/yolov4_sim.onnx + +# Cut decoder part +echo "Cut decoder part" +FINAL_MODEL=${CHECKPOINTS_DIR}/yolov4_bs${BATCH_SIZE}_without_decoder.onnx +if [ -f $FINAL_MODEL ];then + echo " "CUT Model Skip, $FINAL_MODEL has been existed +else + python3 ${RUN_DIR}/cut_model.py \ + --input_model ${CURRENT_MODEL} \ + --output_model ${FINAL_MODEL} \ + --input_names input \ + --output_names /models.138/conv94/Conv_output_0 /models.149/conv102/Conv_output_0 /models.160/conv110/Conv_output_0 + echo " "Generate ${FINAL_MODEL} +fi +CURRENT_MODEL=${FINAL_MODEL} + +# add decoder op +FINAL_MODEL=${CHECKPOINTS_DIR}/yolov4_bs${BATCH_SIZE}_with_decoder.onnx +if [ -f $FINAL_MODEL ];then + echo " "Add Decoder Skip, $FINAL_MODEL has been existed +else + python3 ${RUN_DIR}/deploy.py \ + --src ${CURRENT_MODEL} \ + --dst ${FINAL_MODEL} \ + --decoder_type YoloV3Decoder \ + --decoder_input_names /models.138/conv94/Conv_output_0 /models.149/conv102/Conv_output_0 /models.160/conv110/Conv_output_0 \ + --decoder8_anchor 12 16 19 36 40 28 \ + --decoder16_anchor 36 75 76 55 72 146 \ + --decoder32_anchor 142 110 192 243 459 401 + echo " "Generate ${FINAL_MODEL} +fi +CURRENT_MODEL=${FINAL_MODEL} + +# Build Engine +echo Build Engine +ENGINE_FILE=${CHECKPOINTS_DIR}/yolov4_fp16.engine +if [ -f $ENGINE_FILE ];then + echo " "Build Engine Skip, $ENGINE_FILE has been existed +else + python3 ${RUN_DIR}/build_engine.py \ + --precision float16 \ + --model ${CURRENT_MODEL} \ + --engine ${ENGINE_FILE} + echo " "Generate Engine ${ENGINE_FILE} +fi + +# Inference +echo Inference +RUN_BATCH_SIZE=16 +python3 ${RUN_DIR}/inference.py \ + --test_mode FPS \ + --model_engine ${ENGINE_FILE} \ + --warm_up 2 \ + --bsz ${RUN_BATCH_SIZE} \ + --imgsz 608 \ + --loop_count 10 \ + --eval_dir ${EVAL_DIR} \ + --coco_gt ${COCO_GT} \ + --pred_dir ${CHECKPOINTS_DIR} \ + --precision float16 \ + --map_target 0.30; check_status +exit ${EXIT_STATUS} diff --git a/models/cv/detection/yolov4/ixrt/scripts/infer_yolov4darknet_int8_accuary.sh b/models/cv/detection/yolov4/ixrt/scripts/infer_yolov4darknet_int8_accuary.sh new file mode 100644 index 00000000..1c99cba5 --- /dev/null +++ b/models/cv/detection/yolov4/ixrt/scripts/infer_yolov4darknet_int8_accuary.sh @@ -0,0 +1,106 @@ +#!/bin/bash +EXIT_STATUS=0 +check_status() +{ + if ((${PIPESTATUS[0]} != 0));then + EXIT_STATUS=1 + fi +} + +PROJ_DIR=$(cd $(dirname $0);cd ../; pwd) +DATASETS_DIR="${PROJ_DIR}/data/coco" +COCO_GT=${DATASETS_DIR}/annotations/instances_val2017.json +EVAL_DIR=${DATASETS_DIR}/images/val2017 +CHECKPOINTS_DIR="${PROJ_DIR}/data" +RUN_DIR="${PROJ_DIR}" +ORIGINE_MODEL=${CHECKPOINTS_DIR} + +echo CHECKPOINTS_DIR : ${CHECKPOINTS_DIR} +echo DATASETS_DIR : ${DATASETS_DIR} +echo RUN_DIR : ${RUN_DIR} +echo ====================== Model Info ====================== +echo Model Name : yolov4_darknet +echo Onnx Path : ${ORIGINE_MODEL} + +BATCH_SIZE=16 +CURRENT_MODEL=${CHECKPOINTS_DIR}/yolov4_sim.onnx + +# Cut decoder part +echo "Cut decoder part" +FINAL_MODEL=${CHECKPOINTS_DIR}/yolov4_bs${BATCH_SIZE}_without_decoder.onnx +if [ -f $FINAL_MODEL ];then + echo " "CUT Model Skip, $FINAL_MODEL has been existed +else + python3 ${RUN_DIR}/cut_model.py \ + --input_model ${CURRENT_MODEL} \ + --output_model ${FINAL_MODEL} \ + --input_names input \ + --output_names /models.138/conv94/Conv_output_0 /models.149/conv102/Conv_output_0 /models.160/conv110/Conv_output_0 + echo " "Generate ${FINAL_MODEL} +fi +CURRENT_MODEL=${FINAL_MODEL} + +# quant +FINAL_MODEL=${CHECKPOINTS_DIR}/quantized_yolov4_bs${BATCH_SIZE}_without_decoder.onnx +if [ -f $FINAL_MODEL ];then + echo " "Change Batchsize Skip, $FINAL_MODEL has been existed +else + python3 ${RUN_DIR}/quant.py \ + --model_name "YOLOV4_DARKNET" \ + --model ${CURRENT_MODEL} \ + --bsz ${BATCH_SIZE} \ + --dataset_dir ${EVAL_DIR} \ + --ann_file ${COCO_GT} \ + --observer "hist_percentile" \ + --save_quant_model ${FINAL_MODEL} \ + --imgsz 608 + echo " "Generate ${FINAL_MODEL} +fi +CURRENT_MODEL=${FINAL_MODEL} + +# add decoder op +FINAL_MODEL=${CHECKPOINTS_DIR}/quantized_yolov4_bs${BATCH_SIZE}_with_decoder.onnx +if [ -f $FINAL_MODEL ];then + echo " "Add Decoder Skip, $FINAL_MODEL has been existed +else + python3 ${RUN_DIR}/deploy.py \ + --src ${CURRENT_MODEL} \ + --dst ${FINAL_MODEL} \ + --decoder_type YoloV3Decoder \ + --decoder_input_names /models.138/conv94/Conv_output_0 /models.149/conv102/Conv_output_0 /models.160/conv110/Conv_output_0 \ + --decoder8_anchor 12 16 19 36 40 28 \ + --decoder16_anchor 36 75 76 55 72 146 \ + --decoder32_anchor 142 110 192 243 459 401 + echo " "Generate ${FINAL_MODEL} +fi +CURRENT_MODEL=${FINAL_MODEL} + +# Build Engine +echo Build Engine +ENGINE_FILE=${CHECKPOINTS_DIR}/yolov4_int8.engine +if [ -f $ENGINE_FILE ];then + echo " "Build Engine Skip, $ENGINE_FILE has been existed +else + python3 ${RUN_DIR}/build_engine.py \ + --precision int8 \ + --model ${CURRENT_MODEL} \ + --engine ${ENGINE_FILE} + echo " "Generate Engine ${ENGINE_FILE} +fi + +# Inference +echo Inference +RUN_BATCH_SIZE=16 +python3 ${RUN_DIR}/inference.py \ + --test_mode MAP \ + --model_engine ${ENGINE_FILE} \ + --warm_up 2 \ + --bsz ${RUN_BATCH_SIZE} \ + --imgsz 608 \ + --loop_count 10 \ + --eval_dir ${EVAL_DIR} \ + --coco_gt ${COCO_GT} \ + --pred_dir ${CHECKPOINTS_DIR} \ + --precision int8 \ + --map_target 0.30; check_status +exit ${EXIT_STATUS} diff --git a/models/cv/detection/yolov4/ixrt/scripts/infer_yolov4darknet_int8_performance.sh b/models/cv/detection/yolov4/ixrt/scripts/infer_yolov4darknet_int8_performance.sh new file mode 100644 index 00000000..5c9108f2 --- /dev/null +++ b/models/cv/detection/yolov4/ixrt/scripts/infer_yolov4darknet_int8_performance.sh @@ -0,0 +1,106 @@ +#!/bin/bash +EXIT_STATUS=0 +check_status() +{ + if ((${PIPESTATUS[0]} != 0));then + EXIT_STATUS=1 + fi +} + +PROJ_DIR=$(cd $(dirname $0);cd ../; pwd) +DATASETS_DIR="${PROJ_DIR}/data/coco" +COCO_GT=${DATASETS_DIR}/annotations/instances_val2017.json +EVAL_DIR=${DATASETS_DIR}/images/val2017 +CHECKPOINTS_DIR="${PROJ_DIR}/data" +RUN_DIR="${PROJ_DIR}" +ORIGINE_MODEL=${CHECKPOINTS_DIR} + +echo CHECKPOINTS_DIR : ${CHECKPOINTS_DIR} +echo DATASETS_DIR : ${DATASETS_DIR} +echo RUN_DIR : ${RUN_DIR} +echo ====================== Model Info ====================== +echo Model Name : yolov4_darknet +echo Onnx Path : ${ORIGINE_MODEL} + +BATCH_SIZE=16 +CURRENT_MODEL=${CHECKPOINTS_DIR}/yolov4_sim.onnx + +# Cut decoder part +echo "Cut decoder part" +FINAL_MODEL=${CHECKPOINTS_DIR}/yolov4_bs${BATCH_SIZE}_without_decoder.onnx +if [ -f $FINAL_MODEL ];then + echo " "CUT Model Skip, $FINAL_MODEL has been existed +else + python3 ${RUN_DIR}/cut_model.py \ + --input_model ${CURRENT_MODEL} \ + --output_model ${FINAL_MODEL} \ + --input_names input \ + --output_names /models.138/conv94/Conv_output_0 /models.149/conv102/Conv_output_0 /models.160/conv110/Conv_output_0 + echo " "Generate ${FINAL_MODEL} +fi +CURRENT_MODEL=${FINAL_MODEL} + +# quant +FINAL_MODEL=${CHECKPOINTS_DIR}/quantized_yolov4_bs${BATCH_SIZE}_without_decoder.onnx +if [ -f $FINAL_MODEL ];then + echo " "Change Batchsize Skip, $FINAL_MODEL has been existed +else + python3 ${RUN_DIR}/quant.py \ + --model_name "YOLOV4_DARKNET" \ + --model ${CURRENT_MODEL} \ + --bsz ${BATCH_SIZE} \ + --dataset_dir ${EVAL_DIR} \ + --ann_file ${COCO_GT} \ + --observer "hist_percentile" \ + --save_quant_model ${FINAL_MODEL} \ + --imgsz 608 + echo " "Generate ${FINAL_MODEL} +fi +CURRENT_MODEL=${FINAL_MODEL} + +# add decoder op +FINAL_MODEL=${CHECKPOINTS_DIR}/quantized_yolov4_bs${BATCH_SIZE}_with_decoder.onnx +if [ -f $FINAL_MODEL ];then + echo " "Add Decoder Skip, $FINAL_MODEL has been existed +else + python3 ${RUN_DIR}/deploy.py \ + --src ${CURRENT_MODEL} \ + --dst ${FINAL_MODEL} \ + --decoder_type YoloV3Decoder \ + --decoder_input_names /models.138/conv94/Conv_output_0 /models.149/conv102/Conv_output_0 /models.160/conv110/Conv_output_0 \ + --decoder8_anchor 12 16 19 36 40 28 \ + --decoder16_anchor 36 75 76 55 72 146 \ + --decoder32_anchor 142 110 192 243 459 401 + echo " "Generate ${FINAL_MODEL} +fi +CURRENT_MODEL=${FINAL_MODEL} + +# Build Engine +echo Build Engine +ENGINE_FILE=${CHECKPOINTS_DIR}/yolov4_int8.engine +if [ -f $ENGINE_FILE ];then + echo " "Build Engine Skip, $ENGINE_FILE has been existed +else + python3 ${RUN_DIR}/build_engine.py \ + --precision int8 \ + --model ${CURRENT_MODEL} \ + --engine ${ENGINE_FILE} + echo " "Generate Engine ${ENGINE_FILE} +fi + +# Inference +echo Inference +RUN_BATCH_SIZE=16 +python3 ${RUN_DIR}/inference.py \ + --test_mode FPS \ + --model_engine ${ENGINE_FILE} \ + --warm_up 2 \ + --bsz ${RUN_BATCH_SIZE} \ + --imgsz 608 \ + --loop_count 10 \ + --eval_dir ${EVAL_DIR} \ + --coco_gt ${COCO_GT} \ + --pred_dir ${CHECKPOINTS_DIR} \ + --precision int8 \ + --map_target 0.30; check_status +exit ${EXIT_STATUS} -- Gitee From b4af958a28881c1ec0f0331610a49e574ea751bb Mon Sep 17 00:00:00 2001 From: may Date: Thu, 18 Jul 2024 06:11:24 +0000 Subject: [PATCH 2/2] update models/cv/detection/yolov4/ixrt/README.md. Signed-off-by: may --- models/cv/detection/yolov4/ixrt/README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/models/cv/detection/yolov4/ixrt/README.md b/models/cv/detection/yolov4/ixrt/README.md index abbeb2bf..ce42d187 100644 --- a/models/cv/detection/yolov4/ixrt/README.md +++ b/models/cv/detection/yolov4/ixrt/README.md @@ -19,6 +19,7 @@ pip3 install tqdm pip3 install onnx pip3 install onnxsim pip3 install pycocotools +pip3 install pycuda ``` ### Download -- Gitee