diff --git a/debug/accuracy_tools/msprobe/core/data_dump/api_registry.py b/debug/accuracy_tools/msprobe/core/data_dump/api_registry.py index ecd7f9c0ce60f87fa5ed03726c408cf43665343a..d3e4fb7fc03d34d08fcaa2fbdb229dc3019ce823 100644 --- a/debug/accuracy_tools/msprobe/core/data_dump/api_registry.py +++ b/debug/accuracy_tools/msprobe/core/data_dump/api_registry.py @@ -101,6 +101,7 @@ class ApiWrapper: ori_api = _get_attr(api_modules[0], api_name) if callable(ori_api): def wrap_api_func(api_name, api_func, prefix, hook_build_func, api_template): + hooked_instance = api_template(api_name, api_func, prefix, hook_build_func) def api_function(*args, **kwargs): api_name_with_prefix = prefix + Const.SEP + str(api_name.split(Const.SEP)[-1]) enable_wrap, args, kwargs = self.deal_with_self_kwargs(api_name_with_prefix, @@ -110,7 +111,7 @@ class ApiWrapper: 'It may be fixed by passing the value of "self" ' 'as a positional argument instead of a keyword argument. ') return api_func(*args, **kwargs) - return api_template(api_name, api_func, prefix, hook_build_func)(*args, **kwargs) + return hooked_instance(*args, **kwargs) api_function.__name__ = api_name return api_function wrapped_functions[api_name] = wrap_api_func(api_name, ori_api, name_prefix, diff --git a/debug/accuracy_tools/msprobe/pytorch/hook_module/support_wrap_ops.yaml b/debug/accuracy_tools/msprobe/pytorch/hook_module/support_wrap_ops.yaml index f2d5d22ade2c52057b969a93b73e0897e5d64ae3..d46e4da953fa1352591ef74d996d771fd0986d4f 100644 --- a/debug/accuracy_tools/msprobe/pytorch/hook_module/support_wrap_ops.yaml +++ b/debug/accuracy_tools/msprobe/pytorch/hook_module/support_wrap_ops.yaml @@ -638,7 +638,7 @@ tensor: torch: - linalg.norm - - linalg.vector_norm + # - linalg.vector_norm - linalg.matrix_norm - linalg.diagonal - linalg.det @@ -974,15 +974,15 @@ torch: - matrix_exp - matrix_power - matrix_rank - - max + # - max - max_pool1d - max_pool1d_with_indices - max_pool2d - max_pool3d - maximum - - mean + # - mean - median - - min + # - min - minimum - mm - mode @@ -1013,7 +1013,7 @@ torch: - negative_ - nextafter - nonzero - - norm + # - norm - norm_except_dim - normal - not_equal diff --git a/debug/accuracy_tools/msprobe/pytorch/service.py b/debug/accuracy_tools/msprobe/pytorch/service.py index d230ce4dbf170682a77cd28649e81056ced7a7a4..134725acb4a43f45f4f5686fb4a9e605040fb716 100644 --- a/debug/accuracy_tools/msprobe/pytorch/service.py +++ b/debug/accuracy_tools/msprobe/pytorch/service.py @@ -65,7 +65,11 @@ class Service: self.init_for_debug_level() def build_hook(self, module_type, name): - def pre_hook(api_or_module_name, module, args, kwargs=None): + + def pre_hook(module, args, kwargs=None): + api_or_module_name = name + if module_type == BaseScope.Module_Type_API: + api_or_module_name = name + str(HOOKModule.get_module_count(name)) + Const.SEP + Const.FORWARD kwargs = {} if kwargs is None else kwargs if module_type == BaseScope.Module_Type_Module or \ @@ -75,7 +79,6 @@ class Service: self.inner_switch = True module.forward_data_collected = True - HOOKModule.add_module_count(name) self.data_collector.update_api_or_module_name(api_or_module_name) if self.config.online_run_ut: @@ -138,13 +141,18 @@ class Service: # 记录当前模块的参数梯度信息已占位 self.params_grad_info[grad_name] = True - def forward_hook(api_or_module_name, module, args, kwargs_or_output, output_or_kwargs=None): + def forward_hook(module, args, kwargs_or_output, output_or_kwargs=None): if not self.should_execute_hook(module_type, module, True): return None is_recompute = is_recomputation() + api_or_module_name = name + if module_type == BaseScope.Module_Type_API: + api_or_module_name = name + str(HOOKModule.get_module_count(name)) + Const.SEP + Const.FORWARD kwargs = kwargs_or_output if torch_version_above_or_equal_2 else {} output = output_or_kwargs if torch_version_above_or_equal_2 else kwargs_or_output + HOOKModule.add_module_count(name) + self.inner_switch = True if self.config.online_run_ut: self.data_collector.update_api_or_module_name(api_or_module_name) @@ -203,9 +211,13 @@ class Service: self.inner_switch = False return output - def backward_hook(api_or_module_name, module, grad_input, grad_output): + def backward_hook(module, grad_input, grad_output): if not self.should_execute_hook(module_type, module, False): return + api_or_module_name = name + if module_type == BaseScope.Module_Type_API: + api_or_module_name = name + str(HOOKModule.get_module_count(name) - 1) + Const.SEP + Const.FORWARD + api_or_module_name = replace_last_occurrence(api_or_module_name, Const.FORWARD, Const.BACKWARD) is_recompute = is_recomputation() self.inner_switch = True @@ -228,13 +240,10 @@ class Service: self.inner_switch = False pid = os.getpid() - full_forward_name = name - if module_type == BaseScope.Module_Type_API: - full_forward_name = name + str(HOOKModule.get_module_count(name)) + Const.SEP + Const.FORWARD - full_backward_name = replace_last_occurrence(full_forward_name, Const.FORWARD, Const.BACKWARD) - pre_forward_hook_fn = functools.partial(pre_hook, full_forward_name) - forward_hook_fn = functools.partial(forward_hook, full_forward_name) - backward_hook_fn = functools.partial(backward_hook, full_backward_name) + + pre_forward_hook_fn = pre_hook + forward_hook_fn = forward_hook + backward_hook_fn = backward_hook return pre_forward_hook_fn, forward_hook_fn, backward_hook_fn