diff --git "a/debug/accuracy_tools/api_accuracy_checker/Ascend\346\250\241\345\236\213\347\262\276\345\272\246\351\242\204\346\243\200\345\267\245\345\205\267\344\275\277\347\224\250\346\226\271\346\263\225.md" "b/debug/accuracy_tools/api_accuracy_checker/Ascend\346\250\241\345\236\213\347\262\276\345\272\246\351\242\204\346\243\200\345\267\245\345\205\267\344\275\277\347\224\250\346\226\271\346\263\225.md" index a72e23484c1345084c98b8f3c2989bcf50fe86b8..9980649c3688c0e77927a3578762674986adac8e 100644 --- "a/debug/accuracy_tools/api_accuracy_checker/Ascend\346\250\241\345\236\213\347\262\276\345\272\246\351\242\204\346\243\200\345\267\245\345\205\267\344\275\277\347\224\250\346\226\271\346\263\225.md" +++ "b/debug/accuracy_tools/api_accuracy_checker/Ascend\346\250\241\345\236\213\347\262\276\345\272\246\351\242\204\346\243\200\345\267\245\345\205\267\344\275\277\347\224\250\346\226\271\346\263\225.md" @@ -16,6 +16,10 @@ ``` export PYTHONPATH=$PYTHONPATH:{att_root}/debug/accuracy_tools/ ``` + 安装依赖tqdm + ``` + pip install tqdm + ``` 2. 在工具中加入以下代码使用工具dump模块,启动训练抓取网络所有API信息 diff --git a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py index e6162ca8b2611a3a166f8ae4304fdfcb82551bdc..bfb2296a42568a4f76c1732bee635360d0b874dd 100644 --- a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py +++ b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py @@ -4,6 +4,7 @@ import sys import torch_npu import yaml import torch +from tqdm import tqdm from api_accuracy_checker.run_ut.data_generate import gen_api_params, gen_args from api_accuracy_checker.common.utils import print_info_log, print_warn_log, get_json_contents, check_need_convert, \ print_error_log @@ -60,7 +61,7 @@ def run_ut(forward_file, backward_file, out_path, save_error_data): backward_content = get_json_contents(backward_file) api_setting_dict = get_json_contents("torch_ut_setting.json") compare = Comparator(out_path) - for api_full_name, api_info_dict in forward_content.items(): + for api_full_name, api_info_dict in tqdm(forward_content.items()): try: grad_out, npu_grad_out, npu_out, out = run_torch_api(api_full_name, api_setting_dict, backward_content, api_info_dict)