From a7a6b788d456ab9f26d35ed0a4ff3bea0e9a7e1b Mon Sep 17 00:00:00 2001 From: lcw Date: Mon, 14 Apr 2025 19:55:50 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90bugfix=E3=80=91=E4=BF=AE=E5=A4=8Dmegat?= =?UTF-8?q?ron=E5=9C=BA=E6=99=AF=E4=B8=8Bpp=E5=88=87=E5=88=86=E6=8A=A5?= =?UTF-8?q?=E9=94=99=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit (cherry picked from commit 4fb6a4af9799a9a0687992073840d30a859526d7) --- .../dump/module_dump/module_processer.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/debug/accuracy_tools/msprobe/pytorch/dump/module_dump/module_processer.py b/debug/accuracy_tools/msprobe/pytorch/dump/module_dump/module_processer.py index d9b67c9317..33f09d4c9a 100644 --- a/debug/accuracy_tools/msprobe/pytorch/dump/module_dump/module_processer.py +++ b/debug/accuracy_tools/msprobe/pytorch/dump/module_dump/module_processer.py @@ -39,6 +39,16 @@ def replace_checkpoint(): torch.utils.checkpoint.checkpoint = checkpoint_without_early_stop +def wrap_megatron_deallocate(func): + def wrapper_func(out, deallocate_pipeline_outputs=False): + if deallocate_pipeline_outputs and isinstance(out, torch.Tensor) and getattr(out, "_base") is not None: + out_clone = out.clone() + out.data = torch.empty((1,), device=out.device, dtype=out.dtype, ) + return func(out_clone, deallocate_pipeline_outputs) + return func(out, deallocate_pipeline_outputs) + return wrapper_func + + class ModuleProcesser: module_count = {} module_stack = [] @@ -50,6 +60,14 @@ class ModuleProcesser: BackwardHook.setup_input_hook = ModuleProcesser.clone_return_value(BackwardHook.setup_input_hook) BackwardHook.setup_output_hook = ModuleProcesser.clone_return_value(BackwardHook.setup_output_hook) replace_checkpoint() + try: + from megatron.core.pipeline_parallel import schedules + schedules.deallocate_output_tensor = wrap_megatron_deallocate(schedules.deallocate_output_tensor) + logger.info_on_rank_0("Patch megatron method success.") + except ImportError: + logger.info_on_rank_0("No megatron find.") + except Exception as e: + logger.info_on_rank_0(f"Patch megatron method failed, detail:{str(e)}") @staticmethod def clone_return_value(func): -- Gitee