From 5ccb211d6ca64b399bb10e45b08e245513d4247c Mon Sep 17 00:00:00 2001 From: caishangqiu Date: Wed, 6 Mar 2024 15:24:38 +0800 Subject: [PATCH 1/6] [Feature] add atat one-site cli --- .../api_accuracy_checker/run_ut/run_ut.py | 11 +++-- debug/accuracy_tools/atat/atat.py | 40 +++++++++++++++++++ debug/accuracy_tools/setup.py | 5 ++- 3 files changed, 52 insertions(+), 4 deletions(-) create mode 100644 debug/accuracy_tools/atat/atat.py diff --git a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py index 0035e0dab..0e8d86276 100644 --- a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py +++ b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py @@ -373,10 +373,15 @@ def preprocess_forward_content(forward_content): return processed_content -def _run_ut(): - parser = argparse.ArgumentParser() - _run_ut_parser(parser) +def _run_ut(parser = None): + if not parser: + parser = argparse.ArgumentParser() + parser = _run_ut_parser(parser) args = parser.parse_args(sys.argv[1:]) + run_ut_command(args) + + +def run_ut_command(args): if not is_gpu: torch.npu.set_compile_mode(jit_compile=args.jit_compile) used_device = current_device + ":" + str(args.device_id[0]) diff --git a/debug/accuracy_tools/atat/atat.py b/debug/accuracy_tools/atat/atat.py new file mode 100644 index 000000000..cd1429bbf --- /dev/null +++ b/debug/accuracy_tools/atat/atat.py @@ -0,0 +1,40 @@ +# Copyright (c) 2024, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import sys +from api_accuracy_checker.run_ut_run_ut import _run_ut_parser, run_ut_command + +def main(): + parser = argparse.ArgumentParser( + formatter_class=argparse.RawDescriptionHelpFormatter, + description="atat(Ascend Training Accuracy Tools), [Powered by MindStudio].\n" + "Providing one-site accuracy difference debugging toolkit for training on Ascend Devices.\n" + f"For any issue, refer README first TODO", + ) + subparsers = parser.add_subparsers(help='commands') + #parse_cmd_parser = subparsers.add_parser('parse', help='[TODO]parse command') + run_ut_cmd_parser = subparsers.add_parser('run_ut', help='[TODO]run_ut command') + #multi_run_ut_cmd_parser = subparsers.add_parser('parse', help='[TODO]multi_run_ut command') + #benchmark_compare_cmd_parser = subparsers.add_parser('parse', help='[TODO]benchmark_compare command') + parser.set_defaults(print_help=parser.print_help) + run_parser = _run_ut_parser(run_ut_cmd_parser) + args = parser.parse_args(sys.argv[1:]) + if sys.argv[1] == "run_ut": + run_ut_command(args) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/debug/accuracy_tools/setup.py b/debug/accuracy_tools/setup.py index af16d18e1..886d23090 100644 --- a/debug/accuracy_tools/setup.py +++ b/debug/accuracy_tools/setup.py @@ -33,4 +33,7 @@ setup( ], include_package_data=True, ext_modules=[], - zip_safe=False) \ No newline at end of file + zip_safe=False, + entry_points={ + 'console_scripts' : ['atat=atat.atat:main'], + },) \ No newline at end of file -- Gitee From bcbd688bef901785e4abe0c2c41846a13d29f227 Mon Sep 17 00:00:00 2001 From: caishangqiu Date: Wed, 6 Mar 2024 17:58:36 +0800 Subject: [PATCH 2/6] atat one-site cli adds parse, benchmark_compare, multi_run_ut and overflow_check --- .../compare/benchmark_compare.py | 9 ++++-- .../run_ut/run_overflow_check.py | 13 ++++++--- .../api_accuracy_checker/run_ut/run_ut.py | 4 +-- debug/accuracy_tools/atat/atat.py | 28 ++++++++++++++++--- 4 files changed, 42 insertions(+), 12 deletions(-) diff --git a/debug/accuracy_tools/api_accuracy_checker/compare/benchmark_compare.py b/debug/accuracy_tools/api_accuracy_checker/compare/benchmark_compare.py index 7106b8448..f5ff52c7b 100644 --- a/debug/accuracy_tools/api_accuracy_checker/compare/benchmark_compare.py +++ b/debug/accuracy_tools/api_accuracy_checker/compare/benchmark_compare.py @@ -236,10 +236,15 @@ def check_csv_columns(columns, csv_type): raise CompareException(CompareException.INVALID_DATA_ERROR, msg) -def _benchmark_compare(): - parser = argparse.ArgumentParser() +def _benchmark_compare(parser=None): + if not parser: + parser = argparse.ArgumentParser() _benchmark_compare_parser(parser) args = parser.parse_args(sys.argv[1:]) + _benchmark_compare_command(args) + + +def _benchmark_compare_command(args): npu_csv_path = get_validated_result_csv_path(args.npu_csv_path, 'detail') gpu_csv_path = get_validated_result_csv_path(args.gpu_csv_path, 'detail') out_path = os.path.realpath(args.out_path) if args.out_path else "./" diff --git a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_overflow_check.py b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_overflow_check.py index 9f968c1d8..2e8a12231 100644 --- a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_overflow_check.py +++ b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_overflow_check.py @@ -84,7 +84,7 @@ def run_torch_api(api_full_name, api_info_dict): return -def _run_ut_parser(parser): +def _run_overflow_check_parser(parser): parser.add_argument("-forward", "--forward_input_file", dest="forward_input_file", default="", help=" The api param tool forward result file: generate from api param tool, " "a json file.", @@ -95,10 +95,15 @@ def _run_ut_parser(parser): default=0, required=False) -def _run_overflow_check(): - parser = argparse.ArgumentParser() - _run_ut_parser(parser) +def _run_overflow_check(parser=None): + if not parser: + parser = argparse.ArgumentParser() + _run_overflow_check_parser(parser) args = parser.parse_args(sys.argv[1:]) + _run_overflow_check_command(args) + + +def _run_overflow_check_command(args): torch.npu.set_compile_mode(jit_compile=args.jit_compile) npu_device = "npu:" + str(args.device_id) check_link(args.forward_input_file) diff --git a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py index 0e8d86276..ed7ce7e11 100644 --- a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py +++ b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py @@ -373,10 +373,10 @@ def preprocess_forward_content(forward_content): return processed_content -def _run_ut(parser = None): +def _run_ut(parser=None): if not parser: parser = argparse.ArgumentParser() - parser = _run_ut_parser(parser) + _run_ut_parser(parser) args = parser.parse_args(sys.argv[1:]) run_ut_command(args) diff --git a/debug/accuracy_tools/atat/atat.py b/debug/accuracy_tools/atat/atat.py index cd1429bbf..bc0dc4187 100644 --- a/debug/accuracy_tools/atat/atat.py +++ b/debug/accuracy_tools/atat/atat.py @@ -16,6 +16,11 @@ import argparse import sys from api_accuracy_checker.run_ut_run_ut import _run_ut_parser, run_ut_command +from ptdbg_ascend.src.python.ptdbg_ascend.parse_tool.cli import parse as cli_parse +from api_accuracy_checker.run_ut.multi_run_ut import prepare_config, run_parallel_ut +from api_accuracy_checker.compare.benchmark_compare import _benchmark_compare_parser, _benchmark_compare_command +from api_accuracy_checker.run_ut_run_overflow_check import _run_overflow_check_parser, _run_overflow_check_command + def main(): parser = argparse.ArgumentParser( @@ -25,15 +30,30 @@ def main(): f"For any issue, refer README first TODO", ) subparsers = parser.add_subparsers(help='commands') - #parse_cmd_parser = subparsers.add_parser('parse', help='[TODO]parse command') + parse_cmd_parser = subparsers.add_parser('parse', help='[TODO]parse command') run_ut_cmd_parser = subparsers.add_parser('run_ut', help='[TODO]run_ut command') - #multi_run_ut_cmd_parser = subparsers.add_parser('parse', help='[TODO]multi_run_ut command') - #benchmark_compare_cmd_parser = subparsers.add_parser('parse', help='[TODO]benchmark_compare command') + multi_run_ut_cmd_parser = subparsers.add_parser('multi_run_ut', help='[TODO]multi_run_ut command') + benchmark_compare_cmd_parser = subparsers.add_parser('benchmark_compare', help='[TODO]benchmark_compare command') + run_overflow_check_cmd_parser = subparsers.add_parser('run_overflow_check', help='[TODO]run_overflow_check command') parser.set_defaults(print_help=parser.print_help) - run_parser = _run_ut_parser(run_ut_cmd_parser) + _run_ut_parser(run_ut_cmd_parser) + _run_ut_parser(multi_run_ut_cmd_parser) + multi_run_ut_cmd_parser.add_argument('-n', '--num_splits', type=int, choices=range(1, 65), default=8, + help='Number of splits for parallel processing. Range: 1-64') + _benchmark_compare_parser(benchmark_compare_cmd_parser) + _run_overflow_check_parser(run_overflow_check_cmd_parser) args = parser.parse_args(sys.argv[1:]) if sys.argv[1] == "run_ut": run_ut_command(args) + elif sys.argv[1] == "parse": + cli_parse() + elif sys.argv[1] == "multi_run_ut": + config = prepare_config(args) + run_parallel_ut(config) + elif sys.argv[1] == "benchmark_compare": + _benchmark_compare_command(args) + elif sys.argv[1] == "run_overflow_check": + _run_overflow_check_command(args) if __name__ == "__main__": -- Gitee From 2b5d94f06675105d5f8efb50b3219bd4f89da726 Mon Sep 17 00:00:00 2001 From: caishangqiu Date: Wed, 6 Mar 2024 18:03:03 +0800 Subject: [PATCH 3/6] fix typo --- debug/accuracy_tools/atat/atat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/debug/accuracy_tools/atat/atat.py b/debug/accuracy_tools/atat/atat.py index bc0dc4187..d2063fbd7 100644 --- a/debug/accuracy_tools/atat/atat.py +++ b/debug/accuracy_tools/atat/atat.py @@ -15,7 +15,7 @@ import argparse import sys -from api_accuracy_checker.run_ut_run_ut import _run_ut_parser, run_ut_command +from api_accuracy_checker.run_ut.run_ut import _run_ut_parser, run_ut_command from ptdbg_ascend.src.python.ptdbg_ascend.parse_tool.cli import parse as cli_parse from api_accuracy_checker.run_ut.multi_run_ut import prepare_config, run_parallel_ut from api_accuracy_checker.compare.benchmark_compare import _benchmark_compare_parser, _benchmark_compare_command -- Gitee From 4e961a2e85a8ebb412ce49002822c4ffa841f8db Mon Sep 17 00:00:00 2001 From: caishangqiu Date: Wed, 6 Mar 2024 18:07:19 +0800 Subject: [PATCH 4/6] fix typo 2 --- debug/accuracy_tools/atat/atat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/debug/accuracy_tools/atat/atat.py b/debug/accuracy_tools/atat/atat.py index d2063fbd7..3f3af7ef2 100644 --- a/debug/accuracy_tools/atat/atat.py +++ b/debug/accuracy_tools/atat/atat.py @@ -19,7 +19,7 @@ from api_accuracy_checker.run_ut.run_ut import _run_ut_parser, run_ut_command from ptdbg_ascend.src.python.ptdbg_ascend.parse_tool.cli import parse as cli_parse from api_accuracy_checker.run_ut.multi_run_ut import prepare_config, run_parallel_ut from api_accuracy_checker.compare.benchmark_compare import _benchmark_compare_parser, _benchmark_compare_command -from api_accuracy_checker.run_ut_run_overflow_check import _run_overflow_check_parser, _run_overflow_check_command +from api_accuracy_checker.run_ut.run_overflow_check import _run_overflow_check_parser, _run_overflow_check_command def main(): -- Gitee From 06ab5e5304998d9467b6c06131cb26c92aaf5349 Mon Sep 17 00:00:00 2001 From: caishangqiu Date: Wed, 6 Mar 2024 18:18:31 +0800 Subject: [PATCH 5/6] handle all TODOs --- debug/accuracy_tools/atat/atat.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/debug/accuracy_tools/atat/atat.py b/debug/accuracy_tools/atat/atat.py index 3f3af7ef2..e82c4f9a4 100644 --- a/debug/accuracy_tools/atat/atat.py +++ b/debug/accuracy_tools/atat/atat.py @@ -25,16 +25,16 @@ from api_accuracy_checker.run_ut.run_overflow_check import _run_overflow_check_p def main(): parser = argparse.ArgumentParser( formatter_class=argparse.RawDescriptionHelpFormatter, - description="atat(Ascend Training Accuracy Tools), [Powered by MindStudio].\n" + description="atat(ascend training accuracy tools), [Powered by MindStudio].\n" "Providing one-site accuracy difference debugging toolkit for training on Ascend Devices.\n" - f"For any issue, refer README first TODO", + f"For any issue, refer README.md first", ) subparsers = parser.add_subparsers(help='commands') - parse_cmd_parser = subparsers.add_parser('parse', help='[TODO]parse command') - run_ut_cmd_parser = subparsers.add_parser('run_ut', help='[TODO]run_ut command') - multi_run_ut_cmd_parser = subparsers.add_parser('multi_run_ut', help='[TODO]multi_run_ut command') - benchmark_compare_cmd_parser = subparsers.add_parser('benchmark_compare', help='[TODO]benchmark_compare command') - run_overflow_check_cmd_parser = subparsers.add_parser('run_overflow_check', help='[TODO]run_overflow_check command') + subparsers.add_parser('parse') + run_ut_cmd_parser = subparsers.add_parser('run_ut') + multi_run_ut_cmd_parser = subparsers.add_parser('multi_run_ut') + benchmark_compare_cmd_parser = subparsers.add_parser('benchmark_compare') + run_overflow_check_cmd_parser = subparsers.add_parser('run_overflow_check') parser.set_defaults(print_help=parser.print_help) _run_ut_parser(run_ut_cmd_parser) _run_ut_parser(multi_run_ut_cmd_parser) -- Gitee From d1c035e9858f6d65cc856246cecea54c0ecc07e6 Mon Sep 17 00:00:00 2001 From: caishangqiu Date: Thu, 7 Mar 2024 10:03:25 +0800 Subject: [PATCH 6/6] handle case without subcommand --- debug/accuracy_tools/atat/atat.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/debug/accuracy_tools/atat/atat.py b/debug/accuracy_tools/atat/atat.py index e82c4f9a4..60e49cbe0 100644 --- a/debug/accuracy_tools/atat/atat.py +++ b/debug/accuracy_tools/atat/atat.py @@ -29,19 +29,22 @@ def main(): "Providing one-site accuracy difference debugging toolkit for training on Ascend Devices.\n" f"For any issue, refer README.md first", ) - subparsers = parser.add_subparsers(help='commands') + parser.set_defaults(print_help=parser.print_help) + subparsers = parser.add_subparsers() subparsers.add_parser('parse') run_ut_cmd_parser = subparsers.add_parser('run_ut') multi_run_ut_cmd_parser = subparsers.add_parser('multi_run_ut') benchmark_compare_cmd_parser = subparsers.add_parser('benchmark_compare') run_overflow_check_cmd_parser = subparsers.add_parser('run_overflow_check') - parser.set_defaults(print_help=parser.print_help) _run_ut_parser(run_ut_cmd_parser) _run_ut_parser(multi_run_ut_cmd_parser) multi_run_ut_cmd_parser.add_argument('-n', '--num_splits', type=int, choices=range(1, 65), default=8, help='Number of splits for parallel processing. Range: 1-64') _benchmark_compare_parser(benchmark_compare_cmd_parser) _run_overflow_check_parser(run_overflow_check_cmd_parser) + if len(sys.argv) == 1: + parser.print_help() + sys.exit(0) args = parser.parse_args(sys.argv[1:]) if sys.argv[1] == "run_ut": run_ut_command(args) -- Gitee