From 972bc13001bf8079a254cffb0753f8968c475b8d Mon Sep 17 00:00:00 2001 From: yangxinxian <947098055@qq.com> Date: Fri, 27 Jun 2025 17:54:26 +0800 Subject: [PATCH 1/8] Update jit_script_wrapper.py --- .../pytorch/hook_module/jit_script_wrapper.py | 40 ++++++++++++++++++- 1 file changed, 39 insertions(+), 1 deletion(-) diff --git a/debug/accuracy_tools/msprobe/pytorch/hook_module/jit_script_wrapper.py b/debug/accuracy_tools/msprobe/pytorch/hook_module/jit_script_wrapper.py index ea2ee39ae7..3acfb49a13 100644 --- a/debug/accuracy_tools/msprobe/pytorch/hook_module/jit_script_wrapper.py +++ b/debug/accuracy_tools/msprobe/pytorch/hook_module/jit_script_wrapper.py @@ -14,7 +14,8 @@ # limitations under the License. import torch - +import types +from torch._dynamo.convert_frame import convert_frame as _orig_convert_frame, Hooks from msprobe.pytorch.hook_module.api_register import get_api_register @@ -31,3 +32,40 @@ def wrap_jit_script_func(): original_script = torch.jit.script api_register = get_api_register() torch.jit.script = patched_script + + +def wrap_compile_script_func(): + def _patched_convert_frame(compiler_fn, hooks): + """ + 在调用原 convert_frame 生成的 _convert_frame 之前恢复 API, + 调用完之后再重新注册所有 API。 + """ + # 拿到原来 inner 版的 _convert_frame + inner_convert = _orig_convert_frame(compiler_fn, hooks) + + def _wrapped(frame: types.FrameType, cache_size: int, hooks: Hooks, frame_state): + reg = get_api_register() + # 进入前 restore + reg.restore_all_api() + try: + result = inner_convert(frame, cache_size, hooks, frame_state) + except Exception: + # 异常时也要确保 register + reg.register_all_api() + raise + # 正常结束后 register + reg.register_all_api() + return result + + # 保留原属性以兼容 + _wrapped._torchdynamo_orig_callable = compiler_fn # type: ignore[attr-defined] + _wrapped._clone_with_backend = lambda backend: _patched_convert_frame(backend, + hooks) # type: ignore[attr-defined] + return _wrapped + + torch._dynamo.convert_frame = _patched_convert_frame + + +def wrap_script_func(): + wrap_jit_script_func() + wrap_compile_script_func() \ No newline at end of file -- Gitee From 9cfa7008683e104eed893ff70707bdee3c5d88fe Mon Sep 17 00:00:00 2001 From: yangxinxian <947098055@qq.com> Date: Fri, 27 Jun 2025 17:55:46 +0800 Subject: [PATCH 2/8] =?UTF-8?q?=E4=BF=AE=E5=A4=8Ddynamo=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../hook_module/{jit_script_wrapper.py => script_wrapper.py} | 0 debug/accuracy_tools/msprobe/pytorch/pytorch_service.py | 5 +++-- .../pytorch_ut/hook_module/test_pt_jit_script_wrapper.py | 4 ++-- 3 files changed, 5 insertions(+), 4 deletions(-) rename debug/accuracy_tools/msprobe/pytorch/hook_module/{jit_script_wrapper.py => script_wrapper.py} (100%) diff --git a/debug/accuracy_tools/msprobe/pytorch/hook_module/jit_script_wrapper.py b/debug/accuracy_tools/msprobe/pytorch/hook_module/script_wrapper.py similarity index 100% rename from debug/accuracy_tools/msprobe/pytorch/hook_module/jit_script_wrapper.py rename to debug/accuracy_tools/msprobe/pytorch/hook_module/script_wrapper.py diff --git a/debug/accuracy_tools/msprobe/pytorch/pytorch_service.py b/debug/accuracy_tools/msprobe/pytorch/pytorch_service.py index aad49a90d9..a4d628719a 100644 --- a/debug/accuracy_tools/msprobe/pytorch/pytorch_service.py +++ b/debug/accuracy_tools/msprobe/pytorch/pytorch_service.py @@ -21,7 +21,7 @@ from msprobe.pytorch.common.utils import get_rank_if_initialized, torch_version_ from msprobe.pytorch.dump.module_dump.module_processer import ModuleProcesser from msprobe.pytorch.hook_module.api_register import get_api_register, ApiTemplate from msprobe.pytorch.hook_module.hook_module import HOOKModule -from msprobe.pytorch.hook_module.jit_script_wrapper import wrap_jit_script_func +from msprobe.pytorch.hook_module.script_wrapper import wrap_script_func from msprobe.pytorch.hook_module.pt_hook_manager import PytorchHookManager from msprobe.pytorch.hook_module.register_optimizer_hook import register_optimizer_hook @@ -56,7 +56,8 @@ class PytorchService(BaseService): def _register_api_hook(self): super()._register_api_hook() - wrap_jit_script_func() + wrap_script_func() + redirect_wait() def _register_module_hook(self): ModuleProcesser.enable_module_dump = True diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_pt_jit_script_wrapper.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_pt_jit_script_wrapper.py index 61909523fc..8e93c8de18 100644 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_pt_jit_script_wrapper.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_pt_jit_script_wrapper.py @@ -17,7 +17,7 @@ import unittest from unittest.mock import MagicMock, patch import torch -from msprobe.pytorch.hook_module.jit_script_wrapper import wrap_jit_script_func +from msprobe.pytorch.hook_module.script_wrapper import wrap_jit_script_func class TestWrapJitScriptFunc(unittest.TestCase): @@ -33,7 +33,7 @@ class TestWrapJitScriptFunc(unittest.TestCase): torch.jit.script = self.original_script @patch('torch.jit.script', new_callable=MagicMock) - @patch('msprobe.pytorch.hook_module.jit_script_wrapper.get_api_register', return_value=MagicMock()) + @patch('msprobe.pytorch.hook_module.script_wrapper.get_api_register', return_value=MagicMock()) def test_patched_script(self, mock_get_api, mock_original_script): mock_original_script.return_value = "mocked_result" mock_get_api.return_value = self.mock_api_register -- Gitee From 0e3b021ada24f239377bbcc7197f66779e21294c Mon Sep 17 00:00:00 2001 From: yangxinxian <947098055@qq.com> Date: Sat, 28 Jun 2025 11:04:58 +0800 Subject: [PATCH 3/8] Update script_wrapper.py --- .../msprobe/pytorch/hook_module/script_wrapper.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/debug/accuracy_tools/msprobe/pytorch/hook_module/script_wrapper.py b/debug/accuracy_tools/msprobe/pytorch/hook_module/script_wrapper.py index 3acfb49a13..8857ac2096 100644 --- a/debug/accuracy_tools/msprobe/pytorch/hook_module/script_wrapper.py +++ b/debug/accuracy_tools/msprobe/pytorch/hook_module/script_wrapper.py @@ -17,7 +17,7 @@ import torch import types from torch._dynamo.convert_frame import convert_frame as _orig_convert_frame, Hooks from msprobe.pytorch.hook_module.api_register import get_api_register - +from msprobe.core.common.log import logger def wrap_jit_script_func(): def patched_script(*args, **kwargs): @@ -68,4 +68,7 @@ def wrap_compile_script_func(): def wrap_script_func(): wrap_jit_script_func() - wrap_compile_script_func() \ No newline at end of file + try: + wrap_compile_script_func() + except Exception as e: + logger.warning(f"function wrap_compile_script_func fail: {e}") -- Gitee From d2c28d895286b6aec9a512c8673114266885f209 Mon Sep 17 00:00:00 2001 From: yangxinxian <947098055@qq.com> Date: Mon, 30 Jun 2025 17:43:02 +0800 Subject: [PATCH 4/8] Update pytorch_service.py --- debug/accuracy_tools/msprobe/pytorch/pytorch_service.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/debug/accuracy_tools/msprobe/pytorch/pytorch_service.py b/debug/accuracy_tools/msprobe/pytorch/pytorch_service.py index a4d628719a..ecd2907fec 100644 --- a/debug/accuracy_tools/msprobe/pytorch/pytorch_service.py +++ b/debug/accuracy_tools/msprobe/pytorch/pytorch_service.py @@ -21,7 +21,7 @@ from msprobe.pytorch.common.utils import get_rank_if_initialized, torch_version_ from msprobe.pytorch.dump.module_dump.module_processer import ModuleProcesser from msprobe.pytorch.hook_module.api_register import get_api_register, ApiTemplate from msprobe.pytorch.hook_module.hook_module import HOOKModule -from msprobe.pytorch.hook_module.script_wrapper import wrap_script_func +from msprobe.pytorch.hook_module.script_wrapper import wrap_script_func, wrap_jit_script_func from msprobe.pytorch.hook_module.pt_hook_manager import PytorchHookManager from msprobe.pytorch.hook_module.register_optimizer_hook import register_optimizer_hook -- Gitee From 10444594ade0b6436580266d67b2df4ceb1f1837 Mon Sep 17 00:00:00 2001 From: yangxinxian <947098055@qq.com> Date: Mon, 30 Jun 2025 17:44:32 +0800 Subject: [PATCH 5/8] Update script_wrapper.py --- .../msprobe/pytorch/hook_module/script_wrapper.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/debug/accuracy_tools/msprobe/pytorch/hook_module/script_wrapper.py b/debug/accuracy_tools/msprobe/pytorch/hook_module/script_wrapper.py index 8857ac2096..fa99c1192b 100644 --- a/debug/accuracy_tools/msprobe/pytorch/hook_module/script_wrapper.py +++ b/debug/accuracy_tools/msprobe/pytorch/hook_module/script_wrapper.py @@ -13,8 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -import torch import types +import torch from torch._dynamo.convert_frame import convert_frame as _orig_convert_frame, Hooks from msprobe.pytorch.hook_module.api_register import get_api_register from msprobe.core.common.log import logger -- Gitee From e47d3165376525fcfa5a770d552c8d1825a63d0b Mon Sep 17 00:00:00 2001 From: yangxinxian <947098055@qq.com> Date: Mon, 30 Jun 2025 19:21:29 +0800 Subject: [PATCH 6/8] Update test_pt_service.py --- debug/accuracy_tools/msprobe/test/pytorch_ut/test_pt_service.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/test_pt_service.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/test_pt_service.py index c0c56315ec..f9d5744a95 100644 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/test_pt_service.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/test_pt_service.py @@ -75,7 +75,7 @@ class TestPytorchService(unittest.TestCase): self.service._register_hook() mock_register_opt.assert_not_called() - @patch('msprobe.pytorch.pytorch_service.wrap_jit_script_func') + @patch('msprobe.pytorch.pytorch_service.wrap_script_func') def test_register_api_hook(self, mock_wrap_jit): self.service.config.level = Const.LEVEL_L1 self.service._register_api_hook() -- Gitee From 731bc5d410794b6cbbf4442022586d56addf4bdc Mon Sep 17 00:00:00 2001 From: yangxinxian <947098055@qq.com> Date: Mon, 30 Jun 2025 19:21:29 +0800 Subject: [PATCH 7/8] Update test_pt_service.py -- Gitee From 83616af2f54116ee408b754ebffeea2307f7c845 Mon Sep 17 00:00:00 2001 From: yangxinxian <947098055@qq.com> Date: Fri, 4 Jul 2025 17:32:27 +0800 Subject: [PATCH 8/8] Update script_wrapper.py --- .../msprobe/pytorch/hook_module/script_wrapper.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/debug/accuracy_tools/msprobe/pytorch/hook_module/script_wrapper.py b/debug/accuracy_tools/msprobe/pytorch/hook_module/script_wrapper.py index fa99c1192b..a8417cc1e3 100644 --- a/debug/accuracy_tools/msprobe/pytorch/hook_module/script_wrapper.py +++ b/debug/accuracy_tools/msprobe/pytorch/hook_module/script_wrapper.py @@ -63,7 +63,8 @@ def wrap_compile_script_func(): hooks) # type: ignore[attr-defined] return _wrapped - torch._dynamo.convert_frame = _patched_convert_frame + import torch._dynamo.convert_frame as _cf_mod + _cf_mod.convert_frame = _patched_convert_frame def wrap_script_func(): -- Gitee