diff --git a/debug/accuracy_tools/msprobe/pytorch/hook_module/api_register.py b/debug/accuracy_tools/msprobe/pytorch/hook_module/api_register.py index 38d45a20fdd8839d61472f5905727f975ef36b77..3232e124b420d6d457e1c1e0c694a96b63c1170b 100644 --- a/debug/accuracy_tools/msprobe/pytorch/hook_module/api_register.py +++ b/debug/accuracy_tools/msprobe/pytorch/hook_module/api_register.py @@ -45,7 +45,6 @@ _inner_used_api = {} _supported_api_list_path = (os.path.join(os.path.dirname(os.path.realpath(__file__)), Const.SUPPORT_API_FILE_NAME),) _cuda_func_mapping = {"npu_fusion_attention": "gpu_fusion_attention"} dist_data_collect_func = {} -origin_wait = getattr(dist.Work, 'wait') _api_types = { Const.PT_FRAMEWORK: { @@ -115,6 +114,12 @@ def dist_module_forward(module, *args, **kwargs): def redirect_wait(): + if hasattr(dist, "Work"): + from torch.distributed import Work + else: + from torch._C._distributed_c10d import Work + origin_wait = Work.wait + def wrapped_wait(work): def wrapped_wait(*args, **kwargs): origin_wait(*args, **kwargs) @@ -122,7 +127,7 @@ def redirect_wait(): store_func = dist_data_collect_func.pop(args[0]) store_func() return wrapped_wait - dist.Work.wait = wrapped_wait(dist.Work) + Work.wait = wrapped_wait(Work) def npu_module_forward(module, *args, **kwargs): diff --git a/debug/accuracy_tools/msprobe/pytorch/hook_module/pt_hook_manager.py b/debug/accuracy_tools/msprobe/pytorch/hook_module/pt_hook_manager.py index 413ad3da00adfaebfa3b2652e079dd71267457c2..0beb5d7d39a0d60588f60a1d667585b0bc21f475 100644 --- a/debug/accuracy_tools/msprobe/pytorch/hook_module/pt_hook_manager.py +++ b/debug/accuracy_tools/msprobe/pytorch/hook_module/pt_hook_manager.py @@ -38,8 +38,12 @@ class PytorchHookManager(BaseHookManager): @staticmethod def _process_kwargs_and_output(module, hook_type, kwargs_or_output, output_or_kwargs): - kwargs = kwargs_or_output if torch_version_above_or_equal_2 else {} - output = output_or_kwargs if torch_version_above_or_equal_2 else kwargs_or_output + if hook_type == Const.API: + kwargs = kwargs_or_output + output = output_or_kwargs + else: + kwargs = kwargs_or_output if torch_version_above_or_equal_2 else {} + output = output_or_kwargs if torch_version_above_or_equal_2 else kwargs_or_output return kwargs, output def build_hook(self, hook_type, name):