From dd3f5b89efa3bf61850cbcdebb645044af53e24f Mon Sep 17 00:00:00 2001 From: l30044004 Date: Thu, 21 Aug 2025 11:36:51 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90bugfix=E3=80=91=E5=A4=84=E7=90=86check?= =?UTF-8?q?point=E8=AE=BE=E7=BD=AEearly=5Fstop=E4=B8=BAFalse=E5=AF=BC?= =?UTF-8?q?=E8=87=B4deepspeed=20stage3=20dump=E6=8A=A5=E9=94=99=E9=97=AE?= =?UTF-8?q?=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../dump/module_dump/module_processer.py | 35 +++++--- .../pytorch_ut/dump/test_module_processer.py | 89 ++++++++++++++----- 2 files changed, 88 insertions(+), 36 deletions(-) diff --git a/debug/accuracy_tools/msprobe/pytorch/dump/module_dump/module_processer.py b/debug/accuracy_tools/msprobe/pytorch/dump/module_dump/module_processer.py index c152879d39..947fc5c097 100644 --- a/debug/accuracy_tools/msprobe/pytorch/dump/module_dump/module_processer.py +++ b/debug/accuracy_tools/msprobe/pytorch/dump/module_dump/module_processer.py @@ -30,17 +30,7 @@ from msprobe.pytorch.dump.module_dump.hook_wrapper import wrap_setup_input_outpu torch_version_above_or_equal_2 = torch.__version__.split('+')[0] >= '2.0' if torch_version_above_or_equal_2: - from torch.utils.checkpoint import checkpoint as origin_checkpoint, set_checkpoint_early_stop - - -def checkpoint_without_early_stop(*args, **kwargs): - with set_checkpoint_early_stop(False): - return origin_checkpoint(*args, **kwargs) - - -def replace_checkpoint(): - if torch_version_above_or_equal_2: - torch.utils.checkpoint.checkpoint = checkpoint_without_early_stop + from torch.utils.checkpoint import _StopRecomputationError def wrap_megatron_deallocate(func): @@ -54,6 +44,27 @@ def wrap_megatron_deallocate(func): return wrapper_func +def wrap_forward_with_hook_safety(module): + """ + 包装模块的forward方法,确保异常时也执行forward_hook。 + """ + original_forward = module.forward + + def wrapped_forward(*args, **kwargs): + try: + output = original_forward(*args, **kwargs) + return output + except _StopRecomputationError as e: + exception_output = None + if len(module._forward_hooks.values()) > 0: + # msprobe的forward_hook会出现在第一个,仅执行msprobe的forward_hook + hook_fn = list(module._forward_hooks.values())[0] + hook_fn(module, args, kwargs, exception_output) + raise e + if torch_version_above_or_equal_2: + module.forward = wrapped_forward + + class ModuleProcesser: module_queue = ModuleQueue() module_count = {} @@ -67,7 +78,6 @@ class ModuleProcesser: def __init__(self, scope): self.scope = scope if isinstance(scope, (ModuleRangeScope, MixRangeScope)) else None wrap_setup_input_output_hook() - replace_checkpoint() try: from megatron.core.pipeline_parallel import schedules origin_func_id = id(schedules.deallocate_output_tensor) @@ -156,6 +166,7 @@ class ModuleProcesser: f"which may cause abnormal data dump. The backward data dump for this module will be skipped." ) ModuleProcesser.module_with_backward_hook[prefix_name] = True + wrap_forward_with_hook_safety(module) register_forward_pre_hook(module, forward_pre_hook) def build_module_hook(self, module_name, build_data_hook): diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/dump/test_module_processer.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/dump/test_module_processer.py index ec89ff6f1d..50c43288d2 100644 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/dump/test_module_processer.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/dump/test_module_processer.py @@ -13,6 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import sys +from io import StringIO import threading import unittest @@ -23,37 +25,45 @@ import msprobe.pytorch.dump.module_dump.module_processer as mp from msprobe.core.data_dump.scope import ModuleRangeScope from msprobe.pytorch.dump.module_dump.module_processer import ( ModuleProcesser, - replace_checkpoint, - checkpoint_without_early_stop, - wrap_megatron_deallocate + wrap_megatron_deallocate, + wrap_forward_with_hook_safety ) +from torch.utils.checkpoint import _StopRecomputationError ori_checkpoint = torch.utils.checkpoint.checkpoint +class TestModule(torch.nn.Module): + """测试用的模块类,可控制是否抛出异常""" + + def __init__(self, raise_exception=False): + super().__init__() + self.raise_exception = raise_exception + + def forward(self, x, *args, **kwargs): + if self.raise_exception: + raise _StopRecomputationError() + return x * 2 + + +def forward_hook_fn(module, args, kwargs_or_output, output_or_kwargs=None): + print(f"The forward_hook executed normally.") + + class TestWrapper(unittest.TestCase): def setUp(self): torch.utils.checkpoint.checkpoint = ori_checkpoint + self.held_output = StringIO() + self.original_stdout = sys.stdout + sys.stdout = self.held_output - def test_replace_checkpoint_for_torch_version_above_2(self): - mp.torch_version_above_or_equal_2 = True - with patch('msprobe.pytorch.dump.module_dump.module_processer.checkpoint_without_early_stop') as mock_obj: - replace_checkpoint() - self.assertEqual(torch.utils.checkpoint.checkpoint, mock_obj) - - def test_replace_checkpoint_for_torch_version_below_2(self): - mp.torch_version_above_or_equal_2 = False - replace_checkpoint() - self.assertEqual(torch.utils.checkpoint.checkpoint, ori_checkpoint) - - def test_checkpoint_without_early_stop(self): - mock_checkpoint = MagicMock(return_value="test_result") + def tearDown(self): + """恢复标准输出""" + sys.stdout = self.original_stdout - with patch('msprobe.pytorch.dump.module_dump.module_processer.set_checkpoint_early_stop', MagicMock()), \ - patch('msprobe.pytorch.dump.module_dump.module_processer.origin_checkpoint', mock_checkpoint): - result = checkpoint_without_early_stop("input") - mock_checkpoint.assert_called_once_with("input") - self.assertEqual(result, "test_result") + def get_output(self): + """获取捕获的输出内容""" + return self.held_output.getvalue().strip() def test_wrap_megatron_deallocate(self): mock_func = MagicMock(return_value="output_test") @@ -75,6 +85,39 @@ class TestWrapper(unittest.TestCase): self.assertEqual(result, "output_test") mock_func.assert_called_with("normal_input", False) + def test_normal_forward_execution(self): + """测试正常执行forward时的情况""" + # 准备测试模块和hook + module = TestModule(raise_exception=False) + module.register_forward_hook(forward_hook_fn) + + # 应用包装函数 + wrap_forward_with_hook_safety(module) + + # 执行forward + input_tensor = torch.tensor(3.0) + output = module(input_tensor) + + # 验证结果和hook调用 + self.assertEqual(output.item(), 6.0) + self.assertIn("The forward_hook executed normally.", self.get_output()) + + def test_stop_recomputation_exception_triggers_hook(self): + """测试抛出_StopRecomputationError时hook被调用""" + # 准备测试模块和hook + module = TestModule(raise_exception=True) + module.register_forward_hook(forward_hook_fn) + + # 应用包装函数 + wrap_forward_with_hook_safety(module) + + # 执行forward并验证异常 + input_tensor = torch.tensor(3.0) + with self.assertRaises(_StopRecomputationError): + module(input_tensor) + + self.assertIn("The forward_hook executed normally.", self.get_output()) + class TestModuleProcesser(unittest.TestCase): def setUp(self): @@ -87,12 +130,10 @@ class TestModuleProcesser(unittest.TestCase): self.mock_scope = MagicMock() @patch('msprobe.pytorch.dump.module_dump.module_processer.wrap_setup_input_output_hook') - @patch('msprobe.pytorch.dump.module_dump.module_processer.replace_checkpoint') - def test_init_with_valid_scope(self, mock_replace, mock_wrap): + def test_init_with_valid_scope(self, mock_wrap): processor = ModuleProcesser(self.scope) self.assertEqual(processor.scope, self.scope) mock_wrap.assert_called_once() - mock_replace.assert_called_once() @patch('msprobe.pytorch.dump.module_dump.module_processer.logger.info_on_rank_0') def test_init_without_megatron(self, mock_log): -- Gitee