From 52a18727be51ee61803c6b8b22eaee274695e589 Mon Sep 17 00:00:00 2001 From: curry3 <485078529@qq.com> Date: Thu, 28 Aug 2025 19:32:04 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90bugfix=E3=80=91=E4=BF=AE=E5=A4=8DL1=20?= =?UTF-8?q?dump=E5=9C=BA=E6=99=AF=E4=B8=8Bempty=E7=AD=89API=E7=9A=84device?= =?UTF-8?q?=E4=BF=A1=E6=81=AF=E4=B8=A2=E5=A4=B1=E7=9A=84=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../msprobe/pytorch/hook_module/script_wrapper.py | 15 +++++++++++++-- .../msprobe/pytorch/pytorch_service.py | 5 +++-- 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/debug/accuracy_tools/msprobe/pytorch/hook_module/script_wrapper.py b/debug/accuracy_tools/msprobe/pytorch/hook_module/script_wrapper.py index 2bca426560..7da0221c22 100644 --- a/debug/accuracy_tools/msprobe/pytorch/hook_module/script_wrapper.py +++ b/debug/accuracy_tools/msprobe/pytorch/hook_module/script_wrapper.py @@ -16,11 +16,12 @@ import functools import importlib import types + import torch + from msprobe.core.common.log import logger -from msprobe.pytorch.hook_module.api_register import get_api_register from msprobe.pytorch.common.utils import torch_version_above_or_equal_2 - +from msprobe.pytorch.hook_module.api_register import get_api_register if torch_version_above_or_equal_2: from torch._dynamo.convert_frame import convert_frame as _orig_convert_frame, Hooks @@ -123,6 +124,16 @@ def unpatch_dynamo_compile() -> bool: return True +def preprocess_func(): + try: + from torch.utils._device import _device_constructors + _device_constructors() + except ImportError: + pass + except Exception as e: + logger.warning(f"Failed to execute _device_constructors. Error Details: {str(e)}") + + def wrap_script_func(): wrap_jit_script_func() if torch_version_above_or_equal_2: diff --git a/debug/accuracy_tools/msprobe/pytorch/pytorch_service.py b/debug/accuracy_tools/msprobe/pytorch/pytorch_service.py index 4007553b6d..d9041ffc55 100644 --- a/debug/accuracy_tools/msprobe/pytorch/pytorch_service.py +++ b/debug/accuracy_tools/msprobe/pytorch/pytorch_service.py @@ -16,13 +16,13 @@ from msprobe.core.common.utils import Const from msprobe.core.service import BaseService from msprobe.pytorch.common.log import logger -from msprobe.pytorch.common.utils import get_rank_if_initialized, torch_version_above_or_equal_2 +from msprobe.pytorch.common.utils import get_rank_if_initialized from msprobe.pytorch.dump.module_dump.module_processer import ModuleProcesser from msprobe.pytorch.hook_module.api_register import get_api_register, ApiTemplate, redirect_wait from msprobe.pytorch.hook_module.hook_module import HOOKModule -from msprobe.pytorch.hook_module.script_wrapper import wrap_script_func, wrap_jit_script_func from msprobe.pytorch.hook_module.pt_hook_manager import PytorchHookManager from msprobe.pytorch.hook_module.register_optimizer_hook import register_optimizer_hook +from msprobe.pytorch.hook_module.script_wrapper import wrap_script_func, preprocess_func class PytorchService(BaseService): @@ -49,6 +49,7 @@ class PytorchService(BaseService): register_optimizer_hook(self.data_collector) def _register_api_hook(self): + preprocess_func() super()._register_api_hook() wrap_script_func() redirect_wait() -- Gitee