diff --git a/debug/accuracy_tools/msprobe/pytorch/hook_module/script_wrapper.py b/debug/accuracy_tools/msprobe/pytorch/hook_module/script_wrapper.py index 2bca42656017acfa982544100c5cd59b730e9640..7da0221c22e259f58c59eaa540bb824f427dc668 100644 --- a/debug/accuracy_tools/msprobe/pytorch/hook_module/script_wrapper.py +++ b/debug/accuracy_tools/msprobe/pytorch/hook_module/script_wrapper.py @@ -16,11 +16,12 @@ import functools import importlib import types + import torch + from msprobe.core.common.log import logger -from msprobe.pytorch.hook_module.api_register import get_api_register from msprobe.pytorch.common.utils import torch_version_above_or_equal_2 - +from msprobe.pytorch.hook_module.api_register import get_api_register if torch_version_above_or_equal_2: from torch._dynamo.convert_frame import convert_frame as _orig_convert_frame, Hooks @@ -123,6 +124,16 @@ def unpatch_dynamo_compile() -> bool: return True +def preprocess_func(): + try: + from torch.utils._device import _device_constructors + _device_constructors() + except ImportError: + pass + except Exception as e: + logger.warning(f"Failed to execute _device_constructors. Error Details: {str(e)}") + + def wrap_script_func(): wrap_jit_script_func() if torch_version_above_or_equal_2: diff --git a/debug/accuracy_tools/msprobe/pytorch/pytorch_service.py b/debug/accuracy_tools/msprobe/pytorch/pytorch_service.py index 4007553b6d80bade6f3f13456dab1cd95b4c5734..d9041ffc5579e6ebfe8897336c8b0985bd0f5e50 100644 --- a/debug/accuracy_tools/msprobe/pytorch/pytorch_service.py +++ b/debug/accuracy_tools/msprobe/pytorch/pytorch_service.py @@ -16,13 +16,13 @@ from msprobe.core.common.utils import Const from msprobe.core.service import BaseService from msprobe.pytorch.common.log import logger -from msprobe.pytorch.common.utils import get_rank_if_initialized, torch_version_above_or_equal_2 +from msprobe.pytorch.common.utils import get_rank_if_initialized from msprobe.pytorch.dump.module_dump.module_processer import ModuleProcesser from msprobe.pytorch.hook_module.api_register import get_api_register, ApiTemplate, redirect_wait from msprobe.pytorch.hook_module.hook_module import HOOKModule -from msprobe.pytorch.hook_module.script_wrapper import wrap_script_func, wrap_jit_script_func from msprobe.pytorch.hook_module.pt_hook_manager import PytorchHookManager from msprobe.pytorch.hook_module.register_optimizer_hook import register_optimizer_hook +from msprobe.pytorch.hook_module.script_wrapper import wrap_script_func, preprocess_func class PytorchService(BaseService): @@ -49,6 +49,7 @@ class PytorchService(BaseService): register_optimizer_hook(self.data_collector) def _register_api_hook(self): + preprocess_func() super()._register_api_hook() wrap_script_func() redirect_wait()