From 372d68f59b822c8abb3bea1fed9a406148e273d3 Mon Sep 17 00:00:00 2001 From: h00613304 Date: Wed, 22 May 2024 15:28:31 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90=E7=B2=BE=E5=BA=A6=E5=B7=A5=E5=85=B7?= =?UTF-8?q?=E5=90=88=E4=B8=80=E3=80=91=E7=94=A8=E6=88=B7=E5=91=BD=E4=BB=A4?= =?UTF-8?q?=E8=A1=8C=E4=BD=BF=E7=94=A8=E5=BD=92=E4=B8=80?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- debug/accuracy_tools/atat/atat.py | 12 +++++++----- debug/accuracy_tools/setup.py | 2 +- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/debug/accuracy_tools/atat/atat.py b/debug/accuracy_tools/atat/atat.py index 4f69afd23..195def90c 100644 --- a/debug/accuracy_tools/atat/atat.py +++ b/debug/accuracy_tools/atat/atat.py @@ -30,6 +30,8 @@ def main(): f"For any issue, refer README.md first", ) parser.set_defaults(print_help=parser.print_help) + parser.add_argument('-f', '--framework', required=True, choices=['pytorch'], + help='Deep learning framework.') subparsers = parser.add_subparsers() subparsers.add_parser('parse') run_ut_cmd_parser = subparsers.add_parser('run_ut') @@ -46,16 +48,16 @@ def main(): parser.print_help() sys.exit(0) args = parser.parse_args(sys.argv[1:]) - if sys.argv[1] == "run_ut": + if sys.argv[3] == "run_ut": run_ut_command(args) - elif sys.argv[1] == "parse": + elif sys.argv[3] == "parse": cli_parse() - elif sys.argv[1] == "multi_run_ut": + elif sys.argv[3] == "multi_run_ut": config = prepare_config(args) run_parallel_ut(config) - elif sys.argv[1] == "api_precision_compare": + elif sys.argv[3] == "api_precision_compare": _api_precision_compare_command(args) - elif sys.argv[1] == "run_overflow_check": + elif sys.argv[3] == "run_overflow_check": _run_overflow_check_command(args) diff --git a/debug/accuracy_tools/setup.py b/debug/accuracy_tools/setup.py index 886d23090..1a42eed1a 100644 --- a/debug/accuracy_tools/setup.py +++ b/debug/accuracy_tools/setup.py @@ -19,7 +19,7 @@ from setuptools import setup, find_packages setup( name='ascend_training_accuracy_tools', - version='0.0.1', + version='0.0.2', description='This is a pytorch precision comparison tools', long_description='This is a pytorch precision comparison tools, include ptdbg and api accuracy checker', packages=find_packages(), -- Gitee