From 1a2dc3af28a373e694c468e534dd4b2b747718af Mon Sep 17 00:00:00 2001 From: tanghongyan <1349905607@qq.com> Date: Thu, 24 Mar 2022 15:18:16 +0800 Subject: [PATCH] =?UTF-8?q?=E5=9B=9E=E5=BD=92=E6=B5=8B=E8=AF=95=E4=BF=AE?= =?UTF-8?q?=E6=94=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit . . --- .../MnasNet/mnasnet_pthtar2onnx.py | 42 +++++++++++++++++++ .../MnasNet/modelArts/pth2onnx.py | 23 +++++----- 2 files changed, 53 insertions(+), 12 deletions(-) create mode 100644 PyTorch/contrib/cv/classification/MnasNet/mnasnet_pthtar2onnx.py diff --git a/PyTorch/contrib/cv/classification/MnasNet/mnasnet_pthtar2onnx.py b/PyTorch/contrib/cv/classification/MnasNet/mnasnet_pthtar2onnx.py new file mode 100644 index 0000000000..92d41e71aa --- /dev/null +++ b/PyTorch/contrib/cv/classification/MnasNet/mnasnet_pthtar2onnx.py @@ -0,0 +1,42 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import sys +import torch +import mnasnet +import torch.onnx + +from collections import OrderedDict + + +def convert(): + checkpoint = torch.load(input_file, map_location=None) + model = mnasnet.mnasnet1_0() + model.load_state_dict(checkpoint) + model.eval() + print(model) + + input_names = ["image"] + output_names = ["class"] + dynamic_axes = {'image': {0: '-1'}, 'class': {0: '-1'}} + dummy_input = torch.randn(1, 3, 224, 224) + + torch.onnx.export(model, dummy_input, "mnasnet1_0.onnx", input_names=input_names, output_names=output_names, + opset_version=11) + + +if __name__ == "__main__": + input_file = sys.argv[1] + convert() diff --git a/PyTorch/contrib/cv/classification/MnasNet/modelArts/pth2onnx.py b/PyTorch/contrib/cv/classification/MnasNet/modelArts/pth2onnx.py index 34178af7c3..5b66c0830c 100644 --- a/PyTorch/contrib/cv/classification/MnasNet/modelArts/pth2onnx.py +++ b/PyTorch/contrib/cv/classification/MnasNet/modelArts/pth2onnx.py @@ -49,21 +49,20 @@ def proc_node_module(checkpoint, AttrName): return new_state_dict -def convert(pth_file, onnx_path, class_num, train_url, npu): +def convert(pth_file, onnx_path, class_num, train_url): - loc = 'npu:{}'.format(npu) - checkpoint = torch.load(pth_file, map_location=loc) - - checkpoint['state_dict'] = proc_node_module(checkpoint, 'state_dict') + checkpoint = torch.load(pth_file, map_location=None) + model = mnasnet.mnasnet1_0(num_classes=class_num) + model.load_state_dict(checkpoint) - model.to(loc) - model.load_state_dict(checkpoint['state_dict']) model.eval() - input_names = ["actual_input_1"] - output_names = ["output1"] - dummy_input = torch.randn(16, 3, 224, 224) - dummy_input = dummy_input.to(loc, non_blocking=False) + + input_names = ["image"] + output_names = ["class"] + dynamic_axes = {'image': {0: '-1'}, 'class': {0: '-1'}} + dummy_input = torch.randn(1, 3, 224, 224) + torch.onnx.export(model, dummy_input, onnx_path, input_names=input_names, output_names=output_names, opset_version=11) mox.file.copy_parallel(onnx_path, train_url + 'model.onnx') @@ -76,7 +75,7 @@ def convert_pth_to_onnx(config_args): return pth_file = pth_file_list[0] onnx_path = pth_file.split(".")[0] + '.onnx' - convert(pth_file, onnx_path, config_args.class_num, config_args.train_url, config_args.npu) + convert(pth_file, onnx_path, config_args.class_num, config_args.train_url) if __name__ == '__main__': parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') # modelarts -- Gitee