diff --git a/debug/accuracy_tools/msprobe/core/data_dump/api_registry.py b/debug/accuracy_tools/msprobe/core/data_dump/api_registry.py index 9090c1fa206f7149d3094ac2e2066c580b6ec1f7..890c3480b8ae9606e4032e4325f779ef780e097a 100644 --- a/debug/accuracy_tools/msprobe/core/data_dump/api_registry.py +++ b/debug/accuracy_tools/msprobe/core/data_dump/api_registry.py @@ -103,6 +103,8 @@ class ApiWrapper: ori_api = _get_attr(api_modules[0], api_name) if callable(ori_api): def wrap_api_func(api_name, api_func, prefix, hook_build_func, api_template): + hooked_instance = api_template(api_name, api_func, prefix, hook_build_func) + def api_function(*args, **kwargs): api_name_with_prefix = prefix + Const.SEP + str(api_name.split(Const.SEP)[-1]) enable_wrap, args, kwargs = self.deal_with_self_kwargs(api_name_with_prefix, @@ -112,7 +114,8 @@ class ApiWrapper: 'It may be fixed by passing the value of "self" ' 'as a positional argument instead of a keyword argument. ') return api_func(*args, **kwargs) - return api_template(api_name, api_func, prefix, hook_build_func)(*args, **kwargs) + # return api_template(api_name, api_func, prefix, hook_build_func)(*args, **kwargs) + return hooked_instance(*args, **kwargs) api_function.__name__ = api_name return api_function wrapped_functions[api_name] = wrap_api_func(api_name, ori_api, name_prefix, diff --git a/debug/accuracy_tools/msprobe/core/hook_manager.py b/debug/accuracy_tools/msprobe/core/hook_manager.py index 17d9ee0b54023789bd6624c8c48c7f2c460c0a03..d54f4802d3caeddad3a5ac48cab46b43b3f54c7d 100644 --- a/debug/accuracy_tools/msprobe/core/hook_manager.py +++ b/debug/accuracy_tools/msprobe/core/hook_manager.py @@ -140,8 +140,10 @@ class BaseHookManager(ABC): return return hook_fn - def _build_forward_pre_hook(self, hook_type, full_name, api_name): + def _build_forward_pre_hook(self, hook_type, full_name, api_name, HOOKCell): def forward_pre_hook(module, args, kwargs=None): + if hook_type == Const.API: + full_forward_name = api_name + str(HOOKCell.get_cell_count(api_name)) + Const.SEP + Const.FORWARD if hook_type == Const.MODULE: return if not self._should_execute_hook(hook_type, module, True): @@ -151,14 +153,14 @@ class BaseHookManager(ABC): with self._no_grad_context(): BaseHookManager.inner_switch = False module.forward_data_collected = True - self._add_count(api_name) + # self._add_count(api_name) module_input_output = ModuleForwardInputsOutputs(args=args, kwargs=kwargs, output=None) - self.data_collector.update_api_or_module_name(full_name) + self.data_collector.update_api_or_module_name(full_forward_name) if getattr(self.config, "online_run_ut", False): BaseHookManager.inner_switch = False return self.data_collector.forward_input_data_collect( - full_name, + full_forward_name, module, self._pid, module_input_output, @@ -167,31 +169,36 @@ class BaseHookManager(ABC): BaseHookManager.inner_switch = False return forward_pre_hook - def _build_forward_hook(self, hook_type, full_name): + def _build_forward_hook(self, hook_type, full_name, api_name, HOOKCell): def forward_hook(module, args, kwargs_or_output, output_or_kwargs=None): + if hook_type == Const.API: + full_forward_name = api_name + str(HOOKCell.get_cell_count(api_name)) + Const.SEP + Const.FORWARD + else: + full_forward_name = full_name + if not self._should_execute_hook(hook_type, module, True): self._clear_input_kwargs(module) return None kwargs, output = self._process_kwargs_and_output(module, hook_type, kwargs_or_output, output_or_kwargs) BaseHookManager.inner_switch = True - self.data_collector.update_api_or_module_name(full_name) + self.data_collector.update_api_or_module_name(full_forward_name) module_input_output = ModuleForwardInputsOutputs(args=args, kwargs=kwargs, output=output) with self._no_grad_context(): if getattr(self.config, "online_run_ut", False): - if self.data_collector.scope and not self.data_collector.scope.check(full_name): + if self.data_collector.scope and not self.data_collector.scope.check(full_forward_name): return None if self.attl_manager: - self.attl_manager.attl_send(full_name, args, kwargs, output) + self.attl_manager.attl_send(full_forward_name, args, kwargs, output) BaseHookManager.inner_switch = False return None if hook_type == Const.MODULE: params_dict = self._get_params_dict(module) setattr(module_input_output, Const.PARAMS, params_dict) if params_dict: - self._register_param_hook(full_name, module, params_dict) - self.data_collector.update_api_or_module_name(full_name) + self._register_param_hook(full_forward_name, module, params_dict) + self.data_collector.update_api_or_module_name(full_forward_name) self.data_collector.forward_data_collect( - full_name, + full_forward_name, module, self._pid, module_input_output, @@ -199,8 +206,9 @@ class BaseHookManager(ABC): ) self._init_params_grad_info(module, params_dict) else: + self._add_count(api_name) self.data_collector.forward_output_data_collect( - full_name, + full_forward_name, module, self._pid, module_input_output, @@ -217,12 +225,20 @@ class BaseHookManager(ABC): return output return forward_hook - def _build_backward_hook(self, hook_type, full_name): + def _build_backward_hook(self, hook_type, full_name, api_name, HOOKCell): def backward_hook(module, grad_input, grad_output): + if hook_type == Const.API: + full_backward_name = api_name + str(HOOKCell.get_cell_count(api_name) - 1 ) + Const.SEP + Const.BACKWARD + else: + full_backward_name = full_name + if not self._should_execute_hook(hook_type, module, False): return + if hook_type == Const.API: + HOOKCell.cell_count[api_name] = HOOKCell.cell_count[api_name] - 1 + BaseHookManager.inner_switch = True - self.data_collector.update_api_or_module_name(full_name) + self.data_collector.update_api_or_module_name(full_backward_name) if getattr(self.config, "online_run_ut", False): BaseHookManager.inner_switch = False return @@ -232,7 +248,7 @@ class BaseHookManager(ABC): else: module_input_output = ModuleBackwardInputsOutputs(grad_input=grad_input, grad_output=grad_output) self.data_collector.backward_data_collect( - full_name, + full_backward_name, module, self._pid, module_input_output, diff --git a/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/ms_hook_manager.py b/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/ms_hook_manager.py index 5581a44ca5eb1f1f0ecab4b30255b5c4e09f8b5a..d5536b688a8742a24586fe4e00f2cce9b999a4ac 100644 --- a/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +++ b/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/ms_hook_manager.py @@ -52,10 +52,10 @@ class MindsproeHookManager(BaseHookManager): full_forward_name = name full_backward_name = replace_last_occurrence(full_forward_name, Const.FORWARD, Const.BACKWARD) hookset = HookSet( - forward_hook=self._build_forward_hook(hook_type, full_forward_name), - forward_pre_hook=self._build_forward_pre_hook(hook_type, full_forward_name, name), - backward_hook=self._build_backward_hook(hook_type, full_backward_name), - backward_pre_hook=self._build_backward_pre_hook(hook_type, full_backward_name) + forward_hook=self._build_forward_hook(hook_type, full_forward_name, name, HOOKCell), + forward_pre_hook=self._build_forward_pre_hook(hook_type, full_forward_name, name, HOOKCell), + backward_hook=self._build_backward_hook(hook_type, full_backward_name, name, HOOKCell), + backward_pre_hook=self._build_backward_pre_hook(hook_type, full_backward_name) ) return hookset