From 888bfe921e12859cbcc193f3b45193cfc9489d21 Mon Sep 17 00:00:00 2001 From: l30036321 Date: Tue, 23 Jan 2024 15:37:09 +0800 Subject: [PATCH] init dump fix --- .../python/ptdbg_ascend/debugger/precision_debugger.py | 10 +++++++--- .../ptdbg_ascend/src/python/ptdbg_ascend/dump/dump.py | 8 ++++++-- .../python/ptdbg_ascend/hook_module/register_hook.py | 1 + 3 files changed, 14 insertions(+), 5 deletions(-) diff --git a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/debugger/precision_debugger.py b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/debugger/precision_debugger.py index 6d846b9c2..c56ce22c6 100644 --- a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/debugger/precision_debugger.py +++ b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/debugger/precision_debugger.py @@ -2,7 +2,7 @@ import os import torch from ..common.utils import Const, check_switch_valid, generate_compare_script, check_is_npu, print_error_log, \ CompareException, print_warn_log -from ..dump.dump import DumpUtil, acc_cmp_dump, write_to_disk, get_pkl_file_path +from ..dump.dump import DumpUtil, acc_cmp_dump, write_to_disk, get_pkl_file_path, reset_module_count from ..dump.utils import set_dump_path, set_dump_switch_print_info, generate_dump_path_str, \ set_dump_switch_config, set_backward_input from ..overflow_check.utils import OverFlowUtil @@ -55,7 +55,10 @@ class PrecisionDebugger: def configure_full_dump(self, mode='api_stack', scope=None, api_list=None, filter_switch=Const.OFF, input_output_mode=[Const.ALL], acl_config=None, backward_input=None, summary_only=False): - scope = scope or [] + if mode == "acl" and self.model is not None: + print_error_log("Init dump does not support ACL dump mode.") + raise CompareException(CompareException.INVALID_DUMP_MODE) + scope = scope or [] api_list = api_list or [] backward_input = backward_input or [] set_dump_switch_config(mode=mode, scope=scope, api_list=api_list, @@ -129,6 +132,7 @@ class PrecisionDebugger: DumpUtil.iter_num += 1 DumpUtil.dump_init_enable = True HOOKModule.module_count = {} + reset_module_count() else: print_warn_log("DataLoader is enabled, step() skipped.") @@ -144,4 +148,4 @@ def iter_tracer(func): result = func(*args, **kwargs) PrecisionDebugger.incr_iter_num_maybe_exit() return result - return func_wrapper \ No newline at end of file + return func_wrapper diff --git a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/dump/dump.py b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/dump/dump.py index 83900eb60..4bd7ee1c1 100644 --- a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/dump/dump.py +++ b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/dump/dump.py @@ -20,7 +20,6 @@ import json import os import threading from pathlib import Path -from collections import defaultdict import numpy as np import torch @@ -401,7 +400,7 @@ def acc_cmp_dump(name, **kwargs): except IndexError as e: print_error_log(f"Get module {name_template} index failed.") raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR) from e - name = name.format(index) + name = name_template.format(index) if pid == os.getpid(): dump_acc_cmp(name, in_feat, out_feat, dump_step, module) if hasattr(module, "input_args"): @@ -418,3 +417,8 @@ def write_to_disk(): def get_pkl_file_path(): return pkl_name + + +def reset_module_count(): + global module_count + module_count = {} diff --git a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/hook_module/register_hook.py b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/hook_module/register_hook.py index bc6b9c494..e9abfec4a 100644 --- a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/hook_module/register_hook.py +++ b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/hook_module/register_hook.py @@ -121,6 +121,7 @@ def register_hook_core(hook, model=None): print_info_log("The {} hook function is successfully mounted to the model.".format(hook_name)) if model is not None: + print_info_log("The init dump mode is enabled, and the module dump function will not be available") if not isinstance(model, torch.nn.Module): print_error_log("The argument model must be an object of torch.nn.Module") raise CompareException(CompareException.INVALID_PARAM_ERROR) -- Gitee