diff --git a/debug/accuracy_tools/msprobe/pytorch/monitor/module_hook.py b/debug/accuracy_tools/msprobe/pytorch/monitor/module_hook.py index 0d85e153d1901db11247813424d424c1eb78bc13..ca8ee58bf34b44b3adc7467c1eb6c92e85571b65 100644 --- a/debug/accuracy_tools/msprobe/pytorch/monitor/module_hook.py +++ b/debug/accuracy_tools/msprobe/pytorch/monitor/module_hook.py @@ -16,10 +16,12 @@ import time import json import os import uuid +import inspect from collections import defaultdict from datetime import datetime, timezone from functools import partial +import pandas as pd import pytz import torch import torch.distributed as dist @@ -38,8 +40,10 @@ from msprobe.pytorch.monitor.module_spec_verifier import validate_config_spec from msprobe.pytorch.monitor.optimizer_collect import OptimizerMonFactory, OptimizerMon from msprobe.pytorch.monitor.utils import get_param_struct, validate_config, validate_ops, is_recomputation from msprobe.pytorch.monitor.visualizer import HeatmapVisualizer +from msprobe.core.common.const import Const from torch.optim.optimizer import register_optimizer_step_pre_hook, register_optimizer_step_post_hook from torch.utils.hooks import BackwardHook +from msprobe.core.common.file_utils import change_mode, create_directory, write_df_to_csv try: import torch_npu @@ -82,6 +86,7 @@ class ModuleHookContext: self.focused_in_col = 0 self.focused_out_col = 0 self.ignore_in = False # no need to care when no key 'input' or 'input_grad' found + self.stack = "" def set_format_by_arg(self, key_name: str, target_config: dict): cared = target_config.get(self.module_name, self.struct) @@ -96,6 +101,24 @@ class ModuleHookContext: elif key_name in ['input', 'input_grad']: self.ignore_in = True + @staticmethod + def analyze_api_call_stack(): + try: + api_stack = inspect.stack()[5:] + except Exception as e: + logger.warning(f"The call stack of <{name}> failed to retrieve, {e}.") + api_stack = None + stack_str = [] + if api_stack: + for (_, path, line, func, code, _) in api_stack: + if not code: + continue + stack_line = f"File {path}, line {str(line)}, in {func}, \n {code[0].strip()}" + stack_str.append(stack_line) + else: + stack_str.append(Const.WITHOUT_CALL_STACK) + return str(stack_str) + class OptimizerContext: def __init__(self) -> None: @@ -179,9 +202,11 @@ class TrainerMon: self.ur_distribution = self.config.get('ur_distribution', False) self.mv_distribution = self.config.get("mv_distribution", False) self.wg_distribution = self.config.get("wg_distribution", False) + self.stack_info = self.config.get("stack_info", False) self.param_distribution = self.config.get("param_distribution", False) self.mg_direction = self.config.get('mg_direction', False) self.cc_distribution = self.config.get("cc_distribution", {}) + self.tensorboard_dir = "" if not self.cc_distribution.get('enable', False): self.cc_log_only = False else: @@ -205,12 +230,12 @@ class TrainerMon: if dist.is_initialized(): rank = dist.get_rank() - tensorboard_dir = os.path.join(output_base_dir, f"{cur_time}-rank{rank}-{unique_id}") + self.tensorboard_dir = os.path.join(output_base_dir, f"{cur_time}-rank{rank}-{unique_id}") pp_stage = dist.get_group_rank(self.process_group, rank) group_mates = dist.get_process_group_ranks(self.process_group) else: rank = 0 - tensorboard_dir = os.path.join(output_base_dir, f"{cur_time}-{unique_id}") + self.tensorboard_dir = os.path.join(output_base_dir, f"{cur_time}-{unique_id}") pp_stage = 0 group_mates = [0] self.rank = rank @@ -228,7 +253,7 @@ class TrainerMon: if (rank in self.module_rank_list) or len(self.module_rank_list) == 0: self.summary_writer = writer( WriterInput( - tensorboard_dir, + self.tensorboard_dir, self.alert_rules, unique_id, None, @@ -338,6 +363,19 @@ class TrainerMon: if self.mv_distribution: raise Exception("mv_distribution cannot be enabled with unknown optimizer.") + def hook_for_stack(self, model): + def stack_hook(name, module, module_input, module_output): + if not self.stack_info: + return + if module not in self.module_fwd_hook_context_by_module: + self.module_fwd_hook_context_by_module[module] = ModuleHookContext(name) + context: ModuleHookContext = self.module_fwd_hook_context_by_module[module] + context.stack = ModuleHookContext.analyze_api_call_stack() + + for m in model: + for name, module in m.named_modules(): + module.register_forward_hook(partial(stack_hook, name)) + def hook_modules(self, model: torch.nn.Module, grad_acc_steps): if self.module_rank_list and (self.rank not in self.module_rank_list): return @@ -362,6 +400,9 @@ class TrainerMon: 'targets'].keys() hooked_count += self._hook_module(targets, model_chunk, vpp_stage) + if self.stack_info: + self.hook_for_stack(model) + logger.info_on_rank_0(f"> {hooked_count} out of {len(self.config['targets'])} are monitored.") def clone_if_tensor(args): @@ -466,6 +507,18 @@ class TrainerMon: def write_adhoc_check(self, step): TrainerMon.tensor_metrics.flush(self.summary_writer) + def write_stack_info(self): + stack_data = [] + header = ["module_name", "stack_info"] + stack_data.append(header) + for _, fwd_context in self.module_fwd_hook_context_by_module.items(): + stack_data.append([fwd_context.module_name, fwd_context.stack]) + filepath = os.path.join(self.tensorboard_dir, f'stack_info.csv') + if not os.path.exists(filepath): + data_frame = pd.DataFrame(columns=stack_data) + write_df_to_csv(data_frame, filepath) + + def write_xy_tb(self, step): if not self.xy_distribution: return @@ -475,7 +528,7 @@ class TrainerMon: self.write_metrics(self.ops, self.summary_writer, fwd_context.actv, step, 'actv') fwd_context.actv.clear() if self.grad_context.actv: - self.write_metrics(self.ops, self.summary_writer, self.grad_context.actv, step, 'actv_grad') + self.write_metrics(self.ops, self.summary_writer, self.grad_context, step, 'actv_grad') def write_param_tb(self, opt_context): if not self.param_distribution: @@ -572,6 +625,9 @@ class TrainerMon: self.write_mv_tb(context) self.write_param_tb(context) self.write_adhoc_check(context.step) + if self.stack_info: + self.write_stack_info() + self.stack_info = False if self.ur_distribution: for param_name, _ in context.param_adam_update.items():