diff --git a/debug/accuracy_tools/msprobe/pytorch/dump/module_dump/module_processer.py b/debug/accuracy_tools/msprobe/pytorch/dump/module_dump/module_processer.py index d9b67c93175641cbd0009c1f58562787626d375f..33f09d4c9aa9c83b9eb0a44c645d3c0ea6257c9d 100644 --- a/debug/accuracy_tools/msprobe/pytorch/dump/module_dump/module_processer.py +++ b/debug/accuracy_tools/msprobe/pytorch/dump/module_dump/module_processer.py @@ -39,6 +39,16 @@ def replace_checkpoint(): torch.utils.checkpoint.checkpoint = checkpoint_without_early_stop +def wrap_megatron_deallocate(func): + def wrapper_func(out, deallocate_pipeline_outputs=False): + if deallocate_pipeline_outputs and isinstance(out, torch.Tensor) and getattr(out, "_base") is not None: + out_clone = out.clone() + out.data = torch.empty((1,), device=out.device, dtype=out.dtype, ) + return func(out_clone, deallocate_pipeline_outputs) + return func(out, deallocate_pipeline_outputs) + return wrapper_func + + class ModuleProcesser: module_count = {} module_stack = [] @@ -50,6 +60,14 @@ class ModuleProcesser: BackwardHook.setup_input_hook = ModuleProcesser.clone_return_value(BackwardHook.setup_input_hook) BackwardHook.setup_output_hook = ModuleProcesser.clone_return_value(BackwardHook.setup_output_hook) replace_checkpoint() + try: + from megatron.core.pipeline_parallel import schedules + schedules.deallocate_output_tensor = wrap_megatron_deallocate(schedules.deallocate_output_tensor) + logger.info_on_rank_0("Patch megatron method success.") + except ImportError: + logger.info_on_rank_0("No megatron find.") + except Exception as e: + logger.info_on_rank_0(f"Patch megatron method failed, detail:{str(e)}") @staticmethod def clone_return_value(func):