From 498e2d76438036669939e387f317bfcb8ab4a0ce Mon Sep 17 00:00:00 2001 From: j30031267 Date: Sat, 28 Oct 2023 18:50:49 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E5=A4=8D=E6=BA=A2=E5=87=BA=E6=A3=80?= =?UTF-8?q?=E6=B5=8B=E8=87=AA=E5=8A=A8=E5=88=A4=E5=AE=9A=E5=AD=98=E7=9B=98?= =?UTF-8?q?=E6=95=B0=E6=8D=AE=E8=B7=AF=E5=BE=84=E5=90=8D=E5=B7=B2=E5=AD=98?= =?UTF-8?q?=E5=9C=A8=E6=97=B6=E6=8A=A5=E9=94=99=E7=9A=84=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../src/python/ptdbg_ascend/debugger/precision_debugger.py | 7 ++++++- .../ptdbg_ascend/src/python/ptdbg_ascend/dump/utils.py | 1 + .../src/python/ptdbg_ascend/overflow_check/info_dump.py | 6 +++--- .../python/ptdbg_ascend/overflow_check/overflow_check.py | 3 +-- 4 files changed, 11 insertions(+), 6 deletions(-) diff --git a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/debugger/precision_debugger.py b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/debugger/precision_debugger.py index de1ff96b26..63d2bdf120 100644 --- a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/debugger/precision_debugger.py +++ b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/debugger/precision_debugger.py @@ -53,13 +53,18 @@ class PrecisionDebugger: elif 'backward' in scope[0]: set_backward_input(backward_input) - def configure_overflow_dump(self, mode="api", acl_config=None, overflow_nums=1, filter_switch = Const.OFF): + def configure_overflow_dump(self, mode="api", acl_config=None, overflow_nums=1, filter_switch=Const.OFF, need_replicate=False): if mode == "acl": DumpUtil.dump_switch_mode = mode DumpUtil.set_acl_config(acl_config) init_overflow_nums(overflow_nums) check_switch_valid(filter_switch) OverFlowUtil.overflow_filter_switch = filter_switch + if not isinstance(need_replicate, bool): + print_error_log("Params autojudge only support True or False.") + raise CompareException(CompareException.INVALID_PARAM_ERROR) + if need_replicate: + DumpUtil.need_replicate = True @classmethod def start(cls): diff --git a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/dump/utils.py b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/dump/utils.py index e6cfa1f444..35f5ad4a90 100644 --- a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/dump/utils.py +++ b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/dump/utils.py @@ -36,6 +36,7 @@ class DumpUtil(object): iter_num = 0 target_rank = None summary_only = False + need_replicate = False @staticmethod def set_dump_path(save_path): diff --git a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/overflow_check/info_dump.py b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/overflow_check/info_dump.py index 6a46a592fa..22040a62aa 100644 --- a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/overflow_check/info_dump.py +++ b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/overflow_check/info_dump.py @@ -7,7 +7,7 @@ import threading import numpy as np -from ..common.utils import print_error_log +from ..common.utils import print_error_log, get_time from ..common.file_check_util import FileOpen @@ -80,13 +80,13 @@ class APIInfo: api_args = self.api_name + '.' + str(self.args_num) rank = arg.device.index if self.is_forward: - forward_real_data_path = os.path.join(dump_path, 'forward_real_data', f"rank{rank}") + forward_real_data_path = os.path.join(dump_path, "forward_real_data_" + get_time(), f"rank{rank}") if not os.path.exists(forward_real_data_path): os.makedirs(forward_real_data_path, 0o755) file_path = os.path.join(forward_real_data_path, f'{api_args}.npy') else: - backward_real_data_path = os.path.join(dump_path, 'backward_real_data', f"rank{rank}") + backward_real_data_path = os.path.join(dump_path, "backward_real_data_" + get_time(), f"rank{rank}") if not os.path.exists(backward_real_data_path): os.makedirs(backward_real_data_path, 0o755) file_path = os.path.join(backward_real_data_path, f'{api_args}.npy') diff --git a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/overflow_check/overflow_check.py b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/overflow_check/overflow_check.py index 561b118e85..f389081807 100644 --- a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/overflow_check/overflow_check.py +++ b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/overflow_check/overflow_check.py @@ -125,8 +125,7 @@ def overflow_check(name, **kwargs): if hasattr(module, 'input_kwargs'): del module.input_kwargs if module.has_overflow and OverFlowUtil.check_overflow_dump_times(overflow_nums): - need_replicate = overflow_type_judge(in_feat, out_feat, module_name) - if need_replicate: + if overflow_type_judge(in_feat, out_feat, module_name) and DumpUtil.need_replicate: if module_name.endswith(Const.FORWARD): forward_api_info.update({name: ForwardAPIInfo(name, True, module.input_args, module.input_kwargs)}) api_overflow.append(module_name) -- Gitee