diff --git a/debug/accuracy_tools/msprobe/pytorch/hook_module/jit_script_wrapper.py b/debug/accuracy_tools/msprobe/pytorch/hook_module/script_wrapper.py similarity index 41% rename from debug/accuracy_tools/msprobe/pytorch/hook_module/jit_script_wrapper.py rename to debug/accuracy_tools/msprobe/pytorch/hook_module/script_wrapper.py index ea2ee39ae79544b5a699800cb1e7dc9e0fc9066b..a8417cc1e380bc456b4121923c2fff7e81c16d36 100644 --- a/debug/accuracy_tools/msprobe/pytorch/hook_module/jit_script_wrapper.py +++ b/debug/accuracy_tools/msprobe/pytorch/hook_module/script_wrapper.py @@ -13,10 +13,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +import types import torch - +from torch._dynamo.convert_frame import convert_frame as _orig_convert_frame, Hooks from msprobe.pytorch.hook_module.api_register import get_api_register - +from msprobe.core.common.log import logger def wrap_jit_script_func(): def patched_script(*args, **kwargs): @@ -31,3 +32,44 @@ def wrap_jit_script_func(): original_script = torch.jit.script api_register = get_api_register() torch.jit.script = patched_script + + +def wrap_compile_script_func(): + def _patched_convert_frame(compiler_fn, hooks): + """ + 在调用原 convert_frame 生成的 _convert_frame 之前恢复 API, + 调用完之后再重新注册所有 API。 + """ + # 拿到原来 inner 版的 _convert_frame + inner_convert = _orig_convert_frame(compiler_fn, hooks) + + def _wrapped(frame: types.FrameType, cache_size: int, hooks: Hooks, frame_state): + reg = get_api_register() + # 进入前 restore + reg.restore_all_api() + try: + result = inner_convert(frame, cache_size, hooks, frame_state) + except Exception: + # 异常时也要确保 register + reg.register_all_api() + raise + # 正常结束后 register + reg.register_all_api() + return result + + # 保留原属性以兼容 + _wrapped._torchdynamo_orig_callable = compiler_fn # type: ignore[attr-defined] + _wrapped._clone_with_backend = lambda backend: _patched_convert_frame(backend, + hooks) # type: ignore[attr-defined] + return _wrapped + + import torch._dynamo.convert_frame as _cf_mod + _cf_mod.convert_frame = _patched_convert_frame + + +def wrap_script_func(): + wrap_jit_script_func() + try: + wrap_compile_script_func() + except Exception as e: + logger.warning(f"function wrap_compile_script_func fail: {e}") diff --git a/debug/accuracy_tools/msprobe/pytorch/pytorch_service.py b/debug/accuracy_tools/msprobe/pytorch/pytorch_service.py index aad49a90d95ae7bccc56576f8c34a98de7f55ca6..ecd2907fecc1ec05958903b9c3d9b98e86082c00 100644 --- a/debug/accuracy_tools/msprobe/pytorch/pytorch_service.py +++ b/debug/accuracy_tools/msprobe/pytorch/pytorch_service.py @@ -21,7 +21,7 @@ from msprobe.pytorch.common.utils import get_rank_if_initialized, torch_version_ from msprobe.pytorch.dump.module_dump.module_processer import ModuleProcesser from msprobe.pytorch.hook_module.api_register import get_api_register, ApiTemplate from msprobe.pytorch.hook_module.hook_module import HOOKModule -from msprobe.pytorch.hook_module.jit_script_wrapper import wrap_jit_script_func +from msprobe.pytorch.hook_module.script_wrapper import wrap_script_func, wrap_jit_script_func from msprobe.pytorch.hook_module.pt_hook_manager import PytorchHookManager from msprobe.pytorch.hook_module.register_optimizer_hook import register_optimizer_hook @@ -56,7 +56,8 @@ class PytorchService(BaseService): def _register_api_hook(self): super()._register_api_hook() - wrap_jit_script_func() + wrap_script_func() + redirect_wait() def _register_module_hook(self): ModuleProcesser.enable_module_dump = True diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_pt_jit_script_wrapper.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_pt_jit_script_wrapper.py index 61909523fc523dead62887ba94f399424b72a098..8e93c8de18f10f477a109a6134aa653604ef3f47 100644 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_pt_jit_script_wrapper.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_pt_jit_script_wrapper.py @@ -17,7 +17,7 @@ import unittest from unittest.mock import MagicMock, patch import torch -from msprobe.pytorch.hook_module.jit_script_wrapper import wrap_jit_script_func +from msprobe.pytorch.hook_module.script_wrapper import wrap_jit_script_func class TestWrapJitScriptFunc(unittest.TestCase): @@ -33,7 +33,7 @@ class TestWrapJitScriptFunc(unittest.TestCase): torch.jit.script = self.original_script @patch('torch.jit.script', new_callable=MagicMock) - @patch('msprobe.pytorch.hook_module.jit_script_wrapper.get_api_register', return_value=MagicMock()) + @patch('msprobe.pytorch.hook_module.script_wrapper.get_api_register', return_value=MagicMock()) def test_patched_script(self, mock_get_api, mock_original_script): mock_original_script.return_value = "mocked_result" mock_get_api.return_value = self.mock_api_register diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/test_pt_service.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/test_pt_service.py index c0c56315ec3b9ddb2bdd1f0724f9edd1f4597b99..f9d5744a957eaf8b4fefbe737cfeb5866c961f6e 100644 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/test_pt_service.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/test_pt_service.py @@ -75,7 +75,7 @@ class TestPytorchService(unittest.TestCase): self.service._register_hook() mock_register_opt.assert_not_called() - @patch('msprobe.pytorch.pytorch_service.wrap_jit_script_func') + @patch('msprobe.pytorch.pytorch_service.wrap_script_func') def test_register_api_hook(self, mock_wrap_jit): self.service.config.level = Const.LEVEL_L1 self.service._register_api_hook()