From 27412b8492ecfccd49bae3c5e3a68bc647a0736d Mon Sep 17 00:00:00 2001 From: litian_drinksnow Date: Wed, 2 Aug 2023 11:45:55 +0800 Subject: [PATCH 1/8] =?UTF-8?q?=E5=A4=9A=E5=8D=A1=E9=80=82=E9=85=8D?= =?UTF-8?q?=E5=92=8Cdump.py=20bug=E4=BF=AE=E6=94=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../accuracy_tools/api_accuracy_checker/dump/api_info.py | 2 ++ debug/accuracy_tools/api_accuracy_checker/dump/dump.py | 2 +- .../accuracy_tools/api_accuracy_checker/dump/info_dump.py | 8 ++++---- 3 files changed, 7 insertions(+), 5 deletions(-) diff --git a/debug/accuracy_tools/api_accuracy_checker/dump/api_info.py b/debug/accuracy_tools/api_accuracy_checker/dump/api_info.py index e50e95b46e..1a9a6b12ec 100644 --- a/debug/accuracy_tools/api_accuracy_checker/dump/api_info.py +++ b/debug/accuracy_tools/api_accuracy_checker/dump/api_info.py @@ -1,11 +1,13 @@ # 定义API INFO,保存基本信息,用于后续结构体的落盘,注意考虑random场景及真实数据场景 import inspect import torch +import torch_npu from .utils import DumpUtil, DumpConst, write_npy from ..common.utils import print_error_log class APIInfo: def __init__(self, api_name): + self.rank = torch_npu.npu.current_device() self.api_name = api_name self.save_real_data = DumpUtil.save_real_data diff --git a/debug/accuracy_tools/api_accuracy_checker/dump/dump.py b/debug/accuracy_tools/api_accuracy_checker/dump/dump.py index 6ffc77578f..21e27f5918 100644 --- a/debug/accuracy_tools/api_accuracy_checker/dump/dump.py +++ b/debug/accuracy_tools/api_accuracy_checker/dump/dump.py @@ -24,7 +24,7 @@ import torch import threading -from .utils import ForwardAPIInfo, BackwardAPIInfo +from .api_info import ForwardAPIInfo, BackwardAPIInfo from .info_dump import write_api_info_json from .utils import DumpConst, DumpUtil from ..common.utils import print_warn_log, print_info_log, print_error_log diff --git a/debug/accuracy_tools/api_accuracy_checker/dump/info_dump.py b/debug/accuracy_tools/api_accuracy_checker/dump/info_dump.py index 0f76069a7f..dc60fd2cb7 100644 --- a/debug/accuracy_tools/api_accuracy_checker/dump/info_dump.py +++ b/debug/accuracy_tools/api_accuracy_checker/dump/info_dump.py @@ -12,15 +12,15 @@ lock = threading.Lock() def write_api_info_json(api_info): dump_path = DumpUtil.dump_path - initialize_output_json() + rank = api_info.rank if isinstance(api_info, ForwardAPIInfo): - file_path = os.path.join(dump_path, 'forward_info.json') - stack_file_path = os.path.join(dump_path, 'stack_info.json') + file_path = os.path.join(dump_path, f'forward_info_{rank}.json') + stack_file_path = os.path.join(dump_path, f'stack_info_{rank}.json') write_json(file_path, api_info.api_info_struct) write_json(stack_file_path, api_info.stack_info_struct, indent=4) elif isinstance(api_info, BackwardAPIInfo): - file_path = os.path.join(dump_path, 'backward_info.json') + file_path = os.path.join(dump_path, f'backward_info_{rank}.json') write_json(file_path, api_info.grad_info_struct) else: raise ValueError(f"Invalid api_info type {type(api_info)}") -- Gitee From 1f3b339c87a822e17acfa9f72ba8070730542af4 Mon Sep 17 00:00:00 2001 From: litian_drinksnow Date: Wed, 2 Aug 2023 14:43:24 +0800 Subject: [PATCH 2/8] =?UTF-8?q?=E5=A2=9E=E5=8A=A0=E9=A2=84=E6=A3=80?= =?UTF-8?q?=E5=B7=A5=E5=85=B7=E4=BD=BF=E7=94=A8=E6=96=B9=E6=B3=95=EF=BC=8C?= =?UTF-8?q?=E4=B8=8D=E5=B8=A6config=E7=89=88?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ...77\347\224\250\346\226\271\346\263\225.md" | 30 +++++++++++++++++++ 1 file changed, 30 insertions(+) create mode 100644 "debug/accuracy_tools/api_accuracy_checker/Ascend\346\250\241\345\236\213\347\262\276\345\272\246\351\242\204\346\243\200\345\267\245\345\205\267\344\275\277\347\224\250\346\226\271\346\263\225.md" diff --git "a/debug/accuracy_tools/api_accuracy_checker/Ascend\346\250\241\345\236\213\347\262\276\345\272\246\351\242\204\346\243\200\345\267\245\345\205\267\344\275\277\347\224\250\346\226\271\346\263\225.md" "b/debug/accuracy_tools/api_accuracy_checker/Ascend\346\250\241\345\236\213\347\262\276\345\272\246\351\242\204\346\243\200\345\267\245\345\205\267\344\275\277\347\224\250\346\226\271\346\263\225.md" new file mode 100644 index 0000000000..ff62486348 --- /dev/null +++ "b/debug/accuracy_tools/api_accuracy_checker/Ascend\346\250\241\345\236\213\347\262\276\345\272\246\351\242\204\346\243\200\345\267\245\345\205\267\344\275\277\347\224\250\346\226\271\346\263\225.md" @@ -0,0 +1,30 @@ +# Ascend模型精度预检工具 + +## 使用方式 + +1. 安装遇见工具 + + 将att仓代码下载到本地,并配置环境变量。假设att仓本地路径为 {att_root},环境变量应配置为 + + ``` + export PYTHONPATH=$PYTHONPATH:{att_root}/debug/accuracy_tools/ + ``` + +2. 使用工具dump模块抓取网络所有API信息 + + ``` + from api_accuracy_checker.dump import set_dump_switch + set_dump_switch("ON") + ``` + +​ dump信息默认会存盘到./api_info/路径下,后缀的数字代表rank id + +3. 将上述信息输入给run_ut模块运行精度检测并比对 + + ``` + cd run_ut + python run_ut.py --forward ./api_info/forward_info_0.json --backward ./api_info/backward_info_0.json + ``` + + forward和backward两个命令行参数根据实际情况配置。比对结果存盘位置会打屏显示。 + -- Gitee From e967d2b67200a0e1fd6623a7e65c7df4390981c5 Mon Sep 17 00:00:00 2001 From: litian_drinksnow Date: Wed, 2 Aug 2023 15:08:53 +0800 Subject: [PATCH 3/8] fix typo --- debug/accuracy_tools/api_accuracy_checker/dump/dump.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/debug/accuracy_tools/api_accuracy_checker/dump/dump.py b/debug/accuracy_tools/api_accuracy_checker/dump/dump.py index 21e27f5918..ade72d3eba 100644 --- a/debug/accuracy_tools/api_accuracy_checker/dump/dump.py +++ b/debug/accuracy_tools/api_accuracy_checker/dump/dump.py @@ -35,7 +35,7 @@ def pretest_info_dump(name, out_feat, module, phase): if phase == DumpConst.forward: api_info = ForwardAPIInfo(name, module.input_args, module.input_kwargs) elif phase == DumpConst.backward: - api_info = BackwardApiInfo(name, out_feat) + api_info = BackwardAPIInfo(name, out_feat) else: msg = "Unexpected training phase {}.".format(phase) print_error_log(msg) -- Gitee From a9168cc0ac4d318ca6cdd53dfc2afc319d19ea58 Mon Sep 17 00:00:00 2001 From: litian_drinksnow Date: Wed, 2 Aug 2023 15:19:17 +0800 Subject: [PATCH 4/8] fix typo --- debug/accuracy_tools/api_accuracy_checker/dump/utils.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/debug/accuracy_tools/api_accuracy_checker/dump/utils.py b/debug/accuracy_tools/api_accuracy_checker/dump/utils.py index 1aaf7a3d53..4ad0014358 100644 --- a/debug/accuracy_tools/api_accuracy_checker/dump/utils.py +++ b/debug/accuracy_tools/api_accuracy_checker/dump/utils.py @@ -6,11 +6,6 @@ import numpy as np from ..common.utils import print_error_log, CompareException, DumpException, Const, get_time, print_info_log, \ check_mode_valid, get_api_name_from_matcher -from ..common.version import __version__ - -dump_count = 0 -range_begin_flag, range_end_flag = False, False - class DumpConst: delimiter = '*' forward = 'forward' -- Gitee From b11c8a82c16d0629d71d1b7b406c7208f1af3103 Mon Sep 17 00:00:00 2001 From: litian_drinksnow Date: Wed, 2 Aug 2023 16:04:28 +0800 Subject: [PATCH 5/8] fix typo --- debug/accuracy_tools/api_accuracy_checker/dump/api_info.py | 3 ++- debug/accuracy_tools/api_accuracy_checker/dump/utils.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/debug/accuracy_tools/api_accuracy_checker/dump/api_info.py b/debug/accuracy_tools/api_accuracy_checker/dump/api_info.py index 1a9a6b12ec..ea27248cf3 100644 --- a/debug/accuracy_tools/api_accuracy_checker/dump/api_info.py +++ b/debug/accuracy_tools/api_accuracy_checker/dump/api_info.py @@ -107,9 +107,10 @@ class ForwardAPIInfo(APIInfo): def analyze_api_call_stack(self): stack_str = [] for (_, path, line, func, code, _) in inspect.stack()[3:]: + if not code: continue stack_line = " ".join([ "File", ", ".join([path, " ".join(["line", str(line)]), " ".join(["in", func]), - " ".join(["\n", code[0].strip() if code else code])])]) + " ".join(["\n", code[0].strip()])])]) stack_str.append(stack_line) self.stack_info_struct = {self.api_name: stack_str} diff --git a/debug/accuracy_tools/api_accuracy_checker/dump/utils.py b/debug/accuracy_tools/api_accuracy_checker/dump/utils.py index 4ad0014358..f9c7eeaf27 100644 --- a/debug/accuracy_tools/api_accuracy_checker/dump/utils.py +++ b/debug/accuracy_tools/api_accuracy_checker/dump/utils.py @@ -28,7 +28,7 @@ def set_dump_switch(switch): class DumpUtil(object): save_real_data = False - dump_path = './random_data_jsons' + dump_path = './api_info' dump_switch = None @staticmethod -- Gitee From 30ff07fda59ba67d0491b20e013f03b9f3693a7d Mon Sep 17 00:00:00 2001 From: litian_drinksnow Date: Thu, 3 Aug 2023 09:22:11 +0800 Subject: [PATCH 6/8] fix typo --- .../accuracy_tools/api_accuracy_checker/compare/algorithm.py | 5 +++++ debug/accuracy_tools/api_accuracy_checker/compare/compare.py | 4 ++-- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/debug/accuracy_tools/api_accuracy_checker/compare/algorithm.py b/debug/accuracy_tools/api_accuracy_checker/compare/algorithm.py index 9ccdb05baa..88df23caca 100644 --- a/debug/accuracy_tools/api_accuracy_checker/compare/algorithm.py +++ b/debug/accuracy_tools/api_accuracy_checker/compare/algorithm.py @@ -39,6 +39,11 @@ def get_max_rel_err(n_value, b_value): return 1, False +def cosine_standard(compare_result): + bool_result = np.array(compare_result) > 0.99 + return np.all(bool_result), bool_result + + def cosine_sim(cpu_output, npu_output): n_value = npu_output.cpu().detach().numpy().reshape(-1) b_value = cpu_output.detach().numpy().reshape(-1) diff --git a/debug/accuracy_tools/api_accuracy_checker/compare/compare.py b/debug/accuracy_tools/api_accuracy_checker/compare/compare.py index eb1b2586e9..ae869e208e 100644 --- a/debug/accuracy_tools/api_accuracy_checker/compare/compare.py +++ b/debug/accuracy_tools/api_accuracy_checker/compare/compare.py @@ -1,6 +1,6 @@ # 进行比对及结果展示 import os -from prettytable import Prettytable +from prettytable import PrettyTable from .algorithm import compare_core, cosine_sim, cosine_standard from ..common.utils import get_json_contents, print_error_log, print_info_log, write_csv from .compare_utils import CompareConst @@ -32,7 +32,7 @@ class Comparator: "forward_and_backward_not_pass": self.test_result_cnt['forward_and_backward_fail_num'], "pass": self.test_result_cnt['success_num'] } - tb = Prettytable() + tb = PrettyTable() tb.add_column("Category", list(res_dict.keys())) tb.add_column("statistics",list(res_dict.values())) info_tb = str(tb) -- Gitee From 71068908c53d93d98e557ffe04635e03d6c3a80f Mon Sep 17 00:00:00 2001 From: litian_drinksnow Date: Thu, 3 Aug 2023 10:09:32 +0800 Subject: [PATCH 7/8] fix typo --- debug/accuracy_tools/api_accuracy_checker/common/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/debug/accuracy_tools/api_accuracy_checker/common/utils.py b/debug/accuracy_tools/api_accuracy_checker/common/utils.py index dac54a79fc..e62c9616ca 100644 --- a/debug/accuracy_tools/api_accuracy_checker/common/utils.py +++ b/debug/accuracy_tools/api_accuracy_checker/common/utils.py @@ -510,7 +510,7 @@ LINUX_FILE_NAME_LENGTH_LIMIT = 200 def check_path_length_valid(path): path = os.path.realpath(path) - return len(os.path.basename(path) <= LINUX_FILE_NAME_LENGTH_LIMIT) + return len(os.path.basename(path)) <= LINUX_FILE_NAME_LENGTH_LIMIT def check_path_pattern_valid(path): -- Gitee From 312920ab8c21530cb3604c5bf03612cb499677f8 Mon Sep 17 00:00:00 2001 From: litian_drinksnow Date: Thu, 3 Aug 2023 11:32:21 +0800 Subject: [PATCH 8/8] =?UTF-8?q?=E5=88=B7=E6=96=B0readme?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ...\205\267\344\275\277\347\224\250\346\226\271\346\263\225.md" | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git "a/debug/accuracy_tools/api_accuracy_checker/Ascend\346\250\241\345\236\213\347\262\276\345\272\246\351\242\204\346\243\200\345\267\245\345\205\267\344\275\277\347\224\250\346\226\271\346\263\225.md" "b/debug/accuracy_tools/api_accuracy_checker/Ascend\346\250\241\345\236\213\347\262\276\345\272\246\351\242\204\346\243\200\345\267\245\345\205\267\344\275\277\347\224\250\346\226\271\346\263\225.md" index ff62486348..c2c8f456c2 100644 --- "a/debug/accuracy_tools/api_accuracy_checker/Ascend\346\250\241\345\236\213\347\262\276\345\272\246\351\242\204\346\243\200\345\267\245\345\205\267\344\275\277\347\224\250\346\226\271\346\263\225.md" +++ "b/debug/accuracy_tools/api_accuracy_checker/Ascend\346\250\241\345\236\213\347\262\276\345\272\246\351\242\204\346\243\200\345\267\245\345\205\267\344\275\277\347\224\250\346\226\271\346\263\225.md" @@ -26,5 +26,5 @@ python run_ut.py --forward ./api_info/forward_info_0.json --backward ./api_info/backward_info_0.json ``` - forward和backward两个命令行参数根据实际情况配置。比对结果存盘位置会打屏显示。 + forward和backward两个命令行参数根据实际情况配置。比对结果存盘位置会打屏显示,默认是'./',可以在运行run_ut.py时通过 --out_path命令行参数配置。 -- Gitee