代码拉取完成,页面将自动刷新
import argparse
import ast
import os
import importlib
import sys
import time
import torch
from mindspore import context
from src.compare import Compare
from src import utils
CLASS_NAME = "TestNet"
def _accuracy_compare_parser(parser):
parser.add_argument("-ms", "--mindspore_net", dest="ms_net", default="",
help="<Required> The MindSpore net file path", required=True)
parser.add_argument("-pt", "--pytorch_net", dest="pt_net", default="",
help="<Required> The PyTorch net file path", required=True)
parser.add_argument("-i", "--input_shape", dest="input_shape", default="",
help="<Required> The input data shape of the model. It should be a list."
"Separate multiple inputs with commas(,) in list. "
"E.g: one input [[1,2]]; two inputs [[1,2],[1,2]]")
parser.add_argument("-dtype", "--input_dtype", dest="input_dtype", default="",
help="<Optional> The input data dtype of the model."
"Separate multiple inputs with commas(,). "
"E.g: one input float32; two inputs int32,float32."
"If this option is not configured, the default type of all inputs is float32.")
parser.add_argument("-o", "--out_path", dest="out_path", default="", help="<Optional> The output path")
parser.add_argument("-init", "--init_mode", dest="init_mode", choices=["random", "ones"], default="random",
help="<Optional> Net operator attribute init mode, support: ones; random")
parser.add_argument("-d", "--device", dest="device", default="0",
help="<Optional> Input device ID [0, 255], default is 0.")
parser.add_argument("-print_result", dest="print_result", action="store_true",
help="<Optional> Print compare result on console.")
def init_net_class(file_path, class_name):
file_path = os.path.realpath(file_path)
file_name = os.path.basename(file_path)
module_name, _ = os.path.splitext(file_name)
sys.path.append(os.path.dirname(file_path))
module = importlib.import_module(module_name)
module_cls = getattr(module, class_name)
return module_cls
def load_shape_type():
try:
input_shape = ast.literal_eval(args.input_shape)
except Exception as ex:
utils.print_error_log("Failed to parse input shape: %s" % ex)
raise utils.CompareException(utils.ACCURACY_COMPARISON_INVALID_PARAM_ERROR)
if args.input_dtype:
try:
input_dtype_str = args.input_dtype
input_dtype = input_dtype_str.split(",")
except Exception as ex:
utils.print_error_log("Failed to parse input dtype: %s" % ex)
raise utils.CompareException(utils.ACCURACY_COMPARISON_INVALID_PARAM_ERROR)
else:
input_dtype = ["float32"]*len(input_shape)
return input_shape, input_dtype
if __name__ == "__main__":
start = time.time()
parser = argparse.ArgumentParser()
_accuracy_compare_parser(parser)
args = parser.parse_args(sys.argv[1:])
# check and set device
utils.check_device_param_valid(args.device)
context.set_context(device_target="Ascend", device_id=int(args.device))
# check net script path
args.ms_net = os.path.realpath(args.ms_net)
args.pt_net = os.path.realpath(args.pt_net)
utils.check_file_or_directory_path(args.ms_net)
utils.check_file_or_directory_path(args.pt_net)
# get PyTorch and MindSpore net class
pt_module_cls = init_net_class(args.pt_net, CLASS_NAME)
PtNet = pt_module_cls()
ms_module_cls = init_net_class(args.ms_net, CLASS_NAME)
MsNet = ms_module_cls()
# init out path
utils.check_file_or_directory_path(os.path.realpath(args.out_path), True)
time_dir = time.strftime("%Y%m%d%H%M%S", time.localtime())
out_dir_path = os.path.realpath(os.path.join(args.out_path, time_dir))
input_shape, input_dtype = load_shape_type()
compare_task = Compare(ptmodule=PtNet, msmodule=MsNet,
input_shape=input_shape,
input_dtype=input_dtype,
init_mode=args.init_mode,
out_path=out_dir_path,
print_result=args.print_result)
compare_task.start_compare()
end = time.time()
utils.print_info_log("The command was completed and took %s seconds." % (end - start))
sys.exit(0)
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。