diff --git a/debug/accuracy_tools/msprobe/pytorch/dump/module_dump/module_processer.py b/debug/accuracy_tools/msprobe/pytorch/dump/module_dump/module_processer.py index bd048e63feb9edfcba4ec2e65e528366e42b3253..ef4cdfd94a18d2f88f56fc39d747c8e14ad42853 100644 --- a/debug/accuracy_tools/msprobe/pytorch/dump/module_dump/module_processer.py +++ b/debug/accuracy_tools/msprobe/pytorch/dump/module_dump/module_processer.py @@ -28,17 +28,7 @@ from msprobe.pytorch.dump.module_dump.hook_wrapper import wrap_setup_input_outpu torch_version_above_or_equal_2 = torch.__version__.split('+')[0] >= '2.0' if torch_version_above_or_equal_2: - from torch.utils.checkpoint import checkpoint as origin_checkpoint, set_checkpoint_early_stop - - -def checkpoint_without_early_stop(*args, **kwargs): - with set_checkpoint_early_stop(False): - return origin_checkpoint(*args, **kwargs) - - -def replace_checkpoint(): - if torch_version_above_or_equal_2: - torch.utils.checkpoint.checkpoint = checkpoint_without_early_stop + from torch.utils.checkpoint import _StopRecomputationError def wrap_megatron_deallocate(func): @@ -51,6 +41,28 @@ def wrap_megatron_deallocate(func): return wrapper_func +def wrap_forward_with_hook_safety(module): + """ + 包装模块的forward方法,确保异常时也执行forward_hook。 + """ + original_forward = module.forward + + def wrapped_forward(*args, **kwargs): + try: + output = original_forward(*args, **kwargs) + return output + except _StopRecomputationError as e: + exception_output = None + if len(module._forward_hooks.values()) > 0: + # msprobe的forward_hook会出现在第一个,仅执行msprobe的forward_hook + hook_fn = list(module._forward_hooks.values())[0] + hook_fn(module, args, kwargs, exception_output) + raise e + + if torch_version_above_or_equal_2: + module.forward = wrapped_forward + + class ModuleProcesser: module_count = {} module_stack = [] @@ -63,7 +75,6 @@ class ModuleProcesser: def __init__(self, scope): self.scope = scope if isinstance(scope, (ModuleRangeScope, MixRangeScope)) else None wrap_setup_input_output_hook() - replace_checkpoint() try: from megatron.core.pipeline_parallel import schedules origin_func_id = id(schedules.deallocate_output_tensor) @@ -151,6 +162,7 @@ class ModuleProcesser: f"which may cause abnormal data dump. The backward data dump for this module will be skipped." ) ModuleProcesser.module_with_backward_hook[prefix_name] = True + wrap_forward_with_hook_safety(module) register_forward_pre_hook(module, forward_pre_hook) def build_module_hook(self, module_name, build_data_hook):