From 78e8c3f759a88f6de316b1e727b75a08fe5c8fe1 Mon Sep 17 00:00:00 2001 From: jiangchao_j Date: Thu, 5 Jun 2025 19:46:46 +0800 Subject: [PATCH 1/2] test --- .../msprobe/core/data_dump/api_registry.py | 5 +- .../msprobe/core/hook_manager.py | 46 +++++++++++++------ .../dump/hook_cell/ms_hook_manager.py | 10 ++-- 3 files changed, 40 insertions(+), 21 deletions(-) diff --git a/debug/accuracy_tools/msprobe/core/data_dump/api_registry.py b/debug/accuracy_tools/msprobe/core/data_dump/api_registry.py index 9090c1fa206..890c3480b8a 100644 --- a/debug/accuracy_tools/msprobe/core/data_dump/api_registry.py +++ b/debug/accuracy_tools/msprobe/core/data_dump/api_registry.py @@ -103,6 +103,8 @@ class ApiWrapper: ori_api = _get_attr(api_modules[0], api_name) if callable(ori_api): def wrap_api_func(api_name, api_func, prefix, hook_build_func, api_template): + hooked_instance = api_template(api_name, api_func, prefix, hook_build_func) + def api_function(*args, **kwargs): api_name_with_prefix = prefix + Const.SEP + str(api_name.split(Const.SEP)[-1]) enable_wrap, args, kwargs = self.deal_with_self_kwargs(api_name_with_prefix, @@ -112,7 +114,8 @@ class ApiWrapper: 'It may be fixed by passing the value of "self" ' 'as a positional argument instead of a keyword argument. ') return api_func(*args, **kwargs) - return api_template(api_name, api_func, prefix, hook_build_func)(*args, **kwargs) + # return api_template(api_name, api_func, prefix, hook_build_func)(*args, **kwargs) + return hooked_instance(*args, **kwargs) api_function.__name__ = api_name return api_function wrapped_functions[api_name] = wrap_api_func(api_name, ori_api, name_prefix, diff --git a/debug/accuracy_tools/msprobe/core/hook_manager.py b/debug/accuracy_tools/msprobe/core/hook_manager.py index 17d9ee0b540..d54f4802d3c 100644 --- a/debug/accuracy_tools/msprobe/core/hook_manager.py +++ b/debug/accuracy_tools/msprobe/core/hook_manager.py @@ -140,8 +140,10 @@ class BaseHookManager(ABC): return return hook_fn - def _build_forward_pre_hook(self, hook_type, full_name, api_name): + def _build_forward_pre_hook(self, hook_type, full_name, api_name, HOOKCell): def forward_pre_hook(module, args, kwargs=None): + if hook_type == Const.API: + full_forward_name = api_name + str(HOOKCell.get_cell_count(api_name)) + Const.SEP + Const.FORWARD if hook_type == Const.MODULE: return if not self._should_execute_hook(hook_type, module, True): @@ -151,14 +153,14 @@ class BaseHookManager(ABC): with self._no_grad_context(): BaseHookManager.inner_switch = False module.forward_data_collected = True - self._add_count(api_name) + # self._add_count(api_name) module_input_output = ModuleForwardInputsOutputs(args=args, kwargs=kwargs, output=None) - self.data_collector.update_api_or_module_name(full_name) + self.data_collector.update_api_or_module_name(full_forward_name) if getattr(self.config, "online_run_ut", False): BaseHookManager.inner_switch = False return self.data_collector.forward_input_data_collect( - full_name, + full_forward_name, module, self._pid, module_input_output, @@ -167,31 +169,36 @@ class BaseHookManager(ABC): BaseHookManager.inner_switch = False return forward_pre_hook - def _build_forward_hook(self, hook_type, full_name): + def _build_forward_hook(self, hook_type, full_name, api_name, HOOKCell): def forward_hook(module, args, kwargs_or_output, output_or_kwargs=None): + if hook_type == Const.API: + full_forward_name = api_name + str(HOOKCell.get_cell_count(api_name)) + Const.SEP + Const.FORWARD + else: + full_forward_name = full_name + if not self._should_execute_hook(hook_type, module, True): self._clear_input_kwargs(module) return None kwargs, output = self._process_kwargs_and_output(module, hook_type, kwargs_or_output, output_or_kwargs) BaseHookManager.inner_switch = True - self.data_collector.update_api_or_module_name(full_name) + self.data_collector.update_api_or_module_name(full_forward_name) module_input_output = ModuleForwardInputsOutputs(args=args, kwargs=kwargs, output=output) with self._no_grad_context(): if getattr(self.config, "online_run_ut", False): - if self.data_collector.scope and not self.data_collector.scope.check(full_name): + if self.data_collector.scope and not self.data_collector.scope.check(full_forward_name): return None if self.attl_manager: - self.attl_manager.attl_send(full_name, args, kwargs, output) + self.attl_manager.attl_send(full_forward_name, args, kwargs, output) BaseHookManager.inner_switch = False return None if hook_type == Const.MODULE: params_dict = self._get_params_dict(module) setattr(module_input_output, Const.PARAMS, params_dict) if params_dict: - self._register_param_hook(full_name, module, params_dict) - self.data_collector.update_api_or_module_name(full_name) + self._register_param_hook(full_forward_name, module, params_dict) + self.data_collector.update_api_or_module_name(full_forward_name) self.data_collector.forward_data_collect( - full_name, + full_forward_name, module, self._pid, module_input_output, @@ -199,8 +206,9 @@ class BaseHookManager(ABC): ) self._init_params_grad_info(module, params_dict) else: + self._add_count(api_name) self.data_collector.forward_output_data_collect( - full_name, + full_forward_name, module, self._pid, module_input_output, @@ -217,12 +225,20 @@ class BaseHookManager(ABC): return output return forward_hook - def _build_backward_hook(self, hook_type, full_name): + def _build_backward_hook(self, hook_type, full_name, api_name, HOOKCell): def backward_hook(module, grad_input, grad_output): + if hook_type == Const.API: + full_backward_name = api_name + str(HOOKCell.get_cell_count(api_name) - 1 ) + Const.SEP + Const.BACKWARD + else: + full_backward_name = full_name + if not self._should_execute_hook(hook_type, module, False): return + if hook_type == Const.API: + HOOKCell.cell_count[api_name] = HOOKCell.cell_count[api_name] - 1 + BaseHookManager.inner_switch = True - self.data_collector.update_api_or_module_name(full_name) + self.data_collector.update_api_or_module_name(full_backward_name) if getattr(self.config, "online_run_ut", False): BaseHookManager.inner_switch = False return @@ -232,7 +248,7 @@ class BaseHookManager(ABC): else: module_input_output = ModuleBackwardInputsOutputs(grad_input=grad_input, grad_output=grad_output) self.data_collector.backward_data_collect( - full_name, + full_backward_name, module, self._pid, module_input_output, diff --git a/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/ms_hook_manager.py b/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/ms_hook_manager.py index 5581a44ca5e..2ce9b9cb68c 100644 --- a/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +++ b/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/ms_hook_manager.py @@ -52,10 +52,10 @@ class MindsproeHookManager(BaseHookManager): full_forward_name = name full_backward_name = replace_last_occurrence(full_forward_name, Const.FORWARD, Const.BACKWARD) hookset = HookSet( - forward_hook=self._build_forward_hook(hook_type, full_forward_name), - forward_pre_hook=self._build_forward_pre_hook(hook_type, full_forward_name, name), - backward_hook=self._build_backward_hook(hook_type, full_backward_name), - backward_pre_hook=self._build_backward_pre_hook(hook_type, full_backward_name) + forward_hook=self._build_forward_hook(hook_type, full_forward_name, name, HOOKCell), + forward_pre_hook=self._build_forward_pre_hook(hook_type, full_forward_name, name, HOOKCell), + backward_hook=self._build_backward_hook(hook_type, full_backward_name, name, HOOKCell), + backward_pre_hook=self._build_backward_pre_hook(hook_type, full_backward_name, name, HOOKCell) ) return hookset @@ -74,7 +74,7 @@ class MindsproeHookManager(BaseHookManager): } return params_dict - def _build_backward_pre_hook(self, hook_type, name): + def _build_backward_pre_hook(self, hook_type, name, hookCell): def backward_pre_hook(module, grad_input): if self.config.level != Const.LEVEL_L2: return -- Gitee From cf337100360b2ba036810da41c74bf960e9f227f Mon Sep 17 00:00:00 2001 From: jiangchao_j Date: Thu, 5 Jun 2025 19:49:49 +0800 Subject: [PATCH 2/2] V1.1 --- .../msprobe/mindspore/dump/hook_cell/ms_hook_manager.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/ms_hook_manager.py b/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/ms_hook_manager.py index 2ce9b9cb68c..d5536b688a8 100644 --- a/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +++ b/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/ms_hook_manager.py @@ -55,7 +55,7 @@ class MindsproeHookManager(BaseHookManager): forward_hook=self._build_forward_hook(hook_type, full_forward_name, name, HOOKCell), forward_pre_hook=self._build_forward_pre_hook(hook_type, full_forward_name, name, HOOKCell), backward_hook=self._build_backward_hook(hook_type, full_backward_name, name, HOOKCell), - backward_pre_hook=self._build_backward_pre_hook(hook_type, full_backward_name, name, HOOKCell) + backward_pre_hook=self._build_backward_pre_hook(hook_type, full_backward_name) ) return hookset @@ -74,7 +74,7 @@ class MindsproeHookManager(BaseHookManager): } return params_dict - def _build_backward_pre_hook(self, hook_type, name, hookCell): + def _build_backward_pre_hook(self, hook_type, name): def backward_pre_hook(module, grad_input): if self.config.level != Const.LEVEL_L2: return -- Gitee