From ac1a469144554b2f991897927b6c9b2fd1d37d29 Mon Sep 17 00:00:00 2001 From: l30036321 Date: Fri, 24 Nov 2023 17:17:43 +0800 Subject: [PATCH 1/2] support init dump --- .../src/python/ptdbg_ascend/common/utils.py | 2 +- .../debugger/precision_debugger.py | 6 ++++-- .../src/python/ptdbg_ascend/dump/dump.py | 12 ++++++++++++ .../ptdbg_ascend/hook_module/register_hook.py | 19 +++++++++++++++---- 4 files changed, 32 insertions(+), 7 deletions(-) diff --git a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/common/utils.py b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/common/utils.py index 21cae3c52a..a98f56faba 100644 --- a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/common/utils.py +++ b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/common/utils.py @@ -68,7 +68,7 @@ class Const: DUMP_RATIO_MAX = 100 SUMMERY_DATA_NUMS = 256 FLOAT_EPSILON = np.finfo(float).eps - SUPPORT_DUMP_MODE = ['api', 'acl'] + SUPPORT_DUMP_MODE = ['api', 'acl', 'model'] ON = 'ON' OFF = 'OFF' BACKWARD = 'backward' diff --git a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/debugger/precision_debugger.py b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/debugger/precision_debugger.py index b2cfdaba44..eec386c386 100644 --- a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/debugger/precision_debugger.py +++ b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/debugger/precision_debugger.py @@ -16,8 +16,9 @@ class PrecisionDebugger: first_start = True hook_func = None config = None + model = None - def __init__(self, dump_path=None, hook_name=None, rank=None, step=None, enable_dataloader=False): + def __init__(self, dump_path=None, hook_name=None, rank=None, step=None, enable_dataloader=False, model=None): if hook_name is None: err_msg = "You must provide hook_name argument to PrecisionDebugger\ when config is not provided." @@ -30,6 +31,7 @@ class PrecisionDebugger: DumpUtil.target_rank = self.config.rank set_dump_path(self.config.dump_path) PrecisionDebugger.hook_func = overflow_check if self.config.hook_name == "overflow_check" else acc_cmp_dump + PrecisionDebugger.model = model if not isinstance(enable_dataloader, bool): print_error_log("Params enable_dataloader only support True or False.") raise CompareException(CompareException.INVALID_PARAM_ERROR) @@ -74,7 +76,7 @@ class PrecisionDebugger: def start(cls): if DumpUtil.iter_num in DumpUtil.target_iter or len(DumpUtil.target_iter) == 0: if cls.first_start: - register_hook_core(cls.hook_func) + register_hook_core(cls.hook_func, cls.model) cls.first_start = False DumpUtil.dump_switch = "ON" OverFlowUtil.overflow_check_switch = "ON" diff --git a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/dump/dump.py b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/dump/dump.py index 87eee25f11..b6560297f6 100644 --- a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/dump/dump.py +++ b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/dump/dump.py @@ -20,6 +20,7 @@ import json import os import threading from pathlib import Path +from collections import defaultdict import numpy as np import torch @@ -46,6 +47,7 @@ thread_lock = threading.Lock() pkl_name = "" rank = os.getpid() multi_output_apis = ["_sort_", "npu_flash_attention"] +module_count = defaultdict(int) class DataInfo(object): @@ -347,6 +349,16 @@ def acc_cmp_dump(name, **kwargs): return RuntimeError("Not get the specified process pid.") def acc_cmp_hook(module, in_feat, out_feat=None): + nonlocal name + if "_{}_" in name: + module_name = name.split("_")[1] + if Const.BACKWARD in name: + index = module_count[module_name] - 1 + module_count[module_name] = index + else: + index = module_count[module_name] + module_count[module_name] = index + 1 + name = name.format(index) if pid == os.getpid(): dump_acc_cmp(name, in_feat, out_feat, dump_step, module) if hasattr(module, "input_args"): diff --git a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/hook_module/register_hook.py b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/hook_module/register_hook.py index 802ffb744a..64faaf2239 100644 --- a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/hook_module/register_hook.py +++ b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/hook_module/register_hook.py @@ -70,7 +70,10 @@ def register_hook(model, hook, **kwargs): if dump_mode == 'acl': DumpUtil.dump_switch_mode = dump_mode DumpUtil.set_acl_config(dump_config_file) - register_hook_core(hook) + if dump_mode == 'model': + register_hook_core(hook, model) + else: + register_hook_core(hook) def init_overflow_nums(overflow_nums): @@ -91,7 +94,7 @@ def check_register_hook(hook, **kwargs): raise CompareException(CompareException.INVALID_PARAM_ERROR) -def register_hook_core(hook): +def register_hook_core(hook, model=None): global make_dir_flag pid = os.getpid() @@ -101,6 +104,9 @@ def register_hook_core(hook): make_dir_flag = False hook_name = hook.__name__ + if "overflow_check" in hook_name and model is not None: + print_error_log("Overflow check does not support model dump mode") + raise CompareException(CompareException.INVALID_PARAM_ERROR) if "overflow_check" in hook_name and not is_gpu: if hasattr(torch_npu._C, "_enable_overflow_npu"): torch_npu._C._enable_overflow_npu() @@ -117,8 +123,13 @@ def register_hook_core(hook): hook = functools.partial(hook, dump_step=0, pid=pid) print_info_log("The {} hook function is successfully mounted to the model.".format(hook_name)) - api_register.initialize_hook(hook) - api_register.api_modularity() + if model is not None: + if not isinstance(model, torch.nn.Module): + print_error_log("The argument model must be an object of torch.nn.Module") + raise CompareException(CompareException.INVALID_PARAM_ERROR) + else: + api_register.initialize_hook(hook) + api_register.api_modularity() if "acc_cmp_dump" in hook_name: remove_dropout() -- Gitee From 8c4ca79098063d69304c961456a6ba58b8baf96c Mon Sep 17 00:00:00 2001 From: l30036321 Date: Fri, 24 Nov 2023 17:32:35 +0800 Subject: [PATCH 2/2] support init dump --- .../src/python/ptdbg_ascend/hook_module/register_hook.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/hook_module/register_hook.py b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/hook_module/register_hook.py index 64faaf2239..d8aed40cfc 100644 --- a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/hook_module/register_hook.py +++ b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/hook_module/register_hook.py @@ -127,6 +127,11 @@ def register_hook_core(hook, model=None): if not isinstance(model, torch.nn.Module): print_error_log("The argument model must be an object of torch.nn.Module") raise CompareException(CompareException.INVALID_PARAM_ERROR) + for _, module in model.named_modules(): + if "torch.nn.modules" in str(module.__class__): + prefix = "Module_" + module.__class__.__name__ + module.register_forward_hook(hook(prefix + "_{}_" + "forward")) + module.register_backward_hook(hook(prefix + "_{}_" + "backward")) else: api_register.initialize_hook(hook) api_register.api_modularity() -- Gitee