From af2684f5b189bf989f1b2df8ccbf253e0754bc58 Mon Sep 17 00:00:00 2001 From: l30044004 Date: Fri, 22 Aug 2025 15:59:46 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90bugfix=E3=80=91=E5=A4=84=E7=90=86check?= =?UTF-8?q?point=E8=AE=BE=E7=BD=AEearly=5Fstop=E4=B8=BAFalse=E5=AF=BC?= =?UTF-8?q?=E8=87=B4deepspeed=20stage3=20dump=E6=8A=A5=E9=94=99=E9=97=AE?= =?UTF-8?q?=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../dump/module_dump/module_processer.py | 36 ++++++++++++------- 1 file changed, 24 insertions(+), 12 deletions(-) diff --git a/debug/accuracy_tools/msprobe/pytorch/dump/module_dump/module_processer.py b/debug/accuracy_tools/msprobe/pytorch/dump/module_dump/module_processer.py index bd048e63f..ef4cdfd94 100644 --- a/debug/accuracy_tools/msprobe/pytorch/dump/module_dump/module_processer.py +++ b/debug/accuracy_tools/msprobe/pytorch/dump/module_dump/module_processer.py @@ -28,17 +28,7 @@ from msprobe.pytorch.dump.module_dump.hook_wrapper import wrap_setup_input_outpu torch_version_above_or_equal_2 = torch.__version__.split('+')[0] >= '2.0' if torch_version_above_or_equal_2: - from torch.utils.checkpoint import checkpoint as origin_checkpoint, set_checkpoint_early_stop - - -def checkpoint_without_early_stop(*args, **kwargs): - with set_checkpoint_early_stop(False): - return origin_checkpoint(*args, **kwargs) - - -def replace_checkpoint(): - if torch_version_above_or_equal_2: - torch.utils.checkpoint.checkpoint = checkpoint_without_early_stop + from torch.utils.checkpoint import _StopRecomputationError def wrap_megatron_deallocate(func): @@ -51,6 +41,28 @@ def wrap_megatron_deallocate(func): return wrapper_func +def wrap_forward_with_hook_safety(module): + """ + 包装模块的forward方法,确保异常时也执行forward_hook。 + """ + original_forward = module.forward + + def wrapped_forward(*args, **kwargs): + try: + output = original_forward(*args, **kwargs) + return output + except _StopRecomputationError as e: + exception_output = None + if len(module._forward_hooks.values()) > 0: + # msprobe的forward_hook会出现在第一个,仅执行msprobe的forward_hook + hook_fn = list(module._forward_hooks.values())[0] + hook_fn(module, args, kwargs, exception_output) + raise e + + if torch_version_above_or_equal_2: + module.forward = wrapped_forward + + class ModuleProcesser: module_count = {} module_stack = [] @@ -63,7 +75,6 @@ class ModuleProcesser: def __init__(self, scope): self.scope = scope if isinstance(scope, (ModuleRangeScope, MixRangeScope)) else None wrap_setup_input_output_hook() - replace_checkpoint() try: from megatron.core.pipeline_parallel import schedules origin_func_id = id(schedules.deallocate_output_tensor) @@ -151,6 +162,7 @@ class ModuleProcesser: f"which may cause abnormal data dump. The backward data dump for this module will be skipped." ) ModuleProcesser.module_with_backward_hook[prefix_name] = True + wrap_forward_with_hook_safety(module) register_forward_pre_hook(module, forward_pre_hook) def build_module_hook(self, module_name, build_data_hook): -- Gitee