diff --git a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/common/utils.py b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/common/utils.py index 21cae3c52a0c3e486a4c28ce08d497865ea4d582..a98f56fabaa82c95a82461cc1ee867d9b22fc972 100644 --- a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/common/utils.py +++ b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/common/utils.py @@ -68,7 +68,7 @@ class Const: DUMP_RATIO_MAX = 100 SUMMERY_DATA_NUMS = 256 FLOAT_EPSILON = np.finfo(float).eps - SUPPORT_DUMP_MODE = ['api', 'acl'] + SUPPORT_DUMP_MODE = ['api', 'acl', 'model'] ON = 'ON' OFF = 'OFF' BACKWARD = 'backward' diff --git a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/debugger/precision_debugger.py b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/debugger/precision_debugger.py index b2cfdaba44c48077784f35120047335c38409df9..eec386c38694338538163fb989ea624e8f1d92d1 100644 --- a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/debugger/precision_debugger.py +++ b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/debugger/precision_debugger.py @@ -16,8 +16,9 @@ class PrecisionDebugger: first_start = True hook_func = None config = None + model = None - def __init__(self, dump_path=None, hook_name=None, rank=None, step=None, enable_dataloader=False): + def __init__(self, dump_path=None, hook_name=None, rank=None, step=None, enable_dataloader=False, model=None): if hook_name is None: err_msg = "You must provide hook_name argument to PrecisionDebugger\ when config is not provided." @@ -30,6 +31,7 @@ class PrecisionDebugger: DumpUtil.target_rank = self.config.rank set_dump_path(self.config.dump_path) PrecisionDebugger.hook_func = overflow_check if self.config.hook_name == "overflow_check" else acc_cmp_dump + PrecisionDebugger.model = model if not isinstance(enable_dataloader, bool): print_error_log("Params enable_dataloader only support True or False.") raise CompareException(CompareException.INVALID_PARAM_ERROR) @@ -74,7 +76,7 @@ class PrecisionDebugger: def start(cls): if DumpUtil.iter_num in DumpUtil.target_iter or len(DumpUtil.target_iter) == 0: if cls.first_start: - register_hook_core(cls.hook_func) + register_hook_core(cls.hook_func, cls.model) cls.first_start = False DumpUtil.dump_switch = "ON" OverFlowUtil.overflow_check_switch = "ON" diff --git a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/dump/dump.py b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/dump/dump.py index 6e9a7f7b77b7f574bb2a84c86000ce6497759b6d..7a573f70d59b187a263366bcde6e817990560157 100644 --- a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/dump/dump.py +++ b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/dump/dump.py @@ -20,6 +20,7 @@ import json import os import threading from pathlib import Path +from collections import defaultdict import numpy as np import torch @@ -46,6 +47,7 @@ thread_lock = threading.Lock() pkl_name = "" rank = os.getpid() multi_output_apis = ["_sort_", "npu_flash_attention"] +module_count = defaultdict(int) class DataInfo(object): @@ -352,6 +354,16 @@ def acc_cmp_dump(name, **kwargs): return RuntimeError("Not get the specified process pid.") def acc_cmp_hook(module, in_feat, out_feat=None): + nonlocal name + if "_{}_" in name: + module_name = name.split("_")[1] + if Const.BACKWARD in name: + index = module_count[module_name] - 1 + module_count[module_name] = index + else: + index = module_count[module_name] + module_count[module_name] = index + 1 + name = name.format(index) if pid == os.getpid(): dump_acc_cmp(name, in_feat, out_feat, dump_step, module) if hasattr(module, "input_args"): diff --git a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/hook_module/register_hook.py b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/hook_module/register_hook.py index 802ffb744a9a3d22fd27c5fba3e9d51d24b8e4fa..d8aed40cfcf71b4c661355ce9b9a90e1fdd697c4 100644 --- a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/hook_module/register_hook.py +++ b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/hook_module/register_hook.py @@ -70,7 +70,10 @@ def register_hook(model, hook, **kwargs): if dump_mode == 'acl': DumpUtil.dump_switch_mode = dump_mode DumpUtil.set_acl_config(dump_config_file) - register_hook_core(hook) + if dump_mode == 'model': + register_hook_core(hook, model) + else: + register_hook_core(hook) def init_overflow_nums(overflow_nums): @@ -91,7 +94,7 @@ def check_register_hook(hook, **kwargs): raise CompareException(CompareException.INVALID_PARAM_ERROR) -def register_hook_core(hook): +def register_hook_core(hook, model=None): global make_dir_flag pid = os.getpid() @@ -101,6 +104,9 @@ def register_hook_core(hook): make_dir_flag = False hook_name = hook.__name__ + if "overflow_check" in hook_name and model is not None: + print_error_log("Overflow check does not support model dump mode") + raise CompareException(CompareException.INVALID_PARAM_ERROR) if "overflow_check" in hook_name and not is_gpu: if hasattr(torch_npu._C, "_enable_overflow_npu"): torch_npu._C._enable_overflow_npu() @@ -117,8 +123,18 @@ def register_hook_core(hook): hook = functools.partial(hook, dump_step=0, pid=pid) print_info_log("The {} hook function is successfully mounted to the model.".format(hook_name)) - api_register.initialize_hook(hook) - api_register.api_modularity() + if model is not None: + if not isinstance(model, torch.nn.Module): + print_error_log("The argument model must be an object of torch.nn.Module") + raise CompareException(CompareException.INVALID_PARAM_ERROR) + for _, module in model.named_modules(): + if "torch.nn.modules" in str(module.__class__): + prefix = "Module_" + module.__class__.__name__ + module.register_forward_hook(hook(prefix + "_{}_" + "forward")) + module.register_backward_hook(hook(prefix + "_{}_" + "backward")) + else: + api_register.initialize_hook(hook) + api_register.api_modularity() if "acc_cmp_dump" in hook_name: remove_dropout()