diff --git a/debug/accuracy_tools/msprobe/mindspore/monitor/module_hook.py b/debug/accuracy_tools/msprobe/mindspore/monitor/module_hook.py index 0d74fca7808e30a5a071aeb9ff556be0cd116bc1..82eeed19864e9caee4122138a1d12e1ba6af53ea 100644 --- a/debug/accuracy_tools/msprobe/mindspore/monitor/module_hook.py +++ b/debug/accuracy_tools/msprobe/mindspore/monitor/module_hook.py @@ -31,7 +31,8 @@ from msprobe.core.common.file_utils import load_json, save_json from msprobe.core.monitor.utils import validate_config, get_output_base_dir, get_target_output_dir from msprobe.core.monitor.anomaly_processor import AnomalyScanner, AnomalyDataFactory, AnomalyDataWriter from msprobe.mindspore.common.utils import is_mindtorch -from msprobe.mindspore.monitor.common_func import is_valid_instance, get_parameters, get_submodules, get_rank +from msprobe.mindspore.monitor.common_func import is_valid_instance, get_parameters, get_submodules, get_rank, \ + comm_is_initialized from msprobe.mindspore.monitor.utils import get_summary_writer_tag_name, step_accumulates_one, is_skip_step, \ get_metrics from msprobe.mindspore.monitor.optimizer_collect import OptimizerMonFactory @@ -478,6 +479,9 @@ class TrainerMon: def hook_optimizer(self, optimizer): def optimizer_pre_step_hook(opt, *args, **kwargs): context = self.optimizer_context[opt] + if (self.print_struct and not all(value == {} for value in self.module_struct.values()) + and not self.struct_printed): + self._save_module_struct() if is_skip_step(context.step, self.start_step, self.step_interval, self.has_collect_times, self.collect_times): return @@ -695,6 +699,17 @@ class TrainerMon: } index += 1 + def _save_module_struct(self): + save_module_struct = (not comm_is_initialized() + or (self.module_rank_list and get_rank() == min(self.module_rank_list)) + or (not self.module_rank_list and get_rank() == 0)) + + if save_module_struct: + module_struct_file = os.path.realpath(os.path.join(get_output_base_dir(), 'module_struct.json')) + save_json(module_struct_file, self.module_struct, indent=2) + logger.info(f"> save module struct to {module_struct_file}") + self.struct_printed = True + def _hook_module(self, target_names, module, vpp_stage=''): if not is_valid_instance(module): # nothing to hook @@ -785,7 +800,8 @@ class TrainerMon: return def fwd_hook_register(module, fwd_hook_fun, name): - if mindspore.__version__ >= '2.6.0': + from packaging import version + if version.parse(mindspore.__version__) >= version.parse('2.6.0'): def wrapper(module, args, kwargs, module_output): return fwd_hook_fun(module, args, kwargs, module_output, name) return module.register_forward_hook(wrapper, with_kwargs=True)