diff --git a/debug/accuracy_tools/api_accuracy_checker/README.md b/debug/accuracy_tools/api_accuracy_checker/README.md index bfd1922a63e2e8336e409e98d37a5e7e2588d034..a5bc92c093b153e96ae6e94635b26f29dba04987 100644 --- a/debug/accuracy_tools/api_accuracy_checker/README.md +++ b/debug/accuracy_tools/api_accuracy_checker/README.md @@ -53,6 +53,12 @@ Ascend模型精度预检工具能在昇腾NPU上扫描用户训练模型中所 注意:目前API通过测试的标准是每个输出与标杆比对的余弦相似度大于0.99,并且float16数据要通过双千分之一标准,float32数据要通过双万分之一标准,pretest_details.csv中的相对误差供用户分析时使用。 +4. 如果需要保存比对不达标的输入和输出数据,可以在run_ut执行命令结尾添加-save_error_data,例如: + + ``` + python run_ut.py -forward ./forward_info_0.json -backward ./backward_info_0.json -save_error_data + ``` + 数据默认会存盘到'./ut_error_data'路径下(相对于启动run_ut的路径),有需要的话,用户可以通过msCheckerConfig.update_config来配置保存路径,参数为error_data_path ## FAQ 1. 多卡训练dump结果只有一组json,这正确吗? diff --git a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py index 813d0cb58618eae048ea4ab2057bdc1f50185500..0d84ae54c70e7740547199eafd01cb8211185570 100644 --- a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py +++ b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py @@ -1,5 +1,6 @@ import argparse import os +import copy import sys import torch_npu import yaml @@ -74,9 +75,12 @@ def run_ut(forward_file, backward_file, out_path, save_error_data): for api_full_name, api_info_dict in tqdm(forward_content.items()): try: data_info = run_torch_api(api_full_name, api_setting_dict, backward_content, api_info_dict) - is_fwd_success, is_bwd_success = compare.compare_output(api_full_name, data_info.bench_out, - data_info.npu_out, data_info.bench_grad_out, - data_info.npu_grad_out) + is_fwd_success, is_bwd_success = \ + compare.compare_output(api_full_name, + None if data_info.bench_out is None else data_info.bench_out.clone(), + None if data_info.npu_out is None else data_info.npu_out.clone(), + None if data_info.bench_grad_out is None else data_info.bench_grad_out.clone(), + None if data_info.npu_grad_out is None else data_info.npu_grad_out.clone()) if save_error_data: do_save_error_data(api_full_name, data_info, is_fwd_success, is_bwd_success) except Exception as err: