From fceeb07fe3ebad46c254e2a2ceb9bffbd13aa6a4 Mon Sep 17 00:00:00 2001 From: qiangge Date: Fri, 9 May 2025 17:26:53 +0800 Subject: [PATCH 1/2] instance apitemplate once --- .../msprobe/core/data_dump/api_registry.py | 3 ++- .../pytorch/hook_module/support_wrap_ops.yaml | 10 +++---- .../accuracy_tools/msprobe/pytorch/service.py | 27 ++++++++++++------- 3 files changed, 24 insertions(+), 16 deletions(-) diff --git a/debug/accuracy_tools/msprobe/core/data_dump/api_registry.py b/debug/accuracy_tools/msprobe/core/data_dump/api_registry.py index ecd7f9c0ce6..d3e4fb7fc03 100644 --- a/debug/accuracy_tools/msprobe/core/data_dump/api_registry.py +++ b/debug/accuracy_tools/msprobe/core/data_dump/api_registry.py @@ -101,6 +101,7 @@ class ApiWrapper: ori_api = _get_attr(api_modules[0], api_name) if callable(ori_api): def wrap_api_func(api_name, api_func, prefix, hook_build_func, api_template): + hooked_instance = api_template(api_name, api_func, prefix, hook_build_func) def api_function(*args, **kwargs): api_name_with_prefix = prefix + Const.SEP + str(api_name.split(Const.SEP)[-1]) enable_wrap, args, kwargs = self.deal_with_self_kwargs(api_name_with_prefix, @@ -110,7 +111,7 @@ class ApiWrapper: 'It may be fixed by passing the value of "self" ' 'as a positional argument instead of a keyword argument. ') return api_func(*args, **kwargs) - return api_template(api_name, api_func, prefix, hook_build_func)(*args, **kwargs) + return hooked_instance(*args, **kwargs) api_function.__name__ = api_name return api_function wrapped_functions[api_name] = wrap_api_func(api_name, ori_api, name_prefix, diff --git a/debug/accuracy_tools/msprobe/pytorch/hook_module/support_wrap_ops.yaml b/debug/accuracy_tools/msprobe/pytorch/hook_module/support_wrap_ops.yaml index f2d5d22ade2..d46e4da953f 100644 --- a/debug/accuracy_tools/msprobe/pytorch/hook_module/support_wrap_ops.yaml +++ b/debug/accuracy_tools/msprobe/pytorch/hook_module/support_wrap_ops.yaml @@ -638,7 +638,7 @@ tensor: torch: - linalg.norm - - linalg.vector_norm + # - linalg.vector_norm - linalg.matrix_norm - linalg.diagonal - linalg.det @@ -974,15 +974,15 @@ torch: - matrix_exp - matrix_power - matrix_rank - - max + # - max - max_pool1d - max_pool1d_with_indices - max_pool2d - max_pool3d - maximum - - mean + # - mean - median - - min + # - min - minimum - mm - mode @@ -1013,7 +1013,7 @@ torch: - negative_ - nextafter - nonzero - - norm + # - norm - norm_except_dim - normal - not_equal diff --git a/debug/accuracy_tools/msprobe/pytorch/service.py b/debug/accuracy_tools/msprobe/pytorch/service.py index d230ce4dbf1..52e1392f929 100644 --- a/debug/accuracy_tools/msprobe/pytorch/service.py +++ b/debug/accuracy_tools/msprobe/pytorch/service.py @@ -65,7 +65,11 @@ class Service: self.init_for_debug_level() def build_hook(self, module_type, name): - def pre_hook(api_or_module_name, module, args, kwargs=None): + + def pre_hook(module, args, kwargs=None): + api_or_module_name = name + if module_type == BaseScope.Module_Type_API: + api_or_module_name = name + str(HOOKModule.get_module_count(name)) + Const.SEP + Const.FORWARD kwargs = {} if kwargs is None else kwargs if module_type == BaseScope.Module_Type_Module or \ @@ -138,10 +142,13 @@ class Service: # 记录当前模块的参数梯度信息已占位 self.params_grad_info[grad_name] = True - def forward_hook(api_or_module_name, module, args, kwargs_or_output, output_or_kwargs=None): + def forward_hook(module, args, kwargs_or_output, output_or_kwargs=None): if not self.should_execute_hook(module_type, module, True): return None is_recompute = is_recomputation() + api_or_module_name = name + if module_type == BaseScope.Module_Type_API: + api_or_module_name = name + str(HOOKModule.get_module_count(name)) + Const.SEP + Const.FORWARD kwargs = kwargs_or_output if torch_version_above_or_equal_2 else {} output = output_or_kwargs if torch_version_above_or_equal_2 else kwargs_or_output @@ -203,9 +210,12 @@ class Service: self.inner_switch = False return output - def backward_hook(api_or_module_name, module, grad_input, grad_output): + def backward_hook(module, grad_input, grad_output): if not self.should_execute_hook(module_type, module, False): return + api_or_module_name = name + if module_type == BaseScope.Module_Type_API: + api_or_module_name = name + str(HOOKModule.get_module_count(name)) + Const.SEP + Const.FORWARD is_recompute = is_recomputation() self.inner_switch = True @@ -228,13 +238,10 @@ class Service: self.inner_switch = False pid = os.getpid() - full_forward_name = name - if module_type == BaseScope.Module_Type_API: - full_forward_name = name + str(HOOKModule.get_module_count(name)) + Const.SEP + Const.FORWARD - full_backward_name = replace_last_occurrence(full_forward_name, Const.FORWARD, Const.BACKWARD) - pre_forward_hook_fn = functools.partial(pre_hook, full_forward_name) - forward_hook_fn = functools.partial(forward_hook, full_forward_name) - backward_hook_fn = functools.partial(backward_hook, full_backward_name) + + pre_forward_hook_fn = pre_hook + forward_hook_fn = forward_hook + backward_hook_fn = backward_hook return pre_forward_hook_fn, forward_hook_fn, backward_hook_fn -- Gitee From d7f34e63165886c083c07393261c96ea498b3fb3 Mon Sep 17 00:00:00 2001 From: qiangge Date: Fri, 9 May 2025 18:24:16 +0800 Subject: [PATCH 2/2] fix count --- debug/accuracy_tools/msprobe/pytorch/service.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/debug/accuracy_tools/msprobe/pytorch/service.py b/debug/accuracy_tools/msprobe/pytorch/service.py index 52e1392f929..134725acb4a 100644 --- a/debug/accuracy_tools/msprobe/pytorch/service.py +++ b/debug/accuracy_tools/msprobe/pytorch/service.py @@ -79,7 +79,6 @@ class Service: self.inner_switch = True module.forward_data_collected = True - HOOKModule.add_module_count(name) self.data_collector.update_api_or_module_name(api_or_module_name) if self.config.online_run_ut: @@ -152,6 +151,8 @@ class Service: kwargs = kwargs_or_output if torch_version_above_or_equal_2 else {} output = output_or_kwargs if torch_version_above_or_equal_2 else kwargs_or_output + HOOKModule.add_module_count(name) + self.inner_switch = True if self.config.online_run_ut: self.data_collector.update_api_or_module_name(api_or_module_name) @@ -215,7 +216,8 @@ class Service: return api_or_module_name = name if module_type == BaseScope.Module_Type_API: - api_or_module_name = name + str(HOOKModule.get_module_count(name)) + Const.SEP + Const.FORWARD + api_or_module_name = name + str(HOOKModule.get_module_count(name) - 1) + Const.SEP + Const.FORWARD + api_or_module_name = replace_last_occurrence(api_or_module_name, Const.FORWARD, Const.BACKWARD) is_recompute = is_recomputation() self.inner_switch = True -- Gitee