diff --git a/debug/accuracy_tools/msprobe/mindspore/monitor/module_hook.py b/debug/accuracy_tools/msprobe/mindspore/monitor/module_hook.py index 4cb8684f7bcd18ff6a22873dca3079227825fbcc..82eeed19864e9caee4122138a1d12e1ba6af53ea 100644 --- a/debug/accuracy_tools/msprobe/mindspore/monitor/module_hook.py +++ b/debug/accuracy_tools/msprobe/mindspore/monitor/module_hook.py @@ -21,6 +21,7 @@ from datetime import datetime import pytz import pandas as pd +import mindspore from mindspore import Tensor, mint from mindspore import nn, _no_grad @@ -30,7 +31,8 @@ from msprobe.core.common.file_utils import load_json, save_json from msprobe.core.monitor.utils import validate_config, get_output_base_dir, get_target_output_dir from msprobe.core.monitor.anomaly_processor import AnomalyScanner, AnomalyDataFactory, AnomalyDataWriter from msprobe.mindspore.common.utils import is_mindtorch -from msprobe.mindspore.monitor.common_func import is_valid_instance, get_parameters, get_submodules, get_rank +from msprobe.mindspore.monitor.common_func import is_valid_instance, get_parameters, get_submodules, get_rank, \ + comm_is_initialized from msprobe.mindspore.monitor.utils import get_summary_writer_tag_name, step_accumulates_one, is_skip_step, \ get_metrics from msprobe.mindspore.monitor.optimizer_collect import OptimizerMonFactory @@ -250,7 +252,6 @@ class TrainerMon: self.has_collect_times = 0 # 重设采集计数器 self.print_struct = self.config.get("print_struct", False) self.targets = self.config.get("targets", None) - self.is_select = self.config.get("is_select", False) self.module_rank_list = self.config.get("module_ranks", []) self.format = self.config.get('format', MonitorConst.CSV) # only csv supported in mindspore self.eps = self.config.get('eps', 1e-8) @@ -357,7 +358,7 @@ class TrainerMon: if self.monitoring: module_rank_valid = self.is_target_rank() step_condition = (context.step >= self.start_step and ( - context.step - self.start_step) % self.step_interval == 0) + context.step - self.start_step) % self.step_interval == 0) if module_rank_valid and step_condition: self.has_collect_times += 1 @@ -391,6 +392,7 @@ class TrainerMon: context.step += 1 self.dynamic_monitor(optimizer) + def patch_step(func, optimizer): def wrapper(*args, **kwargs): for hook in self.pre_step_hooks: @@ -400,7 +402,6 @@ class TrainerMon: hook(optimizer, args, kwargs) step_final_hook(optimizer, args, kwargs) return out - return wrapper if self.is_mindtorch: @@ -478,6 +479,9 @@ class TrainerMon: def hook_optimizer(self, optimizer): def optimizer_pre_step_hook(opt, *args, **kwargs): context = self.optimizer_context[opt] + if (self.print_struct and not all(value == {} for value in self.module_struct.values()) + and not self.struct_printed): + self._save_module_struct() if is_skip_step(context.step, self.start_step, self.step_interval, self.has_collect_times, self.collect_times): return @@ -489,7 +493,7 @@ class TrainerMon: if self.mv_distribution or self.ur_distribution or self.mg_direction: if self.is_mindtorch: context.param_exp_avg, context.param_exp_avg_sq, context.param_adam_update, \ - context.param_adam_ratio = self.optimizer_mon.fetch_mv(self, self.param2name) + context.param_adam_ratio = self.optimizer_mon.fetch_mv(self, self.param2name) else: context.param_exp_avg, context.param_exp_avg_sq = self.get_mv_for_ms(optimizer) @@ -563,9 +567,9 @@ class TrainerMon: v_dict = {} for name, param in get_parameters(common_opt): if MonitorConst.EXP_AVG_SQ in name: - m_dict[name] = param - elif MonitorConst.EXP_AVG in name: v_dict[name] = param + elif MonitorConst.EXP_AVG in name: + m_dict[name] = param return m_dict, v_dict def generate_mv_metrics(self, opt_context): @@ -695,6 +699,17 @@ class TrainerMon: } index += 1 + def _save_module_struct(self): + save_module_struct = (not comm_is_initialized() + or (self.module_rank_list and get_rank() == min(self.module_rank_list)) + or (not self.module_rank_list and get_rank() == 0)) + + if save_module_struct: + module_struct_file = os.path.realpath(os.path.join(get_output_base_dir(), 'module_struct.json')) + save_json(module_struct_file, self.module_struct, indent=2) + logger.info(f"> save module struct to {module_struct_file}") + self.struct_printed = True + def _hook_module(self, target_names, module, vpp_stage=''): if not is_valid_instance(module): # nothing to hook @@ -784,11 +799,17 @@ class TrainerMon: step_accumulates_one(context, self.micro_batch_number) return - def fwd_hook_fun_wrapper(fwd_hook_fun, name): - def wrapper(module, args, kwargs, module_output): - return fwd_hook_fun(module, args, kwargs, module_output, name) + def fwd_hook_register(module, fwd_hook_fun, name): + from packaging import version + if version.parse(mindspore.__version__) >= version.parse('2.6.0'): + def wrapper(module, args, kwargs, module_output): + return fwd_hook_fun(module, args, kwargs, module_output, name) + return module.register_forward_hook(wrapper, with_kwargs=True) - return wrapper + else: + def wrapper(module, args, module_output): + return fwd_hook_fun(module, args, None, module_output, name) + return module.register_forward_hook(wrapper) def stack_hook(module, args, kwargs, module_output, name): if module not in self.module_fwd_hook_context_by_module: @@ -804,15 +825,14 @@ class TrainerMon: for module_name, submodule in get_submodules(module): if self.stack_info: name = vpp_stage + squash_param_name(module_name) - handle = submodule.register_forward_hook(fwd_hook_fun_wrapper(stack_hook, name=name), with_kwargs=True) + handle = fwd_hook_register(submodule, stack_hook, name=name) self.handles["stack"].append(handle) name = self._is_target_module(module_name, target_names, vpp_stage) if not name: continue if self.xy_distribution or self.print_struct: if not self.backward_only: - handle = submodule.register_forward_hook(fwd_hook_fun_wrapper(fwd_hook_fun, name=name), - with_kwargs=True) + handle = fwd_hook_register(submodule, fwd_hook_fun, name=name) self.handles['xy'].append(handle) if not self.forward_only: handle = submodule.register_backward_hook(bwd_hook_fun)