diff --git a/models/cv/classification/shufflenetv2_x0_5/igie/README.md b/models/cv/classification/shufflenetv2_x0_5/igie/README.md index e6d97ab8a65c70c57825b176a1869d02c2cc7529..f6fc6790b38cdd2d2fdeb75ac11eaf5efa0f1fe7 100644 --- a/models/cv/classification/shufflenetv2_x0_5/igie/README.md +++ b/models/cv/classification/shufflenetv2_x0_5/igie/README.md @@ -2,7 +2,9 @@ ## Model Description -ShuffleNetV2_x0_5 is a lightweight convolutional neural network architecture designed for efficient image classification and feature extraction, it also incorporates other design optimizations such as depthwise separable convolutions, group convolutions, and efficient building blocks to further reduce computational complexity and improve efficiency. +ShuffleNetV2_x0_5 is a lightweight convolutional neural network architecture designed for efficient image classification +and feature extraction, it also incorporates other design optimizations such as depthwise separable convolutions, group +convolutions, and efficient building blocks to further reduce computational complexity and improve efficiency. ## Supported Environments diff --git a/models/cv/classification/shufflenetv2_x0_5/ixrt/README.md b/models/cv/classification/shufflenetv2_x0_5/ixrt/README.md new file mode 100644 index 0000000000000000000000000000000000000000..d6238d35ac349d48c233e3db81990997b31e621e --- /dev/null +++ b/models/cv/classification/shufflenetv2_x0_5/ixrt/README.md @@ -0,0 +1,66 @@ +# ShuffleNetV2 x0_5 (IxRT) + +## Model Description + +ShuffleNetV2_x0_5 is a lightweight convolutional neural network architecture designed for efficient image classification +and feature extraction, it also incorporates other design optimizations such as depthwise separable convolutions, group +convolutions, and efficient building blocks to further reduce computational complexity and improve efficiency. + +## Supported Environments + +| GPU | [IXUCA SDK](https://gitee.com/deep-spark/deepspark#%E5%A4%A9%E6%95%B0%E6%99%BA%E7%AE%97%E8%BD%AF%E4%BB%B6%E6%A0%88-ixuca) | Release | +|--------|-----------|---------| +| MR-V100 | 4.3.0 | 25.06 | + +## Model Preparation + +### Prepare Resources + +Pretrained model: + +Dataset: to download the validation dataset. + +### Install Dependencies + +```bash +# Install libGL +## CentOS +yum install -y mesa-libGL +## Ubuntu +apt install -y libgl1-mesa-glx + +pip3 install -r requirements.txt +``` + +### Model Conversion + +```bash +mkdir checkpoints +python3 export.py --weight shufflenetv2_x0.5-f707e7126e.pth --output checkpoints/shufflenetv2_x0_5.onnx +``` + +## Model Inference + +```bash +export PROJ_DIR=./ +export DATASETS_DIR=/path/to/imagenet_val/ +export CHECKPOINTS_DIR=./checkpoints +export RUN_DIR=./ +export CONFIG_DIR=config/SHUFFLENET_V2_X0_5_CONFIG + +``` + +### FP16 + +```bash +# Accuracy +bash scripts/infer_shufflenetv2_x0_5_fp16_accuracy.sh +# Performance +bash scripts/infer_shufflenetv2_x0_5_fp16_performance.sh +``` + +## Model Results + +| Model | BatchSize | Precision | FPS | Top-1(%) | Top-5(%) | +|-------------------|-----------|-----------|----------|----------|----------| +| ShuffleNetV2 x0_5 | 32 | FP16 | 10680.65 | 60.53 | 81.74 | diff --git a/models/cv/classification/shufflenetv2_x0_5/ixrt/build_engine.py b/models/cv/classification/shufflenetv2_x0_5/ixrt/build_engine.py new file mode 100644 index 0000000000000000000000000000000000000000..e0a3bb4bcd143c68406d915245eefb7060010127 --- /dev/null +++ b/models/cv/classification/shufflenetv2_x0_5/ixrt/build_engine.py @@ -0,0 +1,53 @@ +# Copyright (c) 2025, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import os +import cv2 +import argparse +import numpy as np + +import torch +import tensorrt + +def main(config): + IXRT_LOGGER = tensorrt.Logger(tensorrt.Logger.WARNING) + builder = tensorrt.Builder(IXRT_LOGGER) + EXPLICIT_BATCH = 1 << (int)(tensorrt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH) + network = builder.create_network(EXPLICIT_BATCH) + build_config = builder.create_builder_config() + parser = tensorrt.OnnxParser(network, IXRT_LOGGER) + parser.parse_from_file(config.model) + + precision = tensorrt.BuilderFlag.INT8 if config.precision == "int8" else tensorrt.BuilderFlag.FP16 + # print("precision : ", precision) + build_config.set_flag(precision) + + plan = builder.build_serialized_network(network, build_config) + engine_file_path = config.engine + with open(engine_file_path, "wb") as f: + f.write(plan) + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--model", type=str) + parser.add_argument("--precision", type=str, choices=["float16", "int8", "float32"], default="int8", + help="The precision of datatype") + parser.add_argument("--engine", type=str, default=None) + args = parser.parse_args() + return args + +if __name__ == "__main__": + args = parse_args() + main(args) \ No newline at end of file diff --git a/models/cv/classification/shufflenetv2_x0_5/ixrt/calibration_dataset.py b/models/cv/classification/shufflenetv2_x0_5/ixrt/calibration_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..774ed840114097ea75eb78f67ffc401f7806ff29 --- /dev/null +++ b/models/cv/classification/shufflenetv2_x0_5/ixrt/calibration_dataset.py @@ -0,0 +1,112 @@ +# Copyright (c) 2025, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import os +import torch +import torchvision.datasets +from torch.utils.data import DataLoader +from torchvision import models +from torchvision import transforms as T + + +class CalibrationImageNet(torchvision.datasets.ImageFolder): + def __init__(self, *args, **kwargs): + super(CalibrationImageNet, self).__init__(*args, **kwargs) + img2label_path = os.path.join(self.root, "val_map.txt") + if not os.path.exists(img2label_path): + raise FileNotFoundError(f"Not found label file `{img2label_path}`.") + + self.img2label_map = self.make_img2label_map(img2label_path) + + def make_img2label_map(self, path): + with open(path) as f: + lines = f.readlines() + + img2lable_map = dict() + for line in lines: + line = line.lstrip().rstrip().split("\t") + if len(line) != 2: + continue + img_name, label = line + img_name = img_name.strip() + if img_name in [None, ""]: + continue + label = int(label.strip()) + img2lable_map[img_name] = label + return img2lable_map + + def __getitem__(self, index): + path, target = self.samples[index] + sample = self.loader(path) + if self.transform is not None: + sample = self.transform(sample) + # if self.target_transform is not None: + # target = self.target_transform(target) + img_name = os.path.basename(path) + target = self.img2label_map[img_name] + + return sample, target + + +def create_dataloaders(data_path, num_samples=1024, img_sz=224, batch_size=2, workers=0): + dataset = CalibrationImageNet( + data_path, + transform=T.Compose( + [ + T.Resize(256), + T.CenterCrop(img_sz), + T.ToTensor(), + T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + ] + ), + ) + + calibration_dataset = dataset + if num_samples is not None: + calibration_dataset = torch.utils.data.Subset( + dataset, indices=range(num_samples) + ) + + calibration_dataloader = DataLoader( + calibration_dataset, + shuffle=False, + batch_size=batch_size, + drop_last=False, + num_workers=workers, + ) + + verify_dataloader = DataLoader( + dataset, + shuffle=False, + batch_size=batch_size, + drop_last=False, + num_workers=workers, + ) + + return calibration_dataloader, verify_dataloader + + +def getdataloader(dataset_dir, step=20, batch_size=32, workers=2, img_sz=224, total_sample=50000): + num_samples = min(total_sample, step * batch_size) + if step < 0: + num_samples = None + calibration_dataloader, _ = create_dataloaders( + dataset_dir, + img_sz=img_sz, + batch_size=batch_size, + workers=workers, + num_samples=num_samples, + ) + return calibration_dataloader \ No newline at end of file diff --git a/models/cv/classification/shufflenetv2_x0_5/ixrt/ci/prepare.sh b/models/cv/classification/shufflenetv2_x0_5/ixrt/ci/prepare.sh new file mode 100644 index 0000000000000000000000000000000000000000..9cfc3f2c5c4ced4848dfef4044d41a60875774a3 --- /dev/null +++ b/models/cv/classification/shufflenetv2_x0_5/ixrt/ci/prepare.sh @@ -0,0 +1,30 @@ +#!/bin/bash +# Copyright (c) 2025, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +set -x + +ID=$(grep -oP '(?<=^ID=).+' /etc/os-release | tr -d '"') +if [[ ${ID} == "ubuntu" ]]; then + apt install -y libgl1-mesa-glx +elif [[ ${ID} == "centos" ]]; then + yum install -y mesa-libGL +else + echo "Not Support Os" +fi + +pip install -r requirements.txt +mkdir -p checkpoints +python3 export.py --weight shufflenetv2_x0.5-f707e7126e.pth --output checkpoints/shufflenetv2_x0_5.onnx \ No newline at end of file diff --git a/models/cv/classification/shufflenetv2_x0_5/ixrt/common.py b/models/cv/classification/shufflenetv2_x0_5/ixrt/common.py new file mode 100644 index 0000000000000000000000000000000000000000..f7aed872184b64e4648bd79f600a59f5b5d0948c --- /dev/null +++ b/models/cv/classification/shufflenetv2_x0_5/ixrt/common.py @@ -0,0 +1,79 @@ +# Copyright (c) 2025, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import os +import cv2 +import glob +import torch +import tensorrt +import numpy as np +import pycuda.driver as cuda + +def eval_batch(batch_score, batch_label): + batch_score = torch.tensor(torch.from_numpy(batch_score), dtype=torch.float32) + values, indices = batch_score.topk(5) + top1, top5 = 0, 0 + for idx, label in enumerate(batch_label): + + if label == indices[idx][0]: + top1 += 1 + if label in indices[idx]: + top5 += 1 + return top1, top5 + +def create_engine_context(engine_path, logger): + with open(engine_path, "rb") as f: + runtime = tensorrt.Runtime(logger) + assert runtime + engine = runtime.deserialize_cuda_engine(f.read()) + assert engine + context = engine.create_execution_context() + assert context + + return engine, context + +def get_io_bindings(engine): + # Setup I/O bindings + inputs = [] + outputs = [] + allocations = [] + + for i in range(engine.num_bindings): + is_input = False + if engine.binding_is_input(i): + is_input = True + name = engine.get_binding_name(i) + dtype = engine.get_binding_dtype(i) + shape = engine.get_binding_shape(i) + if is_input: + batch_size = shape[0] + size = np.dtype(tensorrt.nptype(dtype)).itemsize + for s in shape: + size *= s + allocation = cuda.mem_alloc(size) + binding = { + "index": i, + "name": name, + "dtype": np.dtype(tensorrt.nptype(dtype)), + "shape": list(shape), + "allocation": allocation, + } + print(f"binding {i}, name : {name} dtype : {np.dtype(tensorrt.nptype(dtype))} shape : {list(shape)}") + allocations.append(allocation) + if engine.binding_is_input(i): + inputs.append(binding) + else: + outputs.append(binding) + return inputs, outputs, allocations diff --git a/models/cv/classification/shufflenetv2_x0_5/ixrt/config/SHUFFLENET_V2_X0_5_CONFIG b/models/cv/classification/shufflenetv2_x0_5/ixrt/config/SHUFFLENET_V2_X0_5_CONFIG new file mode 100644 index 0000000000000000000000000000000000000000..dbe6d447979bbf4768cbe3f3e18adf5fb991b1e5 --- /dev/null +++ b/models/cv/classification/shufflenetv2_x0_5/ixrt/config/SHUFFLENET_V2_X0_5_CONFIG @@ -0,0 +1,19 @@ +# IMGSIZE : 模型输入hw大小 +# MODEL_NAME : 生成onnx/engine的basename +# ORIGINE_MODEL : 原始onnx文件名称 +IMGSIZE=224 +MODEL_NAME=shufflenetv2_x0.5 +ORIGINE_MODEL=shufflenetv2_x0_5.onnx + +# QUANT CONFIG (仅PRECISION为int8时生效) + # QUANT_OBSERVER : 量化策略,可选 [hist_percentile, percentile, minmax, entropy, ema] + # QUANT_BATCHSIZE : 量化时组dataloader的batchsize, 最好和onnx中的batchsize保持一致,有些op可能推导shape错误(比如Reshape) + # QUANT_STEP : 量化步数 + # QUANT_SEED : 随机种子 保证量化结果可复现 + # QUANT_EXIST_ONNX : 如果有其他来源的量化模型则填写 +QUANT_OBSERVER=hist_percentile +QUANT_BATCHSIZE=1 +QUANT_STEP=32 +QUANT_SEED=42 +DISABLE_QUANT_LIST= +QUANT_EXIST_ONNX= diff --git a/models/cv/classification/shufflenetv2_x0_5/ixrt/export.py b/models/cv/classification/shufflenetv2_x0_5/ixrt/export.py new file mode 100644 index 0000000000000000000000000000000000000000..4733d330e753317428851d09995c84d90b1caeef --- /dev/null +++ b/models/cv/classification/shufflenetv2_x0_5/ixrt/export.py @@ -0,0 +1,61 @@ +# Copyright (c) 2025, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import torch +import torchvision +import argparse + +def parse_args(): + parser = argparse.ArgumentParser() + + parser.add_argument("--weight", + type=str, + required=True, + help="pytorch model weight.") + + parser.add_argument("--output", + type=str, + required=True, + help="export onnx model path.") + + args = parser.parse_args() + return args + +def main(): + args = parse_args() + + model = torchvision.models.shufflenet_v2_x0_5() + model.load_state_dict(torch.load(args.weight)) + model.eval() + + input_names = ['input'] + output_names = ['output'] + dynamic_axes = {'input': {0: '-1'}, 'output': {0: '-1'}} + dummy_input = torch.randn(1, 3, 224, 224) + + torch.onnx.export( + model, + dummy_input, + args.output, + input_names = input_names, + dynamic_axes = dynamic_axes, + output_names = output_names, + opset_version=13 + ) + + print("Export onnx model successfully! ") + +if __name__ == "__main__": + main() diff --git a/models/cv/classification/shufflenetv2_x0_5/ixrt/inference.py b/models/cv/classification/shufflenetv2_x0_5/ixrt/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..78126cb0ece51d9a8a1f7daea67dc6cdd38ce9bf --- /dev/null +++ b/models/cv/classification/shufflenetv2_x0_5/ixrt/inference.py @@ -0,0 +1,161 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# Copyright (c) 2025, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import argparse +import json +import os +import re +import time +from tqdm import tqdm + +import cv2 +import numpy as np +import pycuda.autoinit +import pycuda.driver as cuda +import torch +import tensorrt + +from calibration_dataset import getdataloader +from common import eval_batch, create_engine_context, get_io_bindings + +def main(config): + dataloader = getdataloader(config.datasets_dir, config.loop_count, config.bsz, img_sz=config.imgsz) + + host_mem = tensorrt.IHostMemory + logger = tensorrt.Logger(tensorrt.Logger.ERROR) + + # Load Engine && I/O bindings + engine, context = create_engine_context(config.engine_file, logger) + inputs, outputs, allocations = get_io_bindings(engine) + + # Warm up + if config.warm_up > 0: + print("\nWarm Start.") + for i in range(config.warm_up): + context.execute_v2(allocations) + print("Warm Done.") + + # Inference + if config.test_mode == "FPS": + torch.cuda.synchronize() + start_time = time.time() + + for i in range(config.loop_count): + context.execute_v2(allocations) + + torch.cuda.synchronize() + end_time = time.time() + forward_time = end_time - start_time + + num_samples = 50000 + if config.loop_count * config.bsz < num_samples: + num_samples = config.loop_count * config.bsz + fps = num_samples / forward_time + + print("FPS : ", fps) + print(f"Performance Check : Test {fps} >= target {config.fps_target}") + if fps >= config.fps_target: + print("pass!") + exit() + else: + print("failed!") + exit(1) + + elif config.test_mode == "ACC": + + ## Prepare the output data + output = np.zeros(outputs[0]["shape"], outputs[0]["dtype"]) + print(f"output shape : {output.shape} output type : {output.dtype}") + + total_sample = 0 + acc_top1, acc_top5 = 0, 0 + + start_time = time.time() + with tqdm(total= len(dataloader)) as _tqdm: + for idx, (batch_data, batch_label) in enumerate(dataloader): + batch_data = batch_data.numpy().astype(inputs[0]["dtype"]) + batch_data = np.ascontiguousarray(batch_data) + total_sample += batch_data.shape[0] + + cuda.memcpy_htod(inputs[0]["allocation"], batch_data) + context.execute_v2(allocations) + cuda.memcpy_dtoh(output, outputs[0]["allocation"]) + + # squeeze output shape [32,1000,1,1] to [32,1000] for mobilenet_v2 model + if len(output.shape) == 4: + output = output.squeeze(axis=(2,3)) + + batch_top1, batch_top5 = eval_batch(output, batch_label) + acc_top1 += batch_top1 + acc_top5 += batch_top5 + + _tqdm.set_postfix(acc_1='{:.4f}'.format(acc_top1/total_sample), + acc_5='{:.4f}'.format(acc_top5/total_sample)) + _tqdm.update(1) + end_time = time.time() + end2end_time = end_time - start_time + + print(F"E2E time : {end2end_time:.3f} seconds") + print(F"Acc@1 : {acc_top1/total_sample} = {acc_top1}/{total_sample}") + print(F"Acc@5 : {acc_top5/total_sample} = {acc_top5}/{total_sample}") + acc1 = acc_top1/total_sample + print(f"Accuracy Check : Test {acc1} >= target {config.acc_target}") + if acc1 >= config.acc_target: + print("pass!") + exit() + else: + print("failed!") + exit(1) + +def parse_config(): + parser = argparse.ArgumentParser() + parser.add_argument("--test_mode", type=str, default="FPS", help="FPS MAP") + parser.add_argument( + "--engine_file", + type=str, + help="engine file path" + ) + parser.add_argument( + "--datasets_dir", + type=str, + default="", + help="ImageNet dir", + ) + parser.add_argument("--warm_up", type=int, default=-1, help="warm_up times") + parser.add_argument("--bsz", type=int, default=32, help="test batch size") + parser.add_argument( + "--imgsz", + "--img", + "--img-size", + type=int, + default=224, + help="inference size h,w", + ) + parser.add_argument("--use_async", action="store_true") + parser.add_argument( + "--device", type=int, default=0, help="cuda device, i.e. 0 or 0,1,2,3,4" + ) + parser.add_argument("--fps_target", type=float, default=-1.0) + parser.add_argument("--acc_target", type=float, default=-1.0) + parser.add_argument("--loop_count", type=int, default=-1) + + config = parser.parse_args() + return config + +if __name__ == "__main__": + config = parse_config() + main(config) diff --git a/models/cv/classification/shufflenetv2_x0_5/ixrt/modify_batchsize.py b/models/cv/classification/shufflenetv2_x0_5/ixrt/modify_batchsize.py new file mode 100644 index 0000000000000000000000000000000000000000..2e8d086b460b14db81cc8e2d2d1487324534f551 --- /dev/null +++ b/models/cv/classification/shufflenetv2_x0_5/ixrt/modify_batchsize.py @@ -0,0 +1,57 @@ +# Copyright (c) 2025, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import onnx +import argparse + +def change_input_dim(model, bsz): + batch_size = bsz + + # The following code changes the first dimension of every input to be batch_size + # Modify as appropriate ... note that this requires all inputs to + # have the same batch_size + inputs = model.graph.input + for input in inputs: + # Checks omitted.This assumes that all inputs are tensors and have a shape with first dim. + # Add checks as needed. + dim1 = input.type.tensor_type.shape.dim[0] + # update dim to be a symbolic value + if isinstance(batch_size, str): + # set dynamic batch size + dim1.dim_param = batch_size + elif (isinstance(batch_size, str) and batch_size.isdigit()) or isinstance(batch_size, int): + # set given batch size + dim1.dim_value = int(batch_size) + else: + # set batch size of 1 + dim1.dim_value = 1 + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--batch_size", type=int) + parser.add_argument("--origin_model", type=str) + parser.add_argument("--output_model", type=str) + args = parser.parse_args() + return args + +args = parse_args() +model = onnx.load(args.origin_model) +change_input_dim(model, args.batch_size) +onnx.save(model, args.output_model) + + + + + diff --git a/models/cv/classification/shufflenetv2_x0_5/ixrt/quant.py b/models/cv/classification/shufflenetv2_x0_5/ixrt/quant.py new file mode 100644 index 0000000000000000000000000000000000000000..6e9740018253e62b05cc5a0d6c56bbbeae87a181 --- /dev/null +++ b/models/cv/classification/shufflenetv2_x0_5/ixrt/quant.py @@ -0,0 +1,59 @@ +# Copyright (c) 2025, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import os +import cv2 +import random +import argparse +import numpy as np +from random import shuffle +from tensorrt.deploy import static_quantize + +import torch +import torchvision.datasets +from calibration_dataset import getdataloader + +def setseed(seed=42): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--model_name", type=str) + parser.add_argument("--model", type=str) + parser.add_argument("--dataset_dir", type=str, default="imagenet_val") + parser.add_argument("--observer", type=str, choices=["hist_percentile", "percentile", "minmax", "entropy", "ema"], default="hist_percentile") + parser.add_argument("--disable_quant_names", nargs='*', type=str) + parser.add_argument("--save_dir", type=str, help="save path", default=None) + parser.add_argument("--bsz", type=int, default=32) + parser.add_argument("--step", type=int, default=20) + parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--imgsz", type=int, default=224) + args = parser.parse_args() + print("Quant config:", args) + print(args.disable_quant_names) + return args + +args = parse_args() +setseed(args.seed) +calibration_dataloader = getdataloader(args.dataset_dir, args.step, args.bsz, img_sz=args.imgsz) +static_quantize(args.model, + calibration_dataloader=calibration_dataloader, + save_quant_onnx_path=os.path.join(args.save_dir, f"quantized_{args.model_name}.onnx"), + observer=args.observer, + data_preprocess=lambda x: x[0].to("cuda"), + quant_format="qdq", + disable_quant_names=args.disable_quant_names) \ No newline at end of file diff --git a/models/cv/classification/shufflenetv2_x0_5/ixrt/refine_model.py b/models/cv/classification/shufflenetv2_x0_5/ixrt/refine_model.py new file mode 100644 index 0000000000000000000000000000000000000000..d7d8ca29cf7b8cc604d2ba87bd5375a2bcd61f2a --- /dev/null +++ b/models/cv/classification/shufflenetv2_x0_5/ixrt/refine_model.py @@ -0,0 +1,291 @@ +# Copyright (c) 2025, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import os +import argparse +import dataclasses + +import torch +import onnx + +from refine_utils.matmul_to_gemm_pass import FusedGemmPass +from refine_utils.linear_pass import FusedLinearPass + +from refine_utils.common import * + +def get_constant_input_name_of_operator(graph: Graph, operator: Operator): + const = None + for input in operator.inputs: + if not graph.containe_var(input): + continue + + if not graph.is_leaf_variable(input): + continue + + input_var = graph.get_variable(input) + if input_var.value is not None: + const = input + return const + +class FuseLayerNormPass(BasePass): + + def process(self, graph: Graph) -> Graph: + self.transform = GraphTransform(graph) + find_sequence_subgraph( + graph, + [OP.REDUCE_MEAN, OP.SUB, OP.POW, OP.REDUCE_MEAN, OP.ADD, OP.SQRT, OP.DIV, OP.MUL, OP.ADD], + self.fuse_layer_norm, + strict=False + ) + return graph + + def fuse_layer_norm(self, graph: Graph, pattern: PatternGraph): + # 检查 REDUCE_MEAN 的输入是否和 SUB 的输入是一致的 + if pattern.nodes[0].operator.inputs[0] != pattern.nodes[1].operator.inputs[0]: + return + + # 检查 POW 的输入是否和 DIV 的输入是一致的 + if pattern.nodes[2].operator.inputs[0] != pattern.nodes[6].operator.inputs[0]: + return + + # 检查部分算子的输出是否被多个算子使用 + nodes = pattern.nodes + for node in [nodes[0]] + nodes[2:-1]: + next_ops = graph.get_next_operators(node.operator) + if len(next_ops) > 1: + return + + eps = None + for input in nodes[4].operator.inputs: + input_var = graph.get_variable(input) + if input_var.value is not None and graph.is_leaf_variable(input): + eps = to_py_type(input_var.value) + + scale = get_constant_input_name_of_operator(graph, nodes[-2].operator) + bias = get_constant_input_name_of_operator(graph, nodes[-1].operator) + + self.transform.delete_operators_between_op_op(nodes[0].operator, nodes[-1].operator) + + bias_var = graph.get_variable(bias) + print(bias_var) + + attributes = { + "axis": nodes[0].operator.attributes.axes, + "epsilon": eps, + } + + + layer_norm_op = self.transform.make_operator( + op_type="LayerNormalization", + inputs=[nodes[0].operator.inputs[0], scale, bias], + outputs=[nodes[-1].operator.outputs[0]], + **attributes + ) + + self.transform.add_operator(layer_norm_op) + +class FusedGeluPass(BasePass): + + def process(self, graph: Graph) -> Graph: + self.transform = GraphTransform(graph) + + find_sequence_subgraph( + graph, pattern=[OP.DIV, OP.ERF, OP.ADD, OP.MUL, OP.MUL], callback=self.fuse_gelu, strict=True + ) + return graph + + def fuse_gelu(self, graph: Graph, pattern: PatternGraph): + nodes = pattern.nodes + prev_op = self.transform.get_previous_operators(nodes[0].operator)[0] + next_ops = self.transform.get_next_operators(prev_op) + if len(next_ops) != 2: + return + + if nodes[0].operator not in next_ops or nodes[3].operator not in next_ops: + return + + gelu_op_input = None + for input in nodes[3].operator.inputs: + if input in nodes[0].operator.inputs: + gelu_op_input = input + break + + self.transform.delete_operators_between_op_op(nodes[0].operator, nodes[-1].operator) + + gelu_op = self.transform.make_operator( + op_type=OP.GELU, + inputs=[gelu_op_input], + outputs=[nodes[-1].operator.outputs[0]] + ) + self.transform.add_operator(gelu_op) + +@dataclasses.dataclass +class NormalizeAttr(BaseOperatorAttr): + p: float = 2.0 + epsilon: float = 1e-12 + axis: int = 1 + + +@registe_operator(OP.GELU) +class GeluOperator(BaseOperator): + + def call( + self, + executor, + operator: Operator, + inputs: List, + attr: NormalizeAttr, + ): + return F.gelu(inputs[0]) + + def convert_onnx_operator( + self, ir_graph: Graph, onnx_graph: onnx.GraphProto, node: onnx.NodeProto + ) -> Operator: + return default_converter(ir_graph, onnx_graph, node, attr_cls=attr.EmptyAttr) + + def quantize( + self, + graph: Graph, + op: Operator, + operator_observer_config: QuantOperatorObserverConfig, + quant_outputs: bool = False, + ): + return quant_single_input_operator(graph, op, operator_observer_config, quant_outputs=quant_outputs) + + + +class ClearUnsedVariables(BasePass): + + def process(self, graph: Graph) -> Graph: + vars = list(graph.variables) + + for var in vars: + if len(graph.get_dst_operators(var)) == 0 and graph.is_leaf_variable(var): + graph.delete_variable(var) + + quant_params = list(graph.quant_parameters.keys()) + for var in quant_params: + if not graph.containe_var(var): + graph.quant_parameters.pop(var) + + return graph + +class FormatLayerNorm(BasePass): + + def process(self, graph: Graph) -> Graph: + for op in graph.operators.values(): + if "LayerNorm" in op.op_type: + self.format_layer_norm(graph, op) + return graph + + def format_layer_norm(self, graph, operator): + if not hasattr(operator.attributes, "axis"): + return + if isinstance(operator.attributes.axis, (tuple, list)): + operator.attributes.axis = operator.attributes.axis[0] + +class FormatReshape(BasePass): + + def process(self, graph: Graph) -> Graph: + for op in graph.operators.values(): + if op.op_type == "Reshape": + self.format_reshape(graph, op) + + return graph + + def format_reshape(self, graph, operator): + shape = graph.get_variable(operator.inputs[1]) + shape.value = torch.tensor(shape.value, dtype=torch.int64) + +class FormatScalar(BasePass): + + def process(self, graph: Graph): + for var in graph.variables.values(): + var: Variable + use_ops = graph.get_dst_operators(var) + + if len(use_ops) == 0: + continue + + if use_ops[0].op_type not in [OP.MUL, OP.ADD, OP.GATHER]: + continue + + if var.value is not None and var.value.ndim == 0: + var.value = var.value.reshape(1) + print(f"Reshape scalar to tensor for {var.name}.") + + return graph + +class RenamePass(BasePass): + + def process(self, graph:Graph): + + names = [name for name in graph.operators.keys()] + for old_name in names: + new_name = old_name.replace("/", "#") + + graph.rename_operator(old_name, new_name) + + names = [name for name in graph.variables.keys()] + for name in names: + new_name = name.replace("/", ".").replace("Output", "out").replace("output", "out") + + graph.rename_vaiable(name, new_name, + with_variables=True, + with_operator_outputs=True) + + return graph + +def create_pipeline(example_inputs): + return PassSequence( + # FuseLayerNormPass(), + FusedGeluPass(), + + # ClearUnsedVariables(), + # FormatLayerNorm(), + # FormatReshape(), + # FormatScalar(), + # RenamePass() + ) + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--onnx_path", type=str) + parser.add_argument("--dst_onnx_path", type=str) + + parser.add_argument("--bsz", type=int, default=8, + help="Batch size") + parser.add_argument("--imgsz", type=int, default=224, + help="Image size") + + return parser.parse_args() + + +if __name__ == "__main__": + args = parse_args() + + example_inputs = torch.randn(args.bsz, 3, args.imgsz, args.imgsz) + + refine_pipline = Pipeline( + create_source(f"{args.onnx_path}", example_inputs=example_inputs), + create_pipeline(example_inputs), + create_target( + f"{args.dst_onnx_path}", + example_inputs=example_inputs, + ) + ) + refine_pipline.run() + + print(f"refine the model, input shape={example_inputs.shape}") diff --git a/models/cv/classification/shufflenetv2_x0_5/ixrt/requirements.txt b/models/cv/classification/shufflenetv2_x0_5/ixrt/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..4546592bb80eedaab3374186467240fa61dde264 --- /dev/null +++ b/models/cv/classification/shufflenetv2_x0_5/ixrt/requirements.txt @@ -0,0 +1,8 @@ +pycuda +tqdm +tabulate +onnx +onnxsim +opencv-python==4.6.0.66 +mmcls==0.24.0 +mmcv==1.5.3 \ No newline at end of file diff --git a/models/cv/classification/shufflenetv2_x0_5/ixrt/scripts/infer_shufflenetv2_x0_5_fp16_accuracy.sh b/models/cv/classification/shufflenetv2_x0_5/ixrt/scripts/infer_shufflenetv2_x0_5_fp16_accuracy.sh new file mode 100644 index 0000000000000000000000000000000000000000..c0d5f96ca5120fed452077d586c9975967e57434 --- /dev/null +++ b/models/cv/classification/shufflenetv2_x0_5/ixrt/scripts/infer_shufflenetv2_x0_5_fp16_accuracy.sh @@ -0,0 +1,143 @@ +#!/bin/bash +# Copyright (c) 2025, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +EXIT_STATUS=0 +check_status() +{ + if ((${PIPESTATUS[0]} != 0));then + EXIT_STATUS=1 + fi +} + +# Run paraments +BSZ=32 +TGT=-1 +WARM_UP=0 +LOOP_COUNT=-1 +RUN_MODE=ACC +PRECISION=float16 + +# Update arguments +index=0 +options=$@ +arguments=($options) +for argument in $options +do + index=`expr $index + 1` + case $argument in + --bs) BSZ=${arguments[index]};; + --tgt) TGT=${arguments[index]};; + esac +done + +source ${CONFIG_DIR} +ORIGINE_MODEL=${CHECKPOINTS_DIR}/${ORIGINE_MODEL} + +echo CHECKPOINTS_DIR : ${CHECKPOINTS_DIR} +echo DATASETS_DIR : ${DATASETS_DIR} +echo RUN_DIR : ${RUN_DIR} +echo CONFIG_DIR : ${CONFIG_DIR} +echo ====================== Model Info ====================== +echo Model Name : ${MODEL_NAME} +echo Onnx Path : ${ORIGINE_MODEL} + +step=0 +SIM_MODEL=${CHECKPOINTS_DIR}/${MODEL_NAME}_sim.onnx + +# Simplify Model +let step++ +echo; +echo [STEP ${step}] : Simplify Model +if [ -f ${SIM_MODEL} ];then + echo " "Simplify Model, ${SIM_MODEL} has been existed +else + python3 ${RUN_DIR}/simplify_model.py \ + --origin_model $ORIGINE_MODEL \ + --output_model ${SIM_MODEL} + echo " "Generate ${SIM_MODEL} +fi + +# Quant Model +if [ $PRECISION == "int8" ];then + let step++ + echo; + echo [STEP ${step}] : Quant Model + if [[ -z ${QUANT_EXIST_ONNX} ]];then + QUANT_EXIST_ONNX=$CHECKPOINTS_DIR/quantized_${MODEL_NAME}.onnx + fi + if [[ -f ${QUANT_EXIST_ONNX} ]];then + SIM_MODEL=${QUANT_EXIST_ONNX} + echo " "Quant Model Skip, ${QUANT_EXIST_ONNX} has been existed + else + python3 ${RUN_DIR}/quant.py \ + --model ${SIM_MODEL} \ + --model_name ${MODEL_NAME} \ + --dataset_dir ${DATASETS_DIR} \ + --observer ${QUANT_OBSERVER} \ + --disable_quant_names ${DISABLE_QUANT_LIST[@]} \ + --save_dir $CHECKPOINTS_DIR \ + --bsz ${QUANT_BATCHSIZE} \ + --step ${QUANT_STEP} \ + --seed ${QUANT_SEED} \ + --imgsz ${IMGSIZE} + SIM_MODEL=${QUANT_EXIST_ONNX} + echo " "Generate ${SIM_MODEL} + fi +fi + +# Change Batchsize +let step++ +echo; +echo [STEP ${step}] : Change Batchsize +FINAL_MODEL=${CHECKPOINTS_DIR}/${MODEL_NAME}_${BSZ}.onnx +if [ -f $FINAL_MODEL ];then + echo " "Change Batchsize Skip, $FINAL_MODEL has been existed +else + python3 ${RUN_DIR}/modify_batchsize.py --batch_size ${BSZ} \ + --origin_model ${SIM_MODEL} --output_model ${FINAL_MODEL} + echo " "Generate ${FINAL_MODEL} +fi + +# Build Engine +let step++ +echo; +echo [STEP ${step}] : Build Engine +ENGINE_FILE=${CHECKPOINTS_DIR}/${MODEL_NAME}_${PRECISION}_bs${BSZ}.engine +if [ -f $ENGINE_FILE ];then + echo " "Build Engine Skip, $ENGINE_FILE has been existed +else + python3 ${RUN_DIR}/build_engine.py \ + --precision ${PRECISION} \ + --model ${FINAL_MODEL} \ + --engine ${ENGINE_FILE} + echo " "Generate Engine ${ENGINE_FILE} +fi + +# Inference +let step++ +echo; +echo [STEP ${step}] : Inference +python3 ${RUN_DIR}/inference.py \ + --engine_file=${ENGINE_FILE} \ + --datasets_dir=${DATASETS_DIR} \ + --imgsz=${IMGSIZE} \ + --warm_up=${WARM_UP} \ + --loop_count ${LOOP_COUNT} \ + --test_mode ${RUN_MODE} \ + --acc_target ${TGT} \ + --bsz ${BSZ}; check_status + +exit ${EXIT_STATUS} \ No newline at end of file diff --git a/models/cv/classification/shufflenetv2_x0_5/ixrt/scripts/infer_shufflenetv2_x0_5_fp16_performance.sh b/models/cv/classification/shufflenetv2_x0_5/ixrt/scripts/infer_shufflenetv2_x0_5_fp16_performance.sh new file mode 100644 index 0000000000000000000000000000000000000000..9ea9b62e3a59437e74df80aaea55cc6d03037ec7 --- /dev/null +++ b/models/cv/classification/shufflenetv2_x0_5/ixrt/scripts/infer_shufflenetv2_x0_5_fp16_performance.sh @@ -0,0 +1,143 @@ +#!/bin/bash +# Copyright (c) 2025, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +EXIT_STATUS=0 +check_status() +{ + if ((${PIPESTATUS[0]} != 0));then + EXIT_STATUS=1 + fi +} + +# Run paraments +BSZ=32 +TGT=-1 +WARM_UP=3 +LOOP_COUNT=20 +RUN_MODE=FPS +PRECISION=float16 + +# Update arguments +index=0 +options=$@ +arguments=($options) +for argument in $options +do + index=`expr $index + 1` + case $argument in + --bs) BSZ=${arguments[index]};; + --tgt) TGT=${arguments[index]};; + esac +done + +source ${CONFIG_DIR} +ORIGINE_MODEL=${CHECKPOINTS_DIR}/${ORIGINE_MODEL} + +echo CHECKPOINTS_DIR : ${CHECKPOINTS_DIR} +echo DATASETS_DIR : ${DATASETS_DIR} +echo RUN_DIR : ${RUN_DIR} +echo CONFIG_DIR : ${CONFIG_DIR} +echo ====================== Model Info ====================== +echo Model Name : ${MODEL_NAME} +echo Onnx Path : ${ORIGINE_MODEL} + +step=0 +SIM_MODEL=${CHECKPOINTS_DIR}/${MODEL_NAME}_sim.onnx + +# Simplify Model +let step++ +echo; +echo [STEP ${step}] : Simplify Model +if [ -f ${SIM_MODEL} ];then + echo " "Simplify Model, ${SIM_MODEL} has been existed +else + python3 ${RUN_DIR}/simplify_model.py \ + --origin_model $ORIGINE_MODEL \ + --output_model ${SIM_MODEL} + echo " "Generate ${SIM_MODEL} +fi + +# Quant Model +if [ $PRECISION == "int8" ];then + let step++ + echo; + echo [STEP ${step}] : Quant Model + if [[ -z ${QUANT_EXIST_ONNX} ]];then + QUANT_EXIST_ONNX=$CHECKPOINTS_DIR/quantized_${MODEL_NAME}.onnx + fi + if [[ -f ${QUANT_EXIST_ONNX} ]];then + SIM_MODEL=${QUANT_EXIST_ONNX} + echo " "Quant Model Skip, ${QUANT_EXIST_ONNX} has been existed + else + python3 ${RUN_DIR}/quant.py \ + --model ${SIM_MODEL} \ + --model_name ${MODEL_NAME} \ + --dataset_dir ${DATASETS_DIR} \ + --observer ${QUANT_OBSERVER} \ + --disable_quant_names ${DISABLE_QUANT_LIST[@]} \ + --save_dir $CHECKPOINTS_DIR \ + --bsz ${QUANT_BATCHSIZE} \ + --step ${QUANT_STEP} \ + --seed ${QUANT_SEED} \ + --imgsz ${IMGSIZE} + SIM_MODEL=${QUANT_EXIST_ONNX} + echo " "Generate ${SIM_MODEL} + fi +fi + +# Change Batchsize +let step++ +echo; +echo [STEP ${step}] : Change Batchsize +FINAL_MODEL=${CHECKPOINTS_DIR}/${MODEL_NAME}_${BSZ}.onnx +if [ -f $FINAL_MODEL ];then + echo " "Change Batchsize Skip, $FINAL_MODEL has been existed +else + python3 ${RUN_DIR}/modify_batchsize.py --batch_size ${BSZ} \ + --origin_model ${SIM_MODEL} --output_model ${FINAL_MODEL} + echo " "Generate ${FINAL_MODEL} +fi + +# Build Engine +let step++ +echo; +echo [STEP ${step}] : Build Engine +ENGINE_FILE=${CHECKPOINTS_DIR}/${MODEL_NAME}_${PRECISION}_bs${BSZ}.engine +if [ -f $ENGINE_FILE ];then + echo " "Build Engine Skip, $ENGINE_FILE has been existed +else + python3 ${RUN_DIR}/build_engine.py \ + --precision ${PRECISION} \ + --model ${FINAL_MODEL} \ + --engine ${ENGINE_FILE} + echo " "Generate Engine ${ENGINE_FILE} +fi + +# Inference +let step++ +echo; +echo [STEP ${step}] : Inference +python3 ${RUN_DIR}/inference.py \ + --engine_file=${ENGINE_FILE} \ + --datasets_dir=${DATASETS_DIR} \ + --imgsz=${IMGSIZE} \ + --warm_up=${WARM_UP} \ + --loop_count ${LOOP_COUNT} \ + --test_mode ${RUN_MODE} \ + --fps_target ${TGT} \ + --bsz ${BSZ}; check_status + +exit ${EXIT_STATUS} \ No newline at end of file diff --git a/models/cv/classification/shufflenetv2_x0_5/ixrt/simplify_model.py b/models/cv/classification/shufflenetv2_x0_5/ixrt/simplify_model.py new file mode 100644 index 0000000000000000000000000000000000000000..bef33576a17235d24305fa32ae3dab6d7accc10e --- /dev/null +++ b/models/cv/classification/shufflenetv2_x0_5/ixrt/simplify_model.py @@ -0,0 +1,41 @@ +# Copyright (c) 2025, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import onnx +import argparse +from onnxsim import simplify + +# Simplify +def simplify_model(args): + onnx_model = onnx.load(args.origin_model) + model_simp, check = simplify(onnx_model) + model_simp = onnx.shape_inference.infer_shapes(model_simp) + onnx.save(model_simp, args.output_model) + print(" Simplify onnx Done.") + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--origin_model", type=str) + parser.add_argument("--output_model", type=str) + parser.add_argument("--reshape", action="store_true") + args = parser.parse_args() + return args + +args = parse_args() +simplify_model(args) + + + +