diff --git "a/debug/accuracy_tools/api_accuracy_checker/Ascend\346\250\241\345\236\213\347\262\276\345\272\246\351\242\204\346\243\200\345\267\245\345\205\267\344\275\277\347\224\250\346\226\271\346\263\225.md" "b/debug/accuracy_tools/api_accuracy_checker/Ascend\346\250\241\345\236\213\347\262\276\345\272\246\351\242\204\346\243\200\345\267\245\345\205\267\344\275\277\347\224\250\346\226\271\346\263\225.md" new file mode 100644 index 0000000000000000000000000000000000000000..c2c8f456c23f0832f9e3c484fb6a62b2b461963e --- /dev/null +++ "b/debug/accuracy_tools/api_accuracy_checker/Ascend\346\250\241\345\236\213\347\262\276\345\272\246\351\242\204\346\243\200\345\267\245\345\205\267\344\275\277\347\224\250\346\226\271\346\263\225.md" @@ -0,0 +1,30 @@ +# Ascend模型精度预检工具 + +## 使用方式 + +1. 安装遇见工具 + + 将att仓代码下载到本地,并配置环境变量。假设att仓本地路径为 {att_root},环境变量应配置为 + + ``` + export PYTHONPATH=$PYTHONPATH:{att_root}/debug/accuracy_tools/ + ``` + +2. 使用工具dump模块抓取网络所有API信息 + + ``` + from api_accuracy_checker.dump import set_dump_switch + set_dump_switch("ON") + ``` + +​ dump信息默认会存盘到./api_info/路径下,后缀的数字代表rank id + +3. 将上述信息输入给run_ut模块运行精度检测并比对 + + ``` + cd run_ut + python run_ut.py --forward ./api_info/forward_info_0.json --backward ./api_info/backward_info_0.json + ``` + + forward和backward两个命令行参数根据实际情况配置。比对结果存盘位置会打屏显示,默认是'./',可以在运行run_ut.py时通过 --out_path命令行参数配置。 + diff --git a/debug/accuracy_tools/api_accuracy_checker/common/utils.py b/debug/accuracy_tools/api_accuracy_checker/common/utils.py index dac54a79fc8bca5caa2bc783852992104a1d7a0e..e62c9616ca710aa4eabc8e3e7cbb099444054b3f 100644 --- a/debug/accuracy_tools/api_accuracy_checker/common/utils.py +++ b/debug/accuracy_tools/api_accuracy_checker/common/utils.py @@ -510,7 +510,7 @@ LINUX_FILE_NAME_LENGTH_LIMIT = 200 def check_path_length_valid(path): path = os.path.realpath(path) - return len(os.path.basename(path) <= LINUX_FILE_NAME_LENGTH_LIMIT) + return len(os.path.basename(path)) <= LINUX_FILE_NAME_LENGTH_LIMIT def check_path_pattern_valid(path): diff --git a/debug/accuracy_tools/api_accuracy_checker/compare/algorithm.py b/debug/accuracy_tools/api_accuracy_checker/compare/algorithm.py index 9ccdb05baa877c42488ca1ae4a5aa2bb0f66dbbc..88df23cacacd5b17b5ed5f4e6ccd23183d9ff80f 100644 --- a/debug/accuracy_tools/api_accuracy_checker/compare/algorithm.py +++ b/debug/accuracy_tools/api_accuracy_checker/compare/algorithm.py @@ -39,6 +39,11 @@ def get_max_rel_err(n_value, b_value): return 1, False +def cosine_standard(compare_result): + bool_result = np.array(compare_result) > 0.99 + return np.all(bool_result), bool_result + + def cosine_sim(cpu_output, npu_output): n_value = npu_output.cpu().detach().numpy().reshape(-1) b_value = cpu_output.detach().numpy().reshape(-1) diff --git a/debug/accuracy_tools/api_accuracy_checker/compare/compare.py b/debug/accuracy_tools/api_accuracy_checker/compare/compare.py index eb1b2586e9a8625da4eac6a7995bca117a8e9b6c..ae869e208e25a96cd363b71762f9d1399a8db8d2 100644 --- a/debug/accuracy_tools/api_accuracy_checker/compare/compare.py +++ b/debug/accuracy_tools/api_accuracy_checker/compare/compare.py @@ -1,6 +1,6 @@ # 进行比对及结果展示 import os -from prettytable import Prettytable +from prettytable import PrettyTable from .algorithm import compare_core, cosine_sim, cosine_standard from ..common.utils import get_json_contents, print_error_log, print_info_log, write_csv from .compare_utils import CompareConst @@ -32,7 +32,7 @@ class Comparator: "forward_and_backward_not_pass": self.test_result_cnt['forward_and_backward_fail_num'], "pass": self.test_result_cnt['success_num'] } - tb = Prettytable() + tb = PrettyTable() tb.add_column("Category", list(res_dict.keys())) tb.add_column("statistics",list(res_dict.values())) info_tb = str(tb) diff --git a/debug/accuracy_tools/api_accuracy_checker/dump/api_info.py b/debug/accuracy_tools/api_accuracy_checker/dump/api_info.py index e50e95b46eb0f663c510394a8b42e4592d7a8c3e..ea27248cf35d53c37623375addb93e89a676ffa4 100644 --- a/debug/accuracy_tools/api_accuracy_checker/dump/api_info.py +++ b/debug/accuracy_tools/api_accuracy_checker/dump/api_info.py @@ -1,11 +1,13 @@ # 定义API INFO,保存基本信息,用于后续结构体的落盘,注意考虑random场景及真实数据场景 import inspect import torch +import torch_npu from .utils import DumpUtil, DumpConst, write_npy from ..common.utils import print_error_log class APIInfo: def __init__(self, api_name): + self.rank = torch_npu.npu.current_device() self.api_name = api_name self.save_real_data = DumpUtil.save_real_data @@ -105,9 +107,10 @@ class ForwardAPIInfo(APIInfo): def analyze_api_call_stack(self): stack_str = [] for (_, path, line, func, code, _) in inspect.stack()[3:]: + if not code: continue stack_line = " ".join([ "File", ", ".join([path, " ".join(["line", str(line)]), " ".join(["in", func]), - " ".join(["\n", code[0].strip() if code else code])])]) + " ".join(["\n", code[0].strip()])])]) stack_str.append(stack_line) self.stack_info_struct = {self.api_name: stack_str} diff --git a/debug/accuracy_tools/api_accuracy_checker/dump/dump.py b/debug/accuracy_tools/api_accuracy_checker/dump/dump.py index 6ffc77578f1ce03d434361d671b317cfae3f4e95..ade72d3ebae88b5cf22e2f79577a21a5f4312ffb 100644 --- a/debug/accuracy_tools/api_accuracy_checker/dump/dump.py +++ b/debug/accuracy_tools/api_accuracy_checker/dump/dump.py @@ -24,7 +24,7 @@ import torch import threading -from .utils import ForwardAPIInfo, BackwardAPIInfo +from .api_info import ForwardAPIInfo, BackwardAPIInfo from .info_dump import write_api_info_json from .utils import DumpConst, DumpUtil from ..common.utils import print_warn_log, print_info_log, print_error_log @@ -35,7 +35,7 @@ def pretest_info_dump(name, out_feat, module, phase): if phase == DumpConst.forward: api_info = ForwardAPIInfo(name, module.input_args, module.input_kwargs) elif phase == DumpConst.backward: - api_info = BackwardApiInfo(name, out_feat) + api_info = BackwardAPIInfo(name, out_feat) else: msg = "Unexpected training phase {}.".format(phase) print_error_log(msg) diff --git a/debug/accuracy_tools/api_accuracy_checker/dump/info_dump.py b/debug/accuracy_tools/api_accuracy_checker/dump/info_dump.py index 0f76069a7f9e81cd6ddf08a467c5cdf2f7a55f5c..dc60fd2cb7dbfbf735c843371362fd2721e2a3f1 100644 --- a/debug/accuracy_tools/api_accuracy_checker/dump/info_dump.py +++ b/debug/accuracy_tools/api_accuracy_checker/dump/info_dump.py @@ -12,15 +12,15 @@ lock = threading.Lock() def write_api_info_json(api_info): dump_path = DumpUtil.dump_path - initialize_output_json() + rank = api_info.rank if isinstance(api_info, ForwardAPIInfo): - file_path = os.path.join(dump_path, 'forward_info.json') - stack_file_path = os.path.join(dump_path, 'stack_info.json') + file_path = os.path.join(dump_path, f'forward_info_{rank}.json') + stack_file_path = os.path.join(dump_path, f'stack_info_{rank}.json') write_json(file_path, api_info.api_info_struct) write_json(stack_file_path, api_info.stack_info_struct, indent=4) elif isinstance(api_info, BackwardAPIInfo): - file_path = os.path.join(dump_path, 'backward_info.json') + file_path = os.path.join(dump_path, f'backward_info_{rank}.json') write_json(file_path, api_info.grad_info_struct) else: raise ValueError(f"Invalid api_info type {type(api_info)}") diff --git a/debug/accuracy_tools/api_accuracy_checker/dump/utils.py b/debug/accuracy_tools/api_accuracy_checker/dump/utils.py index 1aaf7a3d53a0c4da4ef6cf7b24cb84e843898355..f9c7eeaf27ccd15f2183edcd54dd1e0fc86c670f 100644 --- a/debug/accuracy_tools/api_accuracy_checker/dump/utils.py +++ b/debug/accuracy_tools/api_accuracy_checker/dump/utils.py @@ -6,11 +6,6 @@ import numpy as np from ..common.utils import print_error_log, CompareException, DumpException, Const, get_time, print_info_log, \ check_mode_valid, get_api_name_from_matcher -from ..common.version import __version__ - -dump_count = 0 -range_begin_flag, range_end_flag = False, False - class DumpConst: delimiter = '*' forward = 'forward' @@ -33,7 +28,7 @@ def set_dump_switch(switch): class DumpUtil(object): save_real_data = False - dump_path = './random_data_jsons' + dump_path = './api_info' dump_switch = None @staticmethod