diff --git a/ACL_PyTorch/built-in/cv/SwinTransformer_for_Pytorch/pth2onnx.py b/ACL_PyTorch/built-in/cv/SwinTransformer_for_Pytorch/pth2onnx.py index 3b5263295e8173dc8bc1592fc44eb9eff5462fd0..4db88e7b045092ada331d3a4f1865fd8dfbe54c7 100644 --- a/ACL_PyTorch/built-in/cv/SwinTransformer_for_Pytorch/pth2onnx.py +++ b/ACL_PyTorch/built-in/cv/SwinTransformer_for_Pytorch/pth2onnx.py @@ -17,54 +17,118 @@ import os import argparse import torch import timm +import numpy as np def pth2onnx(args): pth_path = args.input_path batch_size = args.batch_size - model_name = args.model_name out_path = args.out_path + model_name = args.model_name + is_open_source = args.open_source + + if is_open_source: + if 's3' in model_name: + size = int(model_name.split('_')[3]) + else: + size = int(model_name.split('_')[4]) + + input_data = torch.randn([batch_size, 3, size, size]).to(torch.float32) + input_names = ["image"] + output_names = ["output"] + + model = timm.create_model(model_name, checkpoint_path=pth_path) + model.eval() + + torch.onnx.export( + model, + input_data, + out_path, + verbose=True, + opset_version=11, + input_names=input_names, + output_names=output_names + ) - # get size - if 's3' in model_name: - size = int(model_name.split('_')[3]) else: - size = int(model_name.split('_')[4]) - input_data = torch.randn([batch_size, 3, size, size]).to(torch.float32) - input_names = ["image"] - output_names = ["out"] - - # build model - model = timm.create_model(model_name, checkpoint_path=pth_path) - model.eval() - - torch.onnx.export( - model, - input_data, - out_path, - verbose=True, - opset_version=11, - input_names=input_names, - output_names=output_names - ) + checkpoint = torch.load(pth_path, map_location='cpu') + + if 'config' in checkpoint and 'model' in checkpoint: + config = checkpoint['config'] + state_dict = checkpoint['model'] + if hasattr(config, 'MODEL') and hasattr(config.MODEL, 'NAME'): + model_name_from_config = config.MODEL.NAME + if model_name == "default": + model_name = model_name_from_config + else: + raise ValueError("Checkpoint文件结构不符合预期,应包含'config'和'model'键") + + model = timm.create_model( + model_name, + pretrained=False, + num_classes=config.MODEL.NUM_CLASSES + ) + + new_state_dict = {} + for k, v in state_dict.items(): + if k.startswith('module.'): + new_state_dict[k[7:]] = v + else: + new_state_dict[k] = v + + # 修复relative_position_index的形状不匹配问题 + for key in list(new_state_dict.keys()): + if 'relative_position_index' in key: + tensor_value = new_state_dict.get(key) + if tensor_value is not None and tensor_value.shape == torch.Size([2401]): + new_state_dict[key] = tensor_value.view(49, 49) + + model.load_state_dict(new_state_dict, strict=False) + model.eval() + + if 's3' in model_name: + input_size = int(model_name.split('_')[3]) + else: + input_size = int(model_name.split('_')[4]) + + input_data = torch.randn([batch_size, 3, input_size, input_size], dtype=torch.float32) + input_names = ["image"] + output_names = ["output"] + + torch.onnx.export( + model, + input_data, + out_path, + verbose=True, + opset_version=11, + input_names=input_names, + output_names=output_names, + ) def parse_arguments(): - parser = argparse.ArgumentParser(description='SwinTransformer onnx export.') + parser = argparse.ArgumentParser(description='Convert SwinTransformer pth to onnx') parser.add_argument('-i', '--input_path', type=str, required=True, help='input path for pth model') parser.add_argument('-o', '--out_path', type=str, required=True, help='save path for output onnx model') - parser.add_argument('-n', '--model_name', type=str, default='swin_base_patch4_window12_384', - help='model name for swintransformer') + parser.add_argument('-n', '--model_name', type=str, default='default', + help='model name for swintransformer (e.g., swin_base_patch4_window12_384)') parser.add_argument('-b', '--batch_size', type=int, default=1, help='batch size for output model') + parser.add_argument('--open_source', action='store_false', + help='whether the model is from open source (use timm direct loading)') args = parser.parse_args() - args.out_path = os.path.abspath(args.out_path) - os.makedirs(os.path.dirname(args.out_path), exist_ok=True) + + os.makedirs(os.path.dirname(os.path.abspath(args.out_path)), exist_ok=True) + return args if __name__ == '__main__': args = parse_arguments() + if args.open_source and args.model_name == 'default': + raise ValueError("使用开源模型时必须指定 --model_name 参数") + pth2onnx(args) +