From 33c58df6fa2c3a031742f306d42db362329f4e72 Mon Sep 17 00:00:00 2001 From: lcw Date: Thu, 19 Jun 2025 17:29:47 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90bugfix=E3=80=91pytorch=E9=80=9A?= =?UTF-8?q?=E4=BF=A1=E7=AE=97=E5=AD=90wait=E6=96=B9=E5=BC=8F=E4=BF=AE?= =?UTF-8?q?=E6=94=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../msprobe/core/hook_manager.py | 2 ++ .../pytorch/hook_module/api_register.py | 33 ++++++++++++++++--- 2 files changed, 30 insertions(+), 5 deletions(-) diff --git a/debug/accuracy_tools/msprobe/core/hook_manager.py b/debug/accuracy_tools/msprobe/core/hook_manager.py index bb4f679c8b1..c662309b91e 100644 --- a/debug/accuracy_tools/msprobe/core/hook_manager.py +++ b/debug/accuracy_tools/msprobe/core/hook_manager.py @@ -117,6 +117,8 @@ class BaseHookManager(ABC): def _should_execute_hook(self, hook_type, module, is_forward): is_module_hook = hook_type == Const.MODULE + if hasattr(module, 'async_op_dump_flag') and getattr(module, 'async_op_dump_flag'): + return False if is_module_hook and not Runtime.is_running: return False elif not is_module_hook and is_forward and not Runtime.is_running: diff --git a/debug/accuracy_tools/msprobe/pytorch/hook_module/api_register.py b/debug/accuracy_tools/msprobe/pytorch/hook_module/api_register.py index 3edbe3a8f32..7c2afeafd85 100644 --- a/debug/accuracy_tools/msprobe/pytorch/hook_module/api_register.py +++ b/debug/accuracy_tools/msprobe/pytorch/hook_module/api_register.py @@ -44,6 +44,8 @@ torch_version_above_2 = torch.__version__.split('+')[0] > '2.0' _inner_used_api = {} _supported_api_list_path = (os.path.join(os.path.dirname(os.path.realpath(__file__)), Const.SUPPORT_API_FILE_NAME),) _cuda_func_mapping = {"npu_fusion_attention": "gpu_fusion_attention"} +dist_data_collect_func = {} +origin_wait = getattr(dist.Work, 'wait') _api_types = { Const.PT_FRAMEWORK: { @@ -94,16 +96,35 @@ def dist_module_forward(module, *args, **kwargs): use_async_op_flag = False logger.warning(f"fail to get dist api's func signature because {e}, no wait") - if use_async_op_flag or module.api_name in ["isend", "irecv"]: - if handle and hasattr(handle, 'wait'): - handle.wait() - if module.api_name == "batch_isend_irecv": + def create_async_callback_func(catch_func): + def store_data(): + module.async_op_dump_flag = False + catch_func(module, args, kwargs, handle) + return store_data + + if len(module._forward_hooks.values()) == 0: + return handle + if use_async_op_flag or module.api_name in ['isend', 'irecv']: + module.async_op_dump_flag = True + dist_data_collect_func[handle] = create_async_callback_func(list(module._forward_hooks.values())[0]) + if module.api_name == 'batch_isend_irecv': if isinstance(handle, list): for req in handle: - req.wait() + dist_data_collect_func[req] = create_async_callback_func(list(module._forward_hooks.values())[0]) return handle +def redirect_wait(): + def wrapped_wait(work): + def wrapped_wait(*args, **kwargs): + origin_wait(*args, **kwargs) + if args[0] in dist_data_collect_func: + store_func = dist_data_collect_func.pop(args[0]) + store_func() + return wrapped_wait + dist.Work.wait = wrapped_wait(dist.Work) + + def npu_module_forward(module, *args, **kwargs): if not module.need_hook: if module.api_name not in npu_custom_functions: @@ -130,6 +151,7 @@ class ApiTemplate(HOOKModule): self.prefix_api_name = prefix + Const.SEP + str(api_name.split(Const.SEP)[-1]) + Const.SEP self.need_hook = need_hook self.device = device + self.async_op_dump_flag = False if self.need_hook: super().__init__(hook_build_func) if prefix == Const.DIST_API_TYPE_PREFIX: @@ -143,6 +165,7 @@ class ApiTemplate(HOOKModule): api_register = None +redirect_wait() def get_api_register(return_new=False): -- Gitee