diff --git a/debug/accuracy_tools/api_accuracy_checker/compare/benchmark_compare.py b/debug/accuracy_tools/api_accuracy_checker/compare/benchmark_compare.py index 7106b8448f9c8f8c219aef009f4cb7b19734337e..f5ff52c7bb7bf0860d72571a53e3c08953fb2b06 100644 --- a/debug/accuracy_tools/api_accuracy_checker/compare/benchmark_compare.py +++ b/debug/accuracy_tools/api_accuracy_checker/compare/benchmark_compare.py @@ -236,10 +236,15 @@ def check_csv_columns(columns, csv_type): raise CompareException(CompareException.INVALID_DATA_ERROR, msg) -def _benchmark_compare(): - parser = argparse.ArgumentParser() +def _benchmark_compare(parser=None): + if not parser: + parser = argparse.ArgumentParser() _benchmark_compare_parser(parser) args = parser.parse_args(sys.argv[1:]) + _benchmark_compare_command(args) + + +def _benchmark_compare_command(args): npu_csv_path = get_validated_result_csv_path(args.npu_csv_path, 'detail') gpu_csv_path = get_validated_result_csv_path(args.gpu_csv_path, 'detail') out_path = os.path.realpath(args.out_path) if args.out_path else "./" diff --git a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_overflow_check.py b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_overflow_check.py index 9f968c1d87064d81c0e6b723a1754bb6dbdcf9d0..2e8a12231ed31c499648f12ee93486dfed47e00c 100644 --- a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_overflow_check.py +++ b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_overflow_check.py @@ -84,7 +84,7 @@ def run_torch_api(api_full_name, api_info_dict): return -def _run_ut_parser(parser): +def _run_overflow_check_parser(parser): parser.add_argument("-forward", "--forward_input_file", dest="forward_input_file", default="", help=" The api param tool forward result file: generate from api param tool, " "a json file.", @@ -95,10 +95,15 @@ def _run_ut_parser(parser): default=0, required=False) -def _run_overflow_check(): - parser = argparse.ArgumentParser() - _run_ut_parser(parser) +def _run_overflow_check(parser=None): + if not parser: + parser = argparse.ArgumentParser() + _run_overflow_check_parser(parser) args = parser.parse_args(sys.argv[1:]) + _run_overflow_check_command(args) + + +def _run_overflow_check_command(args): torch.npu.set_compile_mode(jit_compile=args.jit_compile) npu_device = "npu:" + str(args.device_id) check_link(args.forward_input_file) diff --git a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py index 0035e0dab41402a9308a10d7122112c95f931b39..ed7ce7e11d7b479b1100fd16f6deba2316430da0 100644 --- a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py +++ b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py @@ -373,10 +373,15 @@ def preprocess_forward_content(forward_content): return processed_content -def _run_ut(): - parser = argparse.ArgumentParser() +def _run_ut(parser=None): + if not parser: + parser = argparse.ArgumentParser() _run_ut_parser(parser) args = parser.parse_args(sys.argv[1:]) + run_ut_command(args) + + +def run_ut_command(args): if not is_gpu: torch.npu.set_compile_mode(jit_compile=args.jit_compile) used_device = current_device + ":" + str(args.device_id[0]) diff --git a/debug/accuracy_tools/atat/atat.py b/debug/accuracy_tools/atat/atat.py new file mode 100644 index 0000000000000000000000000000000000000000..60e49cbe039cd2eb14e73d9bb71495b482e2c473 --- /dev/null +++ b/debug/accuracy_tools/atat/atat.py @@ -0,0 +1,63 @@ +# Copyright (c) 2024, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import sys +from api_accuracy_checker.run_ut.run_ut import _run_ut_parser, run_ut_command +from ptdbg_ascend.src.python.ptdbg_ascend.parse_tool.cli import parse as cli_parse +from api_accuracy_checker.run_ut.multi_run_ut import prepare_config, run_parallel_ut +from api_accuracy_checker.compare.benchmark_compare import _benchmark_compare_parser, _benchmark_compare_command +from api_accuracy_checker.run_ut.run_overflow_check import _run_overflow_check_parser, _run_overflow_check_command + + +def main(): + parser = argparse.ArgumentParser( + formatter_class=argparse.RawDescriptionHelpFormatter, + description="atat(ascend training accuracy tools), [Powered by MindStudio].\n" + "Providing one-site accuracy difference debugging toolkit for training on Ascend Devices.\n" + f"For any issue, refer README.md first", + ) + parser.set_defaults(print_help=parser.print_help) + subparsers = parser.add_subparsers() + subparsers.add_parser('parse') + run_ut_cmd_parser = subparsers.add_parser('run_ut') + multi_run_ut_cmd_parser = subparsers.add_parser('multi_run_ut') + benchmark_compare_cmd_parser = subparsers.add_parser('benchmark_compare') + run_overflow_check_cmd_parser = subparsers.add_parser('run_overflow_check') + _run_ut_parser(run_ut_cmd_parser) + _run_ut_parser(multi_run_ut_cmd_parser) + multi_run_ut_cmd_parser.add_argument('-n', '--num_splits', type=int, choices=range(1, 65), default=8, + help='Number of splits for parallel processing. Range: 1-64') + _benchmark_compare_parser(benchmark_compare_cmd_parser) + _run_overflow_check_parser(run_overflow_check_cmd_parser) + if len(sys.argv) == 1: + parser.print_help() + sys.exit(0) + args = parser.parse_args(sys.argv[1:]) + if sys.argv[1] == "run_ut": + run_ut_command(args) + elif sys.argv[1] == "parse": + cli_parse() + elif sys.argv[1] == "multi_run_ut": + config = prepare_config(args) + run_parallel_ut(config) + elif sys.argv[1] == "benchmark_compare": + _benchmark_compare_command(args) + elif sys.argv[1] == "run_overflow_check": + _run_overflow_check_command(args) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/debug/accuracy_tools/setup.py b/debug/accuracy_tools/setup.py index af16d18e11e38f515eae66216b7eab790c27a827..886d230906476909b7e88eade5424e8d20aa883a 100644 --- a/debug/accuracy_tools/setup.py +++ b/debug/accuracy_tools/setup.py @@ -33,4 +33,7 @@ setup( ], include_package_data=True, ext_modules=[], - zip_safe=False) \ No newline at end of file + zip_safe=False, + entry_points={ + 'console_scripts' : ['atat=atat.atat:main'], + },) \ No newline at end of file