diff --git a/debug/accuracy_tools/api_accuracy_checker/README.md b/debug/accuracy_tools/api_accuracy_checker/README.md index 53a904f5295c08dc0b7c3eb2e32851493a62e99c..a50a61e0e8d5b50e2d109af704ae5ed5326e6499 100644 --- a/debug/accuracy_tools/api_accuracy_checker/README.md +++ b/debug/accuracy_tools/api_accuracy_checker/README.md @@ -70,7 +70,7 @@ Ascend模型精度预检工具能在昇腾NPU上扫描用户训练模型中所 若有需要,用户可以通过msCheckerConfig.update_config来配置dump路径以及开启真实数据模式、指定dump某个step或配置API dump白名单,详细请参见“**msCheckerConfig.update_config**”。 -3. 将API信息输入给run_ut模块运行精度检测并比对,运行如下命令: +3. 将API信息输入给run_ut模块运行精度检测并比对,单进程运行如下命令: ```bash cd $ATT_HOME/debug/accuracy_tools/api_accuracy_checker/run_ut @@ -81,18 +81,30 @@ Ascend模型精度预检工具能在昇腾NPU上扫描用户训练模型中所 | -------------------------------- | ------------------------------------------------------------ | -------- | | -forward或--forward_input_file | 指定前向API信息文件forward_info_{pid}.json。 | 是 | | -backward或--backward_input_file | 指定反向API信息文件backward_info_{pid}.json。 | 是 | - | -save_error_data | 保存精度未达标的API输入输出数据。 | 否 | + | -s或--save_error_data | 保存精度未达标的API输入输出数据。 | 否 | | -o或--out_path | 指指定run_ut执行结果存盘路径,默认“./”(相对于run_ut的路径)。 | 否 | | -j或--jit_compile | 开启jit编译。 | 否 | | -d或--device | 指定Device ID,选择UT代码运行所在的卡,默认值为0。 | 否 | - | -csv_path或--result_csv_path | 指定本次运行中断时生成的accuracy_checking_result_{timestamp}.csv文件路径,执行run_ut中断时,若想从中断处继续执行,配置此参数即可。 | 否 | + | -c或--result_csv_path | 指定本次运行中断时生成的accuracy_checking_result_{timestamp}.csv文件路径,执行run_ut中断时,若想从中断处继续执行,配置此参数即可。 | 否 | run_ut执行结果包括`accuracy_checking_result_{timestamp}.csv`和`accuracy_checking_details_{timestamp}.csv`两个文件。`accuracy_checking_result_{timestamp}.csv`是API粒度的,标明每个API是否通过测试。建议用户先查看`accuracy_checking_result_{timestamp}.csv`文件,对于其中没有通过测试的或者特定感兴趣的API,根据其API name字段在`accuracy_checking_details_{timestamp}.csv`中查询其各个输出的达标情况以及比较指标。API达标情况介绍请参考“**API预检指标**”。 -4. 如果需要保存比对不达标的输入和输出数据,可以在run_ut执行命令结尾添加-save_error_data,例如: +4. 可使用多进程进行加速,多进程运行命令如下: ```bash - python run_ut.py -forward ./forward_info_0.json -backward ./backward_info_0.json -save_error_data + cd $ATT_HOME/debug/accuracy_tools/api_accuracy_checker/run_ut + python multi_run_ut.py -forward ./forward_info_0.json -backward ./backward_info_0.json + ``` + | 参数名称 | 说明 | 是否必选 | + | -------------------------------- | ------------------------------------------------------------ | -------- | + | -n或--num_splits | 指定本次运行中多进程的进程数量,默认为8 | 否 | + + 其他参数于run_ut相同 + +5. 如果需要保存比对不达标的输入和输出数据,可以在run_ut执行命令结尾添加--save_error_data或-s,例如: + + ```bash + python run_ut.py -forward ./forward_info_0.json -backward ./backward_info_0.json --save_error_data ``` 数据默认会存盘到'./ut_error_data'路径下(相对于启动run_ut的路径),有需要的话,用户可以通过msCheckerConfig.update_config来配置保存路径,参数为error_data_path。 diff --git a/debug/accuracy_tools/api_accuracy_checker/compare/compare.py b/debug/accuracy_tools/api_accuracy_checker/compare/compare.py index b95ec8fa3e47327e24d1ecd532ad0a13b46586ee..c2ba7dbfd7bc6094095377d906125febca1e3301 100644 --- a/debug/accuracy_tools/api_accuracy_checker/compare/compare.py +++ b/debug/accuracy_tools/api_accuracy_checker/compare/compare.py @@ -18,12 +18,7 @@ class Comparator: def __init__(self, result_csv_path, details_csv_path, is_continue_run_ut, test_result_cnt=None, stack_info_json_path=None): self.save_path = result_csv_path self.detail_save_path = details_csv_path - if not is_continue_run_ut: - if os.path.exists(self.save_path): - raise ValueError(f"file {self.save_path} already exists, please remove it first or use a new dump path") - if os.path.exists(self.detail_save_path): - raise ValueError( - f"file {self.detail_save_path} already exists, please remove it first or use a new dump path") + if not is_continue_run_ut and not os.path.exists(self.save_path) and not os.path.exists(self.detail_save_path): self.write_csv_title() if stack_info_json_path: self.stack_info = get_json_contents(stack_info_json_path) diff --git a/debug/accuracy_tools/api_accuracy_checker/run_ut/multi_run_ut.py b/debug/accuracy_tools/api_accuracy_checker/run_ut/multi_run_ut.py new file mode 100644 index 0000000000000000000000000000000000000000..17cb9033038d8d1275235d055eb60a85f5bcc847 --- /dev/null +++ b/debug/accuracy_tools/api_accuracy_checker/run_ut/multi_run_ut.py @@ -0,0 +1,119 @@ +import subprocess +import json +import os +import argparse +import glob +import sys +from collections import namedtuple +from ptdbg_ascend.src.python.ptdbg_ascend.common.file_check_util import FileCheckConst, FileChecker, \ + check_file_suffix, check_link, FileOpen +from api_accuracy_checker.compare.compare import Comparator +from api_accuracy_checker.run_ut.run_ut import get_statistics_from_result_csv +from api_accuracy_checker.dump.info_dump import write_json + + +def split_json_file(input_file, num_splits): + with FileOpen(input_file, 'r') as file: + data = json.load(file) + + items = list(data.items()) + total_items = len(items) + + chunk_size = total_items // num_splits + split_files = [] + + for i in range(num_splits): + start = i * chunk_size + end = (i + 1) * chunk_size if i < num_splits - 1 else total_items + split_filename = os.path.join("./", f"temp_part{i}.json") + for item in items[start:end]: + write_json(split_filename,{item[0]:item[1]}) + split_files.append(split_filename) + + return split_files + + +ParallelUTConfig = namedtuple('ParallelUTConfig', ['forward_files', 'backward_files', 'out_path', 'num_splits', 'save_error_data_flag', 'jit_compile_flag', 'device_id', 'csv_file']) + + +def run_parallel_ut(config): + processes = [] + + def create_cmd(fwd, bwd): + cmd = [ + 'python', 'run_ut.py', + '-forward', fwd, + '-backward', bwd, + '-o', config.out_path, + '-d' if config.device_id else '', + str(config.device_id) if config.device_id else '', + '-j' if config.jit_compile_flag else '', + '-save_error_data' if config.save_error_data_flag else '', + '-c' if config.csv_file else '', + config.csv_file if config.csv_file else '' + ] + return [arg for arg in cmd if arg] + + commands = [create_cmd(fwd, config.backward_files[0]) for fwd in config.forward_files] + + for cmd in commands: + processes.append(subprocess.Popen(cmd)) + + for process in processes: + process.wait() + + for file in config.forward_files: + os.remove(file) + try: + process_csv_and_print_results(config.out_path, config.csv_file) + except FileNotFoundError as e: + print(f"Error: {e}") + except Exception as e: + print(f"An unexpected error occurred: {e}") + +def process_csv_and_print_results(out_path, csv_file=None): + csv_files = glob.glob(os.path.join(out_path, 'accuracy_checking_result_*.csv')) + if not csv_files: + raise FileNotFoundError("No CSV files found in the specified output path.") + latest_csv = max(csv_files, key=os.path.getmtime) + + comparator = Comparator(out_path,out_path,False) + comparator.test_result_cnt = get_statistics_from_result_csv(latest_csv) + comparator.print_pretest_result() + +def main(): + try: + parser = argparse.ArgumentParser(description='Run UT in parallel') + parser.add_argument('-forward', '--forward_input_file', required=True, help='The forward input JSON file.') + parser.add_argument('-backward', '--backward_input_file', required=True, help='The backward input JSON file.') + parser.add_argument('-o', '--out_path', default="./", help='The UT task result output path.') + parser.add_argument('-n', '--num_splits', type=int, choices=range(1, 21), default=10, help='Number of splits for parallel processing. Range: 1-8') + parser.add_argument('-save_error_data', action='store_true', help='Flag to save error data.') + parser.add_argument('-j', '--jit_compile', action='store_true', help='Flag to turn on jit compile.') + parser.add_argument('-d', '--device_id', type=int, default=0, help='Device id.') + parser.add_argument('-c', '--csv_file',type=str, help='CSV file to exclude APIs already tested.') + args = parser.parse_args() + + check_link(args.forward_input_file) + check_link(args.backward_input_file) + forward_file = os.path.realpath(args.forward_input_file) + backward_file = os.path.realpath(args.backward_input_file) + check_file_suffix(forward_file, FileCheckConst.JSON_SUFFIX) + check_file_suffix(backward_file, FileCheckConst.JSON_SUFFIX) + out_path = os.path.realpath(args.out_path) if args.out_path else "./" + out_path_checker = FileChecker(out_path, FileCheckConst.DIR, ability=FileCheckConst.WRITE_ABLE) + out_path = out_path_checker.common_check() + forward_splits = split_json_file(args.forward_input_file, args.num_splits) + backward_splits = [backward_file] * args.num_splits + + config = ParallelUTConfig(forward_splits, backward_splits, args.out_path, args.num_splits, args.save_error_data, args.jit_compile, args.device_id, args.csv_file) + run_parallel_ut(config) + except KeyboardInterrupt: + print("Interrupted by user, cleaning up temporary files...") + for file in glob.glob(f"temp_part*.json"): + os.remove(file) + print("Temporary files removed.") + sys.exit(1) + +if __name__ == '__main__': + main() diff --git a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py index fb885168d662006b3fa639c2322113d6f1f6b386..3f2d85816b537eea8a96c2992055ea6573acc93f 100644 --- a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py +++ b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py @@ -300,13 +300,13 @@ def _run_ut_parser(parser): parser.add_argument("-o", "--out_path", dest="out_path", default="", type=str, help=" The ut task result out path.", required=False) - parser.add_argument('-save_error_data', dest="save_error_data", action="store_true", + parser.add_argument('-s','--save_error_data', dest="save_error_data", action="store_true", help=" Save compare failed api output.", required=False) parser.add_argument("-j", "--jit_compile", dest="jit_compile", action="store_true", help=" whether to turn on jit compile", required=False) parser.add_argument("-d", "--device", dest="device_id", type=int, help=" set device id to run ut", default=0, required=False) - parser.add_argument("-csv_path", "--result_csv_path", dest="result_csv_path", default="", type=str, + parser.add_argument("-c", "--result_csv_path", dest="result_csv_path", default="", type=str, help=" The path of accuracy_checking_result_{timestamp}.csv, " "when run ut is interrupted, enter the file path to continue run ut.", required=False)