From f00970141aa9dcf0453591f22813c59433c22d87 Mon Sep 17 00:00:00 2001 From: jiangchao_j Date: Wed, 26 Mar 2025 20:00:43 +0800 Subject: [PATCH] Ensure the forward data name of a module is consistent with it's backward data name --- .../msprobe/core/data_dump/data_collector.py | 2 - .../dump/module_dump/module_processer.py | 149 +++++++++--------- .../accuracy_tools/msprobe/pytorch/service.py | 35 ++-- 3 files changed, 91 insertions(+), 95 deletions(-) diff --git a/debug/accuracy_tools/msprobe/core/data_dump/data_collector.py b/debug/accuracy_tools/msprobe/core/data_dump/data_collector.py index 20e4489f8..83de7384c 100644 --- a/debug/accuracy_tools/msprobe/core/data_dump/data_collector.py +++ b/debug/accuracy_tools/msprobe/core/data_dump/data_collector.py @@ -179,8 +179,6 @@ class DataCollector: self.data_writer.update_construct({self.optimizer_status: None}) self.optimizer_status_first_start[self.optimizer_status] = False self.data_writer.update_construct({name: self.optimizer_status}) - else: - self.data_writer.update_construct({name: self.module_processor.api_parent_node}) self.data_writer.update_construct(self.module_processor.module_node) def handle_data(self, name, data_info, flush=False): diff --git a/debug/accuracy_tools/msprobe/pytorch/dump/module_dump/module_processer.py b/debug/accuracy_tools/msprobe/pytorch/dump/module_dump/module_processer.py index ae75d3c65..97ba556e1 100644 --- a/debug/accuracy_tools/msprobe/pytorch/dump/module_dump/module_processer.py +++ b/debug/accuracy_tools/msprobe/pytorch/dump/module_dump/module_processer.py @@ -42,8 +42,9 @@ def replace_checkpoint(): class ModuleProcesser: module_count = {} module_stack = [] - api_parent_node = "" module_node = {} + module_bw_hook_kernels = {} + module_with_backward_hook = {} def __init__(self, scope): self.scope = scope if isinstance(scope, (ModuleRangeScope, MixRangeScope)) else None @@ -75,7 +76,7 @@ class ModuleProcesser: return result @staticmethod - def module_count_func(module_name): + def set_and_get_calls_number(module_name): if module_name not in ModuleProcesser.module_count: ModuleProcesser.module_count[module_name] = 0 else: @@ -102,7 +103,6 @@ class ModuleProcesser: def reset_module_stats(cls): cls.module_count = {} cls.module_stack = [] - cls.api_parent_node = "" cls.module_node = {} def register_module_hook(self, models, build_hook): @@ -115,96 +115,97 @@ class ModuleProcesser: continue if module.__class__.__name__ == "FullyShardedDataParallel": continue + module_index = (index + Const.SEP) if index != "-1" else "" prefix_name = (BaseScope.Module_Type_Module + Const.SEP + module_index + name + Const.SEP + module.__class__.__name__ + Const.SEP) - pre_forward_hook, forward_hook, backward_hook, forward_hook_torch_version_below_2 = build_hook( - BaseScope.Module_Type_Module, - prefix_name - ) + + forward_pre_hook, forward_hook = self.build_module_hook(prefix_name, build_hook) if self.has_register_backward_hook(module): logger.warning( f"The {prefix_name[:-1]} has registered deprecated register_backward_hook," f"which may cause abnormal data dump. The backward data dump for this module will be skipped." ) + ModuleProcesser.module_with_backward_hook[prefix_name] = True + module.register_forward_pre_hook(forward_pre_hook) if torch_version_above_or_equal_2: module.register_forward_hook(forward_hook, with_kwargs=True) else: - if not self.has_register_backward_hook(module): - module.register_full_backward_hook(self.node_hook(prefix_name + Const.BACKWARD, Const.STOP)) - module.register_forward_hook(forward_hook_torch_version_below_2) - if not self.has_register_backward_hook(module): - module.register_full_backward_hook(backward_hook) - - module.register_forward_pre_hook(self.node_hook(prefix_name + Const.FORWARD, Const.START)) - module.register_forward_hook(self.node_hook(prefix_name + Const.FORWARD, Const.STOP)) - if torch_version_above_or_equal_2 and not self.has_register_backward_hook(module): - module.register_full_backward_pre_hook(self.node_hook(prefix_name + Const.BACKWARD, Const.START)) - module.register_full_backward_hook(self.node_hook(prefix_name + Const.BACKWARD, Const.STOP)) - - def node_hook(self, name_prefix, start_or_stop, **kwargs): - - def pre_hook(module, input, output=None): - try: - index = ModuleProcesser.module_count_func(name_prefix) - except IndexError as e: - index = None - pass + module.register_forward_hook(forward_hook) + + def build_module_hook(self, module_name, build_data_hook): + def forward_pre_hook(module, args, kwargs=None): + forward_name_prefix = module_name + Const.FORWARD + backward_name_prefix = module_name + Const.BACKWARD + index = ModuleProcesser.set_and_get_calls_number(module_name) + full_forward_name = forward_name_prefix + Const.SEP + str(index) + full_backward_name = backward_name_prefix + Const.SEP + str(index) + + self.set_construct_info(full_forward_name, 'begin', Const.FORWARD) + + _, _, backward_data_hook = build_data_hook(BaseScope.Module_Type_Module, full_forward_name) + + def get_backward_pre_hook(full_backward_name): + def backward_pre_hook_fn(module, grad_output): + self.set_construct_info(full_backward_name, 'begin', Const.BACKWARD) + return backward_pre_hook_fn + + def get_backward_hook(backward_data_hook, full_backward_name): + def backward_hook_fn(module, grad_input, grad_output): + self.set_construct_info(full_backward_name, 'end', Const.BACKWARD) + return backward_data_hook(module, grad_input, grad_output) + return backward_hook_fn + + if not ModuleProcesser.module_with_backward_hook.get(module_name): + backward_pre_hook = get_backward_pre_hook(full_backward_name) + backward_hook = get_backward_hook(backward_data_hook, full_backward_name) + if torch_version_above_or_equal_2: + bw_hook = BackwardHook(module, [backward_hook], [backward_pre_hook]) + else: + bw_hook = BackwardHook(module, [backward_hook]) + ModuleProcesser.module_bw_hook_kernels[full_forward_name] = bw_hook + args = bw_hook.setup_input_hook(args) + return args + + def forward_hook(module, args, kwargs_or_output, output_or_kwargs=None): + name_prefix = module_name + Const.FORWARD + index = ModuleProcesser.module_count.get(module_name) full_name = name_prefix + Const.SEP + str(index) - if not hasattr(module, "mindstudio_reserved_name") or not module.mindstudio_reserved_name: - module.mindstudio_reserved_name = [] - module.mindstudio_reserved_name.append(full_name) + + self.set_construct_info(full_name, 'end', Const.FORWARD) + + _, forward_data_hook, _ = build_data_hook(BaseScope.Module_Type_Module, full_name) + hook_result = forward_data_hook(module, args, kwargs_or_output, output_or_kwargs) + result = hook_result + + bw_hook = ModuleProcesser.module_bw_hook_kernels.get(full_name) + if bw_hook: + if not isinstance(result, (torch.Tensor, tuple)): + logger.warning("For backward hooks to be called," + " module output should be a Tensor or a tuple of Tensors" + f" but received {type(result)}") + result = bw_hook.setup_output_hook(result) + + return result + + return forward_pre_hook, forward_hook + + def set_construct_info(self, full_name, begin_or_end, forward_or_backward): + if begin_or_end == 'begin': if self.module_stack: ModuleProcesser.module_node[full_name] = self.module_stack[-1] else: ModuleProcesser.module_node[full_name] = None - ModuleProcesser.module_stack.append(full_name) - if self.module_stack: - ModuleProcesser.api_parent_node = self.module_stack[-1] if self.scope: self.scope.begin_module(full_name) - - def end_hook(module, input, output=None): - if self.module_stack: - ModuleProcesser.module_stack.pop() - if self.module_stack: - ModuleProcesser.api_parent_node = self.module_stack[-1] - else: - ModuleProcesser.api_parent_node = None - if not hasattr(module, "mindstudio_reserved_name") or not module.mindstudio_reserved_name: - raise RuntimeError(f"module reserve name is None when pop") - current_name = module.mindstudio_reserved_name.pop() - if self.scope: - self.scope.end_module(current_name) - - def backward_hook(module, input, output=None): - try: - index = ModuleProcesser.module_count_func(name_prefix) - except IndexError as e: - index = None - pass - full_name = name_prefix + Const.SEP + str(index) - if not hasattr(module, "mindstudio_reserved_name") or not module.mindstudio_reserved_name: - module.mindstudio_reserved_name = [] - module.mindstudio_reserved_name.append(full_name) - forward_full_name = replace_last_occurrence(full_name, Const.BACKWARD, Const.FORWARD) - ModuleProcesser.module_node[full_name] = replace_last_occurrence( - ModuleProcesser.module_node.get(forward_full_name), Const.FORWARD, Const.BACKWARD) - ModuleProcesser.api_parent_node = None - if self.scope: - self.scope.begin_module(full_name) - - if torch_version_above_or_equal_2: - if Const.START in start_or_stop: - return pre_hook - else: - return end_hook else: - if Const.FORWARD in name_prefix and Const.START in start_or_stop: - return pre_hook - elif Const.BACKWARD in name_prefix: - return backward_hook + if not torch_version_above_or_equal_2 and forward_or_backward == Const.BACKWARD: + if self.scope: + self.scope.begin_module(full_name) else: - return end_hook + if self.module_stack: + ModuleProcesser.module_stack.pop() + if self.scope: + self.scope.end_module(full_name) diff --git a/debug/accuracy_tools/msprobe/pytorch/service.py b/debug/accuracy_tools/msprobe/pytorch/service.py index b0b278032..98511f2fc 100644 --- a/debug/accuracy_tools/msprobe/pytorch/service.py +++ b/debug/accuracy_tools/msprobe/pytorch/service.py @@ -18,6 +18,8 @@ import os from collections import namedtuple, defaultdict import torch +from torch.utils.hooks import BackwardHook + from msprobe.core.common.const import Const from msprobe.core.common.exceptions import DistributedNotInitializedError from msprobe.core.common.file_utils import create_directory @@ -27,7 +29,7 @@ from msprobe.core.data_dump.data_processor.base import ModuleForwardInputsOutput from msprobe.core.data_dump.scope import BaseScope from msprobe.pytorch.api_accuracy_checker.common.utils import ApiData from msprobe.pytorch.common.log import logger -from msprobe.pytorch.common.utils import get_rank_if_initialized, is_recomputation +from msprobe.pytorch.common.utils import get_rank_if_initialized, is_recomputation, replace_last_occurrence from msprobe.pytorch.dump.kernel_dump.kernel_config import create_kernel_config_json from msprobe.pytorch.dump.module_dump.module_processer import ModuleProcesser from msprobe.pytorch.hook_module.api_register import get_api_register @@ -38,7 +40,7 @@ torch_version_above_or_equal_2 = torch.__version__.split('+')[0] >= '2.0' if torch_version_above_or_equal_2: from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.dump_dispatch import run_ut_dispatch -HookFn = namedtuple('hookFn', ['pre_hook', 'forward_hook', 'backward_hook', 'forward_hook_torch_version_below_2']) +HookFn = namedtuple('hookFn', ['pre_hook', 'forward_hook', 'backward_hook']) class Service: @@ -73,6 +75,10 @@ class Service: self.inner_switch = True if module_type == BaseScope.Module_Type_Module: api_or_module_name = module.mindstudio_reserved_name[-1] + if torch_version_above_or_equal_2: + pass + else: + pass else: module.forward_data_collected = True HOOKModule.add_module_count(name) @@ -139,10 +145,12 @@ class Service: # 记录当前模块的参数梯度信息已占位 self.params_grad_info[grad_name] = True - def forward_hook(api_or_module_name, module, args, kwargs, output): + def forward_hook(api_or_module_name, module, args, kwargs_or_output, output_or_kwargs=None): if not self.should_execute_hook(module_type, module, True): return None is_recompute = is_recomputation() + kwargs = kwargs_or_output if torch_version_above_or_equal_2 else {} + output = output_or_kwargs if torch_version_above_or_equal_2 else kwargs_or_output self.inner_switch = True if self.config.online_run_ut: @@ -162,9 +170,8 @@ class Service: return None module_input_output = ModuleForwardInputsOutputs(args=args, kwargs=kwargs, output=output) + self.data_collector.update_api_or_module_name(api_or_module_name) if module_type == BaseScope.Module_Type_Module: - api_or_module_name = module.mindstudio_reserved_name[-1] - self.data_collector.update_api_or_module_name(api_or_module_name) params_dict = {} if self.config.task != Const.STRUCTURE: params_dict = { @@ -188,7 +195,6 @@ class Service: ) init_params_grad_info(module, params_dict) else: - self.data_collector.update_api_or_module_name(api_or_module_name) self.data_collector.forward_output_data_collect( api_or_module_name, module, @@ -204,17 +210,12 @@ class Service: self.inner_switch = False return output - def forward_hook_torch_version_below_2(api_or_module_name, module, args, output): - return forward_hook(api_or_module_name, module, args, {}, output) - def backward_hook(api_or_module_name, module, grad_input, grad_output): if not self.should_execute_hook(module_type, module, False): return is_recompute = is_recomputation() self.inner_switch = True - if module_type == BaseScope.Module_Type_Module: - api_or_module_name = module.mindstudio_reserved_name[-1] self.data_collector.update_api_or_module_name(api_or_module_name) if self.config.online_run_ut: @@ -234,19 +235,15 @@ class Service: self.inner_switch = False pid = os.getpid() - full_forward_name = None - full_backward_name = None + full_forward_name = name if module_type == BaseScope.Module_Type_API: full_forward_name = name + str(HOOKModule.get_module_count(name)) + Const.SEP + Const.FORWARD - full_backward_name = name + str(HOOKModule.get_module_count(name)) + Const.SEP + Const.BACKWARD + full_backward_name = replace_last_occurrence(full_forward_name, Const.FORWARD, Const.BACKWARD) pre_forward_hook_fn = functools.partial(pre_hook, full_forward_name) forward_hook_fn = functools.partial(forward_hook, full_forward_name) backward_hook_fn = functools.partial(backward_hook, full_backward_name) - forward_hook_torch_version_below_2_fn = functools.partial( - forward_hook_torch_version_below_2, - full_forward_name - ) - return HookFn(pre_forward_hook_fn, forward_hook_fn, backward_hook_fn, forward_hook_torch_version_below_2_fn) + + return HookFn(pre_forward_hook_fn, forward_hook_fn, backward_hook_fn) def start(self, model): self.current_iter = self.loop + self.init_step -- Gitee