From 657f854e50947c6e58ab91ea5ba2f65d3adba72e Mon Sep 17 00:00:00 2001 From: lichangwei Date: Thu, 24 Jul 2025 15:25:41 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90bugfix=E3=80=91=E9=80=82=E9=85=8Dtorch?= =?UTF-8?q?1.x?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../msprobe/pytorch/hook_module/api_register.py | 9 +++++++-- .../msprobe/pytorch/hook_module/pt_hook_manager.py | 8 ++++++-- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/debug/accuracy_tools/msprobe/pytorch/hook_module/api_register.py b/debug/accuracy_tools/msprobe/pytorch/hook_module/api_register.py index 38d45a20f..3232e124b 100644 --- a/debug/accuracy_tools/msprobe/pytorch/hook_module/api_register.py +++ b/debug/accuracy_tools/msprobe/pytorch/hook_module/api_register.py @@ -45,7 +45,6 @@ _inner_used_api = {} _supported_api_list_path = (os.path.join(os.path.dirname(os.path.realpath(__file__)), Const.SUPPORT_API_FILE_NAME),) _cuda_func_mapping = {"npu_fusion_attention": "gpu_fusion_attention"} dist_data_collect_func = {} -origin_wait = getattr(dist.Work, 'wait') _api_types = { Const.PT_FRAMEWORK: { @@ -115,6 +114,12 @@ def dist_module_forward(module, *args, **kwargs): def redirect_wait(): + if hasattr(dist, "Work"): + from torch.distributed import Work + else: + from torch._C._distributed_c10d import Work + origin_wait = Work.wait + def wrapped_wait(work): def wrapped_wait(*args, **kwargs): origin_wait(*args, **kwargs) @@ -122,7 +127,7 @@ def redirect_wait(): store_func = dist_data_collect_func.pop(args[0]) store_func() return wrapped_wait - dist.Work.wait = wrapped_wait(dist.Work) + Work.wait = wrapped_wait(Work) def npu_module_forward(module, *args, **kwargs): diff --git a/debug/accuracy_tools/msprobe/pytorch/hook_module/pt_hook_manager.py b/debug/accuracy_tools/msprobe/pytorch/hook_module/pt_hook_manager.py index 413ad3da0..0beb5d7d3 100644 --- a/debug/accuracy_tools/msprobe/pytorch/hook_module/pt_hook_manager.py +++ b/debug/accuracy_tools/msprobe/pytorch/hook_module/pt_hook_manager.py @@ -38,8 +38,12 @@ class PytorchHookManager(BaseHookManager): @staticmethod def _process_kwargs_and_output(module, hook_type, kwargs_or_output, output_or_kwargs): - kwargs = kwargs_or_output if torch_version_above_or_equal_2 else {} - output = output_or_kwargs if torch_version_above_or_equal_2 else kwargs_or_output + if hook_type == Const.API: + kwargs = kwargs_or_output + output = output_or_kwargs + else: + kwargs = kwargs_or_output if torch_version_above_or_equal_2 else {} + output = output_or_kwargs if torch_version_above_or_equal_2 else kwargs_or_output return kwargs, output def build_hook(self, hook_type, name): -- Gitee