From 0e1fcc3b7937a664e88b6e1b1189dbdbb993d1e4 Mon Sep 17 00:00:00 2001 From: "xiaomei.wang" Date: Thu, 17 Oct 2024 13:36:12 +0800 Subject: [PATCH] Add cspdarknet50 fp16/int8 IxRT inference. --- .../cspdarknet50/ixrt/README.md | 73 ++++++-- .../cspdarknet50/ixrt/build_engine.py | 52 ++++++ .../cspdarknet50/ixrt/build_i8_engine.py | 113 ++++++++++++ .../cspdarknet50/ixrt/calibration_dataset.py | 112 ++++++++++++ .../cspdarknet50/ixrt/common.py | 80 +++++++++ .../ixrt/config/CSPDARKNET50_CONFIG | 33 ++++ .../cspdarknet50/ixrt/export.py | 76 ++++++++ .../cspdarknet50/ixrt/inference.py | 158 +++++++++++++++++ .../classification/cspdarknet50/ixrt/quant.py | 166 ++++++++++++++++++ .../ixrt/refine_utils/__init__.py | 0 .../cspdarknet50/ixrt/refine_utils/common.py | 36 ++++ .../ixrt/refine_utils/linear_pass.py | 113 ++++++++++++ .../ixrt/refine_utils/matmul_to_gemm_pass.py | 54 ++++++ .../infer_cspdarknet50_fp16_accuracy.sh | 92 ++++++++++ .../infer_cspdarknet50_fp16_performance.sh | 92 ++++++++++ .../infer_cspdarknet50_int8_accuracy.sh | 121 +++++++++++++ .../infer_cspdarknet50_int8_performance.sh | 122 +++++++++++++ 17 files changed, 1481 insertions(+), 12 deletions(-) create mode 100644 models/cv/classification/cspdarknet50/ixrt/build_engine.py create mode 100644 models/cv/classification/cspdarknet50/ixrt/build_i8_engine.py create mode 100644 models/cv/classification/cspdarknet50/ixrt/calibration_dataset.py create mode 100755 models/cv/classification/cspdarknet50/ixrt/common.py create mode 100644 models/cv/classification/cspdarknet50/ixrt/config/CSPDARKNET50_CONFIG create mode 100644 models/cv/classification/cspdarknet50/ixrt/export.py create mode 100755 models/cv/classification/cspdarknet50/ixrt/inference.py create mode 100644 models/cv/classification/cspdarknet50/ixrt/quant.py create mode 100644 models/cv/classification/cspdarknet50/ixrt/refine_utils/__init__.py create mode 100644 models/cv/classification/cspdarknet50/ixrt/refine_utils/common.py create mode 100644 models/cv/classification/cspdarknet50/ixrt/refine_utils/linear_pass.py create mode 100644 models/cv/classification/cspdarknet50/ixrt/refine_utils/matmul_to_gemm_pass.py create mode 100644 models/cv/classification/cspdarknet50/ixrt/scripts/infer_cspdarknet50_fp16_accuracy.sh create mode 100644 models/cv/classification/cspdarknet50/ixrt/scripts/infer_cspdarknet50_fp16_performance.sh create mode 100644 models/cv/classification/cspdarknet50/ixrt/scripts/infer_cspdarknet50_int8_accuracy.sh create mode 100644 models/cv/classification/cspdarknet50/ixrt/scripts/infer_cspdarknet50_int8_performance.sh diff --git a/models/cv/classification/cspdarknet50/ixrt/README.md b/models/cv/classification/cspdarknet50/ixrt/README.md index be4b3e98..27036f32 100644 --- a/models/cv/classification/cspdarknet50/ixrt/README.md +++ b/models/cv/classification/cspdarknet50/ixrt/README.md @@ -1,34 +1,83 @@ -# MODEL_NAME +# CSPDarkNet50 ## Description -A brief introduction about this model. +CSPDarkNet50 is an enhanced convolutional neural network architecture that reduces redundant computations by integrating cross-stage partial network features and truncating gradient flow, thereby maintaining high accuracy while lowering computational costs. ## Setup -### Install (remove this step if not necessary) +### Install -### Download (remove this step if not necessary) +```bash +# Install libGL +## CentOS +yum install -y mesa-libGL +## Ubuntu +apt install -y libgl1-mesa-dev + +pip3 install onnx +pip3 install tqdm +pip3 install onnxsim +pip3 install ppq +pip3 install mmcv==1.5.3 +pip3 install mmcls +``` + +### Download + +Pretrained model: + +Dataset: to download the validation dataset. + +### Model Conversion + +```bash +# git clone mmpretrain +git clone -b v0.24.0 https://github.com/open-mmlab/mmpretrain.git + +# export onnx model +python3 export.py --cfg mmpretrain/configs/cspnet/cspdarknet50_8xb32_in1k.py --weight cspdarknet50_3rdparty_8xb32_in1k_20220329-bd275287.pth --output cspdarknet50.onnx -### Model Conversion (remove this step if not necessary) +# Use onnxsim optimize onnx model +mkdir -p data/checkpoints/cspdarknet50_ckpt +onnxsim cspdarknet50.onnx data/checkpoints/cspdarknet50_ckpt/cspdarknet50_sim.onnx + +``` ## Inference +```bash +export DATASETS_DIR=/Path/to/imagenet_val/ +export CHECKPOINTS_DIR=/Path/to/data/checkpoints/cspdarknet50_ckpt +export CONFIG_DIR=./config/CSPDARKNET50_CONFIG +``` + ### FP16 ```bash -bash test_fp16.sh +# Accuracy +bash scripts/infer_cspdarknet50_fp16_accuracy.sh +# Performance +bash scripts/infer_cspdarknet50_fp16_performance.sh ``` ### INT8 - ```bash -bash test_int8.sh +# Accuracy +bash scripts/infer_cspdarknet50_int8_accuracy.sh +# Performance +bash scripts/infer_cspdarknet50_int8_performance.sh ``` -## Results (leave empty for testing team to complete) +### INT8 + +## Results + +| Model | BatchSize | Precision | FPS | Top-1(%) | Top-5(%) | +| ------------ | --------- | --------- | -------- | -------- | -------- | +| CSPDarkNet50 | 32 | FP16 | 3282.318 | 79.09 | 94.52 | +| CSPDarkNet50 | 32 | INT8 | 6335.86 | 75.49 | 92.66 | -Model | BatchSize | Precision | FPS | ACC -------|-----------|-----------|-----|---- -MODEL_NAME | | | | +## Reference +CSPDarkNet50: \ No newline at end of file diff --git a/models/cv/classification/cspdarknet50/ixrt/build_engine.py b/models/cv/classification/cspdarknet50/ixrt/build_engine.py new file mode 100644 index 00000000..126da5e6 --- /dev/null +++ b/models/cv/classification/cspdarknet50/ixrt/build_engine.py @@ -0,0 +1,52 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +import os +import cv2 +import argparse +import numpy as np + +import torch +import tensorrt + +def main(config): + IXRT_LOGGER = tensorrt.Logger(tensorrt.Logger.WARNING) + builder = tensorrt.Builder(IXRT_LOGGER) + EXPLICIT_BATCH = 1 << (int)(tensorrt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH) + network = builder.create_network(EXPLICIT_BATCH) + build_config = builder.create_builder_config() + parser = tensorrt.OnnxParser(network, IXRT_LOGGER) + parser.parse_from_file(config.model) + + precision = tensorrt.BuilderFlag.INT8 if config.precision == "int8" else tensorrt.BuilderFlag.FP16 + # print("precision : ", precision) + build_config.set_flag(precision) + + plan = builder.build_serialized_network(network, build_config) + engine_file_path = config.engine + with open(engine_file_path, "wb") as f: + f.write(plan) + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--model", type=str) + parser.add_argument("--precision", type=str, choices=["float16", "int8", "float32"], default="int8", + help="The precision of datatype") + parser.add_argument("--engine", type=str, default=None) + args = parser.parse_args() + return args + +if __name__ == "__main__": + args = parse_args() + main(args) \ No newline at end of file diff --git a/models/cv/classification/cspdarknet50/ixrt/build_i8_engine.py b/models/cv/classification/cspdarknet50/ixrt/build_i8_engine.py new file mode 100644 index 00000000..6e356260 --- /dev/null +++ b/models/cv/classification/cspdarknet50/ixrt/build_i8_engine.py @@ -0,0 +1,113 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import argparse +import json +import os + +import tensorrt +import tensorrt as trt + +TRT_LOGGER = trt.Logger(tensorrt.Logger.VERBOSE) + +EXPLICIT_BATCH = 1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH) + + +def GiB(val): + return val * 1 << 30 + + +def json_load(filename): + with open(filename) as json_file: + data = json.load(json_file) + return data + + +def setDynamicRange(network, json_file): + """Sets ranges for network layers.""" + quant_param_json = json_load(json_file) + act_quant = quant_param_json["act_quant_info"] + + for i in range(network.num_inputs): + input_tensor = network.get_input(i) + if act_quant.__contains__(input_tensor.name): + print(input_tensor.name) + value = act_quant[input_tensor.name] + tensor_max = abs(value) + tensor_min = -abs(value) + input_tensor.dynamic_range = (tensor_min, tensor_max) + + for i in range(network.num_layers): + layer = network.get_layer(i) + + for output_index in range(layer.num_outputs): + tensor = layer.get_output(output_index) + + if act_quant.__contains__(tensor.name): + value = act_quant[tensor.name] + tensor_max = abs(value) + tensor_min = -abs(value) + tensor.dynamic_range = (tensor_min, tensor_max) + else: + print("\033[1;32m%s\033[0m" % tensor.name) + + +def build_engine(onnx_file, json_file, engine_file): + builder = trt.Builder(TRT_LOGGER) + network = builder.create_network(EXPLICIT_BATCH) + + config = builder.create_builder_config() + + # If it is a dynamic onnx model , you need to add the following. + # profile = builder.create_optimization_profile() + # profile.set_shape("input_name", (batch, channels, min_h, min_w), (batch, channels, opt_h, opt_w), (batch, channels, max_h, max_w)) + # config.add_optimization_profile(profile) + + parser = trt.OnnxParser(network, TRT_LOGGER) + # config.max_workspace_size = GiB(1) + if not os.path.exists(onnx_file): + quit("ONNX file {} not found".format(onnx_file)) + + with open(onnx_file, "rb") as model: + if not parser.parse(model.read()): + print("ERROR: Failed to parse the ONNX file.") + for error in range(parser.num_errors): + print(parser.get_error(error)) + return None + + config.set_flag(trt.BuilderFlag.INT8) + + setDynamicRange(network, json_file) + + engine = builder.build_engine(network, config) + + with open(engine_file, "wb") as f: + f.write(engine.serialize()) + + +if __name__ == "__main__": + # Add plugins if needed + # import ctypes + # ctypes.CDLL("libmmdeploy_tensorrt_ops.so") + parser = argparse.ArgumentParser( + description="Writing qparams to onnx to convert tensorrt engine." + ) + parser.add_argument("--onnx", type=str, default=None) + parser.add_argument("--qparam_json", type=str, default=None) + parser.add_argument("--engine", type=str, default=None) + arg = parser.parse_args() + + build_engine(arg.onnx, arg.qparam_json, arg.engine) + print("\033[1;32mgenerate %s\033[0m" % arg.engine) \ No newline at end of file diff --git a/models/cv/classification/cspdarknet50/ixrt/calibration_dataset.py b/models/cv/classification/cspdarknet50/ixrt/calibration_dataset.py new file mode 100644 index 00000000..442a5602 --- /dev/null +++ b/models/cv/classification/cspdarknet50/ixrt/calibration_dataset.py @@ -0,0 +1,112 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +import os + +import torch +import torchvision.datasets +from torch.utils.data import DataLoader +from torchvision import models +from torchvision import transforms as T + + +class CalibrationImageNet(torchvision.datasets.ImageFolder): + def __init__(self, *args, **kwargs): + super(CalibrationImageNet, self).__init__(*args, **kwargs) + img2label_path = os.path.join(self.root, "val_map.txt") + if not os.path.exists(img2label_path): + raise FileNotFoundError(f"Not found label file `{img2label_path}`.") + + self.img2label_map = self.make_img2label_map(img2label_path) + + def make_img2label_map(self, path): + with open(path) as f: + lines = f.readlines() + + img2lable_map = dict() + for line in lines: + line = line.lstrip().rstrip().split("\t") + if len(line) != 2: + continue + img_name, label = line + img_name = img_name.strip() + if img_name in [None, ""]: + continue + label = int(label.strip()) + img2lable_map[img_name] = label + return img2lable_map + + def __getitem__(self, index): + path, target = self.samples[index] + sample = self.loader(path) + if self.transform is not None: + sample = self.transform(sample) + # if self.target_transform is not None: + # target = self.target_transform(target) + img_name = os.path.basename(path) + target = self.img2label_map[img_name] + + return sample, target + + +def create_dataloaders(data_path, num_samples=1024, img_sz=224, batch_size=2, workers=0): + dataset = CalibrationImageNet( + data_path, + transform=T.Compose( + [ + T.Resize(256), + T.CenterCrop(img_sz), + T.ToTensor(), + T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + ] + ), + ) + + calibration_dataset = dataset + if num_samples is not None: + calibration_dataset = torch.utils.data.Subset( + dataset, indices=range(num_samples) + ) + + calibration_dataloader = DataLoader( + calibration_dataset, + shuffle=False, + batch_size=batch_size, + drop_last=False, + num_workers=workers, + ) + + verify_dataloader = DataLoader( + dataset, + shuffle=False, + batch_size=batch_size, + drop_last=False, + num_workers=workers, + ) + + return calibration_dataloader, verify_dataloader + + +def getdataloader(dataset_dir, step=20, batch_size=32, workers=2, img_sz=224, total_sample=50000): + num_samples = min(total_sample, step * batch_size) + if step < 0: + num_samples = None + calibration_dataloader, _ = create_dataloaders( + dataset_dir, + img_sz=img_sz, + batch_size=batch_size, + workers=workers, + num_samples=num_samples, + ) + return calibration_dataloader \ No newline at end of file diff --git a/models/cv/classification/cspdarknet50/ixrt/common.py b/models/cv/classification/cspdarknet50/ixrt/common.py new file mode 100755 index 00000000..21c2b399 --- /dev/null +++ b/models/cv/classification/cspdarknet50/ixrt/common.py @@ -0,0 +1,80 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +import os +import cv2 +import glob +import torch +import tensorrt +import numpy as np +from cuda import cuda, cudart + +def eval_batch(batch_score, batch_label): + batch_score = torch.tensor(torch.from_numpy(batch_score), dtype=torch.float32) + values, indices = batch_score.topk(5) + top1, top5 = 0, 0 + for idx, label in enumerate(batch_label): + + if label == indices[idx][0]: + top1 += 1 + if label in indices[idx]: + top5 += 1 + return top1, top5 + +def create_engine_context(engine_path, logger): + with open(engine_path, "rb") as f: + runtime = tensorrt.Runtime(logger) + assert runtime + engine = runtime.deserialize_cuda_engine(f.read()) + assert engine + context = engine.create_execution_context() + assert context + + return engine, context + +def get_io_bindings(engine): + # Setup I/O bindings + inputs = [] + outputs = [] + allocations = [] + + for i in range(engine.num_bindings): + is_input = False + if engine.binding_is_input(i): + is_input = True + name = engine.get_binding_name(i) + dtype = engine.get_binding_dtype(i) + shape = engine.get_binding_shape(i) + if is_input: + batch_size = shape[0] + size = np.dtype(tensorrt.nptype(dtype)).itemsize + for s in shape: + size *= s + err, allocation = cudart.cudaMalloc(size) + assert err == cudart.cudaError_t.cudaSuccess + binding = { + "index": i, + "name": name, + "dtype": np.dtype(tensorrt.nptype(dtype)), + "shape": list(shape), + "allocation": allocation, + "nbytes": size, + } + print(f"binding {i}, name : {name} dtype : {np.dtype(tensorrt.nptype(dtype))} shape : {list(shape)}") + allocations.append(allocation) + if engine.binding_is_input(i): + inputs.append(binding) + else: + outputs.append(binding) + return inputs, outputs, allocations diff --git a/models/cv/classification/cspdarknet50/ixrt/config/CSPDARKNET50_CONFIG b/models/cv/classification/cspdarknet50/ixrt/config/CSPDARKNET50_CONFIG new file mode 100644 index 00000000..df51823c --- /dev/null +++ b/models/cv/classification/cspdarknet50/ixrt/config/CSPDARKNET50_CONFIG @@ -0,0 +1,33 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +# IMGSIZE : 模型输入hw大小 +# MODEL_NAME : 生成onnx/engine的basename +# ORIGINE_MODEL : 原始onnx文件名称 +IMGSIZE=224 +MODEL_NAME=cspdarknet50 +ORIGINE_MODEL=cspdarknet50_sim.onnx + +# QUANT CONFIG (仅PRECISION为int8时生效) + # QUANT_OBSERVER : 量化策略,可选 [hist_percentile, percentile, minmax, entropy, ema] + # QUANT_BATCHSIZE : 量化时组dataloader的batchsize, 最好和onnx中的batchsize保持一致,有些op可能推导shape错误(比如Reshape) + # QUANT_STEP : 量化步数 + # QUANT_SEED : 随机种子 保证量化结果可复现 + # QUANT_EXIST_ONNX : 如果有其他来源的量化模型则填写 +QUANT_OBSERVER=hist_percentile +QUANT_BATCHSIZE=32 +QUANT_STEP=32 +QUANT_SEED=42 +DISABLE_QUANT_LIST= +QUANT_EXIST_ONNX= diff --git a/models/cv/classification/cspdarknet50/ixrt/export.py b/models/cv/classification/cspdarknet50/ixrt/export.py new file mode 100644 index 00000000..9f5514d6 --- /dev/null +++ b/models/cv/classification/cspdarknet50/ixrt/export.py @@ -0,0 +1,76 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +import argparse + +import torch +from mmcls.apis import init_model + +class Model(torch.nn.Module): + def __init__(self, config_file, checkpoint_file): + super().__init__() + self.model = init_model(config_file, checkpoint_file, device="cpu") + + def forward(self, x): + feat = self.model.backbone(x) + feat = self.model.neck(feat) + out_head = self.model.head.fc(feat[0]) + return out_head + +def parse_args(): + parser = argparse.ArgumentParser() + + parser.add_argument("--weight", + type=str, + required=True, + help="pytorch model weight.") + + parser.add_argument("--cfg", + type=str, + required=True, + help="model config file.") + + parser.add_argument("--output", + type=str, + required=True, + help="export onnx model path.") + + args = parser.parse_args() + return args + +def main(): + args = parse_args() + + config_file = args.cfg + checkpoint_file = args.weight + model = Model(config_file, checkpoint_file).eval() + + input_names = ['input'] + output_names = ['output'] + dummy_input = torch.randn(32, 3, 224, 224) + + torch.onnx.export( + model, + dummy_input, + args.output, + input_names = input_names, + output_names = output_names, + opset_version=13 + ) + + print("Export onnx model successfully! ") + +if __name__ == '__main__': + main() + diff --git a/models/cv/classification/cspdarknet50/ixrt/inference.py b/models/cv/classification/cspdarknet50/ixrt/inference.py new file mode 100755 index 00000000..56b7f51c --- /dev/null +++ b/models/cv/classification/cspdarknet50/ixrt/inference.py @@ -0,0 +1,158 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +import argparse +import json +import os +import re +import time +from tqdm import tqdm + +import cv2 +import numpy as np +from cuda import cuda, cudart +import torch +import tensorrt + +from calibration_dataset import getdataloader +from common import eval_batch, create_engine_context, get_io_bindings + +def main(config): + dataloader = getdataloader(config.datasets_dir, config.loop_count, config.bsz, img_sz=config.imgsz) + + host_mem = tensorrt.IHostMemory + logger = tensorrt.Logger(tensorrt.Logger.ERROR) + + # Load Engine && I/O bindings + engine, context = create_engine_context(config.engine_file, logger) + inputs, outputs, allocations = get_io_bindings(engine) + + # Warm up + if config.warm_up > 0: + print("\nWarm Start.") + for i in range(config.warm_up): + context.execute_v2(allocations) + print("Warm Done.") + + # Inference + if config.test_mode == "FPS": + torch.cuda.synchronize() + start_time = time.time() + + for i in range(config.loop_count): + context.execute_v2(allocations) + + torch.cuda.synchronize() + end_time = time.time() + forward_time = end_time - start_time + + num_samples = 50000 + if config.loop_count * config.bsz < num_samples: + num_samples = config.loop_count * config.bsz + fps = num_samples / forward_time + + print("FPS : ", fps) + print(f"Performance Check : Test {fps} >= target {config.fps_target}") + if fps >= config.fps_target: + print("pass!") + exit() + else: + print("failed!") + exit(1) + + elif config.test_mode == "ACC": + + ## Prepare the output data + output = np.zeros(outputs[0]["shape"], outputs[0]["dtype"]) + print(f"output shape : {output.shape} output type : {output.dtype}") + + total_sample = 0 + acc_top1, acc_top5 = 0, 0 + + with tqdm(total= len(dataloader)) as _tqdm: + for idx, (batch_data, batch_label) in enumerate(dataloader): + batch_data = batch_data.numpy().astype(inputs[0]["dtype"]) + batch_data = np.ascontiguousarray(batch_data) + total_sample += batch_data.shape[0] + + err, = cuda.cuMemcpyHtoD(inputs[0]["allocation"], batch_data, batch_data.nbytes) + assert(err == cuda.CUresult.CUDA_SUCCESS) + context.execute_v2(allocations) + err, = cuda.cuMemcpyDtoH(output, outputs[0]["allocation"], outputs[0]["nbytes"]) + assert(err == cuda.CUresult.CUDA_SUCCESS) + + # squeeze output shape [32,1000,1,1] to [32,1000] for mobilenet_v2 model + if len(output.shape) == 4: + output = output.squeeze(axis=(2,3)) + + batch_top1, batch_top5 = eval_batch(output, batch_label) + acc_top1 += batch_top1 + acc_top5 += batch_top5 + + _tqdm.set_postfix(acc_1='{:.4f}'.format(acc_top1/total_sample), + acc_5='{:.4f}'.format(acc_top5/total_sample)) + _tqdm.update(1) + + print(F"Acc@1 : {acc_top1/total_sample} = {acc_top1}/{total_sample}") + print(F"Acc@5 : {acc_top5/total_sample} = {acc_top5}/{total_sample}") + acc1 = acc_top1/total_sample + print(f"Accuracy Check : Test {acc1} >= target {config.acc_target}") + if acc1 >= config.acc_target: + print("pass!") + exit() + else: + print("failed!") + exit(1) + +def parse_config(): + parser = argparse.ArgumentParser() + parser.add_argument("--test_mode", type=str, default="FPS", help="FPS MAP") + parser.add_argument( + "--engine_file", + type=str, + help="engine file path" + ) + parser.add_argument( + "--datasets_dir", + type=str, + default="", + help="ImageNet dir", + ) + parser.add_argument("--warm_up", type=int, default=-1, help="warm_up times") + parser.add_argument("--bsz", type=int, default=32, help="test batch size") + parser.add_argument( + "--imgsz", + "--img", + "--img-size", + type=int, + default=224, + help="inference size h,w", + ) + parser.add_argument("--use_async", action="store_true") + parser.add_argument( + "--device", type=int, default=0, help="cuda device, i.e. 0 or 0,1,2,3,4" + ) + parser.add_argument("--fps_target", type=float, default=-1.0) + parser.add_argument("--acc_target", type=float, default=-1.0) + parser.add_argument("--loop_count", type=int, default=-1) + + config = parser.parse_args() + return config + +if __name__ == "__main__": + config = parse_config() + main(config) diff --git a/models/cv/classification/cspdarknet50/ixrt/quant.py b/models/cv/classification/cspdarknet50/ixrt/quant.py new file mode 100644 index 00000000..c728c7a1 --- /dev/null +++ b/models/cv/classification/cspdarknet50/ixrt/quant.py @@ -0,0 +1,166 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +"""这是一个高度自动化的 PPQ 量化的入口脚本,将你的模型和数据按要求进行打包: + +在自动化 API 中,我们使用 QuantizationSetting 对象传递量化参数。 + +This file will show you how to quantize your network with PPQ + You should prepare your model and calibration dataset as follow: + + ~/working/model.onnx <-- your model + ~/working/data/*.npy or ~/working/data/*.bin <-- your dataset + +if you are using caffe model: + ~/working/model.caffemdoel <-- your model + ~/working/model.prototext <-- your model + +### MAKE SURE YOUR INPUT LAYOUT IS [N, C, H, W] or [C, H, W] ### + +quantized model will be generated at: ~/working/quantized.onnx +""" +from ppq import * +from ppq.api import * +import os +from calibration_dataset import getdataloader +import argparse +import random +import numpy as np +import torch + + +def setseed(seed=42): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--model_name", type=str) + parser.add_argument("--model", type=str) + parser.add_argument("--dataset_dir", type=str, default="imagenet_val") + parser.add_argument("--observer", type=str, choices=["hist_percentile", "percentile", "minmax", "entropy", "ema"], + default="hist_percentile") + parser.add_argument("--disable_quant_names", nargs='*', type=str) + parser.add_argument("--save_dir", type=str, help="save path", default=None) + parser.add_argument("--bsz", type=int, default=32) + parser.add_argument("--step", type=int, default=20) + parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--imgsz", type=int, default=224) + args = parser.parse_args() + print("Quant config:", args) + print(args.disable_quant_names) + return args + + +config = parse_args() + +# modify configuration below: +WORKING_DIRECTORY = 'checkpoints' # choose your working directory +TARGET_PLATFORM = TargetPlatform.TRT_INT8 # choose your target platform +MODEL_TYPE = NetworkFramework.ONNX # or NetworkFramework.CAFFE +INPUT_LAYOUT = 'chw' # input data layout, chw or hwc +NETWORK_INPUTSHAPE = [32, 3, 224, 224] # input shape of your network +EXECUTING_DEVICE = 'cuda' # 'cuda' or 'cpu'. +REQUIRE_ANALYSE = False +TRAINING_YOUR_NETWORK = False # 是否需要 Finetuning 一下你的网络 +# ------------------------------------------------------------------- +# 加载你的模型文件,PPQ 将会把 onnx 或者 caffe 模型文件解析成自己的格式 +# 如果你正使用 pytorch, tensorflow 等框架,你可以先将模型导出成 onnx +# 使用 torch.onnx.export 即可,如果你在导出 torch 模型时发生错误,欢迎与我们联系。 +# ------------------------------------------------------------------- +graph = None +if MODEL_TYPE == NetworkFramework.ONNX: + graph = load_onnx_graph(onnx_import_file=config.model) +if MODEL_TYPE == NetworkFramework.CAFFE: + graph = load_caffe_graph( + caffemodel_path=os.path.join(WORKING_DIRECTORY, 'model.caffemodel'), + prototxt_path=os.path.join(WORKING_DIRECTORY, 'model.prototxt')) +assert graph is not None, 'Graph Loading Error, Check your input again.' + +# ------------------------------------------------------------------- +# SETTING 对象用于控制 PPQ 的量化逻辑,主要描述了图融合逻辑、调度方案、量化细节策略等 +# 当你的网络量化误差过高时,你需要修改 SETTING 对象中的属性来进行特定的优化 +# ------------------------------------------------------------------- +QS = QuantizationSettingFactory.default_setting() + +# ------------------------------------------------------------------- +# 下面向你展示了如何使用 finetuning 过程提升量化精度 +# 在 PPQ 中我们提供了十余种算法用来帮助你恢复精度 +# 开启他们的方式都是 QS.xxxx = True +# 按需使用,不要全部打开,容易起飞 +# ------------------------------------------------------------------- +if TRAINING_YOUR_NETWORK: + QS.lsq_optimization = True # 启动网络再训练过程,降低量化误差 + QS.lsq_optimization_setting.steps = 500 # 再训练步数,影响训练时间,500 步大概几分钟 + QS.lsq_optimization_setting.collecting_device = 'cuda' # 缓存数据放在那,cuda 就是放在gpu,如果显存超了你就换成 'cpu' + + +dataloader = getdataloader(config.dataset_dir, config.step, batch_size=config.bsz, img_sz=config.imgsz) +# ENABLE CUDA KERNEL 会加速量化效率 3x ~ 10x,但是你如果没有装相应编译环境的话是编译不了的 +# 你可以尝试安装编译环境,或者在不启动 CUDA KERNEL 的情况下完成量化:移除 with ENABLE_CUDA_KERNEL(): 即可 +with ENABLE_CUDA_KERNEL(): + print('网络正量化中,根据你的量化配置,这将需要一段时间:') + quantized = quantize_native_model( + setting=QS, # setting 对象用来控制标准量化逻辑 + model=graph, + calib_dataloader=dataloader, + calib_steps=config.step, + input_shape=NETWORK_INPUTSHAPE, # 如果你的网络只有一个输入,使用这个参数传参 + inputs=None, + # 如果你的网络有多个输入,使用这个参数传参,就是 input_shape=None, inputs=[torch.zeros(1,3,224,224), torch.zeros(1,3,224,224)] + collate_fn=lambda x: x[0].to(EXECUTING_DEVICE), # collate_fn 跟 torch dataloader 的 collate fn 是一样的,用于数据预处理, + # 你当然也可以用 torch dataloader 的那个,然后设置这个为 None + platform=TARGET_PLATFORM, + device=EXECUTING_DEVICE, + do_quantize=True) + + # ------------------------------------------------------------------- + # 如果你需要执行量化后的神经网络并得到结果,则需要创建一个 executor + # 这个 executor 的行为和 torch.Module 是类似的,你可以利用这个东西来获取执行结果 + # 请注意,必须在 export 之前执行此操作。 + # ------------------------------------------------------------------- + executor = TorchExecutor(graph=quantized, device=EXECUTING_DEVICE) + # output = executor.forward(input) + + # ------------------------------------------------------------------- + # PPQ 计算量化误差时,使用信噪比的倒数作为指标,即噪声能量 / 信号能量 + # 量化误差 0.1 表示在整体信号中,量化噪声的能量约为 10% + # 你应当注意,在 graphwise_error_analyse 分析中,我们衡量的是累计误差 + # 网络的最后一层往往都具有较大的累计误差,这些误差是其前面的所有层所共同造成的 + # 你需要使用 layerwise_error_analyse 逐层分析误差的来源 + # ------------------------------------------------------------------- + print('正计算网络量化误差(SNR),最后一层的误差应小于 0.1 以保证量化精度:') + reports = graphwise_error_analyse( + graph=quantized, running_device=EXECUTING_DEVICE, steps=32, + dataloader=dataloader, collate_fn=lambda x: x[0].to(EXECUTING_DEVICE)) + for op, snr in reports.items(): + if snr > 0.1: ppq_warning(f'层 {op} 的累计量化误差显著,请考虑进行优化') + + if REQUIRE_ANALYSE: + print('正计算逐层量化误差(SNR),每一层的独立量化误差应小于 0.1 以保证量化精度:') + layerwise_error_analyse(graph=quantized, running_device=EXECUTING_DEVICE, + interested_outputs=None, + dataloader=dataloader, collate_fn=lambda x: x.to(EXECUTING_DEVICE)) + + # ------------------------------------------------------------------- + # 使用 export_ppq_graph 函数来导出量化后的模型 + # PPQ 会根据你所选择的导出平台来修改模型格式 + # ------------------------------------------------------------------- + print('网络量化结束,正在生成目标文件:') + export_ppq_graph( + graph=quantized, platform=TARGET_PLATFORM, + graph_save_to=os.path.join(config.save_dir, f"quantized_{config.model_name}.onnx"), + config_save_to=os.path.join(config.save_dir, 'quant_cfg.json')) diff --git a/models/cv/classification/cspdarknet50/ixrt/refine_utils/__init__.py b/models/cv/classification/cspdarknet50/ixrt/refine_utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/models/cv/classification/cspdarknet50/ixrt/refine_utils/common.py b/models/cv/classification/cspdarknet50/ixrt/refine_utils/common.py new file mode 100644 index 00000000..2af19a14 --- /dev/null +++ b/models/cv/classification/cspdarknet50/ixrt/refine_utils/common.py @@ -0,0 +1,36 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +from typing import Union, Callable, List + +from tensorrt.deploy.api import * +from tensorrt.deploy.backend.onnx.converter import default_converter +from tensorrt.deploy.backend.torch.executor.operators._operators import to_py_type +from tensorrt.deploy.ir.operator_attr import BaseOperatorAttr, EmptyAttr +from tensorrt.deploy.ir.operator_type import OperatorType as OP +from tensorrt.deploy.ir import operator_attr as attr, Operator, generate_operator_name +from tensorrt.deploy.fusion import BasePass, PatternGraph, build_sequence_graph, GraphMatcher, PassSequence +from tensorrt.deploy.ir import Graph +from tensorrt.deploy.quantizer.quant_operator.base import quant_single_input_operator +from tensorrt.deploy.backend.onnx.converter import convert_onnx_operator + +def find_sequence_subgraph(graph, + pattern: Union[List[str], PatternGraph], + callback: Callable[[Graph, PatternGraph], None], + strict=True): + if isinstance(pattern, List): + pattern = build_sequence_graph(pattern) + + matcher = GraphMatcher(pattern, strict=strict) + return matcher.findall(graph, callback) \ No newline at end of file diff --git a/models/cv/classification/cspdarknet50/ixrt/refine_utils/linear_pass.py b/models/cv/classification/cspdarknet50/ixrt/refine_utils/linear_pass.py new file mode 100644 index 00000000..29b5e4a9 --- /dev/null +++ b/models/cv/classification/cspdarknet50/ixrt/refine_utils/linear_pass.py @@ -0,0 +1,113 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +import dataclasses + +from refine_utils.common import * + +# AXB=C, Only for B is initializer + +class FusedLinearPass(BasePass): + + def process(self, graph: Graph) -> Graph: + self.transform = GraphTransform(graph) + + find_sequence_subgraph( + graph, pattern=[OP.MATMUL, OP.ADD], callback=self.to_linear_with_bias, strict=True + ) + find_sequence_subgraph( + graph, pattern=[OP.MATMUL], callback=self.to_linear, strict=True + ) + return graph + + def to_linear_with_bias(self, graph, pattern: PatternGraph): + matmul = pattern.nodes[0] + add = pattern.nodes[1] + if len(add.operator.inputs) != 2: + return + + b_var = graph.get_variable(matmul.operator.inputs[1]) + if not graph.is_leaf_variable(b_var) or b_var.value is None: + return + + if b_var.value.ndim != 2: + return + + bias_var = None + for input in add.operator.inputs: + if input not in matmul.operator.outputs: + bias_var = input + + inputs = matmul.operator.inputs + inputs.append(bias_var) + outputs = add.operator.outputs + + b_var.value = b_var.value.transpose(1, 0) + b_var.shape[0],b_var.shape[1] = b_var.shape[1],b_var.shape[0] + + hidden_size = b_var.shape[1] + linear_dim = b_var.shape[0] + + attributes = { + "hidden_size": hidden_size, + "linear_dim": linear_dim, + "has_bias": 1, + "act_type":"none" + } + + self.transform.make_operator( + "LinearFP16", + inputs=inputs, + outputs=outputs, + **attributes + ) + + self.transform.delete_operator(add.operator) + self.transform.delete_operator(matmul.operator) + + def to_linear(self, graph, pattern: PatternGraph): + matmul = pattern.nodes[0] + if len(matmul.operator.inputs) != 2: + return + + b_var = graph.get_variable(matmul.operator.inputs[1]) + if not graph.is_leaf_variable(b_var) or b_var.value is None: + return + + if b_var.value.ndim != 2: + return + + attributes = { + "hidden_size": hidden_size, + "linear_dim": linear_dim, + "has_bias": 0, + "act_type": "none" + } + + b_var.value = b_var.value.transpose(1, 0) + b_var.shape[0],b_var.shape[1] = b_var.shape[1], b_var.shape[0] + + hidden_size = b_var.shape[1] + linear_dim = b_var.shape[0] + + op = self.transform.make_operator( + op_type = "LinearFP16", + inputs = pattern.nodes[0].operator.inputs, + outputs=[pattern.nodes[-1].operator.outputs[0]], + **attributes + ) + + self.transform.add_operator(op) + + self.transform.delete_operator(matmul.operator) \ No newline at end of file diff --git a/models/cv/classification/cspdarknet50/ixrt/refine_utils/matmul_to_gemm_pass.py b/models/cv/classification/cspdarknet50/ixrt/refine_utils/matmul_to_gemm_pass.py new file mode 100644 index 00000000..4ebfac4d --- /dev/null +++ b/models/cv/classification/cspdarknet50/ixrt/refine_utils/matmul_to_gemm_pass.py @@ -0,0 +1,54 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +from refine_utils.common import * + +# +# Common pattern Matmul to Gemm +# +class FusedGemmPass(BasePass): + + def process(self, graph: Graph) -> Graph: + self.transform = GraphTransform(graph) + + find_sequence_subgraph( + graph, pattern=[OP.MATMUL], callback=self.to_gemm, strict=True + ) + return graph + + def to_gemm(self, graph, pattern: PatternGraph): + matmul_op = pattern.nodes[0] + inputs = matmul_op.operator.inputs + outputs = matmul_op.operator.outputs + + if len(inputs)!=2 and len(outputs)!=1: + return + + for input in inputs: + if self.transform.is_leaf_variable(input): + return + + print(f"{self.transform.get_variable(inputs[0]).shape} {self.transform.get_variable(inputs[1]).shape}") + self.transform.delete_operator(matmul_op.operator) + + op = self.transform.make_operator( + op_type = "Gemm", + inputs = inputs, + outputs = outputs, + alpha = 1, + beta = 1, + transB = 1 + ) + + self.transform.add_operator(op) \ No newline at end of file diff --git a/models/cv/classification/cspdarknet50/ixrt/scripts/infer_cspdarknet50_fp16_accuracy.sh b/models/cv/classification/cspdarknet50/ixrt/scripts/infer_cspdarknet50_fp16_accuracy.sh new file mode 100644 index 00000000..02f44d22 --- /dev/null +++ b/models/cv/classification/cspdarknet50/ixrt/scripts/infer_cspdarknet50_fp16_accuracy.sh @@ -0,0 +1,92 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +#!/bin/bash + +EXIT_STATUS=0 +check_status() +{ + if ((${PIPESTATUS[0]} != 0));then + EXIT_STATUS=1 + fi +} + +# Run paraments +BSZ=32 +TGT=-1 +WARM_UP=0 +LOOP_COUNT=-1 +RUN_MODE=ACC +PRECISION=float16 + +# Update arguments +index=0 +options=$@ +arguments=($options) +for argument in $options +do + index=`expr $index + 1` + case $argument in + --bs) BSZ=${arguments[index]};; + --tgt) TGT=${arguments[index]};; + esac +done + +DATASETS_DIR=${DATASETS_DIR} +CHECKPOINTS_DIR=${CHECKPOINTS_DIR} +CONFIG_DIR=${CONFIG_DIR} +source ${CONFIG_DIR} +ORIGINE_MODEL=${CHECKPOINTS_DIR}/${ORIGINE_MODEL} + +echo CHECKPOINTS_DIR : ${CHECKPOINTS_DIR} +echo DATASETS_DIR : ${DATASETS_DIR} +echo CONFIG_DIR : ${CONFIG_DIR} +echo ====================== Model Info ====================== +echo Model Name : ${MODEL_NAME} +echo Onnx Path : ${ORIGINE_MODEL} + +step=0 +SIM_MODEL=${CHECKPOINTS_DIR}/${MODEL_NAME}_sim.onnx +FINAL_MODEL=${SIM_MODEL} + +# Build Engine +let step++ +echo; +echo [STEP ${step}] : Build Engine +ENGINE_FILE=${CHECKPOINTS_DIR}/${MODEL_NAME}_${PRECISION}_bs${BSZ}.engine +if [ -f $ENGINE_FILE ];then + echo " "Build Engine Skip, $ENGINE_FILE has been existed +else + python3 build_engine.py \ + --precision ${PRECISION} \ + --model ${FINAL_MODEL} \ + --engine ${ENGINE_FILE} + echo " "Generate Engine ${ENGINE_FILE} +fi + +# Inference +let step++ +echo; +echo [STEP ${step}] : Inference +python3 inference.py \ + --engine_file=${ENGINE_FILE} \ + --datasets_dir=${DATASETS_DIR} \ + --imgsz=${IMGSIZE} \ + --warm_up=${WARM_UP} \ + --loop_count ${LOOP_COUNT} \ + --test_mode ${RUN_MODE} \ + --acc_target ${TGT} \ + --bsz ${BSZ}; check_status + +exit ${EXIT_STATUS} \ No newline at end of file diff --git a/models/cv/classification/cspdarknet50/ixrt/scripts/infer_cspdarknet50_fp16_performance.sh b/models/cv/classification/cspdarknet50/ixrt/scripts/infer_cspdarknet50_fp16_performance.sh new file mode 100644 index 00000000..2b6b8a66 --- /dev/null +++ b/models/cv/classification/cspdarknet50/ixrt/scripts/infer_cspdarknet50_fp16_performance.sh @@ -0,0 +1,92 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +#!/bin/bash + +EXIT_STATUS=0 +check_status() +{ + if ((${PIPESTATUS[0]} != 0));then + EXIT_STATUS=1 + fi +} + +# Run paraments +BSZ=32 +TGT=-1 +WARM_UP=3 +LOOP_COUNT=20 +RUN_MODE=FPS +PRECISION=float16 + +# Update arguments +index=0 +options=$@ +arguments=($options) +for argument in $options +do + index=`expr $index + 1` + case $argument in + --bs) BSZ=${arguments[index]};; + --tgt) TGT=${arguments[index]};; + esac +done + +DATASETS_DIR=${DATASETS_DIR} +CHECKPOINTS_DIR=${CHECKPOINTS_DIR} +CONFIG_DIR=${CONFIG_DIR} +source ${CONFIG_DIR} +ORIGINE_MODEL=${CHECKPOINTS_DIR}/${ORIGINE_MODEL} + +echo CHECKPOINTS_DIR : ${CHECKPOINTS_DIR} +echo DATASETS_DIR : ${DATASETS_DIR} +echo CONFIG_DIR : ${CONFIG_DIR} +echo ====================== Model Info ====================== +echo Model Name : ${MODEL_NAME} +echo Onnx Path : ${ORIGINE_MODEL} + +step=0 +SIM_MODEL=${CHECKPOINTS_DIR}/${MODEL_NAME}_sim.onnx +FINAL_MODEL=${SIM_MODEL} + +# Build Engine +let step++ +echo; +echo [STEP ${step}] : Build Engine +ENGINE_FILE=${CHECKPOINTS_DIR}/${MODEL_NAME}_${PRECISION}_bs${BSZ}.engine +if [ -f $ENGINE_FILE ];then + echo " "Build Engine Skip, $ENGINE_FILE has been existed +else + python3 build_engine.py \ + --precision ${PRECISION} \ + --model ${FINAL_MODEL} \ + --engine ${ENGINE_FILE} + echo " "Generate Engine ${ENGINE_FILE} +fi + +# Inference +let step++ +echo; +echo [STEP ${step}] : Inference +python3 inference.py \ + --engine_file=${ENGINE_FILE} \ + --datasets_dir=${DATASETS_DIR} \ + --imgsz=${IMGSIZE} \ + --warm_up=${WARM_UP} \ + --loop_count ${LOOP_COUNT} \ + --test_mode ${RUN_MODE} \ + --fps_target ${TGT} \ + --bsz ${BSZ}; check_status + +exit ${EXIT_STATUS} \ No newline at end of file diff --git a/models/cv/classification/cspdarknet50/ixrt/scripts/infer_cspdarknet50_int8_accuracy.sh b/models/cv/classification/cspdarknet50/ixrt/scripts/infer_cspdarknet50_int8_accuracy.sh new file mode 100644 index 00000000..30e208be --- /dev/null +++ b/models/cv/classification/cspdarknet50/ixrt/scripts/infer_cspdarknet50_int8_accuracy.sh @@ -0,0 +1,121 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +#!/bin/bash + +EXIT_STATUS=0 +check_status() +{ + if ((${PIPESTATUS[0]} != 0));then + EXIT_STATUS=1 + fi +} + +# Run paraments +BSZ=32 +TGT=-1 +WARM_UP=0 +LOOP_COUNT=-1 +RUN_MODE=ACC +PRECISION=int8 + +# Update arguments +index=0 +options=$@ +arguments=($options) +for argument in $options +do + index=`expr $index + 1` + case $argument in + --bs) BSZ=${arguments[index]};; + --tgt) TGT=${arguments[index]};; + esac +done + +DATASETS_DIR=${DATASETS_DIR} +CHECKPOINTS_DIR=${CHECKPOINTS_DIR} +CONFIG_DIR=${CONFIG_DIR} +source ${CONFIG_DIR} +ORIGINE_MODEL=${CHECKPOINTS_DIR}/${ORIGINE_MODEL} + +echo CHECKPOINTS_DIR : ${CHECKPOINTS_DIR} +echo DATASETS_DIR : ${DATASETS_DIR} +echo CONFIG_DIR : ${CONFIG_DIR} +echo ====================== Model Info ====================== +echo Model Name : ${MODEL_NAME} +echo Onnx Path : ${ORIGINE_MODEL} + +step=0 +SIM_MODEL=${CHECKPOINTS_DIR}/${MODEL_NAME}_sim.onnx + +# Quant Model +if [ $PRECISION == "int8" ];then + let step++ + echo; + echo [STEP ${step}] : Quant Model + if [[ -z ${QUANT_EXIST_ONNX} ]];then + QUANT_EXIST_ONNX=$CHECKPOINTS_DIR/quantized_${MODEL_NAME}.onnx + fi + if [[ -f ${QUANT_EXIST_ONNX} ]];then + SIM_MODEL=${QUANT_EXIST_ONNX} + echo " "Quant Model Skip, ${QUANT_EXIST_ONNX} has been existed + else + python3 quant.py \ + --model ${SIM_MODEL} \ + --model_name ${MODEL_NAME} \ + --dataset_dir ${DATASETS_DIR} \ + --observer ${QUANT_OBSERVER} \ + --disable_quant_names ${DISABLE_QUANT_LIST[@]} \ + --save_dir $CHECKPOINTS_DIR \ + --bsz ${QUANT_BATCHSIZE} \ + --step ${QUANT_STEP} \ + --seed ${QUANT_SEED} \ + --imgsz ${IMGSIZE} + SIM_MODEL=${QUANT_EXIST_ONNX} + echo " "Generate ${SIM_MODEL} + fi +fi + +FINAL_MODEL=${SIM_MODEL} + +# Build Engine +let step++ +echo; +echo [STEP ${step}] : Build Engine +ENGINE_FILE=${CHECKPOINTS_DIR}/${MODEL_NAME}_${PRECISION}_bs${BSZ}.engine +if [ -f $ENGINE_FILE ];then + echo " "Build Engine Skip, $ENGINE_FILE has been existed +else + python3 build_i8_engine.py \ + --onnx ${FINAL_MODEL} \ + --qparam_json ${CHECKPOINTS_DIR}/quant_cfg.json \ + --engine ${ENGINE_FILE} + echo " "Generate Engine ${ENGINE_FILE} +fi + +# Inference +let step++ +echo; +echo [STEP ${step}] : Inference +python3 inference.py \ + --engine_file=${ENGINE_FILE} \ + --datasets_dir=${DATASETS_DIR} \ + --imgsz=${IMGSIZE} \ + --warm_up=${WARM_UP} \ + --loop_count ${LOOP_COUNT} \ + --test_mode ${RUN_MODE} \ + --acc_target ${TGT} \ + --bsz ${BSZ}; check_status + +exit ${EXIT_STATUS} \ No newline at end of file diff --git a/models/cv/classification/cspdarknet50/ixrt/scripts/infer_cspdarknet50_int8_performance.sh b/models/cv/classification/cspdarknet50/ixrt/scripts/infer_cspdarknet50_int8_performance.sh new file mode 100644 index 00000000..82ebd283 --- /dev/null +++ b/models/cv/classification/cspdarknet50/ixrt/scripts/infer_cspdarknet50_int8_performance.sh @@ -0,0 +1,122 @@ +# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +#!/bin/bash + +EXIT_STATUS=0 +check_status() +{ + if ((${PIPESTATUS[0]} != 0));then + EXIT_STATUS=1 + fi +} + +# Run paraments +BSZ=32 +TGT=-1 +WARM_UP=3 +LOOP_COUNT=20 +RUN_MODE=FPS +PRECISION=int8 + +# Update arguments +index=0 +options=$@ +arguments=($options) +for argument in $options +do + index=`expr $index + 1` + case $argument in + --bs) BSZ=${arguments[index]};; + --tgt) TGT=${arguments[index]};; + esac +done + +DATASETS_DIR=${DATASETS_DIR} +CHECKPOINTS_DIR=${CHECKPOINTS_DIR} +CONFIG_DIR=${CONFIG_DIR} +source ${CONFIG_DIR} +ORIGINE_MODEL=${CHECKPOINTS_DIR}/${ORIGINE_MODEL} + +echo CHECKPOINTS_DIR : ${CHECKPOINTS_DIR} +echo DATASETS_DIR : ${DATASETS_DIR} +echo CONFIG_DIR : ${CONFIG_DIR} +echo ====================== Model Info ====================== +echo Model Name : ${MODEL_NAME} +echo Onnx Path : ${ORIGINE_MODEL} + +step=0 +SIM_MODEL=${CHECKPOINTS_DIR}/${MODEL_NAME}_sim.onnx + +# Quant Model +if [ $PRECISION == "int8" ];then + let step++ + echo; + echo [STEP ${step}] : Quant Model + if [[ -z ${QUANT_EXIST_ONNX} ]];then + QUANT_EXIST_ONNX=$CHECKPOINTS_DIR/quantized_${MODEL_NAME}.onnx + fi + if [[ -f ${QUANT_EXIST_ONNX} ]];then + SIM_MODEL=${QUANT_EXIST_ONNX} + echo " "Quant Model Skip, ${QUANT_EXIST_ONNX} has been existed + else + python3 quant.py \ + --model ${SIM_MODEL} \ + --model_name ${MODEL_NAME} \ + --dataset_dir ${DATASETS_DIR} \ + --observer ${QUANT_OBSERVER} \ + --disable_quant_names ${DISABLE_QUANT_LIST[@]} \ + --save_dir $CHECKPOINTS_DIR \ + --bsz ${QUANT_BATCHSIZE} \ + --step ${QUANT_STEP} \ + --seed ${QUANT_SEED} \ + --imgsz ${IMGSIZE} + SIM_MODEL=${QUANT_EXIST_ONNX} + echo " "Generate ${SIM_MODEL} + fi +fi + +FINAL_MODEL=${SIM_MODEL} + +# Build Engine +let step++ +echo; +echo [STEP ${step}] : Build Engine +ENGINE_FILE=${CHECKPOINTS_DIR}/${MODEL_NAME}_${PRECISION}_bs${BSZ}.engine +if [ -f $ENGINE_FILE ];then + echo " "Build Engine Skip, $ENGINE_FILE has been existed +else + python3 build_i8_engine.py \ + --onnx ${FINAL_MODEL} \ + --qparam_json ${CHECKPOINTS_DIR}/quant_cfg.json \ + --engine ${ENGINE_FILE} + echo " "Generate Engine ${ENGINE_FILE} +fi + + +# Inference +let step++ +echo; +echo [STEP ${step}] : Inference +python3 inference.py \ + --engine_file=${ENGINE_FILE} \ + --datasets_dir=${DATASETS_DIR} \ + --imgsz=${IMGSIZE} \ + --warm_up=${WARM_UP} \ + --loop_count ${LOOP_COUNT} \ + --test_mode ${RUN_MODE} \ + --fps_target ${TGT} \ + --bsz ${BSZ}; check_status + +exit ${EXIT_STATUS} \ No newline at end of file -- Gitee