From b70ab992ddcc109a4579d8bc1bdc0bc31c04df56 Mon Sep 17 00:00:00 2001 From: lichangwei Date: Thu, 24 Jul 2025 15:25:41 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90bugfix=E3=80=91=E9=80=82=E9=85=8Dtorch?= =?UTF-8?q?1.x?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../msprobe/pytorch/hook_module/api_register.py | 9 +++++++-- .../msprobe/pytorch/hook_module/pt_hook_manager.py | 8 ++++++-- .../msprobe/pytorch/hook_module/script_wrapper.py | 8 ++++++-- 3 files changed, 19 insertions(+), 6 deletions(-) diff --git a/debug/accuracy_tools/msprobe/pytorch/hook_module/api_register.py b/debug/accuracy_tools/msprobe/pytorch/hook_module/api_register.py index 38d45a20f..3232e124b 100644 --- a/debug/accuracy_tools/msprobe/pytorch/hook_module/api_register.py +++ b/debug/accuracy_tools/msprobe/pytorch/hook_module/api_register.py @@ -45,7 +45,6 @@ _inner_used_api = {} _supported_api_list_path = (os.path.join(os.path.dirname(os.path.realpath(__file__)), Const.SUPPORT_API_FILE_NAME),) _cuda_func_mapping = {"npu_fusion_attention": "gpu_fusion_attention"} dist_data_collect_func = {} -origin_wait = getattr(dist.Work, 'wait') _api_types = { Const.PT_FRAMEWORK: { @@ -115,6 +114,12 @@ def dist_module_forward(module, *args, **kwargs): def redirect_wait(): + if hasattr(dist, "Work"): + from torch.distributed import Work + else: + from torch._C._distributed_c10d import Work + origin_wait = Work.wait + def wrapped_wait(work): def wrapped_wait(*args, **kwargs): origin_wait(*args, **kwargs) @@ -122,7 +127,7 @@ def redirect_wait(): store_func = dist_data_collect_func.pop(args[0]) store_func() return wrapped_wait - dist.Work.wait = wrapped_wait(dist.Work) + Work.wait = wrapped_wait(Work) def npu_module_forward(module, *args, **kwargs): diff --git a/debug/accuracy_tools/msprobe/pytorch/hook_module/pt_hook_manager.py b/debug/accuracy_tools/msprobe/pytorch/hook_module/pt_hook_manager.py index 3c2a13215..3c00e56bf 100644 --- a/debug/accuracy_tools/msprobe/pytorch/hook_module/pt_hook_manager.py +++ b/debug/accuracy_tools/msprobe/pytorch/hook_module/pt_hook_manager.py @@ -38,8 +38,12 @@ class PytorchHookManager(BaseHookManager): @staticmethod def _process_kwargs_and_output(module, hook_type, kwargs_or_output, output_or_kwargs): - kwargs = kwargs_or_output if torch_version_above_or_equal_2 else {} - output = output_or_kwargs if torch_version_above_or_equal_2 else kwargs_or_output + if hook_type == Const.API: + kwargs = kwargs_or_output + output = output_or_kwargs + else: + kwargs = kwargs_or_output if torch_version_above_or_equal_2 else {} + output = output_or_kwargs if torch_version_above_or_equal_2 else kwargs_or_output return kwargs, output def build_hook(self, hook_type, name): diff --git a/debug/accuracy_tools/msprobe/pytorch/hook_module/script_wrapper.py b/debug/accuracy_tools/msprobe/pytorch/hook_module/script_wrapper.py index 136f0c997..c6d611d5c 100644 --- a/debug/accuracy_tools/msprobe/pytorch/hook_module/script_wrapper.py +++ b/debug/accuracy_tools/msprobe/pytorch/hook_module/script_wrapper.py @@ -15,8 +15,11 @@ import types import torch -from torch._dynamo.convert_frame import convert_frame as _orig_convert_frame, Hooks from msprobe.pytorch.hook_module.api_register import get_api_register +from msprobe.pytorch.common.utils import torch_version_above_or_equal_2 + +if torch_version_above_or_equal_2: + from torch._dynamo.convert_frame import convert_frame as _orig_convert_frame, Hooks def wrap_jit_script_func(): @@ -69,4 +72,5 @@ def wrap_compile_script_func(): def wrap_script_func(): wrap_jit_script_func() - wrap_compile_script_func() \ No newline at end of file + if torch_version_above_or_equal_2: + wrap_compile_script_func() \ No newline at end of file -- Gitee