From 63d167b166623c1f7cc7813437e65aff7e63c089 Mon Sep 17 00:00:00 2001 From: TAJh Date: Tue, 24 Jun 2025 16:47:51 +0800 Subject: [PATCH] bugfix for mon --- .../msprobe/mindspore/monitor/module_hook.py | 64 ++++++++++++++++--- .../msprobe/mindspore/monitor/utils.py | 22 +++++++ 2 files changed, 77 insertions(+), 9 deletions(-) diff --git a/debug/accuracy_tools/msprobe/mindspore/monitor/module_hook.py b/debug/accuracy_tools/msprobe/mindspore/monitor/module_hook.py index 0354ab53368..92b4a45234c 100644 --- a/debug/accuracy_tools/msprobe/mindspore/monitor/module_hook.py +++ b/debug/accuracy_tools/msprobe/mindspore/monitor/module_hook.py @@ -32,7 +32,7 @@ from msprobe.core.monitor.anomaly_processor import AnomalyScanner, AnomalyDataFa from msprobe.mindspore.common.utils import is_mindtorch from msprobe.mindspore.monitor.common_func import is_valid_instance, get_parameters, get_submodules, get_rank from msprobe.mindspore.monitor.utils import get_summary_writer_tag_name, validate_config, step_accumulates_one, \ - is_skip_step, get_metrics, get_target_output_dir + is_skip_step, get_metrics, get_target_output_dir, is_skip_name, flatten_grads from msprobe.mindspore.monitor.optimizer_collect import OptimizerMonFactory from msprobe.mindspore.monitor.data_writers import CSVWriterWithAD, BaseWriterWithAD, WriterInput from msprobe.mindspore.monitor.distributed.wrap_distributed import api_register, create_hooks, op_aggregate @@ -394,11 +394,7 @@ class TrainerMon: def patch_step(func, optimizer): def wrapper(*args, **kwargs): - for hook in self.pre_step_hooks: - hook(optimizer, args, kwargs) out = func(*args, **kwargs) - for hook in self.post_step_hooks: - hook(optimizer, args, kwargs) step_final_hook(optimizer, args, kwargs) return out return wrapper @@ -406,7 +402,7 @@ class TrainerMon: if self.is_mindtorch: optimizer.__class__.step = patch_step(optimizer.__class__.step, optimizer) else: - optimizer.__class__.construct = patch_step(optimizer.__class__.construct, optimizer) + optimizer.register_forward_hook(step_final_hook) return @@ -475,6 +471,29 @@ class TrainerMon: hooked_count += self._hook_module(targets, model_chunk, vpp_stage) logger.info(f"> {hooked_count} modules are monitored.") + def get_grad_for_ms(self, opt, grads): + """获取MindSpore优化器中的梯度信息""" + # 如果未启用权重梯度分布监控,直接返回空字典 + if not self.wg_distribution: + return {}, {} + + # 获取实际的优化器实例 + common_opt = opt + if not is_valid_instance(opt): + common_opt = getattr(opt, 'optimizer', None) + if not is_valid_instance(common_opt): + logger.warning("Optimizer is not valid, please check usage") + return {}, {} + prefix = f'{MonitorConst.FORWARD_STAGE}{MonitorConst.NAME_SEP}' + # 构建参数名称到梯度的映射 + try: + grad_names = [prefix + name for name, _ in get_parameters(opt) if not is_skip_name(name)] + grad_dict = dict(zip(grad_names, flatten_grads(grads))) + return grad_dict + except Exception as e: + logger.warning(f"Failed to get gradients: {str(e)}") + return {}, {} + def hook_optimizer(self, optimizer): def optimizer_pre_step_hook(opt, *args, **kwargs): context = self.optimizer_context[opt] @@ -484,7 +503,10 @@ class TrainerMon: grad_dict = {} if self.wg_distribution: - grad_dict = self.optimizer_mon.fetch_grad(self, self.param2name) + if self.is_mindtorch: + grad_dict = self.optimizer_mon.fetch_grad(self, self.param2name) + else: + grad_dict = self.get_grad_for_ms(optimizer, args) if self.mv_distribution or self.ur_distribution or self.mg_direction: if self.is_mindtorch: @@ -516,8 +538,28 @@ class TrainerMon: if self.optimizer_hooked or not self.is_target_rank(): return - self.pre_step_hooks.append(optimizer_pre_step_hook) - self.post_step_hooks.append(optimizer_post_step_hook) + def forward_pre_hook_fn(fn): + def forward_pre_hook(cell, inputs): + # 将inputs转换为args形式传入optimizer_pre_step_hook + optimizer_pre_step_hook(optimizer, inputs) + return + return forward_pre_hook + + def patch_step(func, optimizer): + def wrapper(*args, **kwargs): + optimizer_pre_step_hook(optimizer, args, kwargs) + out = func(*args, **kwargs) + optimizer_post_step_hook(optimizer, args, kwargs) + return out + return wrapper + + if self.is_mindtorch: + optimizer.__class__.step = patch_step(optimizer.__class__.step, optimizer) + else: + handle = optimizer.register_forward_pre_hook(forward_pre_hook_fn(optimizer_pre_step_hook)) + self.handles['opt'].append(handle) + handle = optimizer.register_forward_hook(optimizer_post_step_hook) + self.handles['opt'].append(handle) self.optimizer_hooked = True return @@ -904,6 +946,10 @@ class TrainerMon: if self.optimizer_hooked: self.pre_step_hooks.clear() self.post_step_hooks.clear() + for handle in self.handles['opt']: + handle.remove() + self.handles['opt'].clear() + for _, context in self.optimizer_context.items(): context.reset() self.optimizer_hooked = False diff --git a/debug/accuracy_tools/msprobe/mindspore/monitor/utils.py b/debug/accuracy_tools/msprobe/mindspore/monitor/utils.py index e0817eb2a4e..307d39d9a59 100644 --- a/debug/accuracy_tools/msprobe/mindspore/monitor/utils.py +++ b/debug/accuracy_tools/msprobe/mindspore/monitor/utils.py @@ -327,3 +327,25 @@ def get_target_output_dir(monitor_path, time_start, time_end): if start_ok and end_ok: result[rank] = os.path.join(monitor_path, dirname) return result + + +def is_skip_name(name: str) -> bool: + """ + 判断参数名是否需要跳过 + Args: + name: 参数名称 + Returns: + bool: 是否需要跳过该参数 + """ + skip_keywords = {'step', 'learning', 'lr'} + return any(keyword in name.lower() for keyword in skip_keywords) + + +def flatten_grads(grads): + flattened = [] + if isinstance(grads, (tuple, list)): + for g in grads: + flattened.extend(flatten_grads(g)) + else: + flattened.append(grads) + return flattened \ No newline at end of file -- Gitee