diff --git a/debug/accuracy_tools/api_accuracy_checker/common/base_api.py b/debug/accuracy_tools/api_accuracy_checker/common/base_api.py index 8e7c063a64d590b3bafae1dcf62223b7256fb9f1..5bcc0e78af7ad24ee7e0f73f7e3e5ff86648aa26 100644 --- a/debug/accuracy_tools/api_accuracy_checker/common/base_api.py +++ b/debug/accuracy_tools/api_accuracy_checker/common/base_api.py @@ -55,13 +55,14 @@ class BaseAPIInfo: else: api_args = self.api_name + '.' + str(self.args_num) + from api_accuracy_checker.dump.dump import DumpUtil if self.is_forward: - forward_real_data_path = os.path.join(self.save_path, self.forward_path, "rank" + str(self.rank)) + forward_real_data_path = os.path.join(self.save_path, "step" + str(DumpUtil.call_num - 1), self.forward_path, "rank" + str(self.rank)) check_path_before_create(forward_real_data_path) create_directory(forward_real_data_path) file_path = os.path.join(forward_real_data_path, f'{api_args}.pt') else: - backward_real_data_path = os.path.join(self.save_path, self.backward_path, "rank" + str(self.rank)) + backward_real_data_path = os.path.join(self.save_path, "step" + str(DumpUtil.call_num - 1), self.backward_path, "rank" + str(self.rank)) check_path_before_create(backward_real_data_path) create_directory(backward_real_data_path) file_path = os.path.join(backward_real_data_path, f'{api_args}.pt') diff --git a/debug/accuracy_tools/api_accuracy_checker/common/config.py b/debug/accuracy_tools/api_accuracy_checker/common/config.py index 36df4bb01406e098b2deb7732237bc346ad9ca55..7d355b7b9b8854d3819d9fb6562451e69f5fc60d 100644 --- a/debug/accuracy_tools/api_accuracy_checker/common/config.py +++ b/debug/accuracy_tools/api_accuracy_checker/common/config.py @@ -18,13 +18,20 @@ class Config: 'real_data': bool, 'dump_step': int, 'error_data_path': str, - 'target_iter': int, + 'target_iter': list, 'precision': int } if not isinstance(value, validators.get(key)): raise ValueError(f"{key} must be {validators[key].__name__} type") - if key == 'target_iter' and value < 0: - raise ValueError("target_iter must be greater than 0") + if key == 'target_iter': + if not isinstance(value, list): + raise ValueError("target_iter must be a list type") + if any(isinstance(i, bool) for i in value): + raise ValueError("target_iter cannot contain boolean values") + if not all(isinstance(i, int) for i in value): + raise ValueError("All elements in target_iter must be of int type") + if any(i < 0 for i in value): + raise ValueError("All elements in target_iter must be greater than or equal to 0") if key == 'precision' and value < 0: raise ValueError("precision must be greater than 0") return value @@ -35,7 +42,9 @@ class Config: def __str__(self): return '\n'.join(f"{key}={value}" for key, value in self.config.items()) - def update_config(self, dump_path, real_data=False, target_iter=1): + def update_config(self, dump_path, real_data=False, target_iter=None): + if target_iter is None: + target_iter = self.config.get('target_iter',[1]) args = { "dump_path": dump_path, "real_data": real_data, diff --git a/debug/accuracy_tools/api_accuracy_checker/config.yaml b/debug/accuracy_tools/api_accuracy_checker/config.yaml index 22ef99c3a0edf276b8622e657bdbb517e0cadaf8..4a1420eb4636b71530394d7ee64cf39af7a8523a 100644 --- a/debug/accuracy_tools/api_accuracy_checker/config.yaml +++ b/debug/accuracy_tools/api_accuracy_checker/config.yaml @@ -3,6 +3,6 @@ jit_compile: True real_data: False dump_step: 1000 error_data_path: './' -target_iter: 1 +target_iter: [1] precision: 14 \ No newline at end of file diff --git a/debug/accuracy_tools/api_accuracy_checker/dump/dump.py b/debug/accuracy_tools/api_accuracy_checker/dump/dump.py index 0120c10ddc27be95c1d1c2f07d9ea1098329dea7..fc1a57bc7b6e9ac72a8090ce0714497eb66bcb2f 100644 --- a/debug/accuracy_tools/api_accuracy_checker/dump/dump.py +++ b/debug/accuracy_tools/api_accuracy_checker/dump/dump.py @@ -46,10 +46,10 @@ class DumpUtil(object): @staticmethod def incr_iter_num_maybe_exit(): - if DumpUtil.call_num == msCheckerConfig.target_iter: + if DumpUtil.call_num in msCheckerConfig.target_iter: set_dump_switch("ON") - elif DumpUtil.call_num > msCheckerConfig.target_iter: - raise Exception("Model pretest: exit after iteration {}".format(msCheckerConfig.target_iter)) + elif DumpUtil.call_num > max(msCheckerConfig.target_iter): + raise Exception("Model pretest: exit after iteration {}".format(DumpUtil.call_num)) else: set_dump_switch("OFF") DumpUtil.call_num += 1 diff --git a/debug/accuracy_tools/api_accuracy_checker/dump/info_dump.py b/debug/accuracy_tools/api_accuracy_checker/dump/info_dump.py index a8071f422d6577e4c5c64a39f4d26a4cac3a3978..69de65912fbc2a0d0ca3ceabe4d15a4f7f65e0c5 100644 --- a/debug/accuracy_tools/api_accuracy_checker/dump/info_dump.py +++ b/debug/accuracy_tools/api_accuracy_checker/dump/info_dump.py @@ -4,16 +4,22 @@ import os import threading from api_accuracy_checker.dump.api_info import ForwardAPIInfo, BackwardAPIInfo -from api_accuracy_checker.common.utils import check_file_or_directory_path, initialize_save_path +from api_accuracy_checker.common.utils import check_file_or_directory_path, initialize_save_path, create_directory +from ptdbg_ascend.src.python.ptdbg_ascend.common.utils import check_path_before_create from api_accuracy_checker.common.config import msCheckerConfig + from ptdbg_ascend.src.python.ptdbg_ascend.common.file_check_util import FileOpen, FileCheckConst, FileChecker lock = threading.Lock() def write_api_info_json(api_info): + from api_accuracy_checker.dump.dump import DumpUtil dump_path = msCheckerConfig.dump_path + dump_path = os.path.join(msCheckerConfig.dump_path, "step" + str(DumpUtil.call_num - 1)) + check_path_before_create(dump_path) + create_directory(dump_path) rank = api_info.rank if isinstance(api_info, ForwardAPIInfo): file_path = os.path.join(dump_path, f'forward_info_{rank}.json') @@ -56,9 +62,6 @@ def initialize_output_json(): dump_path_checker = FileChecker(msCheckerConfig.dump_path, FileCheckConst.DIR) dump_path = dump_path_checker.common_check() files = ['forward_info.json', 'backward_info.json', 'stack_info.json'] - if msCheckerConfig.real_data: - initialize_save_path(dump_path, 'forward_real_data') - initialize_save_path(dump_path, 'backward_real_data') for file in files: file_path = os.path.join(dump_path, file) if os.path.exists(file_path):