diff --git "a/debug/accuracy_tools/api_accuracy_checker/Ascend\346\250\241\345\236\213\347\262\276\345\272\246\351\242\204\346\243\200\345\267\245\345\205\267\344\275\277\347\224\250\346\226\271\346\263\225.md" "b/debug/accuracy_tools/api_accuracy_checker/Ascend\346\250\241\345\236\213\347\262\276\345\272\246\351\242\204\346\243\200\345\267\245\345\205\267\344\275\277\347\224\250\346\226\271\346\263\225.md" index 486a319b6b498dd2cd7abb319960a6a5e2cd1d38..74e6ff59ac19bed0746877c4693184c376525a2c 100644 --- "a/debug/accuracy_tools/api_accuracy_checker/Ascend\346\250\241\345\236\213\347\262\276\345\272\246\351\242\204\346\243\200\345\267\245\345\205\267\344\275\277\347\224\250\346\226\271\346\263\225.md" +++ "b/debug/accuracy_tools/api_accuracy_checker/Ascend\346\250\241\345\236\213\347\262\276\345\272\246\351\242\204\346\243\200\345\267\245\345\205\267\344\275\277\347\224\250\346\226\271\346\263\225.md" @@ -1,8 +1,15 @@ # Ascend模型精度预检工具 +模型精度预检工具会提取模型中所有的API前反向的信息,构造相应的API单元测试,将NPU输出与标杆比对,从而检测出精度有问题的API。 + +## 工具特性 +1. 落盘数据小 +2. 不依赖标杆侧GPU训练资源,本地即可完成预检 +3. 支持随机生成模式和真实数据模式 +4. 单API测试,排除整网中的累计误差问题 ## 使用方式 -1. 安装遇见工具 +1. 安装预检工具 将att仓代码下载到本地,并配置环境变量。假设att仓本地路径为 {att_root},环境变量应配置为 @@ -10,21 +17,36 @@ export PYTHONPATH=$PYTHONPATH:{att_root}/debug/accuracy_tools/ ``` -2. 使用工具dump模块抓取网络所有API信息 +2. 在工具中加入以下代码使用工具dump模块,启动训练抓取网络所有API信息,目前工具仅支持抓取训练的第一个迭代并且在第一个迭代后会退出训练进程。 ``` from api_accuracy_checker.dump import set_dump_switch - set_dump_switch("ON") ``` -​ dump信息默认会存盘到./api_info/路径下,后缀的数字代表进程pid +​ dump信息默认会存盘到./路径下,包括前向API信息forward_info_{pid}.json, 反向API信息backward_info_{pid}.json, 调用栈信息stack_info_{pid}.json。真实数据模式下还有forward_real_data和backward_real_data文件夹,里面有每个api输入的具体数值。forward_info与stack_info中的key值一一对应,用户可根据forward_info中API的key在stack_info中查询到其调用栈及代码行位置。 + + 有需要的话,用户可以通过msCheckerConfig.update_config来配置dump路径以及启用真实数据模式(默认为关)。注意启用真实数据模式目前仅支持单卡,且会存盘较多数据,可能对磁盘空间有较大冲击。 + ``` + from api_accuracy_checker.dump import msCheckerConfig + msCheckerConfig.update_config(dump_path="my/dump/path", real_data=True) # my/dump/path需配置为用户想要的api信息存盘路径,并且需要提前创建好 + ``` -3. 将上述信息输入给run_ut模块运行精度检测并比对 +3. 将上述信息输入给run_ut模块运行精度检测并比对,运行如下命令: ``` cd run_ut - python run_ut.py --forward ./api_info/forward_info_0.json --backward ./api_info/backward_info_0.json + python run_ut.py -forward ./forward_info_0.json -backward ./backward_info_0.json ``` - forward和backward两个命令行参数根据实际情况配置。比对结果存盘位置会打屏显示,默认是'./',可以在运行run_ut.py时通过 --out_path命令行参数配置。 + forward和backward两个命令行参数根据实际存盘的json文件名配置。比对结果存盘路径默认是'./',可以在运行run_ut.py时通过 --out_path命令行参数配置。结果包括pretest_result.csv和pretest_details.csv两个文件。前者是api粒度的,标明每个api是否通过测试。建议用户先查看前者,对于其中没有通过测试的或者特定感兴趣的api,根据其API name字段在pretest_details.csv中查询其各个输出的达标情况以及比较指标。 + + 注意:目前API通过测试的标准是每个输出与标杆比对的余弦相似度大于0.99,pretest_details.csv中的相对误差供用户分析时使用。 + + + + + + + + diff --git a/debug/accuracy_tools/api_accuracy_checker/common/utils.py b/debug/accuracy_tools/api_accuracy_checker/common/utils.py index 988bcaa8f040502ce04804f828fc7a43f4678fd3..79c1aa368b9ffbc056dbd247f3759df3dc835364 100644 --- a/debug/accuracy_tools/api_accuracy_checker/common/utils.py +++ b/debug/accuracy_tools/api_accuracy_checker/common/utils.py @@ -62,7 +62,9 @@ class Const: OFF = 'OFF' BACKWARD = 'backward' FORWARD = 'forward' - FLOAT_TYPE = [np.half, np.single, np.double, np.float64, np.longdouble] + FLOAT_TYPE = [np.half, np.single, float, np.double, np.float64, np.longdouble] + BOOL_TYPE = [bool, np.uint8] + INT_TYPE = [np.int32, np.int64] # dump mode ALL = "all" diff --git a/debug/accuracy_tools/api_accuracy_checker/compare/algorithm.py b/debug/accuracy_tools/api_accuracy_checker/compare/algorithm.py index 809a24d20960cfcd827af35b6dd9a17f7da3615c..58e05cf749a9a41b46867cdb939142b696b572d7 100644 --- a/debug/accuracy_tools/api_accuracy_checker/compare/algorithm.py +++ b/debug/accuracy_tools/api_accuracy_checker/compare/algorithm.py @@ -2,11 +2,15 @@ import torch import numpy as np -from api_accuracy_checker.compare.compare_utils import CompareConst -from api_accuracy_checker.common.utils import print_warn_log, Const +from api_accuracy_checker.compare.compare_utils import CompareConst, check_dtype_comparable +from api_accuracy_checker.common.utils import Const + def compare_torch_tensor(cpu_output, npu_output, compare_alg): - if cpu_output.dtype == torch.bool: + if not check_dtype_comparable(cpu_output, npu_output): + return CompareConst.NAN, False, f"Bench out dtype is {cpu_output.dtype} but\ + npu output dtype is {npu_output.dtype}, cannot compare." + if cpu_output.dtype == np.bool or cpu_output.dtype == np.uint8: return compare_bool_tensor(cpu_output, npu_output) return compare_alg(cpu_output, npu_output) @@ -16,27 +20,47 @@ def compare_bool_tensor(cpu_output, npu_output): cpu_shape = cpu_output.shape npu_shape = npu_output.shape if cpu_shape != npu_shape: - return error_rate, False - npu_data = npu_output.cpu().detach().numpy() - bench_data = cpu_output.detach().numpy() + return error_rate, False, "" + npu_data = npu_output + bench_data = cpu_output data_size = bench_data.size error_nums = (bench_data != npu_data).sum() error_rate = float(error_nums / data_size) - return error_rate, error_rate < 0.001 + return error_rate, error_rate < 0.001, "" def get_max_rel_err(n_value, b_value): + msg = "" if not isinstance(n_value, np.ndarray) or not isinstance(b_value, np.ndarray): - print_warn_log("Max rel err only support numpy array!") - raise ValueError("Max rel err only support numpy array!") + msg = f"Max rel err only support numpy array! The actual type is {type(n_value)}, {type(b_value)}." + return CompareConst.NAN, False, msg + if n_value.shape != b_value.shape: + msg = f"Shape of npu and bench outputs don't match. NPU: {n_value.shape}, bench: {b_value.shape}." + return CompareConst.NAN, False, msg if n_value.dtype != b_value.dtype: - return CompareConst.NA, False - if n_value.dtype in Const.FLOAT_TYPE: - rel_err = np.abs((n_value - b_value) / (b_value + np.finfo(b_value.dtype).eps)).max() - return rel_err, rel_err < 0.001 - if np.all(n_value == b_value): - return 0, True - return 1, False + msg = f"Dtype of npu and bench outputs don't match. NPU: {n_value.dtype}, bench: {b_value.dtype}." + + if b_value.dtype in Const.FLOAT_TYPE: + zero_mask = (b_value == 0) + # 给0的地方加上eps防止除0 + b_value[zero_mask] += np.finfo(b_value.dtype).eps + # 根据b_value为0的位置给n_value也加上eps,否则两者都是0的情况下相对误差会是1 + n_value[zero_mask] += np.finfo(b_value.dtype).eps + else: + # int type + float eps 会报错,所以这里要强转 + n_value, b_value = n_value.astype(float), b_value.astype(float) + zero_mask = (b_value == 0) + b_value[zero_mask] += np.finfo(float).eps + n_value[zero_mask] += np.finfo(float).eps + rel_err = np.abs((n_value - b_value) / b_value).max() + bool_result = rel_err < 0.001 + + return rel_err, bool_result, msg + + +def max_rel_err_standard(max_rel_errs): + bool_result = np.array(max_rel_errs) < 0.001 + return np.all(bool_result), bool_result def cosine_standard(compare_result): @@ -45,34 +69,34 @@ def cosine_standard(compare_result): def cosine_sim(cpu_output, npu_output): - n_value = npu_output.cpu().detach().numpy().reshape(-1) - b_value = cpu_output.detach().numpy().reshape(-1) + msg = "" + n_value = npu_output.reshape(-1) + b_value = cpu_output.reshape(-1) cos = CompareConst.NA np.seterr(divide="ignore", invalid="ignore") + if n_value.shape != b_value.shape: + msg = f"Shape of npu and bench outputs don't match. NPU: {n_value.shape}, bench: {b_value.shape}." + return -1, False, msg if len(n_value) == 1: - print_warn_log("All the data in npu dump data is scalar. Compare by relative error.") - return get_max_rel_err(n_value, b_value) - if len(n_value) == len(b_value) == 0: - print_warn_log("The npu dump data and bench dump data is empty.") - return cos, True - if n_value.dtype == np.uint8: - return compare_uint8_data(n_value, b_value) - n_max = np.max(np.abs(n_value)) - b_max = np.max(np.abs(b_value)) - if n_max <= np.finfo(float).eps and b_max <= np.finfo(float).eps: - return cos, True - elif n_max <= np.finfo(float).eps: - print_warn_log("All the data is Zero in npu dump data. Compare by relative error.") - return get_max_rel_err(n_value, b_value) - elif b_max <= np.finfo(float).eps: - print_warn_log("All the data is Zero in bench dump data. Compare by relative error.") + msg = "All the data in npu dump data is scalar. Please refer to other compare algorithms." + return cos, True, msg + n_value_max = np.max(np.abs(n_value)) + b_value_max = np.max(np.abs(b_value)) + if n_value_max <= np.finfo(float).eps and b_value_max <= np.finfo(float).eps: + return cos, True, msg + elif n_value_max <= np.finfo(float).eps: + msg = "All the data is zero in npu dump data." + return CompareConst.NAN, False, msg + elif b_value_max <= np.finfo(float).eps: + msg = "All the data is zero in bench dump data." + return CompareConst.NAN, False, msg else: - n_value = n_value.astype(float) / n_max - b_value = b_value.astype(float) / b_max + n_value = n_value_max.astype(float) / n_value_max + b_value = b_value_max.astype(float) / b_value_max cos = np.dot(n_value, b_value) / (np.linalg.norm(n_value) * np.linalg.norm(b_value)) if np.isnan(cos): - print_warn_log("Dump data has NaN when comparing with Cosine Similarity.") - return cos, cos > 0.99 + msg = "Dump data has NaN when comparing with Cosine Similarity." + return cos, cos > 0.99, msg def compare_uint8_data(n_value, b_value): @@ -83,9 +107,11 @@ def compare_uint8_data(n_value, b_value): def compare_builtin_type(bench_out, npu_out): + if not isinstance(bench_out, (bool, int, float, str)): + return CompareConst.NA, True, "" if bench_out != npu_out: - return CompareConst.NAN, False - return 1.0, True + return CompareConst.NAN, False, "" + return True, True, "" def flatten_compare_result(result): @@ -97,14 +123,15 @@ def flatten_compare_result(result): flatten_result.append(result_i) return flatten_result - +# 本函数用alg比对bench_out 和npu_out,返回详细比对结果compare_result和标志比对是否通过的布尔变量test_success def compare_core(bench_out, npu_out, alg): - if type(bench_out) != type(npu_out): - raise ValueError("bench and npu output type is different") + msg = "" + if not isinstance(bench_out, type(npu_out)): + return CompareConst.NAN, False, "bench and npu output type is different." if isinstance(bench_out, (list, tuple)): compare_result, test_success = [], True if len(bench_out) != len(npu_out): - raise ValueError("bench and npu output structure is different") + return CompareConst.NAN, False, "bench and npu output structure is different" for b_out_i, n_out_i in zip(bench_out, npu_out): compare_result_i, test_success_i = compare_core(b_out_i, n_out_i, alg) compare_result.append(compare_result_i) @@ -112,18 +139,21 @@ def compare_core(bench_out, npu_out, alg): elif isinstance(bench_out, dict): b_keys, n_keys = set(bench_out.keys()), set(npu_out.keys()) if b_keys != n_keys: - raise ValueError("bench and npu output dictionary keys are different") + compare_result, test_success, msg = CompareConst.NAN, False, "bench and npu output dict keys are different" compare_result, test_success = compare_core(list(bench_out.values()), list(npu_out.values())) elif isinstance(bench_out, torch.Tensor): - compare_result, test_success = compare_torch_tensor(bench_out, npu_out, alg) + compare_result, test_success, msg = compare_torch_tensor(bench_out.detach().numpy(), npu_out.detach().cpu().numpy(), alg) elif isinstance(bench_out, (bool, int, float, str)): - compare_result, test_success = compare_builtin_type(bench_out, npu_out) + compare_result, test_success, msg = compare_builtin_type(bench_out, npu_out) elif bench_out is None: - return 1.0, True + compare_result, test_success, msg = CompareConst.NA, True, "output is None" else: - raise NotImplementedError("Unexpected output type in compare_core: {}".format(type(bench_out))) + compare_result, test_success, msg = CompareConst.NA, True, "Unexpected output type \ + in compare_core: {}".format(type(bench_out)) if isinstance(compare_result, list): compare_result = flatten_compare_result(compare_result) + else: + compare_result = [(compare_result, str(test_success), msg)] return compare_result, test_success diff --git a/debug/accuracy_tools/api_accuracy_checker/compare/compare.py b/debug/accuracy_tools/api_accuracy_checker/compare/compare.py index ed3c50a0cd352590cea358c3c629cd3c3fedb29e..7a1c069e2eff91940d26a7bf4b74bfc54554a04e 100644 --- a/debug/accuracy_tools/api_accuracy_checker/compare/compare.py +++ b/debug/accuracy_tools/api_accuracy_checker/compare/compare.py @@ -1,12 +1,15 @@ # 进行比对及结果展示 -import os +import os from prettytable import PrettyTable -from api_accuracy_checker.compare.algorithm import compare_core, cosine_sim, cosine_standard +from api_accuracy_checker.compare.algorithm import compare_core, cosine_sim, cosine_standard, get_max_rel_err, \ + compare_builtin_type from api_accuracy_checker.common.utils import get_json_contents, print_error_log, print_info_log, write_csv from api_accuracy_checker.compare.compare_utils import CompareConst + class Comparator: TEST_FILE_NAME = "pretest_result.csv" + DETAIL_TEST_FILE_NAME = "pretest_details.csv" # consts for result csv COLUMN_API_NAME = "API name" COLUMN_FORWARD_SUCCESS = "Forward Test Success" @@ -15,16 +18,19 @@ class Comparator: def __init__(self, result_save_path, stack_info_json_path=None): self.save_path = os.path.join(result_save_path, self.TEST_FILE_NAME) + self.detail_save_path = os.path.join(result_save_path, self.DETAIL_TEST_FILE_NAME) if stack_info_json_path: self.stack_info = get_json_contents(stack_info_json_path) else: - self.stack_info = None + self.stack_info = None self.compare_alg = {} - self.compare_alg_names = [] - self.register_compare_algorithm("Cosine Similarity", cosine_sim, cosine_standard) + self.register_compare_algorithm("Cosine Similarity", cosine_sim, cosine_standard) + self.register_compare_algorithm("Max Relative Error", get_max_rel_err, None) + self.register_compare_algorithm("Default: isEqual", compare_builtin_type, None) self.test_results = [] - self.test_result_cnt = {"forward_fail_num": 0, "backward_fail_num": 0, "forward_and_backward_fail_num": 0, - "success_num": 0} + self.test_result_cnt = { + "forward_fail_num": 0, "backward_fail_num": 0, "forward_and_backward_fail_num": 0, "success_num": 0 + } def print_pretest_result(self): res_dict = { @@ -32,7 +38,7 @@ class Comparator: "backward_not_pass": self.test_result_cnt['backward_fail_num'], "forward_and_backward_not_pass": self.test_result_cnt['forward_and_backward_fail_num'], "pass": self.test_result_cnt['success_num'] - } + } tb = PrettyTable() tb.add_column("Category", list(res_dict.keys())) tb.add_column("statistics", list(res_dict.values())) @@ -40,14 +46,15 @@ class Comparator: print_info_log(info_tb) def write_compare_csv(self): - self.write_summary_csv() + self.write_summary_csv() + self.write_detail_csv() def write_summary_csv(self): test_rows = [[self.COLUMN_API_NAME, self.COLUMN_FORWARD_SUCCESS, self.COLUMN_BACKWARD_SUCCESS]] if self.stack_info: test_rows[0].append(self.COLUMN_STACK_INFO) for result in self.test_results: - name = result[0] + name = result[0] df_row = list(result[:3]) if self.stack_info: stack_info = "\n".join(self.stack_info[name]) @@ -55,16 +62,37 @@ class Comparator: test_rows.append(df_row) write_csv(test_rows, self.save_path) + def write_detail_csv(self): + test_rows = [[ + "Subject", "Cosine Similarity", "Cosine Similarity Pass", "Cosine Similarity Message", + "Max Rel Error", "Max Rel Err Pass", "Max Rel Err Message", + "Compare Builtin Type", "Builtin Type Pass", + "Builtin Type Message" + ]] + for test_result in self.test_results: + subject_prefix = test_result[0] + fwd_result = test_result[3] + bwd_result = test_result[4] + if isinstance(fwd_result, list): + for i, test_subject in enumerate(fwd_result): + subject = subject_prefix + ".forward.output." + str(i) + test_rows.append([subject] + list(test_subject)) + if isinstance(bwd_result, list): + for i, test_subject in enumerate(bwd_result): + subject = subject_prefix + ".backward.output." + str(i) + test_rows.append([subject] + list(test_subject)) + + write_csv(test_rows, self.detail_save_path) + def record_results(self, *args): self.test_results.append(args) def register_compare_algorithm(self, name, compare_func, standard): self.compare_alg.update({name: (compare_func, standard)}) - self.compare_alg_names.append(name) def compare_output(self, api_name, bench_out, npu_out, bench_grad=None, npu_grad=None): if "dropout" in api_name: - is_fwd_success, fwd_compare_alg_results = self._compare_dropout(bench_out, npu_out) + is_fwd_success, fwd_compare_alg_results = self._compare_dropout(bench_out, npu_out) else: is_fwd_success, fwd_compare_alg_results = self._compare_core_wrapper(bench_out, npu_out) if bench_grad and npu_grad: @@ -73,7 +101,7 @@ class Comparator: else: is_bwd_success, bwd_compare_alg_results = self._compare_core_wrapper(bench_grad, npu_grad) else: - is_bwd_success, bwd_compare_alg_results = CompareConst.NA, None + is_bwd_success, bwd_compare_alg_results = CompareConst.NA, None self.record_results(api_name, is_fwd_success, is_bwd_success, fwd_compare_alg_results, bwd_compare_alg_results) if is_fwd_success and is_bwd_success: self.test_result_cnt['success_num'] += 1 @@ -85,10 +113,20 @@ class Comparator: self.test_result_cnt['backward_fail_num'] += 1 def _compare_core_wrapper(self, bench_out, npu_out): - name = self.compare_alg_names[0] - detailed_result, test_success = compare_core(bench_out, npu_out, self.compare_alg[name][0]) - return test_success, detailed_result - + detailed_result_total = [] + test_success_total = True + for name in self.compare_alg.keys(): + alg = self.compare_alg[name][0] + detailed_result, test_success = compare_core(bench_out, npu_out, alg) + if name != "Max Relative Error": + test_success_total = test_success_total and test_success + if detailed_result_total: + for i in range(len(detailed_result_total)): + detailed_result_total[i] += detailed_result[i] + else: + detailed_result_total = detailed_result + return test_success_total, detailed_result_total + @staticmethod def _compare_dropout(bench_out, npu_out): tensor_num = bench_out.numel() diff --git a/debug/accuracy_tools/api_accuracy_checker/compare/compare_utils.py b/debug/accuracy_tools/api_accuracy_checker/compare/compare_utils.py index 2f4d7eb38ccd5389b5d88a3bcdb6340ee6459ac7..62044f585218cf98e26859d2ed9492289531e0eb 100644 --- a/debug/accuracy_tools/api_accuracy_checker/compare/compare_utils.py +++ b/debug/accuracy_tools/api_accuracy_checker/compare/compare_utils.py @@ -1,4 +1,24 @@ +from api_accuracy_checker.common.utils import Const, print_warn_log import numpy as np + + class CompareConst: NAN = np.nan NA = "N/A" + + +def check_dtype_comparable(x, y): + if x.dtype in Const.FLOAT_TYPE: + if y.dtype in Const.FLOAT_TYPE: + return True + return False + if x.dtype in Const.BOOL_TYPE: + if y.dtype in Const.BOOL_TYPE: + return True + return False + if x.dtype in Const.INT_TYPE: + if y.dtype in Const.INT_TYPE: + return True + return False + print_warn_log(f"Compare: Unexpected dtype {x.dtype}, {y.dtype}") + return False \ No newline at end of file diff --git a/debug/accuracy_tools/api_accuracy_checker/config.yaml b/debug/accuracy_tools/api_accuracy_checker/config.yaml index 38a1a3c47b1c857a9cea34b40ab56f32eed596c0..7e2cd46fe2477ba2bfb0a16ed972e1beab6d8da0 100644 --- a/debug/accuracy_tools/api_accuracy_checker/config.yaml +++ b/debug/accuracy_tools/api_accuracy_checker/config.yaml @@ -1,4 +1,4 @@ -dump_path: './api_info' +dump_path: './' jit_compile: True compile_option: -O3 compare_algorithm: cosine_similarity