From f25c2004ebe1fb10af407cf506296b8aae57bc85 Mon Sep 17 00:00:00 2001 From: s30048155 Date: Tue, 14 Nov 2023 11:25:53 +0800 Subject: [PATCH 1/4] dump by step --- .../api_accuracy_checker/common/base_api.py | 6 +++--- .../api_accuracy_checker/common/config.py | 11 +++++++---- .../accuracy_tools/api_accuracy_checker/dump/dump.py | 6 +++--- .../api_accuracy_checker/dump/info_dump.py | 2 ++ 4 files changed, 15 insertions(+), 10 deletions(-) diff --git a/debug/accuracy_tools/api_accuracy_checker/common/base_api.py b/debug/accuracy_tools/api_accuracy_checker/common/base_api.py index 8e7c063a64..225dab81c6 100644 --- a/debug/accuracy_tools/api_accuracy_checker/common/base_api.py +++ b/debug/accuracy_tools/api_accuracy_checker/common/base_api.py @@ -2,7 +2,7 @@ import os import torch from api_accuracy_checker.common.utils import print_error_log, write_pt, create_directory from ptdbg_ascend.src.python.ptdbg_ascend.common.utils import check_path_before_create - +from api_accuracy_checker.dump.dump import DumpUtil class BaseAPIInfo: def __init__(self, api_name, is_forward, is_save_data, save_path, forward_path, backward_path): @@ -56,12 +56,12 @@ class BaseAPIInfo: else: api_args = self.api_name + '.' + str(self.args_num) if self.is_forward: - forward_real_data_path = os.path.join(self.save_path, self.forward_path, "rank" + str(self.rank)) + forward_real_data_path = os.path.join(self.save_path, self.forward_path, "step" + str(DumpUtil.call_num), "rank" + str(self.rank)) check_path_before_create(forward_real_data_path) create_directory(forward_real_data_path) file_path = os.path.join(forward_real_data_path, f'{api_args}.pt') else: - backward_real_data_path = os.path.join(self.save_path, self.backward_path, "rank" + str(self.rank)) + backward_real_data_path = os.path.join(self.save_path, self.backward_path, "step" + str(DumpUtil.call_num), "rank" + str(self.rank)) check_path_before_create(backward_real_data_path) create_directory(backward_real_data_path) file_path = os.path.join(backward_real_data_path, f'{api_args}.pt') diff --git a/debug/accuracy_tools/api_accuracy_checker/common/config.py b/debug/accuracy_tools/api_accuracy_checker/common/config.py index 36df4bb014..ce02358401 100644 --- a/debug/accuracy_tools/api_accuracy_checker/common/config.py +++ b/debug/accuracy_tools/api_accuracy_checker/common/config.py @@ -18,13 +18,16 @@ class Config: 'real_data': bool, 'dump_step': int, 'error_data_path': str, - 'target_iter': int, + 'target_iter': list, # 修改target_iter的类型为list 'precision': int } if not isinstance(value, validators.get(key)): raise ValueError(f"{key} must be {validators[key].__name__} type") - if key == 'target_iter' and value < 0: - raise ValueError("target_iter must be greater than 0") + if key == 'target_iter': + if not all(isinstance(i, int) for i in value): + raise ValueError("All elements in target_iter must be int type") + if any(i < 0 for i in value): + raise ValueError("All elements in target_iter must be greater than 0") if key == 'precision' and value < 0: raise ValueError("precision must be greater than 0") return value @@ -35,7 +38,7 @@ class Config: def __str__(self): return '\n'.join(f"{key}={value}" for key, value in self.config.items()) - def update_config(self, dump_path, real_data=False, target_iter=1): + def update_config(self, dump_path, real_data=False, target_iter=[1]): args = { "dump_path": dump_path, "real_data": real_data, diff --git a/debug/accuracy_tools/api_accuracy_checker/dump/dump.py b/debug/accuracy_tools/api_accuracy_checker/dump/dump.py index 0120c10ddc..fc1a57bc7b 100644 --- a/debug/accuracy_tools/api_accuracy_checker/dump/dump.py +++ b/debug/accuracy_tools/api_accuracy_checker/dump/dump.py @@ -46,10 +46,10 @@ class DumpUtil(object): @staticmethod def incr_iter_num_maybe_exit(): - if DumpUtil.call_num == msCheckerConfig.target_iter: + if DumpUtil.call_num in msCheckerConfig.target_iter: set_dump_switch("ON") - elif DumpUtil.call_num > msCheckerConfig.target_iter: - raise Exception("Model pretest: exit after iteration {}".format(msCheckerConfig.target_iter)) + elif DumpUtil.call_num > max(msCheckerConfig.target_iter): + raise Exception("Model pretest: exit after iteration {}".format(DumpUtil.call_num)) else: set_dump_switch("OFF") DumpUtil.call_num += 1 diff --git a/debug/accuracy_tools/api_accuracy_checker/dump/info_dump.py b/debug/accuracy_tools/api_accuracy_checker/dump/info_dump.py index a8071f422d..241afc742b 100644 --- a/debug/accuracy_tools/api_accuracy_checker/dump/info_dump.py +++ b/debug/accuracy_tools/api_accuracy_checker/dump/info_dump.py @@ -6,6 +6,7 @@ import threading from api_accuracy_checker.dump.api_info import ForwardAPIInfo, BackwardAPIInfo from api_accuracy_checker.common.utils import check_file_or_directory_path, initialize_save_path from api_accuracy_checker.common.config import msCheckerConfig +from api_accuracy_checker.dump.dump import DumpUtil from ptdbg_ascend.src.python.ptdbg_ascend.common.file_check_util import FileOpen, FileCheckConst, FileChecker @@ -14,6 +15,7 @@ lock = threading.Lock() def write_api_info_json(api_info): dump_path = msCheckerConfig.dump_path + dump_path = os.path.join(msCheckerConfig.dump_path, "step" + str(DumpUtil.call_num)) rank = api_info.rank if isinstance(api_info, ForwardAPIInfo): file_path = os.path.join(dump_path, f'forward_info_{rank}.json') -- Gitee From 0e789bb2eff91ed786cbe653f004be0ef74e9a8a Mon Sep 17 00:00:00 2001 From: l00826754 Date: Tue, 14 Nov 2023 20:14:21 -0500 Subject: [PATCH 2/4] dump by step --- .../api_accuracy_checker/common/base_api.py | 7 ++++--- .../api_accuracy_checker/common/config.py | 6 ++++-- .../accuracy_tools/api_accuracy_checker/config.yaml | 2 +- .../api_accuracy_checker/dump/info_dump.py | 13 +++++++------ 4 files changed, 16 insertions(+), 12 deletions(-) diff --git a/debug/accuracy_tools/api_accuracy_checker/common/base_api.py b/debug/accuracy_tools/api_accuracy_checker/common/base_api.py index 225dab81c6..e97027ee94 100644 --- a/debug/accuracy_tools/api_accuracy_checker/common/base_api.py +++ b/debug/accuracy_tools/api_accuracy_checker/common/base_api.py @@ -2,7 +2,7 @@ import os import torch from api_accuracy_checker.common.utils import print_error_log, write_pt, create_directory from ptdbg_ascend.src.python.ptdbg_ascend.common.utils import check_path_before_create -from api_accuracy_checker.dump.dump import DumpUtil + class BaseAPIInfo: def __init__(self, api_name, is_forward, is_save_data, save_path, forward_path, backward_path): @@ -55,13 +55,14 @@ class BaseAPIInfo: else: api_args = self.api_name + '.' + str(self.args_num) + from api_accuracy_checker.dump.dump import DumpUtil if self.is_forward: - forward_real_data_path = os.path.join(self.save_path, self.forward_path, "step" + str(DumpUtil.call_num), "rank" + str(self.rank)) + forward_real_data_path = os.path.join(self.save_path, "step" + str(DumpUtil.call_num-1), self.forward_path, "rank" + str(self.rank)) check_path_before_create(forward_real_data_path) create_directory(forward_real_data_path) file_path = os.path.join(forward_real_data_path, f'{api_args}.pt') else: - backward_real_data_path = os.path.join(self.save_path, self.backward_path, "step" + str(DumpUtil.call_num), "rank" + str(self.rank)) + backward_real_data_path = os.path.join(self.save_path, "step" + str(DumpUtil.call_num-1), self.backward_path, "rank" + str(self.rank)) check_path_before_create(backward_real_data_path) create_directory(backward_real_data_path) file_path = os.path.join(backward_real_data_path, f'{api_args}.pt') diff --git a/debug/accuracy_tools/api_accuracy_checker/common/config.py b/debug/accuracy_tools/api_accuracy_checker/common/config.py index ce02358401..7d1bc45a6d 100644 --- a/debug/accuracy_tools/api_accuracy_checker/common/config.py +++ b/debug/accuracy_tools/api_accuracy_checker/common/config.py @@ -18,7 +18,7 @@ class Config: 'real_data': bool, 'dump_step': int, 'error_data_path': str, - 'target_iter': list, # 修改target_iter的类型为list + 'target_iter': list, 'precision': int } if not isinstance(value, validators.get(key)): @@ -38,7 +38,9 @@ class Config: def __str__(self): return '\n'.join(f"{key}={value}" for key, value in self.config.items()) - def update_config(self, dump_path, real_data=False, target_iter=[1]): + def update_config(self, dump_path, real_data=False, target_iter=None): + if target_iter is None: + target_iter = self.config.get('target_iter') args = { "dump_path": dump_path, "real_data": real_data, diff --git a/debug/accuracy_tools/api_accuracy_checker/config.yaml b/debug/accuracy_tools/api_accuracy_checker/config.yaml index 22ef99c3a0..4a1420eb46 100644 --- a/debug/accuracy_tools/api_accuracy_checker/config.yaml +++ b/debug/accuracy_tools/api_accuracy_checker/config.yaml @@ -3,6 +3,6 @@ jit_compile: True real_data: False dump_step: 1000 error_data_path: './' -target_iter: 1 +target_iter: [1] precision: 14 \ No newline at end of file diff --git a/debug/accuracy_tools/api_accuracy_checker/dump/info_dump.py b/debug/accuracy_tools/api_accuracy_checker/dump/info_dump.py index 241afc742b..e1773bade4 100644 --- a/debug/accuracy_tools/api_accuracy_checker/dump/info_dump.py +++ b/debug/accuracy_tools/api_accuracy_checker/dump/info_dump.py @@ -4,9 +4,10 @@ import os import threading from api_accuracy_checker.dump.api_info import ForwardAPIInfo, BackwardAPIInfo -from api_accuracy_checker.common.utils import check_file_or_directory_path, initialize_save_path +from api_accuracy_checker.common.utils import check_file_or_directory_path, initialize_save_path, create_directory +from ptdbg_ascend.src.python.ptdbg_ascend.common.utils import check_path_before_create from api_accuracy_checker.common.config import msCheckerConfig -from api_accuracy_checker.dump.dump import DumpUtil + from ptdbg_ascend.src.python.ptdbg_ascend.common.file_check_util import FileOpen, FileCheckConst, FileChecker @@ -14,8 +15,11 @@ lock = threading.Lock() def write_api_info_json(api_info): + from api_accuracy_checker.dump.dump import DumpUtil dump_path = msCheckerConfig.dump_path - dump_path = os.path.join(msCheckerConfig.dump_path, "step" + str(DumpUtil.call_num)) + dump_path = os.path.join(msCheckerConfig.dump_path, "step" + str(DumpUtil.call_num-1)) + check_path_before_create(dump_path) + create_directory(dump_path) rank = api_info.rank if isinstance(api_info, ForwardAPIInfo): file_path = os.path.join(dump_path, f'forward_info_{rank}.json') @@ -58,9 +62,6 @@ def initialize_output_json(): dump_path_checker = FileChecker(msCheckerConfig.dump_path, FileCheckConst.DIR) dump_path = dump_path_checker.common_check() files = ['forward_info.json', 'backward_info.json', 'stack_info.json'] - if msCheckerConfig.real_data: - initialize_save_path(dump_path, 'forward_real_data') - initialize_save_path(dump_path, 'backward_real_data') for file in files: file_path = os.path.join(dump_path, file) if os.path.exists(file_path): -- Gitee From af99c463999621557730a1788b7c867dbb7d3ba0 Mon Sep 17 00:00:00 2001 From: s30048155 Date: Mon, 20 Nov 2023 09:44:02 +0800 Subject: [PATCH 3/4] clean code --- debug/accuracy_tools/api_accuracy_checker/common/base_api.py | 4 ++-- debug/accuracy_tools/api_accuracy_checker/dump/info_dump.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/debug/accuracy_tools/api_accuracy_checker/common/base_api.py b/debug/accuracy_tools/api_accuracy_checker/common/base_api.py index e97027ee94..5bcc0e78af 100644 --- a/debug/accuracy_tools/api_accuracy_checker/common/base_api.py +++ b/debug/accuracy_tools/api_accuracy_checker/common/base_api.py @@ -57,12 +57,12 @@ class BaseAPIInfo: api_args = self.api_name + '.' + str(self.args_num) from api_accuracy_checker.dump.dump import DumpUtil if self.is_forward: - forward_real_data_path = os.path.join(self.save_path, "step" + str(DumpUtil.call_num-1), self.forward_path, "rank" + str(self.rank)) + forward_real_data_path = os.path.join(self.save_path, "step" + str(DumpUtil.call_num - 1), self.forward_path, "rank" + str(self.rank)) check_path_before_create(forward_real_data_path) create_directory(forward_real_data_path) file_path = os.path.join(forward_real_data_path, f'{api_args}.pt') else: - backward_real_data_path = os.path.join(self.save_path, "step" + str(DumpUtil.call_num-1), self.backward_path, "rank" + str(self.rank)) + backward_real_data_path = os.path.join(self.save_path, "step" + str(DumpUtil.call_num - 1), self.backward_path, "rank" + str(self.rank)) check_path_before_create(backward_real_data_path) create_directory(backward_real_data_path) file_path = os.path.join(backward_real_data_path, f'{api_args}.pt') diff --git a/debug/accuracy_tools/api_accuracy_checker/dump/info_dump.py b/debug/accuracy_tools/api_accuracy_checker/dump/info_dump.py index e1773bade4..69de65912f 100644 --- a/debug/accuracy_tools/api_accuracy_checker/dump/info_dump.py +++ b/debug/accuracy_tools/api_accuracy_checker/dump/info_dump.py @@ -17,7 +17,7 @@ lock = threading.Lock() def write_api_info_json(api_info): from api_accuracy_checker.dump.dump import DumpUtil dump_path = msCheckerConfig.dump_path - dump_path = os.path.join(msCheckerConfig.dump_path, "step" + str(DumpUtil.call_num-1)) + dump_path = os.path.join(msCheckerConfig.dump_path, "step" + str(DumpUtil.call_num - 1)) check_path_before_create(dump_path) create_directory(dump_path) rank = api_info.rank -- Gitee From 84d6b1de1b8b3cdf16d18808dc61e66e7cca68c6 Mon Sep 17 00:00:00 2001 From: s30048155 Date: Tue, 21 Nov 2023 18:30:45 +0800 Subject: [PATCH 4/4] update --- .../api_accuracy_checker/common/config.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/debug/accuracy_tools/api_accuracy_checker/common/config.py b/debug/accuracy_tools/api_accuracy_checker/common/config.py index 7d1bc45a6d..7d355b7b9b 100644 --- a/debug/accuracy_tools/api_accuracy_checker/common/config.py +++ b/debug/accuracy_tools/api_accuracy_checker/common/config.py @@ -24,10 +24,14 @@ class Config: if not isinstance(value, validators.get(key)): raise ValueError(f"{key} must be {validators[key].__name__} type") if key == 'target_iter': - if not all(isinstance(i, int) for i in value): - raise ValueError("All elements in target_iter must be int type") + if not isinstance(value, list): + raise ValueError("target_iter must be a list type") + if any(isinstance(i, bool) for i in value): + raise ValueError("target_iter cannot contain boolean values") + if not all(isinstance(i, int) for i in value): + raise ValueError("All elements in target_iter must be of int type") if any(i < 0 for i in value): - raise ValueError("All elements in target_iter must be greater than 0") + raise ValueError("All elements in target_iter must be greater than or equal to 0") if key == 'precision' and value < 0: raise ValueError("precision must be greater than 0") return value @@ -40,7 +44,7 @@ class Config: def update_config(self, dump_path, real_data=False, target_iter=None): if target_iter is None: - target_iter = self.config.get('target_iter') + target_iter = self.config.get('target_iter',[1]) args = { "dump_path": dump_path, "real_data": real_data, -- Gitee