diff --git a/ACL_PyTorch/built-in/cv/SwinTransformer_for_Pytorch/pth2onnx.py b/ACL_PyTorch/built-in/cv/SwinTransformer_for_Pytorch/pth2onnx.py index 3b5263295e8173dc8bc1592fc44eb9eff5462fd0..f45803c3d7325cce15be4e4cc16ab4a16cb15969 100644 --- a/ACL_PyTorch/built-in/cv/SwinTransformer_for_Pytorch/pth2onnx.py +++ b/ACL_PyTorch/built-in/cv/SwinTransformer_for_Pytorch/pth2onnx.py @@ -17,26 +17,50 @@ import os import argparse import torch import timm - +import numpy as np def pth2onnx(args): pth_path = args.input_path batch_size = args.batch_size - model_name = args.model_name out_path = args.out_path - # get size + checkpoint = torch.load(pth_path, map_location='cpu') + + config = checkpoint['config'] + state_dict = checkpoint['model'] + + model_name = config.MODEL.NAME + + model = timm.create_model( + model_name, + pretrained=False, + num_classes=config.MODEL.NUM_CLASSES + ) + + new_state_dict = {} + for k, v in state_dict.items(): + if k.startswith('module.'): + new_state_dict[k[7:]] = v + else: + new_state_dict[k] = v + + # Fix the shape mismatch issue of relative_position_index + for key in list(new_state_dict.keys()): + if 'relative_position_index' in key: + #The original shape is [2401], and it needs to be reshaped into [49, 49] + if new_state_dict[key].shape == torch.Size([2401]): + new_state_dict[key] = new_state_dict[key].view(49, 49) + + model.load_state_dict(new_state_dict, strict=False) + + model.eval() + if 's3' in model_name: - size = int(model_name.split('_')[3]) + input_size = int(model_name.split('_')[3]) else: - size = int(model_name.split('_')[4]) - input_data = torch.randn([batch_size, 3, size, size]).to(torch.float32) - input_names = ["image"] - output_names = ["out"] + input_size = int(model_name.split('_')[4]) - # build model - model = timm.create_model(model_name, checkpoint_path=pth_path) - model.eval() + input_data = torch.randn([batch_size, 3, input_size, input_size], dtype=torch.float32) torch.onnx.export( model, @@ -44,27 +68,21 @@ def pth2onnx(args): out_path, verbose=True, opset_version=11, - input_names=input_names, - output_names=output_names + input_names=["image"], + output_names=["output"], ) - def parse_arguments(): - parser = argparse.ArgumentParser(description='SwinTransformer onnx export.') + parser = argparse.ArgumentParser(description='Convert Swin-Tiny pth to onnx') parser.add_argument('-i', '--input_path', type=str, required=True, help='input path for pth model') parser.add_argument('-o', '--out_path', type=str, required=True, help='save path for output onnx model') - parser.add_argument('-n', '--model_name', type=str, default='swin_base_patch4_window12_384', - help='model name for swintransformer') parser.add_argument('-b', '--batch_size', type=int, default=1, help='batch size for output model') - args = parser.parse_args() - args.out_path = os.path.abspath(args.out_path) - os.makedirs(os.path.dirname(args.out_path), exist_ok=True) - return args - + return parser.parse_args() if __name__ == '__main__': args = parse_arguments() pth2onnx(args) +