diff --git a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/debugger/precision_debugger.py b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/debugger/precision_debugger.py index de1ff96b2633bd4e123e963c00c10538c92d72ba..63d2bdf120c9bcc826370e1d6717a740dd3e5a64 100644 --- a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/debugger/precision_debugger.py +++ b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/debugger/precision_debugger.py @@ -53,13 +53,18 @@ class PrecisionDebugger: elif 'backward' in scope[0]: set_backward_input(backward_input) - def configure_overflow_dump(self, mode="api", acl_config=None, overflow_nums=1, filter_switch = Const.OFF): + def configure_overflow_dump(self, mode="api", acl_config=None, overflow_nums=1, filter_switch=Const.OFF, need_replicate=False): if mode == "acl": DumpUtil.dump_switch_mode = mode DumpUtil.set_acl_config(acl_config) init_overflow_nums(overflow_nums) check_switch_valid(filter_switch) OverFlowUtil.overflow_filter_switch = filter_switch + if not isinstance(need_replicate, bool): + print_error_log("Params autojudge only support True or False.") + raise CompareException(CompareException.INVALID_PARAM_ERROR) + if need_replicate: + DumpUtil.need_replicate = True @classmethod def start(cls): diff --git a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/dump/utils.py b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/dump/utils.py index e6cfa1f444b67b4a2b4371f7639cefc4334dfca8..35f5ad4a905f9d1ce7b055d1723fd4babcc3d9fd 100644 --- a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/dump/utils.py +++ b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/dump/utils.py @@ -36,6 +36,7 @@ class DumpUtil(object): iter_num = 0 target_rank = None summary_only = False + need_replicate = False @staticmethod def set_dump_path(save_path): diff --git a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/overflow_check/info_dump.py b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/overflow_check/info_dump.py index 6a46a592fa0b773ee3e6bfda32a24a3e81e25b4f..22040a62aa5b19e97c978f024790f262cef8d857 100644 --- a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/overflow_check/info_dump.py +++ b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/overflow_check/info_dump.py @@ -7,7 +7,7 @@ import threading import numpy as np -from ..common.utils import print_error_log +from ..common.utils import print_error_log, get_time from ..common.file_check_util import FileOpen @@ -80,13 +80,13 @@ class APIInfo: api_args = self.api_name + '.' + str(self.args_num) rank = arg.device.index if self.is_forward: - forward_real_data_path = os.path.join(dump_path, 'forward_real_data', f"rank{rank}") + forward_real_data_path = os.path.join(dump_path, "forward_real_data_" + get_time(), f"rank{rank}") if not os.path.exists(forward_real_data_path): os.makedirs(forward_real_data_path, 0o755) file_path = os.path.join(forward_real_data_path, f'{api_args}.npy') else: - backward_real_data_path = os.path.join(dump_path, 'backward_real_data', f"rank{rank}") + backward_real_data_path = os.path.join(dump_path, "backward_real_data_" + get_time(), f"rank{rank}") if not os.path.exists(backward_real_data_path): os.makedirs(backward_real_data_path, 0o755) file_path = os.path.join(backward_real_data_path, f'{api_args}.npy') diff --git a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/overflow_check/overflow_check.py b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/overflow_check/overflow_check.py index 561b118e8541bafd4ffb171eb7329c94401636c6..f3890818070079db93f39bfb39ac504a9e6ca82a 100644 --- a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/overflow_check/overflow_check.py +++ b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/overflow_check/overflow_check.py @@ -125,8 +125,7 @@ def overflow_check(name, **kwargs): if hasattr(module, 'input_kwargs'): del module.input_kwargs if module.has_overflow and OverFlowUtil.check_overflow_dump_times(overflow_nums): - need_replicate = overflow_type_judge(in_feat, out_feat, module_name) - if need_replicate: + if overflow_type_judge(in_feat, out_feat, module_name) and DumpUtil.need_replicate: if module_name.endswith(Const.FORWARD): forward_api_info.update({name: ForwardAPIInfo(name, True, module.input_args, module.input_kwargs)}) api_overflow.append(module_name)